qq456cvb commited on
Commit
d8eb09b
·
1 Parent(s): eb2eeeb

upload app.py

Browse files
app.py CHANGED
@@ -1,7 +1,220 @@
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import matplotlib.pyplot as plt
3
+ import numpy as np
4
+ import os
5
+ import requests
6
+ import spaces
7
+ import timm
8
+ import torch
9
+ import torchvision.transforms as T
10
+ import types
11
+ import albumentations as A
12
+ import torch.nn.functional as F
13
 
14
+ from PIL import Image
15
+ from tqdm import tqdm
16
+ from sklearn.decomposition import PCA
17
+ from torch_kmeans import KMeans, CosineSimilarity
18
 
19
+ cmap = plt.get_cmap("tab20")
20
+ imagenet_transform = T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
21
+
22
+ def get_bg_mask(image):
23
+ # detect background based on the four edges
24
+ image = np.array(image)
25
+ if np.all(image[:, 0] == image[0, 0]) and np.all(image[:, -1] == image[0, -1]) \
26
+ and np.all(image[0, :] == image[0, 0]) and np.all(image[-1, :] == image[-1, 0]) \
27
+ and np.all(image[0, 0] == image[0, -1]) and np.all(image[0, 0] == image[-1, 0]) \
28
+ and np.all(image[0, 0] == image[-1, -1]):
29
+ return np.any(image != image[0, 0], -1)
30
+ return np.ones_like(image[:, :, 0], dtype=bool)
31
+
32
+
33
+ def download_image(url, save_path):
34
+ response = requests.get(url)
35
+ with open(save_path, 'wb') as file:
36
+ file.write(response.content)
37
+
38
+
39
+ def process_image(image, res, patch_size, decimation=4):
40
+ image = torch.from_numpy(np.array(image) / 255.).float().permute(2, 0, 1).to(device)
41
+
42
+ tgt_size = (int(image.shape[-2] * res / image.shape[-1]), res)
43
+ if image.shape[-2] > image.shape[-1]:
44
+ tgt_size = (res, int(image.shape[-1] * res / image.shape[-2]))
45
+
46
+ patch_h, patch_w = tgt_size[0] // decimation, tgt_size[1] // decimation
47
+ image_resized = T.functional.resize(image, (patch_h * patch_size, patch_w * patch_size))
48
+
49
+ image_resized = imagenet_transform(image_resized)
50
+
51
+ return image_resized
52
+
53
+ def generate_grid(x, y, stride):
54
+ x_coords = np.arange(0, x, grid_stride)
55
+ y_coords = np.arange(0, y, grid_stride)
56
+
57
+ x_mesh, y_mesh = np.meshgrid(x_coords, y_coords)
58
+ kp = np.column_stack((x_mesh.ravel(), y_mesh.ravel())).astype(float)
59
+ return kp
60
+
61
+ def pca(feat, pca_dim=3):
62
+ feat_flattened = feat
63
+ mean = torch.mean(feat_flattened, dim=0)
64
+ centered_features = feat_flattened - mean
65
+ U, S, V = torch.pca_lowrank(centered_features, q=pca_dim)
66
+ reduced_features = torch.matmul(centered_features, V[:, :pca_dim])
67
+ return reduced_features
68
+
69
+ def co_pca(feat1, feat2, pca_dim=3):
70
+ co_feats = torch.cat((feat1.reshape(-1, feat1.shape[-1]), feat2.reshape(-1, feat2.shape[-1])), dim=0)
71
+ feats = pca(co_feats)
72
+ feat1_pca = feats[:feat1.shape[0]*feat1.shape[1]].reshape(feat1.shape[0], feat1.shape[1], -1)
73
+ feat2_pca = feats[feat1.shape[0]*feat1.shape[1]:].reshape(feat2.shape[0], feat2.shape[1], -1)
74
+ return feat1_pca, feat2_pca
75
+
76
+
77
+ def draw_correspondence(feat1, feat2, color1, mask1, mask2):
78
+ original_mask2_shape = mask2.shape
79
+ mask1, mask2 = mask1.reshape(-1), mask2.reshape(-1)
80
+ distances = torch.cdist(feat1.reshape(-1, feat1.shape[-1])[mask1], feat2.reshape(-1, feat2.shape[-1])[mask2])
81
+ nearest = torch.argmin(distances, dim=0)
82
+ color2 = torch.zeros((mask2.shape[0], 3,)).to(device)
83
+ color2[mask2] = color1.reshape(-1, 3)[mask1][nearest]
84
+ color2 = color2.reshape(*original_mask2_shape, 3)
85
+ return color2
86
+
87
+ def load_model(options):
88
+ original_models = {}
89
+ fine_models = {}
90
+ for option in tqdm(options):
91
+ print('Please wait ...')
92
+ print('loading weights of ', option)
93
+ fine_models[option] = torch.hub.load(".", model_card[option], source='local').to(device)
94
+ original_models[option] = torch.hub.load(repo_or_dir="facebookresearch/dinov2", model=fine_models[option].backbone_name).eval().to(device)
95
+ print('Done! Now play the demo :)')
96
+ return original_models, fine_models
97
+
98
+ if __name__ == "__main__":
99
+
100
+ if torch.cuda.is_available():
101
+ device = torch.device('cuda')
102
+ else:
103
+ device = torch.device('cpu')
104
+
105
+ print("device: ")
106
+ print(device)
107
+
108
+ example_dir = "examples"
109
+
110
+ os.makedirs(example_dir, exist_ok=True)
111
+
112
+ image_input1 = gr.Image(label="Choose an image:",
113
+ height=500,
114
+ type="pil",
115
+ image_mode='RGB',
116
+ sources=['upload', 'webcam', 'clipboard']
117
+ )
118
+ image_input2 = gr.Image(label="Choose another image:",
119
+ height=500,
120
+ type="pil",
121
+ image_mode='RGB',
122
+ sources=['upload', 'webcam', 'clipboard']
123
+ )
124
+
125
+ options = ['DINOv2-Base']
126
+ model_option = gr.Radio(options, value="DINOv2-Base", label='Choose a 2D foundation model')
127
+
128
+ model_card = {
129
+ "DINOv2-Base": "dinov2_base",
130
+ }
131
+
132
+
133
+ os.environ['TORCH_HOME'] = '/tmp/.cache'
134
+ # os.environ['GRADIO_EXAMPLES_CACHE'] = '/tmp/gradio_cache'
135
+
136
+ # Pre-load all models
137
+ original_models, fine_models = load_model(options)
138
+
139
+ @spaces.GPU
140
+ def main(image1, image2, model_option, kmeans_num):
141
+ if image1 is None or image2 is None:
142
+ return None
143
+ # Select model
144
+ original_model = original_models[model_option]
145
+ fine_model = fine_models[model_option]
146
+
147
+ images_resized = [process_image(image, 640, 14, decimation=8) for image in [image1, image2]]
148
+ masks = [torch.from_numpy(get_bg_mask(image)).to(device) for image in [image1, image2]]
149
+ feat_shapes = [(images_resized[0].shape[-2] // 14, images_resized[0].shape[-1] // 14),
150
+ (images_resized[1].shape[-2] // 14, images_resized[1].shape[-1] // 14)]
151
+
152
+ masks_resized = [T.functional.resize(mask.float()[None], feat_shape,
153
+ interpolation=T.functional.InterpolationMode.NEAREST_EXACT)[0]
154
+ for mask, feat_shape in zip(masks, feat_shapes)]
155
+
156
+ with torch.no_grad():
157
+ original_feats = [original_model.forward_features(image[None])['x_norm_patchtokens'].reshape(*feat_shape, -1)
158
+ for image, feat_shape in zip(images_resized, feat_shapes)]
159
+ original_feats = [F.normalize(feat, p=2, dim=-1) for feat in original_feats]
160
+
161
+ original_color1 = torch.zeros((original_feats[0].shape[0] * original_feats[0].shape[1], 3,)).to(device)
162
+ color = pca((original_feats[0][masks_resized[0] > 0]), 3)
163
+ color = (color - color.min()) / (color.max() - color.min())
164
+ original_color1[masks_resized[0].reshape(-1) > 0] = color
165
+ original_color1 = original_color1.reshape(*original_feats[0].shape[:2], 3)
166
+
167
+ original_color2 = draw_correspondence(original_feats[0], original_feats[1], original_color1,
168
+ masks_resized[0] > 0, masks_resized[1] > 0)
169
+
170
+ fine_feats = [fine_model.dinov2.forward_features(image[None])['x_norm_patchtokens'].reshape(*feat_shape, -1)
171
+ for image, feat_shape in zip(images_resized, feat_shapes)]
172
+ fine_feats = [fine_model.refine_conv(feat[None].permute(0, 3, 1, 2)).permute(0, 2, 3, 1)[0] for feat in fine_feats]
173
+ fine_feats = [F.normalize(feat, p=2, dim=-1) for feat in fine_feats]
174
+ fine_color2 = draw_correspondence(fine_feats[0], fine_feats[1], original_color1,
175
+ masks_resized[0] > 0, masks_resized[1] > 0)
176
+
177
+ fig, ax = plt.subplots(2, 2, squeeze=False)
178
+ ax[0][0].imshow(original_color1.cpu().numpy())
179
+ ax[0][1].text(-0.1, 0.5, "Original " + model_option, fontsize=7, rotation=90, va='center', transform=ax[0][1].transAxes)
180
+ ax[0][1].imshow(original_color2.cpu().numpy())
181
+
182
+ # ax[1][0].imshow(fine_color1.cpu().numpy())
183
+ ax[1][1].text(-0.1, 0.5, "Finetuned " + model_option, fontsize=7, rotation=90, va='center', transform=ax[1][1].transAxes)
184
+ ax[1][1].imshow(fine_color2.cpu().numpy())
185
+ for xx in ax:
186
+ for x in xx:
187
+ x.xaxis.set_major_formatter(plt.NullFormatter())
188
+ x.yaxis.set_major_formatter(plt.NullFormatter())
189
+ x.set_xticks([])
190
+ x.set_yticks([])
191
+ x.axis('off')
192
+
193
+ plt.tight_layout()
194
+ plt.close(fig)
195
+ return fig
196
+
197
+
198
+ demo = gr.Interface(
199
+ title="<div> \
200
+ <h1>3DCorrEnhance</h1> \
201
+ <h2>Multiview Equivariance Improves 3D Correspondence Understanding with Minimal Feature Finetuning</h2> \
202
+ <h2>ICLR 2025</h2> \
203
+ </div>",
204
+ description="<div style='display: flex; justify-content: center; align-items: center; text-align: center;'> \
205
+ <a href='https://arxiv.org/abs/2411.19458'><img src='https://img.shields.io/badge/arXiv-2411.19458-red'></a> \
206
+ &nbsp; \
207
+ <a href='#'><img src='https://img.shields.io/badge/Project_Page-3DCorrEnhance-green' alt='Project Page (Coming soon)'></a> \
208
+ &nbsp; \
209
+ <a href='https://github.com/qq456cvb/3DCorrEnhance'><img src='https://img.shields.io/badge/Github-Code-blue'></a> \
210
+ </div>",
211
+ fn=main,
212
+ inputs=[image_input1, image_input2, model_option],
213
+ outputs="plot",
214
+ examples=[
215
+ ["examples/objs/1-1.png", "examples/objs/1-2.png", "DINOv2-Base"],
216
+ ["examples/scenes/1-1.jpg", "examples/scenes/1-2.jpg", "DINOv2-Base"],
217
+ ["examples/scenes/2-1.jpg", "examples/scenes/2-2.jpg", "DINOv2-Base"],
218
+ ],
219
+ cache_examples=True)
220
+ demo.launch()
examples/objs/1-1.png ADDED
examples/objs/1-2.png ADDED
examples/scenes/1-1.jpg ADDED
examples/scenes/1-2.jpg ADDED
examples/scenes/2-1.jpg ADDED
examples/scenes/2-2.jpg ADDED