MoR / models /multi_vss.py
GagaLey's picture
framework
7bf4b88
import os.path as osp
import torch
from typing import Any, Union, List, Dict
from models.model import ModelForSTaRKQA
from models.vss import VSS
from stark_qa.tools.api import get_openai_embeddings, get_sentence_transformer_embeddings
from stark_qa.tools.process_text import chunk_text
class MultiVSS(ModelForSTaRKQA):
def __init__(self,
skb,
query_emb_dir: str,
candidates_emb_dir: str,
chunk_emb_dir: str,
emb_model: str = 'text-embedding-ada-002',
aggregate: str = 'top3_avg',
max_k: int = 50,
chunk_size: int = 256,
device: str = 'cuda'):
"""
Multivector Vector Similarity Search
Args:
skb (SemiStruct): Knowledge base.
query_emb_dir (str): Directory to query embeddings.
candidates_emb_dir (str): Directory to candidate embeddings.
chunk_emb_dir (str): Directory to chunk embeddings.
emb_model (str): Embedding model name.
aggregate (str): Aggregation method for similarity scores ('max', 'avg', 'top{k}_avg').
max_k (int): Maximum number of top candidates to consider.
chunk_size (int): Size of chunks for text processing.
"""
super(MultiVSS, self).__init__(skb, query_emb_dir)
self.skb = skb
self.aggregate = aggregate # 'max', 'avg', 'top{k}_avg'
self.max_k = max_k
self.chunk_size = chunk_size
self.emb_model = emb_model
self.query_emb_dir = query_emb_dir
self.chunk_emb_dir = chunk_emb_dir
self.candidates_emb_dir = candidates_emb_dir
# self.parent_vss = VSS(skb, query_emb_dir, candidates_emb_dir,
# emb_model=emb_model, device=device)
# using default embeddings (text-embedding-ada-002)
self.parent_vss = VSS(skb, query_emb_dir=self.modify_emb_dir(self.query_emb_dir),
candidates_emb_dir=self.modify_emb_dir(self.candidates_emb_dir), device=device)
def forward(self,
query: Union[str, List[str]],
query_id: Union[int, List[int]],
**kwargs: Any) -> Dict[int, float]:
"""
Forward pass to compute predictions for the given query using MultiVSS.
Args:
query (Union[str, list]): Query string or a list of query strings.
query_id (Union[int, list]): Query index.
Returns:
pred_dict (dict): A dictionary of predicted scores or answer ids.
"""
query_emb = self.get_query_emb(query, query_id, emb_model=self.emb_model)
initial_score_dict = self.parent_vss(query, query_id)
node_ids = list(initial_score_dict.keys())
node_scores = list(initial_score_dict.values())
# Get the ids with top k highest scores
top_k_idx = torch.topk(
torch.FloatTensor(node_scores),
min(self.max_k, len(node_scores)),
dim=-1
).indices.view(-1).tolist()
top_k_node_ids = [node_ids[i] for i in top_k_idx]
pred_dict = {}
for node_id in top_k_node_ids:
doc = self.skb.get_doc_info(node_id, add_rel=True, compact=True)
chunks = chunk_text(doc, chunk_size=self.chunk_size)
chunk_path = osp.join(self.chunk_emb_dir, f'{node_id}_size={self.chunk_size}.pt')
if osp.exists(chunk_path):
chunk_embs = torch.load(chunk_path)
else:
# using sentence transformer first, then using api
# chunk_embs = get_openai_embeddings(chunks, model=self.emb_model)
chunk_embs = get_sentence_transformer_embeddings(chunks)
torch.save(chunk_embs, chunk_path)
print(f'chunk_embs.shape: {chunk_embs.shape}')
similarity = torch.matmul(query_emb.cuda(), chunk_embs.cuda().T).cpu().view(-1)
if self.aggregate == 'max':
pred_dict[node_id] = torch.max(similarity).item()
elif self.aggregate == 'avg':
pred_dict[node_id] = torch.mean(similarity).item()
elif 'top' in self.aggregate:
k = int(self.aggregate.split('_')[0][len('top'):])
pred_dict[node_id] = torch.mean(
torch.topk(similarity, k=min(k, len(chunks)), dim=-1).values
).item()
return pred_dict
def modify_emb_dir(self, base_dir: str, emb_model: str='text-embedding-ada-002') -> str:
"""
Modify the base directory to reflect the embedding model in the path.
Args:
base_dir (str): Base directory for embeddings.
emb_model (str): Embedding model name to be used in the directory.
Returns:
str: Modified directory path with the embedding model name.
"""
# Replace the last part of the directory with the embedding model name
# Assumes that 'base_dir' is in the format 'emb/amazon/<embedding_model>/query'
parts = base_dir.split('/')
if len(parts) >= 3: # Make sure there are enough parts to replace
parts[-2] = emb_model # Replace the embedding model part
return '/'.join(parts)