import os | |
def slow_rerank(args, query, pids, passages): | |
colbert = args.colbert | |
inference = args.inference | |
Q = inference.queryFromText([query]) | |
D_ = inference.docFromText(passages, bsize=args.bsize) | |
scores = colbert.score(Q, D_).cpu() | |
scores = scores.sort(descending=True) | |
ranked = scores.indices.tolist() | |
ranked_scores = scores.values.tolist() | |
ranked_pids = [pids[position] for position in ranked] | |
ranked_passages = [passages[position] for position in ranked] | |
assert len(ranked_pids) == len(set(ranked_pids)) | |
return list(zip(ranked_scores, ranked_pids, ranked_passages)) | |