Spaces:
Runtime error
Runtime error
File size: 3,670 Bytes
f073f4b 1281282 f073f4b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
import requests
import re
import gradio as gr
import numpy as np
from torch import topk
from torch.nn.functional import softmax
from transformers import ViTImageProcessor, ViTForImageClassification
from transformers_interpret import ImageClassificationExplainer
def load_label_data():
file_url = "https://gist.githubusercontent.com/yrevar/942d3a0ac09ec9e5eb3a/raw/238f720ff059c1f82f368259d1ca4ffa5dd8f9f5/imagenet1000_clsidx_to_labels.txt"
response = requests.get(file_url)
labels = []
pattern = '["\'](.*?)["\']'
for line in response.text.split('\n'):
try:
tmp = re.findall(pattern, line)[0]
labels.append(tmp)
except IndexError:
pass
return labels
class WebUI:
def __init__(self):
super().__init__()
self.nb_classes = 10
self.processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')
self.model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
self.labels = load_label_data()
def run_model(self, image):
inputs = self.processor(images=image, return_tensors="pt")
outputs = self.model(**inputs)
outputs = softmax(outputs.logits, dim=1)
outputs = topk(outputs, k=self.nb_classes)
return outputs
def classify_image(self, image):
top10 = self.run_model(image)
return {self.labels[top10[1][0][i]]: float(top10[0][0][i]) for i in range(self.nb_classes)}
def explain_pred(self, image):
image_classification_explainer = ImageClassificationExplainer(model=self.model, feature_extractor=self.processor)
saliency = image_classification_explainer(image)
saliency = np.squeeze(np.moveaxis(saliency, 1, 3))
saliency[saliency >= 0.05] = 0.05
saliency[saliency <= -0.05] = -0.05
saliency /= np.amax(np.abs(saliency))
return saliency
def run(self):
examples=[
['https://github.com/andreped/INF1600-ai-workshop/releases/download/Examples/cat.jpg'],
['https://github.com/andreped/INF1600-ai-workshop/releases/download/Examples/dog.jpeg'],
]
with gr.Blocks() as demo:
with gr.Row():
image = gr.Image(height=512)
label = gr.Label(num_top_classes=self.nb_classes)
saliency = gr.Image(height=512, label="saliency map", show_label=True)
with gr.Column(scale=0.2, min_width=150):
run_btn = gr.Button("Run analysis", variant="primary", elem_id="run-button")
run_btn.click(
fn=lambda x: self.explain_pred(x),
inputs=image,
outputs=saliency,
)
run_btn.click(
fn=lambda x: self.classify_image(x),
inputs=image,
outputs=label,
)
gr.Examples(
examples=[
['https://github.com/andreped/INF1600-ai-workshop/releases/download/Examples/cat.jpg'],
['https://github.com/andreped/INF1600-ai-workshop/releases/download/Examples/dog.jpeg'],
],
inputs=image,
outputs=image,
fn=lambda x: x,
cache_examples=True,
)
demo.queue().launch(server_name="0.0.0.0", server_port=7860, share=False)
def main():
ui = WebUI()
ui.run()
if __name__ == "__main__":
main()
|