Hila commited on
Commit
5bba327
·
1 Parent(s): e0f37a0

Add original model results

Browse files
Files changed (1) hide show
  1. app.py +8 -3
app.py CHANGED
@@ -74,6 +74,7 @@ def generate_visualization(model, original_image, class_index=None):
74
  return vis
75
 
76
  model_finetuned = None
 
77
 
78
  normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
79
  transform_224 = transforms.Compose([
@@ -93,10 +94,14 @@ def image_classifier(inp):
93
  prediction = torch.nn.functional.softmax(model_finetuned(image.unsqueeze(0))[0], dim=0)
94
  confidences = {labels[i]: float(prediction[i]) for i in range(1000)}
95
  heatmap = generate_visualization(model_finetuned, image)
96
- return confidences, heatmap
 
 
 
 
97
 
98
  def _load_model(model_name: str):
99
- global model_finetuned
100
  path = hf_hub_download('Hila/RobustViT',
101
  f'{model_name}')
102
 
@@ -108,5 +113,5 @@ def _load_model(model_name: str):
108
  model_finetuned.eval()
109
 
110
  _load_model('ar_base.tar')
111
- demo = gr.Interface(image_classifier, gr.inputs.Image(shape=(224,224)), [gr.outputs.Label(label="Our Classification", num_top_classes=3), gr.Image(label="Our Relevance",shape=(224,224))],examples=["samples/augreg_base/tank.png", "samples/augreg_base/sundial.png", "samples/augreg_base/lizard.png", "samples/augreg_base/storck.png", "samples/augreg_base/hummingbird2.png", "samples/augreg_base/hummingbird.png"], capture_session=True)
112
  demo.launch(debug=True)
 
74
  return vis
75
 
76
  model_finetuned = None
77
+ model = None
78
 
79
  normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
80
  transform_224 = transforms.Compose([
 
94
  prediction = torch.nn.functional.softmax(model_finetuned(image.unsqueeze(0))[0], dim=0)
95
  confidences = {labels[i]: float(prediction[i]) for i in range(1000)}
96
  heatmap = generate_visualization(model_finetuned, image)
97
+
98
+ prediction_orig = torch.nn.functional.softmax(model(image.unsqueeze(0))[0], dim=0)
99
+ confidences_orig = {labels[i]: float(prediction_orig[i]) for i in range(1000)}
100
+ heatmap_orig = generate_visualization(model, image)
101
+ return confidences, heatmap, confidences_orig, heatmap_orig
102
 
103
  def _load_model(model_name: str):
104
+ global model_finetuned, model
105
  path = hf_hub_download('Hila/RobustViT',
106
  f'{model_name}')
107
 
 
113
  model_finetuned.eval()
114
 
115
  _load_model('ar_base.tar')
116
+ demo = gr.Interface(image_classifier, gr.inputs.Image(shape=(224,224)), [gr.outputs.Label(label="Our Classification", num_top_classes=3), gr.Image(label="Our Relevance",shape=(224,224)), gr.outputs.Label(label="Original Classification", num_top_classes=3), gr.Image(label="Original Relevance",shape=(224,224))],examples=["samples/augreg_base/tank.png", "samples/augreg_base/sundial.png", "samples/augreg_base/lizard.png", "samples/augreg_base/storck.png", "samples/augreg_base/hummingbird2.png", "samples/augreg_base/hummingbird.png"], capture_session=True)
117
  demo.launch(debug=True)