File size: 852 Bytes
8c5684e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import gradio as gr
from fastai.vision.all import *

mapping = {'cat_counter/cat': 1.0, 'cat_counter/two cats': 2.0, 'cat_counter/three cats': 3.0}

def get_number_of_cats(file_name): 
  return mapping.get(str(file_name.parent), 0)

learn = load_learner('./cat_counter.pkl')

mapping = {'cat_counter/cat': 1.0, 'cat_counter/two cats': 2.0, 'cat_counter/three cats': 3.0}

def get_number_of_cats(file_name): 
  return mapping.get(str(file_name.parent), 0)

learn = load_learner('cat_counter.pkl')

def predict(img):
    img = PILImage.create(img)
    pred,pred_idx,probs = learn.predict(img)
    print(type(probs))
    return torch.round(probs)

title = "Cat Counter"
description = "A model that counts cats"
gr.Interface(fn=predict, inputs=gr.inputs.Image(shape=(512, 512)), outputs="number", title=title, description=description).launch(share=True)