from dotenv import load_dotenv import os import json import redis from transformers import ( AutoTokenizer, AutoModelForSequenceClassification, AutoModelForCausalLM, TrainingArguments, Trainer, AutoModelForTextToWaveform, pipeline, ) from diffusers import FluxPipeline from fastapi import FastAPI, HTTPException, Request from fastapi.responses import HTMLResponse import multiprocessing import uuid import torch from torch.utils.data import Dataset import numpy as np load_dotenv() REDIS_HOST = os.getenv('REDIS_HOST') REDIS_PORT = os.getenv('REDIS_PORT') REDIS_PASSWORD = os.getenv('REDIS_PASSWORD') app = FastAPI() default_language = "es" class ChatbotService: def __init__(self): self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD) self.model_name = "response_model" self.tokenizer_name = "response_tokenizer" self.model = self.load_model_from_redis() self.tokenizer = self.load_tokenizer_from_redis() def get_response(self, user_id, message, language=default_language): if self.model is None or self.tokenizer is None: return "El modelo aún no está listo. Por favor, inténtelo de nuevo más tarde." input_text = f"Usuario: {message} Asistente:" input_ids = self.tokenizer.encode(input_text, return_tensors="pt").to("cpu") with torch.no_grad(): output = self.model.generate(input_ids=input_ids, max_length=100, num_beams=5, no_repeat_ngram_size=2, early_stopping=True) response = self.tokenizer.decode(output[0], skip_special_tokens=True) response = response.replace(input_text, "").strip() return response def load_model_from_redis(self): model_data_bytes = self.redis_client.get(f"model:{self.model_name}") if model_data_bytes: model = AutoModelForCausalLM.from_pretrained("gpt2") model.load_state_dict(torch.load(model_data_bytes)) return model return None def load_tokenizer_from_redis(self): tokenizer_data_bytes = self.redis_client.get(f"tokenizer:{self.tokenizer_name}") if tokenizer_data_bytes: tokenizer = AutoTokenizer.from_pretrained("gpt2") tokenizer.add_tokens(json.loads(tokenizer_data_bytes.decode("utf-8"))) tokenizer.pad_token = tokenizer.eos_token return tokenizer return None chatbot_service = ChatbotService() class UnifiedModel(AutoModelForSequenceClassification): def __init__(self, config): super().__init__(config) @staticmethod def load_model_from_redis(redis_client): model_name = "unified_model" model_path = f"models/{model_name}" if redis_client.exists(f"model:{model_name}"): redis_client.delete(f"model:{model_name}") if not os.path.exists(model_path): model = UnifiedModel.from_pretrained("gpt2", num_labels=3) model.save_pretrained(model_path) else: model = UnifiedModel.from_pretrained(model_path) return model class SyntheticDataset(Dataset): def __init__(self, tokenizer, data): self.tokenizer = tokenizer self.data = data def __len__(self): return len(self.data) def __getitem__(self, idx): item = self.data[idx] text = item['text'] label = item['label'] tokens = self.tokenizer(text, padding="max_length", truncation=True, max_length=128, return_tensors="pt") return {"input_ids": tokens["input_ids"].squeeze(), "attention_mask": tokens["attention_mask"].squeeze(), "labels": label} conversation_history = {} tokenizer_name = "unified_tokenizer" tokenizer = None unified_model = None musicgen_tokenizer = AutoTokenizer.from_pretrained("facebook/musicgen-small") musicgen_model = AutoModelForTextToWaveform.from_pretrained("facebook/musicgen-small") image_pipeline = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) image_pipeline.enable_model_cpu_offload() @app.on_event("startup") async def startup_event(): global tokenizer, unified_model redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD) tokenizer_data_bytes = redis_client.get(f"tokenizer:{tokenizer_name}") if tokenizer_data_bytes: tokenizer = AutoTokenizer.from_pretrained("gpt2") tokenizer.add_tokens(json.loads(tokenizer_data_bytes.decode("utf-8"))) tokenizer.pad_token = tokenizer.eos_token else: tokenizer = AutoTokenizer.from_pretrained("gpt2") tokenizer.pad_token = tokenizer.eos_token unified_model = UnifiedModel.load_model_from_redis(redis_client) unified_model.to(torch.device("cpu")) @app.post("/process") async def process(request: Request): global tokenizer, unified_model data = await request.json() redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD) if data.get("train"): user_data = data.get("user_data", []) if not user_data: user_data = [ {"text": "Hola", "label": 1}, {"text": "Necesito ayuda", "label": 2}, {"text": "No entiendo", "label": 0} ] redis_client.rpush("training_queue", json.dumps({ "tokenizers": {tokenizer_name: tokenizer.get_vocab()}, "data": user_data })) return {"message": "Training data received. Model will be updated asynchronously."} elif data.get("message"): user_id = data.get("user_id") text = data['message'] language = data.get("language", default_language) if user_id not in conversation_history: conversation_history[user_id] = [] conversation_history[user_id].append(text) contextualized_text = " ".join(conversation_history[user_id][-3:]) tokenized_input = tokenizer(contextualized_text, return_tensors="pt") with torch.no_grad(): logits = unified_model(**tokenized_input).logits predicted_class = torch.argmax(logits, dim=-1).item() response = chatbot_service.get_response(user_id, contextualized_text, language) redis_client.rpush("training_queue", json.dumps({ "tokenizers": {tokenizer_name: tokenizer.get_vocab()}, "data": [{"text": contextualized_text, "label": predicted_class}] })) return {"answer": response} else: raise HTTPException(status_code=400, detail="Request must contain 'train' or 'message'.") @app.get("/") async def get_home(): user_id = str(uuid.uuid4()) html_code = f""" Chatbot

Chatbot

""" return HTMLResponse(content=html_code) def train_unified_model(): global tokenizer, unified_model redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD) while True: training_data = redis_client.lpop("training_queue") if training_data: item_data = json.loads(training_data) tokenizer_data = item_data["tokenizers"] tokenizer_name = list(tokenizer_data.keys())[0] if redis_client.exists(f"tokenizer:{tokenizer_name}"): tokenizer.add_tokens(list(tokenizer_data[tokenizer_name].keys())) data = item_data["data"] dataset = SyntheticDataset(tokenizer, data) model_name = "unified_model" model_path = f"models/{model_name}" training_args = TrainingArguments( output_dir="./results", per_device_train_batch_size=8, num_train_epochs=3, ) trainer = Trainer(model=unified_model, args=training_args, train_dataset=dataset) trainer.train() unified_model.save_pretrained(model_path) async def auto_learn(): global tokenizer, unified_model redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD) while True: training_data = redis_client.lpop("training_queue") if training_data: item_data = json.loads(training_data) tokenizer_data = item_data["tokenizers"] tokenizer_name = list(tokenizer_data.keys())[0] if redis_client.exists(f"tokenizer:{tokenizer_name}"): tokenizer.add_tokens(list(tokenizer_data[tokenizer_name].keys())) data = item_data["data"] dataset = SyntheticDataset(tokenizer, data) model_name = "unified_model" model_path = f"models/{model_name}" training_args = TrainingArguments( output_dir="./results", per_device_train_batch_size=8, num_train_epochs=3, ) trainer = Trainer(model=unified_model, args=training_args, train_dataset=dataset) trainer.train() unified_model.save_pretrained(model_path) async def auto_learn_music(): global musicgen_tokenizer, musicgen_model redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD) while True: music_training_data = redis_client.lpop("music_training_queue") if music_training_data: music_training_data = json.loads(music_training_data.decode("utf-8")) inputs = musicgen_tokenizer(music_training_data, return_tensors="pt", padding=True) musicgen_model.train() optimizer = torch.optim.Adam(musicgen_model.parameters(), lr=5e-5) loss_fn = torch.nn.CrossEntropyLoss() for epoch in range(1): outputs = musicgen_model(**inputs) loss = loss_fn(outputs.logits, inputs['labels']) optimizer.zero_grad() loss.backward() optimizer.step() async def auto_learn_images(): global image_pipeline redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD) while True: image_training_data = redis_client.lpop("image_training_queue") if image_training_data: image_training_data = json.loads(image_training_data.decode("utf-8")) for image_prompt in image_training_data: image = image_pipeline( image_prompt, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256, generator=torch.Generator("cpu").manual_seed(0) ).images[0] image_tensor = torch.tensor(np.array(image)).unsqueeze(0) image_pipeline.model.train() optimizer = torch.optim.Adam(image_pipeline.model.parameters(), lr=1e-5) loss_fn = torch.nn.MSELoss() target_tensor = torch.zeros_like(image_tensor) for epoch in range(1): outputs = image_pipeline.model(image_tensor) loss = loss_fn(outputs, target_tensor) optimizer.zero_grad() loss.backward() optimizer.step() if __name__ == "__main__": training_process = multiprocessing.Process(target=train_unified_model) training_process.start() music_training_process = multiprocessing.Process(target=auto_learn_music) music_training_process.start() image_training_process = multiprocessing.Process(target=auto_learn_images) image_training_process.start() import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)