File size: 1,476 Bytes
7bf4b88 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 |
import os.path as osp
import torch
from typing import Any, Union, List, Dict
from models.model import ModelForSTaRKQA
from tqdm import tqdm
import pandas as pd
import bm25s
class BM25(ModelForSTaRKQA):
def __init__(self, skb):
super(BM25, self).__init__(skb)
self.indices = skb.candidate_ids
self.corpus = [skb.get_doc_info(idx) for idx in tqdm(self.indices, desc="Gathering docs")]
# Create the BM25 model and index the corpus
self.retriever = bm25s.BM25(corpus=self.corpus)
self.retriever.index(bm25s.tokenize(self.corpus))
# build hash map from text to index
self.text_to_index = {hash(text): index for text, index in zip(self.corpus, self.indices)}
def forward(self,
query: str,
query_id: Union[int, None] = None,
k: int = 100,
**kwargs: Any) -> dict:
"""
Forward pass to compute similarity scores for the given query.
Args:
query (str): Query string.
query_id (int): Query index.
Returns:
pred_dict (dict): A dictionary of candidate ids and their corresponding similarity scores.
"""
results, scores = self.retriever.retrieve(bm25s.tokenize(query), k=k)
indices = [self.text_to_index[hash(result.item())] for result in results[0]]
scores = scores[0].tolist()
return dict(zip(indices, scores))
|