GEITje-7B-ultra-sft / README.md
BramVanroy's picture
Update README.md
4d33b97 verified
|
raw
history blame
8.31 kB
metadata
license: cc-by-nc-4.0
base_model: Rijgersberg/GEITje-7B
tags:
  - alignment-handbook
  - trl
  - sft
  - geitje
datasets:
  - BramVanroy/ultrachat_200k_dutch
  - BramVanroy/stackoverflow-chat-dutch
  - BramVanroy/alpaca-cleaned-dutch
  - BramVanroy/dolly-15k-dutch
  - BramVanroy/no_robots_dutch
model-index:
  - name: GEITje-ultra-sft
    results: []
pipeline_tag: conversational
language:
  - nl

GEITje-ultra-sft

This model is a fine-tuned version of Rijgersberg/GEITje-7B on a number of synthetic datasets including gpt-3.5-turbo and gpt-4-turbo data, multi- and single turn conversations, and code. The training set consists of around 240M tokens. The model was trained with context length 8192.

Note that this model has not been aligned with DPO or other techniques. In practice, it is therefore recommended to use the DPO variant of this model.

Model description

This model is a SFT (chat-tuned) version of Rijgersberg/GEITje-7B, which in turn is based on Mistral 7B and further pretrained on Dutch data.

Usage

from transformers import pipeline, Conversation

# load_in_8bit: lower precision but saves a lot of memory
# device_map=auto: loads the model across multiple GPUs
chatbot = pipeline("conversational", model="BramVanroy/GEITje-ultra-sft", model_kwargs={"load_in_8bit": True}, device_map="auto")

start_messages = [
    {"role": "system", "content": "Je bent een grappige chatbot die Bert heet. Je maakt vaak mopjes."},
    {"role": "user", "content": "Hallo, ik ben Bram. Ik wil vanavond graag een film kijken. Heb je enkele suggesties? Liefst een Disney-film."}
]
conversation = Conversation(start_messages)
conversation = chatbot(conversation)
response = conversation.messages[-1]["content"]
print(response)
# Hallo Bram! Wat leuk dat je vanavond een film wilt kijken. Als je van Disney-films houdt, heb ik een paar suggesties voor je.
# Een klassieker is "The Lion King", die is altijd een hit. Of misschien "Frozen", die is ook erg populair en heeft een paar grappige momenten.
# Of als je iets nieuws wilt proberen, "Raya and the Last Dragon" is een spannende avonturenfilm met een hartverwarmend verhaal. Welke film spreekt jou het meest aan?

Intended uses & limitations

This model was only trained on (synthetic) chat data and not specifically aligned through reinforcement learning. The model can generate wrong, misleading, and potentially even offensive content. Use at your own risk.

Because the model was trained on synthetic data created with OpenAI/Azure services, this model cannot be used for commercial purposes.

Training and evaluation data

Training data consists of older datasets that were translated to Dutch with OpenAI's gpt-3.5-turbo (alpaca, dolly, stackoverflow) and newer ones that were generated with gpt-4-turbo via Azure (no robots, ultrachat). In the case of no robots, the original English prompt (and optionally system message) were translated, and new answers were then generated with gpt-4-turbo. The case of UltraChat may be more interesting, where multi-turn conversations were generated in one go: through prompt engineering we provide the model with the original English first user message and ask it to create a conversation between a user and assistant in a single response. Additionally, and in my opinion excitedly, I created multiple personas that were randomly select from. The user messages in the dataset are written "as if" they were created by one of the personas, in hopes that the model learns to react well to different types of users. Personas include language learners, a direct conversationalist, someone who loves details, someone who is critical, a child, an expert in the field, a joyful, chaotic mind, a generalist, and "an average user". This is described in more detail in the dataset.

The training set (train_sft) consists of 240,527,565 tokens (calculated prior to applying a chat template). The test sets (test_sft in the datasets) account for 26,397,086 tokens, which is around 10.97% of the training set.

Here is a break down of the training set (some data pages might not be available yet but they definitely will be in the near future).

Training procedure

The great alignment handbook was used for training, with a custom slurm script for compatibility with our cluster. It was trained in full, without LoRA or other adapters.

The model was trained in bfloat16 with flash attention 2 and a context length of 8192 on two nodes of four A100 80GB each for around 2.5 hours. I thank the Flemish Super Computer for their compute. You can find the wandb logs here.

For conversational usage, the model relies on the Zephyr chat template, which is compatible with system messages. A small portion of the data contained system messages, so it is assumed the model can handle system messages at least a little bit.

Recipe used with the handbook:

# Model arguments
model_name_or_path: Rijgersberg/GEITje-7B
model_revision: main
torch_dtype: bfloat16
use_flash_attention_2: true

# Data training arguments
# Zephyr chat template
chat_template: "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n'  + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"
dataset_mixer:
  BramVanroy/ultrachat_200k_dutch: 1.0
  BramVanroy/stackoverflow-chat-dutch: 0.5
  BramVanroy/alpaca-cleaned-dutch: 1.0
  BramVanroy/dolly-15k-dutch: 1.0
  BramVanroy/no_robots_dutch: 1.0
dataset_splits:
- train_sft
- test_sft
preprocessing_num_workers: 8

# SFT trainer config
bf16: true
do_eval: true
evaluation_strategy: epoch
gradient_accumulation_steps: 1
gradient_checkpointing: true
gradient_checkpointing_kwargs:
  use_reentrant: False
hub_model_id: GEITje-ultra-sft
hub_strategy: every_save
learning_rate: 2.0e-05
log_level: info
logging_steps: 5
logging_strategy: steps
lr_scheduler_type: cosine
max_seq_length: 8192
max_steps: -1
num_train_epochs: 1
output_dir: data/GEITje-ultra-sft
overwrite_output_dir: true
per_device_eval_batch_size: 8
per_device_train_batch_size: 16
push_to_hub: true
remove_unused_columns: true
report_to:
- wandb
save_strategy: "steps"
save_steps: 100
save_total_limit: 1
seed: 42
warmup_ratio: 0.1

Training hyperparameters

The following hyperparameters were used during training:

  • learning_rate: 2e-05
  • train_batch_size: 4
  • eval_batch_size: 4
  • seed: 42
  • distributed_type: multi-GPU
  • num_devices: 8
  • gradient_accumulation_steps: 4
  • total_train_batch_size: 128
  • total_eval_batch_size: 32
  • optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
  • lr_scheduler_type: cosine
  • lr_scheduler_warmup_ratio: 0.1
  • num_epochs: 1

Training results

Training Loss Epoch Step Validation Loss
0.8632 1.0 238 0.8563

Framework versions

  • Transformers 4.36.2
  • Pytorch 2.1.2+cu121
  • Datasets 2.14.6
  • Tokenizers 0.15.0