|
import os
|
|
import torch
|
|
import wandb
|
|
import shutil
|
|
from config import SmolLM2Config
|
|
from model import SmolLM2Lightning
|
|
import pytorch_lightning as pl
|
|
from pytorch_lightning.callbacks import ModelCheckpoint, Callback
|
|
from pytorch_lightning.loggers import WandbLogger
|
|
from env_setup import setup_environment, cleanup_environment
|
|
|
|
|
|
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
|
|
os.environ['TORCH_USE_CUDA_DSA'] = '1'
|
|
|
|
def setup_training():
|
|
"""Setup training environment"""
|
|
try:
|
|
if torch.cuda.is_available():
|
|
|
|
torch.backends.cuda.matmul.allow_tf32 = True
|
|
torch.backends.cudnn.allow_tf32 = True
|
|
torch.backends.cudnn.benchmark = True
|
|
torch.set_float32_matmul_precision('high')
|
|
|
|
|
|
device = torch.device('cuda:0')
|
|
torch.cuda.set_device(device)
|
|
|
|
|
|
print(f"Using GPU: {torch.cuda.get_device_name()}")
|
|
print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
|
|
return device
|
|
except Exception as e:
|
|
print(f"CUDA setup error: {str(e)}")
|
|
|
|
print("Using CPU")
|
|
return torch.device('cpu')
|
|
|
|
def cleanup_training():
|
|
"""Cleanup training resources"""
|
|
try:
|
|
|
|
if torch.cuda.is_available():
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
try:
|
|
wandb.finish()
|
|
except:
|
|
pass
|
|
|
|
except Exception as e:
|
|
print(f"Cleanup error: {str(e)}")
|
|
|
|
|
|
device = setup_training()
|
|
|
|
class GenerationMonitorCallback(Callback):
|
|
def __init__(self, prompt="Explain what machine learning is:", sample_every_n_steps=500):
|
|
super().__init__()
|
|
self.prompt = prompt
|
|
self.sample_every_n_steps = sample_every_n_steps
|
|
|
|
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
|
|
try:
|
|
if (trainer.global_step + 1) % self.sample_every_n_steps == 0:
|
|
|
|
pl_module.eval()
|
|
|
|
with torch.no_grad():
|
|
|
|
inputs = pl_module.tokenizer(
|
|
self.prompt,
|
|
return_tensors="pt",
|
|
truncation=True,
|
|
max_length=pl_module.config.model.max_position_embeddings,
|
|
padding=True
|
|
).to(pl_module.device)
|
|
|
|
try:
|
|
|
|
outputs = pl_module.generate(
|
|
input_ids=inputs.input_ids,
|
|
attention_mask=inputs.attention_mask,
|
|
max_length=100,
|
|
temperature=0.7,
|
|
top_p=0.9,
|
|
top_k=50,
|
|
do_sample=True,
|
|
pad_token_id=pl_module.tokenizer.pad_token_id,
|
|
bos_token_id=pl_module.tokenizer.bos_token_id,
|
|
eos_token_id=pl_module.tokenizer.eos_token_id
|
|
)
|
|
|
|
|
|
generated_text = pl_module.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
|
|
|
|
print(f"\n=== Generation at step {trainer.global_step + 1} ===")
|
|
print(f"Prompt: {self.prompt}")
|
|
print(f"Generated: {generated_text}\n")
|
|
|
|
except RuntimeError as e:
|
|
print(f"\nError during generation at step {trainer.global_step + 1}: {str(e)}")
|
|
print(f"Input shape: {inputs.input_ids.shape}")
|
|
print(f"Input device: {inputs.input_ids.device}")
|
|
|
|
|
|
pl_module.train()
|
|
|
|
except Exception as e:
|
|
print(f"\nCallback error at step {trainer.global_step + 1}: {str(e)}")
|
|
|
|
def init_wandb(project_name, run_name):
|
|
"""Initialize WandB with error handling and cleanup"""
|
|
try:
|
|
|
|
wandb_dir = os.path.join(os.getcwd(), "wandb")
|
|
if os.path.exists(wandb_dir):
|
|
try:
|
|
shutil.rmtree(wandb_dir)
|
|
print("Cleaned up existing wandb directory")
|
|
except Exception as e:
|
|
print(f"Warning: Could not clean up wandb directory: {str(e)}")
|
|
|
|
|
|
os.makedirs(wandb_dir, exist_ok=True)
|
|
|
|
|
|
logger = WandbLogger(
|
|
project=project_name,
|
|
name=run_name,
|
|
save_dir=os.getcwd(),
|
|
settings=wandb.Settings(start_method="thread")
|
|
)
|
|
return logger
|
|
|
|
except Exception as e:
|
|
print(f"Error initializing WandB: {str(e)}")
|
|
print("Continuing without WandB logging...")
|
|
return None
|
|
|
|
def main():
|
|
device = setup_training()
|
|
|
|
try:
|
|
|
|
config = SmolLM2Config("config.yaml")
|
|
|
|
|
|
model = SmolLM2Lightning(config)
|
|
|
|
|
|
print("\n=== Starting Phase 1 Training ===")
|
|
|
|
|
|
wandb_logger = init_wandb("smol-lm2", "training_run_phase1")
|
|
|
|
|
|
checkpoint_callback = ModelCheckpoint(
|
|
dirpath=config.training.checkpoint_dir,
|
|
filename="smol-lm2-phase1-{epoch:02d}-{train_loss:.2f}",
|
|
save_top_k=3,
|
|
monitor="train_loss",
|
|
mode="min",
|
|
every_n_train_steps=config.training.save_steps
|
|
)
|
|
|
|
|
|
generation_callback = GenerationMonitorCallback(
|
|
prompt=config.training.sample_prompt,
|
|
sample_every_n_steps=config.training.sample_frequency
|
|
)
|
|
|
|
|
|
trainer_phase1 = pl.Trainer(
|
|
max_steps=config.training.first_phase_steps,
|
|
accelerator=config.hardware.accelerator,
|
|
devices=config.hardware.devices,
|
|
precision=config.hardware.precision,
|
|
logger=wandb_logger,
|
|
callbacks=[checkpoint_callback, generation_callback],
|
|
gradient_clip_val=config.hardware.gradient_clip,
|
|
accumulate_grad_batches=config.training.gradient_accumulation_steps,
|
|
log_every_n_steps=config.training.logging_steps,
|
|
deterministic=False,
|
|
benchmark=True,
|
|
strategy='auto',
|
|
)
|
|
|
|
|
|
try:
|
|
trainer_phase1.fit(model)
|
|
except Exception as e:
|
|
print(f"Error during phase 1 training: {str(e)}")
|
|
raise
|
|
|
|
|
|
phase1_checkpoint_path = os.path.join(config.training.checkpoint_dir, "smol-lm2-phase1-final.ckpt")
|
|
trainer_phase1.save_checkpoint(phase1_checkpoint_path)
|
|
print(f"Phase 1 completed. Model saved to {phase1_checkpoint_path}")
|
|
|
|
|
|
if torch.cuda.is_available():
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
print("\n=== Starting Phase 2 Training ===")
|
|
|
|
|
|
try:
|
|
model = SmolLM2Lightning.load_from_checkpoint(phase1_checkpoint_path, config=config)
|
|
except Exception as e:
|
|
print(f"Error loading checkpoint for phase 2: {str(e)}")
|
|
raise
|
|
|
|
|
|
wandb_logger = init_wandb("smol-lm2", "training_run_phase2")
|
|
|
|
|
|
generation_callback = GenerationMonitorCallback(
|
|
prompt=config.training.sample_prompt,
|
|
sample_every_n_steps=config.training.second_phase_sample_frequency
|
|
)
|
|
|
|
|
|
trainer_phase2 = pl.Trainer(
|
|
max_steps=config.training.second_phase_steps,
|
|
accelerator=config.hardware.accelerator,
|
|
devices=config.hardware.devices,
|
|
precision=config.hardware.precision,
|
|
logger=wandb_logger,
|
|
callbacks=[generation_callback],
|
|
gradient_clip_val=config.hardware.gradient_clip,
|
|
accumulate_grad_batches=config.training.gradient_accumulation_steps,
|
|
log_every_n_steps=config.training.logging_steps,
|
|
deterministic=False,
|
|
benchmark=True,
|
|
)
|
|
|
|
|
|
try:
|
|
trainer_phase2.fit(model)
|
|
except Exception as e:
|
|
print(f"Error during phase 2 training: {str(e)}")
|
|
raise
|
|
|
|
|
|
final_checkpoint_path = os.path.join(config.training.checkpoint_dir, "smol-lm2-final.ckpt")
|
|
trainer_phase2.save_checkpoint(final_checkpoint_path)
|
|
print(f"Phase 2 completed. Final model saved to {final_checkpoint_path}")
|
|
|
|
except Exception as e:
|
|
print(f"\nTraining failed with error: {str(e)}")
|
|
if torch.cuda.is_available():
|
|
print(f"CUDA memory allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
|
|
print(f"CUDA memory cached: {torch.cuda.memory_reserved() / 1e9:.2f} GB")
|
|
raise
|
|
finally:
|
|
cleanup_training()
|
|
|
|
if __name__ == "__main__":
|
|
main() |