Robust Visual Reward Model
Robust visual reward model (RoVRM) is developed through a three-phase progressive training (i.e., pre-training with textual preference data→fine-tuning with image caption-based preference data→fine-tuning with visual preference data), and optimal transport-based selective preference data. These approaches effectively transfer preferences from auxiliary textual data to enhance the model's robustness. The repository hosts the RoVRM built on the LLaVA-1.5-7B model. We employed RoVRM for best-of-$n$ sampling and RL training, demonstrating its capability to significantly improve performance and reduce hallucination in large vision-language models. Detailed training information and experimental results are available in our paper.
How to use the model
We recommend using the Vision-LLM-Alignment system to run our RoVRM, as it was also used for its training.
To evaluate a question-answer pair with RoVRM, follow two steps:
- Convert the safetensor format model to
pytorch_model.bin
by using theconvert_pytorch_bin.py
script. - Download the Vision-LLM-Alignment repository and run the demo from the first-level directory within the repository.
from transformers import AutoProcessor
from training.utils.model.third_party_model.hf_model.modeling_llava import LlavaForConditionalGeneration
from torch.utils.data.dataloader import default_collate
from PIL import Image
import copy
import torch
import argparse
import os
device = torch.device("cuda:0")
from training.utils.model.modeling_reward import create_reward_or_critic_model
# Set vis_llm_base path and path of the checkpoint
# You need to load the llava-1.5-7b model to build an initialized RoVRM.
base_path = "base_models/llava-1.5-7b-hf"
# the checkpoint of RoVRM.
ckpt_path = "models/pytorch_model.bin"
processor = AutoProcessor.from_pretrained(base_path)
image_processor = processor.image_processor
tokenizer = processor.tokenizer
tokenizer.add_bos_token = True
tokenizer.add_eos_token = True
args = {
"model_architecture": "llava",
"lang_decoder_update": False,
"from_checkpoint": base_path
}
args = argparse.Namespace(**args)
model, image_processor, tokenizer = create_reward_or_critic_model(
text_tokenizer=tokenizer,
args=args)
model.load_state_dict(torch.load(os.path.join(ckpt_path, 'pytorch_model.bin'), map_location='cpu'), strict=False)
model.to(device)
# Set input sentence and path of the input image
# <image> is necessary when there is an image input
input_sen = "USER: ### Image:<image>\nIdentify and describe each object in the image in detail.\nASSISTANT: In the image, there is a cute, colorful cartoon girl sitting on a chair at a wooden table. She is reading a book, which is a prominent object in the scene. The table and chair are also present, adding to the overall setting. As this is a cartoon-style image, the girl and the book may have a more exaggerated or simplified design compared to real-life objects. "
img_path = "llava1.5_raw_images_00011_000118793.jpg"
# Load and preprocess the image
image = Image.open(img_path).convert("RGB")
image = image_processor(image)
try:
image = image['pixel_values'][0]
except:
pass
input_sen = tokenizer(input_sen,
return_tensors=None,
padding="do_not_pad",
truncation=True,
max_length=512,)
input_sen.update(labels=copy.deepcopy(input_sen["input_ids"]))
input_sen.update(image=image)
reward_scores = model(img=default_collate(image).reshape((-1,) + image[0].shape[-3:]).unsqueeze(0).to(device),
lang=torch.LongTensor(input_sen["input_ids"]).unsqueeze(0).to(device),
attention_mask=torch.LongTensor(input_sen["attention_mask"]).unsqueeze(0).to(device),
input_labels=torch.LongTensor(input_sen["labels"]).unsqueeze(0).to(device))
print(reward_scores[0].item())
Please cite our paper if you find RoVRM helpful in your work🌹🌹🌹:
@misc{wang2024rovrmrobustvisualreward,
title={RoVRM: A Robust Visual Reward Model Optimized via Auxiliary Textual Preference Data},
author={Chenglong Wang and Yang Gan and Yifu Huo and Yongyu Mu and Murun Yang and Qiaozhi He and Tong Xiao and Chunliang Zhang and Tongran Liu and Quan Du and Di Yang and Jingbo Zhu},
year={2024},
eprint={2408.12109},
archivePrefix={arXiv},
primaryClass={cs.CV},
url={https://arxiv.org/abs/2408.12109},
}