Upload run_program.py with huggingface_hub
Browse files- run_program.py +144 -68
run_program.py
CHANGED
@@ -25,6 +25,10 @@ client = OpenAI(
|
|
25 |
api_key='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'
|
26 |
)
|
27 |
|
|
|
|
|
|
|
|
|
28 |
|
29 |
torch.random.manual_seed(0)
|
30 |
|
@@ -46,12 +50,12 @@ torch.random.manual_seed(0)
|
|
46 |
# )
|
47 |
# qwen_tokenizer = AutoTokenizer.from_pretrained(qwen_model_name)
|
48 |
|
49 |
-
llama_pipeline = pipeline(
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
)
|
55 |
|
56 |
|
57 |
|
@@ -238,7 +242,7 @@ def update_question_with_new_parameters():
|
|
238 |
json.dump(program_data, outfile, indent=4)
|
239 |
|
240 |
|
241 |
-
def call_answer_question(question, model_name='gpt', cot=False):
|
242 |
if cot:
|
243 |
prompt_template = PROMPT_DICT['prompt_answer_question_few_shot_cot']
|
244 |
else:
|
@@ -250,12 +254,13 @@ def call_answer_question(question, model_name='gpt', cot=False):
|
|
250 |
if model_name == 'gpt':
|
251 |
response = client.chat.completions.create(
|
252 |
model="gpt-4o",
|
|
|
253 |
messages=[
|
254 |
{"role": "system", "content": "You are a helpful assistant."},
|
255 |
{"role": "user", "content": prompt}
|
256 |
],
|
257 |
-
temperature=
|
258 |
-
max_tokens=
|
259 |
top_p=1
|
260 |
)
|
261 |
return response.choices[0].message.content
|
@@ -267,7 +272,7 @@ def call_answer_question(question, model_name='gpt', cot=False):
|
|
267 |
messages=[
|
268 |
{"role": "user", "content": prompt}
|
269 |
],
|
270 |
-
temperature=
|
271 |
top_p=1
|
272 |
)
|
273 |
return message.content[0].text
|
@@ -292,29 +297,31 @@ def call_answer_question(question, model_name='gpt', cot=False):
|
|
292 |
#
|
293 |
# output = pipe(messages, **generation_args)
|
294 |
# print(output[0]['generated_text'])
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
|
|
|
|
318 |
if model_name == 'llama':
|
319 |
messages = [
|
320 |
{"role": "system", "content": "You are a helpful assistant."},
|
@@ -324,30 +331,39 @@ def call_answer_question(question, model_name='gpt', cot=False):
|
|
324 |
messages,
|
325 |
max_new_tokens=300,
|
326 |
# temperature=0.00001
|
327 |
-
temperature =
|
328 |
)
|
329 |
# print(outputs[0]["generated_text"][-1])
|
330 |
return outputs[0]["generated_text"][-1]['content']
|
331 |
|
332 |
|
333 |
-
def answer_question(model_name='gpt', cot=False):
|
334 |
-
infile = open('data/math/test_dump_gsm8k_train_perturbed_with_new_questions.json', 'r')
|
|
|
335 |
program_data = json.load(infile)
|
336 |
print(len(program_data))
|
337 |
for case in tqdm(program_data):
|
338 |
-
response = call_answer_question(case['question'], model_name=model_name, cot=cot)
|
339 |
case['prediction'] = response
|
340 |
# print(case['prediction'])
|
341 |
case['new_prediction'] = []
|
342 |
for question in case['new_questions']:
|
343 |
-
response = call_answer_question(question, model_name=model_name, cot=cot)
|
344 |
case['new_prediction'].append(response)
|
345 |
# print(case)
|
346 |
# break
|
347 |
# print(case)
|
348 |
# break
|
349 |
-
# outfile = open('data/math/
|
350 |
-
outfile = open('data/math/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
351 |
json.dump(program_data, outfile, indent=4)
|
352 |
|
353 |
|
@@ -356,10 +372,19 @@ def parse_answer(answer):
|
|
356 |
if 'answer is' in answer:
|
357 |
answer = answer.split('answer is')[-1].strip()
|
358 |
else:
|
359 |
-
|
|
|
|
|
|
|
|
|
|
|
360 |
if len(answer) > 0 and answer[-1] == '.':
|
361 |
answer = answer[0:-1]
|
|
|
|
|
362 |
answer = re.sub("[^\d\.]", "", answer)
|
|
|
|
|
363 |
return answer
|
364 |
else:
|
365 |
answer_freq = {}
|
@@ -394,7 +419,7 @@ def collect_self_consistency_result(infile_path):
|
|
394 |
|
395 |
|
396 |
|
397 |
-
def evaluator(infile_path):
|
398 |
infile = open(infile_path, 'r')
|
399 |
data = json.load(infile)
|
400 |
correct_case = 0
|
@@ -406,44 +431,95 @@ def evaluator(infile_path):
|
|
406 |
continue
|
407 |
total_case += 1
|
408 |
prediction = parse_answer(case['prediction'])
|
409 |
-
|
|
|
|
|
410 |
correct_case += 1
|
411 |
-
|
412 |
-
|
413 |
-
|
414 |
-
|
415 |
-
|
416 |
-
|
417 |
-
|
418 |
-
|
419 |
-
|
420 |
-
|
421 |
-
|
422 |
-
|
423 |
-
|
424 |
-
|
425 |
-
|
426 |
-
|
427 |
-
|
428 |
-
|
429 |
-
|
430 |
-
|
431 |
-
|
|
|
|
|
|
|
|
|
|
|
432 |
|
433 |
# else:
|
434 |
# print(prediction, case['answer'])
|
435 |
print(correct_case, total_case, correct_case/total_case)
|
436 |
-
|
|
|
|
|
|
|
437 |
print(new_parameter_correct_counter)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
438 |
|
439 |
def main():
|
440 |
# generate_new_parameter_value()
|
441 |
# update_question_with_new_parameters()
|
442 |
-
answer_question(model_name='
|
443 |
-
# collect_self_consistency_result('data/math/
|
444 |
-
|
445 |
# evaluator('data/math/test_dump_gsm8k_train_perturbed_with_new_questions_answer_few_shot_cot_qwen.json')
|
446 |
|
|
|
|
|
447 |
|
448 |
if __name__ == "__main__":
|
449 |
main()
|
|
|
25 |
api_key='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'
|
26 |
)
|
27 |
|
28 |
+
# client = OpenAI(
|
29 |
+
# api_key='sk-svcacct-JlNMlCPtZ_F0zJtJM9yaYSYzG8xnSdksl2uYUZLuabGoOCKqDtKGTWhHOlq-Idm4lT3BlbkFJ4zHo-hOjH6J8ne9IturX2sQA-tdKDOUw3Oj44pShZZ3iM-ptGsVcd8LFvB8pBIpAA'
|
30 |
+
# )
|
31 |
+
|
32 |
|
33 |
torch.random.manual_seed(0)
|
34 |
|
|
|
50 |
# )
|
51 |
# qwen_tokenizer = AutoTokenizer.from_pretrained(qwen_model_name)
|
52 |
|
53 |
+
# llama_pipeline = pipeline(
|
54 |
+
# "text-generation",
|
55 |
+
# model="meta-llama/Meta-Llama-3.1-8B-Instruct",
|
56 |
+
# model_kwargs={"torch_dtype": torch.bfloat16},
|
57 |
+
# device_map="auto",
|
58 |
+
# )
|
59 |
|
60 |
|
61 |
|
|
|
242 |
json.dump(program_data, outfile, indent=4)
|
243 |
|
244 |
|
245 |
+
def call_answer_question(question, model_name='gpt', cot=False, temp=0.7):
|
246 |
if cot:
|
247 |
prompt_template = PROMPT_DICT['prompt_answer_question_few_shot_cot']
|
248 |
else:
|
|
|
254 |
if model_name == 'gpt':
|
255 |
response = client.chat.completions.create(
|
256 |
model="gpt-4o",
|
257 |
+
# model="gpt-4-turbo",
|
258 |
messages=[
|
259 |
{"role": "system", "content": "You are a helpful assistant."},
|
260 |
{"role": "user", "content": prompt}
|
261 |
],
|
262 |
+
temperature=temp,
|
263 |
+
max_tokens=1024,
|
264 |
top_p=1
|
265 |
)
|
266 |
return response.choices[0].message.content
|
|
|
272 |
messages=[
|
273 |
{"role": "user", "content": prompt}
|
274 |
],
|
275 |
+
temperature=temp,
|
276 |
top_p=1
|
277 |
)
|
278 |
return message.content[0].text
|
|
|
297 |
#
|
298 |
# output = pipe(messages, **generation_args)
|
299 |
# print(output[0]['generated_text'])
|
300 |
+
|
301 |
+
if model_name == 'qwen':
|
302 |
+
messages = [
|
303 |
+
{"role": "system", "content": "You are a helpful assistant."},
|
304 |
+
{"role": "user", "content": prompt}
|
305 |
+
]
|
306 |
+
text = qwen_tokenizer.apply_chat_template(
|
307 |
+
messages,
|
308 |
+
tokenize=False,
|
309 |
+
add_generation_prompt=True
|
310 |
+
)
|
311 |
+
model_inputs = qwen_tokenizer([text], return_tensors="pt").to(qwen_model.device)
|
312 |
+
|
313 |
+
generated_ids = qwen_model.generate(
|
314 |
+
**model_inputs,
|
315 |
+
max_new_tokens=300,
|
316 |
+
temperature=temp
|
317 |
+
)
|
318 |
+
generated_ids = [
|
319 |
+
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
|
320 |
+
]
|
321 |
+
|
322 |
+
response = qwen_tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
323 |
+
return response
|
324 |
+
|
325 |
if model_name == 'llama':
|
326 |
messages = [
|
327 |
{"role": "system", "content": "You are a helpful assistant."},
|
|
|
331 |
messages,
|
332 |
max_new_tokens=300,
|
333 |
# temperature=0.00001
|
334 |
+
temperature = temp
|
335 |
)
|
336 |
# print(outputs[0]["generated_text"][-1])
|
337 |
return outputs[0]["generated_text"][-1]['content']
|
338 |
|
339 |
|
340 |
+
def answer_question(model_name='gpt', cot=False, temp=0.0):
|
341 |
+
# infile = open('data/math/test_dump_gsm8k_train_perturbed_with_new_questions.json', 'r')
|
342 |
+
infile = open('data/math/test_dump_math_train_4o_perturbed_with_new_questions.json', 'r')
|
343 |
program_data = json.load(infile)
|
344 |
print(len(program_data))
|
345 |
for case in tqdm(program_data):
|
346 |
+
response = call_answer_question(case['question'], model_name=model_name, cot=cot, temp=temp)
|
347 |
case['prediction'] = response
|
348 |
# print(case['prediction'])
|
349 |
case['new_prediction'] = []
|
350 |
for question in case['new_questions']:
|
351 |
+
response = call_answer_question(question, model_name=model_name, cot=cot, temp=temp)
|
352 |
case['new_prediction'].append(response)
|
353 |
# print(case)
|
354 |
# break
|
355 |
# print(case)
|
356 |
# break
|
357 |
+
# outfile = open('data/math/test_dump_gsm8k_train_perturbed_with_new_questions_answer_few_shot_cot_llama8b.json', 'w')
|
358 |
+
# outfile = open('data/math/gsm8k_cot_sc_qwen/temp=0.7_iter=5.json', 'w')
|
359 |
+
# outfile = open('data/math/gsm8k_cot_sc_llama3.1_8b/temp=0.7_iter=5.json', 'w')
|
360 |
+
|
361 |
+
# outfile = open('data/math/test_dump_math_train_4o_perturbed_with_new_questions_few_shot_cot_qwen.json', 'w')
|
362 |
+
# outfile = open('data/math/test_dump_math_train_4o_perturbed_with_new_questions_few_shot_cot_gpt4o.json', 'w')
|
363 |
+
outfile = open('data/math/math_cot_sc_gpt4o/temp=0.7_iter=2.json', 'w')
|
364 |
+
# outfile = open('data/math/math_cot_sc_qwen/temp=0.7_iter=5.json', 'w')
|
365 |
+
# outfile = open('data/math/math_cot_sc_llama3.1_8b/temp=0.7_iter=4.json', 'w')
|
366 |
+
|
367 |
json.dump(program_data, outfile, indent=4)
|
368 |
|
369 |
|
|
|
372 |
if 'answer is' in answer:
|
373 |
answer = answer.split('answer is')[-1].strip()
|
374 |
else:
|
375 |
+
if '\\(' in answer and '\\)' in answer:
|
376 |
+
answer = answer.split('\\(')[-1].split('\\)')[0]
|
377 |
+
else:
|
378 |
+
# print("Before: ", answer)
|
379 |
+
answer = answer.split(' ')[-1]
|
380 |
+
|
381 |
if len(answer) > 0 and answer[-1] == '.':
|
382 |
answer = answer[0:-1]
|
383 |
+
print("##########Before: ", answer)
|
384 |
+
answer = answer.split('=')[-1]
|
385 |
answer = re.sub("[^\d\.]", "", answer)
|
386 |
+
print("################After: ", answer)
|
387 |
+
|
388 |
return answer
|
389 |
else:
|
390 |
answer_freq = {}
|
|
|
419 |
|
420 |
|
421 |
|
422 |
+
def evaluator(infile_path, normalize=False):
|
423 |
infile = open(infile_path, 'r')
|
424 |
data = json.load(infile)
|
425 |
correct_case = 0
|
|
|
431 |
continue
|
432 |
total_case += 1
|
433 |
prediction = parse_answer(case['prediction'])
|
434 |
+
parsed_gold = parse_answer(str(case['answer']))
|
435 |
+
case['answer'] = str(case['answer'])
|
436 |
+
if prediction == case['answer'] or case['answer'] in prediction or prediction == parsed_gold or parsed_gold in prediction:
|
437 |
correct_case += 1
|
438 |
+
else:
|
439 |
+
# print(prediction)
|
440 |
+
if normalize:
|
441 |
+
continue
|
442 |
+
new_parameter_correct_case = 0
|
443 |
+
for idx, pred in enumerate(case['new_prediction']):
|
444 |
+
parsed_pred = parse_answer(pred)
|
445 |
+
parsed_gold = parse_answer(case['new_answers'][idx])
|
446 |
+
if parsed_pred == case['new_answers'][idx] or case['new_answers'][idx] in parsed_pred or parsed_pred == parsed_gold or parsed_gold in parsed_pred:
|
447 |
+
new_parameter_correct_case += 1
|
448 |
+
else:
|
449 |
+
try:
|
450 |
+
parsed_pred = round(float(parsed_pred))
|
451 |
+
new_answer = round(float(case['new_answers'][idx]))
|
452 |
+
if parsed_pred == new_answer:
|
453 |
+
new_parameter_correct_case += 1
|
454 |
+
# else:
|
455 |
+
# print(parsed_pred, case['new_answers'][idx])
|
456 |
+
except:
|
457 |
+
continue
|
458 |
+
total_parameter_correct_case = len(case['new_prediction'])
|
459 |
+
percentage = float(new_parameter_correct_case / total_parameter_correct_case)
|
460 |
+
total_percentage += percentage
|
461 |
+
if new_parameter_correct_case not in new_parameter_correct_counter:
|
462 |
+
new_parameter_correct_counter[new_parameter_correct_case] = 0
|
463 |
+
new_parameter_correct_counter[new_parameter_correct_case] += 1
|
464 |
|
465 |
# else:
|
466 |
# print(prediction, case['answer'])
|
467 |
print(correct_case, total_case, correct_case/total_case)
|
468 |
+
if normalize:
|
469 |
+
print(total_percentage, total_percentage/correct_case)
|
470 |
+
else:
|
471 |
+
print(total_percentage, total_percentage/total_case)
|
472 |
print(new_parameter_correct_counter)
|
473 |
+
print(new_parameter_correct_counter[5] / correct_case)
|
474 |
+
|
475 |
+
|
476 |
+
def sample_questions(filepath):
|
477 |
+
infile = open(filepath, 'r')
|
478 |
+
data = json.load(infile)
|
479 |
+
filtered_data = []
|
480 |
+
for case in data:
|
481 |
+
if 'new_answers' not in case or len(case['new_answers']) != 5:
|
482 |
+
continue
|
483 |
+
filtered_data.append(case)
|
484 |
+
filtered_data = random.sample(filtered_data, 100)
|
485 |
+
|
486 |
+
# with open('data/sample_verification/gsk8k_sample.csv', 'w', newline='') as csvfile:
|
487 |
+
# csvwriter = csv.writer(csvfile, delimiter=' ',
|
488 |
+
# quotechar='|', quoting=csv.QUOTE_MINIMAL)
|
489 |
+
# for case in filtered_data:
|
490 |
+
# csvwriter.writerow([case['question'], case['answer'], case['parameters'], case['selected_programs'][0].replace('\n', '\\n'),
|
491 |
+
# case['new_parameters'], case['new_questions'], case['new_answers']])
|
492 |
+
out_data = []
|
493 |
+
for case in filtered_data:
|
494 |
+
new_case = {
|
495 |
+
'question': case['question'],
|
496 |
+
'answer': case['answer'],
|
497 |
+
'parameters': case['parameters'],
|
498 |
+
'programs': case['candidate_programs'][0],
|
499 |
+
'new_parameters': case['new_parameters'],
|
500 |
+
'new_questions': case['new_questions'],
|
501 |
+
'new_answers': case['new_answers']
|
502 |
+
}
|
503 |
+
out_data.append(new_case)
|
504 |
+
outfile1 = open('data/sample_verification/math_xiaodong_split.json', 'w')
|
505 |
+
outfile2 = open('data/sample_verification/math_ben_split.json', 'w')
|
506 |
+
outfile3 = open('data/sample_verification/math_hao_split.json', 'w')
|
507 |
+
json.dump(out_data[0:34], outfile1, indent=4)
|
508 |
+
json.dump(out_data[34:67], outfile2, indent=4)
|
509 |
+
json.dump(out_data[67:100], outfile3, indent=4)
|
510 |
+
|
511 |
+
|
512 |
|
513 |
def main():
|
514 |
# generate_new_parameter_value()
|
515 |
# update_question_with_new_parameters()
|
516 |
+
# answer_question(model_name='gpt', cot=True, temp=0.7)
|
517 |
+
# collect_self_consistency_result('data/math/math_cot_sc_gpt4turbo')
|
518 |
+
evaluator('data/math/math_cot_sc_gpt4o/temp=0.7_iter=1.json', normalize=True)
|
519 |
# evaluator('data/math/test_dump_gsm8k_train_perturbed_with_new_questions_answer_few_shot_cot_qwen.json')
|
520 |
|
521 |
+
# sample_questions('data/math/test_dump_math_train_4o_perturbed_with_new_questions.json')
|
522 |
+
|
523 |
|
524 |
if __name__ == "__main__":
|
525 |
main()
|