Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
from transformers import ViTMAEForPreTraining, ViTFeatureExtractor
|
6 |
+
from PIL import Image
|
7 |
+
import uuid
|
8 |
+
|
9 |
+
feature_extractor = ViTFeatureExtractor.from_pretrained("facebook/vit-mae-base")
|
10 |
+
model = ViTMAEForPreTraining.from_pretrained("facebook/vit-mae-base")
|
11 |
+
|
12 |
+
imagenet_mean = np.array(feature_extractor.image_mean)
|
13 |
+
imagenet_std = np.array(feature_extractor.image_std)
|
14 |
+
|
15 |
+
def show_image(image, title=''):
|
16 |
+
# image is [H, W, 3]
|
17 |
+
assert image.shape[2] == 3
|
18 |
+
unique_id = str(uuid.uuid4())
|
19 |
+
plt.imshow(torch.clip((image * imagenet_std + imagenet_mean) * 255, 0, 255).int())
|
20 |
+
plt.axis('off')
|
21 |
+
plt.savefig(f"{unique_id}.png", bbox_inches='tight', pad_inches=0)
|
22 |
+
|
23 |
+
return f"{unique_id}.png"
|
24 |
+
|
25 |
+
def visualize(image):
|
26 |
+
pixel_values = feature_extractor(image, return_tensors="pt").pixel_values
|
27 |
+
# forward pass
|
28 |
+
outputs = model(pixel_values)
|
29 |
+
y = model.unpatchify(outputs.logits)
|
30 |
+
y = torch.einsum('nchw->nhwc', y).detach().cpu()
|
31 |
+
|
32 |
+
# visualize the mask
|
33 |
+
mask = outputs.mask.detach()
|
34 |
+
mask = mask.unsqueeze(-1).repeat(1, 1, model.config.patch_size**2 *3) # (N, H*W, p*p*3)
|
35 |
+
mask = model.unpatchify(mask) # 1 is removing, 0 is keeping
|
36 |
+
mask = torch.einsum('nchw->nhwc', mask).detach().cpu()
|
37 |
+
|
38 |
+
x = torch.einsum('nchw->nhwc', pixel_values)
|
39 |
+
|
40 |
+
# masked image
|
41 |
+
im_masked = x * (1 - mask)
|
42 |
+
|
43 |
+
# MAE reconstruction pasted with visible patches
|
44 |
+
im_paste = x * (1 - mask) + y * mask
|
45 |
+
|
46 |
+
gallery_labels = ["Original Image", "Masked Image", "Reconstruction", "Reconstruction with Patches"]
|
47 |
+
gallery_out = [show_image(out) for out in [x[0], im_masked[0], y[0], im_paste[0]]]
|
48 |
+
|
49 |
+
return [(k,v) for k,v in zip(gallery_out, gallery_labels)]
|
50 |
+
|
51 |
+
|
52 |
+
|
53 |
+
|
54 |
+
with gr.Blocks() as demo:
|
55 |
+
gr.Markdown("## ViTMAE Demo")
|
56 |
+
gr.Markdown("ViTMAE is an architecture that combine masked autoencoder and Vision Transformer (ViT) for self-supervised pre-training.")
|
57 |
+
gr.Markdown("By pre-training a ViT to reconstruct pixel values for masked patches, one can get results after fine-tuning that outperform supervised pre-training.")
|
58 |
+
|
59 |
+
with gr.Row():
|
60 |
+
|
61 |
+
input_img = gr.Image()
|
62 |
+
output = gr.Gallery()
|
63 |
+
|
64 |
+
input_img.change(visualize, inputs=input_img, outputs=output)
|
65 |
+
|
66 |
+
gr.Examples([["./cat.png"]], inputs=input_img, outputs=output, fn=visualize)
|
67 |
+
|
68 |
+
demo.launch(debug=True)
|
69 |
+
|