|
from fastai.vision.all import * |
|
import gradio as gr |
|
|
|
def is_cat(x): |
|
return x[0].isupper() |
|
|
|
learn = load_learner('model.pkl') |
|
|
|
categories = ('Cat', 'Dog') |
|
|
|
prompts = [ |
|
"# Definitely a {}!", |
|
"# Well, that must be a {}!", |
|
"# Oh, that's a {}!", |
|
"# That's a {}!", |
|
"# Looks like a {} to me!", |
|
] |
|
|
|
failure_prompts = [ |
|
"# I'm not sure what that is.", |
|
"# I don't know what that thing is.", |
|
"# I've never seen that before.", |
|
"# Looks familiar, but unsure.", |
|
"# Something, something?", |
|
"# Beats me.", |
|
] |
|
|
|
def classify_image(img): |
|
pred,idx,probs = learn.predict(img) |
|
return dict(zip(categories, map(float,probs))) |
|
|
|
def calculate(confidence_threshold, img): |
|
classifications = classify_image(img) |
|
classification = random.choice(failure_prompts) |
|
for key, value in classifications.items(): |
|
if value > confidence_threshold: |
|
classification = random.choice(prompts).format(key) |
|
break |
|
|
|
return [classification, classifications] |
|
|
|
|
|
with gr.Blocks() as ui: |
|
|
|
heading = gr.Markdown(" # Dog or Cat?", render=False) |
|
results = gr.Label(value="Waiting to receive image.", label="Details", show_label=False, render=False) |
|
|
|
with gr.Row(equal_height=True): |
|
|
|
with gr.Column(): |
|
gr.Markdown("Upload an image of a cat or a dog.") |
|
|
|
with gr.Group(): |
|
image = gr.Image(show_label=False, height=300) |
|
confidence = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, label="Confidence Threshold") |
|
btn = gr.Button(value="Classify") |
|
btn.click(calculate, inputs=[confidence, image], outputs=[heading, results]) |
|
|
|
with gr.Column(): |
|
gr.Markdown("Then wait for the magic to happen") |
|
with gr.Group(): |
|
results.render() |
|
heading.render() |
|
|
|
gr.Markdown(" # Examples") |
|
with gr.Group(): |
|
gr.Examples(inputs=image, examples=['images/cat1.jpeg', 'images/cat2.jpeg', 'images/cat3.jpeg', 'images/dog1.jpeg', 'images/dog2.jpeg', 'images/dog3.jpeg']) |
|
|
|
if __name__ == "__main__": |
|
ui.launch() |