Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,985 Bytes
02df9f8 5c3cea5 02df9f8 9bfa66b 02df9f8 9bfa66b 02df9f8 9428a07 02df9f8 39a2dae 02df9f8 39a2dae 02df9f8 39a2dae 02df9f8 39a2dae 9428a07 8fa0ae4 39a2dae 70487ef 8fa0ae4 70487ef 02df9f8 3f861c3 9bfa66b 8fa0ae4 9bfa66b b2ef87d 9428a07 8fa0ae4 3f861c3 02df9f8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
import spaces
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
model_name = 'yuntian-deng/gpt2-implicit-cot-multiplication'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
def preprocess(num):
num = str(num).strip().replace(' ', '')
reversed_num = ' '.join(num[::-1])
return reversed_num
def postprocess(raw_output):
prediction = raw_output.replace(' ', '')[::-1]
return prediction
@spaces.GPU
def predict_product(num1, num2):
# Reverse input digits and add spaces
input_text = f'{preprocess(num1)} * {preprocess(num2)} ='
inputs = tokenizer(input_text, return_tensors='pt').to('cuda' if torch.cuda.is_available() else 'cpu')
model.to('cuda' if torch.cuda.is_available() else 'cpu')
# Generate output
outputs = model.generate(**inputs, max_new_tokens=40)
output = outputs[0][inputs['input_ids'].shape[-1]:]
raw_output = tokenizer.decode(output, skip_special_tokens=True)
prediction = postprocess(raw_output)
# Evalaute the correctness of the result
try:
num1_int = int(num1)
num2_int = int(num2)
valid_input = True
except ValueError:
valid_input = False
if valid_input:
correct_product = str(num1_int * num2_int)
is_correct = (prediction == correct_product)
result_color = "green" if is_correct else "red"
result_message = "Correct!" if is_correct else f"Incorrect! The correct product is {correct_product}."
else:
result_color = "black"
result_message = "Invalid input. Could not evaluate correctness."
result_html = f"<div style='color: {result_color};'>{result_message}</div>"
return input_text, raw_output, prediction, result_html
demo = gr.Interface(
fn=predict_product,
inputs=[
gr.Textbox(label='First Number (up to 9 digits)', value='12345'),
gr.Textbox(label='Second Number (up to 9 digits)', value='67890'),
],
outputs=[
gr.Textbox(label='Raw Input to GPT-2'),
gr.Textbox(label='Raw Output from GPT-2'),
gr.Textbox(label='Predicted Product'),
gr.HTML(label='Result Message')
],
title='GPT-2 Direct Multiplication Calculator (Without Using Chain-of-Thought)',
description='This demo uses GPT-2 to directly predict the product of two numbers without using any intermediate steps. The GPT-2 is finetuned to internalize chain-of-thought reasoning in its hidden states, using our stepwise internalization approach detailed in the paper below.',
article="""
### Additional Resources
- [Paper: From Explicit CoT to Implicit CoT: Learning to Internalize CoT Step by Step](https://arxiv.org/pdf/2405.14838)
- [Code Repository](https://github.com/da03/Internalize_CoT_Step_by_Step)
- [Tweet Announcement](https://twitter.com/yuntiandeng/status/1795854740879774036)
""",
live=False
)
demo.launch()
|