File size: 2,719 Bytes
828992f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
import string
import torch
import torch.nn as nn
from transformers import BertPreTrainedModel, BertModel, BertTokenizerFast
from colbert.parameters import DEVICE
class ColBERT(BertPreTrainedModel):
def __init__(self, config, query_maxlen, doc_maxlen, mask_punctuation, dim=128, similarity_metric='cosine'):
super(ColBERT, self).__init__(config)
self.query_maxlen = query_maxlen
self.doc_maxlen = doc_maxlen
self.similarity_metric = similarity_metric
self.dim = dim
self.mask_punctuation = mask_punctuation
self.skiplist = {}
if self.mask_punctuation:
self.tokenizer = BertTokenizerFast.from_pretrained('bert-base-multilingual-uncased')
self.skiplist = {w: True
for symbol in string.punctuation
for w in [symbol, self.tokenizer.encode(symbol, add_special_tokens=False)[0]]}
self.bert = BertModel(config)
self.linear = nn.Linear(config.hidden_size, dim * 2, bias=False)
self.init_weights()
def forward(self, Q, D):
return self.score(self.query(*Q), self.doc(*D))
def query(self, input_ids, attention_mask):
input_ids, attention_mask = input_ids.to(DEVICE), attention_mask.to(DEVICE)
Q = self.bert(input_ids, attention_mask=attention_mask)[0]
Q = self.linear(Q)
Q = Q.split(int(Q.size(2)/2),2)
Q = torch.cat(Q,1)
return torch.nn.functional.normalize(Q, p=2, dim=2)
def doc(self, input_ids, attention_mask, keep_dims=True):
input_ids, attention_mask = input_ids.to(DEVICE), attention_mask.to(DEVICE)
D = self.bert(input_ids, attention_mask=attention_mask)[0]
D = self.linear(D)
D = D.split(int(D.size(2)/2),2)
D = torch.cat(D,1)
mask = torch.tensor(self.mask(input_ids), device=DEVICE).unsqueeze(2).float()
mask = torch.cat(2*[mask],1)
D = D * mask
D = torch.nn.functional.normalize(D, p=2, dim=2)
if not keep_dims:
D, mask = D.cpu().to(dtype=torch.float16), mask.cpu().bool().squeeze(-1)
D = [d[mask[idx]] for idx, d in enumerate(D)]
return D
def score(self, Q, D):
if self.similarity_metric == 'cosine':
return (Q @ D.permute(0, 2, 1)).max(2).values.sum(1)
assert self.similarity_metric == 'l2'
return (-1.0 * ((Q.unsqueeze(2) - D.unsqueeze(1))**2).sum(-1)).max(-1).values.sum(-1)
def mask(self, input_ids):
mask = [[(x not in self.skiplist) and (x != 0) for x in d] for d in input_ids.cpu().tolist()]
return mask
|