|
--- |
|
license: cc-by-nc-sa-4.0 |
|
--- |
|
|
|
Our best attempt at reproducing [RankT5 Enc-Softmax](https://arxiv.org/pdf/2210.10634.pdf), with a few important differences: |
|
|
|
1. We use a SPLADE first stage for the negatives vs GTR on the paper |
|
2. We train using Pytorch vs Flaxx on the paper |
|
3. ~~We use the original t5-3b vs Flan T5-3b on the paper ~~ |
|
4. The head is not exactly the same, here we add Linear->LayerNorm->Linear and actually make a mistake by not including a nonlinearity. The original paper uses just a dense layer. Fixing this should improve our performance because we have more layers without actually using them correctly |
|
|
|
This leads to what seems to be a slightly worse performance (42.8 vs 43.? on the paper) and seems slightly worse on BEIR as well. |
|
|
|
To use this model, first clone the huggingface repo |
|
|
|
``` |
|
git clone https://huggingface.co./naver/trecdl22-crossencoder-rankT53b-repro |
|
|
|
``` |
|
|
|
And then we suggest loading it like follows: |
|
|
|
``` |
|
import torch |
|
from transformers import T5EncoderModel, AutoTokenizer |
|
from transformers.modeling_outputs import SequenceClassifierOutput |
|
|
|
class T5EncoderRerank(torch.nn.Module): |
|
def __init__(self, model_type_or_dir): |
|
super().__init__() |
|
self.model = T5EncoderModel.from_pretrained(model_type_or_dir) |
|
self.config = self.model.config |
|
self.first_transform = torch.nn.Linear(self.config.d_model, self.config.d_model) |
|
self.layer_norm = torch.nn.LayerNorm(self.config.d_model, eps=1e-12) |
|
self.linear = torch.nn.Linear(self.config.d_model,1) |
|
|
|
def forward(self, **kwargs): |
|
result = self.model(**kwargs).last_hidden_state[:,0,:] |
|
first_transformed = self.first_transform(result) |
|
layer_normed = self.layer_norm(first_transformed) |
|
logits = self.linear(layer_normed) |
|
return SequenceClassifierOutput( |
|
logits=logits |
|
) |
|
|
|
|
|
original_model="t5-3b" |
|
path_checkpoint="trecdl22-crossencoder-rankT53b-repro/pytorch_model.bin" |
|
|
|
print("Loading") |
|
model = T5EncoderRerank(original_model) |
|
model.load_state_dict(torch.load(path_checkpoint,map_location=torch.device("cpu"))) |
|
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
|
model.to(device) |
|
tokenizer = AutoTokenizer.from_pretrained(original_model) |
|
print("loaded") |
|
|
|
``` |
|
|