cc-bert / app.py
Shrey's picture
added changes for file name
dfbf9e3
raw
history blame
2.3 kB
#the inference function
from transformers import FillMaskPipeline ,DistilBertTokenizer,TFAutoModelForMaskedLM
from transformers import BertTokenizer
#load the tokenizer
tokenizer_path_1="./vocab.txt"
tokenizer_1 = BertTokenizer.from_pretrained(tokenizer_path_1)
#load the model path
model_path="./bert_lm_10"
model_1 = TFAutoModelForMaskedLM.from_pretrained(model_path)
#build the unmasker pipeline using HF for inference
unmasker = FillMaskPipeline(model=model_1,tokenizer=tokenizer_1)
#try on a sample of txt
txt="a polynomial [MASK] from 3-SAT." #reduction
#results=unmasker(txt,top_k=5)
#show the results
for res in results:
print(res["sequence"])
print(res["score"])
#make a function out of the unmasker
def unmask_words(txt_with_mask,k_suggestions=5):
results=unmasker(txt_with_mask,top_k=k_suggestions)
labels={}
for res in results:
labels["".join(res["token_str"].split(" "))]=res["score"]
return labels
#trying our function
#val=unmask_words(txt)
import gradio as gr
description="""CC bert is a MLM model pretrained on data collected from ~200k papers in mainly Computational Complexity
or related domain. For more information visit [Theoremkb Project](https://github.com/PierreSenellart/theoremkb)
or contact [[email protected]]([email protected]).
"""
examples=[["as pspace is [MASK] under complement."],
["n!-(n-1)[MASK]"],
["[MASK] these two classes is a major problem."],
["This would show that the polynomial heirarchy at the second [MASK], which is considered only"],
["""we consider two ways of measuring complexity, data complexity, which is with respect to the size of the data,
and their combined [MASK]"""]
]
input_box=gr.inputs.Textbox(lines=20,placeholder="Unifying computational entropies via Kullback–Leibler [MASK]",label="Enter the masked text:")
interface=gr.Interface(fn=unmask_words,inputs=[input_box,
gr.inputs.Slider(1,10,1,5,label="No of Suggestions:")],
outputs=gr.outputs.Label(label="top words:"),
examples=examples,
title="CC-Bert MLM",description=description)
interface.launch(debug=True,share=True,auth=("test", "test"))