Spaces:
Running
on
Zero
Running
on
Zero
File size: 8,410 Bytes
02df9f8 5c3cea5 02df9f8 9a65236 3f8feaa 9a65236 c03e8f2 02df9f8 9bfa66b 02df9f8 9428a07 02df9f8 9a65236 39a2dae bf65d9e f7dc2d2 9a65236 f7dc2d2 8fa0ae4 70487ef 9a65236 70487ef f7dc2d2 eaa0586 9a65236 d8750b1 9a65236 7c08a74 c8d010a 7c08a74 9a65236 dde88b9 9a65236 e01b57c 9a65236 ad4fc9e 129e149 9f650af ad4fc9e 9a65236 0ad2aca d8750b1 9a65236 cc2aeb6 d1a789b fe5b36f d1a789b fe5b36f 9a65236 7c08a74 9a65236 bf65d9e 3738ecc 9a65236 7c08a74 d8750b1 bf65d9e d8750b1 f7dc2d2 d8750b1 eaa0586 d8750b1 02df9f8 cf0cf7e 02df9f8 3f861c3 df028fe 3f861c3 9bfa66b 9a65236 fb97ac0 dc59d8f 9bfa66b 85fafa0 b8ed883 85fafa0 486c21f ebf0fc0 02df9f8 ebf0fc0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
import spaces
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
# Load models
implicit_cot_model_name = 'yuntian-deng/gpt2-implicit-cot-multiplication'
implicit_cot_model = AutoModelForCausalLM.from_pretrained(implicit_cot_model_name)
tokenizer = AutoTokenizer.from_pretrained(implicit_cot_model_name)
no_cot_model_name = 'yuntian-deng/gpt2-no-cot-multiplication'
no_cot_model = AutoModelForCausalLM.from_pretrained(no_cot_model_name)
explicit_cot_model_name = 'yuntian-deng/gpt2-explicit-cot-multiplication'
explicit_cot_model = AutoModelForCausalLM.from_pretrained(explicit_cot_model_name)
models = {'implicit': implicit_cot_model, 'no': no_cot_model, 'explicit': explicit_cot_model}
# Constants
MAX_PRODUCT_DIGITS_PER_MODEL = {'implicit': 100, 'no': 100, 'explicit': 960}
def preprocess(num):
num = str(num).strip().replace(' ', '')
reversed_num = ' '.join(num[::-1])
return reversed_num
def postprocess(raw_output):
prediction = raw_output.replace(' ', '')[::-1]
return prediction
@spaces.GPU
def predict_product(num1, num2):
input_text = f'{preprocess(num1)} * {preprocess(num2)} ='
inputs = tokenizer(input_text, return_tensors='pt').to('cuda' if torch.cuda.is_available() else 'cpu')
[model.to('cuda' if torch.cuda.is_available() else 'cpu') for model in models.values()]
input_ids = inputs['input_ids']
input_len = input_ids.shape[-1]
prediction = ""
ground_truth_product = ""
valid_input = True
try:
num1_int = int(num1)
num2_int = int(num2)
ground_truth_product = str(num1_int * num2_int)
ground_truth_digits_reversed = list(ground_truth_product)[::-1]
except ValueError:
valid_input = False
generated_ids_per_model = {model_name: inputs['input_ids'].data.clone() for model_name in models}
finished_per_model = {model_name: False for model_name in models}
past_key_values_per_model = {model_name: None for model_name in models}
predicted_annotations_per_model = {}
for step in range(max(MAX_PRODUCT_DIGITS_PER_MODEL.values())): # Set a maximum limit to prevent infinite loops
# Ground Truth
if not valid_input:
ground_truth_annotations = [('Invalid Input!', None)]
else:
ground_truth_annotations = [(ground_truth_digit, None) for ground_truth_digit in ground_truth_digits_reversed[:step+1]]
ground_truth_annotations = ground_truth_annotations[::-1]
# Predicted
for model_name in models:
model = models[model_name]
if finished_per_model[model_name]:
continue
if step >= MAX_PRODUCT_DIGITS_PER_MODEL[model_name]:
continue
generation_kwargs = {
'input_ids': generated_ids_per_model[model_name],
'max_new_tokens': 1,
'do_sample': False,
'past_key_values': past_key_values_per_model[model_name],
'return_dict_in_generate': True,
'use_cache': True
}
if step == 0:
del generation_kwargs['past_key_values']
outputs = model.generate(**generation_kwargs)
generated_ids = outputs.sequences
next_token_id = generated_ids[0, -1]
#print (next_token_id)
if next_token_id.item() == tokenizer.eos_token_id:
finished_per_model[model_name] = True
if valid_input:
if len([item for item in predicted_annotations_per_model[model_name] if item[1] is not None]) < len(ground_truth_digits_reversed):
predicted_annotations_per_model[model_name].insert(0, ('⠀', 'wrong'))
continue
generated_ids_per_model[model_name] = generated_ids
past_key_values_per_model[model_name] = outputs.past_key_values
output_text = tokenizer.decode(generated_ids[0, input_len:], skip_special_tokens=True)
predicted_digits_reversed = output_text.strip().split(' ')
predicted_annotations = []
is_correct_sofar = True
if model_name == 'explicit':
if '=' not in predicted_digits_reversed:
predicted_annotations = [(predicted_digit, None) for predicted_digit in predicted_digits_reversed]
predicted_digits_reversed = []
else:
equal_sign_position = predicted_digits_reversed.index('=')
predicted_annotations = [(predicted_digit, None) for predicted_digit in predicted_digits_reversed[:equal_sign_position+1]]
predicted_digits_reversed = predicted_digits_reversed[equal_sign_position+1:]
for i in range(len(predicted_digits_reversed)):
predicted_digit = predicted_digits_reversed[i]
if not valid_input:
is_correct_digit = None
elif i >= len(ground_truth_digits_reversed):
if predicted_digit == '0' and is_correct_sofar:
is_correct_digit = True
else:
is_correct_digit = False
else:
ground_truth_digit = ground_truth_digits_reversed[i]
if predicted_digit == ground_truth_digit:
is_correct_digit = True
else:
is_correct_digit = False
if not is_correct_digit:
is_correct_sofar = False
if is_correct_digit is None:
predicted_annotations.append((predicted_digit, None))
elif is_correct_digit:
predicted_annotations.append((predicted_digit, "correct"))
else:
predicted_annotations.append((predicted_digit, "wrong"))
predicted_annotations = predicted_annotations[::-1]
predicted_annotations_per_model[model_name] = predicted_annotations
predicted_annotations_implicit_cot = predicted_annotations_per_model['implicit']
predicted_annotations_nocot = predicted_annotations_per_model['no']
predicted_annotations_explicit_cot = predicted_annotations_per_model['explicit']
yield ground_truth_annotations, predicted_annotations_implicit_cot, predicted_annotations_nocot, predicted_annotations_explicit_cot
color_map = {"correct": "green", "wrong": "red"}
demo = gr.Interface(
fn=predict_product,
inputs=[
gr.Textbox(label='First Number (up to 15 digits)', value='123456789'),
gr.Textbox(label='Second Number (up to 15 digits)', value='987654321'),
],
outputs=[
gr.HighlightedText(label='Ground Truth Product', combine_adjacent=False, show_legend=False, color_map=color_map),
gr.HighlightedText(label='Implicit CoT Prediction (Ours)', combine_adjacent=False, show_legend=False, color_map=color_map, show_inline_category=False),
gr.HighlightedText(label='No CoT Prediction', combine_adjacent=False, show_legend=False, color_map=color_map, show_inline_category=False),
gr.HighlightedText(label='Explicit CoT Steps & Prediction', combine_adjacent=False, show_legend=False, color_map=color_map, show_inline_category=False),
],
title='Predicting Multiplication with GPT-2: Implicit vs. Explicit CoT',
description='This demo showcases GPT-2\'s ability to directly predict the product of two large numbers without intermediate steps, using our stepwise internalization method. Compare the performance of implicit CoT (our method), no CoT, and explicit CoT. Implicit CoT offers accuracy and speed, while explicit CoT provides detailed reasoning but is slower.',
article="""
- [Paper 1: Implicit Chain of Thought Reasoning via Knowledge Distillation](https://arxiv.org/pdf/2311.01460)
- [Paper 2: From Explicit CoT to Implicit CoT: Learning to Internalize CoT Step by Step](https://arxiv.org/pdf/2405.14838)
- [Code Repository](https://github.com/da03/Internalize_CoT_Step_by_Step)
- [Tweet Announcement](https://twitter.com/yuntiandeng/status/1795854740879774036)
""",
clear_btn=None,
submit_btn="Multiply!",
live=False,
concurrency_limit=1
)
demo.queue(max_size=20).launch()
|