--- license: cc-by-nc-sa-4.0 --- Our best attempt at reproducing [RankT5 Enc-Softmax](https://arxiv.org/pdf/2210.10634.pdf), with a few important differences: 1. We use a SPLADE first stage for the negatives vs GTR on the paper 2. We train using Pytorch vs Flaxx on the paper 3. ~~We use the original t5-3b vs Flan T5-3b on the paper~~ -> Actually the paper also uses t5-3b 4. The head is not exactly the same, here we add Linear->LayerNorm->Linear and actually make a mistake by not including a nonlinearity. The original paper uses just a dense layer. Fixing this should improve our performance because we have more layers without actually using them correctly This leads to what seems to be a slightly worse performance (42.8 vs 43.? on the paper) and seems slightly worse on BEIR as well. To use this model, first clone the huggingface repo ``` git clone https://huggingface.co./naver/trecdl22-crossencoder-rankT53b-repro ``` And then we suggest loading it like follows: ``` import torch from transformers import T5EncoderModel, AutoTokenizer from transformers.modeling_outputs import SequenceClassifierOutput class T5EncoderRerank(torch.nn.Module): def __init__(self, model_type_or_dir): super().__init__() self.model = T5EncoderModel.from_pretrained(model_type_or_dir) self.config = self.model.config self.first_transform = torch.nn.Linear(self.config.d_model, self.config.d_model) self.layer_norm = torch.nn.LayerNorm(self.config.d_model, eps=1e-12) self.linear = torch.nn.Linear(self.config.d_model,1) def forward(self, **kwargs): result = self.model(**kwargs).last_hidden_state[:,0,:] first_transformed = self.first_transform(result) layer_normed = self.layer_norm(first_transformed) logits = self.linear(layer_normed) return SequenceClassifierOutput( logits=logits ) original_model="t5-3b" path_checkpoint="trecdl22-crossencoder-rankT53b-repro/pytorch_model.bin" print("Loading") model = T5EncoderRerank(original_model) model.load_state_dict(torch.load(path_checkpoint,map_location=torch.device("cpu"))) device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") model.to(device) tokenizer = AutoTokenizer.from_pretrained(original_model) print("loaded") ```