Jhp's picture
2424
825f4db
raw
history blame
719 Bytes
import gradio as gr
from visualization import visualization
# pipeline = pipeline(task="image-classification", model="julien-c/hotdog-not-hotdog")
# pipeline = pipeline(task="image-classification", model="jhp/hoi")
def predict(image,threshold,topk,device=''):
vis_img = visualization(image,threshold,topk)
return vis_img
gr.Interface(
predict,
inputs=[gr.Image(type='pil',label="input image"),
gr.Slider(0, 1, value=0.4, label="Threshold", info="Set detection score threshold between 0~1"),
gr.Number(value=5,label='Topk',info='Topk prediction')],
outputs= gr.Image(type="pil", label="hoi detection results"),
title="HOI detection",
).launch(debug=True,enable_queue=True)