dgdgdgdgd / app.py
Yhhxhfh's picture
Update app.py
57c7d28 verified
raw
history blame
7.71 kB
import os
import platform
from dotenv import load_dotenv
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer, Trainer, TrainingArguments
from datasets import load_dataset, concatenate_datasets
from huggingface_hub import login
import time
import uvicorn
from fastapi import FastAPI
import threading
import logging
import warnings
# Ignorar advertencias espec铆ficas si lo deseas (opcional)
warnings.filterwarnings("ignore", category=FutureWarning)
# Configurar logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler("training.log"),
logging.StreamHandler()
]
)
# Cargar las variables de entorno
load_dotenv()
huggingface_token = os.getenv('HUGGINGFACE_TOKEN')
if huggingface_token is None:
raise ValueError("HUGGINGFACE_TOKEN no encontrado en las variables de entorno.")
# Iniciar sesi贸n en Hugging Face
login(token=huggingface_token)
# Definir la aplicaci贸n FastAPI
app = FastAPI()
@app.get("/")
async def root():
return {"message": "Modelo entrenado y en ejecuci贸n."}
def load_and_train():
model_name = 'gpt2'
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = GPT2LMHeadModel.from_pretrained(model_name)
# Asignar el pad_token al eos_token
tokenizer.pad_token = tokenizer.eos_token
# Redimensionar las embeddings del modelo para incluir el pad_token
model.resize_token_embeddings(len(tokenizer))
# Verificar dispositivo
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
logging.info(f"Entrenando en: {device}")
# Determinar cache_dir
if platform.system() == "Linux":
cache_dir = '/dev/shm'
else:
cache_dir = './cache'
# Crear el directorio de cach茅 si no existe
os.makedirs(cache_dir, exist_ok=True)
# Intentar cargar los datasets con manejo de errores
try:
dataset_humanizado = load_dataset('daily_dialog', split='train', cache_dir=cache_dir, trust_remote_code=True)
dataset_codigo = load_dataset('code_search_net', split='train', cache_dir=cache_dir, trust_remote_code=True)
except Exception as e:
logging.error(f"Error al cargar los datasets: {e}")
# Intentar cargar un dataset alternativo
time.sleep(60) # Esperar 60 segundos antes de reintentar
try:
dataset_humanizado = load_dataset('alternative_dataset', split='train', cache_dir=cache_dir, trust_remote_code=True)
dataset_codigo = load_dataset('alternative_code_dataset', split='train', cache_dir=cache_dir, trust_remote_code=True)
except Exception as e:
logging.error(f"Error al cargar el dataset alternativo: {e}")
return
logging.info("Daily Dialog columnas: %s", dataset_humanizado.column_names)
logging.info("Code Search Net columnas: %s", dataset_codigo.column_names)
# Combinar los datasets en memoria
combined_dataset = concatenate_datasets([dataset_humanizado, dataset_codigo])
logging.info("Dataset combinado columnas: %s", combined_dataset.column_names)
# Funci贸n para crear un campo 'text' estandarizado
def concatenate_text_fields(examples):
"""
Crea un nuevo campo 'text' concatenando los campos de texto disponibles en cada ejemplo.
Prioriza 'dialog', luego 'whole_func_string', y luego 'func_documentation_string'.
Si ninguno est谩 presente, asigna una cadena vac铆a.
Args:
examples (dict): Diccionario con listas de valores para cada columna.
Returns:
dict: Diccionario con el nuevo campo 'text'.
"""
texts = []
# Determinar el tama帽o del lote
batch_size = len(next(iter(examples.values())))
for i in range(batch_size):
text = ''
if 'dialog' in examples and examples['dialog'][i]:
# Verificar si el campo es una lista y concatenar si es necesario
dialog = examples['dialog'][i]
if isinstance(dialog, list):
dialog = ' '.join(dialog)
text = dialog
elif 'whole_func_string' in examples and examples['whole_func_string'][i]:
whole_func = examples['whole_func_string'][i]
if isinstance(whole_func, list):
whole_func = ' '.join(whole_func)
text = whole_func
elif 'func_documentation_string' in examples and examples['func_documentation_string'][i]:
func_doc = examples['func_documentation_string'][i]
if isinstance(func_doc, list):
func_doc = ' '.join(func_doc)
text = func_doc
else:
text = '' # Asignar cadena vac铆a si no hay texto disponible
# Asegurar que 'text' es una cadena de texto
if not isinstance(text, str):
text = str(text)
texts.append(text)
examples['text'] = texts
return examples
# Crear el campo 'text'
combined_dataset = combined_dataset.map(concatenate_text_fields, batched=True)
# Funci贸n de tokenizaci贸n basada en el campo 'text'
def tokenize_function(examples):
return tokenizer(
examples['text'],
truncation=True,
padding='max_length',
max_length=512,
clean_up_tokenization_spaces=True # Para evitar la advertencia de FutureWarning
)
# Tokenizar el dataset
tokenized_dataset = combined_dataset.map(
tokenize_function,
batched=True
)
# Configurar argumentos de entrenamiento
training_args = TrainingArguments(
output_dir=os.path.join(cache_dir, 'results'), # Almacenar temporalmente en RAM
per_device_train_batch_size=4,
per_device_eval_batch_size=4,
num_train_epochs=1,
learning_rate=1e-5,
logging_steps=100,
save_total_limit=1,
seed=42,
weight_decay=0.01,
warmup_ratio=0.1,
evaluation_strategy="epoch",
lr_scheduler_type="linear",
save_strategy="epoch", # Guardar solo al final de cada epoch
logging_dir=os.path.join(cache_dir, 'logs'), # Directorio de logs
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset,
)
while True:
try:
trainer.train()
# Subir el modelo a Hugging Face desde la RAM
model.push_to_hub(
'Yhhxhfh/nombre_de_tu_modelo',
commit_message="Actualizaci贸n del modelo",
add_to_git_credential=False # Desactivar la configuraci贸n autom谩tica de credenciales de Git
)
tokenizer.push_to_hub(
'Yhhxhfh/nombre_de_tu_modelo',
commit_message="Actualizaci贸n del tokenizador",
add_to_git_credential=False # Desactivar la configuraci贸n autom谩tica de credenciales de Git
)
logging.info("Modelo y tokenizador subidos exitosamente.")
time.sleep(0) # Esperar 0 segundos antes de la siguiente iteraci贸n
except Exception as e:
logging.error(f"Error durante el entrenamiento: {e}. Reiniciando el proceso de entrenamiento...")
time.sleep(0) # Esperar 0 segundos antes de reintentar
if __name__ == "__main__":
# Correr FastAPI en un hilo separado
threading.Thread(target=lambda: uvicorn.run(app, host="0.0.0.0", port=7860), daemon=True).start()
load_and_train()