|
import os |
|
import platform |
|
from dotenv import load_dotenv |
|
import torch |
|
from transformers import GPT2LMHeadModel, GPT2Tokenizer, Trainer, TrainingArguments |
|
from datasets import load_dataset, concatenate_datasets |
|
from huggingface_hub import login |
|
import time |
|
import uvicorn |
|
from fastapi import FastAPI |
|
import threading |
|
import logging |
|
import warnings |
|
|
|
|
|
warnings.filterwarnings("ignore", category=FutureWarning) |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='%(asctime)s - %(levelname)s - %(message)s', |
|
handlers=[ |
|
logging.FileHandler("training.log"), |
|
logging.StreamHandler() |
|
] |
|
) |
|
|
|
|
|
load_dotenv() |
|
huggingface_token = os.getenv('HUGGINGFACE_TOKEN') |
|
if huggingface_token is None: |
|
raise ValueError("HUGGINGFACE_TOKEN no encontrado en las variables de entorno.") |
|
|
|
|
|
login(token=huggingface_token) |
|
|
|
|
|
app = FastAPI() |
|
|
|
@app.get("/") |
|
async def root(): |
|
return {"message": "Modelo entrenado y en ejecuci贸n."} |
|
|
|
def load_and_train(): |
|
model_name = 'gpt2' |
|
tokenizer = GPT2Tokenizer.from_pretrained(model_name) |
|
model = GPT2LMHeadModel.from_pretrained(model_name) |
|
|
|
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
model.resize_token_embeddings(len(tokenizer)) |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model.to(device) |
|
logging.info(f"Entrenando en: {device}") |
|
|
|
|
|
if platform.system() == "Linux": |
|
cache_dir = '/dev/shm' |
|
else: |
|
cache_dir = './cache' |
|
|
|
|
|
os.makedirs(cache_dir, exist_ok=True) |
|
|
|
|
|
try: |
|
dataset_humanizado = load_dataset('daily_dialog', split='train', cache_dir=cache_dir, trust_remote_code=True) |
|
dataset_codigo = load_dataset('code_search_net', split='train', cache_dir=cache_dir, trust_remote_code=True) |
|
except Exception as e: |
|
logging.error(f"Error al cargar los datasets: {e}") |
|
|
|
time.sleep(60) |
|
try: |
|
dataset_humanizado = load_dataset('alternative_dataset', split='train', cache_dir=cache_dir, trust_remote_code=True) |
|
dataset_codigo = load_dataset('alternative_code_dataset', split='train', cache_dir=cache_dir, trust_remote_code=True) |
|
except Exception as e: |
|
logging.error(f"Error al cargar el dataset alternativo: {e}") |
|
return |
|
|
|
logging.info("Daily Dialog columnas: %s", dataset_humanizado.column_names) |
|
logging.info("Code Search Net columnas: %s", dataset_codigo.column_names) |
|
|
|
|
|
combined_dataset = concatenate_datasets([dataset_humanizado, dataset_codigo]) |
|
|
|
logging.info("Dataset combinado columnas: %s", combined_dataset.column_names) |
|
|
|
|
|
def concatenate_text_fields(examples): |
|
""" |
|
Crea un nuevo campo 'text' concatenando los campos de texto disponibles en cada ejemplo. |
|
Prioriza 'dialog', luego 'whole_func_string', y luego 'func_documentation_string'. |
|
Si ninguno est谩 presente, asigna una cadena vac铆a. |
|
|
|
Args: |
|
examples (dict): Diccionario con listas de valores para cada columna. |
|
|
|
Returns: |
|
dict: Diccionario con el nuevo campo 'text'. |
|
""" |
|
texts = [] |
|
|
|
batch_size = len(next(iter(examples.values()))) |
|
|
|
for i in range(batch_size): |
|
text = '' |
|
if 'dialog' in examples and examples['dialog'][i]: |
|
|
|
dialog = examples['dialog'][i] |
|
if isinstance(dialog, list): |
|
dialog = ' '.join(dialog) |
|
text = dialog |
|
elif 'whole_func_string' in examples and examples['whole_func_string'][i]: |
|
whole_func = examples['whole_func_string'][i] |
|
if isinstance(whole_func, list): |
|
whole_func = ' '.join(whole_func) |
|
text = whole_func |
|
elif 'func_documentation_string' in examples and examples['func_documentation_string'][i]: |
|
func_doc = examples['func_documentation_string'][i] |
|
if isinstance(func_doc, list): |
|
func_doc = ' '.join(func_doc) |
|
text = func_doc |
|
else: |
|
text = '' |
|
|
|
|
|
if not isinstance(text, str): |
|
text = str(text) |
|
|
|
texts.append(text) |
|
|
|
examples['text'] = texts |
|
return examples |
|
|
|
|
|
combined_dataset = combined_dataset.map(concatenate_text_fields, batched=True) |
|
|
|
|
|
def tokenize_function(examples): |
|
return tokenizer( |
|
examples['text'], |
|
truncation=True, |
|
padding='max_length', |
|
max_length=512, |
|
clean_up_tokenization_spaces=True |
|
) |
|
|
|
|
|
tokenized_dataset = combined_dataset.map( |
|
tokenize_function, |
|
batched=True |
|
) |
|
|
|
|
|
training_args = TrainingArguments( |
|
output_dir=os.path.join(cache_dir, 'results'), |
|
per_device_train_batch_size=4, |
|
per_device_eval_batch_size=4, |
|
num_train_epochs=1, |
|
learning_rate=1e-5, |
|
logging_steps=100, |
|
save_total_limit=1, |
|
seed=42, |
|
weight_decay=0.01, |
|
warmup_ratio=0.1, |
|
evaluation_strategy="epoch", |
|
lr_scheduler_type="linear", |
|
save_strategy="epoch", |
|
logging_dir=os.path.join(cache_dir, 'logs'), |
|
) |
|
|
|
trainer = Trainer( |
|
model=model, |
|
args=training_args, |
|
train_dataset=tokenized_dataset, |
|
) |
|
|
|
while True: |
|
try: |
|
trainer.train() |
|
|
|
model.push_to_hub( |
|
'Yhhxhfh/nombre_de_tu_modelo', |
|
commit_message="Actualizaci贸n del modelo", |
|
add_to_git_credential=False |
|
) |
|
tokenizer.push_to_hub( |
|
'Yhhxhfh/nombre_de_tu_modelo', |
|
commit_message="Actualizaci贸n del tokenizador", |
|
add_to_git_credential=False |
|
) |
|
logging.info("Modelo y tokenizador subidos exitosamente.") |
|
time.sleep(0) |
|
except Exception as e: |
|
logging.error(f"Error durante el entrenamiento: {e}. Reiniciando el proceso de entrenamiento...") |
|
time.sleep(0) |
|
|
|
if __name__ == "__main__": |
|
|
|
threading.Thread(target=lambda: uvicorn.run(app, host="0.0.0.0", port=7860), daemon=True).start() |
|
load_and_train() |
|
|