|
import argparse |
|
import sys |
|
|
|
|
|
from Reasoning.mor4path import MOR4Path |
|
from Planning.model import Planner |
|
from prepare_rerank import prepare_trajectories |
|
from tqdm import tqdm |
|
import os |
|
import pickle as pkl |
|
import torch |
|
import numpy as np |
|
import pandas as pd |
|
from argparse import ArgumentParser |
|
from stark_qa import load_qa, load_skb |
|
import torch.nn as nn |
|
|
|
|
|
|
|
|
|
parser = ArgumentParser() |
|
parser.add_argument("--dataset_name", type=str, default="mag") |
|
|
|
parser.add_argument("--text_retriever_name", type=str, default="bm25") |
|
parser.add_argument("--scorer_name", type=str, default="ada", help="contriever, ada") |
|
|
|
parser.add_argument("--mod", type=str, default="test", help="train, valid, test") |
|
|
|
parser.add_argument("--device", type=str, default="cuda", help="Device to run the model (e.g., 'cuda' or 'cpu').") |
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
args = parser.parse_args() |
|
dataset_name = args.dataset_name |
|
scorer_name = args.scorer_name |
|
text_retriever_name = args.text_retriever_name |
|
skb = load_skb(dataset_name) |
|
qa = load_qa(dataset_name, human_generated_eval=False) |
|
|
|
eval_metrics = [ |
|
"mrr", |
|
"map", |
|
"rprecision", |
|
"recall@5", |
|
"recall@10", |
|
"recall@20", |
|
"recall@50", |
|
"recall@100", |
|
"hit@1", |
|
"hit@3", |
|
"hit@5", |
|
"hit@10", |
|
"hit@20", |
|
"hit@50", |
|
] |
|
|
|
mor_path = MOR4Path(dataset_name, text_retriever_name, scorer_name, skb) |
|
reasoner = Planner(dataset_name) |
|
outputs = [] |
|
topk = 100 |
|
split_idx = qa.get_idx_split(test_ratio=1.0) |
|
mod = args.mod |
|
all_indices = split_idx[mod].tolist() |
|
eval_csv = pd.DataFrame(columns=["idx", "query_id", "pred_rank"] + eval_metrics) |
|
|
|
count = 0 |
|
|
|
|
|
|
|
plan_cache_path = f"./cache/{dataset_name}/path/{mod}_20250222.pkl" |
|
if os.path.exists(plan_cache_path): |
|
with open(plan_cache_path, 'rb') as f: |
|
plan_output_list = pkl.load(f) |
|
else: |
|
plan_output_list = [] |
|
for idx, i in enumerate(tqdm(all_indices)): |
|
plan_output = {} |
|
query, q_id, ans_ids, _ = qa[i] |
|
rg = reasoner(query) |
|
|
|
plan_output['query'] = query |
|
plan_output['q_id'] = q_id |
|
plan_output['ans_ids'] = ans_ids |
|
plan_output['rg'] = rg |
|
plan_output_list.append(plan_output) |
|
|
|
plan_cache_path = f"./cache/{dataset_name}/path/{mod}_20250222.pkl" |
|
os.makedirs(os.path.dirname(plan_cache_path), exist_ok=True) |
|
with open(plan_cache_path, 'wb') as f: |
|
pkl.dump(plan_output_list, f) |
|
|
|
|
|
|
|
for idx, i in enumerate(tqdm(all_indices)): |
|
|
|
query = plan_output_list[idx]['query'] |
|
q_id = plan_output_list[idx]['q_id'] |
|
ans_ids = plan_output_list[idx]['ans_ids'] |
|
rg = plan_output_list[idx]['rg'] |
|
|
|
|
|
output = mor_path(query, q_id, ans_ids, rg, args) |
|
|
|
ans_ids = torch.LongTensor(ans_ids) |
|
|
|
pred_dict = output['pred_dict'] |
|
result = mor_path.evaluate(pred_dict, ans_ids, metrics=eval_metrics) |
|
|
|
result["idx"], result["query_id"] = i, q_id |
|
result["pred_rank"] = torch.LongTensor(list(pred_dict.keys()))[ |
|
torch.argsort(torch.tensor(list(pred_dict.values())), descending=True)[ |
|
:topk |
|
] |
|
].tolist() |
|
|
|
eval_csv = pd.concat([eval_csv, pd.DataFrame([result])], ignore_index=True) |
|
|
|
output['q_id'] = q_id |
|
outputs.append(output) |
|
|
|
count += 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print(f"MOR count: {mor_path.mor_count}") |
|
|
|
|
|
|
|
bm25 = mor_path.text_retriever |
|
test_data = prepare_trajectories(dataset_name, bm25, skb, outputs) |
|
save_path = f"{dataset_name}_{mod}.pkl" |
|
with open(save_path, 'wb') as f: |
|
pkl.dump(test_data, f) |
|
|
|
|