MoR / eval.py
GagaLey's picture
scripts
f96a150
import argparse
import sys
from Reasoning.mor4path import MOR4Path
from Planning.model import Planner
from prepare_rerank import prepare_trajectories
from tqdm import tqdm
import os
import pickle as pkl
import torch
import numpy as np
import pandas as pd
from argparse import ArgumentParser
from stark_qa import load_qa, load_skb
import torch.nn as nn
# make model_name a argument
parser = ArgumentParser()
parser.add_argument("--dataset_name", type=str, default="mag")
# text retriever name
parser.add_argument("--text_retriever_name", type=str, default="bm25")
parser.add_argument("--scorer_name", type=str, default="ada", help="contriever, ada") # contriever for prime, ada for amazon and mag
# mod
parser.add_argument("--mod", type=str, default="test", help="train, valid, test")
# device
parser.add_argument("--device", type=str, default="cuda", help="Device to run the model (e.g., 'cuda' or 'cpu').")
if __name__ == "__main__":
args = parser.parse_args()
dataset_name = args.dataset_name
scorer_name = args.scorer_name
text_retriever_name = args.text_retriever_name
skb = load_skb(dataset_name)
qa = load_qa(dataset_name, human_generated_eval=False)
eval_metrics = [
"mrr",
"map",
"rprecision",
"recall@5",
"recall@10",
"recall@20",
"recall@50",
"recall@100",
"hit@1",
"hit@3",
"hit@5",
"hit@10",
"hit@20",
"hit@50",
]
mor_path = MOR4Path(dataset_name, text_retriever_name, scorer_name, skb)
reasoner = Planner(dataset_name)
outputs = []
topk = 100
split_idx = qa.get_idx_split(test_ratio=1.0)
mod = args.mod
all_indices = split_idx[mod].tolist()
eval_csv = pd.DataFrame(columns=["idx", "query_id", "pred_rank"] + eval_metrics)
count = 0
# ***** planning *****
# if the plan cache exists, load it
plan_cache_path = f"./cache/{dataset_name}/path/{mod}_20250222.pkl"
if os.path.exists(plan_cache_path):
with open(plan_cache_path, 'rb') as f:
plan_output_list = pkl.load(f)
else:
plan_output_list = []
for idx, i in enumerate(tqdm(all_indices)):
plan_output = {}
query, q_id, ans_ids, _ = qa[i]
rg = reasoner(query)
plan_output['query'] = query
plan_output['q_id'] = q_id
plan_output['ans_ids'] = ans_ids
plan_output['rg'] = rg
plan_output_list.append(plan_output)
# save plan_output_list
plan_cache_path = f"./cache/{dataset_name}/path/{mod}_20250222.pkl"
os.makedirs(os.path.dirname(plan_cache_path), exist_ok=True)
with open(plan_cache_path, 'wb') as f:
pkl.dump(plan_output_list, f)
# ***** Reasoning *****
for idx, i in enumerate(tqdm(all_indices)):
query = plan_output_list[idx]['query']
q_id = plan_output_list[idx]['q_id']
ans_ids = plan_output_list[idx]['ans_ids']
rg = plan_output_list[idx]['rg']
output = mor_path(query, q_id, ans_ids, rg, args)
ans_ids = torch.LongTensor(ans_ids)
pred_dict = output['pred_dict']
result = mor_path.evaluate(pred_dict, ans_ids, metrics=eval_metrics)
result["idx"], result["query_id"] = i, q_id
result["pred_rank"] = torch.LongTensor(list(pred_dict.keys()))[
torch.argsort(torch.tensor(list(pred_dict.values())), descending=True)[
:topk
]
].tolist()
eval_csv = pd.concat([eval_csv, pd.DataFrame([result])], ignore_index=True)
output['q_id'] = q_id
outputs.append(output)
count += 1
# for metric in eval_metrics:
# print(
# f"{metric}: {np.mean(eval_csv[eval_csv['idx'].isin(all_indices)][metric])}"
# )
print(f"MOR count: {mor_path.mor_count}")
# prepare trajectories and save
bm25 = mor_path.text_retriever
test_data = prepare_trajectories(dataset_name, bm25, skb, outputs)
save_path = f"{dataset_name}_{mod}.pkl"
with open(save_path, 'wb') as f:
pkl.dump(test_data, f)