MoR / Reranking /train_eval_path_prime.py
GagaLey's picture
framework
7bf4b88
import sys
import os
sys.path.append(os.path.dirname(os.path.abspath(os.getcwd())))
import pickle as pkl
from torch.utils.data import Dataset, DataLoader
import torch
from tqdm import tqdm
import wandb
import numpy as np
import time
import torch.nn as nn
from torch.nn import CrossEntropyLoss
import random
from collections import defaultdict
from Reranking.utils import move_to_cuda, seed_everything
from Reranking.rerankers.path import PathReranker
from stark_qa import load_qa, load_skb
import torch.nn.functional as F
import argparse
import json
import time
# set the seed
seed_everything(42)
# ***** Dataset *****
class TrainDataset(Dataset):
"""
Custom Dataset for the training data.
Each instance contains multiple positive and negative candidates.
"""
def __init__(self, saved_data, max_neg_candidates=100):
"""
10s for 1000 data
"""
print(f"start processing training dataset...")
s_time = time.time()
self.max_neg_candidates = max_neg_candidates
self.sorted_query2neg = defaultdict(list)
self.text2emb_dict = saved_data['text2emb_dict']
self.data = saved_data['data']
# separage neg and pos, and prepare query, pos pairs
new_data = []
for i in range(len(self.data)):
neg_ids = []
pos_ids = []
item = self.data[i]
candidates_dict = item['pred_dict']
ans_ids = item['ans_ids']
# pos_ids = ans_ids
for ans_id in ans_ids:
if ans_id in candidates_dict.keys():
pos_ids.append(ans_id)
neg_ids = list(set(candidates_dict.keys()) - set(pos_ids))
# load scores vector
score_vector_dict = item['score_vector_dict']
# load the text path, str format
text_emb_dict = item['text_emb_dict']
# load the symb_enc_dict
symb_enc_dict = item['symb_enc_dict']
self.data[i]['pos_ids'] = pos_ids
self.data[i]['neg_ids'] = neg_ids
query = item['query']
for pos_id in pos_ids:
new_data.append((query, score_vector_dict[pos_id], self.text2emb_dict[text_emb_dict[pos_id]], symb_enc_dict[pos_id]))
# print(f"new_data: {new_data}")
neg_dict = {neg_id: candidates_dict[neg_id] for neg_id in neg_ids}
sorted_neg_ids = sorted(neg_dict.keys(), key=lambda x: neg_dict[x], reverse=True) # return list
self.sorted_query2neg[query] = [(score_vector_dict[neg_id], self.text2emb_dict[text_emb_dict[neg_id]], symb_enc_dict[neg_id]) for neg_id in sorted_neg_ids]
self.data = new_data
print(f"Complete data preparation")
print(f"Time: {time.time() - s_time}")
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
def collate_batch(self, pairs):
s_time = time.time()
# q
batch_q = [pair[0] for pair in pairs] # q is text
q_text = batch_q
# print(f"q111, {q_text}")
# pos
# get the score vector
batch_p_score_vector = [pair[1] for pair in pairs] # p is score vector
batch_p_score_vector = torch.tensor(batch_p_score_vector) # [bs, 4]
batch_p_score_vector = batch_p_score_vector[:, :args.vector_dim]
# get the text emb
batch_p_text_emb = [pair[2] for pair in pairs] # p is text emb
batch_p_text_emb = torch.concat(batch_p_text_emb, dim=0) # [bs, 768]
# get the symb_enc
batch_p_symb_enc = [pair[3] for pair in pairs] # p is symb_enc
batch_p_symb_enc = torch.tensor(batch_p_symb_enc) # [bs, 3]
# Negative samples
batch_n = [random.choices(self.sorted_query2neg[query], k=self.max_neg_candidates) for query in batch_q] # allow duplicates
# get the score vector
batch_n_score_vector = [pair[0] for sublist in batch_n for pair in sublist]
batch_n_score_vector = torch.tensor(batch_n_score_vector) # [bs*100, 4]
# reshape to [bs, 100, 4]
batch_n_score_vector = batch_n_score_vector.reshape(len(batch_q), self.max_neg_candidates, -1) # [bs, 100, 4]
batch_n_score_vector = batch_n_score_vector[:, :, :args.vector_dim]
# get the text emb
batch_n_text_emb = [pair[1] for sublist in batch_n for pair in sublist]
batch_n_text_emb = torch.concat(batch_n_text_emb, dim=0) # [bs*100, 768]
# reshape to [bs, 100, 768]
batch_n_text_emb = batch_n_text_emb.reshape(len(batch_q), self.max_neg_candidates, -1) # [bs, 100, 768]
# get the symb_enc
batch_n_symb_enc = [pair[2] for sublist in batch_n for pair in sublist]
batch_n_symb_enc = torch.tensor(batch_n_symb_enc) # [bs*100, 3]
# reshape to [bs, 100, 3]
batch_n_symb_enc = batch_n_symb_enc.reshape(len(batch_q), self.max_neg_candidates, -1) # [bs, 100, 3]
# Create a dictionary for the batch
feed_dict = {
'query': q_text,
'p_score_vector': batch_p_score_vector,
'p_text_emb': batch_p_text_emb,
'p_symb_enc': batch_p_symb_enc,
'n_score_vector': batch_n_score_vector,
'n_text_emb': batch_n_text_emb,
'n_symb_enc': batch_n_symb_enc,
}
return feed_dict
class TestDataset(Dataset):
"""
data format: {
"query": query,
"pred_dict": {node_id: score},
'score_vector_dict': {node_id: [bm25, bm_25, bm25, ada]},
"text_emb_dict": {node_id: text_emb},
"ans_ids": [],
}
"""
def __init__(self, saved_data):
print(f"Start processing test dataset...")
self.text2emb_dict = saved_data['text2emb_dict']
self.data = saved_data['data']
self.text_emb_matrix = list(self.text2emb_dict.values())
self.text_emb_matrix = torch.concat(self.text_emb_matrix, dim=0)
# make the mapping between the key of text2emb_dict and the index of text_emb_matrix
self.text2idx_dict = {key: idx for idx, key in enumerate(self.text2emb_dict.keys())}
print(f"Complete data preparation: {len(self.data)}")
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
# sort the pred_dict by the score
pred_dict = self.data[idx]['pred_dict']
sorted_ids = sorted(pred_dict.keys(), key=lambda x: pred_dict[x], reverse=True)
# get the top 50 candidates
sorted_ids = sorted_ids[:50]
# get the score vector
self.data[idx]['score_vector_dict'] = {key: self.data[idx]['score_vector_dict'][key] for key in sorted_ids}
# get the symb_enc_dict
self.data[idx]['symb_enc_dict'] = {key: self.data[idx]['symb_enc_dict'][key] for key in sorted_ids}
# change from the str to index
self.data[idx]['text_emb_dict'] = {key: self.text2idx_dict[value] for key, value in self.data[idx]['text_emb_dict'].items()}
self.data[idx]['text_emb_dict'] = {key: self.data[idx]['text_emb_dict'][key] for key in sorted_ids}
return self.data[idx]
def collate_batch(self, batch):
# q
batch_q = [batch[i]['query'] for i in range(len(batch))]
q_text = batch_q
# c
batch_c = [list(batch[i]['score_vector_dict'].keys()) for i in range(len(batch))] # [batch, 100]
batch_c = torch.tensor(batch_c)
# print(f"111, {batch_c.shape}")
c_score_vector = [list(batch[i]['score_vector_dict'].values()) for i in range(len(batch))] # [batch, 100, 4]
c_score_vector = torch.tensor(c_score_vector)[:, :, :args.vector_dim] # [batch, 100, 4]
# print(f"222, {c_vector.shape}")
# c_text_emb
# c_text_emb = [torch.concat(list(batch[i]['text_emb_dict'].values()), dim=0).unsqueeze(0) for i in range(len(batch))]
c_text_emb = [self.text_emb_matrix[list(batch[i]['text_emb_dict'].values())].unsqueeze(0) for i in range(len(batch))]
c_text_emb = torch.concat(c_text_emb, dim=0) # [bs, 100, 768]
# c_symb_enc
c_symb_enc = [list(batch[i]['symb_enc_dict'].values()) for i in range(len(batch))]
c_symb_enc = torch.tensor(c_symb_enc) # [bs, 100, 3]
# ans_ids
ans_ids = [batch[i]['ans_ids'] for i in range(len(batch))] # list of ans_ids
# pred_ids
pred_ids = batch_c.tolist()
# Create a dictionary for the batch
feed_dict = {
'query': q_text,
'c_score_vector': c_score_vector,
'c_text_emb': c_text_emb,
'c_symb_enc': c_symb_enc,
'ans_ids': ans_ids,
'pred_ids': pred_ids
}
return feed_dict
# ******* loss function ********
def loss_fn(scores_pos, scores_neg):
loss_fct = CrossEntropyLoss(ignore_index=-1)
# Combine scores
scores = torch.cat([scores_pos, scores_neg.squeeze(-1)], dim=1) # B x (1 + max_neg_candidates*B)
# print(f"scores: {scores.shape}")
# Create target
target = torch.zeros(scores.size(0), dtype=torch.long).to(scores.device)
# Compute loss
loss = loss_fct(scores, target)
return loss
def batch_evaluator(skb, scores_cand, ans_ids, batch):
results = {}
# **** batch wise evaluation ****
# evaluate
candidates_ids = skb.candidate_ids
id_to_idx = {candidate_id: idx for idx, candidate_id in enumerate(candidates_ids)}
# initialize the pred_matrix
pred_matrix = torch.zeros((scores_cand.shape[0],len(candidates_ids)))
# get the index of each pred_ids
# flatten the pred_ids
flat_pred_ids = torch.tensor(batch['pred_ids']).flatten().tolist()
# get the index of each pred_ids
pred_idx = [id_to_idx[pred_id] for pred_id in flat_pred_ids]
# reshape the pred_idx
pred_idx = torch.tensor(pred_idx).reshape(scores_cand.shape[0], -1) # [bs, 100]
# move pred_matrix to the device
pred_matrix = pred_matrix.to(scores_cand.device)
# advanced indexing
pred_matrix[torch.arange(scores_cand.shape[0]).unsqueeze(1), pred_idx] = scores_cand.squeeze(-1) # [bs, num_candidates]
# Create a mapping from candidate IDs to their indices for faster lookup
# Flatten ans_ids to a single list and map them to indices
flat_ans_idx = [id_to_idx[a_id] for sublist in ans_ids for a_id in sublist]
# Create the row indices for ans_matrix corresponding to the answers
row_indices = torch.repeat_interleave(torch.arange(len(ans_ids)), torch.tensor([len(sublist) for sublist in ans_ids]))
# Create the answer matrix
ans_matrix = torch.zeros((scores_cand.shape[0], len(candidates_ids)), device=scores_cand.device)
ans_matrix[row_indices, torch.tensor(flat_ans_idx, device=scores_cand.device)] = 1
# batch computing hit1
# find the index of the max score
max_score, max_idx = torch.max(pred_matrix, dim=1)
# check the label of the max idx
batch_hit1 = ans_matrix[torch.arange(scores_cand.shape[0]), max_idx]
hit1_list = batch_hit1.tolist()
# batch computing hit@5
_, top5_idx = torch.topk(pred_matrix, 5, dim=1)
batch_hit5 = ans_matrix[torch.arange(scores_cand.shape[0]).unsqueeze(1), top5_idx]
# max with each row
batch_hit5 = torch.max(batch_hit5, dim=1)[0]
hit5_list = batch_hit5.tolist()
# batch computing recall@20
_, top20_idx = torch.topk(pred_matrix, 20, dim=1)
batch_recall20 = ans_matrix[torch.arange(scores_cand.shape[0]).unsqueeze(1), top20_idx]
# sum with each row
batch_recall20 = torch.sum(batch_recall20, dim=1)
# divide by the sum of the ans_matrix along the row
batch_recall20 = batch_recall20 / torch.sum(ans_matrix, dim=1)
recall20_list = batch_recall20.tolist()
# batch computing mrr
# find the highest rank of the answer
_, rank_idx = torch.sort(pred_matrix, dim=1, descending=True)
# query the answer matrix with the rank_idx
batch_mrr = ans_matrix[torch.arange(scores_cand.shape[0]).unsqueeze(1), rank_idx]
# find the first rank of the answer
batch_mrr = torch.argmax(batch_mrr, dim=1)
# add 1 to the rank
batch_mrr += 1
# divide by the rank
batch_mrr = 1 / batch_mrr.float()
mrr_list = batch_mrr.tolist()
results['hit@1'] = hit1_list
results['hit@5'] = hit5_list
results['recall@20'] = recall20_list
results['mrr'] = mrr_list
return results
# ***** evaluate *****
@torch.no_grad()
def evaluate(reranker, test_loader):
reranker.eval()
all_results = {
"hit@1": [],
"hit@5": [],
"recall@20": [],
"mrr": []
}
avg_results = {
"hit@1": 0,
"hit@5": 0,
"recall@20": 0,
"mrr": 0
}
# save the scores and ans_ids, and pred_ids
pred_list = []
scores_cand_list = []
ans_ids_list = []
# use tqdm to show the progress
for idx, batch in enumerate(tqdm(test_loader, desc='Evaluating', position=0)):
batch = move_to_cuda(batch)
# Check if the model is wrapped in DataParallel
if isinstance(reranker, nn.DataParallel):
scores_cand = reranker.module.eval_batch(batch) # q_emb: [bs, 100], c_emb: [bs*100, 100]
else:
scores_cand = reranker.eval_batch(batch)
# ans_ids
ans_ids = batch['ans_ids']
results = batch_evaluator(skb, scores_cand, ans_ids, batch)
for key in results.keys():
all_results[key].extend(results[key])
# save the scores and ans_ids, and pred_ids
pred_list.extend(batch['pred_ids'])
scores_cand_list.extend(scores_cand.cpu().tolist())
ans_ids_list.extend(ans_ids)
for key in avg_results.keys():
avg_results[key] = np.mean(all_results[key])
print(f"Results: {avg_results}")
return avg_results
# ***** train *****
def main(train_data, val_data, test_data, skb, dataset_name, args):
epochs = args.epochs
device = args.device
train_size = args.train_batch_size
test_size = 64
train_dataset = TrainDataset(train_data)
train_loader = DataLoader(train_dataset, batch_size=train_size, num_workers=32, collate_fn=train_dataset.collate_batch, drop_last=True)
test_dataset = TestDataset(test_data)
test_loader = DataLoader(test_dataset, batch_size=test_size, num_workers=32, collate_fn=test_dataset.collate_batch)
val_dataset = TestDataset(val_data)
val_loader = DataLoader(val_dataset, batch_size=test_size, num_workers=32, collate_fn=val_dataset.collate_batch)
# ***** Model *****
reranker = PathReranker(socre_vector_input_dim=args.vector_dim, text_emb_input_dim=768, symb_enc_dim=3, args=args)
save_dir = f"./data/checkpoints/{dataset_name}/path"
os.makedirs(save_dir, exist_ok=True)
reranker.to(device)
# # parallel processing
reranker = nn.DataParallel(reranker)
optimizer = torch.optim.Adam(reranker.parameters(), lr=args.lr)
best_val_hit1 = float('-inf')
val_results = evaluate(reranker, val_loader)
print(f"Val evaluation")
print(val_results)
test_results = evaluate(reranker, test_loader)
print(f"Test evaluation")
print(test_results)
# log both val and test results
wandb.log({'val_mrr': val_results['mrr'], 'val_hit1': val_results['hit@1'], 'val_hit5': val_results['hit@5'], 'val_recall@20': val_results['recall@20'],
'test_mrr': test_results['mrr'], 'test_hit1': test_results['hit@1'], 'test_hit5': test_results['hit@5'], 'test_recall@20': test_results['recall@20']})
best_test_results = {}
for epoch in tqdm(range(epochs), desc='Training Epochs', position=0):
total_loss = 0.0
reranker.train()
count = 0
total_instances = 0
for batch in tqdm(train_loader):
# print(batch)
batch = move_to_cuda(batch)
# print(batch)
scores_pos, scores_neg = reranker(batch)
# batch_loss = pairwise_loss(scores_pos, scores_neg)
batch_loss = loss_fn(scores_pos, scores_neg)
# clear optimizer
optimizer.zero_grad()
batch_loss.backward()
optimizer.step()
# total_loss += batch_loss.item()
count += 1
# compute the average loss
total_instances += scores_pos.shape[0]
total_loss += batch_loss.item()
train_loss = total_loss / total_instances
print(f"Epoch {epoch+1}/{epochs}, Average Train Loss: {train_loss}")
val_results = evaluate(reranker, val_loader)
print(f"Val evaluation")
print(val_results)
test_results = evaluate(reranker, test_loader)
print(f"Test evaluation")
print(test_results)
# log both val and test results
wandb.log({'val_mrr': val_results['mrr'], 'val_hit1': val_results['hit@1'], 'val_hit5': val_results['hit@5'], 'val_recall@20': val_results['recall@20'],
'test_mrr': test_results['mrr'], 'test_hit1': test_results['hit@1'], 'test_hit5': test_results['hit@5'], 'test_recall@20': test_results['recall@20'],
'train_loss': train_loss})
# save the best model when val hit1 is the highest
hit1 = val_results['hit@1']
if best_val_hit1 < hit1:
best_val_hit1 = hit1
save_path = f"{save_dir}/best_{best_val_hit1}.pth"
if isinstance(reranker, nn.DataParallel):
torch.save(reranker.module.state_dict(), save_path)
else:
torch.save(reranker.state_dict(), save_path)
print(f"Checkpoint saved at epoch {epoch+1} with test hits@1 {hit1}")
args.checkpoint_path = save_path
best_test_results = test_results
# save last epoch checkopint
save_path = f"{save_dir}/last_{hit1}.pth"
if isinstance(reranker, nn.DataParallel):
torch.save(reranker.module.state_dict(), save_path)
else:
torch.save(reranker.state_dict(), save_path)
print(f"Final checkpoint saved at {save_path}")
# ***** save the results *****
results = []
results.append(
{
"config": vars(args),
"test_results": best_test_results
}
)
# save the results to json
timestamp = time.strftime("%Y%m%d-%H%M%S")
output_dir = f"./data/outputs/{dataset_name}"
os.makedirs(output_dir, exist_ok=True)
with open(f"{output_dir}/results_{timestamp}.json", "w") as f:
json.dump(results, f, indent=4)
print(best_test_results)
def get_concat_num(combo):
"""
Determine the value of concat_num based on the combination of embeddings.
- score_vec adds +1
- text_emb adds +1
- symb_enc adds +3
"""
concat_num = 0
if combo.get("score_vec", False): # If score_vec is True
concat_num += 1
if combo.get("text_emb", False): # If text_emb is True
concat_num += 1
if combo.get("symb_enc", False): # If symb_enc is True
concat_num += 3
return concat_num
def parse_args():
parser = argparse.ArgumentParser(description="Run Pathreranker with dynamic combinations of embeddings.")
# Add arguments for model configurations
parser.add_argument("--train_batch_size", type=int, default=256, help="Batch size for training or evaluation.")
parser.add_argument("--lr", type=float, default=3e-5, help="Learning rate for optimizer.")
parser.add_argument("--epochs", type=int, default=100, help="Number of epochs to train the model.")
parser.add_argument("--device", type=str, default="cuda", help="Device to run the model (e.g., 'cuda' or 'cpu').")
# Add arguments for the dataset
parser.add_argument("--dataset_name", type=str, default="prime", help="Name of the dataset to use.")
# paths
parser.add_argument("--train_path", type=str, default=f"../prime_train.pkl", help="Path to the training data.")
parser.add_argument("--test_path", type=str, default=f"../prime_test.pkl", help="Path to the test data.")
parser.add_argument("--val_path", type=str, default=f"../prime_val.pkl", help="Path to the validation data.")
# add concat_num
parser.add_argument("--concat_num", type=int, default=0, help="Number of concatenation of embeddings.")
# checkpoint save path
parser.add_argument("--checkpoint_path", type=str, default="", help="Path to save the checkpoints.")
parser.add_argument("--vector_dim", type=int, default=4, help="Dimension of the similarity vector.")
# Parse the base arguments
args = parser.parse_args()
return args
if __name__ == "__main__":
base_args = parse_args()
dataset_name = base_args.dataset_name
train_path = base_args.train_path
test_path = base_args.test_path
val_path = base_args.val_path
with open(test_path, "rb") as f:
test_data = pkl.load(f)
with open(train_path, "rb") as f:
train_data = pkl.load(f)
with open(val_path, "rb") as f:
val_data = pkl.load(f)
# load skb
skb = load_skb(dataset_name)
# set all
combo = {
"text_emb": True,
"score_vec": True,
"symb_enc": True
}
concat_num = get_concat_num(combo)
wandb.init(project=f'Reranking-{dataset_name}', name=f"path")
args = argparse.Namespace(**vars(base_args), **combo)
args.concat_num = concat_num
main(train_data, val_data, test_data, skb, dataset_name, args)