File size: 17,187 Bytes
564354e |
1 2 |
{"cells":[{"cell_type":"code","execution_count":1,"metadata":{"_cell_guid":"b1076dfc-b9ad-4769-8c92-a6c4dae69d19","_uuid":"8f2839f25d086af736a60e9eeb907d3b93b6e0e5","execution":{"iopub.execute_input":"2024-10-20T05:15:35.864396Z","iopub.status.busy":"2024-10-20T05:15:35.863473Z","iopub.status.idle":"2024-10-20T05:15:45.149676Z","shell.execute_reply":"2024-10-20T05:15:45.148906Z","shell.execute_reply.started":"2024-10-20T05:15:35.864342Z"},"trusted":true},"outputs":[],"source":["from transformers import (\n"," AutoModelForCausalLM,\n"," AutoTokenizer,\n"," BitsAndBytesConfig,\n"," HfArgumentParser,\n"," TrainingArguments,\n"," pipeline,\n"," logging,\n",")\n","from peft import (\n"," LoraConfig,\n"," PeftModel,\n"," prepare_model_for_kbit_training,\n"," get_peft_model,\n",")\n","import os, torch, wandb\n","from datasets import load_dataset\n","from trl import SFTTrainer, setup_chat_format"]},{"cell_type":"code","execution_count":null,"metadata":{"execution":{"iopub.execute_input":"2024-10-20T05:29:05.099594Z","iopub.status.busy":"2024-10-20T05:29:05.099135Z","iopub.status.idle":"2024-10-20T05:29:08.829349Z","shell.execute_reply":"2024-10-20T05:29:08.828348Z","shell.execute_reply.started":"2024-10-20T05:29:05.099552Z"},"trusted":true},"outputs":[],"source":["from huggingface_hub import login\n","from kaggle_secrets import UserSecretsClient\n","user_secrets = UserSecretsClient()\n","\n","hf_token = user_secrets.get_secret(\"HF_TOKEN\")\n","\n","login(token = hf_token)\n","\n","wb_token = user_secrets.get_secret(\"wandb\")\n","\n","wandb.login(key=wb_token)\n","run = wandb.init(\n"," project='Fine-tune Llama 3 8B on Mathematical Word Problems', \n"," job_type=\"training\", \n"," anonymous=\"allow\"\n",")"]},{"cell_type":"code","execution_count":3,"metadata":{"execution":{"iopub.execute_input":"2024-10-20T05:15:45.405144Z","iopub.status.busy":"2024-10-20T05:15:45.404830Z","iopub.status.idle":"2024-10-20T05:15:45.409442Z","shell.execute_reply":"2024-10-20T05:15:45.408553Z","shell.execute_reply.started":"2024-10-20T05:15:45.405112Z"},"trusted":true},"outputs":[],"source":["base_model = \"/kaggle/input/llama-3.1/transformers/8b-instruct/2\"\n","dataset_name = \"microsoft/orca-math-word-problems-200k\"\n","new_model = \"llama-3.1-8b-chat-math-teacher\""]},{"cell_type":"code","execution_count":4,"metadata":{"execution":{"iopub.execute_input":"2024-10-20T05:15:45.412376Z","iopub.status.busy":"2024-10-20T05:15:45.411921Z","iopub.status.idle":"2024-10-20T05:15:45.418455Z","shell.execute_reply":"2024-10-20T05:15:45.417527Z","shell.execute_reply.started":"2024-10-20T05:15:45.412338Z"},"trusted":true},"outputs":[],"source":["torch_dtype = torch.float16\n","attn_implementation = \"eager\""]},{"cell_type":"code","execution_count":5,"metadata":{"execution":{"iopub.execute_input":"2024-10-20T05:15:45.419811Z","iopub.status.busy":"2024-10-20T05:15:45.419500Z","iopub.status.idle":"2024-10-20T05:15:45.428504Z","shell.execute_reply":"2024-10-20T05:15:45.427628Z","shell.execute_reply.started":"2024-10-20T05:15:45.419780Z"},"trusted":true},"outputs":[],"source":["# QLoRA config\n","bnb_config = BitsAndBytesConfig(\n"," load_in_4bit=True,\n"," bnb_4bit_quant_type=\"nf4\",\n"," bnb_4bit_compute_dtype=torch_dtype,\n"," bnb_4bit_use_double_quant=True,\n",")"]},{"cell_type":"code","execution_count":null,"metadata":{"execution":{"iopub.execute_input":"2024-10-20T05:19:33.334065Z","iopub.status.busy":"2024-10-20T05:19:33.332939Z","iopub.status.idle":"2024-10-20T05:20:18.698489Z","shell.execute_reply":"2024-10-20T05:20:18.697745Z","shell.execute_reply.started":"2024-10-20T05:19:33.334006Z"},"trusted":true},"outputs":[],"source":["# Load model\n","model = AutoModelForCausalLM.from_pretrained(\n"," base_model,\n"," quantization_config=bnb_config,\n"," device_map=\"auto\",\n"," attn_implementation=attn_implementation\n",")"]},{"cell_type":"code","execution_count":8,"metadata":{"execution":{"iopub.execute_input":"2024-10-20T05:20:23.623910Z","iopub.status.busy":"2024-10-20T05:20:23.623492Z","iopub.status.idle":"2024-10-20T05:20:24.294298Z","shell.execute_reply":"2024-10-20T05:20:24.293468Z","shell.execute_reply.started":"2024-10-20T05:20:23.623870Z"},"trusted":true},"outputs":[],"source":["# Load tokenizer\n","tokenizer = AutoTokenizer.from_pretrained(base_model)\n","model, tokenizer = setup_chat_format(model, tokenizer)"]},{"cell_type":"code","execution_count":9,"metadata":{"execution":{"iopub.execute_input":"2024-10-20T05:20:32.472505Z","iopub.status.busy":"2024-10-20T05:20:32.471630Z","iopub.status.idle":"2024-10-20T05:20:33.257761Z","shell.execute_reply":"2024-10-20T05:20:33.256711Z","shell.execute_reply.started":"2024-10-20T05:20:32.472461Z"},"trusted":true},"outputs":[],"source":["# LoRA config\n","peft_config = LoraConfig(\n"," r=16,\n"," lora_alpha=32,\n"," lora_dropout=0.05,\n"," bias=\"none\",\n"," task_type=\"CAUSAL_LM\",\n"," target_modules=['up_proj', 'down_proj', 'gate_proj', 'k_proj', 'q_proj', 'v_proj', 'o_proj']\n",")\n","model = get_peft_model(model, peft_config)"]},{"cell_type":"code","execution_count":null,"metadata":{"execution":{"iopub.execute_input":"2024-10-20T05:20:40.458399Z","iopub.status.busy":"2024-10-20T05:20:40.457863Z","iopub.status.idle":"2024-10-20T05:20:44.412895Z","shell.execute_reply":"2024-10-20T05:20:44.411800Z","shell.execute_reply.started":"2024-10-20T05:20:40.458348Z"},"trusted":true},"outputs":[],"source":["# Importing the dataset\n","dataset = load_dataset(dataset_name, split=\"all\")\n","dataset = dataset.shuffle(seed=42).select(range(1000)) # Only use 1000 samples for quick demo\n","\n","def format_chat_template(row):\n"," row_json = [{\"role\": \"user\", \"content\": row[\"question\"]},\n"," {\"role\": \"assistant\", \"content\": row[\"answer\"]}]\n"," row[\"text\"] = tokenizer.apply_chat_template(row_json, tokenize=False)\n"," return row\n","\n","dataset = dataset.map(\n"," format_chat_template,\n"," num_proc=4,\n",")\n","\n","dataset['text'][3]"]},{"cell_type":"code","execution_count":11,"metadata":{"execution":{"iopub.execute_input":"2024-10-20T05:21:15.549983Z","iopub.status.busy":"2024-10-20T05:21:15.548983Z","iopub.status.idle":"2024-10-20T05:21:15.566472Z","shell.execute_reply":"2024-10-20T05:21:15.565431Z","shell.execute_reply.started":"2024-10-20T05:21:15.549934Z"},"trusted":true},"outputs":[],"source":["dataset = dataset.train_test_split(test_size=0.2)"]},{"cell_type":"code","execution_count":18,"metadata":{"execution":{"iopub.execute_input":"2024-10-20T05:29:56.567672Z","iopub.status.busy":"2024-10-20T05:29:56.566941Z","iopub.status.idle":"2024-10-20T05:29:56.595910Z","shell.execute_reply":"2024-10-20T05:29:56.595117Z","shell.execute_reply.started":"2024-10-20T05:29:56.567628Z"},"trusted":true},"outputs":[],"source":["training_arguments = TrainingArguments(\n"," output_dir=new_model,\n"," per_device_train_batch_size=1,\n"," per_device_eval_batch_size=1,\n"," gradient_accumulation_steps=2,\n"," optim=\"paged_adamw_32bit\",\n"," num_train_epochs=1,\n"," eval_strategy=\"steps\",\n"," eval_steps=0.2,\n"," logging_steps=1,\n"," warmup_steps=10,\n"," logging_strategy=\"steps\",\n"," learning_rate=2e-4,\n"," fp16=False,\n"," bf16=False,\n"," group_by_length=True,\n"," report_to=\"wandb\"\n",")"]},{"cell_type":"code","execution_count":null,"metadata":{"execution":{"iopub.execute_input":"2024-10-20T05:30:01.680282Z","iopub.status.busy":"2024-10-20T05:30:01.679557Z","iopub.status.idle":"2024-10-20T05:30:03.062818Z","shell.execute_reply":"2024-10-20T05:30:03.061888Z","shell.execute_reply.started":"2024-10-20T05:30:01.680235Z"},"trusted":true},"outputs":[],"source":["trainer = SFTTrainer(\n"," model=model,\n"," train_dataset=dataset[\"train\"],\n"," eval_dataset=dataset[\"test\"],\n"," peft_config=peft_config,\n"," max_seq_length=512,\n"," dataset_text_field=\"text\",\n"," tokenizer=tokenizer,\n"," args=training_arguments,\n"," packing= False,\n",")"]},{"cell_type":"code","execution_count":20,"metadata":{"execution":{"iopub.execute_input":"2024-10-20T05:30:06.284325Z","iopub.status.busy":"2024-10-20T05:30:06.283458Z","iopub.status.idle":"2024-10-20T06:12:45.133545Z","shell.execute_reply":"2024-10-20T06:12:45.132593Z","shell.execute_reply.started":"2024-10-20T05:30:06.284279Z"},"trusted":true},"outputs":[{"data":{"text/html":["\n"," <div>\n"," \n"," <progress value='400' max='400' style='width:300px; height:20px; vertical-align: middle;'></progress>\n"," [400/400 42:31, Epoch 1/1]\n"," </div>\n"," <table border=\"1\" class=\"dataframe\">\n"," <thead>\n"," <tr style=\"text-align: left;\">\n"," <th>Step</th>\n"," <th>Training Loss</th>\n"," <th>Validation Loss</th>\n"," </tr>\n"," </thead>\n"," <tbody>\n"," <tr>\n"," <td>80</td>\n"," <td>0.438500</td>\n"," <td>0.620739</td>\n"," </tr>\n"," <tr>\n"," <td>160</td>\n"," <td>0.396200</td>\n"," <td>0.606882</td>\n"," </tr>\n"," <tr>\n"," <td>240</td>\n"," <td>0.606400</td>\n"," <td>0.591038</td>\n"," </tr>\n"," <tr>\n"," <td>320</td>\n"," <td>0.565300</td>\n"," <td>0.583970</td>\n"," </tr>\n"," <tr>\n"," <td>400</td>\n"," <td>0.944000</td>\n"," <td>0.576970</td>\n"," </tr>\n"," </tbody>\n","</table><p>"],"text/plain":["<IPython.core.display.HTML object>"]},"metadata":{},"output_type":"display_data"},{"name":"stderr","output_type":"stream","text":["/opt/conda/lib/python3.10/site-packages/peft/utils/save_and_load.py:257: UserWarning: Setting `save_embedding_layers` to `True` as the embedding layer has been resized during finetuning.\n"," warnings.warn(\n"]}],"source":["history = trainer.train()"]},{"cell_type":"code","execution_count":21,"metadata":{"execution":{"iopub.execute_input":"2024-10-20T06:12:45.135502Z","iopub.status.busy":"2024-10-20T06:12:45.135185Z","iopub.status.idle":"2024-10-20T06:12:46.781460Z","shell.execute_reply":"2024-10-20T06:12:46.780763Z","shell.execute_reply.started":"2024-10-20T06:12:45.135469Z"},"trusted":true},"outputs":[{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"","version_major":2,"version_minor":0},"text/plain":["VBox(children=(Label(value='0.028 MB of 0.028 MB uploaded\\r'), FloatProgress(value=1.0, max=1.0)))"]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["<style>\n"," table.wandb td:nth-child(1) { padding: 0 10px; text-align: left ; width: auto;} td:nth-child(2) {text-align: left ; width: 100%}\n"," .wandb-row { display: flex; flex-direction: row; flex-wrap: wrap; justify-content: flex-start; width: 100% }\n"," .wandb-col { display: flex; flex-direction: column; flex-basis: 100%; flex: 1; padding: 10px; }\n"," </style>\n","<div class=\"wandb-row\"><div class=\"wandb-col\"><h3>Run history:</h3><br/><table class=\"wandb\"><tr><td>eval/loss</td><td>βββββ</td></tr><tr><td>eval/runtime</td><td>ββ
βββ</td></tr><tr><td>eval/samples_per_second</td><td>βββββ</td></tr><tr><td>eval/steps_per_second</td><td>βββββ</td></tr><tr><td>train/epoch</td><td>ββββββββββββββββββββββββ
β
β
β
β
ββββββββββββ</td></tr><tr><td>train/global_step</td><td>ββββββββββββββββββββββ
β
βββββββββββββββββ</td></tr><tr><td>train/grad_norm</td><td>βββ
ββββββββββββββ
ββββββββββββ
βββββββββββ
</td></tr><tr><td>train/learning_rate</td><td>βββββββββββββββββββ
β
ββββββββββββββββββββ</td></tr><tr><td>train/loss</td><td>β
βββ
ββ
βββ
β
βββ
ββ
β
ββββββββββββββββββββββββ</td></tr></table><br/></div><div class=\"wandb-col\"><h3>Run summary:</h3><br/><table class=\"wandb\"><tr><td>eval/loss</td><td>0.57697</td></tr><tr><td>eval/runtime</td><td>195.8812</td></tr><tr><td>eval/samples_per_second</td><td>1.021</td></tr><tr><td>eval/steps_per_second</td><td>1.021</td></tr><tr><td>total_flos</td><td>1.1602099108503552e+16</td></tr><tr><td>train/epoch</td><td>1</td></tr><tr><td>train/global_step</td><td>400</td></tr><tr><td>train/grad_norm</td><td>1.6999</td></tr><tr><td>train/learning_rate</td><td>0</td></tr><tr><td>train/loss</td><td>0.944</td></tr><tr><td>train_loss</td><td>0.57197</td></tr><tr><td>train_runtime</td><td>2557.7609</td></tr><tr><td>train_samples_per_second</td><td>0.313</td></tr><tr><td>train_steps_per_second</td><td>0.156</td></tr></table><br/></div></div>"],"text/plain":["<IPython.core.display.HTML object>"]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":[" View run <strong style=\"color:#cdcd00\">zesty-snowball-1</strong> at: <a href='https://wandb.ai/ccapo-astro-siaa-research/Fine-tune%20Llama%203%208B%20on%20Mathematical%20Word%20Problems/runs/n7jezz5x' target=\"_blank\">https://wandb.ai/ccapo-astro-siaa-research/Fine-tune%20Llama%203%208B%20on%20Mathematical%20Word%20Problems/runs/n7jezz5x</a><br/> View project at: <a href='https://wandb.ai/ccapo-astro-siaa-research/Fine-tune%20Llama%203%208B%20on%20Mathematical%20Word%20Problems' target=\"_blank\">https://wandb.ai/ccapo-astro-siaa-research/Fine-tune%20Llama%203%208B%20on%20Mathematical%20Word%20Problems</a><br/>Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)"],"text/plain":["<IPython.core.display.HTML object>"]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["Find logs at: <code>./wandb/run-20241020_052905-n7jezz5x/logs</code>"],"text/plain":["<IPython.core.display.HTML object>"]},"metadata":{},"output_type":"display_data"}],"source":["wandb.finish()\n","model.config.use_cache = True"]},{"cell_type":"code","execution_count":29,"metadata":{"execution":{"iopub.execute_input":"2024-10-20T06:25:32.333755Z","iopub.status.busy":"2024-10-20T06:25:32.332841Z","iopub.status.idle":"2024-10-20T06:26:06.729107Z","shell.execute_reply":"2024-10-20T06:26:06.728129Z","shell.execute_reply.started":"2024-10-20T06:25:32.333693Z"},"trusted":true},"outputs":[{"name":"stdout","output_type":"stream","text":["\n","Let's denote Parisa's current age as P and her mother's current age as M.\n","\n","According to the first condition, the age difference between Parisa and her mother is 40 years, so we can write:\n","\n","M = P + 40 (1)\n","\n","According to the second condition, after 15 years, the age of her mother will be three times that of Parisa. So we can write:\n","\n","M + 15 = 3 * (P + 15) (2)\n","\n","Now, let's substitute the expression for M from equation (1) into equation (2):\n","\n","(P + 40) + 15 = 3 * (P + 15)\n","\n","Now, let's solve for P:\n","\n","P + 40 + 15 = 3P + 45\n","\n","Combine like terms:\n","\n","P + 55 = 3P + 45\n","\n","Subtract P from both sides:\n","\n","55 = 2P + 45\n","\n","Subtract 45 from both sides:\n","\n","10 = 2P\n","\n","Divide both sides by 2:\n","\n","P = 5\n","\n","So, Parisa is currently 5 years old.\n","\n"]}],"source":["messages = [\n"," {\n"," \"role\": \"user\",\n"," \"content\": \"This year, the age difference between Parisa and her mother is 40 years, and after 15 years, the age of her mother will be three times that of Parisa. Find the age of Parisa this year.\"\n"," }\n","]\n","\n","prompt = tokenizer.apply_chat_template(messages, tokenize=False, \n"," add_generation_prompt=True)\n","\n","inputs = tokenizer(prompt, return_tensors='pt', padding=True, \n"," truncation=True).to(\"cuda\")\n","\n","outputs = model.generate(**inputs, max_length=300, \n"," num_return_sequences=1)\n","\n","text = tokenizer.decode(outputs[0], skip_special_tokens=True)\n","\n","print(text.split(\"assistant\")[1])"]},{"cell_type":"code","execution_count":null,"metadata":{"execution":{"iopub.execute_input":"2024-10-20T06:13:27.806915Z","iopub.status.busy":"2024-10-20T06:13:27.806495Z","iopub.status.idle":"2024-10-20T06:14:46.231972Z","shell.execute_reply":"2024-10-20T06:14:46.231004Z","shell.execute_reply.started":"2024-10-20T06:13:27.806870Z"},"trusted":true},"outputs":[],"source":["trainer.model.save_pretrained(new_model)\n","trainer.model.push_to_hub(new_model, use_temp_dir=False)"]}],"metadata":{"kaggle":{"accelerator":"gpu","dataSources":[{"isSourceIdPinned":true,"modelId":91102,"modelInstanceId":68809,"sourceId":104449,"sourceType":"modelInstanceVersion"}],"dockerImageVersionId":30787,"isGpuEnabled":true,"isInternetEnabled":true,"language":"python","sourceType":"notebook"},"kernelspec":{"display_name":"Python 3","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.10.14"}},"nbformat":4,"nbformat_minor":4}
|