TRL documentation

Supervised Fine-tuning Trainer

You are viewing v0.4.7 version. A newer version v0.12.1 is available.
Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

Supervised Fine-tuning Trainer

Supervised fine-tuning (or SFT for short) is a crucial step in RLHF. In TRL we provide an easy-to-use API to create your SFT models and train them with few lines of code on your dataset.

Quickstart

If you have a dataset hosted on the 🤗 Hub, you can easily fine-tune your SFT model using SFTTrainer from TRL. Let us assume your dataset is imdb, the text you want to predict is inside the text field of the dataset, and you want to fine-tune the facebook/opt-350m model. The following code-snippet takes care of all the data pre-processing and training for you:

from datasets import load_dataset
from trl import SFTTrainer

dataset = load_dataset("imdb", split="train")

trainer = SFTTrainer(
    "facebook/opt-350m",
    train_dataset=dataset,
    dataset_text_field="text",
    max_seq_length=512,
)
trainer.train()

Make sure to pass a correct value for max_seq_length as the default value will be set to min(tokenizer.model_max_length, 1024).

You can also construct a model outside of the trainer and pass it as follows:

from transformers import AutoModelForCausalLM
from datasets import load_dataset
from trl import SFTTrainer

dataset = load_dataset("imdb", split="train")

model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m")

trainer = SFTTrainer(
    model,
    train_dataset=dataset,
    dataset_text_field="text",
    max_seq_length=512,
)

trainer.train()

The above snippets will use the default training arguments from the transformers.TrainingArguments class. If you want to modify that, make sure to create your own TrainingArguments object and pass it to the SFTTrainer constructor as it is done on the supervised_finetuning.py script on the stack-llama example.

Advanced usage

Format your input prompts

For instruction fine-tuning, it is quite common to have two columns inside the dataset: one for the prompt & the other for the response. This allows people to format examples like Stanford-Alpaca did as follows:

Below is an instruction ...

### Instruction
{prompt}

### Response:
{completion}

Let us assume your dataset has two fields, question and answer. Therefore you can just run:

...
def formatting_prompts_func(example):
    output_texts = []
    for i in range(len(example['question'])):
        text = f"### Question: {example['question'][i]}\n ### Answer: {example['answer'][i]}"
        output_texts.append(text)
    return output_texts

trainer = SFTTrainer(
    model,
    train_dataset=dataset,
    formatting_func=formatting_prompts_func,
)

trainer.train()

To preperly format your input make sure to process all the examples by looping over them and returning a list of processed text. Check out a full example on how to use SFTTrainer on alpaca dataset here

Packing dataset (ConstantLengthDataset)

SFTTrainer supports example packing, where multiple short examples are packed in the same input sequence to increase training efficiency. This is done with the ConstantLengthDataset utility class that returns constant length chunks of tokens from a stream of examples. To enable the usage of this dataset class, simply pass packing=True to the SFTTrainer constructor.

...

trainer = SFTTrainer(
    "facebook/opt-350m",
    train_dataset=dataset,
    dataset_text_field="text",
    packing=True
)

trainer.train()

Note that if you use a packed dataset and if you pass max_steps in the training arguments you will probably train your models for more than few epochs, depending on the way you have configured the packed dataset and the training protocol. Double check that you know and understand what you are doing.

Customize your prompts using packed dataset

If your dataset has several fields that you want to combine, for example if the dataset has question and answer fields and you want to combine them, you can pass a formatting function to the trainer that will take care of that. For example:

def formatting_func(example):
    text = f"### Question: {example['question']}\n ### Answer: {example['answer']}"
    return text

trainer = SFTTrainer(
    "facebook/opt-350m",
    train_dataset=dataset,
    packing=True,
    formatting_func=formatting_func
)

trainer.train()

You can also customize the ConstantLengthDataset much more by directly passing the arguments to the SFTTrainer constructor. Please refer to that class’ signature for more information.

Control over the pretrained model

You can directly pass the kwargs of the from_pretrained() method to the SFTTrainer. For example, if you want to load a model in a different precision, analoguous to

model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m", torch_dtype=torch.bfloat16)

```python
...

trainer = SFTTrainer(
    "facebook/opt-350m",
    train_dataset=dataset,
    dataset_text_field="text",
    torch_dtype=torch.bfloat16,
)

trainer.train()

Note that all keyword arguments of from_pretrained() are supported.

Training adapters

We also support a tight integration with 🤗 PEFT library so that any user can conveniently train adapters and share them on the Hub instead of training the entire model

from datasets import load_dataset
from trl import SFTTrainer
from peft import LoraConfig

dataset = load_dataset("imdb", split="train")

peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
)

trainer = SFTTrainer(
    "EleutherAI/gpt-neo-125m",
    train_dataset=dataset,
    dataset_text_field="text",
    peft_config=peft_config
)

trainer.train()

Note that in case of training adapters, we manually add a saving callback to automatically save the adapters only:

class PeftSavingCallback(TrainerCallback):
    def on_save(self, args, state, control, **kwargs):
        checkpoint_path = os.path.join(args.output_dir, f"checkpoint-{state.global_step}")
        kwargs["model"].save_pretrained(checkpoint_path)

        if "pytorch_model.bin" in os.listdir(checkpoint_path):
            os.remove(os.path.join(checkpoint_path, "pytorch_model.bin"))

If you want to add more callbacks, make sure to add this one as well to properly save the adapters only during training.

...

callbacks = [YourCustomCallback(), PeftSavingCallback()]

trainer = SFTTrainer(
    "EleutherAI/gpt-neo-125m",
    train_dataset=dataset,
    dataset_text_field="text",
    torch_dtype=torch.bfloat16,
    peft_config=peft_config,
    callbacks=callbacks
)

trainer.train()

Training adapters with base 8 bit models

For that you need to first load your 8bit model outside the Trainer and pass a PeftConfig to the trainer. For example:

...

peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
)

model = AutoModelForCausalLM.from_pretrained(
    "EleutherAI/gpt-neo-125m",
    load_in_8bit=True,
    device_map="auto",
)

trainer = SFTTrainer(
    model,
    train_dataset=dataset,
    dataset_text_field="text",
    torch_dtype=torch.bfloat16,
    peft_config=peft_config,
)

trainer.train()

Best practices

Pay attention to the following best practices when training a model with that trainer:

  • SFTTrainer always pads by default the sequences to the max_seq_length argument of the SFTTrainer. If none is passed, the trainer will retrieve that value from the tokenizer. Some tokenizers do not provide default value, so there is a check to retrieve the minimum between 2048 and that value. Make sure to check it before training.
  • For training adapters in 8bit, you might need to tweak the arguments of the prepare_model_for_int8_training method from PEFT, hence we advise users to use prepare_in_int8_kwargs field, or create the PeftModel outside the SFTTrainer and pass it.
  • For a more memory-efficient training using adapters, you can load the base model in 8bit, for that simply add load_in_8bit argument when creating the SFTTrainer, or create a base model in 8bit outside the trainer and pass it.
  • If you create a model outside the trainer, make sure to not pass to the trainer any additional keyword arguments that are relative to from_pretrained() method.

SFTTrainer

class trl.SFTTrainer

< >

( model: typing.Union[transformers.modeling_utils.PreTrainedModel, torch.nn.modules.module.Module, str] = None args: TrainingArguments = None data_collator: typing.Optional[DataCollator] = None train_dataset: typing.Optional[datasets.arrow_dataset.Dataset] = None eval_dataset: typing.Union[datasets.arrow_dataset.Dataset, typing.Dict[str, datasets.arrow_dataset.Dataset], NoneType] = None tokenizer: typing.Optional[transformers.tokenization_utils_base.PreTrainedTokenizerBase] = None model_init: typing.Union[typing.Callable[[], transformers.modeling_utils.PreTrainedModel], NoneType] = None compute_metrics: typing.Union[typing.Callable[[transformers.trainer_utils.EvalPrediction], typing.Dict], NoneType] = None callbacks: typing.Optional[typing.List[transformers.trainer_callback.TrainerCallback]] = None optimizers: typing.Tuple[torch.optim.optimizer.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None) preprocess_logits_for_metrics: typing.Union[typing.Callable[[torch.Tensor, torch.Tensor], torch.Tensor], NoneType] = None peft_config: typing.Optional[typing.Dict] = None dataset_text_field: typing.Optional[str] = None packing: typing.Optional[bool] = False formatting_func: typing.Optional[typing.Callable] = None max_seq_length: typing.Optional[int] = None infinite: typing.Optional[bool] = False num_of_sequences: typing.Optional[int] = 1024 chars_per_token: typing.Optional[float] = 3.6 )

Parameters

  • model (Union[transformers.PreTrainedModel, nn.Module, str]) — The model to train, can be a PreTrainedModel, a torch.nn.Module or a string with the model name to load from cache or download. The model can be also converted to a PeftModel if a PeftConfig object is passed to the peft_config argument.
  • args (Optionaltransformers.TrainingArguments) — The arguments to tweak for training. Please refer to the official documentation of transformers.TrainingArguments for more information.
  • data_collator (Optionaltransformers.DataCollator) — The data collator to use for training.
  • train_dataset (Optionaldatasets.Dataset) — The dataset to use for training. We recommend users to use trl.trainer.ConstantLengthDataset to create their dataset.
  • eval_dataset (Optional[Union[datasets.Dataset, Dict[str, datasets.Dataset]]]) — The dataset to use for evaluation. We recommend users to use trl.trainer.ConstantLengthDataset to create their dataset.
  • tokenizer (Optionaltransformers.PreTrainedTokenizer) — The tokenizer to use for training. If not specified, the tokenizer associated to the model will be used.
  • model_init (Callable[[], transformers.PreTrainedModel]) — The model initializer to use for training. If None is specified, the default model initializer will be used.
  • compute_metrics (Callable[[transformers.EvalPrediction], Dict], optional defaults to compute_accuracy) — The metrics to use for evaluation. If no metrics are specified, the default metric (compute_accuracy) will be used.
  • callbacks (List[transformers.TrainerCallback]) — The callbacks to use for training.
  • optimizers (Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]) — The optimizer and scheduler to use for training.
  • preprocess_logits_for_metrics (Callable[[torch.Tensor, torch.Tensor], torch.Tensor]) — The function to use to preprocess the logits before computing the metrics.
  • peft_config (Optional[PeftConfig]) — The PeftConfig object to use to initialize the PeftModel.
  • dataset_text_field (Optional[str]) — The name of the text field of the dataset, in case this is passed by a user, the trainer will automatically create a ConstantLengthDataset based on the dataset_text_field argument.
  • formatting_func (Optional[Callable]) — The formatting function to be used for creating the ConstantLengthDataset.
  • max_seq_length (Optional[int]) — The maximum sequence length to use for the ConstantLengthDataset and for automaticallty creating the Dataset. Defaults to 512.
  • infinite (Optional[bool]) — Whether to use an infinite dataset or not. Defaults to False.
  • num_of_sequences (Optional[int]) — The number of sequences to use for the ConstantLengthDataset. Defaults to 1024.
  • chars_per_token (Optional[float]) — The number of characters per token to use for the ConstantLengthDataset. Defaults to 3.6. You can check how this is computed in the stack-llama example: https://github.com/lvwerra/trl/blob/08f550674c553c36c51d1027613c29f14f3676a5/examples/stack_llama/scripts/supervised_finetuning.py#L53.
  • packing (Optional[bool]) — Used only in case dataset_text_field is passed. This argument is used by the ConstantLengthDataset to pack the sequences of the dataset.

Class definition of the Supervised Finetuning Trainer (SFT Trainer). This class is a wrapper around the transformers.Trainer class and inherits all of its attributes and methods. The trainer takes care of properly initializing the PeftModel in case a user passes a PeftConfig object.

ConstantLengthDataset

class trl.trainer.ConstantLengthDataset

< >

( *args **kwds )

Parameters

  • tokenizer (transformers.PreTrainedTokenizer) — The processor used for proccessing the data.
  • dataset (dataset.Dataset) — Dataset with text files.
  • dataset_text_field (str, optional) — Name of the field in the dataset that contains the text. Used only if formatting_func is None.
  • formatting_func (Callable, optional) — Function that formats the text before tokenization. Usually it is recommended to have follows a certain pattern such as `”### Question: {question}

Iterable dataset that returns constant length chunks of tokens from stream of text files. The dataset also formats the text before tokenization with a specific format that is provided by the user.

Answer: {answer}

infinite (bool, *optional*, defaults to False): If True the iterator is reset after dataset reaches end else stops. seq_length (int, *optional*, defaults to 1024): Length of token sequences to return. num_of_sequences (int, *optional*, defaults to 1024): Number of token sequences to keep in buffer. chars_per_token (int, *optional*, defaults to 3.6): Number of characters per token used to estimate number of tokens in text buffer. eos_token_id (int, *optional*, defaults to 0`): Id of the end of sequence token if the passed tokenizer does not have an EOS token. shuffle (‘bool’, optional, defaults to True) Shuffle the examples before they are returned