|
""" |
|
Description: This file is used to train and evaluate the llama model. |
|
|
|
""" |
|
import os |
|
from unsloth import FastLanguageModel |
|
from trl import SFTTrainer |
|
from transformers import TrainingArguments, DataCollatorForSeq2Seq |
|
from unsloth import is_bfloat16_supported |
|
from datasets import load_dataset |
|
from unsloth.chat_templates import train_on_responses_only |
|
from unsloth.chat_templates import get_chat_template |
|
from sklearn.model_selection import train_test_split |
|
from transformers import TextStreamer |
|
import ast |
|
import contractions |
|
import re |
|
from utils import remove_inner_single_quotes |
|
|
|
|
|
max_seq_length = 2048 |
|
dtype = None |
|
load_in_4bit = True |
|
|
|
model, tokenizer = FastLanguageModel.from_pretrained( |
|
model_name = "unsloth/Llama-3.2-3B-Instruct", |
|
max_seq_length = max_seq_length, |
|
dtype = dtype, |
|
load_in_4bit = load_in_4bit, |
|
) |
|
|
|
|
|
model = FastLanguageModel.get_peft_model( |
|
model, |
|
r = 16, |
|
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", |
|
"gate_proj", "up_proj", "down_proj",], |
|
lora_alpha = 16, |
|
lora_dropout = 0, |
|
bias = "none", |
|
|
|
use_gradient_checkpointing = "unsloth", |
|
random_state = 3407, |
|
use_rslora = False, |
|
loftq_config = None, |
|
) |
|
|
|
tokenizer = get_chat_template( |
|
tokenizer, |
|
chat_template = "llama-3.1", |
|
) |
|
|
|
def formatting_prompts_func(data): |
|
convos = data['conversations'] |
|
texts = [tokenizer.apply_chat_template(convo, tokenize = False, add_generation_prompt = False) for convo in convos] |
|
|
|
return { "text" : texts, } |
|
|
|
def train(train_dataset, val_dataset): |
|
|
|
texts_train = train_dataset.map(formatting_prompts_func, batched=True) |
|
print(texts_train) |
|
texts_val = val_dataset.map(formatting_prompts_func, batched=True) |
|
print(texts_val) |
|
|
|
|
|
trainer = SFTTrainer( |
|
model = model, |
|
tokenizer = tokenizer, |
|
train_dataset = texts_train, |
|
eval_dataset= texts_val, |
|
dataset_text_field = "text", |
|
max_seq_length = max_seq_length, |
|
data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer), |
|
dataset_num_proc = 2, |
|
packing = False, |
|
args = TrainingArguments( |
|
per_device_train_batch_size = 4, |
|
gradient_accumulation_steps = 8, |
|
warmup_steps = 5, |
|
num_train_epochs = 100, |
|
max_steps = 1000, |
|
learning_rate = 2e-4, |
|
fp16 = not is_bfloat16_supported(), |
|
bf16 = is_bfloat16_supported(), |
|
logging_steps = 1, |
|
optim = "adamw_8bit", |
|
weight_decay = 0.01, |
|
lr_scheduler_type = "linear", |
|
seed = 3407, |
|
output_dir = "outputs", |
|
do_eval=True, |
|
report_to='wandb', |
|
evaluation_strategy="epoch", |
|
save_strategy="epoch", |
|
load_best_model_at_end=True, |
|
metric_for_best_model="loss", |
|
greater_is_better=False, |
|
save_total_limit=1 |
|
), |
|
) |
|
|
|
|
|
trainer = train_on_responses_only( |
|
trainer, |
|
instruction_part = "<|start_header_id|>user<|end_header_id|>\n\n", |
|
response_part = "<|start_header_id|>assistant<|end_header_id|>\n\n", |
|
) |
|
|
|
trainer_stats = trainer.train() |
|
|
|
|
|
checkpoint_dir = f"./checkpoints/{dataset_name}" |
|
os.makedirs(checkpoint_dir, exist_ok=True) |
|
checkpoint_path = os.path.join(checkpoint_dir, f"lora_model") |
|
model.save_pretrained(checkpoint_path) |
|
tokenizer.save_pretrained(checkpoint_path) |
|
|
|
print("Training completed.") |
|
|
|
return checkpoint_path |
|
|
|
def evaluate(test_dataset, checkpoint_path): |
|
|
|
|
|
model, tokenizer = FastLanguageModel.from_pretrained( |
|
model_name = checkpoint_path, |
|
max_seq_length = max_seq_length, |
|
dtype = dtype, |
|
load_in_4bit = load_in_4bit, |
|
) |
|
|
|
FastLanguageModel.for_inference(model) |
|
|
|
|
|
acc = 0 |
|
for idx, dp in enumerate(test_dataset): |
|
message = dp['conversations'][0] |
|
label = dp['conversations'][1] |
|
assert label['role'] == "assistant" |
|
|
|
|
|
inputs = tokenizer.apply_chat_template([message], tokenize = True, add_generation_prompt = True, return_tensors = 'pt').to('cuda') |
|
print(f"222, {inputs}") |
|
|
|
text_streamer = TextStreamer(tokenizer, skip_prompt = True) |
|
outputs = model.generate(input_ids = inputs, streamer = text_streamer, max_new_tokens = 128, |
|
use_cache = True, temperature = 1.5, min_p = 0.1) |
|
outputs = tokenizer.batch_decode(outputs) |
|
parts = outputs[0].split("<|start_header_id|>assistant<|end_header_id|>\n\n") |
|
|
|
results = parts[1].strip("<|eot_id|>") |
|
results = contractions.fix(results) |
|
try: |
|
results = ast.literal_eval(results) |
|
except: |
|
try: |
|
results = re.sub(r"\['(.*?)'", remove_inner_single_quotes, results) |
|
results = ast.literal_eval(results) |
|
except: |
|
results = { |
|
"Metapath":"", |
|
"Restriction":{}, |
|
} |
|
|
|
|
|
|
|
pred_metapath = results['Metapath'] |
|
|
|
ground_metapath = ast.literal_eval(label['content'])['Metapath'] |
|
if pred_metapath == ground_metapath: |
|
acc += 1 |
|
print(f"Prediction: {pred_metapath}") |
|
print(f"Ground truth: {ground_metapath}") |
|
|
|
|
|
|
|
print(f"Accuracy: {acc / len(test_dataset)}") |
|
|
|
|
|
def main(dataset_name, model_name): |
|
|
|
data_dir = f"./data/finetune" |
|
data_path = os.path.join(data_dir, f"{dataset_name}/llama_ft.jsonl") |
|
|
|
dataset = load_dataset("json", data_files=data_path) |
|
dataset = dataset['train'] |
|
|
|
|
|
dataset = dataset.add_column("index", list(range(len(dataset)))) |
|
|
|
|
|
train_test = dataset.train_test_split(test_size=0.2, seed=42) |
|
val_test = train_test['test'].train_test_split(test_size=0.5, seed=42) |
|
|
|
|
|
train_dataset = train_test['train'] |
|
val_dataset = val_test['train'] |
|
test_dataset = val_test['test'] |
|
|
|
|
|
train_dataset = train_dataset.remove_columns(['index']) |
|
val_dataset = val_dataset.remove_columns(['index']) |
|
test_dataset = test_dataset.remove_columns(['index']) |
|
|
|
|
|
checkpoint_path = train(train_dataset, val_dataset) |
|
|
|
|
|
checkpoint_path = f"./checkpoints/{dataset_name}/lora_model" |
|
|
|
evaluate(test_dataset, checkpoint_path) |
|
|
|
if __name__ == "__main__": |
|
dataset_name_list = ['mag', 'amazon', 'prime'] |
|
model_names = ["4o"] |
|
for dataset_name in dataset_name_list: |
|
for model_name in model_names: |
|
main(dataset_name, model_name) |
|
|
|
|
|
|
|
|