import gradio as gr from fastai.vision.all import * mapping = {'cat_counter/cat': 1.0, 'cat_counter/two cats': 2.0, 'cat_counter/three cats': 3.0} def get_number_of_cats(file_name): return mapping.get(str(file_name.parent), 0) learn = load_learner('./cat_counter.pkl') def predict(img): img = PILImage.create(img) pred,pred_idx,probs = learn.predict(img) print(type(probs)) return torch.round(probs) title = "Cat Counter" description = "A model that counts cats" gr.Interface(fn=predict, inputs=gr.inputs.Image(shape=(512, 512)), outputs="number", title=title, description=description).launch(inline=False)