TRL documentation

GRPO Trainer

You are viewing main version, which requires installation from source. If you'd like regular pip install, checkout the latest stable version (v0.13.0).
Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

GRPO Trainer

Overview

TRL supports the GRPO Trainer for training language models, as described in the paper DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models by Zhihong Shao, Peiyi Wang, Qihao Zhu, Runxin Xu, Junxiao Song, Mingchuan Zhang, Y. K. Li, Y. Wu, Daya Guo.

The abstract from the paper is the following:

Mathematical reasoning poses a significant challenge for language models due to its complex and structured nature. In this paper, we introduce DeepSeekMath 7B, which continues pre-training DeepSeek-Coder-Base-v1.5 7B with 120B math-related tokens sourced from Common Crawl, together with natural language and code data. DeepSeekMath 7B has achieved an impressive score of 51.7% on the competition-level MATH benchmark without relying on external toolkits and voting techniques, approaching the performance level of Gemini-Ultra and GPT-4. Self-consistency over 64 samples from DeepSeekMath 7B achieves 60.9% on MATH. The mathematical reasoning capability of DeepSeekMath is attributed to two key factors: First, we harness the significant potential of publicly available web data through a meticulously engineered data selection pipeline. Second, we introduce Group Relative Policy Optimization (GRPO), a variant of Proximal Policy Optimization (PPO), that enhances mathematical reasoning abilities while concurrently optimizing the memory usage of PPO.

This post-training method was contributed by Quentin Gallouédec.

Quick start

This example demonstrates how to train a model using the GRPO method. We use the Qwen 0.5B model as the base model and the RM-Gemma-2B model as the reward model. We use the prompts from the TLDR dataset (completion column is ingored!). You can view the data in the dataset here:

Below is the script to train the model. We use PEFT to reduce the memory requirements.

# train_grpo.py
from datasets import load_dataset
from peft import LoraConfig
from trl import GRPOConfig, GRPOTrainer

# Load the dataset
dataset = load_dataset("trl-lib/tldr", split="train")

training_args = GRPOConfig(
    output_dir="Qwen2-0.5B-GRPO",
    learning_rate=1e-5,
    logging_steps=10,
    gradient_accumulation_steps=16,
    max_completion_length=128,
)
trainer = GRPOTrainer(
    model="Qwen/Qwen2-0.5B-Instruct",
    reward_funcs="weqweasdas/RM-Gemma-2B",
    args=training_args,
    train_dataset=dataset,
    peft_config=LoraConfig(task_type="CAUSAL_LM"),
)

trainer.train()

Execute the script using the following command:

accelerate launch train_grpo.py

Distributed across 8 GPUs, the training takes approximately 1 day.

Looking deeper into the GRPO method

GRPO is an online learning algorithm, meaning it improves iteratively by using the data generated by the trained model itself during training. The intuition behind GRPO objective is to maximize the advantage of the generated completions, while ensuring that the model remains close to the reference policy. To understand how GRPO works, it can be broken down into four main steps: Generating completions, computing the advantage, estimating the KL divergence, and computing the loss.

Generating completions

At each training step, we sample a batch of prompts and generate a set of G G completions for each prompt (denoted as oi o_i ).

Computing the advantage

For each of the G G sequences, we compute the reward using a reward model. To align with the comparative nature of reward models—typically trained on datasets of comparisons between outputs for the same question—the advantage is calculated to reflect these relative comparisons. It is normalized as follows:
A^i,t=rimean(r)std(r)\hat{A}_{i,t} = \frac{r_i - \text{mean}(\mathbf{r})}{\text{std}(\mathbf{r})}

This approach gives the method its name: Group Relative Policy Optimization (GRPO).

Estimating the KL divergence

KL divergence is estimated using the approximator introduced by Schulman et al. (2020). The approximator is defined as follows: DKL[πθπref]=πref(oi,tq,oi,<t)πθ(oi,tq,oi,<t)logπref(oi,tq,oi,<t)πθ(oi,tq,oi,<t)1,\mathbb{D}_{\text{KL}}\left[\pi_\theta \|\pi_{\text{ref}}\right] = \frac{\pi_{\text{ref}}(o_{i,t} \mid q, o_{i,<t})}{\pi_\theta(o_{i,t} \mid q, o_{i,<t})} - \log \frac{\pi_{\text{ref}}(o_{i,t} \mid q, o_{i,<t})}{\pi_\theta(o_{i,t} \mid q, o_{i,<t})} - 1,

Computing the loss

The objective is to maximize the advantage while ensuring that the model remains close to the reference policy. Consequently, the loss is defined as follows:
LGRPO(θ)=1Gi=1G1oit=1oi[πθ(oi,tq,oi,<t)[πθ(oi,tq,oi,<t)]no gradA^i,tβDKL[πθπref]], \mathcal{L}_{\text{GRPO}}(\theta) = -\frac{1}{G} \sum_{i=1}^G \frac{1}{|o_i|} \sum_{t=1}^{|o_i|} \left[ \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\left[\pi_\theta(o_{i,t} \mid q, o_{i,< t})\right]_{\text{no grad}}} \hat{A}_{i,t} - \beta \mathbb{D}_{\text{KL}}\left[\pi_\theta \| \pi_{\text{ref}}\right] \right],

where the first term represents the scaled advantage and the second term penalizes deviations from the reference policy through KL divergence.

In the original paper, this formulation is generalized to account for multiple updates after each generation by leveraging the clipped surrogate objective:
LGRPO(θ)=1Gi=1G1oit=1oi[min(πθ(oi,tq,oi,<t)πθold(oi,tq,oi,<t)A^i,t,clip(πθ(oi,tq,oi,<t)πθold(oi,tq,oi,<t),1ϵ,1+ϵ)A^i,t)βDKL[πθπref]], \mathcal{L}_{\text{GRPO}}(\theta) = - \frac{1}{G} \sum_{i=1}^G \frac{1}{|o_i|} \sum_{t=1}^{|o_i|} \left[ \min \left( \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\pi_{\theta_{\text{old}}}(o_{i,t} \mid q, o_{i,< t})} \hat{A}_{i,t}, \, \text{clip}\left( \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\pi_{\theta_{\text{old}}}(o_{i,t} \mid q, o_{i,< t})}, 1 - \epsilon, 1 + \epsilon \right) \hat{A}_{i,t} \right) - \beta \mathbb{D}_{\text{KL}}\left[\pi_\theta \| \pi_{\text{ref}}\right] \right],

where clip(,1ϵ,1+ϵ)\text{clip}(\cdot, 1 - \epsilon, 1 + \epsilon) ensures that updates do not deviate excessively from the reference policy by bounding the policy ratio between 1ϵ 1 - \epsilon and 1+ϵ 1 + \epsilon . In TRL though, as in the original paper, we only do one update per generation, so we can simplify the loss to the first form.

Logged metrics

The GRPO Trainer logs the following metrics:

  • reward: The average reward.
  • reward_std : The average standard deviation within reward groups.
  • kl : The average KL divergence between the model and the reference model calculated on completions.

Customization

Using a custom reward function

The GRPOTrainer supports using custom reward functions instead of dense reward models. To ensure compatibility, your reward function must satisfy the following requirements:

  1. Input arguments:

    • The function must accept two arguments: prompts and completions.
    • Depending on the dataset format, the input will vary:
  2. Return value: The function must return a list of floats. Each float represents the reward corresponding to a single completion.

Example 1: Reward longer completions

Below is an example of a reward function for a standard format that rewards longer completions:

def reward_func(prompts, completions):
    """Reward function that gives higher scores to longer completions."""
    return [float(len(completion)) for completion in completions]

You can test it as follows:

>>> prompts = ["The sky is", "The sun is"]
>>> completions = [" blue.", " in the sky."]
>>> print(reward_func(prompts, completions))
[6.0, 12.0]

Example 2: Reward completions with specific format

Below is an example of a reward function that checks if the completion has a specific format. This example is inspired by the reward function used in the paper DeepSeek-R1: Incentivizing Reasoning Capability in LLMs via Reinforcement Learning. It is designed for conversational format, where prompts and completions consist of structured messages.

import re

def format_reward_func(prompts, completions):
    """Reward function that checks if the completion has a specific format."""
    pattern = r"^<think>.*?</think><answer>.*?</answer>$"
    completion_contents = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, content) for content in completion_contents]
    return [1.0 if match else 0.0 for match in matches]

You can test this function as follows:

>>> prompts = [
...     [{"role": "assistant", "content": "What is the result of (1 + 2) * 4?"}],
...     [{"role": "assistant", "content": "What is the result of (3 + 1) * 2?"}],
... ]
>>> completions = [
...     [{"role": "assistant", "content": "<think>The sum of 1 and 2 is 3, which we multiply by 4 to get 12.</think><answer>(1 + 2) * 4 = 12</answer>"}],
...     [{"role": "assistant", "content": "The sum of 3 and 1 is 4, which we multiply by 2 to get 8. So (3 + 1) * 2 = 8."}],
... ]
>>> format_reward_func(prompts, completions)
[1.0, 0.0]
>>>

Passing the reward function to the trainer

To use your custom reward function, pass it to the GRPOTrainer as follows:

from trl import GRPOTrainer

trainer = GRPOTrainer(
    reward_funcs=reward_func,
    ...,
)

If you have multiple reward functions, you can pass them as a list:

from trl import GRPOTrainer

trainer = GRPOTrainer(
    reward_funcs=[reward_func1, reward_func2],
    ...,
)

and the reward will be computed as the sum of the rewards from each function.

Note that GRPOTrainer supports multiple reward functions of different types. See the parameters documentation for more details.

GRPOTrainer

class trl.GRPOTrainer

< >

( model: typing.Union[str, transformers.modeling_utils.PreTrainedModel] reward_funcs: typing.Union[str, transformers.modeling_utils.PreTrainedModel, typing.Callable[[list, list], list[float]], list[typing.Union[str, transformers.modeling_utils.PreTrainedModel, typing.Callable[[list, list], list[float]]]]] args: GRPOConfig = None train_dataset: typing.Union[datasets.arrow_dataset.Dataset, datasets.iterable_dataset.IterableDataset, NoneType] = None eval_dataset: typing.Union[datasets.arrow_dataset.Dataset, datasets.iterable_dataset.IterableDataset, dict[str, typing.Union[datasets.arrow_dataset.Dataset, datasets.iterable_dataset.IterableDataset]], NoneType] = None processing_class: typing.Optional[transformers.tokenization_utils_base.PreTrainedTokenizerBase] = None reward_processing_classes: typing.Union[transformers.tokenization_utils_base.PreTrainedTokenizerBase, list[transformers.tokenization_utils_base.PreTrainedTokenizerBase], NoneType] = None callbacks: typing.Optional[list[transformers.trainer_callback.TrainerCallback]] = None optimizers: tuple = (None, None) peft_config: typing.Optional[ForwardRef('PeftConfig')] = None )

Parameters

  • model (Union[str, PreTrainedModel]) — Model to be trained. Can be either:

    • A string, being the model id of a pretrained model hosted inside a model repo on huggingface.co, or a path to a directory containing model weights saved using save_pretrained, e.g., './my_model_directory/'. The model is loaded using from_pretrained with the keywork arguments in args.model_init_kwargs.
    • A PreTrainedModel object. Only causal language models are supported.
  • reward_funcs (Union[RewardFunc, list[RewardFunc]]) — Reward functions to be used for computing the rewards. To compute the rewards, we call all the reward functions with the prompts and completions and sum the rewards. Can be either:

    • A single reward function, such as:
      • A string: The model ID of a pretrained model hosted inside a model repo on huggingface.co, or a path to a directory containing model weights saved using save_pretrained, e.g., './my_model_directory/'. The model is loaded using from_pretrained with num_labels=1 and the keyword arguments in args.model_init_kwargs.
      • A PreTrainedModel object: Only sequence classification models are supported.
      • A custom reward function: This should take a list of prompts and completions and return a list of rewards. For more details, see Using a custom reward function.
    • A list of reward functions, where each item can independently be any of the above types. Mixing different types within the list (e.g., a string model ID and a custom reward function) is allowed.
  • args (GRPOConfig, optional, defaults to None) — Configuration for this trainer. If None, a default configuration is used.
  • train_dataset (Dataset or IterableDataset) — Dataset to use for training. It must include a column "prompt". Any additional columns in the dataset is ignored. The format of the samples can be either:

    • Standard: Each sample contains plain text.
    • Conversational: Each sample contains structured messages (e.g., role and content).
  • eval_dataset (Dataset, IterableDataset or dict[str, Union[Dataset, IterableDataset]]) — Dataset to use for evaluation. It must meet the same requirements as train_dataset.
  • processing_class (PreTrainedTokenizerBase, optional, defaults to None) — Processing class used to process the data. The padding side must be set to “left”. If None, the processing class is loaded from the model’s name with from_pretrained.
  • reward_processing_classes (Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]], optional, defaults to None) — Processing classes corresponding to the reward functions specified in reward_funcs. Can be either:

    • A single processing class: Used when reward_funcs contains only one reward function.
    • A list of processing classes: Must match the order and length of the reward functions in reward_funcs. If set to None, or if an element of the list corresponding to a PreTrainedModel is None, the tokenizer for the model is automatically loaded using from_pretrained. For elements in reward_funcs that are custom reward functions (not PreTrainedModel), the corresponding entries in reward_processing_classes are ignored.
  • callbacks (list of TrainerCallback, optional, defaults to None) — List of callbacks to customize the training loop. Will add those to the list of default callbacks detailed in here.

    If you want to remove one of the default callbacks used, use the remove_callback method.

  • optimizers (tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR], optional, defaults to (None, None)) — A tuple containing the optimizer and the scheduler to use. Will default to an instance of AdamW on your model and a scheduler given by get_linear_schedule_with_warmup controlled by args.
  • peft_config (~peft.PeftConfig, optional, defaults to None) — PEFT configuration used to wrap the model. If None, the model is not wrapped.

Trainer for the Group Relative Policy Optimization (GRPO) method. This algorithm was initially proposed in the paper DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models.

Example:

from datasets import load_dataset
from trl import GRPOTrainer

dataset = load_dataset("trl-lib/tldr", split="train")

trainer = GRPOTrainer(
    model="Qwen/Qwen2-0.5B-Instruct",
    reward_funcs="weqweasdas/RM-Gemma-2B",
    train_dataset=dataset,
)

trainer.train()

create_model_card

< >

( model_name: typing.Optional[str] = None dataset_name: typing.Optional[str] = None tags: typing.Union[str, list[str], NoneType] = None )

Parameters

  • model_name (str or None, optional, defaults to None) — Name of the model.
  • dataset_name (str or None, optional, defaults to None) — Name of the dataset used for training.
  • tags (str, list[str] or None, optional, defaults to None) — Tags to be associated with the model card.

Creates a draft of a model card using the information available to the Trainer.

GRPOConfig

class trl.GRPOConfig

< >

( output_dir: str overwrite_output_dir: bool = False do_train: bool = False do_eval: bool = False do_predict: bool = False eval_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'no' prediction_loss_only: bool = False per_device_train_batch_size: int = 8 per_device_eval_batch_size: int = 8 per_gpu_train_batch_size: typing.Optional[int] = None per_gpu_eval_batch_size: typing.Optional[int] = None gradient_accumulation_steps: int = 1 eval_accumulation_steps: typing.Optional[int] = None eval_delay: typing.Optional[float] = 0 torch_empty_cache_steps: typing.Optional[int] = None learning_rate: float = 1e-06 weight_decay: float = 0.0 adam_beta1: float = 0.9 adam_beta2: float = 0.999 adam_epsilon: float = 1e-08 max_grad_norm: float = 1.0 num_train_epochs: float = 3.0 max_steps: int = -1 lr_scheduler_type: typing.Union[transformers.trainer_utils.SchedulerType, str] = 'linear' lr_scheduler_kwargs: typing.Union[dict, str, NoneType] = <factory> warmup_ratio: float = 0.0 warmup_steps: int = 0 log_level: typing.Optional[str] = 'passive' log_level_replica: typing.Optional[str] = 'warning' log_on_each_node: bool = True logging_dir: typing.Optional[str] = None logging_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'steps' logging_first_step: bool = False logging_steps: float = 500 logging_nan_inf_filter: bool = True save_strategy: typing.Union[transformers.trainer_utils.SaveStrategy, str] = 'steps' save_steps: float = 500 save_total_limit: typing.Optional[int] = None save_safetensors: typing.Optional[bool] = True save_on_each_node: bool = False save_only_model: bool = False restore_callback_states_from_checkpoint: bool = False no_cuda: bool = False use_cpu: bool = False use_mps_device: bool = False seed: int = 42 data_seed: typing.Optional[int] = None jit_mode_eval: bool = False use_ipex: bool = False bf16: bool = False fp16: bool = False fp16_opt_level: str = 'O1' half_precision_backend: str = 'auto' bf16_full_eval: bool = False fp16_full_eval: bool = False tf32: typing.Optional[bool] = None local_rank: int = -1 ddp_backend: typing.Optional[str] = None tpu_num_cores: typing.Optional[int] = None tpu_metrics_debug: bool = False debug: typing.Union[str, typing.List[transformers.debug_utils.DebugOption]] = '' dataloader_drop_last: bool = False eval_steps: typing.Optional[float] = None dataloader_num_workers: int = 0 dataloader_prefetch_factor: typing.Optional[int] = None past_index: int = -1 run_name: typing.Optional[str] = None disable_tqdm: typing.Optional[bool] = None remove_unused_columns: typing.Optional[bool] = True label_names: typing.Optional[typing.List[str]] = None load_best_model_at_end: typing.Optional[bool] = False metric_for_best_model: typing.Optional[str] = None greater_is_better: typing.Optional[bool] = None ignore_data_skip: bool = False fsdp: typing.Union[typing.List[transformers.trainer_utils.FSDPOption], str, NoneType] = '' fsdp_min_num_params: int = 0 fsdp_config: typing.Union[dict, str, NoneType] = None fsdp_transformer_layer_cls_to_wrap: typing.Optional[str] = None accelerator_config: typing.Union[dict, str, NoneType] = None deepspeed: typing.Union[dict, str, NoneType] = None label_smoothing_factor: float = 0.0 optim: typing.Union[transformers.training_args.OptimizerNames, str] = 'adamw_torch' optim_args: typing.Optional[str] = None adafactor: bool = False group_by_length: bool = False length_column_name: typing.Optional[str] = 'length' report_to: typing.Union[NoneType, str, typing.List[str]] = None ddp_find_unused_parameters: typing.Optional[bool] = None ddp_bucket_cap_mb: typing.Optional[int] = None ddp_broadcast_buffers: typing.Optional[bool] = None dataloader_pin_memory: bool = True dataloader_persistent_workers: bool = False skip_memory_metrics: bool = True use_legacy_prediction_loop: bool = False push_to_hub: bool = False resume_from_checkpoint: typing.Optional[str] = None hub_model_id: typing.Optional[str] = None hub_strategy: typing.Union[transformers.trainer_utils.HubStrategy, str] = 'every_save' hub_token: typing.Optional[str] = None hub_private_repo: typing.Optional[bool] = None hub_always_push: bool = False gradient_checkpointing: bool = False gradient_checkpointing_kwargs: typing.Union[dict, str, NoneType] = None include_inputs_for_metrics: bool = False include_for_metrics: typing.List[str] = <factory> eval_do_concat_batches: bool = True fp16_backend: str = 'auto' evaluation_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = None push_to_hub_model_id: typing.Optional[str] = None push_to_hub_organization: typing.Optional[str] = None push_to_hub_token: typing.Optional[str] = None mp_parameters: str = '' auto_find_batch_size: bool = False full_determinism: bool = False torchdynamo: typing.Optional[str] = None ray_scope: typing.Optional[str] = 'last' ddp_timeout: typing.Optional[int] = 1800 torch_compile: bool = False torch_compile_backend: typing.Optional[str] = None torch_compile_mode: typing.Optional[str] = None dispatch_batches: typing.Optional[bool] = None split_batches: typing.Optional[bool] = None include_tokens_per_second: typing.Optional[bool] = False include_num_input_tokens_seen: typing.Optional[bool] = False neftune_noise_alpha: typing.Optional[float] = None optim_target_modules: typing.Union[NoneType, str, typing.List[str]] = None batch_eval_metrics: bool = False eval_on_start: bool = False use_liger_kernel: typing.Optional[bool] = False eval_use_gather_object: typing.Optional[bool] = False average_tokens_across_devices: typing.Optional[bool] = False model_init_kwargs: typing.Optional[dict] = None max_prompt_length: typing.Optional[int] = 512 num_generations: typing.Optional[int] = 8 temperature: typing.Optional[float] = 0.9 max_completion_length: typing.Optional[int] = 256 beta: float = 0.04 )

Parameters that control the model and reference model

  • model_init_kwargs (dict[str, Any] or None, optional, defaults to None) — Keyword arguments for from_pretrained, used when the model argument of the GRPOTrainer is provided as a string.

Parameters that control the data preprocessing

  • max_prompt_length (int or None, optional, defaults to 512) — Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left.
  • num_generations (int or None, optional, defaults to 8) — Number of generations per prompt to sample.
  • temperature (float, optional, defaults to 0.9) — Temperature for sampling. The higher the temperature, the more random the completions.
  • max_completion_length (int or None, optional, defaults to None) — Maximum length of the generated completion.

Parameters that control the training

  • learning_rate (float, optional, defaults to 1e-6) — Initial learning rate for AdamW optimizer. The default value replaces that of TrainingArguments.
  • beta (float, optional, defaults to 0.04) — KL coefficient.

Configuration class for the GRPOTrainer.

Only the parameters specific to GRPO training are listed here. For details on other parameters, refer to the TrainingArguments documentation.

Using HfArgumentParser we can turn this class into argparse arguments that can be specified on the command line.

< > Update on GitHub