File size: 5,438 Bytes
7bf4b88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import sys


import os
import os.path as osp
from typing import Any, Union, List, Dict

import torch
import torch.nn as nn
from stark_qa.tools.api import get_api_embeddings, get_sentence_transformer_embeddings, get_contriever_embeddings
from stark_qa.tools.local_encoder import get_llm2vec_embeddings, get_gritlm_embeddings
from stark_qa.evaluator import Evaluator


class ModelForSTaRKQA(nn.Module):
    
    def __init__(self, skb, query_emb_dir='.'):
        """
        Initializes the model with the given knowledge base.
        
        Args:
            skb: Knowledge base containing candidate information.
        """
        super(ModelForSTaRKQA, self).__init__()
        self.skb = skb

        self.candidate_ids = skb.candidate_ids
        self.num_candidates = skb.num_candidates
        self.query_emb_dir = query_emb_dir

        query_emb_path = osp.join(self.query_emb_dir, 'query_emb_dict.pt')
        if os.path.exists(query_emb_path):
            print(f'Load query embeddings from {query_emb_path}')
            self.query_emb_dict = torch.load(query_emb_path)
        else:
            self.query_emb_dict = {}
        self.evaluator = Evaluator(self.candidate_ids)
    
    def forward(self, 
                query: Union[str, List[str]], 
                candidates: List[int] = None,
                query_id: Union[int, List[int]] = None,
                **kwargs: Any) -> Dict[str, Any]:
        """
        Forward pass to compute predictions for the given query.
        
        Args:
            query (Union[str, list]): Query string or a list of query strings.
            candidates (Union[list, None]): A list of candidate ids (optional).
            query_id (Union[int, list, None]): Query index (optional).
            
        Returns:
            pred_dict (dict): A dictionary of predicted scores or answer ids.
        """
        raise NotImplementedError
    
    def get_query_emb(self, 
                      query: Union[str, List[str]], 
                      query_id: Union[int, List[int]], 
                      emb_model: str = 'text-embedding-ada-002', 
                      **encode_kwargs) -> torch.Tensor:
        """
        Retrieves or computes the embedding for the given query.
        
        Args:
            query (str): Query string.
            query_id (int): Query index.
            emb_model (str): Embedding model to use.
            
        Returns:
            query_emb (torch.Tensor): Query embedding.
        """        
        if isinstance(query_id, int):
            query_id = [query_id]
        if isinstance(query, str):
            query = [query]

        if query_id is None:
            query_emb = get_embeddings(query, emb_model, **encode_kwargs)
        elif set(query_id).issubset(set(list(self.query_emb_dict.keys()))):
            query_emb = torch.concat([self.query_emb_dict[qid] for qid in query_id], dim=0)
        else:
            query_emb = get_embeddings(query, emb_model, **encode_kwargs)
            for qid, emb in zip(query_id, query_emb):
                self.query_emb_dict[qid] = emb.view(1, -1)
            torch.save(self.query_emb_dict, osp.join(self.query_emb_dir, 'query_emb_dict.pt'))
            
        query_emb = query_emb.view(len(query), -1)
        return query_emb
    
    def evaluate(self, 
                 pred_dict: Dict[int, float], 
                 answer_ids: Union[torch.LongTensor, List[Any]], 
                 metrics: List[str] = ['mrr', 'hit@3', 'recall@20'], 
                 **kwargs: Any) -> Dict[str, float]:
        """
        Evaluates the predictions using the specified metrics.
        
        Args:
            pred_dict (Dict[int, float]): Predicted answer ids or scores.
            answer_ids (torch.LongTensor): Ground truth answer ids.
            metrics (List[str]): A list of metrics to be evaluated, including 'mrr', 'hit@k', 'recall@k', 
                                 'precision@k', 'map@k', 'ndcg@k'.
                             
        Returns:
            Dict[str, float]: A dictionary of evaluation metrics.
        """
        return self.evaluator(pred_dict, answer_ids, metrics)
    
    def evaluate_batch(self, 
                pred_ids: List[int],
                pred: torch.Tensor, 
                answer_ids: Union[torch.LongTensor, List[Any]], 
                metrics: List[str] = ['mrr', 'hit@3', 'recall@20'], 
                **kwargs: Any) -> Dict[str, float]:
        return self.evaluator.evaluate_batch(pred_ids, pred, answer_ids, metrics)


def get_embeddings(text, model_name, **encode_kwargs):
    """
    Get embeddings for the given text using the specified model.
    
    Args:
        model_name (str): Model name.
        text (Union[str, List[str]]): The input text to be embedded.
        
    Returns:
        torch.Tensor: Embedding of the input text.
    """
    if isinstance(text, str):   
        text = [text]

    if 'GritLM' in model_name:
        emb = get_gritlm_embeddings(text, model_name, **encode_kwargs)
    elif 'LLM2Vec' in model_name:
        emb = get_llm2vec_embeddings(text, model_name, **encode_kwargs)
    elif 'all-mpnet-base-v2' in model_name:
        emb = get_sentence_transformer_embeddings(text)
    elif 'contriever' in model_name:
        emb = get_contriever_embeddings(text)
    else:
        emb = get_api_embeddings(text, model_name, **encode_kwargs)
    return emb.view(len(text), -1)