import os.path as osp import torch from typing import Any, Union, List, Dict from models.model import ModelForSTaRKQA from tqdm import tqdm from stark_qa.evaluator import Evaluator import sys sys.path.append("stark/") class VSS(ModelForSTaRKQA): def __init__(self, skb, query_emb_dir: str, candidates_emb_dir: str, emb_model: str = 'text-embedding-ada-002', device: str = 'cuda'): """ Vector Similarity Search Args: skb (SemiStruct): Knowledge base. query_emb_dir (str): Directory to query embeddings. candidates_emb_dir (str): Directory to candidate embeddings. emb_model (str): Embedding model name. """ super(VSS, self).__init__(skb, query_emb_dir=query_emb_dir) self.emb_model = emb_model self.candidates_emb_dir = candidates_emb_dir self.device = device self.evaluator = Evaluator(self.candidate_ids, device) candidate_emb_path = osp.join(candidates_emb_dir, 'candidate_emb_dict.pt') candidate_emb_dict = torch.load(candidate_emb_path) print(f'Loaded candidate_emb_dict from {candidate_emb_path}!') assert len(candidate_emb_dict) == len(self.candidate_ids) candidate_embs = [candidate_emb_dict[idx].view(1, -1) for idx in self.candidate_ids] self.candidate_embs = torch.cat(candidate_embs, dim=0).to(device) def forward(self, query: Union[str, List[str]], query_id: Union[int, List[int]], **kwargs: Any) -> dict: """ Forward pass to compute similarity scores for the given query. Args: query (str): Query string. query_id (int): Query index. Returns: pred_dict (dict): A dictionary of candidate ids and their corresponding similarity scores. """ query_emb = self.get_query_emb(query, query_id, emb_model=self.emb_model, **kwargs) similarity = torch.matmul(query_emb.to(self.device), self.candidate_embs.T).cpu() if isinstance(query, str): return dict(zip(self.candidate_ids, similarity.view(-1))) else: return torch.LongTensor(self.candidate_ids), similarity.t()