jmeadows17 commited on
Commit
dab8570
1 Parent(s): a6d13ee

Upload MathT5.py

Browse files
Files changed (1) hide show
  1. MathT5.py +43 -0
MathT5.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import T5Tokenizer, T5ForConditionalGeneration
3
+
4
+ def pretty_print(text, prompt=True):
5
+ s = ""
6
+ if prompt:
7
+ for section in text.split(', '):
8
+ premises = section.split(" and ")
9
+ if len(premises) > 1:
10
+ for premise in premises[:-1]:
11
+ s += premise + "\n\n\n" + "and" + "\n\n\n"
12
+ s += premises[-1] + "\n\n\n"
13
+ else:
14
+ s += section + "\n\n\n"
15
+ else:
16
+ for equation in text.split("and"):
17
+ s += equation + "\n\n\n"
18
+ return print(s[:-2])
19
+
20
+
21
+ def load_model(model_id):
22
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
23
+ tokenizer = T5Tokenizer.from_pretrained(model_id)
24
+ model = T5ForConditionalGeneration.from_pretrained(model_id).to(device)
25
+ return tokenizer, model
26
+
27
+
28
+ def inference(prompt, tokenizer, model):
29
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
30
+ input_ids = tokenizer.encode(prompt, return_tensors='pt', max_length=512, truncation=True).to(device)
31
+ output = model.generate(input_ids=input_ids, max_length=512, early_stopping=True)
32
+ generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
33
+
34
+ # post-processing
35
+ derivation = generated_text.replace("\\ ","\\")
36
+ partial_symbols = derivation.split(" ")
37
+ backslash_syms = set([i for i in partial_symbols if "\\" in i])
38
+ for i in range(len(partial_symbols)):
39
+ sym = partial_symbols[i]
40
+ for b_sym in backslash_syms:
41
+ if b_sym.replace("\\","") == sym:
42
+ partial_symbols[i] = b_sym
43
+ return " ".join(partial_symbols)