|
--- |
|
license: cc-by-nc-4.0 |
|
base_model: Rijgersberg/GEITje-7B |
|
tags: |
|
- alignment-handbook |
|
- trl |
|
- sft |
|
- geitje |
|
- conversational |
|
datasets: |
|
- BramVanroy/ultrachat_200k_dutch |
|
- BramVanroy/stackoverflow-chat-dutch |
|
- BramVanroy/alpaca-cleaned-dutch |
|
- BramVanroy/dolly-15k-dutch |
|
- BramVanroy/no_robots_dutch |
|
model-index: |
|
- name: GEITje-ultra-sft |
|
results: [] |
|
pipeline_tag: text-generation |
|
language: |
|
- nl |
|
--- |
|
|
|
|
|
# GEITje-ultra-sft |
|
|
|
This model is a fine-tuned version of [Rijgersberg/GEITje-7B](https://huggingface.co./Rijgersberg/GEITje-7B) on a number of synthetic datasets including gpt-3.5-turbo and gpt-4-turbo data, multi- and single turn conversations, and code. The training set consists of around 240M tokens. The model was trained with context length 8192. |
|
|
|
> [!WARNING] |
|
> Note that this model has not been aligned with DPO or other techniques. In practice, it is therefore recommended to use the [DPO variant](https://huggingface.co./BramVanroy/GEITje-7B-ultra) of this model. |
|
|
|
|
|
## Citation |
|
|
|
If you use GEITje 7B Ultra (SFT) or any of its derivatives or quantizations, place cite the following paper: |
|
|
|
```bibtex |
|
@misc{vanroy2024geitje7bultraconversational, |
|
title={GEITje 7B Ultra: A Conversational Model for Dutch}, |
|
author={Bram Vanroy}, |
|
year={2024}, |
|
eprint={2412.04092}, |
|
archivePrefix={arXiv}, |
|
primaryClass={cs.CL}, |
|
url={https://arxiv.org/abs/2412.04092}, |
|
} |
|
``` |
|
|
|
## Model description |
|
|
|
This model is a SFT (chat-tuned) version of [Rijgersberg/GEITje-7B](https://huggingface.co./Rijgersberg/GEITje-7B), which in turn is based on Mistral 7B and further pretrained on Dutch data. |
|
|
|
## Usage |
|
|
|
```python |
|
from transformers import pipeline, Conversation |
|
|
|
# load_in_8bit: lower precision but saves a lot of memory |
|
# device_map=auto: loads the model across multiple GPUs |
|
chatbot = pipeline("conversational", model="BramVanroy/GEITje-ultra-sft", model_kwargs={"load_in_8bit": True}, device_map="auto") |
|
|
|
start_messages = [ |
|
{"role": "system", "content": "Je bent een grappige chatbot die Bert heet. Je maakt vaak mopjes."}, |
|
{"role": "user", "content": "Hallo, ik ben Bram. Ik wil vanavond graag een film kijken. Heb je enkele suggesties? Liefst een Disney-film."} |
|
] |
|
conversation = Conversation(start_messages) |
|
conversation = chatbot(conversation) |
|
response = conversation.messages[-1]["content"] |
|
print(response) |
|
# Hallo Bram! Wat leuk dat je vanavond een film wilt kijken. Als je van Disney-films houdt, heb ik een paar suggesties voor je. |
|
# Een klassieker is "The Lion King", die is altijd een hit. Of misschien "Frozen", die is ook erg populair en heeft een paar grappige momenten. |
|
# Of als je iets nieuws wilt proberen, "Raya and the Last Dragon" is een spannende avonturenfilm met een hartverwarmend verhaal. Welke film spreekt jou het meest aan? |
|
``` |
|
|
|
## Intended uses & limitations |
|
|
|
This model was only trained on (synthetic) chat data and not specifically aligned through reinforcement learning. The model can generate wrong, misleading, and potentially even offensive content. Use at your own risk. |
|
|
|
Because the model was trained on synthetic data created with OpenAI/Azure services, this model cannot be used for commercial purposes. |
|
|
|
## Training and evaluation data |
|
|
|
Training data consists of older datasets that were translated to Dutch with OpenAI's gpt-3.5-turbo (alpaca, dolly, stackoverflow) and newer ones that were generated with gpt-4-turbo via Azure (no robots, ultrachat). In the case of no robots, the original English prompt (and optionally system message) were translated, and new answers were then generated with gpt-4-turbo. The case of UltraChat may be more interesting, where multi-turn conversations were generated in one go: through prompt engineering we provide the model with the original English first user message and ask it to create a conversation between a user and assistant in a single response. Additionally, and in my opinion excitedly, I created multiple personas that were randomly select from. The user messages in the dataset are written "as if" they were created by one of the personas, in hopes that the model learns to react well to different types of users. Personas include language learners, a direct conversationalist, someone who loves details, someone who is critical, a child, an expert in the field, a joyful, chaotic mind, a generalist, and "an average user". This is described in more detail [in the dataset](https://huggingface.co./datasets/BramVanroy/ultrachat_200k_dutch). |
|
|
|
The training set (`train_sft`) consists of 240,527,565 tokens (calculated prior to applying a chat template). The test sets (`test_sft` in the datasets) account for 26,397,086 tokens, which is around 10.97\% of the training set. |
|
|
|
Here is a break down of the training set (some data pages might not be available yet *but they definitely will be in the near future*). |
|
|
|
- [BramVanroy/ultrachat_200k_dutch](https://huggingface.co./datasets/BramVanroy/ultrachat_200k_dutch) (gpt-4-turbo; multi-turn; generated): 85.42% |
|
- [BramVanroy/no_robots_dutch](https://huggingface.co./datasets/BramVanroy/no_robots_dutch) (gpt-4-turbo; prompt translate, answer generated; some items have system messages): 2.20% |
|
- [BramVanroy/stackoverflow-chat-dutch](https://huggingface.co./datasets/BramVanroy/stackoverflow-chat-dutch) (gpt-3.5-turbo; multi-turn; code; translated; only 50% used): 8.38% |
|
- [BramVanroy/alpaca-cleaned-dutch](https://huggingface.co./datasets/BramVanroy/alpaca-cleaned-dutch) (gpt-3.5-turbo; translated): 2.62% |
|
- [BramVanroy/dolly-15k-dutch](https://huggingface.co./datasets/BramVanroy/dolly-15k-dutch) (gpt-3.5-turbo; translated): 1.39% |
|
|
|
|
|
## Training procedure |
|
|
|
The great [alignment handbook](https://github.com/huggingface/alignment-handbook/) was used for training, with a custom slurm script for compatibility with our cluster. It was trained in full, without LoRA or other adapters. |
|
|
|
The model was trained in bfloat16 with flash attention 2 and a context length of 8192 on two nodes of four A100 80GB each for around 2.5 hours. I thank the [Flemish Super Computer](https://www.vscentrum.be/compute) for their compute. You can find the [wandb logs](https://wandb.ai/bramvanroy/sft-geitje-ultra) here. |
|
|
|
For conversational usage, the model relies on the Zephyr chat template, which is compatible with system messages. A small portion of the data contained system messages, so it is assumed the model can handle system messages at least a little bit. |
|
|
|
|
|
Recipe used with the handbook: |
|
|
|
```yaml |
|
# Model arguments |
|
model_name_or_path: Rijgersberg/GEITje-7B |
|
model_revision: main |
|
torch_dtype: bfloat16 |
|
use_flash_attention_2: true |
|
|
|
# Data training arguments |
|
# Zephyr chat template |
|
chat_template: "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}" |
|
dataset_mixer: |
|
BramVanroy/ultrachat_200k_dutch: 1.0 |
|
BramVanroy/stackoverflow-chat-dutch: 0.5 |
|
BramVanroy/alpaca-cleaned-dutch: 1.0 |
|
BramVanroy/dolly-15k-dutch: 1.0 |
|
BramVanroy/no_robots_dutch: 1.0 |
|
dataset_splits: |
|
- train_sft |
|
- test_sft |
|
preprocessing_num_workers: 8 |
|
|
|
# SFT trainer config |
|
bf16: true |
|
do_eval: true |
|
evaluation_strategy: epoch |
|
gradient_accumulation_steps: 1 |
|
gradient_checkpointing: true |
|
gradient_checkpointing_kwargs: |
|
use_reentrant: False |
|
hub_model_id: GEITje-ultra-sft |
|
hub_strategy: every_save |
|
learning_rate: 2.0e-05 |
|
log_level: info |
|
logging_steps: 5 |
|
logging_strategy: steps |
|
lr_scheduler_type: cosine |
|
max_seq_length: 8192 |
|
max_steps: -1 |
|
num_train_epochs: 1 |
|
output_dir: data/GEITje-ultra-sft |
|
overwrite_output_dir: true |
|
per_device_eval_batch_size: 8 |
|
per_device_train_batch_size: 16 |
|
push_to_hub: true |
|
remove_unused_columns: true |
|
report_to: |
|
- wandb |
|
save_strategy: "steps" |
|
save_steps: 100 |
|
save_total_limit: 1 |
|
seed: 42 |
|
warmup_ratio: 0.1 |
|
``` |
|
|
|
|
|
### Training hyperparameters |
|
|
|
The following hyperparameters were used during training: |
|
- learning_rate: 2e-05 |
|
- train_batch_size: 4 |
|
- eval_batch_size: 4 |
|
- seed: 42 |
|
- distributed_type: multi-GPU |
|
- num_devices: 8 |
|
- gradient_accumulation_steps: 4 |
|
- total_train_batch_size: 128 |
|
- total_eval_batch_size: 32 |
|
- optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08 |
|
- lr_scheduler_type: cosine |
|
- lr_scheduler_warmup_ratio: 0.1 |
|
- num_epochs: 1 |
|
|
|
### Training results |
|
|
|
| Training Loss | Epoch | Step | Validation Loss | |
|
|:-------------:|:-----:|:----:|:---------------:| |
|
| 0.8632 | 1.0 | 238 | 0.8563 | |
|
|
|
|
|
### Framework versions |
|
|
|
- Transformers 4.36.2 |
|
- Pytorch 2.1.2+cu121 |
|
- Datasets 2.14.6 |
|
- Tokenizers 0.15.0 |