merve HF staff commited on
Commit
b479d0e
1 Parent(s): d79ba37

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -0
app.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
+ from transformers import ViTMAEForPreTraining, ViTFeatureExtractor
6
+ from PIL import Image
7
+ import uuid
8
+
9
+ feature_extractor = ViTFeatureExtractor.from_pretrained("facebook/vit-mae-base")
10
+ model = ViTMAEForPreTraining.from_pretrained("facebook/vit-mae-base")
11
+
12
+ imagenet_mean = np.array(feature_extractor.image_mean)
13
+ imagenet_std = np.array(feature_extractor.image_std)
14
+
15
+ def show_image(image, title=''):
16
+ # image is [H, W, 3]
17
+ assert image.shape[2] == 3
18
+ unique_id = str(uuid.uuid4())
19
+ plt.imshow(torch.clip((image * imagenet_std + imagenet_mean) * 255, 0, 255).int())
20
+ plt.axis('off')
21
+ plt.savefig(f"{unique_id}.png", bbox_inches='tight', pad_inches=0)
22
+
23
+ return f"{unique_id}.png"
24
+
25
+ def visualize(image):
26
+ pixel_values = feature_extractor(image, return_tensors="pt").pixel_values
27
+ # forward pass
28
+ outputs = model(pixel_values)
29
+ y = model.unpatchify(outputs.logits)
30
+ y = torch.einsum('nchw->nhwc', y).detach().cpu()
31
+
32
+ # visualize the mask
33
+ mask = outputs.mask.detach()
34
+ mask = mask.unsqueeze(-1).repeat(1, 1, model.config.patch_size**2 *3) # (N, H*W, p*p*3)
35
+ mask = model.unpatchify(mask) # 1 is removing, 0 is keeping
36
+ mask = torch.einsum('nchw->nhwc', mask).detach().cpu()
37
+
38
+ x = torch.einsum('nchw->nhwc', pixel_values)
39
+
40
+ # masked image
41
+ im_masked = x * (1 - mask)
42
+
43
+ # MAE reconstruction pasted with visible patches
44
+ im_paste = x * (1 - mask) + y * mask
45
+
46
+ gallery_labels = ["Original Image", "Masked Image", "Reconstruction", "Reconstruction with Patches"]
47
+ gallery_out = [show_image(out) for out in [x[0], im_masked[0], y[0], im_paste[0]]]
48
+
49
+ return [(k,v) for k,v in zip(gallery_out, gallery_labels)]
50
+
51
+
52
+
53
+
54
+ with gr.Blocks() as demo:
55
+ gr.Markdown("## ViTMAE Demo")
56
+ gr.Markdown("ViTMAE is an architecture that combine masked autoencoder and Vision Transformer (ViT) for self-supervised pre-training.")
57
+ gr.Markdown("By pre-training a ViT to reconstruct pixel values for masked patches, one can get results after fine-tuning that outperform supervised pre-training.")
58
+
59
+ with gr.Row():
60
+
61
+ input_img = gr.Image()
62
+ output = gr.Gallery()
63
+
64
+ input_img.change(visualize, inputs=input_img, outputs=output)
65
+
66
+ gr.Examples([["./cat.png"]], inputs=input_img, outputs=output, fn=visualize)
67
+
68
+ demo.launch(debug=True)
69
+