File size: 2,113 Bytes
3f9313f 94d26a2 3f9313f 94d26a2 3f9313f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
from fastai.vision.all import *
import gradio as gr
def is_cat(x):
return x[0].isupper()
learn = load_learner('model.pkl')
categories = ('Cat', 'Dog')
prompts = [
"# Definitely a {}!",
"# Well, that must be a {}!",
"# Oh, that's a {}!",
"# That's a {}!",
"# Looks like a {} to me!",
]
failure_prompts = [
"# I'm not sure what that is.",
"# I don't know what that thing is.",
"# I've never seen that before.",
"# Looks familiar, but unsure.",
"# Something, something?",
"# Beats me.",
]
def classify_image(img):
pred,idx,probs = learn.predict(img)
return dict(zip(categories, map(float,probs)))
def calculate(confidence_threshold, img):
classifications = classify_image(img)
classification = random.choice(failure_prompts)
for key, value in classifications.items():
if value > confidence_threshold:
classification = random.choice(prompts).format(key)
break
return [classification, classifications]
with gr.Blocks() as ui:
heading = gr.Markdown(" # Dog or Cat?", render=False)
results = gr.Label(value="Waiting to receive image.", label="Details", show_label=False, render=False)
with gr.Row(equal_height=True):
with gr.Column():
gr.Markdown("Upload an image of a cat or a dog.")
with gr.Group():
image = gr.Image(show_label=False, height=300)
confidence = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, label="Confidence Threshold")
btn = gr.Button(value="Classify")
btn.click(calculate, inputs=[confidence, image], outputs=[heading, results])
with gr.Column():
gr.Markdown("Then wait for the magic to happen")
with gr.Group():
results.render()
heading.render()
gr.Markdown(" # Examples")
with gr.Group():
gr.Examples(inputs=image, examples=['images/cat1.jpeg', 'images/cat2.jpeg', 'images/cat3.jpeg', 'images/dog1.jpeg', 'images/dog2.jpeg', 'images/dog3.jpeg'])
if __name__ == "__main__":
ui.launch() |