Spaces:
Sleeping
Sleeping
import os | |
import math | |
import time | |
import torch | |
import torch.nn as nn | |
from torch.nn import functional as F | |
import wandb | |
import gradio as gr | |
from tqdm import tqdm | |
import tiktoken | |
from transformer import GPT, GPTConfig # Import from transformer.py instead | |
from torch.cuda.amp import autocast, GradScaler | |
# DataLoader class for handling input.txt | |
class DataLoaderLite: | |
def __init__(self, B, T, config): | |
self.B = B | |
self.T = T | |
self.config = config | |
# Load and tokenize input.txt | |
with open('input.txt', 'r', encoding='utf-8') as f: | |
text = f.read() | |
enc = tiktoken.get_encoding('gpt2') | |
self.tokens = torch.tensor(enc.encode(text), dtype=torch.long) | |
# Create dataset chunks for faster loading | |
self.data = [] | |
for i in range(0, len(self.tokens) - T, B * T): | |
chunk = self.tokens[i:i + B * T + 1] | |
if len(chunk) == B * T + 1: | |
self.data.append(chunk) | |
print(f'Loaded {len(self.tokens)} tokens') | |
print(f'Created {len(self.data)} batches') | |
self.current_idx = 0 | |
def next_batch(self): | |
chunk = self.data[self.current_idx] | |
x = chunk[:-1].view(self.B, self.T) | |
y = chunk[1:].view(self.B, self.T) | |
self.current_idx = (self.current_idx + 1) % len(self.data) | |
if self.config.pin_memory: | |
x = x.pin_memory() | |
y = y.pin_memory() | |
return x, y | |
class TrainingConfig: | |
def __init__(self): | |
# Smaller model architecture (~15M params) | |
self.n_layer = 6 # Reduced from 12 | |
self.n_head = 6 # Reduced from 12 | |
self.n_embd = 384 # Reduced from 768 | |
self.block_size = 256 # Keep this the same | |
self.dropout = 0.2 | |
# Optimized training hyperparameters for faster convergence | |
self.learning_rate = 1e-4 # Reduced learning rate for stability | |
self.max_iters = 50000 # Increased max iterations | |
self.batch_size = 4 # Reduced batch size | |
self.grad_clip = 0.5 # Reduced gradient clipping | |
self.weight_decay = 0.1 | |
self.betas = (0.9, 0.95) | |
self.warmup_iters = 2000 | |
self.lr_decay_iters = 40000 # Increased decay iterations | |
self.min_lr = 1e-5 | |
self.eval_interval = 100 # More frequent evaluation | |
self.eval_iters = 20 | |
# Performance optimization flags | |
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
self.gradient_checkpointing = True | |
self.mixed_precision = True | |
self.gradient_accumulation_steps = 8 # Increased for effective batch size | |
self.num_workers = 4 | |
self.pin_memory = True | |
# Check if Triton is available before enabling compile | |
try: | |
import triton | |
self.compile_model = True | |
except ImportError: | |
print("Triton not available, disabling model compilation") | |
self.compile_model = False | |
class TrainingLogger: | |
def __init__(self, log_file='training_log.txt'): | |
self.log_file = log_file | |
self.start_time = time.time() | |
# Initialize log file | |
with open(self.log_file, 'w') as f: | |
f.write("Training Log\n") | |
f.write("=" * 50 + "\n") | |
f.write(f"Training started at: {time.strftime('%Y-%m-%d %H:%M:%S')}\n\n") | |
f.write("Iteration | Train Loss | Val Loss | Learning Rate | Tokens/sec\n") | |
f.write("-" * 65 + "\n") | |
def log_step(self, iter_num, train_loss, val_loss, lr, tokens_per_sec): | |
log_line = f"{iter_num:>9} | {train_loss:>10.4f} | {val_loss:>8.4f} | {lr:>12.2e} | {tokens_per_sec:>9.2f}" | |
print(log_line) | |
with open(self.log_file, 'a') as f: | |
f.write(log_line + "\n") | |
def log_message(self, message): | |
print(message) | |
with open(self.log_file, 'a') as f: | |
f.write("\n" + message + "\n") | |
def finish(self): | |
total_time = (time.time() - self.start_time) / 3600 # Convert to hours | |
message = f"\nTraining completed in {total_time:.2f} hours" | |
self.log_message(message) | |
def get_lr(it, config): | |
if it < config.warmup_iters: | |
return config.learning_rate * it / config.warmup_iters | |
if it > config.lr_decay_iters: | |
return config.min_lr | |
decay_ratio = (it - config.warmup_iters) / (config.lr_decay_iters - config.warmup_iters) | |
assert 0 <= decay_ratio <= 1 | |
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) | |
return config.min_lr + coeff * (config.learning_rate - config.min_lr) | |
def evaluate_loss(model, train_loader, config): | |
model.eval() | |
total_loss = 0.0 | |
with torch.no_grad(): | |
for _ in range(config.eval_iters): | |
x, y = train_loader.next_batch() | |
x, y = x.to(config.device), y.to(config.device) | |
_, loss = model(x, y) | |
total_loss += loss.item() | |
model.train() | |
return total_loss / config.eval_iters | |
def train_model(): | |
config = TrainingConfig() | |
logger = TrainingLogger() | |
# Create and optimize model | |
model_config = GPTConfig( | |
block_size=config.block_size, | |
n_layer=config.n_layer, | |
n_head=config.n_head, | |
n_embd=config.n_embd, | |
dropout=config.dropout | |
) | |
model = GPT(model_config) | |
if config.compile_model and hasattr(torch, 'compile'): | |
try: | |
model = torch.compile(model) | |
logger.log_message("Model compilation successful") | |
except Exception as e: | |
logger.log_message(f"Model compilation failed: {e}") | |
logger.log_message("Continuing without compilation") | |
if config.gradient_checkpointing: | |
model.gradient_checkpointing_enable() | |
model.to(config.device) | |
logger.log_message(f"Number of parameters: {sum(p.numel() for p in model.parameters())/1e6:.2f}M") | |
optimizer = torch.optim.AdamW( | |
model.parameters(), | |
lr=config.learning_rate, | |
betas=config.betas, | |
weight_decay=config.weight_decay | |
) | |
train_loader = DataLoaderLite(B=config.batch_size, T=config.block_size, config=config) | |
scaler = GradScaler() if config.mixed_precision else None | |
best_val_loss = float('inf') | |
no_improvement_count = 0 | |
for iter in tqdm(range(config.max_iters)): | |
iter_start = time.time() | |
# Training step | |
x, y = train_loader.next_batch() | |
x, y = x.to(config.device, non_blocking=True), y.to(config.device, non_blocking=True) | |
lr = get_lr(iter, config) | |
for param_group in optimizer.param_groups: | |
param_group['lr'] = lr | |
if config.mixed_precision: | |
with autocast(): | |
logits, loss = model(x, y) | |
loss = loss / config.gradient_accumulation_steps | |
scaler.scale(loss).backward() | |
if (iter + 1) % config.gradient_accumulation_steps == 0: | |
scaler.unscale_(optimizer) | |
torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip) | |
scaler.step(optimizer) | |
scaler.update() | |
optimizer.zero_grad(set_to_none=True) | |
else: | |
logits, loss = model(x, y) | |
loss = loss / config.gradient_accumulation_steps | |
loss.backward() | |
if (iter + 1) % config.gradient_accumulation_steps == 0: | |
torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip) | |
optimizer.step() | |
optimizer.zero_grad(set_to_none=True) | |
# Calculate metrics | |
iter_time = time.time() - iter_start | |
tokens_per_sec = config.batch_size * config.block_size / iter_time | |
# Evaluation and logging | |
if iter % config.eval_interval == 0: | |
val_loss = evaluate_loss(model, train_loader, config) | |
logger.log_step(iter, loss.item(), val_loss, lr, tokens_per_sec) | |
if val_loss < best_val_loss: | |
best_val_loss = val_loss | |
no_improvement_count = 0 | |
torch.save({ | |
'model_state_dict': model.state_dict(), | |
'optimizer_state_dict': optimizer.state_dict(), | |
'val_loss': val_loss, | |
'iter': iter, | |
'config': model_config | |
}, 'best_model.pt') | |
logger.log_message(f"New best model saved with validation loss: {val_loss:.6f}") | |
else: | |
no_improvement_count += 1 | |
if val_loss < 0.099999: | |
logger.log_message(f"Target loss achieved at iteration {iter}") | |
logger.log_message(f"Final validation loss: {val_loss:.6f}") | |
break | |
if no_improvement_count >= 5: | |
for param_group in optimizer.param_groups: | |
param_group['lr'] *= 0.5 | |
no_improvement_count = 0 | |
logger.log_message("Reducing learning rate due to no improvement") | |
logger.finish() | |
return model | |
def generate_text(model, prompt, max_length=100, temperature=0.7): | |
model.eval() | |
device = model.device | |
enc = tiktoken.get_encoding('gpt2') | |
input_ids = torch.tensor(enc.encode(prompt)).unsqueeze(0).to(device) | |
with torch.no_grad(): | |
output_sequence = [] | |
for _ in range(max_length): | |
outputs = model(input_ids) | |
logits = outputs[0] if isinstance(outputs, tuple) else outputs | |
next_token_logits = logits[:, -1, :] | |
# Apply temperature | |
next_token_logits = next_token_logits / temperature | |
probs = F.softmax(next_token_logits, dim=-1) | |
next_token = torch.multinomial(probs, num_samples=1) | |
output_sequence.append(next_token.item()) | |
input_ids = torch.cat([input_ids, next_token], dim=1) | |
return enc.decode(output_sequence) | |
if __name__ == "__main__": | |
# Train the model | |
model = train_model() | |
# Create and launch Gradio interface | |
def predict(prompt, length, temp=0.7): | |
return generate_text(model, prompt, length, temp) | |
iface = gr.Interface( | |
fn=predict, | |
inputs=[ | |
gr.Textbox(lines=2, label="Enter your prompt"), | |
gr.Slider(minimum=10, maximum=200, value=50, label="Max Length"), | |
gr.Slider(minimum=0.1, maximum=2.0, value=0.7, label="Temperature", step=0.1) | |
], | |
outputs=gr.Textbox(lines=5, label="Generated Text"), | |
title="Custom Transformer Text Generator", | |
description="Enter a prompt and adjust parameters to generate text" | |
) | |
iface.launch(share=True) |