da03 commited on
Commit
8fa0ae4
·
1 Parent(s): 36cbe43
Files changed (1) hide show
  1. app.py +30 -3
app.py CHANGED
@@ -25,15 +25,36 @@ def predict_product(num1, num2):
25
  output = outputs[0][inputs['input_ids'].shape[-1]:]
26
  raw_output = tokenizer.decode(output, skip_special_tokens=True)
27
  prediction = postprocess(raw_output)
28
- return input_text, raw_output, prediction
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
  demo = gr.Interface(
31
  fn=predict_product,
32
- inputs=[gr.Number(label='First Number (up to 9 digits)', value=12345), gr.Number(label='Second Number (up to 9 digits)', value=67890)],
33
  outputs=[
34
  gr.Textbox(label='Raw Input to GPT-2'),
35
  gr.Textbox(label='Raw Output from GPT-2'),
36
- gr.Textbox(label='Predicted Product')
 
37
  ],
38
  title='GPT-2 Multiplication Predictor',
39
  description='Enter two numbers up to 9 digits each and get the predicted product.',
@@ -42,6 +63,12 @@ demo = gr.Interface(
42
  - [Paper: From Explicit CoT to Implicit CoT: Learning to Internalize CoT Step by Step](https://arxiv.org/pdf/2405.14838)
43
  - [Code Repository](https://github.com/da03/Internalize_CoT_Step_by_Step)
44
  - [Tweet Announcement](https://twitter.com/yuntiandeng/status/1795854740879774036)
 
 
 
 
 
 
45
  """
46
  )
47
 
 
25
  output = outputs[0][inputs['input_ids'].shape[-1]:]
26
  raw_output = tokenizer.decode(output, skip_special_tokens=True)
27
  prediction = postprocess(raw_output)
28
+
29
+ try:
30
+ num1_int = int(num1)
31
+ num2_int = int(num2)
32
+ valid_input = True
33
+ except ValueError:
34
+ valid_input = False
35
+
36
+ if valid_input:
37
+ correct_product = str(num1_int * num2_int)
38
+ is_correct = (prediction == correct_product)
39
+ result_color = "green" if is_correct else "red"
40
+ result_message = "Correct!" if is_correct else f"Incorrect! The correct product is {correct_product}."
41
+ else:
42
+ result_color = "black"
43
+ result_message = "Invalid input. Could not evaluate correctness."
44
+
45
+ return input_text, raw_output, prediction, result_message, result_color
46
+
47
+ def output_component(value, color):
48
+ return gr.HTML.update(value=f"<div style='color: {color};'>{value}</div>")
49
 
50
  demo = gr.Interface(
51
  fn=predict_product,
52
+ inputs=[gr.Textbox(label='First Number (up to 9 digits)', value='12345'), gr.Textbox(label='Second Number (up to 9 digits)', value='67890')],
53
  outputs=[
54
  gr.Textbox(label='Raw Input to GPT-2'),
55
  gr.Textbox(label='Raw Output from GPT-2'),
56
+ gr.Textbox(label='Predicted Product'),
57
+ gr.HTML(label='Result Message')
58
  ],
59
  title='GPT-2 Multiplication Predictor',
60
  description='Enter two numbers up to 9 digits each and get the predicted product.',
 
63
  - [Paper: From Explicit CoT to Implicit CoT: Learning to Internalize CoT Step by Step](https://arxiv.org/pdf/2405.14838)
64
  - [Code Repository](https://github.com/da03/Internalize_CoT_Step_by_Step)
65
  - [Tweet Announcement](https://twitter.com/yuntiandeng/status/1795854740879774036)
66
+ """,
67
+ live=True,
68
+ css="""
69
+ .output-html {
70
+ font-size: 1.5em;
71
+ }
72
  """
73
  )
74