簡単な算数問題を解けるように GRPO で学習してみた。学習コードは下の方にあります。
学習データは簡単な問題なのでその場で合成したものを使いました。(コード参照)
prompt format:
あなたはアシスタントとして回答します。
ユーザーの質問に対して、<think></think>ブロック内で思考してから<answer></answer>でファイナルアンサーしてください。
具体的には、「<think>ここに思考過程</think><answer>ここに解答</answer>」という形です。
「ユーザー」の質問の後に、「アシスタント」が回答します。
ユーザー:
次の ? に入る数値を計算して回答してください。
{formula}
アシスタント:
example formula
:
4 + 3 * 2 = ?
expected output:
<think>思考内容</think><answer>解答</answer>
Example
from transformers import pipeline
formula = "9 + 3 * 5 = ?" # A + B * C か A * B + C の形式のみ対応
prompt = f"""\
あなたはアシスタントとして回答します。
ユーザーの質問に対して、<think></think>ブロック内で思考してから<answer></answer>でファイナルアンサーしてください。
具体的には、「<think>ここに思考過程</think><answer>ここに解答</answer>」という形です。
「ユーザー」の質問の後に、「アシスタント」が回答します。
ユーザー:
次の ? に入る数値を計算して回答してください。
{formula}
アシスタント:
"""
print(pipe(prompt)[0]["generated_text"][len(prompt):])
# <think>9 + 3 * 5 = 9 + 15 = 24</think><answer>24</answer>
Training information
- Base model: Qwen/Qwen2.5-0.5B
- Device: 1x A100 80G
- GPU Hour: about 1 hour
- Total training steps: 140 steps (the last checkpoint)
Wandb log: https://wandb.ai/p1atdev/grpo-math-01/runs/ytv8wxll
Training code
import random
import re
import torch
from datasets import Dataset
from trl import GRPOConfig, GRPOTrainer
from transformers import AutoTokenizer, AutoModelForCausalLM
import wandb
SYSTEM_PROMPT = """命令:
あなたはアシスタントとして回答します。
ユーザーの質問に対して、<think></think>ブロック内で思考してから<answer></answer>でファイナルアンサーしてください。
具体的には、「<think>ここに思考過程</think><answer>ここに解答</answer>」という形です。
「ユーザー」の質問の後に、「アシスタント」が回答します。
ユーザー:
"""
MODEL_NAME = "Qwen/Qwen2.5-0.5B"
def generate_problem():
# written by ChatGPT
# 1~10 の間のランダムな整数を3つ生成
a = random.randint(1, 10)
b = random.randint(1, 10)
c = random.randint(1, 10)
# 足し算と掛け算の両方を含むように、2通りのパターンからランダムに選択
if random.randint(0, 1) == 0:
# パターン1: 足し算+掛け算 => 例: a + b * c
expression = f"{a} + {b} * {c}"
else:
# パターン2: 掛け算+足し算 => 例: a * b + c
expression = f"{a} * {b} + {c}"
# Python の eval() を用いて答えを計算(演算子の優先順位に従う)
answer = eval(expression)
return f"{expression} = ?", answer
def generate_random_pair(max_count: int):
for i in range(max_count):
formula, answer = generate_problem()
question = f"""{SYSTEM_PROMPT}
次の ? に入る数値を計算して回答してください。
{formula}
アシスタント:
"""
yield {"id": i, "prompt": question, "ground_truth": answer}
# format reward
FORMAT_PATTERN = re.compile(r"^<think>.*?</think><answer>.*?</answer>$")
def format_reward_func(completions: list[str], **kwargs):
"""Reward function that checks if the completion has a specific format."""
matches = [FORMAT_PATTERN.match(content) for content in completions]
return [1.0 if match else 0.0 for match in matches]
# answer reward
ANSWER_PATTERN = re.compile(r"<answer>(\d+)</answer>")
def answer_reward_func(completions: list[str], ground_truth: list[str], **kwargs):
# Regular expression to capture content inside \boxed{}
matches = [ANSWER_PATTERN.search(completion) for completion in completions]
contents = [match.group(1) if match else "" for match in matches]
# Reward 1 if the content is the same as the ground truth, 0 otherwise
return [1.0 if c == str(gt) else 0.0 for c, gt in zip(contents, ground_truth)]
def main():
ds = Dataset.from_generator(generate_random_pair, gen_kwargs={"max_count": 100000}) # 100000 is too many, we don't need so much for this task
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16,
device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token
project_name = YOUR_WANDB_PROJECT_NAME
push_hub_name = YOUR_PUSH_HUB_NAME
wandb.init(project=project_name)
train_args = GRPOConfig(
output_dir="./grpo-01", #! output path
use_vllm=False, # True to use vLLM
overwrite_output_dir=True,
num_train_epochs=10,
num_generations=4,
per_device_train_batch_size=16,
# per_device_eval_batch_size=4,
gradient_accumulation_steps=1,
gradient_checkpointing=True,
learning_rate=1e-4, # maybe a bit high
warmup_ratio=0.01,
weight_decay=0.01,
optim="adamw_8bit",
adam_epsilon=1e-8,
lr_scheduler_type="cosine_with_min_lr",
lr_scheduler_kwargs={
"min_lr": 5e-5,
"num_cycles": 0.5,
},
# eval_strategy="steps", # eval did not work well
# eval_steps=10,
save_steps=10,
save_total_limit=2,
logging_steps=1,
logging_first_step=True,
# load_best_model_at_end=True,
# metric_for_best_model="eval_loss",
torch_compile=False, # compile does not work
fp16=False,
bf16=True,
report_to=["wandb"],
hub_model_id=push_hub_name,
hub_private_repo=True,
push_to_hub=True,
save_safetensors=True,
)
trainer = GRPOTrainer(
model=model,
processing_class=tokenizer,
train_dataset=ds,
# eval_dataset=ds["test"],
reward_funcs=[format_reward_func, answer_reward_func],
args=train_args,
)
trainer.train()
if __name__ == "__main__":
main()
- Downloads last month
- 22
Inference Providers
NEW
This model is not currently available via any of the supported Inference Providers.
Model tree for p1atdev/qwen2.5-0.5b-grpo-math-01
Base model
Qwen/Qwen2.5-0.5B