mColBERT / colbert /modeling /inference.py
vjeronymo2's picture
Adding model and checkpoint
828992f
raw
history blame
3.22 kB
import torch
from colbert.modeling.colbert import ColBERT
from colbert.modeling.tokenization import QueryTokenizer, DocTokenizer
from colbert.utils.amp import MixedPrecisionManager
from colbert.parameters import DEVICE
class ModelInference():
def __init__(self, colbert: ColBERT, amp=False):
assert colbert.training is False
self.colbert = colbert
self.query_tokenizer = QueryTokenizer(colbert.query_maxlen)
self.doc_tokenizer = DocTokenizer(colbert.doc_maxlen)
self.amp_manager = MixedPrecisionManager(amp)
def query(self, *args, to_cpu=False, **kw_args):
with torch.no_grad():
with self.amp_manager.context():
Q = self.colbert.query(*args, **kw_args)
return Q.cpu() if to_cpu else Q
def doc(self, *args, to_cpu=False, **kw_args):
with torch.no_grad():
with self.amp_manager.context():
D = self.colbert.doc(*args, **kw_args)
return D.cpu() if to_cpu else D
def queryFromText(self, queries, bsize=None, to_cpu=False):
if bsize:
batches = self.query_tokenizer.tensorize(queries, bsize=bsize)
batches = [self.query(input_ids, attention_mask, to_cpu=to_cpu) for input_ids, attention_mask in batches]
return torch.cat(batches)
input_ids, attention_mask = self.query_tokenizer.tensorize(queries)
return self.query(input_ids, attention_mask)
def docFromText(self, docs, bsize=None, keep_dims=True, to_cpu=False):
if bsize:
batches, reverse_indices = self.doc_tokenizer.tensorize(docs, bsize=bsize)
batches = [self.doc(input_ids, attention_mask, keep_dims=keep_dims, to_cpu=to_cpu)
for input_ids, attention_mask in batches]
if keep_dims:
D = _stack_3D_tensors(batches)
return D[reverse_indices]
D = [d for batch in batches for d in batch]
return [D[idx] for idx in reverse_indices.tolist()]
input_ids, attention_mask = self.doc_tokenizer.tensorize(docs)
return self.doc(input_ids, attention_mask, keep_dims=keep_dims)
def score(self, Q, D, mask=None, lengths=None, explain=False):
if lengths is not None:
assert mask is None, "don't supply both mask and lengths"
mask = torch.arange(D.size(1), device=DEVICE) + 1
mask = mask.unsqueeze(0) <= lengths.to(DEVICE).unsqueeze(-1)
scores = (D @ Q)
scores = scores if mask is None else scores * mask.unsqueeze(-1)
scores = scores.max(1)
if explain:
assert False, "TODO"
return scores.values.sum(-1).cpu()
def _stack_3D_tensors(groups):
bsize = sum([x.size(0) for x in groups])
maxlen = max([x.size(1) for x in groups])
hdim = groups[0].size(2)
output = torch.zeros(bsize, maxlen, hdim, device=groups[0].device, dtype=groups[0].dtype)
offset = 0
for x in groups:
endpos = offset + x.size(0)
output[offset:endpos, :x.size(1)] = x
offset = endpos
return output