vjeronymo2's picture
Adding model and checkpoint
828992f
raw
history blame
2.72 kB
import string
import torch
import torch.nn as nn
from transformers import BertPreTrainedModel, BertModel, BertTokenizerFast
from colbert.parameters import DEVICE
class ColBERT(BertPreTrainedModel):
def __init__(self, config, query_maxlen, doc_maxlen, mask_punctuation, dim=128, similarity_metric='cosine'):
super(ColBERT, self).__init__(config)
self.query_maxlen = query_maxlen
self.doc_maxlen = doc_maxlen
self.similarity_metric = similarity_metric
self.dim = dim
self.mask_punctuation = mask_punctuation
self.skiplist = {}
if self.mask_punctuation:
self.tokenizer = BertTokenizerFast.from_pretrained('bert-base-multilingual-uncased')
self.skiplist = {w: True
for symbol in string.punctuation
for w in [symbol, self.tokenizer.encode(symbol, add_special_tokens=False)[0]]}
self.bert = BertModel(config)
self.linear = nn.Linear(config.hidden_size, dim * 2, bias=False)
self.init_weights()
def forward(self, Q, D):
return self.score(self.query(*Q), self.doc(*D))
def query(self, input_ids, attention_mask):
input_ids, attention_mask = input_ids.to(DEVICE), attention_mask.to(DEVICE)
Q = self.bert(input_ids, attention_mask=attention_mask)[0]
Q = self.linear(Q)
Q = Q.split(int(Q.size(2)/2),2)
Q = torch.cat(Q,1)
return torch.nn.functional.normalize(Q, p=2, dim=2)
def doc(self, input_ids, attention_mask, keep_dims=True):
input_ids, attention_mask = input_ids.to(DEVICE), attention_mask.to(DEVICE)
D = self.bert(input_ids, attention_mask=attention_mask)[0]
D = self.linear(D)
D = D.split(int(D.size(2)/2),2)
D = torch.cat(D,1)
mask = torch.tensor(self.mask(input_ids), device=DEVICE).unsqueeze(2).float()
mask = torch.cat(2*[mask],1)
D = D * mask
D = torch.nn.functional.normalize(D, p=2, dim=2)
if not keep_dims:
D, mask = D.cpu().to(dtype=torch.float16), mask.cpu().bool().squeeze(-1)
D = [d[mask[idx]] for idx, d in enumerate(D)]
return D
def score(self, Q, D):
if self.similarity_metric == 'cosine':
return (Q @ D.permute(0, 2, 1)).max(2).values.sum(1)
assert self.similarity_metric == 'l2'
return (-1.0 * ((Q.unsqueeze(2) - D.unsqueeze(1))**2).sum(-1)).max(-1).values.sum(-1)
def mask(self, input_ids):
mask = [[(x not in self.skiplist) and (x != 0) for x in d] for d in input_ids.cpu().tolist()]
return mask