|
import os.path as osp |
|
import torch |
|
from typing import Any, Union, List, Dict |
|
from models.model import ModelForSTaRKQA |
|
from tqdm import tqdm |
|
from stark_qa.evaluator import Evaluator |
|
import sys |
|
sys.path.append("stark/") |
|
|
|
|
|
class VSS(ModelForSTaRKQA): |
|
|
|
def __init__(self, |
|
skb, |
|
query_emb_dir: str, |
|
candidates_emb_dir: str, |
|
emb_model: str = 'text-embedding-ada-002', |
|
device: str = 'cuda'): |
|
""" |
|
Vector Similarity Search |
|
|
|
Args: |
|
skb (SemiStruct): Knowledge base. |
|
query_emb_dir (str): Directory to query embeddings. |
|
candidates_emb_dir (str): Directory to candidate embeddings. |
|
emb_model (str): Embedding model name. |
|
""" |
|
super(VSS, self).__init__(skb, query_emb_dir=query_emb_dir) |
|
self.emb_model = emb_model |
|
self.candidates_emb_dir = candidates_emb_dir |
|
self.device = device |
|
self.evaluator = Evaluator(self.candidate_ids, device) |
|
|
|
candidate_emb_path = osp.join(candidates_emb_dir, 'candidate_emb_dict.pt') |
|
candidate_emb_dict = torch.load(candidate_emb_path) |
|
print(f'Loaded candidate_emb_dict from {candidate_emb_path}!') |
|
|
|
assert len(candidate_emb_dict) == len(self.candidate_ids) |
|
candidate_embs = [candidate_emb_dict[idx].view(1, -1) for idx in self.candidate_ids] |
|
self.candidate_embs = torch.cat(candidate_embs, dim=0).to(device) |
|
|
|
def forward(self, |
|
query: Union[str, List[str]], |
|
query_id: Union[int, List[int]], |
|
**kwargs: Any) -> dict: |
|
""" |
|
Forward pass to compute similarity scores for the given query. |
|
|
|
Args: |
|
query (str): Query string. |
|
query_id (int): Query index. |
|
|
|
Returns: |
|
pred_dict (dict): A dictionary of candidate ids and their corresponding similarity scores. |
|
""" |
|
query_emb = self.get_query_emb(query, query_id, emb_model=self.emb_model, **kwargs) |
|
similarity = torch.matmul(query_emb.to(self.device), self.candidate_embs.T).cpu() |
|
if isinstance(query, str): |
|
return dict(zip(self.candidate_ids, similarity.view(-1))) |
|
else: |
|
return torch.LongTensor(self.candidate_ids), similarity.t() |
|
|