Yhhxhfh commited on
Commit
a9bd469
1 Parent(s): 94c9328

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -22
app.py CHANGED
@@ -1,10 +1,9 @@
1
  import os
2
  from dotenv import load_dotenv
3
  import torch
4
- from transformers import GPT2LMHeadModel, GPT2Tokenizer
5
  from datasets import load_dataset, concatenate_datasets
6
  from huggingface_hub import login
7
- from autotrain import AutoTrain
8
  import time
9
 
10
  load_dotenv()
@@ -14,6 +13,7 @@ model_name = 'gpt2'
14
  tokenizer = GPT2Tokenizer.from_pretrained(model_name)
15
  model = GPT2LMHeadModel.from_pretrained(model_name)
16
 
 
17
  dataset_humanizado = load_dataset('daily_dialog', split='train')
18
  dataset_codigo = load_dataset('code_search_net', split='train')
19
  dataset_prompts = load_dataset('openai_humaneval', split='train')
@@ -24,36 +24,40 @@ combined_dataset = concatenate_datasets([
24
  dataset_prompts
25
  ])
26
 
 
27
  def tokenize_function(examples):
28
  return tokenizer(examples['text'], truncation=True, padding='max_length', max_length=512)
29
 
30
  tokenized_dataset = combined_dataset.map(tokenize_function, batched=True)
31
 
32
- training_args = {
33
- "output_dir": './results',
34
- "per_device_train_batch_size": 100,
35
- "per_device_eval_batch_size": 100,
36
- "num_train_epochs": 0,
37
- "learning_rate": 1e-5,
38
- "logging_steps": -1,
39
- "max_grad_norm": 1,
40
- "save_total_limit": 1,
41
- "seed": 42,
42
- "weight_decay": 0,
43
- "warmup_ratio": 0.0,
44
- "evaluation_strategy": "no",
45
- "optim": "adamw_torch",
46
- "lr_scheduler_type": "constant",
47
- "model_max_length": 2098989848
48
- }
49
 
50
- autotrain = AutoTrain(model=model, args=training_args)
 
 
 
 
51
 
52
  @spaces.gpu
53
  def run_training():
54
  while True:
55
  try:
56
- autotrain.train(tokenized_dataset)
57
  model.push_to_hub('Yhhxhfh/nombre_de_tu_modelo', repo_type='model', use_temp_dir=True, commit_message="Actualización del modelo")
58
  tokenizer.push_to_hub('Yhhxhfh/nombre_de_tu_modelo', repo_type='model', use_temp_dir=True, commit_message="Actualización del tokenizador")
59
  time.sleep(5)
@@ -64,4 +68,4 @@ def run_training():
64
  run_training()
65
 
66
  import shutil
67
- shutil.rmtree('./results', ignore_errors=True)
 
1
  import os
2
  from dotenv import load_dotenv
3
  import torch
4
+ from transformers import GPT2LMHeadModel, GPT2Tokenizer, Trainer, TrainingArguments
5
  from datasets import load_dataset, concatenate_datasets
6
  from huggingface_hub import login
 
7
  import time
8
 
9
  load_dotenv()
 
13
  tokenizer = GPT2Tokenizer.from_pretrained(model_name)
14
  model = GPT2LMHeadModel.from_pretrained(model_name)
15
 
16
+ # Cargar datasets y mantener todo en RAM
17
  dataset_humanizado = load_dataset('daily_dialog', split='train')
18
  dataset_codigo = load_dataset('code_search_net', split='train')
19
  dataset_prompts = load_dataset('openai_humaneval', split='train')
 
24
  dataset_prompts
25
  ])
26
 
27
+ # Tokenizar y mantener todo en RAM
28
  def tokenize_function(examples):
29
  return tokenizer(examples['text'], truncation=True, padding='max_length', max_length=512)
30
 
31
  tokenized_dataset = combined_dataset.map(tokenize_function, batched=True)
32
 
33
+ training_args = TrainingArguments(
34
+ output_dir='./results', # Puede ser usado para guardar resultados, pero no es necesario en RAM
35
+ per_device_train_batch_size=100,
36
+ per_device_eval_batch_size=100,
37
+ num_train_epochs=0,
38
+ learning_rate=1e-5,
39
+ logging_steps=-1,
40
+ max_grad_norm=1,
41
+ save_total_limit=1,
42
+ seed=42,
43
+ weight_decay=0,
44
+ warmup_ratio=0.0,
45
+ evaluation_strategy="no",
46
+ optim="adamw_torch",
47
+ lr_scheduler_type="constant",
48
+ )
 
49
 
50
+ trainer = Trainer(
51
+ model=model,
52
+ args=training_args,
53
+ train_dataset=tokenized_dataset,
54
+ )
55
 
56
  @spaces.gpu
57
  def run_training():
58
  while True:
59
  try:
60
+ trainer.train()
61
  model.push_to_hub('Yhhxhfh/nombre_de_tu_modelo', repo_type='model', use_temp_dir=True, commit_message="Actualización del modelo")
62
  tokenizer.push_to_hub('Yhhxhfh/nombre_de_tu_modelo', repo_type='model', use_temp_dir=True, commit_message="Actualización del tokenizador")
63
  time.sleep(5)
 
68
  run_training()
69
 
70
  import shutil
71
+ shutil.rmtree('./results', ignore_errors=True) # Limpiar si es necesario, aunque puede no ser requerido si todo está en RAM