cat_counter / app.py
AylinM's picture
Upload app.py
8c5684e
raw
history blame
852 Bytes
import gradio as gr
from fastai.vision.all import *
mapping = {'cat_counter/cat': 1.0, 'cat_counter/two cats': 2.0, 'cat_counter/three cats': 3.0}
def get_number_of_cats(file_name):
return mapping.get(str(file_name.parent), 0)
learn = load_learner('./cat_counter.pkl')
mapping = {'cat_counter/cat': 1.0, 'cat_counter/two cats': 2.0, 'cat_counter/three cats': 3.0}
def get_number_of_cats(file_name):
return mapping.get(str(file_name.parent), 0)
learn = load_learner('cat_counter.pkl')
def predict(img):
img = PILImage.create(img)
pred,pred_idx,probs = learn.predict(img)
print(type(probs))
return torch.round(probs)
title = "Cat Counter"
description = "A model that counts cats"
gr.Interface(fn=predict, inputs=gr.inputs.Image(shape=(512, 512)), outputs="number", title=title, description=description).launch(share=True)