Jangai commited on
Commit
528548d
·
verified ·
1 Parent(s): 591d82b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -3
app.py CHANGED
@@ -34,9 +34,14 @@ def predict(image):
34
  inputs = feature_extractor(images=image, return_tensors="pt")
35
  outputs = model(**inputs)
36
  logits = outputs.logits
37
- predicted_class_idx = logits.argmax(-1).item()
 
 
 
 
 
38
  logging.info("Prediction successful.")
39
- return model.config.id2label[predicted_class_idx]
40
  except Exception as e:
41
  logging.error("Error during prediction: %s", e)
42
  return str(e)
@@ -45,7 +50,7 @@ def predict(image):
45
  iface = gr.Interface(
46
  fn=predict,
47
  inputs=gr.Sketchpad(),
48
- outputs=gr.Label(),
49
  title="Drawing Classifier",
50
  description="Draw something and the model will try to identify it!"
51
  )
 
34
  inputs = feature_extractor(images=image, return_tensors="pt")
35
  outputs = model(**inputs)
36
  logits = outputs.logits
37
+ probs = torch.nn.functional.softmax(logits, dim=-1)
38
+ top_probs, top_idxs = probs.topk(3, dim=-1)
39
+ top_probs = top_probs.detach().numpy()[0]
40
+ top_idxs = top_idxs.detach().numpy()[0]
41
+ top_classes = [model.config.id2label[idx] for idx in top_idxs]
42
+ result = {top_classes[i]: float(top_probs[i]) for i in range(3)}
43
  logging.info("Prediction successful.")
44
+ return result
45
  except Exception as e:
46
  logging.error("Error during prediction: %s", e)
47
  return str(e)
 
50
  iface = gr.Interface(
51
  fn=predict,
52
  inputs=gr.Sketchpad(),
53
+ outputs=gr.JSON(),
54
  title="Drawing Classifier",
55
  description="Draw something and the model will try to identify it!"
56
  )