MoR / Reranking /utils.py
GagaLey's picture
framework
7bf4b88
import sys
from pathlib import Path
# Get the absolute path of the current script
current_file = Path(__file__).resolve()
project_root = current_file.parents[1]
# Add the project root to the system path
sys.path.append(str(project_root))
import random
import torch
import os
from stark_qa.evaluator import Evaluator
import torch.nn as nn
from typing import Any, Union, List, Dict
class ModelForSTaRKQA(nn.Module):
def __init__(self, skb, query_emb_dir='.'):
"""
Initializes the model with the given knowledge base.
Args:
skb: Knowledge base containing candidate information.
"""
super(ModelForSTaRKQA, self).__init__()
self.skb = skb
self.candidate_ids = skb.candidate_ids
self.evaluator = Evaluator(self.candidate_ids)
def evaluate(self,
pred_dict: Dict[int, float],
answer_ids: Union[torch.LongTensor, List[Any]],
metrics: List[str] = ['mrr', 'hit@3', 'recall@20'],
**kwargs: Any) -> Dict[str, float]:
"""
Evaluates the predictions using the specified metrics.
Args:
pred_dict (Dict[int, float]): Predicted answer ids or scores.
answer_ids (torch.LongTensor): Ground truth answer ids.
metrics (List[str]): A list of metrics to be evaluated, including 'mrr', 'hit@k', 'recall@k',
'precision@k', 'map@k', 'ndcg@k'.
Returns:
Dict[str, float]: A dictionary of evaluation metrics.
"""
return self.evaluator(pred_dict, answer_ids, metrics)
def evaluate_batch(self,
pred_ids: List[int],
pred: torch.Tensor,
answer_ids: Union[torch.LongTensor, List[Any]],
metrics: List[str] = ['mrr', 'hit@3', 'recall@20'],
**kwargs: Any) -> Dict[str, float]:
return self.evaluator.evaluate_batch(pred_ids, pred, answer_ids, metrics)
def seed_everything(seed=0):
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def move_to_cuda(sample):
if len(sample) == 0:
return {}
def _move_to_cuda(maybe_tensor):
if torch.is_tensor(maybe_tensor):
return maybe_tensor.cuda()
elif isinstance(maybe_tensor, dict):
return {
key: _move_to_cuda(value)
for key, value in maybe_tensor.items()
}
# elif isinstance(maybe_tensor, list):
# return [_move_to_cuda(x) for x in maybe_tensor]
else:
return maybe_tensor
return _move_to_cuda(sample)
if __name__ == "__main__":
print("Testing Utils")