|
""" |
|
Desc: This file is used to get the training data from the LLM |
|
|
|
""" |
|
import sys |
|
from pathlib import Path |
|
|
|
|
|
current_file = Path(__file__).resolve() |
|
project_root = current_file.parents[3] |
|
|
|
|
|
sys.path.append(str(project_root)) |
|
|
|
from stark_qa import load_qa |
|
|
|
import argparse |
|
import os |
|
from openai import AzureOpenAI |
|
import json |
|
import openai |
|
from prompts import prompts |
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
|
MAG: |
|
sys_content: 478/query |
|
output: 45/query |
|
input: 25/query |
|
1000 queries |
|
|
|
total price: |
|
1. o1: $13.29 |
|
2. o3mini: $0.97 |
|
3. deepseek-chat: $0.24 |
|
4. deepseek-reasoner: $0.49 |
|
|
|
Amazon: |
|
sys_content: 478/query |
|
|
|
""" |
|
|
|
|
|
def get_sys_content(dataset_name): |
|
""" |
|
input: |
|
dataset_name: the name of the dataset |
|
output: |
|
sys_content: the sys_content for the dataset |
|
""" |
|
sys_content = prompts(dataset_name) |
|
|
|
|
|
return sys_content |
|
|
|
|
|
def get_response(sys_content, user_content): |
|
|
|
messages = [{"role": "system", "content": sys_content}, |
|
{"role": "user", "content": user_content} |
|
] |
|
|
|
chat_completion = client.chat.completions.create( |
|
messages=messages, |
|
model=parameters['azure']['model'], |
|
|
|
seed=576879897, |
|
) |
|
response = chat_completion.choices[0].message.content |
|
|
|
|
|
|
|
|
|
return response |
|
|
|
|
|
def save_json(data, dataset_name): |
|
""" |
|
input: |
|
data: the data to be saved |
|
dataset_name: the name of the dataset |
|
""" |
|
|
|
file_dir = f"/home/yongjia/dgl/Yongjia/MOE/Reasoner/data/finetune/{dataset_name}" |
|
os.makedirs(file_dir, exist_ok=True) |
|
file_path = f"{file_dir}/1000_{parameters['azure']['model']}.json" |
|
|
|
with open(file_path, 'w') as f: |
|
json.dump(data, f, indent=4) |
|
print(f"Saved to {file_path}") |
|
|
|
|
|
def get_rg(dataset_name): |
|
""" |
|
input: |
|
dataset_name: the name of the dataset |
|
output: |
|
rg: the reasoning graph for the dataset |
|
""" |
|
|
|
|
|
sys_content = get_sys_content(dataset_name) |
|
|
|
|
|
qa = load_qa(dataset_name) |
|
train_qa = qa.get_subset('train') |
|
|
|
|
|
pair_list = [] |
|
failure_count = 0 |
|
for i in range(1500): |
|
query, q_id, ans_ids, _ = train_qa[i] |
|
|
|
|
|
response = get_response(sys_content, query) |
|
print(response) |
|
|
|
|
|
|
|
if dataset_name == 'prime': |
|
output = { |
|
"Triplets":[], |
|
"Restriction": [], |
|
"Target": "" |
|
} |
|
|
|
try: |
|
response = response.split('\n') |
|
triplets_raw = response[0].replace('Triplets:', '').strip() |
|
triplets = json.loads(triplets_raw) |
|
output['Triplets'] = triplets |
|
|
|
restriction_raw = response[1].replace('Restriction:', '').strip() |
|
restriction = json.loads(restriction_raw) |
|
output['Restriction'] = restriction |
|
|
|
target = response[2].replace('Target:', '').strip() |
|
output['Target'] = target |
|
except: |
|
failure_count += 1 |
|
continue |
|
|
|
elif dataset_name == 'mag' or dataset_name == 'amazon': |
|
output = { |
|
"Metapath": "", |
|
"Restriction": [], |
|
} |
|
|
|
try: |
|
response = response.split('\n') |
|
metapath = response[0].replace('Metapath:', '').strip() |
|
output['Metapath'] = metapath |
|
|
|
restriction_raw = response[1].replace('Restriction:', '').strip() |
|
restriction = json.loads(restriction_raw) |
|
output['Restriction'] = restriction |
|
except: |
|
failure_count += 1 |
|
continue |
|
|
|
else: |
|
raise ValueError('The dataset is not supported') |
|
|
|
pair = {'query': query, 'answer': output} |
|
|
|
pair_list.append(pair) |
|
|
|
if len(pair_list) == 1000: |
|
break |
|
|
|
|
|
save_json(pair_list, dataset_name) |
|
print(f"Failure count: {failure_count}") |
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
parser = argparse.ArgumentParser(description="Load LLM parameters and initialize API clients.") |
|
|
|
|
|
parser.add_argument("--dataset_name", type=str, required=True, |
|
choices=["mag", "amazon", "prime"], |
|
help="Specify the dataset to use.") |
|
|
|
|
|
parser.add_argument("--model", type=str, required=True, |
|
choices=["gpt-4o-mini-20240718", "gpt-4o-2024-05-13", |
|
"deepseek-reasoner", "gpt-o1-2024-12-17", |
|
"o3-mini-2025-01-31"], |
|
help="Specify the model to use.") |
|
|
|
|
|
parser.add_argument("--azure_api_key", type=str, default=None, help="Azure API Key") |
|
parser.add_argument("--azure_endpoint", type=str, default=None, help="Azure API Endpoint") |
|
parser.add_argument("--azure_api_version", type=str, default=None, help="Azure API Version") |
|
|
|
|
|
parser.add_argument("--openai_api_key", type=str, default=None, help="OpenAI API Key") |
|
parser.add_argument("--openai_endpoint", type=str, default=None, help="OpenAI API Endpoint") |
|
|
|
args = parser.parse_args() |
|
|
|
|
|
parameters = { |
|
"azure": { |
|
"api_key": args.azure_api_key, |
|
"azure_endpoint": args.azure_endpoint, |
|
"api_version": args.azure_api_version, |
|
}, |
|
"openai": { |
|
"api_key": args.openai_api_key, |
|
"endpoint": args.openai_endpoint, |
|
} |
|
} |
|
|
|
|
|
|
|
if parameters["openai"]["api_key"]: |
|
client = openai.OpenAI( |
|
base_url=parameters["openai"]["endpoint"], |
|
api_key=parameters["openai"]["api_key"], |
|
) |
|
else: |
|
client = AzureOpenAI( |
|
azure_endpoint=parameters["azure"]["azure_endpoint"], |
|
api_key=parameters["azure"]["api_key"], |
|
api_version=parameters["azure"]["api_version"], |
|
) |
|
|
|
get_rg(args.dataset_name) |
|
|
|
|
|
|