Add original model results
Browse files
app.py
CHANGED
@@ -74,6 +74,7 @@ def generate_visualization(model, original_image, class_index=None):
|
|
74 |
return vis
|
75 |
|
76 |
model_finetuned = None
|
|
|
77 |
|
78 |
normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
79 |
transform_224 = transforms.Compose([
|
@@ -93,10 +94,14 @@ def image_classifier(inp):
|
|
93 |
prediction = torch.nn.functional.softmax(model_finetuned(image.unsqueeze(0))[0], dim=0)
|
94 |
confidences = {labels[i]: float(prediction[i]) for i in range(1000)}
|
95 |
heatmap = generate_visualization(model_finetuned, image)
|
96 |
-
|
|
|
|
|
|
|
|
|
97 |
|
98 |
def _load_model(model_name: str):
|
99 |
-
global model_finetuned
|
100 |
path = hf_hub_download('Hila/RobustViT',
|
101 |
f'{model_name}')
|
102 |
|
@@ -108,5 +113,5 @@ def _load_model(model_name: str):
|
|
108 |
model_finetuned.eval()
|
109 |
|
110 |
_load_model('ar_base.tar')
|
111 |
-
demo = gr.Interface(image_classifier, gr.inputs.Image(shape=(224,224)), [gr.outputs.Label(label="Our Classification", num_top_classes=3), gr.Image(label="Our Relevance",shape=(224,224))],examples=["samples/augreg_base/tank.png", "samples/augreg_base/sundial.png", "samples/augreg_base/lizard.png", "samples/augreg_base/storck.png", "samples/augreg_base/hummingbird2.png", "samples/augreg_base/hummingbird.png"], capture_session=True)
|
112 |
demo.launch(debug=True)
|
|
|
74 |
return vis
|
75 |
|
76 |
model_finetuned = None
|
77 |
+
model = None
|
78 |
|
79 |
normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
80 |
transform_224 = transforms.Compose([
|
|
|
94 |
prediction = torch.nn.functional.softmax(model_finetuned(image.unsqueeze(0))[0], dim=0)
|
95 |
confidences = {labels[i]: float(prediction[i]) for i in range(1000)}
|
96 |
heatmap = generate_visualization(model_finetuned, image)
|
97 |
+
|
98 |
+
prediction_orig = torch.nn.functional.softmax(model(image.unsqueeze(0))[0], dim=0)
|
99 |
+
confidences_orig = {labels[i]: float(prediction_orig[i]) for i in range(1000)}
|
100 |
+
heatmap_orig = generate_visualization(model, image)
|
101 |
+
return confidences, heatmap, confidences_orig, heatmap_orig
|
102 |
|
103 |
def _load_model(model_name: str):
|
104 |
+
global model_finetuned, model
|
105 |
path = hf_hub_download('Hila/RobustViT',
|
106 |
f'{model_name}')
|
107 |
|
|
|
113 |
model_finetuned.eval()
|
114 |
|
115 |
_load_model('ar_base.tar')
|
116 |
+
demo = gr.Interface(image_classifier, gr.inputs.Image(shape=(224,224)), [gr.outputs.Label(label="Our Classification", num_top_classes=3), gr.Image(label="Our Relevance",shape=(224,224)), gr.outputs.Label(label="Original Classification", num_top_classes=3), gr.Image(label="Original Relevance",shape=(224,224))],examples=["samples/augreg_base/tank.png", "samples/augreg_base/sundial.png", "samples/augreg_base/lizard.png", "samples/augreg_base/storck.png", "samples/augreg_base/hummingbird2.png", "samples/augreg_base/hummingbird.png"], capture_session=True)
|
117 |
demo.launch(debug=True)
|