semantic-demo / app.py
jwalanthi's picture
checks input
5288696
raw
history blame
1.09 kB
import gradio as gr
import torch
import lightning
from minicons import cwe
import pandas as pd
import numpy as np
from model import FeatureNormPredictor
import sys
sys.path.insert(0, '/home/jjr4354/semantic-features')
def predict (word, sentence, lm_name, layer, norm):
if word not in sentence: return "invalid input: word not in sentence"
model_name = lm_name + str(layer) + '_to_' + norm
lm = cwe.CWE('bert-base-uncased')
if layer not in range (lm.layers): return "invalid input: layer not in lm"
model = FeatureNormPredictor.load_from_checkpoint(
checkpoint_path=model_name+'.ckpt',
map_location=None
)
model.eval()
inputs = [word, sentence, lm_name, str(layer), norm]
outputs = [input+'\t'+str(np.random.randint(0,100, size=1)[0]) for input in inputs]
return "\n".join(outputs)
demo = gr.Interface(
fn=predict,
inputs=[
"text",
"text",
gr.Radio(["bert", "roberta", "electra"]),
"number",
gr.Radio(["Binder", "McRae", "Buchanan"]),
],
outputs=["text"],
)
demo.launch()