|
--- |
|
license: mit |
|
datasets: |
|
- CarperAI/openai_summarize_tldr |
|
language: |
|
- en |
|
base_model: |
|
- EleutherAI/gpt-j-6b |
|
- CarperAI/openai_summarize_tldr_sft |
|
--- |
|
# ALT-Quark model |
|
This is a Quark-based baseline developed during the research carried out in the [ALT paper](https://www.arxiv.org/abs/2407.16970). The model is trained following the algorithm introduced in [Quark](https://arxiv.org/abs/2205.13636), with a slight modification as to sample multiple generations per prompt to compute the reward quantiles locally instead of globally across all prompts. We found that this was crucial for training. |
|
Notice that Quark was not introduced for tackling the alignment problem but for unlearning attributes in text completion tasks, such as unlearning toxcity, negative sentiment or repetition. |
|
|
|
It is a fine-tuned **GPT-J (6B)** model on the **TL;DR Summarization** dataset to be better aligned with humans' preferences on summaries, i.e., accounting for axes such as accuracy, coverage, and coherence. |
|
|
|
# Model description |
|
The alignment process departs from a [SFT checkpoint](https://huggingface.co./CarperAI/openai_summarize_tldr_sft) released by CarperAI and trained using their [trlx](https://github.com/CarperAI/trlx/tree/main/examples/summarize_rlhf) library. |
|
|
|
In a nutshell, the Quark method consists on sampling new generations and scoring them with a reward model to further cluster them into reward quantiles. For every quantile in a pre-defined number of quantiles, a new reward quantile token is added to the tokenizer. Afterward, each generation is mapped to a reward quantile token, and the latter is preppended to the input prompt for conditional language modelling training. |
|
|
|
For extensive coverage on Quark, please refer to their paper. |
|
|
|
The reward model used for scoring the generations can be found in [here](https://huggingface.co./CarperAI/openai_summarize_tldr_rm_checkpoint). We used K = 5 quantile tokens, which were newly added to the tokenizer: |
|
```python |
|
{'_QUANTILE_TOKEN_0_', '_QUANTILE_TOKEN_1_', '_QUANTILE_TOKEN_2_', '_QUANTILE_TOKEN_3_', '_QUANTILE_TOKEN_4_'} |
|
``` |
|
Thus, at inference time, the expected aligned behavior can be attained by conditioning the input on `_QUANTILE_TOKEN_0_`. |
|
|
|
**Related Models:** [ALT-RM](https://huggingface.co./sauc-abadal-lloret/gpt-j-6b-ALT-RM-tldr). |
|
|
|
# Intended uses & limitations |
|
This model originates from a research project focused on alignment and is intended primarily for research purposes. Commercial use as an off-the-shelf model is discouraged, as it was not designed with such applications in mind. The model is tailored specifically for the summarization task, having been trained on the TL;DR dataset, though some out-of-distribution generalization may be possible for related datasets. |
|
|
|
# How to use |
|
|
|
You should format the input by preppending the feedback as follows: `_QUANTILE_TOKEN_0_{prompt}` |
|
```python |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig |
|
|
|
checkpoint_path = "sauc-abadal-lloret/gpt-j-6b-ALT-Quark-tldr" |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(checkpoint_path) |
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
model = AutoModelForCausalLM.from_pretrained(checkpoint_path) |
|
model.eval() |
|
|
|
prompt = "_QUANTILE_TOKEN_0_SUBREDDIT: r/relationship_advice\nTITLE: I'm [18M] going to a party where an old middle \ |
|
school crush [17F] is also going.\nPOST: Story time! Back in the summer after 8th grade, I hung out with my group of \ |
|
friends everyday for the whole summer. There was this girl in the group and I really liked her. Like I had the biggest \ |
|
and dumbest crush on her. I was only 13 so I didn't know shit, but I was thinking she's perfect for me, I gotta marry \ |
|
her and all this dumb stuff. The puppy love was so strong I wanted to be a part of her life and I wanted her to be a \ |
|
part of my life. I never had the courage to ask her out, and we went to different high schools. Eventually we stopped \ |
|
talking but during high school I never really liked anyone else. Every other girl felt dull compared to her. I still \ |
|
get nostalgic thinking about her and what would've been different if I had the balls to ask her out. Anyway I'm going \ |
|
to a party this Friday and I heard she's coming. I honestly don't know what to do to so this goes great and eventually \ |
|
ends up in a relationship.\nTL;DR:" |
|
|
|
inputs = tokenizer([prompt], padding=True, truncation=True, return_tensors="pt") |
|
input_seq_len = inputs["input_ids"].shape[1] |
|
|
|
generation_config = GenerationConfig( |
|
max_length = 2048, |
|
max_new_tokens = 64, |
|
do_sample = False, |
|
num_beams = 1, |
|
bad_words_ids = None, |
|
num_return_sequences = 1, |
|
return_dict_in_generate = True, |
|
pad_token_id = tokenizer.pad_token_id, |
|
) |
|
|
|
outputs = model.generate(**inputs, generation_config=generation_config) |
|
generated_input_ids = outputs["sequences"][:, input_seq_len:] |
|
generated_text = tokenizer.batch_decode( |
|
generated_input_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True |
|
) |
|
generated_text |
|
``` |
|
|
|
``` |
|
[" I'm going to a party where an old middle school crush is also going. I honestly don't know what to do to so this goes great and eventually ends up in a relationship."] |
|
``` |
|
|
|
## Training data |
|
The model was trained on the TL;DR summarization dataset introduced in the Stiennon et al.'s, ["Learning to Summarize from human feedback"](https://arxiv.org/abs/2009.01325) paper. We employed the dataset version from CarperAI, which can be found in the HuggingFace Hub in [here](CarperAI/openai_summarize_tldr). |
|
|
|
## Training procedure |
|
The exact training procedure and hyper-parameters configuration can be found in our paper. |
|
|
|
## Variable and metrics |
|
As an evaluation metric, we compute GPT-4 win-rates over PPO on a 1k random subset of the test set. We use the prompt provided in the DPO paper and we ask GPT-4 to compare generations between ALT-RM and Quark and PPO. Furthermore, we report the following metrics computed on the whole test set: average reward model score, perplexity measured by the SFT reference policy as a proxy for fluency, and average length of the generations. In addition, we conduct an out-of-domain evaluation and compute GPT-4 win-rates on 100 articles from the test split of the CNN/DailyMail dataset. |
|
|
|
| **Model** | **TL;DR** (In-domain) | **CNN/DailyMail** (Out-of-domain) | |
|
|:---------------:|:---------------------:|:----------------------------------:| |
|
| Quark vs PPO | 0.36 | 0.40 | |
|
| ALT-RM vs PPO | 0.50 | 0.48 | |
|
|
|
*Win-rates with GPT-4. TL;DR on 1000 randomly chosen test prompts and CNN/daily mail on 100 randomly chosen test prompts.* |
|
|
|
| **Model** | **RM** | **PPL** | **Avg. len** | **# Train** | |
|
|:---------------:|:---------------------:|:----------------------------------:|:----------------------------------:|:----------------------------------:| |
|
| SFT | 2.89 | 1.96 | 31.25 | - | |
|
| Refrences | 2.89 | 11.84 | 32.60 | - | |
|
| PPO | 3.38 | 2.29 | 67.52 | 116k | |
|
| Quark | 3.52 | 1.82 | 49.42 | 19k | |
|
| ALT-RM | 3.58 | 2.20 | 46.14 | 19k | |
|
|
|
*TL;DR metrics on the whole test set, including avg. reward model score, perplexity, avg. generations’ length, and number of training prompts.* |
|
|
|
## BibTeX entry and citation info |
|
``` |
|
@misc{lloret2024aligninglanguagemodelstextual, |
|
title={Towards Aligning Language Models with Textual Feedback}, |
|
author={Saüc Abadal Lloret and Shehzaad Dhuliawala and Keerthiram Murugesan and Mrinmaya Sachan}, |
|
year={2024}, |
|
eprint={2407.16970}, |
|
archivePrefix={arXiv}, |
|
primaryClass={cs.CL}, |
|
url={https://arxiv.org/abs/2407.16970}, |
|
} |
|
``` |