MoR / models /vss.py
GagaLey's picture
framework
7bf4b88
import os.path as osp
import torch
from typing import Any, Union, List, Dict
from models.model import ModelForSTaRKQA
from tqdm import tqdm
from stark_qa.evaluator import Evaluator
import sys
sys.path.append("stark/")
class VSS(ModelForSTaRKQA):
def __init__(self,
skb,
query_emb_dir: str,
candidates_emb_dir: str,
emb_model: str = 'text-embedding-ada-002',
device: str = 'cuda'):
"""
Vector Similarity Search
Args:
skb (SemiStruct): Knowledge base.
query_emb_dir (str): Directory to query embeddings.
candidates_emb_dir (str): Directory to candidate embeddings.
emb_model (str): Embedding model name.
"""
super(VSS, self).__init__(skb, query_emb_dir=query_emb_dir)
self.emb_model = emb_model
self.candidates_emb_dir = candidates_emb_dir
self.device = device
self.evaluator = Evaluator(self.candidate_ids, device)
candidate_emb_path = osp.join(candidates_emb_dir, 'candidate_emb_dict.pt')
candidate_emb_dict = torch.load(candidate_emb_path)
print(f'Loaded candidate_emb_dict from {candidate_emb_path}!')
assert len(candidate_emb_dict) == len(self.candidate_ids)
candidate_embs = [candidate_emb_dict[idx].view(1, -1) for idx in self.candidate_ids]
self.candidate_embs = torch.cat(candidate_embs, dim=0).to(device)
def forward(self,
query: Union[str, List[str]],
query_id: Union[int, List[int]],
**kwargs: Any) -> dict:
"""
Forward pass to compute similarity scores for the given query.
Args:
query (str): Query string.
query_id (int): Query index.
Returns:
pred_dict (dict): A dictionary of candidate ids and their corresponding similarity scores.
"""
query_emb = self.get_query_emb(query, query_id, emb_model=self.emb_model, **kwargs)
similarity = torch.matmul(query_emb.to(self.device), self.candidate_embs.T).cpu()
if isinstance(query, str):
return dict(zip(self.candidate_ids, similarity.view(-1)))
else:
return torch.LongTensor(self.candidate_ids), similarity.t()