{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [], "gpuType": "T4" }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" }, "accelerator": "GPU" }, "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "id": "PFGREH3kAkep" }, "outputs": [], "source": [ "!pip install -q git+https://github.com/huggingface/transformers.git git+https://github.com/huggingface/accelerate.git\n", "!pip install -q datasets huggingface-hub trl\n", "!pip install -q git+https://github.com/huggingface/peft.git" ] }, { "cell_type": "code", "source": [ "import re\n", "import torch\n", "from datasets import load_dataset, Dataset\n", "from transformers import AutoTokenizer, AutoModelForCausalLM\n", "from peft import LoraConfig\n", "from trl import GRPOConfig, GRPOTrainer" ], "metadata": { "id": "PTbegv7kAwXd" }, "execution_count": 3, "outputs": [] }, { "cell_type": "code", "source": [ "SYSTEM_PROMPT = \"\"\"\n", "Respond in the following format:\n", "\n", "...\n", "\n", "\n", "...\n", "\n", "\"\"\"\n", "\n", "XML_COT_FORMAT = \"\"\"\\\n", "\n", "{reasoning}\n", "\n", "\n", "{answer}\n", "\n", "\"\"\"\n", "\n", "def extract_xml_answer(text: str) -> str:\n", " answer = text.split(\"\")[-1]\n", " answer = answer.split(\"\")[0]\n", " return answer.strip()\n", "\n", "def extract_hash_answer(text: str) -> str | None:\n", " if \"####\" not in text:\n", " return None\n", " return text.split(\"####\")[1].strip()\n", "\n", "# uncomment middle messages for 1-shot prompting\n", "def get_gsm8k_questions(split = \"train\") -> Dataset:\n", " data = load_dataset('openai/gsm8k', 'main')[split] # type: ignore\n", " data = data.map(lambda x: { # type: ignore\n", " 'prompt': [\n", " {'role': 'system', 'content': SYSTEM_PROMPT},\n", " #{'role': 'user', 'content': 'What is the largest single-digit prime number?'},\n", " #{'role': 'assistant', 'content': XML_COT_FORMAT.format(\n", " # reasoning=\"9 is divisble by 3 and 8 is divisible by 2, but 7 is prime.\",\n", " # answer=\"7\"\n", " #)},\n", " {'role': 'user', 'content': x['question']}\n", " ],\n", " 'answer': extract_hash_answer(x['answer'])\n", " }) # type: ignore\n", " return data # type: ignore\n", "\n", "dataset = get_gsm8k_questions()" ], "metadata": { "id": "L1h8s7QzAyE0" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:\n", " responses = [completion[0]['content'] for completion in completions]\n", " q = prompts[0][-1]['content']\n", " extracted_responses = [extract_xml_answer(r) for r in responses]\n", " print('-'*20, f\"Question:\\n{q}\", f\"\\nAnswer:\\n{answer[0]}\", f\"\\nResponse:\\n{responses[0]}\", f\"\\nExtracted:\\n{extracted_responses[0]}\")\n", " return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]\n", "\n", "def int_reward_func(completions, **kwargs) -> list[float]:\n", " responses = [completion[0]['content'] for completion in completions]\n", " extracted_responses = [extract_xml_answer(r) for r in responses]\n", " return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]\n", "\n", "def strict_format_reward_func(completions, **kwargs) -> list[float]:\n", " \"\"\"Reward function that checks if the completion has a specific format.\"\"\"\n", " pattern = r\"^\\n.*?\\n\\n\\n.*?\\n\\n$\"\n", " responses = [completion[0][\"content\"] for completion in completions]\n", " matches = [re.match(pattern, r) for r in responses]\n", " return [0.5 if match else 0.0 for match in matches]\n", "\n", "def soft_format_reward_func(completions, **kwargs) -> list[float]:\n", " \"\"\"Reward function that checks if the completion has a specific format.\"\"\"\n", " pattern = r\".*?\\s*.*?\"\n", " responses = [completion[0][\"content\"] for completion in completions]\n", " matches = [re.match(pattern, r) for r in responses]\n", " return [0.5 if match else 0.0 for match in matches]\n", "\n", "def count_xml(text) -> float:\n", " count = 0.0\n", " if text.count(\"\\n\") == 1:\n", " count += 0.125\n", " if text.count(\"\\n\\n\") == 1:\n", " count += 0.125\n", " if text.count(\"\\n\\n\") == 1:\n", " count += 0.125\n", " count -= len(text.split(\"\\n\\n\")[-1])*0.001\n", " if text.count(\"\\n\") == 1:\n", " count += 0.125\n", " count -= (len(text.split(\"\\n\")[-1]) - 1)*0.001\n", " return count\n", "\n", "def xmlcount_reward_func(completions, **kwargs) -> list[float]:\n", " contents = [completion[0][\"content\"] for completion in completions]\n", " return [count_xml(c) for c in contents]" ], "metadata": { "id": "R60IIK5FA3Mu" }, "execution_count": 5, "outputs": [] }, { "cell_type": "code", "source": [ "model_name = \"HuggingFaceTB/SmolLM2-135M-Instruct\"\n", "#model_name = \"Qwen/Qwen2.5-1.5B-Instruct\"\n", "\n", "if \"SmolLM2\" in model_name:\n", " output_dir = \"outputs/SmolLM2-135M-GRPO\"\n", " run_name = \"SmolLM2-135M-GRPO\"\n", "else:\n", " output_dir=\"outputs/Qwen-1.5B-GRPO\"\n", " run_name=\"Qwen-1.5B-GRPO-gsm8k\"\n" ], "metadata": { "id": "q-u5UG8pDLvD" }, "execution_count": 6, "outputs": [] }, { "cell_type": "code", "source": [ "training_args = GRPOConfig(\n", " output_dir=output_dir,\n", " run_name=run_name,\n", " learning_rate=5e-6,\n", " adam_beta1=0.9,\n", " adam_beta2=0.99,\n", " weight_decay=0.1,\n", " warmup_ratio=0.1,\n", " lr_scheduler_type='cosine',\n", " logging_steps=1,\n", " bf16=True,\n", " per_device_train_batch_size=16, # Fixed: Must be divisible by num_generations\n", " gradient_accumulation_steps=4,\n", " num_generations=16, # Kept as is\n", " max_prompt_length=256,\n", " max_completion_length=786,\n", " num_train_epochs=1,\n", " save_steps=100,\n", " max_grad_norm=0.1,\n", " report_to=\"none\",\n", " log_on_each_node=False,\n", ")\n", "\n", "peft_config = LoraConfig(\n", " r=16,\n", " lora_alpha=64,\n", " target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"up_proj\", \"down_proj\", \"gate_proj\"],\n", " task_type=\"CAUSAL_LM\",\n", " lora_dropout=0.05,\n", ")\n", "\n", "model = AutoModelForCausalLM.from_pretrained(\n", " model_name,\n", " torch_dtype=torch.bfloat16,\n", " device_map=None\n", ").to(\"cuda\")\n", "\n", "tokenizer = AutoTokenizer.from_pretrained(model_name)\n", "tokenizer.pad_token = tokenizer.eos_token" ], "metadata": { "id": "vcZOoqyHBIBV" }, "execution_count": 10, "outputs": [] }, { "cell_type": "code", "source": [ "trainer = GRPOTrainer(\n", " model=model,\n", " processing_class=tokenizer,\n", " reward_funcs=[\n", " xmlcount_reward_func,\n", " soft_format_reward_func,\n", " strict_format_reward_func,\n", " int_reward_func,\n", " correctness_reward_func],\n", " args=training_args,\n", " train_dataset=dataset,\n", " #peft_config=peft_config\n", ")" ], "metadata": { "id": "WLxd8ZFmBzvE" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "trainer.train()" ], "metadata": { "id": "RTbBtLmRIDpY" }, "execution_count": null, "outputs": [] } ] }