File size: 5,340 Bytes
7bf4b88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import os.path as osp
import torch
from typing import Any, Union, List, Dict
from models.model import ModelForSTaRKQA
from models.vss import VSS
from stark_qa.tools.api import get_openai_embeddings, get_sentence_transformer_embeddings
from stark_qa.tools.process_text import chunk_text


class MultiVSS(ModelForSTaRKQA):
    
    def __init__(self, 
                 skb,
                 query_emb_dir: str,
                 candidates_emb_dir: str,
                 chunk_emb_dir: str,
                 emb_model: str = 'text-embedding-ada-002',
                 aggregate: str = 'top3_avg',
                 max_k: int = 50,
                 chunk_size: int = 256,
                 device: str = 'cuda'):
        """
        Multivector Vector Similarity Search

        Args:
            skb (SemiStruct): Knowledge base.
            query_emb_dir (str): Directory to query embeddings.
            candidates_emb_dir (str): Directory to candidate embeddings.
            chunk_emb_dir (str): Directory to chunk embeddings.
            emb_model (str): Embedding model name.
            aggregate (str): Aggregation method for similarity scores ('max', 'avg', 'top{k}_avg').
            max_k (int): Maximum number of top candidates to consider.
            chunk_size (int): Size of chunks for text processing.
        """
        super(MultiVSS, self).__init__(skb, query_emb_dir)
        self.skb = skb
        self.aggregate = aggregate  # 'max', 'avg', 'top{k}_avg'
        self.max_k = max_k
        self.chunk_size = chunk_size
        self.emb_model = emb_model
        self.query_emb_dir = query_emb_dir
        self.chunk_emb_dir = chunk_emb_dir
        self.candidates_emb_dir = candidates_emb_dir
        # self.parent_vss = VSS(skb, query_emb_dir, candidates_emb_dir, 
        #                       emb_model=emb_model, device=device)
        # using default embeddings (text-embedding-ada-002)
        self.parent_vss = VSS(skb, query_emb_dir=self.modify_emb_dir(self.query_emb_dir), 
                              candidates_emb_dir=self.modify_emb_dir(self.candidates_emb_dir), device=device)


    def forward(self, 
                query: Union[str, List[str]],
                query_id: Union[int, List[int]],
                **kwargs: Any) -> Dict[int, float]:
        """
        Forward pass to compute predictions for the given query using MultiVSS.

        Args:
            query (Union[str, list]): Query string or a list of query strings.
            query_id (Union[int, list]): Query index.
            
        Returns:
            pred_dict (dict): A dictionary of predicted scores or answer ids.
        """
        query_emb = self.get_query_emb(query, query_id, emb_model=self.emb_model)

        initial_score_dict = self.parent_vss(query, query_id)
        node_ids = list(initial_score_dict.keys())
        node_scores = list(initial_score_dict.values())

        # Get the ids with top k highest scores
        top_k_idx = torch.topk(
            torch.FloatTensor(node_scores),
            min(self.max_k, len(node_scores)),
            dim=-1
        ).indices.view(-1).tolist()
        top_k_node_ids = [node_ids[i] for i in top_k_idx]

        pred_dict = {}
        for node_id in top_k_node_ids:
            doc = self.skb.get_doc_info(node_id, add_rel=True, compact=True)
            chunks = chunk_text(doc, chunk_size=self.chunk_size)
            chunk_path = osp.join(self.chunk_emb_dir, f'{node_id}_size={self.chunk_size}.pt')
            if osp.exists(chunk_path):
                chunk_embs = torch.load(chunk_path)
            else:
                # using sentence transformer first, then using api
                # chunk_embs = get_openai_embeddings(chunks, model=self.emb_model)
                chunk_embs = get_sentence_transformer_embeddings(chunks) 
                torch.save(chunk_embs, chunk_path)
            print(f'chunk_embs.shape: {chunk_embs.shape}')

            similarity = torch.matmul(query_emb.cuda(), chunk_embs.cuda().T).cpu().view(-1)
            if self.aggregate == 'max':
                pred_dict[node_id] = torch.max(similarity).item()
            elif self.aggregate == 'avg':
                pred_dict[node_id] = torch.mean(similarity).item()
            elif 'top' in self.aggregate:
                k = int(self.aggregate.split('_')[0][len('top'):])
                pred_dict[node_id] = torch.mean(
                    torch.topk(similarity, k=min(k, len(chunks)), dim=-1).values
                ).item()

        return pred_dict

    def modify_emb_dir(self, base_dir: str, emb_model: str='text-embedding-ada-002') -> str:
        """
        Modify the base directory to reflect the embedding model in the path.

        Args:
            base_dir (str): Base directory for embeddings.
            emb_model (str): Embedding model name to be used in the directory.

        Returns:
            str: Modified directory path with the embedding model name.
        """
        # Replace the last part of the directory with the embedding model name
        # Assumes that 'base_dir' is in the format 'emb/amazon/<embedding_model>/query'
        parts = base_dir.split('/')
        if len(parts) >= 3:  # Make sure there are enough parts to replace
            parts[-2] = emb_model  # Replace the embedding model part
        return '/'.join(parts)