|
import torch |
|
from transformers import T5Tokenizer, T5ForConditionalGeneration |
|
|
|
def pretty_print(text, prompt=True): |
|
s = "" |
|
if prompt: |
|
for section in text.split(', '): |
|
premises = section.split(" and ") |
|
if len(premises) > 1: |
|
for premise in premises[:-1]: |
|
s += premise + "\n\n\n" + "and" + "\n\n\n" |
|
s += premises[-1] + "\n\n\n" |
|
else: |
|
s += section + "\n\n\n" |
|
else: |
|
for equation in text.split("and"): |
|
s += equation + "\n\n\n" |
|
return print(s[:-2]) |
|
|
|
|
|
def load_model(model_id): |
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
tokenizer = T5Tokenizer.from_pretrained(model_id) |
|
model = T5ForConditionalGeneration.from_pretrained(model_id).to(device) |
|
return tokenizer, model |
|
|
|
|
|
def inference(prompt, tokenizer, model): |
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
input_ids = tokenizer.encode(prompt, return_tensors='pt', max_length=512, truncation=True).to(device) |
|
output = model.generate(input_ids=input_ids, max_length=512, early_stopping=True) |
|
generated_text = tokenizer.decode(output[0], skip_special_tokens=True) |
|
|
|
|
|
derivation = generated_text.replace("\\ ","\\") |
|
partial_symbols = derivation.split(" ") |
|
backslash_syms = set([i for i in partial_symbols if "\\" in i]) |
|
for i in range(len(partial_symbols)): |
|
sym = partial_symbols[i] |
|
for b_sym in backslash_syms: |
|
if b_sym.replace("\\","") == sym: |
|
partial_symbols[i] = b_sym |
|
return " ".join(partial_symbols) |
|
|