import os.path as osp import torch from typing import Any, Union, List, Dict from models.model import ModelForSTaRKQA from models.vss import VSS from stark_qa.tools.api import get_openai_embeddings, get_sentence_transformer_embeddings from stark_qa.tools.process_text import chunk_text class MultiVSS(ModelForSTaRKQA): def __init__(self, skb, query_emb_dir: str, candidates_emb_dir: str, chunk_emb_dir: str, emb_model: str = 'text-embedding-ada-002', aggregate: str = 'top3_avg', max_k: int = 50, chunk_size: int = 256, device: str = 'cuda'): """ Multivector Vector Similarity Search Args: skb (SemiStruct): Knowledge base. query_emb_dir (str): Directory to query embeddings. candidates_emb_dir (str): Directory to candidate embeddings. chunk_emb_dir (str): Directory to chunk embeddings. emb_model (str): Embedding model name. aggregate (str): Aggregation method for similarity scores ('max', 'avg', 'top{k}_avg'). max_k (int): Maximum number of top candidates to consider. chunk_size (int): Size of chunks for text processing. """ super(MultiVSS, self).__init__(skb, query_emb_dir) self.skb = skb self.aggregate = aggregate # 'max', 'avg', 'top{k}_avg' self.max_k = max_k self.chunk_size = chunk_size self.emb_model = emb_model self.query_emb_dir = query_emb_dir self.chunk_emb_dir = chunk_emb_dir self.candidates_emb_dir = candidates_emb_dir # self.parent_vss = VSS(skb, query_emb_dir, candidates_emb_dir, # emb_model=emb_model, device=device) # using default embeddings (text-embedding-ada-002) self.parent_vss = VSS(skb, query_emb_dir=self.modify_emb_dir(self.query_emb_dir), candidates_emb_dir=self.modify_emb_dir(self.candidates_emb_dir), device=device) def forward(self, query: Union[str, List[str]], query_id: Union[int, List[int]], **kwargs: Any) -> Dict[int, float]: """ Forward pass to compute predictions for the given query using MultiVSS. Args: query (Union[str, list]): Query string or a list of query strings. query_id (Union[int, list]): Query index. Returns: pred_dict (dict): A dictionary of predicted scores or answer ids. """ query_emb = self.get_query_emb(query, query_id, emb_model=self.emb_model) initial_score_dict = self.parent_vss(query, query_id) node_ids = list(initial_score_dict.keys()) node_scores = list(initial_score_dict.values()) # Get the ids with top k highest scores top_k_idx = torch.topk( torch.FloatTensor(node_scores), min(self.max_k, len(node_scores)), dim=-1 ).indices.view(-1).tolist() top_k_node_ids = [node_ids[i] for i in top_k_idx] pred_dict = {} for node_id in top_k_node_ids: doc = self.skb.get_doc_info(node_id, add_rel=True, compact=True) chunks = chunk_text(doc, chunk_size=self.chunk_size) chunk_path = osp.join(self.chunk_emb_dir, f'{node_id}_size={self.chunk_size}.pt') if osp.exists(chunk_path): chunk_embs = torch.load(chunk_path) else: # using sentence transformer first, then using api # chunk_embs = get_openai_embeddings(chunks, model=self.emb_model) chunk_embs = get_sentence_transformer_embeddings(chunks) torch.save(chunk_embs, chunk_path) print(f'chunk_embs.shape: {chunk_embs.shape}') similarity = torch.matmul(query_emb.cuda(), chunk_embs.cuda().T).cpu().view(-1) if self.aggregate == 'max': pred_dict[node_id] = torch.max(similarity).item() elif self.aggregate == 'avg': pred_dict[node_id] = torch.mean(similarity).item() elif 'top' in self.aggregate: k = int(self.aggregate.split('_')[0][len('top'):]) pred_dict[node_id] = torch.mean( torch.topk(similarity, k=min(k, len(chunks)), dim=-1).values ).item() return pred_dict def modify_emb_dir(self, base_dir: str, emb_model: str='text-embedding-ada-002') -> str: """ Modify the base directory to reflect the embedding model in the path. Args: base_dir (str): Base directory for embeddings. emb_model (str): Embedding model name to be used in the directory. Returns: str: Modified directory path with the embedding model name. """ # Replace the last part of the directory with the embedding model name # Assumes that 'base_dir' is in the format 'emb/amazon//query' parts = base_dir.split('/') if len(parts) >= 3: # Make sure there are enough parts to replace parts[-2] = emb_model # Replace the embedding model part return '/'.join(parts)