JMonga commited on
Commit
3c7cb3b
·
verified ·
1 Parent(s): 63b252b

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +70 -40
train.py CHANGED
@@ -1,43 +1,73 @@
1
- from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
2
- from datasets import load_dataset
3
- import torch
4
-
5
- # 📂 Charger les données d'entraînement
6
- dataset = load_dataset("json", data_files="training_data.jsonl", split="train")
7
-
8
- # 🔥 Charger le modèle GPT-2
9
- MODEL_NAME = "gpt2" # Change avec ton propre modèle si besoin
10
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
11
- model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
12
-
13
- # 🔄 Tokenisation des données
14
- def tokenize_function(examples):
15
- return tokenizer(examples["text"], truncation=True, padding="max_length", max_length=512)
16
-
17
- dataset = dataset.map(tokenize_function, batched=True)
18
-
19
- # 🎯 Définition des paramètres d'entraînement
20
- training_args = TrainingArguments(
21
- output_dir="./trained_model",
22
- per_device_train_batch_size=2,
23
- per_device_eval_batch_size=2,
24
- num_train_epochs=3,
25
- save_steps=1000,
26
- save_total_limit=2,
27
- logging_dir="./logs",
28
- push_to_hub=True,
29
- hub_model_id="TON_NOM_UTILISATEUR/mon_chatbot_gpt2"
30
- )
31
 
32
- trainer = Trainer(
33
- model=model,
34
- args=training_args,
35
- train_dataset=dataset
36
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
- # 🚀 Lancer l'entraînement
39
- trainer.train()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
- # 💾 Sauvegarder et envoyer le modèle sur Hugging Face
42
- trainer.push_to_hub()
43
- tokenizer.push_to_hub("TON_NOM_UTILISATEUR/mon_chatbot_gpt2")
 
1
+ import gradio as gr
2
+ import json
3
+ import os
4
+ from huggingface_hub import InferenceClient
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
+ # Initialisation du client Hugging Face
7
+ client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
+
9
+ # Fichier de stockage des questions-réponses personnalisées
10
+ DATA_FILE = "qna_data.json"
11
+ CREATOR_CODE = "LnMonga2025" # Code secret du créateur
12
+
13
+ # Fonction pour charger les données enregistrées
14
+ def load_json():
15
+ if not os.path.exists(DATA_FILE) or os.stat(DATA_FILE).st_size == 0:
16
+ return {}
17
+ try:
18
+ with open(DATA_FILE, "r", encoding="utf-8") as f:
19
+ return json.load(f)
20
+ except json.JSONDecodeError:
21
+ return {}
22
+
23
+ # Fonction pour sauvegarder une nouvelle question-réponse
24
+ def save_json(data):
25
+ with open(DATA_FILE, "w", encoding="utf-8") as f:
26
+ json.dump(data, f, indent=4, ensure_ascii=False)
27
+
28
+ # Fonction principale de réponse
29
+ def respond(message, history, system_message, max_tokens, temperature, top_p):
30
+ data = load_json()
31
 
32
+ # Vérification si l'utilisateur est le créateur en utilisant le code spécial
33
+ if message.startswith(f"{CREATOR_CODE}:input:"):
34
+ question = message.replace(f"{CREATOR_CODE}:input:", "").strip()
35
+ response = f"{CREATOR_CODE}:output:"
36
+
37
+ # Demande manuelle de réponse
38
+ response = input(f"Entrez la réponse pour '{question}': ").strip()
39
+
40
+ if response:
41
+ data[question] = response
42
+ save_json(data)
43
+ return f"✅ Réponse enregistrée : {response}"
44
+
45
+ # Vérifier si la question est déjà enregistrée
46
+ if message in data:
47
+ return data[message]
48
+
49
+ # Si non enregistré, utiliser GPT pour répondre normalement
50
+ messages = [{"role": "system", "content": system_message}] + [
51
+ {"role": "user", "content": q} if i % 2 == 0 else {"role": "assistant", "content": a}
52
+ for i, (q, a) in enumerate(history)
53
+ ] + [{"role": "user", "content": message}]
54
+
55
+ response = ""
56
+ for msg in client.chat_completion(messages, max_tokens=max_tokens, stream=True, temperature=temperature, top_p=top_p):
57
+ token = msg.choices[0].delta.content
58
+ response += token
59
+ yield response
60
+
61
+ # Interface Gradio
62
+ demo = gr.ChatInterface(
63
+ respond,
64
+ additional_inputs=[
65
+ gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
66
+ gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
67
+ gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
68
+ gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
69
+ ],
70
+ )
71
 
72
+ if __name__ == "__main__":
73
+ demo.launch()