|
import os.path as osp |
|
import torch |
|
from typing import Any, Union, List, Dict |
|
from models.model import ModelForSTaRKQA |
|
from tqdm import tqdm |
|
import pandas as pd |
|
import bm25s |
|
|
|
|
|
class BM25(ModelForSTaRKQA): |
|
|
|
def __init__(self, skb): |
|
|
|
super(BM25, self).__init__(skb) |
|
|
|
self.indices = skb.candidate_ids |
|
self.corpus = [skb.get_doc_info(idx) for idx in tqdm(self.indices, desc="Gathering docs")] |
|
|
|
|
|
self.retriever = bm25s.BM25(corpus=self.corpus) |
|
self.retriever.index(bm25s.tokenize(self.corpus)) |
|
|
|
|
|
self.text_to_index = {hash(text): index for text, index in zip(self.corpus, self.indices)} |
|
|
|
def forward(self, |
|
query: str, |
|
query_id: Union[int, None] = None, |
|
k: int = 100, |
|
**kwargs: Any) -> dict: |
|
""" |
|
Forward pass to compute similarity scores for the given query. |
|
|
|
Args: |
|
query (str): Query string. |
|
query_id (int): Query index. |
|
|
|
Returns: |
|
pred_dict (dict): A dictionary of candidate ids and their corresponding similarity scores. |
|
""" |
|
results, scores = self.retriever.retrieve(bm25s.tokenize(query), k=k) |
|
indices = [self.text_to_index[hash(result.item())] for result in results[0]] |
|
scores = scores[0].tolist() |
|
return dict(zip(indices, scores)) |
|
|