|
import pandas as pd
|
|
import torch
|
|
from torch.utils.data import Dataset, DataLoader
|
|
from transformers import AlbertTokenizer, AlbertModel, AdamW, get_linear_schedule_with_warmup
|
|
from sklearn.model_selection import train_test_split
|
|
import numpy as np
|
|
import os
|
|
from tqdm.auto import tqdm
|
|
import streamlit as st
|
|
import matplotlib.pyplot as plt
|
|
import torch.nn as nn
|
|
|
|
|
|
EPOCHS = 10
|
|
VAL_SPLIT = 0.1
|
|
VAL_EVERY_STEPS = 1000
|
|
BATCH_SIZE = 38
|
|
LEARNING_RATE = 5e-5
|
|
LOG_EVERY_STEP = True
|
|
SAVE_CHECKPOINTS = True
|
|
MAX_SEQ_LENGTH = 512
|
|
EARLY_STOPPING_PATIENCE = 3
|
|
MODEL_NAME = 'albert/albert-base-v2'
|
|
LEVEL = 4
|
|
OUTPUT_DIR = f'level{LEVEL}'
|
|
|
|
|
|
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
|
|
|
|
|
df = pd.read_csv(f'level_{LEVEL}.csv')
|
|
df.rename(columns={'response': 'text'}, inplace=True)
|
|
|
|
|
|
labels = sorted(df[str(LEVEL)].unique())
|
|
label_to_index = {label: i for i, label in enumerate(labels)}
|
|
index_to_label = {i: label for label, i in label_to_index.items()}
|
|
num_labels = len(labels)
|
|
|
|
|
|
np.save(os.path.join(OUTPUT_DIR, 'label_map.npy'), label_to_index)
|
|
|
|
|
|
parent_level = LEVEL - 1
|
|
parent_label_to_index = np.load(f'level{parent_level}/label_map.npy', allow_pickle=True).item()
|
|
num_parent_labels = len(parent_label_to_index)
|
|
|
|
|
|
df['label'] = df[str(LEVEL)].map(label_to_index)
|
|
train_df, val_df = train_test_split(df, test_size=VAL_SPLIT, random_state=42)
|
|
|
|
|
|
tokenizer = AlbertTokenizer.from_pretrained(MODEL_NAME)
|
|
|
|
class TaxonomyDataset(Dataset):
|
|
def __init__(self, dataframe, tokenizer, max_len, parent_label_to_index):
|
|
self.data = dataframe
|
|
self.tokenizer = tokenizer
|
|
self.max_len = max_len
|
|
self.parent_label_to_index = parent_label_to_index
|
|
|
|
def __len__(self):
|
|
return len(self.data)
|
|
|
|
def __getitem__(self, index):
|
|
text = str(self.data.iloc[index].text)
|
|
label = int(self.data.iloc[index].label)
|
|
parent_id = int(self.data.iloc[index][str(LEVEL - 1)])
|
|
|
|
encoding = self.tokenizer.encode_plus(
|
|
text,
|
|
add_special_tokens=True,
|
|
max_length=self.max_len,
|
|
padding='max_length',
|
|
truncation=True,
|
|
return_attention_mask=True,
|
|
return_tensors='pt'
|
|
)
|
|
|
|
|
|
parent_one_hot = torch.zeros(len(self.parent_label_to_index))
|
|
if parent_id != 0:
|
|
parent_index = self.parent_label_to_index.get(parent_id)
|
|
if parent_index is not None:
|
|
parent_one_hot[parent_index] = 1
|
|
|
|
return {
|
|
'input_ids': encoding['input_ids'].flatten(),
|
|
'attention_mask': encoding['attention_mask'].flatten(),
|
|
'parent_ids': parent_one_hot,
|
|
'labels': torch.tensor(label, dtype=torch.long)
|
|
}
|
|
|
|
|
|
train_dataset = TaxonomyDataset(train_df, tokenizer, MAX_SEQ_LENGTH, parent_label_to_index)
|
|
val_dataset = TaxonomyDataset(val_df, tokenizer, MAX_SEQ_LENGTH, parent_label_to_index)
|
|
|
|
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
|
|
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE)
|
|
|
|
|
|
class TaxonomyClassifier(nn.Module):
|
|
def __init__(self, base_model_name, num_parent_labels, num_labels):
|
|
super().__init__()
|
|
self.albert = AlbertModel.from_pretrained(base_model_name)
|
|
self.dropout = nn.Dropout(0.1)
|
|
self.classifier = nn.Linear(self.albert.config.hidden_size + num_parent_labels, num_labels)
|
|
|
|
def forward(self, input_ids, attention_mask, parent_ids):
|
|
outputs = self.albert(input_ids, attention_mask=attention_mask)
|
|
pooled_output = outputs.pooler_output
|
|
pooled_output = self.dropout(pooled_output)
|
|
combined_features = torch.cat((pooled_output, parent_ids), dim=1)
|
|
logits = self.classifier(combined_features)
|
|
return logits
|
|
|
|
|
|
model = TaxonomyClassifier(MODEL_NAME, num_parent_labels, num_labels)
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
model.to(device)
|
|
|
|
|
|
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)
|
|
total_steps = len(train_dataloader) * EPOCHS
|
|
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)
|
|
|
|
|
|
loss_fn = nn.CrossEntropyLoss()
|
|
|
|
|
|
train_losses = []
|
|
val_losses = []
|
|
val_steps = []
|
|
best_val_loss = float('inf')
|
|
early_stopping_counter = 0
|
|
global_step = 0
|
|
|
|
|
|
st.title(f'Level {LEVEL} Model Training')
|
|
progress_bar = st.progress(0)
|
|
status_text = st.empty()
|
|
train_loss_fig, train_loss_ax = plt.subplots()
|
|
val_loss_fig, val_loss_ax = plt.subplots()
|
|
train_loss_chart = st.pyplot(train_loss_fig)
|
|
val_loss_chart = st.pyplot(val_loss_fig)
|
|
|
|
def update_loss_charts():
|
|
train_loss_ax.clear()
|
|
train_loss_ax.plot(range(len(train_losses)), train_losses)
|
|
train_loss_ax.set_xlabel("Steps")
|
|
train_loss_ax.set_ylabel("Loss")
|
|
train_loss_ax.set_title("Training Loss")
|
|
train_loss_chart.pyplot(train_loss_fig)
|
|
|
|
val_loss_ax.clear()
|
|
val_loss_ax.plot(val_steps, val_losses)
|
|
val_loss_ax.set_xlabel("Steps")
|
|
val_loss_ax.set_ylabel("Loss")
|
|
val_loss_ax.set_title("Validation Loss")
|
|
val_loss_chart.pyplot(val_loss_fig)
|
|
|
|
|
|
for epoch in range(EPOCHS):
|
|
model.train()
|
|
total_train_loss = 0
|
|
for batch in tqdm(train_dataloader, desc=f'Epoch {epoch+1}/{EPOCHS}', leave=False):
|
|
optimizer.zero_grad()
|
|
input_ids = batch['input_ids'].to(device)
|
|
attention_mask = batch['attention_mask'].to(device)
|
|
parent_ids = batch['parent_ids'].to(device)
|
|
labels = batch['labels'].to(device)
|
|
outputs = model(input_ids, attention_mask, parent_ids)
|
|
loss = loss_fn(outputs, labels)
|
|
total_train_loss += loss.item()
|
|
loss.backward()
|
|
optimizer.step()
|
|
scheduler.step()
|
|
global_step += 1
|
|
|
|
train_losses.append(loss.item())
|
|
|
|
if LOG_EVERY_STEP:
|
|
status_text.text(f"Epoch {epoch+1}/{EPOCHS}, Step {global_step}, Training Loss: {loss.item():.4f}")
|
|
update_loss_charts()
|
|
|
|
if global_step % VAL_EVERY_STEPS == 0:
|
|
model.eval()
|
|
total_val_loss = 0
|
|
with torch.no_grad():
|
|
for val_batch in val_dataloader:
|
|
input_ids = val_batch['input_ids'].to(device)
|
|
attention_mask = val_batch['attention_mask'].to(device)
|
|
parent_ids = val_batch['parent_ids'].to(device)
|
|
labels = val_batch['labels'].to(device)
|
|
outputs = model(input_ids, attention_mask, parent_ids)
|
|
loss = loss_fn(outputs, labels)
|
|
total_val_loss += loss.item()
|
|
|
|
avg_val_loss = total_val_loss / len(val_dataloader)
|
|
val_losses.append(avg_val_loss)
|
|
val_steps.append(global_step)
|
|
status_text.text(f"Epoch {epoch+1}/{EPOCHS}, Step {global_step}, Training Loss: {loss.item():.4f}, Validation Loss: {avg_val_loss:.4f}")
|
|
update_loss_charts()
|
|
|
|
if SAVE_CHECKPOINTS:
|
|
checkpoint_dir = os.path.join(OUTPUT_DIR, f'level{LEVEL}_step{global_step}')
|
|
os.makedirs(checkpoint_dir, exist_ok=True)
|
|
torch.save(model.state_dict(), os.path.join(checkpoint_dir, 'model.safetensors'))
|
|
tokenizer.save_pretrained(checkpoint_dir)
|
|
status_text.text(f"Checkpoint saved at step {global_step}")
|
|
|
|
if avg_val_loss < best_val_loss:
|
|
best_val_loss = avg_val_loss
|
|
early_stopping_counter = 0
|
|
else:
|
|
early_stopping_counter += 1
|
|
if early_stopping_counter >= EARLY_STOPPING_PATIENCE:
|
|
status_text.text(f"Early stopping triggered at step {global_step}")
|
|
progress_bar.progress(100)
|
|
|
|
torch.save(model.state_dict(), os.path.join(OUTPUT_DIR, 'model.safetensors'))
|
|
tokenizer.save_pretrained(os.path.join(OUTPUT_DIR, 'model'))
|
|
exit()
|
|
progress_bar.progress(int((global_step / total_steps) * 100))
|
|
|
|
avg_train_loss = total_train_loss / len(train_dataloader)
|
|
print(f'Epoch {epoch+1}/{EPOCHS} Average Training Loss: {avg_train_loss:.4f}')
|
|
|
|
|
|
torch.save(model.state_dict(), os.path.join(OUTPUT_DIR, 'model.safetensors'))
|
|
tokenizer.save_pretrained(os.path.join(OUTPUT_DIR, 'model'))
|
|
status_text.success("Training complete!") |