{"cells":[{"cell_type":"code","execution_count":null,"metadata":{"execution":{"iopub.execute_input":"2024-10-21T05:02:45.320064Z","iopub.status.busy":"2024-10-21T05:02:45.319624Z","iopub.status.idle":"2024-10-21T05:02:46.019192Z","shell.execute_reply":"2024-10-21T05:02:46.018223Z","shell.execute_reply.started":"2024-10-21T05:02:45.320020Z"},"trusted":true},"outputs":[],"source":["from huggingface_hub import login\n","from kaggle_secrets import UserSecretsClient\n","user_secrets = UserSecretsClient()\n","\n","hf_token = user_secrets.get_secret(\"HF_TOKEN\")\n","login(token = hf_token)"]},{"cell_type":"code","execution_count":2,"metadata":{"execution":{"iopub.execute_input":"2024-10-21T05:02:46.024625Z","iopub.status.busy":"2024-10-21T05:02:46.024248Z","iopub.status.idle":"2024-10-21T05:02:46.028844Z","shell.execute_reply":"2024-10-21T05:02:46.027966Z","shell.execute_reply.started":"2024-10-21T05:02:46.024574Z"},"trusted":true},"outputs":[],"source":["base_model = \"/kaggle/input/llama-3.1/transformers/8b-instruct/2\"\n","new_model = \"/kaggle/input/llama-3-fine-tune-math-problems/llama-3.1-8b-chat-math-teacher/\""]},{"cell_type":"code","execution_count":3,"metadata":{"execution":{"iopub.execute_input":"2024-10-21T05:02:46.030396Z","iopub.status.busy":"2024-10-21T05:02:46.030104Z","iopub.status.idle":"2024-10-21T05:02:53.155455Z","shell.execute_reply":"2024-10-21T05:02:53.154615Z","shell.execute_reply.started":"2024-10-21T05:02:46.030364Z"},"trusted":true},"outputs":[],"source":["from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline\n","from peft import PeftModel\n","import torch\n","from trl import setup_chat_format"]},{"cell_type":"markdown","metadata":{},"source":["## Note\n","Depending on the base model size, the following commands may require a GPU with more memory or multiple GPUs"]},{"cell_type":"code","execution_count":null,"metadata":{"execution":{"iopub.execute_input":"2024-10-21T05:02:53.158650Z","iopub.status.busy":"2024-10-21T05:02:53.157906Z","iopub.status.idle":"2024-10-21T05:04:28.373624Z","shell.execute_reply":"2024-10-21T05:04:28.372521Z","shell.execute_reply.started":"2024-10-21T05:02:53.158601Z"},"trusted":true},"outputs":[],"source":["# Reload tokenizer and model\n","tokenizer = AutoTokenizer.from_pretrained(base_model)\n","\n","base_model_reload = AutoModelForCausalLM.from_pretrained(\n"," base_model,\n"," return_dict=True,\n"," low_cpu_mem_usage=True,\n"," torch_dtype=torch.float16,\n"," device_map=\"auto\",\n"," trust_remote_code=True,\n"," offload_buffers=True,\n",")\n","\n","base_model_reload, tokenizer = setup_chat_format(base_model_reload, tokenizer)\n","\n","# Merge adapter with base model\n","model = PeftModel.from_pretrained(base_model_reload, new_model)\n","\n","model = model.merge_and_unload()"]},{"cell_type":"code","execution_count":7,"metadata":{"execution":{"iopub.execute_input":"2024-10-21T05:07:43.870798Z","iopub.status.busy":"2024-10-21T05:07:43.869721Z","iopub.status.idle":"2024-10-21T05:08:04.414201Z","shell.execute_reply":"2024-10-21T05:08:04.413170Z","shell.execute_reply.started":"2024-10-21T05:07:43.870731Z"},"trusted":true},"outputs":[{"name":"stdout","output_type":"stream","text":["<|im_start|>user\n","Find the sum of all multiples of 9 that are less than 100.<|im_end|>\n","<|im_start|>assistant\n","To find the sum of all multiples of 9 that are less than 100, we first need to identify the last multiple of 9 that is less than 100. Since 99 is the largest multiple of 9 less than 100, we will use that number.\n","\n","The formula to find the sum of an arithmetic series is:\n","\n","Sum = n/2 * (first term + last term)\n","\n","Where n is the number of terms in the series.\n","\n","The first term is the first multiple of 9, which is 9, and the last term is 99. To find the number of terms (n), we can use the formula:\n","\n","n = (last term - first term) / common difference + 1\n","\n","In this case, the common difference is 9, since we are adding 9 to get from one multiple to the next.\n","\n","So, let's calculate:\n","\n","n = (99 - 9) / 9 + 1\n","n = 90 / 9 + 1\n","n = 10 + 1\n","n = 11\n","\n","Now we can find the sum:\n","\n","Sum = 11/2 * (9 + 99)\n","Sum = 11/2 * 108\n","Sum = 5.5 * 108\n","Sum = 594\n","\n","Therefore, the sum of all multiples of 9 that are less than 100 is 594.\n","assistant\n","The sum of all multiples of 9 that are less than \n"]}],"source":["messages = [{\"role\": \"user\", \"content\": \"Find the sum of all multiples of 9 that are less than 100.\"}]\n","\n","prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n","pipe = pipeline(\n"," \"text-generation\",\n"," model=model,\n"," tokenizer=tokenizer,\n"," torch_dtype=torch.float16,\n"," device_map=\"auto\",\n",")\n","\n","outputs = pipe(prompt, max_new_tokens=300, do_sample=True, temperature=0.7, top_k=50, top_p=0.95)\n","print(outputs[0][\"generated_text\"])"]},{"cell_type":"code","execution_count":8,"metadata":{"execution":{"iopub.execute_input":"2024-10-21T05:11:53.899166Z","iopub.status.busy":"2024-10-21T05:11:53.898412Z","iopub.status.idle":"2024-10-21T05:12:51.165328Z","shell.execute_reply":"2024-10-21T05:12:51.164303Z","shell.execute_reply.started":"2024-10-21T05:11:53.899120Z"},"trusted":true},"outputs":[{"data":{"text/plain":["('llama-3.1-8b-chat-math-teacher/tokenizer_config.json',\n"," 'llama-3.1-8b-chat-math-teacher/special_tokens_map.json',\n"," 'llama-3.1-8b-chat-math-teacher/tokenizer.json')"]},"execution_count":8,"metadata":{},"output_type":"execute_result"}],"source":["model.save_pretrained(\"llama-3.1-8b-chat-math-teacher\")\n","tokenizer.save_pretrained(\"llama-3.1-8b-chat-math-teacher\")"]},{"cell_type":"code","execution_count":null,"metadata":{"execution":{"iopub.execute_input":"2024-10-21T05:13:06.910077Z","iopub.status.busy":"2024-10-21T05:13:06.909659Z","iopub.status.idle":"2024-10-21T05:17:56.332173Z","shell.execute_reply":"2024-10-21T05:17:56.331177Z","shell.execute_reply.started":"2024-10-21T05:13:06.910038Z"},"trusted":true},"outputs":[],"source":["model.push_to_hub(\"llama-3.1-8b-chat-math-teacher\", use_temp_dir=False)\n","tokenizer.push_to_hub(\"llama-3.1-8b-chat-math-teacher\", use_temp_dir=False)"]}],"metadata":{"kaggle":{"accelerator":"nvidiaTeslaT4","dataSources":[{"sourceId":202208119,"sourceType":"kernelVersion"},{"isSourceIdPinned":true,"modelId":91102,"modelInstanceId":68809,"sourceId":104449,"sourceType":"modelInstanceVersion"}],"dockerImageVersionId":30787,"isGpuEnabled":true,"isInternetEnabled":true,"language":"python","sourceType":"notebook"},"kernelspec":{"display_name":"Python 3","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.10.14"}},"nbformat":4,"nbformat_minor":4}