Train / maggin.py
Yjhhh's picture
Rename main.py to maggin.py
e2d0b71 verified
from dotenv import load_dotenv
import os
import json
import requests
import redis
from transformers import (
AutoTokenizer,
AutoModelForSequenceClassification,
AutoModelForCausalLM,
)
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torch.optim import AdamW
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import HTMLResponse
import multiprocessing
import time
import uuid
import random
load_dotenv()
REDIS_HOST = os.getenv('REDIS_HOST')
REDIS_PORT = os.getenv('REDIS_PORT')
REDIS_PASSWORD = os.getenv('REDIS_PASSWORD')
app = FastAPI()
default_language = "es"
class ChatbotService:
def __init__(self):
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, decode_responses=True)
self.model_name = "response_model"
self.tokenizer_name = "response_tokenizer"
def get_response(self, user_id, message, language=default_language):
model = self.load_model_from_redis()
tokenizer = self.load_tokenizer_from_redis()
if model is None or tokenizer is None:
return "El modelo aún no está listo. Por favor, inténtelo de nuevo más tarde."
input_text = f"Usuario: {message} Asistente:"
input_ids = tokenizer.encode(input_text, return_tensors="pt").to("cpu")
with torch.no_grad():
output = model.generate(input_ids=input_ids, max_length=100, num_beams=5, no_repeat_ngram_size=2, early_stopping=True)
response = tokenizer.decode(output[0], skip_special_tokens=True)
response = response.replace(input_text, "").strip()
return response
def load_model_from_redis(self):
model_data_bytes = self.redis_client.get(f"model:{self.model_name}")
if model_data_bytes:
model = AutoModelForCausalLM.from_pretrained("gpt2")
model.load_state_dict(torch.load(model_data_bytes))
return model
else:
return None
def load_tokenizer_from_redis(self):
tokenizer_data_bytes = self.redis_client.get(f"tokenizer:{self.tokenizer_name}")
if tokenizer_data_bytes:
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.add_tokens(json.loads(tokenizer_data_bytes))
return tokenizer
else:
return None
chatbot_service = ChatbotService()
class UnifiedModel(nn.Module):
def __init__(self, models):
super(UnifiedModel, self).__init__()
self.models = nn.ModuleList(models)
hidden_size = self.models[0].config.hidden_size
self.projection = nn.Linear(len(models) * 3, 768)
self.classifier = nn.Linear(hidden_size, 3)
def forward(self, input_ids, attention_mask):
hidden_states = []
for model, input_id, attn_mask in zip(self.models, input_ids, attention_mask):
outputs = model(
input_ids=input_id,
attention_mask=attn_mask
)
hidden_states.append(outputs.logits)
concatenated_hidden_states = torch.cat(hidden_states, dim=1)
projected_features = self.projection(concatenated_hidden_states)
logits = self.classifier(projected_features)
return logits
@staticmethod
def load_model_from_redis(redis_client):
model_name = "unified_model"
model_data_bytes = redis_client.get(f"model:{model_name}")
if model_data_bytes:
model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=3)
model.load_state_dict(torch.load(model_data_bytes))
else:
model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=3)
return UnifiedModel([model, model])
class SyntheticDataset(Dataset):
def __init__(self, tokenizers, data):
self.tokenizers = tokenizers
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data[idx]
text = item['text']
label = item['label']
tokenized = {}
for name, tokenizer in self.tokenizers.items():
tokens = tokenizer(text, padding="max_length", truncation=True, max_length=128)
tokenized[f"input_ids_{name}"] = torch.tensor(tokens["input_ids"])
tokenized[f"attention_mask_{name}"] = torch.tensor(tokens["attention_mask"])
tokenized["labels"] = torch.tensor(label)
return tokenized
conversation_history = {}
@app.post("/process")
async def process(request: Request):
data = await request.json()
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, decode_responses=True)
tokenizers = {}
models = {}
model_name = "unified_model"
tokenizer_name = "unified_tokenizer"
model_data_bytes = redis_client.get(f"model:{model_name}")
tokenizer_data_bytes = redis_client.get(f"tokenizer:{tokenizer_name}")
if model_data_bytes:
model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=3)
model.load_state_dict(torch.load(model_data_bytes))
else:
model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=3)
models[model_name] = model
if tokenizer_data_bytes:
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.add_tokens(json.loads(tokenizer_data_bytes))
else:
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizers[tokenizer_name] = tokenizer
unified_model = UnifiedModel(list(models.values()))
unified_model.to(torch.device("cpu"))
if data.get("train"):
user_data = data.get("user_data", [])
if not user_data:
user_data = [
{"text": "Hola", "label": 1},
{"text": "Necesito ayuda", "label": 2},
{"text": "No entiendo", "label": 0}
]
redis_client.rpush("training_queue", json.dumps({
"tokenizers": {tokenizer_name: tokenizer.get_vocab()},
"data": user_data
}))
return {"message": "Training data received. Model will be updated asynchronously."}
elif data.get("message"):
user_id = data.get("user_id")
text = data['message']
language = data.get("language", default_language)
if user_id not in conversation_history:
conversation_history[user_id] = []
conversation_history[user_id].append(text)
contextualized_text = " ".join(conversation_history[user_id][-3:])
tokenized_inputs = [tokenizers[name](contextualized_text, return_tensors="pt") for name in tokenizers.keys()]
input_ids = [tokens['input_ids'] for tokens in tokenized_inputs]
attention_mask = [tokens['attention_mask'] for tokens in tokenized_inputs]
with torch.no_grad():
logits = unified_model(input_ids=input_ids, attention_mask=attention_mask)
predicted_class = torch.argmax(logits, dim=-1).item()
response = chatbot_service.get_response(user_id, contextualized_text, language)
redis_client.rpush("training_queue", json.dumps({
"tokenizers": {tokenizer_name: tokenizer.get_vocab()},
"data": [{"text": contextualized_text, "label": predicted_class}]
}))
return {"answer": response}
else:
raise HTTPException(status_code=400, detail="Request must contain 'train' or 'message'.")
def get_chatbot_response(user_id, question, predicted_class, language):
if user_id not in conversation_history:
conversation_history[user_id] = []
conversation_history[user_id].append(question)
return chatbot_service.get_response(user_id, question, language)
@app.get("/")
async def get_home():
user_id = str(uuid.uuid4())
html_code = f"""
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<title>Chatbot</title>
<style>
body {{
font-family: 'Arial', sans-serif;
background-color: #f4f4f9;
margin: 0;
padding: 0;
display: flex;
align-items: center;
justify-content: center;
min-height: 100vh;
}}
.container {{
background-color: #fff;
border-radius: 10px;
box-shadow: 0 2px 5px rgba(0, 0, 0, 0.1);
overflow: hidden;
width: 400px;
max-width: 90%;
}}
h1 {{
color: #333;
text-align: center;
padding: 20px;
margin: 0;
background-color: #f8f9fa;
border-bottom: 1px solid #eee;
}}
#chatbox {{
height: 400px;
padding: 20px;
overflow-y: auto;
}}
.message {{
margin-bottom: 15px;
padding: 10px;
border-radius: 5px;
max-width: 70%;
animation: slide-in 0.3s ease-out;
}}
.user-message {{
text-align: right;
background-color: #eee;
margin-left: 30%;
}}
.bot-message {{
text-align: left;
background-color: #ccf5ff;
margin-right: 30%;
}}
#input-area {{
display: flex;
padding: 10px;
background-color: #f8f9fa;
border-top: 1px solid #eee;
}}
#message-input {{
flex: 1;
padding: 10px;
border: 1px solid #ccc;
border-radius: 5px;
margin-right: 10px;
}}
#send-button {{
padding: 10px 15px;
background-color: #28a745;
color: white;
border: none;
cursor: pointer;
border-radius: 5px;
transition: background-color 0.3s ease;
}}
#send-button:hover {{
background-color: #218838;
}}
@keyframes slide-in {{
from {{
transform: translateX(-100%);
opacity: 0;
}}
to {{
transform: translateX(0);
opacity: 1;
}}
}}
</style>
</head>
<body>
<div class="container">
<h1>Chatbot</h1>
<div id="chatbox"></div>
<div id="input-area">
<input type="hidden" id="user-id" value="{user_id}">
<input type="text" id="message-input" placeholder="Escribe tu mensaje...">
<button id="send-button">Enviar</button>
</div>
</div>
<script>
const chatbox = document.getElementById('chatbox');
const messageInput = document.getElementById('message-input');
const sendButton = document.getElementById('send-button');
const userId = document.getElementById('user-id').value;
sendButton.addEventListener('click', sendMessage);
function sendMessage() {{
const message = messageInput.value;
if (message.trim() === '') return;
appendMessage('user', message);
messageInput.value = '';
fetch('/process', {{
method: 'POST',
headers: {{
'Content-Type': 'application/json'
}},
body: JSON.stringify({{ message: message, user_id: userId, language: 'es' }})
}})
.then(response => response.json())
.then(data => {{
appendMessage('bot', data.answer);
}});
}}
function appendMessage(sender, message) {{
const messageElement = document.createElement('div');
messageElement.classList.add('message', `${{sender}}-message`);
messageElement.textContent = message;
chatbox.appendChild(messageElement);
chatbox.scrollTop = chatbox.scrollHeight;
}}
</script>
</body>
</html>
"""
return HTMLResponse(content=html_code)
def push_to_redis(models, tokenizers, redis_client, model_name, tokenizer_name):
for model_name, model in models.items():
torch.save(model.state_dict(), model_name)
with open(model_name, "rb") as f:
redis_client.set(f"model:{model_name}", f.read())
for tokenizer_name, tokenizer in tokenizers.items():
tokens = tokenizer.get_vocab()
redis_client.set(f"tokenizer:{tokenizer_name}", json.dumps(tokens))
def continuous_training():
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, decode_responses=True)
while True:
try:
data = redis_client.lpop("training_queue")
if data:
data = json.loads(data)
unified_model = UnifiedModel.load_model_from_redis(redis_client)
unified_model.train()
train_dataset = SyntheticDataset(data["tokenizers"], data["data"])
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
optimizer = AdamW(unified_model.parameters(), lr=5e-5)
for epoch in range(10):
for batch in train_loader:
input_ids = [batch[f"input_ids_{name}"].to("cpu") for name in data["tokenizers"].keys()]
attention_mask = [batch[f"attention_mask_{name}"].to("cpu") for name in data["tokenizers"].keys()]
labels = batch["labels"].to("cpu")
outputs = unified_model(input_ids=input_ids, attention_mask=attention_mask)
loss = nn.CrossEntropyLoss()(outputs, labels)
loss.backward()
optimizer.step()
optimizer.zero_grad()
print(f"Epoch {epoch}, Loss {loss.item()}")
push_to_redis(
{"response_model": unified_model},
{"response_tokenizer": tokenizer},
redis_client,
"response_model",
"response_tokenizer",
)
time.sleep(10)
except Exception as e:
print(f"Error in continuous training: {e}")
time.sleep(5)
if __name__ == "__main__":
training_process = multiprocessing.Process(target=continuous_training)
training_process.start()
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)