簡単な算数問題を解けるように GRPO で学習してみた。学習コードは下の方にあります。

学習データは簡単な問題なのでその場で合成したものを使いました。(コード参照)

prompt format:

あなたはアシスタントとして回答します。
ユーザーの質問に対して、<think></think>ブロック内で思考してから<answer></answer>でファイナルアンサーしてください。
具体的には、「<think>ここに思考過程</think><answer>ここに解答</answer>」という形です。
「ユーザー」の質問の後に、「アシスタント」が回答します。
ユーザー:
次の ? に入る数値を計算して回答してください。
{formula}

アシスタント:

example formula:

4 + 3 * 2 = ?

expected output:

<think>思考内容</think><answer>解答</answer>

Example

from transformers import pipeline

formula = "9 + 3 * 5 = ?" # A + B * C か A * B + C の形式のみ対応

prompt = f"""\
あなたはアシスタントとして回答します。
ユーザーの質問に対して、<think></think>ブロック内で思考してから<answer></answer>でファイナルアンサーしてください。
具体的には、「<think>ここに思考過程</think><answer>ここに解答</answer>」という形です。
「ユーザー」の質問の後に、「アシスタント」が回答します。
ユーザー:
次の ? に入る数値を計算して回答してください。
{formula}

アシスタント:
"""

print(pipe(prompt)[0]["generated_text"][len(prompt):])
# <think>9 + 3 * 5 = 9 + 15 = 24</think><answer>24</answer>

Training information

Wandb log: https://wandb.ai/p1atdev/grpo-math-01/runs/ytv8wxll

Training code

import random
import re

import torch
from datasets import Dataset
from trl import GRPOConfig, GRPOTrainer
from transformers import AutoTokenizer, AutoModelForCausalLM
import wandb

SYSTEM_PROMPT = """命令:
あなたはアシスタントとして回答します。
ユーザーの質問に対して、<think></think>ブロック内で思考してから<answer></answer>でファイナルアンサーしてください。
具体的には、「<think>ここに思考過程</think><answer>ここに解答</answer>」という形です。
「ユーザー」の質問の後に、「アシスタント」が回答します。
ユーザー:
"""
MODEL_NAME = "Qwen/Qwen2.5-0.5B"

def generate_problem():
    # written by ChatGPT
    # 1~10 の間のランダムな整数を3つ生成
    a = random.randint(1, 10)
    b = random.randint(1, 10)
    c = random.randint(1, 10)

    # 足し算と掛け算の両方を含むように、2通りのパターンからランダムに選択
    if random.randint(0, 1) == 0:
        # パターン1: 足し算+掛け算 => 例: a + b * c
        expression = f"{a} + {b} * {c}"
    else:
        # パターン2: 掛け算+足し算 => 例: a * b + c
        expression = f"{a} * {b} + {c}"

    # Python の eval() を用いて答えを計算(演算子の優先順位に従う)
    answer = eval(expression)

    return f"{expression} = ?", answer


def generate_random_pair(max_count: int):
    for i in range(max_count):
        formula, answer = generate_problem()
        question = f"""{SYSTEM_PROMPT}
次の ? に入る数値を計算して回答してください。
{formula}

アシスタント:
"""
        yield {"id": i, "prompt": question, "ground_truth": answer}


# format reward
FORMAT_PATTERN = re.compile(r"^<think>.*?</think><answer>.*?</answer>$")

def format_reward_func(completions: list[str], **kwargs):
    """Reward function that checks if the completion has a specific format."""
    matches = [FORMAT_PATTERN.match(content) for content in completions]
    return [1.0 if match else 0.0 for match in matches]


# answer reward
ANSWER_PATTERN = re.compile(r"<answer>(\d+)</answer>")

def answer_reward_func(completions: list[str], ground_truth: list[str], **kwargs):
    # Regular expression to capture content inside \boxed{}
    matches = [ANSWER_PATTERN.search(completion) for completion in completions]
    contents = [match.group(1) if match else "" for match in matches]
    # Reward 1 if the content is the same as the ground truth, 0 otherwise
    return [1.0 if c == str(gt) else 0.0 for c, gt in zip(contents, ground_truth)]


def main():
    ds = Dataset.from_generator(generate_random_pair, gen_kwargs={"max_count": 100000}) # 100000 is too many, we don't need so much for this task
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        attn_implementation="flash_attention_2",
        torch_dtype=torch.bfloat16,
        device_map="auto",
    )
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    tokenizer.pad_token = tokenizer.eos_token

    project_name = YOUR_WANDB_PROJECT_NAME
    push_hub_name = YOUR_PUSH_HUB_NAME

    wandb.init(project=project_name)
    train_args = GRPOConfig(
        output_dir="./grpo-01", #! output path
        use_vllm=False,  # True to use vLLM
        overwrite_output_dir=True,
        num_train_epochs=10,
        num_generations=4,
        per_device_train_batch_size=16,
        # per_device_eval_batch_size=4,
        gradient_accumulation_steps=1,
        gradient_checkpointing=True,
        learning_rate=1e-4, # maybe a bit high
        warmup_ratio=0.01,
        weight_decay=0.01,
        optim="adamw_8bit",
        adam_epsilon=1e-8,
        lr_scheduler_type="cosine_with_min_lr",
        lr_scheduler_kwargs={
            "min_lr": 5e-5,
            "num_cycles": 0.5,
        },
        # eval_strategy="steps", # eval did not work well
        # eval_steps=10,
        save_steps=10,
        save_total_limit=2,
        logging_steps=1,
        logging_first_step=True,
        # load_best_model_at_end=True,
        # metric_for_best_model="eval_loss",
        torch_compile=False,  # compile does not work
        fp16=False,
        bf16=True,
        report_to=["wandb"],
        hub_model_id=push_hub_name,
        hub_private_repo=True,
        push_to_hub=True,
        save_safetensors=True,
    )

    trainer = GRPOTrainer(
        model=model,
        processing_class=tokenizer,
        train_dataset=ds,
        # eval_dataset=ds["test"], 
        reward_funcs=[format_reward_func, answer_reward_func],
        args=train_args,
    )

    trainer.train()


if __name__ == "__main__":
    main()
Downloads last month
22
Safetensors
Model size
494M params
Tensor type
BF16
·
Inference Providers NEW
This model is not currently available via any of the supported Inference Providers.

Model tree for p1atdev/qwen2.5-0.5b-grpo-math-01

Base model

Qwen/Qwen2.5-0.5B
Finetuned
(108)
this model