{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"gpuType": "T4"
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "PFGREH3kAkep"
},
"outputs": [],
"source": [
"!pip install -q git+https://github.com/huggingface/transformers.git git+https://github.com/huggingface/accelerate.git\n",
"!pip install -q datasets huggingface-hub trl\n",
"!pip install -q git+https://github.com/huggingface/peft.git"
]
},
{
"cell_type": "code",
"source": [
"import re\n",
"import torch\n",
"from datasets import load_dataset, Dataset\n",
"from transformers import AutoTokenizer, AutoModelForCausalLM\n",
"from peft import LoraConfig\n",
"from trl import GRPOConfig, GRPOTrainer"
],
"metadata": {
"id": "PTbegv7kAwXd"
},
"execution_count": 3,
"outputs": []
},
{
"cell_type": "code",
"source": [
"SYSTEM_PROMPT = \"\"\"\n",
"Respond in the following format:\n",
"\n",
"...\n",
"\n",
"\n",
"...\n",
"\n",
"\"\"\"\n",
"\n",
"XML_COT_FORMAT = \"\"\"\\\n",
"\n",
"{reasoning}\n",
"\n",
"\n",
"{answer}\n",
"\n",
"\"\"\"\n",
"\n",
"def extract_xml_answer(text: str) -> str:\n",
" answer = text.split(\"\")[-1]\n",
" answer = answer.split(\"\")[0]\n",
" return answer.strip()\n",
"\n",
"def extract_hash_answer(text: str) -> str | None:\n",
" if \"####\" not in text:\n",
" return None\n",
" return text.split(\"####\")[1].strip()\n",
"\n",
"# uncomment middle messages for 1-shot prompting\n",
"def get_gsm8k_questions(split = \"train\") -> Dataset:\n",
" data = load_dataset('openai/gsm8k', 'main')[split] # type: ignore\n",
" data = data.map(lambda x: { # type: ignore\n",
" 'prompt': [\n",
" {'role': 'system', 'content': SYSTEM_PROMPT},\n",
" #{'role': 'user', 'content': 'What is the largest single-digit prime number?'},\n",
" #{'role': 'assistant', 'content': XML_COT_FORMAT.format(\n",
" # reasoning=\"9 is divisble by 3 and 8 is divisible by 2, but 7 is prime.\",\n",
" # answer=\"7\"\n",
" #)},\n",
" {'role': 'user', 'content': x['question']}\n",
" ],\n",
" 'answer': extract_hash_answer(x['answer'])\n",
" }) # type: ignore\n",
" return data # type: ignore\n",
"\n",
"dataset = get_gsm8k_questions()"
],
"metadata": {
"id": "L1h8s7QzAyE0"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:\n",
" responses = [completion[0]['content'] for completion in completions]\n",
" q = prompts[0][-1]['content']\n",
" extracted_responses = [extract_xml_answer(r) for r in responses]\n",
" print('-'*20, f\"Question:\\n{q}\", f\"\\nAnswer:\\n{answer[0]}\", f\"\\nResponse:\\n{responses[0]}\", f\"\\nExtracted:\\n{extracted_responses[0]}\")\n",
" return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]\n",
"\n",
"def int_reward_func(completions, **kwargs) -> list[float]:\n",
" responses = [completion[0]['content'] for completion in completions]\n",
" extracted_responses = [extract_xml_answer(r) for r in responses]\n",
" return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]\n",
"\n",
"def strict_format_reward_func(completions, **kwargs) -> list[float]:\n",
" \"\"\"Reward function that checks if the completion has a specific format.\"\"\"\n",
" pattern = r\"^\\n.*?\\n\\n\\n.*?\\n\\n$\"\n",
" responses = [completion[0][\"content\"] for completion in completions]\n",
" matches = [re.match(pattern, r) for r in responses]\n",
" return [0.5 if match else 0.0 for match in matches]\n",
"\n",
"def soft_format_reward_func(completions, **kwargs) -> list[float]:\n",
" \"\"\"Reward function that checks if the completion has a specific format.\"\"\"\n",
" pattern = r\".*?\\s*.*?\"\n",
" responses = [completion[0][\"content\"] for completion in completions]\n",
" matches = [re.match(pattern, r) for r in responses]\n",
" return [0.5 if match else 0.0 for match in matches]\n",
"\n",
"def count_xml(text) -> float:\n",
" count = 0.0\n",
" if text.count(\"\\n\") == 1:\n",
" count += 0.125\n",
" if text.count(\"\\n\\n\") == 1:\n",
" count += 0.125\n",
" if text.count(\"\\n\\n\") == 1:\n",
" count += 0.125\n",
" count -= len(text.split(\"\\n\\n\")[-1])*0.001\n",
" if text.count(\"\\n\") == 1:\n",
" count += 0.125\n",
" count -= (len(text.split(\"\\n\")[-1]) - 1)*0.001\n",
" return count\n",
"\n",
"def xmlcount_reward_func(completions, **kwargs) -> list[float]:\n",
" contents = [completion[0][\"content\"] for completion in completions]\n",
" return [count_xml(c) for c in contents]"
],
"metadata": {
"id": "R60IIK5FA3Mu"
},
"execution_count": 5,
"outputs": []
},
{
"cell_type": "code",
"source": [
"model_name = \"HuggingFaceTB/SmolLM2-135M-Instruct\"\n",
"#model_name = \"Qwen/Qwen2.5-1.5B-Instruct\"\n",
"\n",
"if \"SmolLM2\" in model_name:\n",
" output_dir = \"outputs/SmolLM2-135M-GRPO\"\n",
" run_name = \"SmolLM2-135M-GRPO\"\n",
"else:\n",
" output_dir=\"outputs/Qwen-1.5B-GRPO\"\n",
" run_name=\"Qwen-1.5B-GRPO-gsm8k\"\n"
],
"metadata": {
"id": "q-u5UG8pDLvD"
},
"execution_count": 6,
"outputs": []
},
{
"cell_type": "code",
"source": [
"training_args = GRPOConfig(\n",
" output_dir=output_dir,\n",
" run_name=run_name,\n",
" learning_rate=5e-6,\n",
" adam_beta1=0.9,\n",
" adam_beta2=0.99,\n",
" weight_decay=0.1,\n",
" warmup_ratio=0.1,\n",
" lr_scheduler_type='cosine',\n",
" logging_steps=1,\n",
" bf16=True,\n",
" per_device_train_batch_size=16, # Fixed: Must be divisible by num_generations\n",
" gradient_accumulation_steps=4,\n",
" num_generations=16, # Kept as is\n",
" max_prompt_length=256,\n",
" max_completion_length=786,\n",
" num_train_epochs=1,\n",
" save_steps=100,\n",
" max_grad_norm=0.1,\n",
" report_to=\"none\",\n",
" log_on_each_node=False,\n",
")\n",
"\n",
"peft_config = LoraConfig(\n",
" r=16,\n",
" lora_alpha=64,\n",
" target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"up_proj\", \"down_proj\", \"gate_proj\"],\n",
" task_type=\"CAUSAL_LM\",\n",
" lora_dropout=0.05,\n",
")\n",
"\n",
"model = AutoModelForCausalLM.from_pretrained(\n",
" model_name,\n",
" torch_dtype=torch.bfloat16,\n",
" device_map=None\n",
").to(\"cuda\")\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
"tokenizer.pad_token = tokenizer.eos_token"
],
"metadata": {
"id": "vcZOoqyHBIBV"
},
"execution_count": 10,
"outputs": []
},
{
"cell_type": "code",
"source": [
"trainer = GRPOTrainer(\n",
" model=model,\n",
" processing_class=tokenizer,\n",
" reward_funcs=[\n",
" xmlcount_reward_func,\n",
" soft_format_reward_func,\n",
" strict_format_reward_func,\n",
" int_reward_func,\n",
" correctness_reward_func],\n",
" args=training_args,\n",
" train_dataset=dataset,\n",
" #peft_config=peft_config\n",
")"
],
"metadata": {
"id": "WLxd8ZFmBzvE"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"trainer.train()"
],
"metadata": {
"id": "RTbBtLmRIDpY"
},
"execution_count": null,
"outputs": []
}
]
}