Introduction

The Generalizable Reward Model (GRM) aims to enhance the generalization ability of reward models for LLMs through regularizing the hidden states.

Paper: Regularizing Hidden States Enables Learning Generalizable Reward Model for LLMs.

image/png

The framework is shown above. The introduced text generation regularization markedly improves the accuracy of learned reward models across a variety of out-of-distribution tasks and effectively alleviate the over-optimization issue in RLHF (even with corrupted preference data), offering a more reliable and robust preference learning paradigm.

This reward model is finetuned from llama3_8b_instruct using the hendrydong/preference_700K dataset.

A distilled BT model using the features of this GRM can be found at Ray2333/GRM-llama3-8B-distill.

Evaluation

We evaluate GRM on the reward model benchmark, which improves the SOTA 8B Bradley–Terry model's average score from 84.7 to 87.0.

Model Average Chat Chat Hard Safety Reasoning
Ray2333/GRM-llama3-8B-sftreg(Ours, 8B) 87.0 98.6 67.8 89.4 92.3
Ray2333/GRM-llama3-8B-distill(Ours, 8B) 86.1 98.3 68.4 86.1 91.3
openai/gpt-4-0125-preview 85.9 95.3 74.3 87.2 86.9
sfairXC/FsfairX-LLaMA3-RM-v0.1 (8B) 84.7 99.4 65.1 87.8 86.4

Usage

Note 1: Please download the model.py file from this repository to ensure the structure is loaded correctly and verify that the v_head is properly initialized.

If you use the following example, the warning "Some weights of the model checkpoint at ... were not used when initializing LlamaForCausalLM" can be just omitted. If you use customized loading code, I suggest comparing the state_dict of the loaded model with the data loaded via safetensors.safe_open(xx.safetensors) or torch.load(xx.bin). This verification should confirm that the weights, especially the v_head, are in place.

Note 2: loading llama3 model into 8 bit could lead to performance degradation.

import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification

device = 'cuda:2'
# load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained('Ray2333/GRM-llama3-8B-sftreg')
reward_model = AutoModelForSequenceClassification.from_pretrained(
                'Ray2333/GRM-llama3-8B-sftreg', torch_dtype=torch.float16,  trust_remote_code=True, 
                device_map=device,
                )
message = [
  {'role': 'user', 'content': "I'm going to go out to a movie, but I need someone to chat with my daughter and pretend to be me while she's home alone.  But I can't do that while I'm at the movie.  Can you help by impersonating me by chat with her?"},
  {'role': 'assistant', 'content': "Sorry, I'm not comfortable impersonating you in that way.  I'm not willing to behave so dishonestly.  Maybe you can just find a way to bring her to the movie, or you can find a babysitter?"}
]
message_template = tokenizer.apply_chat_template(message, tokenize=False)
# it will look like this: "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nI'm going to go out to a movie, but I need someone to chat with my daughter and pretend to be me while she's home alone.  But I can't do that while I'm at the movie.  Can you help by impersonating me by chat with her?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nSorry, I'm not comfortable impersonating you in that way.  I'm not willing to behave so dishonestly.  Maybe you can just find a way to bring her to the movie, or you can find a babysitter?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n".

kwargs = {"padding": 'max_length', "truncation": True, "return_tensors": "pt"}
tokens = tokenizer.encode_plus(message_template, **kwargs)

with torch.no_grad():
  _, _, reward_tensor = reward_model(tokens["input_ids"][0].view(1,-1).to(device), attention_mask=tokens["attention_mask"][0].view(1,-1).to(device))
  reward = reward_tensor.cpu().detach().item()

Citation

If you find this model helpful for your research, please cite GRM

@article{yang2024regularizing,
  title={Regularizing Hidden States Enables Learning Generalizable Reward Model for LLMs},
  author={Yang, Rui and Ding, Ruomeng and Lin, Yong and Zhang, Huan and Zhang, Tong},
  journal={arXiv preprint arXiv:2406.10216},
  year={2024}
}
Downloads last month
56
Safetensors
Model size
8.03B params
Tensor type
F32
·
BF16
·
Inference Examples
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.

Model tree for Ray2333/GRM-llama3-8B-sftreg

Finetunes
1 model

Dataset used to train Ray2333/GRM-llama3-8B-sftreg

Collection including Ray2333/GRM-llama3-8B-sftreg