{"cells":[{"cell_type":"code","execution_count":1,"metadata":{"_cell_guid":"b1076dfc-b9ad-4769-8c92-a6c4dae69d19","_uuid":"8f2839f25d086af736a60e9eeb907d3b93b6e0e5","execution":{"iopub.execute_input":"2024-10-20T05:15:35.864396Z","iopub.status.busy":"2024-10-20T05:15:35.863473Z","iopub.status.idle":"2024-10-20T05:15:45.149676Z","shell.execute_reply":"2024-10-20T05:15:45.148906Z","shell.execute_reply.started":"2024-10-20T05:15:35.864342Z"},"trusted":true},"outputs":[],"source":["from transformers import (\n"," AutoModelForCausalLM,\n"," AutoTokenizer,\n"," BitsAndBytesConfig,\n"," HfArgumentParser,\n"," TrainingArguments,\n"," pipeline,\n"," logging,\n",")\n","from peft import (\n"," LoraConfig,\n"," PeftModel,\n"," prepare_model_for_kbit_training,\n"," get_peft_model,\n",")\n","import os, torch, wandb\n","from datasets import load_dataset\n","from trl import SFTTrainer, setup_chat_format"]},{"cell_type":"code","execution_count":null,"metadata":{"execution":{"iopub.execute_input":"2024-10-20T05:29:05.099594Z","iopub.status.busy":"2024-10-20T05:29:05.099135Z","iopub.status.idle":"2024-10-20T05:29:08.829349Z","shell.execute_reply":"2024-10-20T05:29:08.828348Z","shell.execute_reply.started":"2024-10-20T05:29:05.099552Z"},"trusted":true},"outputs":[],"source":["from huggingface_hub import login\n","from kaggle_secrets import UserSecretsClient\n","user_secrets = UserSecretsClient()\n","\n","hf_token = user_secrets.get_secret(\"HF_TOKEN\")\n","\n","login(token = hf_token)\n","\n","wb_token = user_secrets.get_secret(\"wandb\")\n","\n","wandb.login(key=wb_token)\n","run = wandb.init(\n"," project='Fine-tune Llama 3 8B on Mathematical Word Problems', \n"," job_type=\"training\", \n"," anonymous=\"allow\"\n",")"]},{"cell_type":"code","execution_count":3,"metadata":{"execution":{"iopub.execute_input":"2024-10-20T05:15:45.405144Z","iopub.status.busy":"2024-10-20T05:15:45.404830Z","iopub.status.idle":"2024-10-20T05:15:45.409442Z","shell.execute_reply":"2024-10-20T05:15:45.408553Z","shell.execute_reply.started":"2024-10-20T05:15:45.405112Z"},"trusted":true},"outputs":[],"source":["base_model = \"/kaggle/input/llama-3.1/transformers/8b-instruct/2\"\n","dataset_name = \"microsoft/orca-math-word-problems-200k\"\n","new_model = \"llama-3.1-8b-chat-math-teacher\""]},{"cell_type":"code","execution_count":4,"metadata":{"execution":{"iopub.execute_input":"2024-10-20T05:15:45.412376Z","iopub.status.busy":"2024-10-20T05:15:45.411921Z","iopub.status.idle":"2024-10-20T05:15:45.418455Z","shell.execute_reply":"2024-10-20T05:15:45.417527Z","shell.execute_reply.started":"2024-10-20T05:15:45.412338Z"},"trusted":true},"outputs":[],"source":["torch_dtype = torch.float16\n","attn_implementation = \"eager\""]},{"cell_type":"code","execution_count":5,"metadata":{"execution":{"iopub.execute_input":"2024-10-20T05:15:45.419811Z","iopub.status.busy":"2024-10-20T05:15:45.419500Z","iopub.status.idle":"2024-10-20T05:15:45.428504Z","shell.execute_reply":"2024-10-20T05:15:45.427628Z","shell.execute_reply.started":"2024-10-20T05:15:45.419780Z"},"trusted":true},"outputs":[],"source":["# QLoRA config\n","bnb_config = BitsAndBytesConfig(\n"," load_in_4bit=True,\n"," bnb_4bit_quant_type=\"nf4\",\n"," bnb_4bit_compute_dtype=torch_dtype,\n"," bnb_4bit_use_double_quant=True,\n",")"]},{"cell_type":"code","execution_count":null,"metadata":{"execution":{"iopub.execute_input":"2024-10-20T05:19:33.334065Z","iopub.status.busy":"2024-10-20T05:19:33.332939Z","iopub.status.idle":"2024-10-20T05:20:18.698489Z","shell.execute_reply":"2024-10-20T05:20:18.697745Z","shell.execute_reply.started":"2024-10-20T05:19:33.334006Z"},"trusted":true},"outputs":[],"source":["# Load model\n","model = AutoModelForCausalLM.from_pretrained(\n"," base_model,\n"," quantization_config=bnb_config,\n"," device_map=\"auto\",\n"," attn_implementation=attn_implementation\n",")"]},{"cell_type":"code","execution_count":8,"metadata":{"execution":{"iopub.execute_input":"2024-10-20T05:20:23.623910Z","iopub.status.busy":"2024-10-20T05:20:23.623492Z","iopub.status.idle":"2024-10-20T05:20:24.294298Z","shell.execute_reply":"2024-10-20T05:20:24.293468Z","shell.execute_reply.started":"2024-10-20T05:20:23.623870Z"},"trusted":true},"outputs":[],"source":["# Load tokenizer\n","tokenizer = AutoTokenizer.from_pretrained(base_model)\n","model, tokenizer = setup_chat_format(model, tokenizer)"]},{"cell_type":"code","execution_count":9,"metadata":{"execution":{"iopub.execute_input":"2024-10-20T05:20:32.472505Z","iopub.status.busy":"2024-10-20T05:20:32.471630Z","iopub.status.idle":"2024-10-20T05:20:33.257761Z","shell.execute_reply":"2024-10-20T05:20:33.256711Z","shell.execute_reply.started":"2024-10-20T05:20:32.472461Z"},"trusted":true},"outputs":[],"source":["# LoRA config\n","peft_config = LoraConfig(\n"," r=16,\n"," lora_alpha=32,\n"," lora_dropout=0.05,\n"," bias=\"none\",\n"," task_type=\"CAUSAL_LM\",\n"," target_modules=['up_proj', 'down_proj', 'gate_proj', 'k_proj', 'q_proj', 'v_proj', 'o_proj']\n",")\n","model = get_peft_model(model, peft_config)"]},{"cell_type":"code","execution_count":null,"metadata":{"execution":{"iopub.execute_input":"2024-10-20T05:20:40.458399Z","iopub.status.busy":"2024-10-20T05:20:40.457863Z","iopub.status.idle":"2024-10-20T05:20:44.412895Z","shell.execute_reply":"2024-10-20T05:20:44.411800Z","shell.execute_reply.started":"2024-10-20T05:20:40.458348Z"},"trusted":true},"outputs":[],"source":["# Importing the dataset\n","dataset = load_dataset(dataset_name, split=\"all\")\n","dataset = dataset.shuffle(seed=42).select(range(1000)) # Only use 1000 samples for quick demo\n","\n","def format_chat_template(row):\n"," row_json = [{\"role\": \"user\", \"content\": row[\"question\"]},\n"," {\"role\": \"assistant\", \"content\": row[\"answer\"]}]\n"," row[\"text\"] = tokenizer.apply_chat_template(row_json, tokenize=False)\n"," return row\n","\n","dataset = dataset.map(\n"," format_chat_template,\n"," num_proc=4,\n",")\n","\n","dataset['text'][3]"]},{"cell_type":"code","execution_count":11,"metadata":{"execution":{"iopub.execute_input":"2024-10-20T05:21:15.549983Z","iopub.status.busy":"2024-10-20T05:21:15.548983Z","iopub.status.idle":"2024-10-20T05:21:15.566472Z","shell.execute_reply":"2024-10-20T05:21:15.565431Z","shell.execute_reply.started":"2024-10-20T05:21:15.549934Z"},"trusted":true},"outputs":[],"source":["dataset = dataset.train_test_split(test_size=0.2)"]},{"cell_type":"code","execution_count":18,"metadata":{"execution":{"iopub.execute_input":"2024-10-20T05:29:56.567672Z","iopub.status.busy":"2024-10-20T05:29:56.566941Z","iopub.status.idle":"2024-10-20T05:29:56.595910Z","shell.execute_reply":"2024-10-20T05:29:56.595117Z","shell.execute_reply.started":"2024-10-20T05:29:56.567628Z"},"trusted":true},"outputs":[],"source":["training_arguments = TrainingArguments(\n"," output_dir=new_model,\n"," per_device_train_batch_size=1,\n"," per_device_eval_batch_size=1,\n"," gradient_accumulation_steps=2,\n"," optim=\"paged_adamw_32bit\",\n"," num_train_epochs=1,\n"," eval_strategy=\"steps\",\n"," eval_steps=0.2,\n"," logging_steps=1,\n"," warmup_steps=10,\n"," logging_strategy=\"steps\",\n"," learning_rate=2e-4,\n"," fp16=False,\n"," bf16=False,\n"," group_by_length=True,\n"," report_to=\"wandb\"\n",")"]},{"cell_type":"code","execution_count":null,"metadata":{"execution":{"iopub.execute_input":"2024-10-20T05:30:01.680282Z","iopub.status.busy":"2024-10-20T05:30:01.679557Z","iopub.status.idle":"2024-10-20T05:30:03.062818Z","shell.execute_reply":"2024-10-20T05:30:03.061888Z","shell.execute_reply.started":"2024-10-20T05:30:01.680235Z"},"trusted":true},"outputs":[],"source":["trainer = SFTTrainer(\n"," model=model,\n"," train_dataset=dataset[\"train\"],\n"," eval_dataset=dataset[\"test\"],\n"," peft_config=peft_config,\n"," max_seq_length=512,\n"," dataset_text_field=\"text\",\n"," tokenizer=tokenizer,\n"," args=training_arguments,\n"," packing= False,\n",")"]},{"cell_type":"code","execution_count":20,"metadata":{"execution":{"iopub.execute_input":"2024-10-20T05:30:06.284325Z","iopub.status.busy":"2024-10-20T05:30:06.283458Z","iopub.status.idle":"2024-10-20T06:12:45.133545Z","shell.execute_reply":"2024-10-20T06:12:45.132593Z","shell.execute_reply.started":"2024-10-20T05:30:06.284279Z"},"trusted":true},"outputs":[{"data":{"text/html":["\n","
Step | \n","Training Loss | \n","Validation Loss | \n","
---|---|---|
80 | \n","0.438500 | \n","0.620739 | \n","
160 | \n","0.396200 | \n","0.606882 | \n","
240 | \n","0.606400 | \n","0.591038 | \n","
320 | \n","0.565300 | \n","0.583970 | \n","
400 | \n","0.944000 | \n","0.576970 | \n","
"],"text/plain":["Run history:
eval/loss █▆▃▂▁ eval/runtime ▃▅▆▁█ eval/samples_per_second ▁▁▁▁▁ eval/steps_per_second ▁▁▁▁▁ train/epoch ▁▁▁▁▁▂▂▂▂▂▂▂▃▃▃▃▄▄▄▄▄▄▄▅▅▅▅▅▆▆▇▇▇▇▇█████ train/global_step ▁▁▂▂▂▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▆▆▆▆▆▆▆▇▇▇▇▇█████ train/grad_norm ▄▃▅▃▃▁▂▃▂▂▃▁▂▃▂▃▅▇▂▂▂▄█▂▂▃▄▃▅▇▂▂▃▃▃▁▁▂▃▅ train/learning_rate ▂████▇▇▇▇▇▆▆▆▆▆▆▆▆▅▅▄▄▄▄▄▄▃▃▃▃▃▃▃▃▃▂▂▁▁▁ train/loss ▅▆▄▅▂▅▂▂▅▅▃▄▅▃▅▅▂▃▄▂▄▇▃▁▄▃▇▂▃▃▂▃▃▁▃▆▁▄▁█ Run summary:
eval/loss 0.57697 eval/runtime 195.8812 eval/samples_per_second 1.021 eval/steps_per_second 1.021 total_flos 1.1602099108503552e+16 train/epoch 1 train/global_step 400 train/grad_norm 1.6999 train/learning_rate 0 train/loss 0.944 train_loss 0.57197 train_runtime 2557.7609 train_samples_per_second 0.313 train_steps_per_second 0.156
View project at: https://wandb.ai/ccapo-astro-siaa-research/Fine-tune%20Llama%203%208B%20on%20Mathematical%20Word%20Problems
Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)"],"text/plain":["./wandb/run-20241020_052905-n7jezz5x/logs
"],"text/plain":["