vit-explainer / app.py
andreped's picture
Scaled saliency
1281282
import requests
import re
import gradio as gr
import numpy as np
from torch import topk
from torch.nn.functional import softmax
from transformers import ViTImageProcessor, ViTForImageClassification
from transformers_interpret import ImageClassificationExplainer
def load_label_data():
file_url = "https://gist.githubusercontent.com/yrevar/942d3a0ac09ec9e5eb3a/raw/238f720ff059c1f82f368259d1ca4ffa5dd8f9f5/imagenet1000_clsidx_to_labels.txt"
response = requests.get(file_url)
labels = []
pattern = '["\'](.*?)["\']'
for line in response.text.split('\n'):
try:
tmp = re.findall(pattern, line)[0]
labels.append(tmp)
except IndexError:
pass
return labels
class WebUI:
def __init__(self):
super().__init__()
self.nb_classes = 10
self.processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')
self.model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
self.labels = load_label_data()
def run_model(self, image):
inputs = self.processor(images=image, return_tensors="pt")
outputs = self.model(**inputs)
outputs = softmax(outputs.logits, dim=1)
outputs = topk(outputs, k=self.nb_classes)
return outputs
def classify_image(self, image):
top10 = self.run_model(image)
return {self.labels[top10[1][0][i]]: float(top10[0][0][i]) for i in range(self.nb_classes)}
def explain_pred(self, image):
image_classification_explainer = ImageClassificationExplainer(model=self.model, feature_extractor=self.processor)
saliency = image_classification_explainer(image)
saliency = np.squeeze(np.moveaxis(saliency, 1, 3))
saliency[saliency >= 0.05] = 0.05
saliency[saliency <= -0.05] = -0.05
saliency /= np.amax(np.abs(saliency))
return saliency
def run(self):
examples=[
['https://github.com/andreped/INF1600-ai-workshop/releases/download/Examples/cat.jpg'],
['https://github.com/andreped/INF1600-ai-workshop/releases/download/Examples/dog.jpeg'],
]
with gr.Blocks() as demo:
with gr.Row():
image = gr.Image(height=512)
label = gr.Label(num_top_classes=self.nb_classes)
saliency = gr.Image(height=512, label="saliency map", show_label=True)
with gr.Column(scale=0.2, min_width=150):
run_btn = gr.Button("Run analysis", variant="primary", elem_id="run-button")
run_btn.click(
fn=lambda x: self.explain_pred(x),
inputs=image,
outputs=saliency,
)
run_btn.click(
fn=lambda x: self.classify_image(x),
inputs=image,
outputs=label,
)
gr.Examples(
examples=[
['https://github.com/andreped/INF1600-ai-workshop/releases/download/Examples/cat.jpg'],
['https://github.com/andreped/INF1600-ai-workshop/releases/download/Examples/dog.jpeg'],
],
inputs=image,
outputs=image,
fn=lambda x: x,
cache_examples=True,
)
demo.queue().launch(server_name="0.0.0.0", server_port=7860, share=False)
def main():
ui = WebUI()
ui.run()
if __name__ == "__main__":
main()