CC / utils /generate_distractors.py
谢璐璟
.
e9ce3e8
raw
history blame
6.43 kB
import json
import re
from tqdm import tqdm
import os
import asyncio
from openai import AsyncOpenAI
from utils.api_utils import generate_from_openai_chat_completion, generate_from_claude_chat_completion
def construct_prompt_textonly(question: str, options: list, answer: str, answer_analysis: str) -> str:
optionized_list = [f"{chr(65 + i)}. {option}" for i, option in enumerate(options)]
optionized_str = "\n".join(optionized_list)
prompt = f"""
Generate a multiple-choice question with additional distractors that increase the complexity of answer selection. Follow these instructions:
1. **Retain Original Structure**: Retain the original question and options.
2. **Add Three Distractors**: Add three new distractors that are **plausible and maintain professional validity**. These should increase the difficulty but still be incorrect, based on the original question and answer analysis.
3. **Use Answer Analysis**: Reference the **correct answer analysis** when creating distractors to ensure they challenge **subject-matter experts**.
4. **Expert-Level Difficulty**: Keep the distractors **challenging and hard to distinguish** from the correct answer, requiring **advanced knowledge** to avoid the correct answer being too obvious.
5. **Balanced Length**: Ensure all options have **similar lengths** to prevent any one option from standing out.
6. **Distractors Analysis**: Provide a **distractor analysis in Chinese**, explaining why the distractors are **incorrect** but **challenging and hard to distinguish**.
Please output the result in valid JSON format using the structure below. Make sure there are no extra commas, missing commas, extra quotation marks or missing quotation marks:
{{
"question": "{question}",
"options": {{
"A": "{options[0]}",
"B": "{options[1]}",
"C": "{options[2]}",
"D": "{options[3]}"
}},
"distractors": {{
"E": "New distractor 1",
"F": "New distractor 2",
"G": "New distractor 3",
"analysis_of_distractors": "Use Chinese to explain why the distractors are **incorrect** but **challenging and hard to distinguish**, based on the question, options, and answer analysis.",
}},
"correct_answer": "{answer}",
}}
Input:
Question: {question}
Options:
{optionized_str}
Answer: {answer}
Answer Analysis: {answer_analysis}
"""
# prompt = prompt.replace("I don't know.", "Idle.")
return prompt
def prepare_q_text_input(query, prompt_func=construct_prompt_textonly):
question = query['question']
options = [query['option_1'], query['option_2'], query['option_3'], query['option_4']]
gt = query['answer']
answer_analysis = query['answer_analysis']
q_text_prompt = prompt_func(question=question, options=options, answer=gt, answer_analysis=answer_analysis)
return q_text_prompt
def prepare_q_inputs(queries):
messages = []
for i, query in enumerate(queries):
q_text_prompt = prepare_q_text_input(query)
prompt_message = [
{
"role": "user",
"content": q_text_prompt,
},
]
messages.append(prompt_message)
return messages
# def extract_json_from_text(text):
# text = json.dumps(text)
# # 移除转义符和换行符
# text = text.replace('\\n', '').replace('\\"', '"')
# # 定义匹配 JSON 对象的正则表达式模式
# json_pattern = re.compile(
# r'\{\s*"question":\s*"([^"]*)",\s*"options":\s*\{\s*"A":\s*"([^"]*)",\s*"B":\s*"([^"]*)",\s*"C":\s*"([^"]*)",\s*"D":\s*"([^"]*)"\s*\},'
# r'\s*"distractors":\s*\{\s*"E":\s*"([^"]*)",\s*"F":\s*"([^"]*)",\s*"G":\s*"([^"]*)"\s*\},\s*"correct_answer":\s*"([^"]*)"\s*\}',
# re.DOTALL
# )
# # 匹配 JSON 结构
# match = json_pattern.search(text)
# if match:
# # 捕获到的匹配组
# question = match.group(1)
# option_a = match.group(2)
# option_b = match.group(3)
# option_c = match.group(4)
# option_d = match.group(5)
# distractor_e = match.group(6)
# distractor_f = match.group(7)
# distractor_g = match.group(8)
# correct_answer = match.group(9)
# # 构建 JSON 对象
# json_data = {
# "question": question,
# "options": {
# "A": option_a,
# "B": option_b,
# "C": option_c,
# "D": option_d
# },
# "distractors": {
# "E": distractor_e,
# "F": distractor_f,
# "G": distractor_g
# },
# "correct_answer": correct_answer
# }
# return json_data
# else:
# print("No JSON object found in the text.")
# return None
def generate_distractors(model_name: str,
queries: list,
n: int=1,
max_tokens: int=4096):
assert model_name in ["gpt-4o-mini", "gpt-4-turbo", "gpt-4o", "gpt-4o-2024-08-06"], "Invalid model name"
client = AsyncOpenAI(api_key=os.environ.get("OPENAI_API_KEY"),base_url="https://yanlp.zeabur.app/v1")
messages = prepare_q_inputs(queries)
responses = asyncio.run(
generate_from_openai_chat_completion(
client,
messages=messages,
engine_name=model_name,
n = n,
max_tokens=max_tokens,
requests_per_minute=30,
json_format=True
)
)
for query, response in zip(queries, responses):
new_options = response
# print(new_options)
if new_options and "distractors" in new_options:
query["option_5"] = new_options["distractors"].get("E", "")
else:
query["option_5"] = ""
if new_options and "distractors" in new_options:
query["option_6"] = new_options["distractors"].get("F", "")
else:
query["option_6"] = ""
if new_options and "distractors" in new_options:
query["option_7"] = new_options["distractors"].get("G", "")
else:
query["option_7"] = ""
if new_options and "distractors" in new_options:
query["distractor_analysis"] = new_options["distractors"].get("analysis_of_distractors", "")
else:
query["distractor_analysis"] = ""
return queries