PEFT
English
File size: 2,282 Bytes
4fdc413
 
f9da618
 
 
 
 
4fdc413
47e6886
 
eab618d
47e6886
2b74e56
47e6886
 
eab618d
47e6886
76f204b
 
 
 
 
47e6886
 
eab618d
47e6886
 
 
eab618d
47e6886
 
 
 
56f241d
47e6886
 
 
 
 
 
 
eab618d
47e6886
76f204b
47e6886
 
76f204b
47e6886
76f204b
d8b8aad
eab618d
 
47e6886
eab618d
 
 
47e6886
 
 
 
a079ead
 
 
47e6886
 
 
 
 
2b74e56
 
 
 
 
 
47e6886
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
---
license: llama2
datasets:
- Tevatron/msmarco-passage-aug
language:
- en
library_name: peft
---


# RankLLaMA-7B-Passage

[Fine-Tuning LLaMA for Multi-Stage Text Retrieval](https://arxiv.org/abs/2310.08319).
Xueguang Ma, Liang Wang, Nan Yang, Furu Wei, Jimmy Lin, arXiv 2023

This model is fine-tuned from LLaMA-2-7B using LoRA for passage reranking.

## Training Data
The model is fine-tuned on the training split of [MS MARCO Passage Ranking](https://microsoft.github.io/msmarco/Datasets) datasets for 1 epoch.
Please check our paper for details.


## Usage

Below is an example to compute the similarity score of a query-passage pair

```python
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from peft import PeftModel, PeftConfig

def get_model(peft_model_name):
    config = PeftConfig.from_pretrained(peft_model_name)
    base_model = AutoModelForSequenceClassification.from_pretrained(config.base_model_name_or_path, num_labels=1)
    model = PeftModel.from_pretrained(base_model, peft_model_name)
    model = model.merge_and_unload()
    model.eval()
    return model

# Load the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained('meta-llama/Llama-2-7b-hf')
model = get_model('castorini/rankllama-v1-7b-lora-passage')

# Define a query-passage pair
query = "What is llama?"
title = "Llama"
passage = "The llama is a domesticated South American camelid, widely used as a meat and pack animal by Andean cultures since the pre-Columbian era."

# Tokenize the query-passage pair
inputs = tokenizer(f'query: {query}', f'document: {title} {passage}', return_tensors='pt')

# Run the model forward
with torch.no_grad():
    outputs = model(**inputs)
    logits = outputs.logits
    score = logits[0][0]
    print(score)

```

## Batch inference and training
An unofficial replication of the inference and training code can be found [here](https://github.com/texttron/tevatron/tree/main/examples/rankllama)

## Citation

If you find our paper or models helpful, please consider cite as follows:

```
@article{rankllama,
      title={Fine-Tuning LLaMA for Multi-Stage Text Retrieval}, 
      author={Xueguang Ma and Liang Wang and Nan Yang and Furu Wei and Jimmy Lin},
      year={2023},
      journal={arXiv:2310.08319},
}
```