File size: 5,438 Bytes
7bf4b88 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
import sys
import os
import os.path as osp
from typing import Any, Union, List, Dict
import torch
import torch.nn as nn
from stark_qa.tools.api import get_api_embeddings, get_sentence_transformer_embeddings, get_contriever_embeddings
from stark_qa.tools.local_encoder import get_llm2vec_embeddings, get_gritlm_embeddings
from stark_qa.evaluator import Evaluator
class ModelForSTaRKQA(nn.Module):
def __init__(self, skb, query_emb_dir='.'):
"""
Initializes the model with the given knowledge base.
Args:
skb: Knowledge base containing candidate information.
"""
super(ModelForSTaRKQA, self).__init__()
self.skb = skb
self.candidate_ids = skb.candidate_ids
self.num_candidates = skb.num_candidates
self.query_emb_dir = query_emb_dir
query_emb_path = osp.join(self.query_emb_dir, 'query_emb_dict.pt')
if os.path.exists(query_emb_path):
print(f'Load query embeddings from {query_emb_path}')
self.query_emb_dict = torch.load(query_emb_path)
else:
self.query_emb_dict = {}
self.evaluator = Evaluator(self.candidate_ids)
def forward(self,
query: Union[str, List[str]],
candidates: List[int] = None,
query_id: Union[int, List[int]] = None,
**kwargs: Any) -> Dict[str, Any]:
"""
Forward pass to compute predictions for the given query.
Args:
query (Union[str, list]): Query string or a list of query strings.
candidates (Union[list, None]): A list of candidate ids (optional).
query_id (Union[int, list, None]): Query index (optional).
Returns:
pred_dict (dict): A dictionary of predicted scores or answer ids.
"""
raise NotImplementedError
def get_query_emb(self,
query: Union[str, List[str]],
query_id: Union[int, List[int]],
emb_model: str = 'text-embedding-ada-002',
**encode_kwargs) -> torch.Tensor:
"""
Retrieves or computes the embedding for the given query.
Args:
query (str): Query string.
query_id (int): Query index.
emb_model (str): Embedding model to use.
Returns:
query_emb (torch.Tensor): Query embedding.
"""
if isinstance(query_id, int):
query_id = [query_id]
if isinstance(query, str):
query = [query]
if query_id is None:
query_emb = get_embeddings(query, emb_model, **encode_kwargs)
elif set(query_id).issubset(set(list(self.query_emb_dict.keys()))):
query_emb = torch.concat([self.query_emb_dict[qid] for qid in query_id], dim=0)
else:
query_emb = get_embeddings(query, emb_model, **encode_kwargs)
for qid, emb in zip(query_id, query_emb):
self.query_emb_dict[qid] = emb.view(1, -1)
torch.save(self.query_emb_dict, osp.join(self.query_emb_dir, 'query_emb_dict.pt'))
query_emb = query_emb.view(len(query), -1)
return query_emb
def evaluate(self,
pred_dict: Dict[int, float],
answer_ids: Union[torch.LongTensor, List[Any]],
metrics: List[str] = ['mrr', 'hit@3', 'recall@20'],
**kwargs: Any) -> Dict[str, float]:
"""
Evaluates the predictions using the specified metrics.
Args:
pred_dict (Dict[int, float]): Predicted answer ids or scores.
answer_ids (torch.LongTensor): Ground truth answer ids.
metrics (List[str]): A list of metrics to be evaluated, including 'mrr', 'hit@k', 'recall@k',
'precision@k', 'map@k', 'ndcg@k'.
Returns:
Dict[str, float]: A dictionary of evaluation metrics.
"""
return self.evaluator(pred_dict, answer_ids, metrics)
def evaluate_batch(self,
pred_ids: List[int],
pred: torch.Tensor,
answer_ids: Union[torch.LongTensor, List[Any]],
metrics: List[str] = ['mrr', 'hit@3', 'recall@20'],
**kwargs: Any) -> Dict[str, float]:
return self.evaluator.evaluate_batch(pred_ids, pred, answer_ids, metrics)
def get_embeddings(text, model_name, **encode_kwargs):
"""
Get embeddings for the given text using the specified model.
Args:
model_name (str): Model name.
text (Union[str, List[str]]): The input text to be embedded.
Returns:
torch.Tensor: Embedding of the input text.
"""
if isinstance(text, str):
text = [text]
if 'GritLM' in model_name:
emb = get_gritlm_embeddings(text, model_name, **encode_kwargs)
elif 'LLM2Vec' in model_name:
emb = get_llm2vec_embeddings(text, model_name, **encode_kwargs)
elif 'all-mpnet-base-v2' in model_name:
emb = get_sentence_transformer_embeddings(text)
elif 'contriever' in model_name:
emb = get_contriever_embeddings(text)
else:
emb = get_api_embeddings(text, model_name, **encode_kwargs)
return emb.view(len(text), -1)
|