|
import sys |
|
import os |
|
sys.path.append(os.path.dirname(os.path.abspath(os.getcwd()))) |
|
|
|
import pickle as pkl |
|
from torch.utils.data import Dataset, DataLoader |
|
import torch |
|
from tqdm import tqdm |
|
import wandb |
|
import numpy as np |
|
import time |
|
import torch.nn as nn |
|
from torch.nn import CrossEntropyLoss |
|
import random |
|
from collections import defaultdict |
|
|
|
|
|
from Reranking.utils import move_to_cuda, seed_everything |
|
from Reranking.rerankers.path import PathReranker |
|
from stark_qa import load_qa, load_skb |
|
import torch.nn.functional as F |
|
import argparse |
|
import json |
|
import time |
|
|
|
|
|
seed_everything(42) |
|
|
|
|
|
class TrainDataset(Dataset): |
|
""" |
|
Custom Dataset for the training data. |
|
Each instance contains multiple positive and negative candidates. |
|
""" |
|
def __init__(self, saved_data, max_neg_candidates=100): |
|
""" |
|
10s for 1000 data |
|
""" |
|
print(f"start processing training dataset...") |
|
s_time = time.time() |
|
self.max_neg_candidates = max_neg_candidates |
|
self.sorted_query2neg = defaultdict(list) |
|
|
|
|
|
self.text2emb_dict = saved_data['text2emb_dict'] |
|
self.data = saved_data['data'] |
|
|
|
|
|
|
|
new_data = [] |
|
|
|
for i in range(len(self.data)): |
|
neg_ids = [] |
|
pos_ids = [] |
|
item = self.data[i] |
|
|
|
|
|
candidates_dict = item['pred_dict'] |
|
ans_ids = item['ans_ids'] |
|
|
|
for ans_id in ans_ids: |
|
if ans_id in candidates_dict.keys(): |
|
pos_ids.append(ans_id) |
|
neg_ids = list(set(candidates_dict.keys()) - set(pos_ids)) |
|
|
|
|
|
score_vector_dict = item['score_vector_dict'] |
|
|
|
|
|
text_emb_dict = item['text_emb_dict'] |
|
|
|
|
|
symb_enc_dict = item['symb_enc_dict'] |
|
|
|
|
|
self.data[i]['pos_ids'] = pos_ids |
|
self.data[i]['neg_ids'] = neg_ids |
|
|
|
query = item['query'] |
|
for pos_id in pos_ids: |
|
new_data.append((query, score_vector_dict[pos_id], self.text2emb_dict[text_emb_dict[pos_id]], symb_enc_dict[pos_id])) |
|
|
|
|
|
|
|
|
|
neg_dict = {neg_id: candidates_dict[neg_id] for neg_id in neg_ids} |
|
sorted_neg_ids = sorted(neg_dict.keys(), key=lambda x: neg_dict[x], reverse=True) |
|
|
|
|
|
self.sorted_query2neg[query] = [(score_vector_dict[neg_id], self.text2emb_dict[text_emb_dict[neg_id]], symb_enc_dict[neg_id]) for neg_id in sorted_neg_ids] |
|
|
|
|
|
self.data = new_data |
|
print(f"Complete data preparation") |
|
print(f"Time: {time.time() - s_time}") |
|
|
|
|
|
|
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
def __getitem__(self, idx): |
|
|
|
return self.data[idx] |
|
|
|
def collate_batch(self, pairs): |
|
s_time = time.time() |
|
|
|
|
|
batch_q = [pair[0] for pair in pairs] |
|
q_text = batch_q |
|
|
|
|
|
|
|
|
|
|
|
batch_p_score_vector = [pair[1] for pair in pairs] |
|
batch_p_score_vector = torch.tensor(batch_p_score_vector) |
|
batch_p_score_vector = batch_p_score_vector[:, :args.vector_dim] |
|
|
|
batch_p_text_emb = [pair[2] for pair in pairs] |
|
batch_p_text_emb = torch.concat(batch_p_text_emb, dim=0) |
|
|
|
batch_p_symb_enc = [pair[3] for pair in pairs] |
|
batch_p_symb_enc = torch.tensor(batch_p_symb_enc) |
|
|
|
|
|
|
|
batch_n = [random.choices(self.sorted_query2neg[query], k=self.max_neg_candidates) for query in batch_q] |
|
|
|
|
|
|
|
batch_n_score_vector = [pair[0] for sublist in batch_n for pair in sublist] |
|
batch_n_score_vector = torch.tensor(batch_n_score_vector) |
|
|
|
batch_n_score_vector = batch_n_score_vector.reshape(len(batch_q), self.max_neg_candidates, -1) |
|
batch_n_score_vector = batch_n_score_vector[:, :, :args.vector_dim] |
|
|
|
|
|
batch_n_text_emb = [pair[1] for sublist in batch_n for pair in sublist] |
|
batch_n_text_emb = torch.concat(batch_n_text_emb, dim=0) |
|
|
|
batch_n_text_emb = batch_n_text_emb.reshape(len(batch_q), self.max_neg_candidates, -1) |
|
|
|
|
|
batch_n_symb_enc = [pair[2] for sublist in batch_n for pair in sublist] |
|
batch_n_symb_enc = torch.tensor(batch_n_symb_enc) |
|
|
|
batch_n_symb_enc = batch_n_symb_enc.reshape(len(batch_q), self.max_neg_candidates, -1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
feed_dict = { |
|
'query': q_text, |
|
'p_score_vector': batch_p_score_vector, |
|
'p_text_emb': batch_p_text_emb, |
|
'p_symb_enc': batch_p_symb_enc, |
|
'n_score_vector': batch_n_score_vector, |
|
'n_text_emb': batch_n_text_emb, |
|
'n_symb_enc': batch_n_symb_enc, |
|
|
|
} |
|
|
|
|
|
return feed_dict |
|
|
|
|
|
class TestDataset(Dataset): |
|
""" |
|
data format: { |
|
"query": query, |
|
"pred_dict": {node_id: score}, |
|
'score_vector_dict': {node_id: [bm25, bm_25, bm25, ada]}, |
|
"text_emb_dict": {node_id: text_emb}, |
|
"ans_ids": [], |
|
} |
|
|
|
""" |
|
|
|
def __init__(self, saved_data): |
|
|
|
print(f"Start processing test dataset...") |
|
self.text2emb_dict = saved_data['text2emb_dict'] |
|
self.data = saved_data['data'] |
|
|
|
self.text_emb_matrix = list(self.text2emb_dict.values()) |
|
self.text_emb_matrix = torch.concat(self.text_emb_matrix, dim=0) |
|
|
|
|
|
self.text2idx_dict = {key: idx for idx, key in enumerate(self.text2emb_dict.keys())} |
|
|
|
print(f"Complete data preparation: {len(self.data)}") |
|
|
|
|
|
|
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
def __getitem__(self, idx): |
|
|
|
|
|
pred_dict = self.data[idx]['pred_dict'] |
|
sorted_ids = sorted(pred_dict.keys(), key=lambda x: pred_dict[x], reverse=True) |
|
|
|
sorted_ids = sorted_ids[:50] |
|
|
|
self.data[idx]['score_vector_dict'] = {key: self.data[idx]['score_vector_dict'][key] for key in sorted_ids} |
|
|
|
self.data[idx]['symb_enc_dict'] = {key: self.data[idx]['symb_enc_dict'][key] for key in sorted_ids} |
|
|
|
self.data[idx]['text_emb_dict'] = {key: self.text2idx_dict[value] for key, value in self.data[idx]['text_emb_dict'].items()} |
|
self.data[idx]['text_emb_dict'] = {key: self.data[idx]['text_emb_dict'][key] for key in sorted_ids} |
|
|
|
return self.data[idx] |
|
|
|
def collate_batch(self, batch): |
|
|
|
|
|
batch_q = [batch[i]['query'] for i in range(len(batch))] |
|
q_text = batch_q |
|
|
|
|
|
batch_c = [list(batch[i]['score_vector_dict'].keys()) for i in range(len(batch))] |
|
batch_c = torch.tensor(batch_c) |
|
|
|
c_score_vector = [list(batch[i]['score_vector_dict'].values()) for i in range(len(batch))] |
|
c_score_vector = torch.tensor(c_score_vector)[:, :, :args.vector_dim] |
|
|
|
|
|
|
|
|
|
|
|
c_text_emb = [self.text_emb_matrix[list(batch[i]['text_emb_dict'].values())].unsqueeze(0) for i in range(len(batch))] |
|
c_text_emb = torch.concat(c_text_emb, dim=0) |
|
|
|
|
|
c_symb_enc = [list(batch[i]['symb_enc_dict'].values()) for i in range(len(batch))] |
|
c_symb_enc = torch.tensor(c_symb_enc) |
|
|
|
|
|
|
|
ans_ids = [batch[i]['ans_ids'] for i in range(len(batch))] |
|
|
|
|
|
pred_ids = batch_c.tolist() |
|
|
|
|
|
|
|
feed_dict = { |
|
'query': q_text, |
|
'c_score_vector': c_score_vector, |
|
'c_text_emb': c_text_emb, |
|
'c_symb_enc': c_symb_enc, |
|
'ans_ids': ans_ids, |
|
'pred_ids': pred_ids |
|
|
|
} |
|
|
|
|
|
|
|
return feed_dict |
|
|
|
|
|
|
|
def loss_fn(scores_pos, scores_neg): |
|
|
|
loss_fct = CrossEntropyLoss(ignore_index=-1) |
|
|
|
|
|
scores = torch.cat([scores_pos, scores_neg.squeeze(-1)], dim=1) |
|
|
|
|
|
|
|
target = torch.zeros(scores.size(0), dtype=torch.long).to(scores.device) |
|
|
|
|
|
loss = loss_fct(scores, target) |
|
|
|
return loss |
|
|
|
|
|
|
|
def batch_evaluator(skb, scores_cand, ans_ids, batch): |
|
|
|
results = {} |
|
|
|
|
|
|
|
candidates_ids = skb.candidate_ids |
|
id_to_idx = {candidate_id: idx for idx, candidate_id in enumerate(candidates_ids)} |
|
|
|
|
|
|
|
pred_matrix = torch.zeros((scores_cand.shape[0],len(candidates_ids))) |
|
|
|
|
|
|
|
|
|
flat_pred_ids = torch.tensor(batch['pred_ids']).flatten().tolist() |
|
|
|
|
|
|
|
pred_idx = [id_to_idx[pred_id] for pred_id in flat_pred_ids] |
|
|
|
|
|
|
|
pred_idx = torch.tensor(pred_idx).reshape(scores_cand.shape[0], -1) |
|
|
|
|
|
pred_matrix = pred_matrix.to(scores_cand.device) |
|
|
|
|
|
pred_matrix[torch.arange(scores_cand.shape[0]).unsqueeze(1), pred_idx] = scores_cand.squeeze(-1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flat_ans_idx = [id_to_idx[a_id] for sublist in ans_ids for a_id in sublist] |
|
|
|
|
|
row_indices = torch.repeat_interleave(torch.arange(len(ans_ids)), torch.tensor([len(sublist) for sublist in ans_ids])) |
|
|
|
|
|
ans_matrix = torch.zeros((scores_cand.shape[0], len(candidates_ids)), device=scores_cand.device) |
|
ans_matrix[row_indices, torch.tensor(flat_ans_idx, device=scores_cand.device)] = 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
max_score, max_idx = torch.max(pred_matrix, dim=1) |
|
|
|
batch_hit1 = ans_matrix[torch.arange(scores_cand.shape[0]), max_idx] |
|
hit1_list = batch_hit1.tolist() |
|
|
|
|
|
|
|
|
|
|
|
_, top5_idx = torch.topk(pred_matrix, 5, dim=1) |
|
batch_hit5 = ans_matrix[torch.arange(scores_cand.shape[0]).unsqueeze(1), top5_idx] |
|
|
|
|
|
batch_hit5 = torch.max(batch_hit5, dim=1)[0] |
|
hit5_list = batch_hit5.tolist() |
|
|
|
|
|
|
|
|
|
_, top20_idx = torch.topk(pred_matrix, 20, dim=1) |
|
batch_recall20 = ans_matrix[torch.arange(scores_cand.shape[0]).unsqueeze(1), top20_idx] |
|
|
|
batch_recall20 = torch.sum(batch_recall20, dim=1) |
|
|
|
batch_recall20 = batch_recall20 / torch.sum(ans_matrix, dim=1) |
|
recall20_list = batch_recall20.tolist() |
|
|
|
|
|
|
|
|
|
|
|
|
|
_, rank_idx = torch.sort(pred_matrix, dim=1, descending=True) |
|
|
|
batch_mrr = ans_matrix[torch.arange(scores_cand.shape[0]).unsqueeze(1), rank_idx] |
|
|
|
batch_mrr = torch.argmax(batch_mrr, dim=1) |
|
|
|
batch_mrr += 1 |
|
|
|
batch_mrr = 1 / batch_mrr.float() |
|
mrr_list = batch_mrr.tolist() |
|
|
|
|
|
|
|
|
|
results['hit@1'] = hit1_list |
|
results['hit@5'] = hit5_list |
|
results['recall@20'] = recall20_list |
|
results['mrr'] = mrr_list |
|
|
|
|
|
|
|
|
|
|
|
|
|
return results |
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
def evaluate(reranker, test_loader): |
|
|
|
|
|
reranker.eval() |
|
|
|
all_results = { |
|
"hit@1": [], |
|
"hit@5": [], |
|
"recall@20": [], |
|
"mrr": [] |
|
} |
|
avg_results = { |
|
"hit@1": 0, |
|
"hit@5": 0, |
|
"recall@20": 0, |
|
"mrr": 0 |
|
} |
|
|
|
|
|
|
|
pred_list = [] |
|
scores_cand_list = [] |
|
ans_ids_list = [] |
|
|
|
for idx, batch in enumerate(tqdm(test_loader, desc='Evaluating', position=0)): |
|
batch = move_to_cuda(batch) |
|
|
|
|
|
if isinstance(reranker, nn.DataParallel): |
|
scores_cand = reranker.module.eval_batch(batch) |
|
else: |
|
scores_cand = reranker.eval_batch(batch) |
|
|
|
|
|
|
|
ans_ids = batch['ans_ids'] |
|
|
|
results = batch_evaluator(skb, scores_cand, ans_ids, batch) |
|
|
|
|
|
for key in results.keys(): |
|
all_results[key].extend(results[key]) |
|
|
|
|
|
pred_list.extend(batch['pred_ids']) |
|
scores_cand_list.extend(scores_cand.cpu().tolist()) |
|
ans_ids_list.extend(ans_ids) |
|
|
|
|
|
|
|
for key in avg_results.keys(): |
|
avg_results[key] = np.mean(all_results[key]) |
|
|
|
print(f"Results: {avg_results}") |
|
|
|
|
|
|
|
return avg_results |
|
|
|
|
|
|
|
def main(train_data, val_data, test_data, skb, dataset_name, args): |
|
|
|
|
|
epochs = args.epochs |
|
device = args.device |
|
|
|
train_size = args.train_batch_size |
|
test_size = 64 |
|
|
|
train_dataset = TrainDataset(train_data) |
|
train_loader = DataLoader(train_dataset, batch_size=train_size, num_workers=32, collate_fn=train_dataset.collate_batch, drop_last=True) |
|
|
|
test_dataset = TestDataset(test_data) |
|
test_loader = DataLoader(test_dataset, batch_size=test_size, num_workers=32, collate_fn=test_dataset.collate_batch) |
|
|
|
val_dataset = TestDataset(val_data) |
|
val_loader = DataLoader(val_dataset, batch_size=test_size, num_workers=32, collate_fn=val_dataset.collate_batch) |
|
|
|
|
|
|
|
reranker = PathReranker(socre_vector_input_dim=args.vector_dim, text_emb_input_dim=768, symb_enc_dim=3, args=args) |
|
save_dir = f"./data/checkpoints/{dataset_name}/path" |
|
os.makedirs(save_dir, exist_ok=True) |
|
|
|
reranker.to(device) |
|
|
|
reranker = nn.DataParallel(reranker) |
|
|
|
|
|
optimizer = torch.optim.Adam(reranker.parameters(), lr=args.lr) |
|
best_val_hit1 = float('-inf') |
|
|
|
|
|
val_results = evaluate(reranker, val_loader) |
|
print(f"Val evaluation") |
|
print(val_results) |
|
|
|
|
|
test_results = evaluate(reranker, test_loader) |
|
print(f"Test evaluation") |
|
print(test_results) |
|
|
|
|
|
wandb.log({'val_mrr': val_results['mrr'], 'val_hit1': val_results['hit@1'], 'val_hit5': val_results['hit@5'], 'val_recall@20': val_results['recall@20'], |
|
'test_mrr': test_results['mrr'], 'test_hit1': test_results['hit@1'], 'test_hit5': test_results['hit@5'], 'test_recall@20': test_results['recall@20']}) |
|
|
|
best_test_results = {} |
|
for epoch in tqdm(range(epochs), desc='Training Epochs', position=0): |
|
total_loss = 0.0 |
|
reranker.train() |
|
count = 0 |
|
total_instances = 0 |
|
|
|
for batch in tqdm(train_loader): |
|
|
|
batch = move_to_cuda(batch) |
|
|
|
|
|
scores_pos, scores_neg = reranker(batch) |
|
|
|
|
|
batch_loss = loss_fn(scores_pos, scores_neg) |
|
|
|
|
|
optimizer.zero_grad() |
|
batch_loss.backward() |
|
optimizer.step() |
|
|
|
|
|
count += 1 |
|
|
|
total_instances += scores_pos.shape[0] |
|
total_loss += batch_loss.item() |
|
|
|
|
|
train_loss = total_loss / total_instances |
|
|
|
print(f"Epoch {epoch+1}/{epochs}, Average Train Loss: {train_loss}") |
|
|
|
|
|
|
|
val_results = evaluate(reranker, val_loader) |
|
print(f"Val evaluation") |
|
print(val_results) |
|
|
|
|
|
test_results = evaluate(reranker, test_loader) |
|
print(f"Test evaluation") |
|
print(test_results) |
|
|
|
|
|
wandb.log({'val_mrr': val_results['mrr'], 'val_hit1': val_results['hit@1'], 'val_hit5': val_results['hit@5'], 'val_recall@20': val_results['recall@20'], |
|
'test_mrr': test_results['mrr'], 'test_hit1': test_results['hit@1'], 'test_hit5': test_results['hit@5'], 'test_recall@20': test_results['recall@20'], |
|
'train_loss': train_loss}) |
|
|
|
|
|
|
|
hit1 = val_results['hit@1'] |
|
if best_val_hit1 < hit1: |
|
best_val_hit1 = hit1 |
|
|
|
save_path = f"{save_dir}/best_{best_val_hit1}.pth" |
|
|
|
if isinstance(reranker, nn.DataParallel): |
|
torch.save(reranker.module.state_dict(), save_path) |
|
else: |
|
torch.save(reranker.state_dict(), save_path) |
|
print(f"Checkpoint saved at epoch {epoch+1} with test hits@1 {hit1}") |
|
|
|
args.checkpoint_path = save_path |
|
best_test_results = test_results |
|
|
|
|
|
|
|
|
|
save_path = f"{save_dir}/last_{hit1}.pth" |
|
if isinstance(reranker, nn.DataParallel): |
|
torch.save(reranker.module.state_dict(), save_path) |
|
else: |
|
torch.save(reranker.state_dict(), save_path) |
|
print(f"Final checkpoint saved at {save_path}") |
|
|
|
|
|
|
|
|
|
results = [] |
|
results.append( |
|
{ |
|
"config": vars(args), |
|
"test_results": best_test_results |
|
} |
|
) |
|
|
|
timestamp = time.strftime("%Y%m%d-%H%M%S") |
|
output_dir = f"./data/outputs/{dataset_name}" |
|
os.makedirs(output_dir, exist_ok=True) |
|
with open(f"{output_dir}/results_{timestamp}.json", "w") as f: |
|
json.dump(results, f, indent=4) |
|
|
|
print(best_test_results) |
|
|
|
def get_concat_num(combo): |
|
""" |
|
Determine the value of concat_num based on the combination of embeddings. |
|
- score_vec adds +1 |
|
- text_emb adds +1 |
|
- symb_enc adds +3 |
|
""" |
|
concat_num = 0 |
|
if combo.get("score_vec", False): |
|
concat_num += 1 |
|
if combo.get("text_emb", False): |
|
concat_num += 1 |
|
if combo.get("symb_enc", False): |
|
concat_num += 3 |
|
|
|
|
|
return concat_num |
|
|
|
|
|
|
|
def parse_args(): |
|
|
|
parser = argparse.ArgumentParser(description="Run Pathreranker with dynamic combinations of embeddings.") |
|
|
|
|
|
parser.add_argument("--train_batch_size", type=int, default=256, help="Batch size for training or evaluation.") |
|
parser.add_argument("--lr", type=float, default=3e-5, help="Learning rate for optimizer.") |
|
parser.add_argument("--epochs", type=int, default=100, help="Number of epochs to train the model.") |
|
parser.add_argument("--device", type=str, default="cuda", help="Device to run the model (e.g., 'cuda' or 'cpu').") |
|
|
|
|
|
parser.add_argument("--dataset_name", type=str, default="prime", help="Name of the dataset to use.") |
|
|
|
parser.add_argument("--train_path", type=str, default=f"../prime_train.pkl", help="Path to the training data.") |
|
parser.add_argument("--test_path", type=str, default=f"../prime_test.pkl", help="Path to the test data.") |
|
parser.add_argument("--val_path", type=str, default=f"../prime_val.pkl", help="Path to the validation data.") |
|
|
|
|
|
parser.add_argument("--concat_num", type=int, default=0, help="Number of concatenation of embeddings.") |
|
|
|
|
|
parser.add_argument("--checkpoint_path", type=str, default="", help="Path to save the checkpoints.") |
|
parser.add_argument("--vector_dim", type=int, default=4, help="Dimension of the similarity vector.") |
|
|
|
|
|
|
|
args = parser.parse_args() |
|
return args |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
base_args = parse_args() |
|
|
|
dataset_name = base_args.dataset_name |
|
train_path = base_args.train_path |
|
test_path = base_args.test_path |
|
val_path = base_args.val_path |
|
|
|
|
|
with open(test_path, "rb") as f: |
|
test_data = pkl.load(f) |
|
|
|
with open(train_path, "rb") as f: |
|
train_data = pkl.load(f) |
|
|
|
with open(val_path, "rb") as f: |
|
val_data = pkl.load(f) |
|
|
|
|
|
skb = load_skb(dataset_name) |
|
|
|
|
|
combo = { |
|
"text_emb": True, |
|
"score_vec": True, |
|
"symb_enc": True |
|
} |
|
concat_num = get_concat_num(combo) |
|
|
|
wandb.init(project=f'Reranking-{dataset_name}', name=f"path") |
|
args = argparse.Namespace(**vars(base_args), **combo) |
|
args.concat_num = concat_num |
|
|
|
main(train_data, val_data, test_data, skb, dataset_name, args) |
|
|