xdyu commited on
Commit
55f9861
·
verified ·
1 Parent(s): 6fb3491

Upload run_program.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. run_program.py +144 -68
run_program.py CHANGED
@@ -25,6 +25,10 @@ client = OpenAI(
25
  api_key='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'
26
  )
27
 
 
 
 
 
28
 
29
  torch.random.manual_seed(0)
30
 
@@ -46,12 +50,12 @@ torch.random.manual_seed(0)
46
  # )
47
  # qwen_tokenizer = AutoTokenizer.from_pretrained(qwen_model_name)
48
 
49
- llama_pipeline = pipeline(
50
- "text-generation",
51
- model="meta-llama/Meta-Llama-3.1-8B-Instruct",
52
- model_kwargs={"torch_dtype": torch.bfloat16},
53
- device_map="auto",
54
- )
55
 
56
 
57
 
@@ -238,7 +242,7 @@ def update_question_with_new_parameters():
238
  json.dump(program_data, outfile, indent=4)
239
 
240
 
241
- def call_answer_question(question, model_name='gpt', cot=False):
242
  if cot:
243
  prompt_template = PROMPT_DICT['prompt_answer_question_few_shot_cot']
244
  else:
@@ -250,12 +254,13 @@ def call_answer_question(question, model_name='gpt', cot=False):
250
  if model_name == 'gpt':
251
  response = client.chat.completions.create(
252
  model="gpt-4o",
 
253
  messages=[
254
  {"role": "system", "content": "You are a helpful assistant."},
255
  {"role": "user", "content": prompt}
256
  ],
257
- temperature=0,
258
- max_tokens=300,
259
  top_p=1
260
  )
261
  return response.choices[0].message.content
@@ -267,7 +272,7 @@ def call_answer_question(question, model_name='gpt', cot=False):
267
  messages=[
268
  {"role": "user", "content": prompt}
269
  ],
270
- temperature=0,
271
  top_p=1
272
  )
273
  return message.content[0].text
@@ -292,29 +297,31 @@ def call_answer_question(question, model_name='gpt', cot=False):
292
  #
293
  # output = pipe(messages, **generation_args)
294
  # print(output[0]['generated_text'])
295
- # if model_name == 'qwen':
296
- # messages = [
297
- # {"role": "system", "content": "You are a helpful assistant."},
298
- # {"role": "user", "content": prompt}
299
- # ]
300
- # text = qwen_tokenizer.apply_chat_template(
301
- # messages,
302
- # tokenize=False,
303
- # add_generation_prompt=True
304
- # )
305
- # model_inputs = qwen_tokenizer([text], return_tensors="pt").to(qwen_model.device)
306
- #
307
- # generated_ids = qwen_model.generate(
308
- # **model_inputs,
309
- # max_new_tokens=300,
310
- # temperature=0.7
311
- # )
312
- # generated_ids = [
313
- # output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
314
- # ]
315
- #
316
- # response = qwen_tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
317
- # return response
 
 
318
  if model_name == 'llama':
319
  messages = [
320
  {"role": "system", "content": "You are a helpful assistant."},
@@ -324,30 +331,39 @@ def call_answer_question(question, model_name='gpt', cot=False):
324
  messages,
325
  max_new_tokens=300,
326
  # temperature=0.00001
327
- temperature = 0.7
328
  )
329
  # print(outputs[0]["generated_text"][-1])
330
  return outputs[0]["generated_text"][-1]['content']
331
 
332
 
333
- def answer_question(model_name='gpt', cot=False):
334
- infile = open('data/math/test_dump_gsm8k_train_perturbed_with_new_questions.json', 'r')
 
335
  program_data = json.load(infile)
336
  print(len(program_data))
337
  for case in tqdm(program_data):
338
- response = call_answer_question(case['question'], model_name=model_name, cot=cot)
339
  case['prediction'] = response
340
  # print(case['prediction'])
341
  case['new_prediction'] = []
342
  for question in case['new_questions']:
343
- response = call_answer_question(question, model_name=model_name, cot=cot)
344
  case['new_prediction'].append(response)
345
  # print(case)
346
  # break
347
  # print(case)
348
  # break
349
- # outfile = open('data/math/test_dump_gsm8k_train_perturbed_with_new_questions_answer_llama8b.json', 'w')
350
- outfile = open('data/math/gsm8k_cot_sc_llama3.1_8b/temp=0.7_iter=5.json', 'w')
 
 
 
 
 
 
 
 
351
  json.dump(program_data, outfile, indent=4)
352
 
353
 
@@ -356,10 +372,19 @@ def parse_answer(answer):
356
  if 'answer is' in answer:
357
  answer = answer.split('answer is')[-1].strip()
358
  else:
359
- answer = answer.split(' ')[-1]
 
 
 
 
 
360
  if len(answer) > 0 and answer[-1] == '.':
361
  answer = answer[0:-1]
 
 
362
  answer = re.sub("[^\d\.]", "", answer)
 
 
363
  return answer
364
  else:
365
  answer_freq = {}
@@ -394,7 +419,7 @@ def collect_self_consistency_result(infile_path):
394
 
395
 
396
 
397
- def evaluator(infile_path):
398
  infile = open(infile_path, 'r')
399
  data = json.load(infile)
400
  correct_case = 0
@@ -406,44 +431,95 @@ def evaluator(infile_path):
406
  continue
407
  total_case += 1
408
  prediction = parse_answer(case['prediction'])
409
- if prediction == case['answer'] or case['answer'] in prediction:
 
 
410
  correct_case += 1
411
- new_parameter_correct_case = 0
412
- for idx, pred in enumerate(case['new_prediction']):
413
- parsed_pred = parse_answer(pred)
414
- if parsed_pred == case['new_answers'][idx] or case['new_answers'][idx] in parsed_pred:
415
- new_parameter_correct_case += 1
416
- else:
417
- try:
418
- parsed_pred = round(float(parsed_pred))
419
- new_answer = round(float(case['new_answers'][idx]))
420
- if parsed_pred == new_answer:
421
- new_parameter_correct_case += 1
422
- else:
423
- print(parsed_pred, case['new_answers'][idx])
424
- except:
425
- continue
426
- total_parameter_correct_case = len(case['new_prediction'])
427
- percentage = float(new_parameter_correct_case / total_parameter_correct_case)
428
- total_percentage += percentage
429
- if new_parameter_correct_case not in new_parameter_correct_counter:
430
- new_parameter_correct_counter[new_parameter_correct_case] = 0
431
- new_parameter_correct_counter[new_parameter_correct_case] += 1
 
 
 
 
 
432
 
433
  # else:
434
  # print(prediction, case['answer'])
435
  print(correct_case, total_case, correct_case/total_case)
436
- print(total_percentage/correct_case)
 
 
 
437
  print(new_parameter_correct_counter)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
438
 
439
  def main():
440
  # generate_new_parameter_value()
441
  # update_question_with_new_parameters()
442
- answer_question(model_name='llama', cot=True)
443
- # collect_self_consistency_result('data/math/gsm8k_cot_sc')
444
- # evaluator('data/math/gsm8k_cot_sc_merged_result.json')
445
  # evaluator('data/math/test_dump_gsm8k_train_perturbed_with_new_questions_answer_few_shot_cot_qwen.json')
446
 
 
 
447
 
448
  if __name__ == "__main__":
449
  main()
 
25
  api_key='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'
26
  )
27
 
28
+ # client = OpenAI(
29
+ # api_key='sk-svcacct-JlNMlCPtZ_F0zJtJM9yaYSYzG8xnSdksl2uYUZLuabGoOCKqDtKGTWhHOlq-Idm4lT3BlbkFJ4zHo-hOjH6J8ne9IturX2sQA-tdKDOUw3Oj44pShZZ3iM-ptGsVcd8LFvB8pBIpAA'
30
+ # )
31
+
32
 
33
  torch.random.manual_seed(0)
34
 
 
50
  # )
51
  # qwen_tokenizer = AutoTokenizer.from_pretrained(qwen_model_name)
52
 
53
+ # llama_pipeline = pipeline(
54
+ # "text-generation",
55
+ # model="meta-llama/Meta-Llama-3.1-8B-Instruct",
56
+ # model_kwargs={"torch_dtype": torch.bfloat16},
57
+ # device_map="auto",
58
+ # )
59
 
60
 
61
 
 
242
  json.dump(program_data, outfile, indent=4)
243
 
244
 
245
+ def call_answer_question(question, model_name='gpt', cot=False, temp=0.7):
246
  if cot:
247
  prompt_template = PROMPT_DICT['prompt_answer_question_few_shot_cot']
248
  else:
 
254
  if model_name == 'gpt':
255
  response = client.chat.completions.create(
256
  model="gpt-4o",
257
+ # model="gpt-4-turbo",
258
  messages=[
259
  {"role": "system", "content": "You are a helpful assistant."},
260
  {"role": "user", "content": prompt}
261
  ],
262
+ temperature=temp,
263
+ max_tokens=1024,
264
  top_p=1
265
  )
266
  return response.choices[0].message.content
 
272
  messages=[
273
  {"role": "user", "content": prompt}
274
  ],
275
+ temperature=temp,
276
  top_p=1
277
  )
278
  return message.content[0].text
 
297
  #
298
  # output = pipe(messages, **generation_args)
299
  # print(output[0]['generated_text'])
300
+
301
+ if model_name == 'qwen':
302
+ messages = [
303
+ {"role": "system", "content": "You are a helpful assistant."},
304
+ {"role": "user", "content": prompt}
305
+ ]
306
+ text = qwen_tokenizer.apply_chat_template(
307
+ messages,
308
+ tokenize=False,
309
+ add_generation_prompt=True
310
+ )
311
+ model_inputs = qwen_tokenizer([text], return_tensors="pt").to(qwen_model.device)
312
+
313
+ generated_ids = qwen_model.generate(
314
+ **model_inputs,
315
+ max_new_tokens=300,
316
+ temperature=temp
317
+ )
318
+ generated_ids = [
319
+ output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
320
+ ]
321
+
322
+ response = qwen_tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
323
+ return response
324
+
325
  if model_name == 'llama':
326
  messages = [
327
  {"role": "system", "content": "You are a helpful assistant."},
 
331
  messages,
332
  max_new_tokens=300,
333
  # temperature=0.00001
334
+ temperature = temp
335
  )
336
  # print(outputs[0]["generated_text"][-1])
337
  return outputs[0]["generated_text"][-1]['content']
338
 
339
 
340
+ def answer_question(model_name='gpt', cot=False, temp=0.0):
341
+ # infile = open('data/math/test_dump_gsm8k_train_perturbed_with_new_questions.json', 'r')
342
+ infile = open('data/math/test_dump_math_train_4o_perturbed_with_new_questions.json', 'r')
343
  program_data = json.load(infile)
344
  print(len(program_data))
345
  for case in tqdm(program_data):
346
+ response = call_answer_question(case['question'], model_name=model_name, cot=cot, temp=temp)
347
  case['prediction'] = response
348
  # print(case['prediction'])
349
  case['new_prediction'] = []
350
  for question in case['new_questions']:
351
+ response = call_answer_question(question, model_name=model_name, cot=cot, temp=temp)
352
  case['new_prediction'].append(response)
353
  # print(case)
354
  # break
355
  # print(case)
356
  # break
357
+ # outfile = open('data/math/test_dump_gsm8k_train_perturbed_with_new_questions_answer_few_shot_cot_llama8b.json', 'w')
358
+ # outfile = open('data/math/gsm8k_cot_sc_qwen/temp=0.7_iter=5.json', 'w')
359
+ # outfile = open('data/math/gsm8k_cot_sc_llama3.1_8b/temp=0.7_iter=5.json', 'w')
360
+
361
+ # outfile = open('data/math/test_dump_math_train_4o_perturbed_with_new_questions_few_shot_cot_qwen.json', 'w')
362
+ # outfile = open('data/math/test_dump_math_train_4o_perturbed_with_new_questions_few_shot_cot_gpt4o.json', 'w')
363
+ outfile = open('data/math/math_cot_sc_gpt4o/temp=0.7_iter=2.json', 'w')
364
+ # outfile = open('data/math/math_cot_sc_qwen/temp=0.7_iter=5.json', 'w')
365
+ # outfile = open('data/math/math_cot_sc_llama3.1_8b/temp=0.7_iter=4.json', 'w')
366
+
367
  json.dump(program_data, outfile, indent=4)
368
 
369
 
 
372
  if 'answer is' in answer:
373
  answer = answer.split('answer is')[-1].strip()
374
  else:
375
+ if '\\(' in answer and '\\)' in answer:
376
+ answer = answer.split('\\(')[-1].split('\\)')[0]
377
+ else:
378
+ # print("Before: ", answer)
379
+ answer = answer.split(' ')[-1]
380
+
381
  if len(answer) > 0 and answer[-1] == '.':
382
  answer = answer[0:-1]
383
+ print("##########Before: ", answer)
384
+ answer = answer.split('=')[-1]
385
  answer = re.sub("[^\d\.]", "", answer)
386
+ print("################After: ", answer)
387
+
388
  return answer
389
  else:
390
  answer_freq = {}
 
419
 
420
 
421
 
422
+ def evaluator(infile_path, normalize=False):
423
  infile = open(infile_path, 'r')
424
  data = json.load(infile)
425
  correct_case = 0
 
431
  continue
432
  total_case += 1
433
  prediction = parse_answer(case['prediction'])
434
+ parsed_gold = parse_answer(str(case['answer']))
435
+ case['answer'] = str(case['answer'])
436
+ if prediction == case['answer'] or case['answer'] in prediction or prediction == parsed_gold or parsed_gold in prediction:
437
  correct_case += 1
438
+ else:
439
+ # print(prediction)
440
+ if normalize:
441
+ continue
442
+ new_parameter_correct_case = 0
443
+ for idx, pred in enumerate(case['new_prediction']):
444
+ parsed_pred = parse_answer(pred)
445
+ parsed_gold = parse_answer(case['new_answers'][idx])
446
+ if parsed_pred == case['new_answers'][idx] or case['new_answers'][idx] in parsed_pred or parsed_pred == parsed_gold or parsed_gold in parsed_pred:
447
+ new_parameter_correct_case += 1
448
+ else:
449
+ try:
450
+ parsed_pred = round(float(parsed_pred))
451
+ new_answer = round(float(case['new_answers'][idx]))
452
+ if parsed_pred == new_answer:
453
+ new_parameter_correct_case += 1
454
+ # else:
455
+ # print(parsed_pred, case['new_answers'][idx])
456
+ except:
457
+ continue
458
+ total_parameter_correct_case = len(case['new_prediction'])
459
+ percentage = float(new_parameter_correct_case / total_parameter_correct_case)
460
+ total_percentage += percentage
461
+ if new_parameter_correct_case not in new_parameter_correct_counter:
462
+ new_parameter_correct_counter[new_parameter_correct_case] = 0
463
+ new_parameter_correct_counter[new_parameter_correct_case] += 1
464
 
465
  # else:
466
  # print(prediction, case['answer'])
467
  print(correct_case, total_case, correct_case/total_case)
468
+ if normalize:
469
+ print(total_percentage, total_percentage/correct_case)
470
+ else:
471
+ print(total_percentage, total_percentage/total_case)
472
  print(new_parameter_correct_counter)
473
+ print(new_parameter_correct_counter[5] / correct_case)
474
+
475
+
476
+ def sample_questions(filepath):
477
+ infile = open(filepath, 'r')
478
+ data = json.load(infile)
479
+ filtered_data = []
480
+ for case in data:
481
+ if 'new_answers' not in case or len(case['new_answers']) != 5:
482
+ continue
483
+ filtered_data.append(case)
484
+ filtered_data = random.sample(filtered_data, 100)
485
+
486
+ # with open('data/sample_verification/gsk8k_sample.csv', 'w', newline='') as csvfile:
487
+ # csvwriter = csv.writer(csvfile, delimiter=' ',
488
+ # quotechar='|', quoting=csv.QUOTE_MINIMAL)
489
+ # for case in filtered_data:
490
+ # csvwriter.writerow([case['question'], case['answer'], case['parameters'], case['selected_programs'][0].replace('\n', '\\n'),
491
+ # case['new_parameters'], case['new_questions'], case['new_answers']])
492
+ out_data = []
493
+ for case in filtered_data:
494
+ new_case = {
495
+ 'question': case['question'],
496
+ 'answer': case['answer'],
497
+ 'parameters': case['parameters'],
498
+ 'programs': case['candidate_programs'][0],
499
+ 'new_parameters': case['new_parameters'],
500
+ 'new_questions': case['new_questions'],
501
+ 'new_answers': case['new_answers']
502
+ }
503
+ out_data.append(new_case)
504
+ outfile1 = open('data/sample_verification/math_xiaodong_split.json', 'w')
505
+ outfile2 = open('data/sample_verification/math_ben_split.json', 'w')
506
+ outfile3 = open('data/sample_verification/math_hao_split.json', 'w')
507
+ json.dump(out_data[0:34], outfile1, indent=4)
508
+ json.dump(out_data[34:67], outfile2, indent=4)
509
+ json.dump(out_data[67:100], outfile3, indent=4)
510
+
511
+
512
 
513
  def main():
514
  # generate_new_parameter_value()
515
  # update_question_with_new_parameters()
516
+ # answer_question(model_name='gpt', cot=True, temp=0.7)
517
+ # collect_self_consistency_result('data/math/math_cot_sc_gpt4turbo')
518
+ evaluator('data/math/math_cot_sc_gpt4o/temp=0.7_iter=1.json', normalize=True)
519
  # evaluator('data/math/test_dump_gsm8k_train_perturbed_with_new_questions_answer_few_shot_cot_qwen.json')
520
 
521
+ # sample_questions('data/math/test_dump_math_train_4o_perturbed_with_new_questions.json')
522
+
523
 
524
  if __name__ == "__main__":
525
  main()