Update app.py
Browse files
app.py
CHANGED
@@ -34,9 +34,14 @@ def predict(image):
|
|
34 |
inputs = feature_extractor(images=image, return_tensors="pt")
|
35 |
outputs = model(**inputs)
|
36 |
logits = outputs.logits
|
37 |
-
|
|
|
|
|
|
|
|
|
|
|
38 |
logging.info("Prediction successful.")
|
39 |
-
return
|
40 |
except Exception as e:
|
41 |
logging.error("Error during prediction: %s", e)
|
42 |
return str(e)
|
@@ -45,7 +50,7 @@ def predict(image):
|
|
45 |
iface = gr.Interface(
|
46 |
fn=predict,
|
47 |
inputs=gr.Sketchpad(),
|
48 |
-
outputs=gr.
|
49 |
title="Drawing Classifier",
|
50 |
description="Draw something and the model will try to identify it!"
|
51 |
)
|
|
|
34 |
inputs = feature_extractor(images=image, return_tensors="pt")
|
35 |
outputs = model(**inputs)
|
36 |
logits = outputs.logits
|
37 |
+
probs = torch.nn.functional.softmax(logits, dim=-1)
|
38 |
+
top_probs, top_idxs = probs.topk(3, dim=-1)
|
39 |
+
top_probs = top_probs.detach().numpy()[0]
|
40 |
+
top_idxs = top_idxs.detach().numpy()[0]
|
41 |
+
top_classes = [model.config.id2label[idx] for idx in top_idxs]
|
42 |
+
result = {top_classes[i]: float(top_probs[i]) for i in range(3)}
|
43 |
logging.info("Prediction successful.")
|
44 |
+
return result
|
45 |
except Exception as e:
|
46 |
logging.error("Error during prediction: %s", e)
|
47 |
return str(e)
|
|
|
50 |
iface = gr.Interface(
|
51 |
fn=predict,
|
52 |
inputs=gr.Sketchpad(),
|
53 |
+
outputs=gr.JSON(),
|
54 |
title="Drawing Classifier",
|
55 |
description="Draw something and the model will try to identify it!"
|
56 |
)
|