|
import sys |
|
import os |
|
sys.path.append(os.path.dirname(os.path.abspath(os.getcwd()))) |
|
from stark_qa import load_skb |
|
|
|
|
|
from torch.utils.data import Dataset, DataLoader |
|
import torch |
|
from tqdm import tqdm |
|
import numpy as np |
|
import torch.nn as nn |
|
|
|
from Reranking.utils import move_to_cuda, seed_everything |
|
from Reranking.rerankers.path import PathReranker |
|
import torch.nn.functional as F |
|
import argparse |
|
import pickle as pkl |
|
|
|
|
|
|
|
|
|
class TestDataset(Dataset): |
|
""" |
|
data format: { |
|
"query": query, |
|
"pred_dict": {node_id: score}, |
|
'score_vector_dict': {node_id: [bm25, bm_25, bm25, ada]}, |
|
"text_emb_dict": {node_id: text_emb}, |
|
"ans_ids": [], |
|
} |
|
|
|
""" |
|
|
|
def __init__(self, saved_data, args): |
|
|
|
print(f"Start processing test dataset...") |
|
self.text2emb_dict = saved_data['text2emb_dict'] |
|
self.data = saved_data['data'] |
|
|
|
self.text_emb_matrix = list(self.text2emb_dict.values()) |
|
self.text_emb_matrix = torch.concat(self.text_emb_matrix, dim=0) |
|
|
|
|
|
self.text2idx_dict = {key: idx for idx, key in enumerate(self.text2emb_dict.keys())} |
|
|
|
self.args = args |
|
|
|
|
|
|
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
def __getitem__(self, idx): |
|
|
|
if self.args.dataset_name == 'amazon': |
|
|
|
self.data[idx]['text_emb_dict'] = {key: self.text2idx_dict[value] for key, value in self.data[idx]['text_emb_dict'].items()} |
|
else: |
|
|
|
pred_dict = self.data[idx]['pred_dict'] |
|
sorted_ids = sorted(pred_dict.keys(), key=lambda x: pred_dict[x], reverse=True) |
|
|
|
sorted_ids = sorted_ids[:50] |
|
|
|
self.data[idx]['score_vector_dict'] = {key: self.data[idx]['score_vector_dict'][key] for key in sorted_ids} |
|
|
|
self.data[idx]['symb_enc_dict'] = {key: self.data[idx]['symb_enc_dict'][key] for key in sorted_ids} |
|
|
|
self.data[idx]['text_emb_dict'] = {key: self.text2idx_dict[value] for key, value in self.data[idx]['text_emb_dict'].items()} |
|
self.data[idx]['text_emb_dict'] = {key: self.data[idx]['text_emb_dict'][key] for key in sorted_ids} |
|
|
|
|
|
|
|
return self.data[idx] |
|
|
|
|
|
def collate_batch(self, batch): |
|
|
|
|
|
batch_q = [batch[i]['query'] for i in range(len(batch))] |
|
q_text = batch_q |
|
|
|
|
|
batch_c = [list(batch[i]['score_vector_dict'].keys()) for i in range(len(batch))] |
|
batch_c = torch.tensor(batch_c) |
|
c_score_vector = [list(batch[i]['score_vector_dict'].values()) for i in range(len(batch))] |
|
c_score_vector = torch.tensor(c_score_vector) |
|
c_score_vector = c_score_vector[:, :, :self.args.vector_dim] |
|
|
|
|
|
c_symb_enc = [list(batch[i]['symb_enc_dict'].values()) for i in range(len(batch))] |
|
c_symb_enc = torch.tensor(c_symb_enc) |
|
|
|
|
|
c_text_emb = [self.text_emb_matrix[list(batch[i]['text_emb_dict'].values())].unsqueeze(0) for i in range(len(batch))] |
|
c_text_emb = torch.concat(c_text_emb, dim=0) |
|
|
|
|
|
|
|
ans_ids = [batch[i]['ans_ids'] for i in range(len(batch))] |
|
|
|
|
|
pred_ids = batch_c.tolist() |
|
|
|
|
|
|
|
feed_dict = { |
|
'query': q_text, |
|
'c_score_vector': c_score_vector, |
|
'c_text_emb': c_text_emb, |
|
'c_symb_enc': c_symb_enc, |
|
'ans_ids': ans_ids, |
|
'pred_ids': pred_ids |
|
|
|
} |
|
|
|
|
|
return feed_dict |
|
|
|
|
|
|
|
def batch_evaluator(skb, scores_cand, ans_ids, batch): |
|
|
|
results = {} |
|
|
|
|
|
|
|
candidates_ids = skb.candidate_ids |
|
id_to_idx = {candidate_id: idx for idx, candidate_id in enumerate(candidates_ids)} |
|
|
|
|
|
|
|
pred_matrix = torch.zeros((scores_cand.shape[0],len(candidates_ids))) |
|
|
|
|
|
|
|
|
|
flat_pred_ids = torch.tensor(batch['pred_ids']).flatten().tolist() |
|
|
|
|
|
|
|
pred_idx = [id_to_idx[pred_id] for pred_id in flat_pred_ids] |
|
|
|
|
|
|
|
pred_idx = torch.tensor(pred_idx).reshape(scores_cand.shape[0], -1) |
|
|
|
|
|
pred_matrix = pred_matrix.to(scores_cand.device) |
|
|
|
|
|
pred_matrix[torch.arange(scores_cand.shape[0]).unsqueeze(1), pred_idx] = scores_cand.squeeze(-1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
flat_ans_idx = [id_to_idx[a_id] for sublist in ans_ids for a_id in sublist] |
|
|
|
|
|
row_indices = torch.repeat_interleave(torch.arange(len(ans_ids)), torch.tensor([len(sublist) for sublist in ans_ids])) |
|
|
|
|
|
ans_matrix = torch.zeros((scores_cand.shape[0], len(candidates_ids)), device=scores_cand.device) |
|
ans_matrix[row_indices, torch.tensor(flat_ans_idx, device=scores_cand.device)] = 1 |
|
|
|
|
|
|
|
|
|
|
|
max_score, max_idx = torch.max(pred_matrix, dim=1) |
|
|
|
batch_hit1 = ans_matrix[torch.arange(scores_cand.shape[0]), max_idx] |
|
hit1_list = batch_hit1.tolist() |
|
|
|
|
|
|
|
_, top5_idx = torch.topk(pred_matrix, 5, dim=1) |
|
batch_hit5 = ans_matrix[torch.arange(scores_cand.shape[0]).unsqueeze(1), top5_idx] |
|
|
|
|
|
batch_hit5 = torch.max(batch_hit5, dim=1)[0] |
|
hit5_list = batch_hit5.tolist() |
|
|
|
|
|
|
|
|
|
_, top20_idx = torch.topk(pred_matrix, 20, dim=1) |
|
batch_recall20 = ans_matrix[torch.arange(scores_cand.shape[0]).unsqueeze(1), top20_idx] |
|
|
|
batch_recall20 = torch.sum(batch_recall20, dim=1) |
|
|
|
batch_recall20 = batch_recall20 / torch.sum(ans_matrix, dim=1) |
|
recall20_list = batch_recall20.tolist() |
|
|
|
|
|
|
|
|
|
|
|
_, rank_idx = torch.sort(pred_matrix, dim=1, descending=True) |
|
|
|
batch_mrr = ans_matrix[torch.arange(scores_cand.shape[0]).unsqueeze(1), rank_idx] |
|
|
|
batch_mrr = torch.argmax(batch_mrr, dim=1) |
|
|
|
batch_mrr += 1 |
|
|
|
batch_mrr = 1 / batch_mrr.float() |
|
mrr_list = batch_mrr.tolist() |
|
|
|
|
|
results['hit@1'] = hit1_list |
|
results['hit@5'] = hit5_list |
|
results['recall@20'] = recall20_list |
|
results['mrr'] = mrr_list |
|
|
|
|
|
return results |
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
def evaluate(router, test_loader, skb): |
|
|
|
|
|
router.eval() |
|
|
|
all_results = { |
|
"hit@1": [], |
|
"hit@5": [], |
|
"recall@20": [], |
|
"mrr": [] |
|
} |
|
avg_results = { |
|
"hit@1": 0, |
|
"hit@5": 0, |
|
"recall@20": 0, |
|
"mrr": 0 |
|
} |
|
|
|
|
|
|
|
pred_list = [] |
|
scores_cand_list = [] |
|
ans_ids_list = [] |
|
print(f"Start evaluating...") |
|
|
|
for idx, batch in enumerate(tqdm(test_loader, desc='Evaluating', position=0)): |
|
|
|
batch = move_to_cuda(batch) |
|
|
|
|
|
if isinstance(router, nn.DataParallel): |
|
scores_cand = router.module.eval_batch(batch) |
|
else: |
|
scores_cand = router.eval_batch(batch) |
|
|
|
|
|
|
|
ans_ids = batch['ans_ids'] |
|
|
|
results = batch_evaluator(skb, scores_cand, ans_ids, batch) |
|
|
|
|
|
for key in results.keys(): |
|
all_results[key].extend(results[key]) |
|
|
|
|
|
pred_list.extend(batch['pred_ids']) |
|
scores_cand_list.extend(scores_cand.cpu().tolist()) |
|
ans_ids_list.extend(ans_ids) |
|
|
|
|
|
|
|
for key in avg_results.keys(): |
|
avg_results[key] = np.mean(all_results[key]) |
|
|
|
print(f"Results: {avg_results}") |
|
|
|
|
|
|
|
return avg_results |
|
|
|
|
|
def parse_args(): |
|
|
|
parser = argparse.ArgumentParser(description="Run PathRouter with dynamic combinations of embeddings.") |
|
|
|
|
|
parser.add_argument("--dataset_name", type=str, default="mag", help="Name of the dataset.") |
|
|
|
|
|
parser.add_argument("--device", type=str, default="cuda", help="Device to run the model (e.g., 'cuda' or 'cpu').") |
|
|
|
|
|
|
|
parser.add_argument("--concat_num", type=int, default=0, help="Number of concatenation of embeddings.") |
|
|
|
|
|
parser.add_argument("--checkpoint_path", type=str, default="./data/checkpoints", help="Path saves the checkpoints.") |
|
|
|
|
|
parser.add_argument("--vector_dim", type=int, default=4, help="Dimension of the similarity vector.") |
|
|
|
|
|
|
|
args = parser.parse_args() |
|
return args |
|
|
|
|
|
def get_concat_num(combo): |
|
""" |
|
Determine the value of concat_num based on the combination of embeddings. |
|
- score_vec adds +1 |
|
- text_emb adds +1 |
|
- symb_enc adds +3 |
|
""" |
|
concat_num = 0 |
|
if combo.get("score_vec", False): |
|
concat_num += 1 |
|
if combo.get("text_emb", False): |
|
concat_num += 1 |
|
if combo.get("symb_enc", False): |
|
concat_num += 3 |
|
|
|
|
|
return concat_num |
|
|
|
|
|
def run(test_data, skb, dataset_name): |
|
|
|
|
|
|
|
test_size = 64 |
|
test_dataset = TestDataset(test_data, args=args) |
|
test_loader = DataLoader(test_dataset, batch_size=test_size, num_workers=32, collate_fn=test_dataset.collate_batch) |
|
|
|
|
|
print(f"Load the model...") |
|
args.checkpoint_path = args.checkpoint_path + f"/{dataset_name}/best.pth" |
|
router = PathReranker(socre_vector_input_dim=4, text_emb_input_dim=768, symb_enc_dim=3, args=args) |
|
checkpoint = torch.load(args.checkpoint_path) |
|
router.load_state_dict(checkpoint) |
|
router = router.to(args.device) |
|
|
|
|
|
test_results = evaluate(router, test_loader, skb) |
|
print(f"Test evaluation") |
|
print(test_results) |
|
|
|
return test_results |
|
|
|
if __name__ == "__main__": |
|
|
|
combo = { |
|
"text_emb": True, |
|
"score_vec": True, |
|
"symb_enc": True |
|
} |
|
concat_num = get_concat_num(combo) |
|
|
|
base_args = parse_args() |
|
args = argparse.Namespace(**vars(base_args), **combo) |
|
args.concat_num = concat_num |
|
dataset_name = args.dataset_name |
|
|
|
test_data_path = f"../{dataset_name}_test.pkl" |
|
with open(test_data_path, 'rb') as f: |
|
test_data = pkl.load(f) |
|
skb = load_skb(dataset_name) |
|
results = run(test_data, skb, dataset_name) |
|
|