GRPO Trainer
Overview
TRL supports the GRPO Trainer for training language models, as described in the paper DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models by Zhihong Shao, Peiyi Wang, Qihao Zhu, Runxin Xu, Junxiao Song, Mingchuan Zhang, Y. K. Li, Y. Wu, Daya Guo.
The abstract from the paper is the following:
Mathematical reasoning poses a significant challenge for language models due to its complex and structured nature. In this paper, we introduce DeepSeekMath 7B, which continues pre-training DeepSeek-Coder-Base-v1.5 7B with 120B math-related tokens sourced from Common Crawl, together with natural language and code data. DeepSeekMath 7B has achieved an impressive score of 51.7% on the competition-level MATH benchmark without relying on external toolkits and voting techniques, approaching the performance level of Gemini-Ultra and GPT-4. Self-consistency over 64 samples from DeepSeekMath 7B achieves 60.9% on MATH. The mathematical reasoning capability of DeepSeekMath is attributed to two key factors: First, we harness the significant potential of publicly available web data through a meticulously engineered data selection pipeline. Second, we introduce Group Relative Policy Optimization (GRPO), a variant of Proximal Policy Optimization (PPO), that enhances mathematical reasoning abilities while concurrently optimizing the memory usage of PPO.
This post-training method was contributed by Quentin Gallouédec.
Quick start
This example demonstrates how to train a model using the GRPO method. We use the Qwen 0.5B model as the base model and the RM-Gemma-2B model as the reward model. We use the prompts from the TLDR dataset (completion column is ingored!). You can view the data in the dataset here:
Below is the script to train the model. We use PEFT to reduce the memory requirements.
# train_grpo.py
from datasets import load_dataset
from peft import LoraConfig
from trl import GRPOConfig, GRPOTrainer
# Load the dataset
dataset = load_dataset("trl-lib/tldr", split="train")
training_args = GRPOConfig(
output_dir="Qwen2-0.5B-GRPO",
learning_rate=1e-5,
logging_steps=10,
gradient_accumulation_steps=16,
max_completion_length=128,
)
trainer = GRPOTrainer(
model="Qwen/Qwen2-0.5B-Instruct",
reward_funcs="weqweasdas/RM-Gemma-2B",
args=training_args,
train_dataset=dataset,
peft_config=LoraConfig(task_type="CAUSAL_LM"),
)
trainer.train()
Execute the script using the following command:
accelerate launch train_grpo.py
Distributed across 8 GPUs, the training takes approximately 1 day.
Looking deeper into the GRPO method
GRPO is an online learning algorithm, meaning it improves iteratively by using the data generated by the trained model itself during training. The intuition behind GRPO objective is to maximize the advantage of the generated completions, while ensuring that the model remains close to the reference policy. To understand how GRPO works, it can be broken down into four main steps: Generating completions, computing the advantage, estimating the KL divergence, and computing the loss.
Generating completions
At each training step, we sample a batch of prompts and generate a set of completions for each prompt (denoted as ).
Computing the advantage
For each of the sequences, we compute the reward using a reward model. To align with the comparative nature of reward models—typically trained on datasets of comparisons between outputs for the same question—the advantage is calculated to reflect these relative comparisons. It is normalized as follows:
This approach gives the method its name: Group Relative Policy Optimization (GRPO).
Estimating the KL divergence
KL divergence is estimated using the approximator introduced by Schulman et al. (2020). The approximator is defined as follows:
Computing the loss
The objective is to maximize the advantage while ensuring that the model remains close to the reference policy. Consequently, the loss is defined as follows:
where the first term represents the scaled advantage and the second term penalizes deviations from the reference policy through KL divergence.
In the original paper, this formulation is generalized to account for multiple updates after each generation by leveraging the clipped surrogate objective:
where ensures that updates do not deviate excessively from the reference policy by bounding the policy ratio between and . In TRL though, as in the original paper, we only do one update per generation, so we can simplify the loss to the first form.
Logged metrics
The GRPO Trainer logs the following metrics:
reward
: The average reward.reward_std
: The average standard deviation within reward groups.kl
: The average KL divergence between the model and the reference model calculated on completions.
Customization
Using a custom reward function
The GRPOTrainer supports using custom reward functions instead of dense reward models. To ensure compatibility, your reward function must satisfy the following requirements:
Input arguments:
- The function must accept two arguments:
prompts
andcompletions
. - Depending on the dataset format, the input will vary:
- For standard format,
prompts
andcompletions
will be lists of strings. - For conversational format,
prompts
andcompletions
will be lists of message dictionaries.
- For standard format,
- The function must accept two arguments:
Return value: The function must return a list of floats. Each float represents the reward corresponding to a single completion.
Example 1: Reward longer completions
Below is an example of a reward function for a standard format that rewards longer completions:
def reward_func(prompts, completions):
"""Reward function that gives higher scores to longer completions."""
return [float(len(completion)) for completion in completions]
You can test it as follows:
>>> prompts = ["The sky is", "The sun is"]
>>> completions = [" blue.", " in the sky."]
>>> print(reward_func(prompts, completions))
[6.0, 12.0]
Example 2: Reward completions with specific format
Below is an example of a reward function that checks if the completion has a specific format. This example is inspired by the reward function used in the paper DeepSeek-R1: Incentivizing Reasoning Capability in LLMs via Reinforcement Learning. It is designed for conversational format, where prompts and completions consist of structured messages.
import re
def format_reward_func(prompts, completions):
"""Reward function that checks if the completion has a specific format."""
pattern = r"^<think>.*?</think><answer>.*?</answer>$"
completion_contents = [completion[0]["content"] for completion in completions]
matches = [re.match(pattern, content) for content in completion_contents]
return [1.0 if match else 0.0 for match in matches]
You can test this function as follows:
>>> prompts = [
... [{"role": "assistant", "content": "What is the result of (1 + 2) * 4?"}],
... [{"role": "assistant", "content": "What is the result of (3 + 1) * 2?"}],
... ]
>>> completions = [
... [{"role": "assistant", "content": "<think>The sum of 1 and 2 is 3, which we multiply by 4 to get 12.</think><answer>(1 + 2) * 4 = 12</answer>"}],
... [{"role": "assistant", "content": "The sum of 3 and 1 is 4, which we multiply by 2 to get 8. So (3 + 1) * 2 = 8."}],
... ]
>>> format_reward_func(prompts, completions)
[1.0, 0.0]
>>>
Passing the reward function to the trainer
To use your custom reward function, pass it to the GRPOTrainer
as follows:
from trl import GRPOTrainer
trainer = GRPOTrainer(
reward_funcs=reward_func,
...,
)
If you have multiple reward functions, you can pass them as a list:
from trl import GRPOTrainer
trainer = GRPOTrainer(
reward_funcs=[reward_func1, reward_func2],
...,
)
and the reward will be computed as the sum of the rewards from each function.
Note that GRPOTrainer supports multiple reward functions of different types. See the parameters documentation for more details.
GRPOTrainer
class trl.GRPOTrainer
< source >( model: typing.Union[str, transformers.modeling_utils.PreTrainedModel] reward_funcs: typing.Union[str, transformers.modeling_utils.PreTrainedModel, typing.Callable[[list, list], list[float]], list[typing.Union[str, transformers.modeling_utils.PreTrainedModel, typing.Callable[[list, list], list[float]]]]] args: GRPOConfig = None train_dataset: typing.Union[datasets.arrow_dataset.Dataset, datasets.iterable_dataset.IterableDataset, NoneType] = None eval_dataset: typing.Union[datasets.arrow_dataset.Dataset, datasets.iterable_dataset.IterableDataset, dict[str, typing.Union[datasets.arrow_dataset.Dataset, datasets.iterable_dataset.IterableDataset]], NoneType] = None processing_class: typing.Optional[transformers.tokenization_utils_base.PreTrainedTokenizerBase] = None reward_processing_classes: typing.Union[transformers.tokenization_utils_base.PreTrainedTokenizerBase, list[transformers.tokenization_utils_base.PreTrainedTokenizerBase], NoneType] = None callbacks: typing.Optional[list[transformers.trainer_callback.TrainerCallback]] = None optimizers: tuple = (None, None) peft_config: typing.Optional[ForwardRef('PeftConfig')] = None )
Parameters
- model (
Union[str, PreTrainedModel]
) — Model to be trained. Can be either:- A string, being the model id of a pretrained model hosted inside a model repo on huggingface.co, or
a path to a directory containing model weights saved using
save_pretrained, e.g.,
'./my_model_directory/'
. The model is loaded using from_pretrained with the keywork arguments inargs.model_init_kwargs
. - A PreTrainedModel object. Only causal language models are supported.
- A string, being the model id of a pretrained model hosted inside a model repo on huggingface.co, or
a path to a directory containing model weights saved using
save_pretrained, e.g.,
- reward_funcs (
Union[RewardFunc, list[RewardFunc]]
) — Reward functions to be used for computing the rewards. To compute the rewards, we call all the reward functions with the prompts and completions and sum the rewards. Can be either:- A single reward function, such as:
- A string: The model ID of a pretrained model hosted inside a model repo on huggingface.co, or a
path to a directory containing model weights saved using
save_pretrained, e.g.,
'./my_model_directory/'
. The model is loaded using from_pretrained withnum_labels=1
and the keyword arguments inargs.model_init_kwargs
. - A PreTrainedModel object: Only sequence classification models are supported.
- A custom reward function: This should take a list of prompts and completions and return a list of rewards. For more details, see Using a custom reward function.
- A string: The model ID of a pretrained model hosted inside a model repo on huggingface.co, or a
path to a directory containing model weights saved using
save_pretrained, e.g.,
- A list of reward functions, where each item can independently be any of the above types. Mixing different types within the list (e.g., a string model ID and a custom reward function) is allowed.
- A single reward function, such as:
- args (GRPOConfig, optional, defaults to
None
) — Configuration for this trainer. IfNone
, a default configuration is used. - train_dataset (Dataset or IterableDataset) —
Dataset to use for training. It must include a column
"prompt"
. Any additional columns in the dataset is ignored. The format of the samples can be either:- Standard: Each sample contains plain text.
- Conversational: Each sample contains structured messages (e.g., role and content).
- eval_dataset (Dataset, IterableDataset or
dict[str, Union[Dataset, IterableDataset]]
) — Dataset to use for evaluation. It must meet the same requirements astrain_dataset
. - processing_class (PreTrainedTokenizerBase, optional, defaults to
None
) — Processing class used to process the data. The padding side must be set to “left”. IfNone
, the processing class is loaded from the model’s name with from_pretrained. - reward_processing_classes (
Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]
, optional, defaults toNone
) — Processing classes corresponding to the reward functions specified inreward_funcs
. Can be either:- A single processing class: Used when
reward_funcs
contains only one reward function. - A list of processing classes: Must match the order and length of the reward functions in
reward_funcs
. If set toNone
, or if an element of the list corresponding to a PreTrainedModel isNone
, the tokenizer for the model is automatically loaded using from_pretrained. For elements inreward_funcs
that are custom reward functions (not PreTrainedModel), the corresponding entries inreward_processing_classes
are ignored.
- A single processing class: Used when
- callbacks (list of TrainerCallback, optional, defaults to
None
) — List of callbacks to customize the training loop. Will add those to the list of default callbacks detailed in here.If you want to remove one of the default callbacks used, use the remove_callback method.
- optimizers (
tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]
, optional, defaults to(None, None)
) — A tuple containing the optimizer and the scheduler to use. Will default to an instance ofAdamW
on your model and a scheduler given byget_linear_schedule_with_warmup
controlled byargs
. - peft_config (
~peft.PeftConfig
, optional, defaults toNone
) — PEFT configuration used to wrap the model. IfNone
, the model is not wrapped.
Trainer for the Group Relative Policy Optimization (GRPO) method. This algorithm was initially proposed in the paper DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models.
Example:
from datasets import load_dataset
from trl import GRPOTrainer
dataset = load_dataset("trl-lib/tldr", split="train")
trainer = GRPOTrainer(
model="Qwen/Qwen2-0.5B-Instruct",
reward_funcs="weqweasdas/RM-Gemma-2B",
train_dataset=dataset,
)
trainer.train()
create_model_card
< source >( model_name: typing.Optional[str] = None dataset_name: typing.Optional[str] = None tags: typing.Union[str, list[str], NoneType] = None )
Creates a draft of a model card using the information available to the Trainer
.
GRPOConfig
class trl.GRPOConfig
< source >( output_dir: str overwrite_output_dir: bool = False do_train: bool = False do_eval: bool = False do_predict: bool = False eval_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'no' prediction_loss_only: bool = False per_device_train_batch_size: int = 8 per_device_eval_batch_size: int = 8 per_gpu_train_batch_size: typing.Optional[int] = None per_gpu_eval_batch_size: typing.Optional[int] = None gradient_accumulation_steps: int = 1 eval_accumulation_steps: typing.Optional[int] = None eval_delay: typing.Optional[float] = 0 torch_empty_cache_steps: typing.Optional[int] = None learning_rate: float = 1e-06 weight_decay: float = 0.0 adam_beta1: float = 0.9 adam_beta2: float = 0.999 adam_epsilon: float = 1e-08 max_grad_norm: float = 1.0 num_train_epochs: float = 3.0 max_steps: int = -1 lr_scheduler_type: typing.Union[transformers.trainer_utils.SchedulerType, str] = 'linear' lr_scheduler_kwargs: typing.Union[dict, str, NoneType] = <factory> warmup_ratio: float = 0.0 warmup_steps: int = 0 log_level: typing.Optional[str] = 'passive' log_level_replica: typing.Optional[str] = 'warning' log_on_each_node: bool = True logging_dir: typing.Optional[str] = None logging_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'steps' logging_first_step: bool = False logging_steps: float = 500 logging_nan_inf_filter: bool = True save_strategy: typing.Union[transformers.trainer_utils.SaveStrategy, str] = 'steps' save_steps: float = 500 save_total_limit: typing.Optional[int] = None save_safetensors: typing.Optional[bool] = True save_on_each_node: bool = False save_only_model: bool = False restore_callback_states_from_checkpoint: bool = False no_cuda: bool = False use_cpu: bool = False use_mps_device: bool = False seed: int = 42 data_seed: typing.Optional[int] = None jit_mode_eval: bool = False use_ipex: bool = False bf16: bool = False fp16: bool = False fp16_opt_level: str = 'O1' half_precision_backend: str = 'auto' bf16_full_eval: bool = False fp16_full_eval: bool = False tf32: typing.Optional[bool] = None local_rank: int = -1 ddp_backend: typing.Optional[str] = None tpu_num_cores: typing.Optional[int] = None tpu_metrics_debug: bool = False debug: typing.Union[str, typing.List[transformers.debug_utils.DebugOption]] = '' dataloader_drop_last: bool = False eval_steps: typing.Optional[float] = None dataloader_num_workers: int = 0 dataloader_prefetch_factor: typing.Optional[int] = None past_index: int = -1 run_name: typing.Optional[str] = None disable_tqdm: typing.Optional[bool] = None remove_unused_columns: typing.Optional[bool] = True label_names: typing.Optional[typing.List[str]] = None load_best_model_at_end: typing.Optional[bool] = False metric_for_best_model: typing.Optional[str] = None greater_is_better: typing.Optional[bool] = None ignore_data_skip: bool = False fsdp: typing.Union[typing.List[transformers.trainer_utils.FSDPOption], str, NoneType] = '' fsdp_min_num_params: int = 0 fsdp_config: typing.Union[dict, str, NoneType] = None fsdp_transformer_layer_cls_to_wrap: typing.Optional[str] = None accelerator_config: typing.Union[dict, str, NoneType] = None deepspeed: typing.Union[dict, str, NoneType] = None label_smoothing_factor: float = 0.0 optim: typing.Union[transformers.training_args.OptimizerNames, str] = 'adamw_torch' optim_args: typing.Optional[str] = None adafactor: bool = False group_by_length: bool = False length_column_name: typing.Optional[str] = 'length' report_to: typing.Union[NoneType, str, typing.List[str]] = None ddp_find_unused_parameters: typing.Optional[bool] = None ddp_bucket_cap_mb: typing.Optional[int] = None ddp_broadcast_buffers: typing.Optional[bool] = None dataloader_pin_memory: bool = True dataloader_persistent_workers: bool = False skip_memory_metrics: bool = True use_legacy_prediction_loop: bool = False push_to_hub: bool = False resume_from_checkpoint: typing.Optional[str] = None hub_model_id: typing.Optional[str] = None hub_strategy: typing.Union[transformers.trainer_utils.HubStrategy, str] = 'every_save' hub_token: typing.Optional[str] = None hub_private_repo: typing.Optional[bool] = None hub_always_push: bool = False gradient_checkpointing: bool = False gradient_checkpointing_kwargs: typing.Union[dict, str, NoneType] = None include_inputs_for_metrics: bool = False include_for_metrics: typing.List[str] = <factory> eval_do_concat_batches: bool = True fp16_backend: str = 'auto' evaluation_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = None push_to_hub_model_id: typing.Optional[str] = None push_to_hub_organization: typing.Optional[str] = None push_to_hub_token: typing.Optional[str] = None mp_parameters: str = '' auto_find_batch_size: bool = False full_determinism: bool = False torchdynamo: typing.Optional[str] = None ray_scope: typing.Optional[str] = 'last' ddp_timeout: typing.Optional[int] = 1800 torch_compile: bool = False torch_compile_backend: typing.Optional[str] = None torch_compile_mode: typing.Optional[str] = None dispatch_batches: typing.Optional[bool] = None split_batches: typing.Optional[bool] = None include_tokens_per_second: typing.Optional[bool] = False include_num_input_tokens_seen: typing.Optional[bool] = False neftune_noise_alpha: typing.Optional[float] = None optim_target_modules: typing.Union[NoneType, str, typing.List[str]] = None batch_eval_metrics: bool = False eval_on_start: bool = False use_liger_kernel: typing.Optional[bool] = False eval_use_gather_object: typing.Optional[bool] = False average_tokens_across_devices: typing.Optional[bool] = False model_init_kwargs: typing.Optional[dict] = None max_prompt_length: typing.Optional[int] = 512 num_generations: typing.Optional[int] = 8 temperature: typing.Optional[float] = 0.9 max_completion_length: typing.Optional[int] = 256 beta: float = 0.04 )
Parameters that control the model and reference model
- model_init_kwargs (
dict[str, Any]
orNone
, optional, defaults toNone
) — Keyword arguments for from_pretrained, used when themodel
argument of the GRPOTrainer is provided as a string.
Parameters that control the data preprocessing
- max_prompt_length (
int
orNone
, optional, defaults to512
) — Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left. - num_generations (
int
orNone
, optional, defaults to8
) — Number of generations per prompt to sample. - temperature (
float
, optional, defaults to0.9
) — Temperature for sampling. The higher the temperature, the more random the completions. - max_completion_length (
int
orNone
, optional, defaults toNone
) — Maximum length of the generated completion.
Parameters that control the training
- learning_rate (
float
, optional, defaults to1e-6
) — Initial learning rate forAdamW
optimizer. The default value replaces that of TrainingArguments. - beta (
float
, optional, defaults to0.04
) — KL coefficient.
Configuration class for the GRPOTrainer.
Only the parameters specific to GRPO training are listed here. For details on other parameters, refer to the TrainingArguments documentation.
Using HfArgumentParser we can turn this class into argparse arguments that can be specified on the command line.