import torch from transformers import T5Tokenizer, T5ForConditionalGeneration def pretty_print(text, prompt=True): s = "" if prompt: for section in text.split(', '): premises = section.split(" and ") if len(premises) > 1: for premise in premises[:-1]: s += premise + "\n\n\n" + "and" + "\n\n\n" s += premises[-1] + "\n\n\n" else: s += section + "\n\n\n" else: for equation in text.split("and"): s += equation + "\n\n\n" return print(s[:-3]) def load_model(model_id): device = 'cuda' if torch.cuda.is_available() else 'cpu' tokenizer = T5Tokenizer.from_pretrained(model_id) model = T5ForConditionalGeneration.from_pretrained(model_id).to(device) return tokenizer, model def inference(prompt, tokenizer, model): device = 'cuda' if torch.cuda.is_available() else 'cpu' input_ids = tokenizer.encode(prompt, return_tensors='pt', max_length=512, truncation=True).to(device) output = model.generate(input_ids=input_ids, max_length=512, early_stopping=True) generated_text = tokenizer.decode(output[0], skip_special_tokens=True) # post-processing derivation = generated_text.replace("\\ ","\\") partial_symbols = derivation.split(" ") backslash_syms = set([i for i in partial_symbols if "\\" in i]) for i in range(len(partial_symbols)): sym = partial_symbols[i] for b_sym in backslash_syms: if b_sym.replace("\\","") == sym: partial_symbols[i] = b_sym return " ".join(partial_symbols)