dejanseo's picture
Upload 7 files
f196edb verified
raw
history blame
9.03 kB
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AlbertTokenizer, AlbertModel, AdamW, get_linear_schedule_with_warmup
from sklearn.model_selection import train_test_split
import numpy as np
import os
from tqdm.auto import tqdm
import streamlit as st
import matplotlib.pyplot as plt
import torch.nn as nn
# Constants
EPOCHS = 10
VAL_SPLIT = 0.1
VAL_EVERY_STEPS = 1000
BATCH_SIZE = 38
LEARNING_RATE = 5e-5
LOG_EVERY_STEP = True
SAVE_CHECKPOINTS = True
MAX_SEQ_LENGTH = 512
EARLY_STOPPING_PATIENCE = 3
MODEL_NAME = 'albert/albert-base-v2'
LEVEL = 4
OUTPUT_DIR = f'level{LEVEL}'
# Ensure output directory exists
os.makedirs(OUTPUT_DIR, exist_ok=True)
# Load data
df = pd.read_csv(f'level_{LEVEL}.csv')
df.rename(columns={'response': 'text'}, inplace=True)
# Get unique labels for current level and create mapping
labels = sorted(df[str(LEVEL)].unique())
label_to_index = {label: i for i, label in enumerate(labels)}
index_to_label = {i: label for label, i in label_to_index.items()}
num_labels = len(labels)
# Save label mapping for current level
np.save(os.path.join(OUTPUT_DIR, 'label_map.npy'), label_to_index)
# Load parent level ID mapping
parent_level = LEVEL - 1
parent_label_to_index = np.load(f'level{parent_level}/label_map.npy', allow_pickle=True).item()
num_parent_labels = len(parent_label_to_index)
# Prepare data for training
df['label'] = df[str(LEVEL)].map(label_to_index)
train_df, val_df = train_test_split(df, test_size=VAL_SPLIT, random_state=42)
# Tokenizer
tokenizer = AlbertTokenizer.from_pretrained(MODEL_NAME)
class TaxonomyDataset(Dataset):
def __init__(self, dataframe, tokenizer, max_len, parent_label_to_index):
self.data = dataframe
self.tokenizer = tokenizer
self.max_len = max_len
self.parent_label_to_index = parent_label_to_index
def __len__(self):
return len(self.data)
def __getitem__(self, index):
text = str(self.data.iloc[index].text)
label = int(self.data.iloc[index].label)
parent_id = int(self.data.iloc[index][str(LEVEL - 1)])
encoding = self.tokenizer.encode_plus(
text,
add_special_tokens=True,
max_length=self.max_len,
padding='max_length',
truncation=True,
return_attention_mask=True,
return_tensors='pt'
)
# One-hot encode parent ID
parent_one_hot = torch.zeros(len(self.parent_label_to_index))
if parent_id != 0:
parent_index = self.parent_label_to_index.get(parent_id)
if parent_index is not None:
parent_one_hot[parent_index] = 1
return {
'input_ids': encoding['input_ids'].flatten(),
'attention_mask': encoding['attention_mask'].flatten(),
'parent_ids': parent_one_hot,
'labels': torch.tensor(label, dtype=torch.long)
}
# Create datasets and dataloaders
train_dataset = TaxonomyDataset(train_df, tokenizer, MAX_SEQ_LENGTH, parent_label_to_index)
val_dataset = TaxonomyDataset(val_df, tokenizer, MAX_SEQ_LENGTH, parent_label_to_index)
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE)
# Model Definition
class TaxonomyClassifier(nn.Module):
def __init__(self, base_model_name, num_parent_labels, num_labels):
super().__init__()
self.albert = AlbertModel.from_pretrained(base_model_name)
self.dropout = nn.Dropout(0.1)
self.classifier = nn.Linear(self.albert.config.hidden_size + num_parent_labels, num_labels)
def forward(self, input_ids, attention_mask, parent_ids):
outputs = self.albert(input_ids, attention_mask=attention_mask)
pooled_output = outputs.pooler_output
pooled_output = self.dropout(pooled_output)
combined_features = torch.cat((pooled_output, parent_ids), dim=1)
logits = self.classifier(combined_features)
return logits
# Model Initialization
model = TaxonomyClassifier(MODEL_NAME, num_parent_labels, num_labels)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Optimizer and scheduler
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)
total_steps = len(train_dataloader) * EPOCHS
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)
# Loss Function
loss_fn = nn.CrossEntropyLoss()
# Loss tracking
train_losses = []
val_losses = []
val_steps = []
best_val_loss = float('inf')
early_stopping_counter = 0
global_step = 0
# Streamlit setup
st.title(f'Level {LEVEL} Model Training')
progress_bar = st.progress(0)
status_text = st.empty()
train_loss_fig, train_loss_ax = plt.subplots()
val_loss_fig, val_loss_ax = plt.subplots()
train_loss_chart = st.pyplot(train_loss_fig)
val_loss_chart = st.pyplot(val_loss_fig)
def update_loss_charts():
train_loss_ax.clear()
train_loss_ax.plot(range(len(train_losses)), train_losses)
train_loss_ax.set_xlabel("Steps")
train_loss_ax.set_ylabel("Loss")
train_loss_ax.set_title("Training Loss")
train_loss_chart.pyplot(train_loss_fig)
val_loss_ax.clear()
val_loss_ax.plot(val_steps, val_losses)
val_loss_ax.set_xlabel("Steps")
val_loss_ax.set_ylabel("Loss")
val_loss_ax.set_title("Validation Loss")
val_loss_chart.pyplot(val_loss_fig)
# Training loop
for epoch in range(EPOCHS):
model.train()
total_train_loss = 0
for batch in tqdm(train_dataloader, desc=f'Epoch {epoch+1}/{EPOCHS}', leave=False):
optimizer.zero_grad()
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
parent_ids = batch['parent_ids'].to(device)
labels = batch['labels'].to(device)
outputs = model(input_ids, attention_mask, parent_ids)
loss = loss_fn(outputs, labels)
total_train_loss += loss.item()
loss.backward()
optimizer.step()
scheduler.step()
global_step += 1
train_losses.append(loss.item())
if LOG_EVERY_STEP:
status_text.text(f"Epoch {epoch+1}/{EPOCHS}, Step {global_step}, Training Loss: {loss.item():.4f}")
update_loss_charts()
if global_step % VAL_EVERY_STEPS == 0:
model.eval()
total_val_loss = 0
with torch.no_grad():
for val_batch in val_dataloader:
input_ids = val_batch['input_ids'].to(device)
attention_mask = val_batch['attention_mask'].to(device)
parent_ids = val_batch['parent_ids'].to(device)
labels = val_batch['labels'].to(device)
outputs = model(input_ids, attention_mask, parent_ids)
loss = loss_fn(outputs, labels)
total_val_loss += loss.item()
avg_val_loss = total_val_loss / len(val_dataloader)
val_losses.append(avg_val_loss)
val_steps.append(global_step)
status_text.text(f"Epoch {epoch+1}/{EPOCHS}, Step {global_step}, Training Loss: {loss.item():.4f}, Validation Loss: {avg_val_loss:.4f}")
update_loss_charts()
if SAVE_CHECKPOINTS:
checkpoint_dir = os.path.join(OUTPUT_DIR, f'level{LEVEL}_step{global_step}')
os.makedirs(checkpoint_dir, exist_ok=True)
torch.save(model.state_dict(), os.path.join(checkpoint_dir, 'model.safetensors'))
tokenizer.save_pretrained(checkpoint_dir)
status_text.text(f"Checkpoint saved at step {global_step}")
if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss
early_stopping_counter = 0
else:
early_stopping_counter += 1
if early_stopping_counter >= EARLY_STOPPING_PATIENCE:
status_text.text(f"Early stopping triggered at step {global_step}")
progress_bar.progress(100)
# Save final model before stopping
torch.save(model.state_dict(), os.path.join(OUTPUT_DIR, 'model.safetensors'))
tokenizer.save_pretrained(os.path.join(OUTPUT_DIR, 'model'))
exit() # Stop training
progress_bar.progress(int((global_step / total_steps) * 100))
avg_train_loss = total_train_loss / len(train_dataloader)
print(f'Epoch {epoch+1}/{EPOCHS} Average Training Loss: {avg_train_loss:.4f}')
# Save final model
torch.save(model.state_dict(), os.path.join(OUTPUT_DIR, 'model.safetensors'))
tokenizer.save_pretrained(os.path.join(OUTPUT_DIR, 'model'))
status_text.success("Training complete!")