|
import os.path as osp |
|
import torch |
|
from typing import Any, Union, List, Dict |
|
from models.model import ModelForSTaRKQA |
|
from models.vss import VSS |
|
from stark_qa.tools.api import get_openai_embeddings, get_sentence_transformer_embeddings |
|
from stark_qa.tools.process_text import chunk_text |
|
|
|
|
|
class MultiVSS(ModelForSTaRKQA): |
|
|
|
def __init__(self, |
|
skb, |
|
query_emb_dir: str, |
|
candidates_emb_dir: str, |
|
chunk_emb_dir: str, |
|
emb_model: str = 'text-embedding-ada-002', |
|
aggregate: str = 'top3_avg', |
|
max_k: int = 50, |
|
chunk_size: int = 256, |
|
device: str = 'cuda'): |
|
""" |
|
Multivector Vector Similarity Search |
|
|
|
Args: |
|
skb (SemiStruct): Knowledge base. |
|
query_emb_dir (str): Directory to query embeddings. |
|
candidates_emb_dir (str): Directory to candidate embeddings. |
|
chunk_emb_dir (str): Directory to chunk embeddings. |
|
emb_model (str): Embedding model name. |
|
aggregate (str): Aggregation method for similarity scores ('max', 'avg', 'top{k}_avg'). |
|
max_k (int): Maximum number of top candidates to consider. |
|
chunk_size (int): Size of chunks for text processing. |
|
""" |
|
super(MultiVSS, self).__init__(skb, query_emb_dir) |
|
self.skb = skb |
|
self.aggregate = aggregate |
|
self.max_k = max_k |
|
self.chunk_size = chunk_size |
|
self.emb_model = emb_model |
|
self.query_emb_dir = query_emb_dir |
|
self.chunk_emb_dir = chunk_emb_dir |
|
self.candidates_emb_dir = candidates_emb_dir |
|
|
|
|
|
|
|
self.parent_vss = VSS(skb, query_emb_dir=self.modify_emb_dir(self.query_emb_dir), |
|
candidates_emb_dir=self.modify_emb_dir(self.candidates_emb_dir), device=device) |
|
|
|
|
|
def forward(self, |
|
query: Union[str, List[str]], |
|
query_id: Union[int, List[int]], |
|
**kwargs: Any) -> Dict[int, float]: |
|
""" |
|
Forward pass to compute predictions for the given query using MultiVSS. |
|
|
|
Args: |
|
query (Union[str, list]): Query string or a list of query strings. |
|
query_id (Union[int, list]): Query index. |
|
|
|
Returns: |
|
pred_dict (dict): A dictionary of predicted scores or answer ids. |
|
""" |
|
query_emb = self.get_query_emb(query, query_id, emb_model=self.emb_model) |
|
|
|
initial_score_dict = self.parent_vss(query, query_id) |
|
node_ids = list(initial_score_dict.keys()) |
|
node_scores = list(initial_score_dict.values()) |
|
|
|
|
|
top_k_idx = torch.topk( |
|
torch.FloatTensor(node_scores), |
|
min(self.max_k, len(node_scores)), |
|
dim=-1 |
|
).indices.view(-1).tolist() |
|
top_k_node_ids = [node_ids[i] for i in top_k_idx] |
|
|
|
pred_dict = {} |
|
for node_id in top_k_node_ids: |
|
doc = self.skb.get_doc_info(node_id, add_rel=True, compact=True) |
|
chunks = chunk_text(doc, chunk_size=self.chunk_size) |
|
chunk_path = osp.join(self.chunk_emb_dir, f'{node_id}_size={self.chunk_size}.pt') |
|
if osp.exists(chunk_path): |
|
chunk_embs = torch.load(chunk_path) |
|
else: |
|
|
|
|
|
chunk_embs = get_sentence_transformer_embeddings(chunks) |
|
torch.save(chunk_embs, chunk_path) |
|
print(f'chunk_embs.shape: {chunk_embs.shape}') |
|
|
|
similarity = torch.matmul(query_emb.cuda(), chunk_embs.cuda().T).cpu().view(-1) |
|
if self.aggregate == 'max': |
|
pred_dict[node_id] = torch.max(similarity).item() |
|
elif self.aggregate == 'avg': |
|
pred_dict[node_id] = torch.mean(similarity).item() |
|
elif 'top' in self.aggregate: |
|
k = int(self.aggregate.split('_')[0][len('top'):]) |
|
pred_dict[node_id] = torch.mean( |
|
torch.topk(similarity, k=min(k, len(chunks)), dim=-1).values |
|
).item() |
|
|
|
return pred_dict |
|
|
|
def modify_emb_dir(self, base_dir: str, emb_model: str='text-embedding-ada-002') -> str: |
|
""" |
|
Modify the base directory to reflect the embedding model in the path. |
|
|
|
Args: |
|
base_dir (str): Base directory for embeddings. |
|
emb_model (str): Embedding model name to be used in the directory. |
|
|
|
Returns: |
|
str: Modified directory path with the embedding model name. |
|
""" |
|
|
|
|
|
parts = base_dir.split('/') |
|
if len(parts) >= 3: |
|
parts[-2] = emb_model |
|
return '/'.join(parts) |