import sys import os sys.path.append(os.path.dirname(os.path.abspath(os.getcwd()))) from stark_qa import load_skb from torch.utils.data import Dataset, DataLoader import torch from tqdm import tqdm import numpy as np import torch.nn as nn from Reranking.utils import move_to_cuda, seed_everything from Reranking.rerankers.path import PathReranker import torch.nn.functional as F import argparse import pickle as pkl class TestDataset(Dataset): """ data format: { "query": query, "pred_dict": {node_id: score}, 'score_vector_dict': {node_id: [bm25, bm_25, bm25, ada]}, "text_emb_dict": {node_id: text_emb}, "ans_ids": [], } """ def __init__(self, saved_data, args): print(f"Start processing test dataset...") self.text2emb_dict = saved_data['text2emb_dict'] self.data = saved_data['data'] self.text_emb_matrix = list(self.text2emb_dict.values()) self.text_emb_matrix = torch.concat(self.text_emb_matrix, dim=0) # make the mapping between the key of text2emb_dict and the index of text_emb_matrix self.text2idx_dict = {key: idx for idx, key in enumerate(self.text2emb_dict.keys())} self.args = args def __len__(self): return len(self.data) def __getitem__(self, idx): if self.args.dataset_name == 'amazon': # change from the str to index self.data[idx]['text_emb_dict'] = {key: self.text2idx_dict[value] for key, value in self.data[idx]['text_emb_dict'].items()} else: # sort the pred_dict by the score pred_dict = self.data[idx]['pred_dict'] sorted_ids = sorted(pred_dict.keys(), key=lambda x: pred_dict[x], reverse=True) # get the top 50 candidates sorted_ids = sorted_ids[:50] # get the score vector self.data[idx]['score_vector_dict'] = {key: self.data[idx]['score_vector_dict'][key] for key in sorted_ids} # get the symb_enc_dict self.data[idx]['symb_enc_dict'] = {key: self.data[idx]['symb_enc_dict'][key] for key in sorted_ids} # change from the str to index self.data[idx]['text_emb_dict'] = {key: self.text2idx_dict[value] for key, value in self.data[idx]['text_emb_dict'].items()} self.data[idx]['text_emb_dict'] = {key: self.data[idx]['text_emb_dict'][key] for key in sorted_ids} return self.data[idx] def collate_batch(self, batch): # q batch_q = [batch[i]['query'] for i in range(len(batch))] q_text = batch_q # c batch_c = [list(batch[i]['score_vector_dict'].keys()) for i in range(len(batch))] # [batch, 100] batch_c = torch.tensor(batch_c) c_score_vector = [list(batch[i]['score_vector_dict'].values()) for i in range(len(batch))] # [batch, 100, 4] c_score_vector = torch.tensor(c_score_vector) c_score_vector = c_score_vector[:, :, :self.args.vector_dim] # c_symb_enc c_symb_enc = [list(batch[i]['symb_enc_dict'].values()) for i in range(len(batch))] c_symb_enc = torch.tensor(c_symb_enc) # [bs, 100, 3] # c_text_emb c_text_emb = [self.text_emb_matrix[list(batch[i]['text_emb_dict'].values())].unsqueeze(0) for i in range(len(batch))] c_text_emb = torch.concat(c_text_emb, dim=0) # [bs, 100, 768] # ans_ids ans_ids = [batch[i]['ans_ids'] for i in range(len(batch))] # list of ans_ids # pred_ids pred_ids = batch_c.tolist() # Create a dictionary for the batch feed_dict = { 'query': q_text, 'c_score_vector': c_score_vector, 'c_text_emb': c_text_emb, 'c_symb_enc': c_symb_enc, 'ans_ids': ans_ids, 'pred_ids': pred_ids } return feed_dict # ***** batch_evaluator ***** def batch_evaluator(skb, scores_cand, ans_ids, batch): results = {} # **** batch wise evaluation **** # evaluate candidates_ids = skb.candidate_ids id_to_idx = {candidate_id: idx for idx, candidate_id in enumerate(candidates_ids)} # initialize the pred_matrix pred_matrix = torch.zeros((scores_cand.shape[0],len(candidates_ids))) # get the index of each pred_ids # flatten the pred_ids flat_pred_ids = torch.tensor(batch['pred_ids']).flatten().tolist() # get the index of each pred_ids pred_idx = [id_to_idx[pred_id] for pred_id in flat_pred_ids] # reshape the pred_idx pred_idx = torch.tensor(pred_idx).reshape(scores_cand.shape[0], -1) # [bs, 100] # move pred_matrix to the device pred_matrix = pred_matrix.to(scores_cand.device) # advanced indexing pred_matrix[torch.arange(scores_cand.shape[0]).unsqueeze(1), pred_idx] = scores_cand.squeeze(-1) # [bs, num_candidates] # Create a mapping from candidate IDs to their indices for faster lookup # Flatten ans_ids to a single list and map them to indices flat_ans_idx = [id_to_idx[a_id] for sublist in ans_ids for a_id in sublist] # Create the row indices for ans_matrix corresponding to the answers row_indices = torch.repeat_interleave(torch.arange(len(ans_ids)), torch.tensor([len(sublist) for sublist in ans_ids])) # Create the answer matrix ans_matrix = torch.zeros((scores_cand.shape[0], len(candidates_ids)), device=scores_cand.device) ans_matrix[row_indices, torch.tensor(flat_ans_idx, device=scores_cand.device)] = 1 # batch computing hit1 # find the index of the max score max_score, max_idx = torch.max(pred_matrix, dim=1) # check the label of the max idx batch_hit1 = ans_matrix[torch.arange(scores_cand.shape[0]), max_idx] hit1_list = batch_hit1.tolist() # batch computing hit@5 _, top5_idx = torch.topk(pred_matrix, 5, dim=1) batch_hit5 = ans_matrix[torch.arange(scores_cand.shape[0]).unsqueeze(1), top5_idx] # max with each row batch_hit5 = torch.max(batch_hit5, dim=1)[0] hit5_list = batch_hit5.tolist() # batch computing recall@20 _, top20_idx = torch.topk(pred_matrix, 20, dim=1) batch_recall20 = ans_matrix[torch.arange(scores_cand.shape[0]).unsqueeze(1), top20_idx] # sum with each row batch_recall20 = torch.sum(batch_recall20, dim=1) # divide by the sum of the ans_matrix along the row batch_recall20 = batch_recall20 / torch.sum(ans_matrix, dim=1) recall20_list = batch_recall20.tolist() # batch computing mrr # find the highest rank of the answer _, rank_idx = torch.sort(pred_matrix, dim=1, descending=True) # query the answer matrix with the rank_idx batch_mrr = ans_matrix[torch.arange(scores_cand.shape[0]).unsqueeze(1), rank_idx] # find the first rank of the answer batch_mrr = torch.argmax(batch_mrr, dim=1) # add 1 to the rank batch_mrr += 1 # divide by the rank batch_mrr = 1 / batch_mrr.float() mrr_list = batch_mrr.tolist() results['hit@1'] = hit1_list results['hit@5'] = hit5_list results['recall@20'] = recall20_list results['mrr'] = mrr_list return results # ***** evaluate ***** @torch.no_grad() def evaluate(router, test_loader, skb): router.eval() all_results = { "hit@1": [], "hit@5": [], "recall@20": [], "mrr": [] } avg_results = { "hit@1": 0, "hit@5": 0, "recall@20": 0, "mrr": 0 } # save the scores and ans_ids, and pred_ids pred_list = [] scores_cand_list = [] ans_ids_list = [] print(f"Start evaluating...") # use tqdm to show the progress for idx, batch in enumerate(tqdm(test_loader, desc='Evaluating', position=0)): # print(f"idx: {idx}") batch = move_to_cuda(batch) # Check if the model is wrapped in DataParallel if isinstance(router, nn.DataParallel): scores_cand = router.module.eval_batch(batch) # q_emb: [bs, 100], c_emb: [bs*100, 100] else: scores_cand = router.eval_batch(batch) # ans_ids ans_ids = batch['ans_ids'] results = batch_evaluator(skb, scores_cand, ans_ids, batch) for key in results.keys(): all_results[key].extend(results[key]) # save the scores and ans_ids, and pred_ids pred_list.extend(batch['pred_ids']) scores_cand_list.extend(scores_cand.cpu().tolist()) ans_ids_list.extend(ans_ids) for key in avg_results.keys(): avg_results[key] = np.mean(all_results[key]) print(f"Results: {avg_results}") return avg_results def parse_args(): parser = argparse.ArgumentParser(description="Run PathRouter with dynamic combinations of embeddings.") # dataset_name parser.add_argument("--dataset_name", type=str, default="mag", help="Name of the dataset.") # Add arguments for model configurations parser.add_argument("--device", type=str, default="cuda", help="Device to run the model (e.g., 'cuda' or 'cpu').") # add concat_num parser.add_argument("--concat_num", type=int, default=0, help="Number of concatenation of embeddings.") # checkpoint save path parser.add_argument("--checkpoint_path", type=str, default="./data/checkpoints", help="Path saves the checkpoints.") # similarity vector dim parser.add_argument("--vector_dim", type=int, default=4, help="Dimension of the similarity vector.") # Parse the base arguments args = parser.parse_args() return args def get_concat_num(combo): """ Determine the value of concat_num based on the combination of embeddings. - score_vec adds +1 - text_emb adds +1 - symb_enc adds +3 """ concat_num = 0 if combo.get("score_vec", False): # If score_vec is True concat_num += 1 if combo.get("text_emb", False): # If text_emb is True concat_num += 1 if combo.get("symb_enc", False): # If symb_enc is True concat_num += 3 return concat_num def run(test_data, skb, dataset_name): test_size = 64 test_dataset = TestDataset(test_data, args=args) test_loader = DataLoader(test_dataset, batch_size=test_size, num_workers=32, collate_fn=test_dataset.collate_batch) # load the model print(f"Load the model...") args.checkpoint_path = args.checkpoint_path + f"/{dataset_name}/best.pth" router = PathReranker(socre_vector_input_dim=4, text_emb_input_dim=768, symb_enc_dim=3, args=args) checkpoint = torch.load(args.checkpoint_path) router.load_state_dict(checkpoint) router = router.to(args.device) # evalute test_results = evaluate(router, test_loader, skb) print(f"Test evaluation") print(test_results) return test_results if __name__ == "__main__": combo = { "text_emb": True, "score_vec": True, "symb_enc": True } concat_num = get_concat_num(combo) base_args = parse_args() args = argparse.Namespace(**vars(base_args), **combo) args.concat_num = concat_num dataset_name = args.dataset_name test_data_path = f"../{dataset_name}_test.pkl" with open(test_data_path, 'rb') as f: test_data = pkl.load(f) skb = load_skb(dataset_name) results = run(test_data, skb, dataset_name)