Spaces:
Running
on
Zero
Running
on
Zero
upload app.py
Browse files- app.py +217 -4
- examples/objs/1-1.png +0 -0
- examples/objs/1-2.png +0 -0
- examples/scenes/1-1.jpg +0 -0
- examples/scenes/1-2.jpg +0 -0
- examples/scenes/2-1.jpg +0 -0
- examples/scenes/2-2.jpg +0 -0
app.py
CHANGED
@@ -1,7 +1,220 @@
|
|
1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
-
|
4 |
-
|
|
|
|
|
5 |
|
6 |
-
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
+
import matplotlib.pyplot as plt
|
3 |
+
import numpy as np
|
4 |
+
import os
|
5 |
+
import requests
|
6 |
+
import spaces
|
7 |
+
import timm
|
8 |
+
import torch
|
9 |
+
import torchvision.transforms as T
|
10 |
+
import types
|
11 |
+
import albumentations as A
|
12 |
+
import torch.nn.functional as F
|
13 |
|
14 |
+
from PIL import Image
|
15 |
+
from tqdm import tqdm
|
16 |
+
from sklearn.decomposition import PCA
|
17 |
+
from torch_kmeans import KMeans, CosineSimilarity
|
18 |
|
19 |
+
cmap = plt.get_cmap("tab20")
|
20 |
+
imagenet_transform = T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
|
21 |
+
|
22 |
+
def get_bg_mask(image):
|
23 |
+
# detect background based on the four edges
|
24 |
+
image = np.array(image)
|
25 |
+
if np.all(image[:, 0] == image[0, 0]) and np.all(image[:, -1] == image[0, -1]) \
|
26 |
+
and np.all(image[0, :] == image[0, 0]) and np.all(image[-1, :] == image[-1, 0]) \
|
27 |
+
and np.all(image[0, 0] == image[0, -1]) and np.all(image[0, 0] == image[-1, 0]) \
|
28 |
+
and np.all(image[0, 0] == image[-1, -1]):
|
29 |
+
return np.any(image != image[0, 0], -1)
|
30 |
+
return np.ones_like(image[:, :, 0], dtype=bool)
|
31 |
+
|
32 |
+
|
33 |
+
def download_image(url, save_path):
|
34 |
+
response = requests.get(url)
|
35 |
+
with open(save_path, 'wb') as file:
|
36 |
+
file.write(response.content)
|
37 |
+
|
38 |
+
|
39 |
+
def process_image(image, res, patch_size, decimation=4):
|
40 |
+
image = torch.from_numpy(np.array(image) / 255.).float().permute(2, 0, 1).to(device)
|
41 |
+
|
42 |
+
tgt_size = (int(image.shape[-2] * res / image.shape[-1]), res)
|
43 |
+
if image.shape[-2] > image.shape[-1]:
|
44 |
+
tgt_size = (res, int(image.shape[-1] * res / image.shape[-2]))
|
45 |
+
|
46 |
+
patch_h, patch_w = tgt_size[0] // decimation, tgt_size[1] // decimation
|
47 |
+
image_resized = T.functional.resize(image, (patch_h * patch_size, patch_w * patch_size))
|
48 |
+
|
49 |
+
image_resized = imagenet_transform(image_resized)
|
50 |
+
|
51 |
+
return image_resized
|
52 |
+
|
53 |
+
def generate_grid(x, y, stride):
|
54 |
+
x_coords = np.arange(0, x, grid_stride)
|
55 |
+
y_coords = np.arange(0, y, grid_stride)
|
56 |
+
|
57 |
+
x_mesh, y_mesh = np.meshgrid(x_coords, y_coords)
|
58 |
+
kp = np.column_stack((x_mesh.ravel(), y_mesh.ravel())).astype(float)
|
59 |
+
return kp
|
60 |
+
|
61 |
+
def pca(feat, pca_dim=3):
|
62 |
+
feat_flattened = feat
|
63 |
+
mean = torch.mean(feat_flattened, dim=0)
|
64 |
+
centered_features = feat_flattened - mean
|
65 |
+
U, S, V = torch.pca_lowrank(centered_features, q=pca_dim)
|
66 |
+
reduced_features = torch.matmul(centered_features, V[:, :pca_dim])
|
67 |
+
return reduced_features
|
68 |
+
|
69 |
+
def co_pca(feat1, feat2, pca_dim=3):
|
70 |
+
co_feats = torch.cat((feat1.reshape(-1, feat1.shape[-1]), feat2.reshape(-1, feat2.shape[-1])), dim=0)
|
71 |
+
feats = pca(co_feats)
|
72 |
+
feat1_pca = feats[:feat1.shape[0]*feat1.shape[1]].reshape(feat1.shape[0], feat1.shape[1], -1)
|
73 |
+
feat2_pca = feats[feat1.shape[0]*feat1.shape[1]:].reshape(feat2.shape[0], feat2.shape[1], -1)
|
74 |
+
return feat1_pca, feat2_pca
|
75 |
+
|
76 |
+
|
77 |
+
def draw_correspondence(feat1, feat2, color1, mask1, mask2):
|
78 |
+
original_mask2_shape = mask2.shape
|
79 |
+
mask1, mask2 = mask1.reshape(-1), mask2.reshape(-1)
|
80 |
+
distances = torch.cdist(feat1.reshape(-1, feat1.shape[-1])[mask1], feat2.reshape(-1, feat2.shape[-1])[mask2])
|
81 |
+
nearest = torch.argmin(distances, dim=0)
|
82 |
+
color2 = torch.zeros((mask2.shape[0], 3,)).to(device)
|
83 |
+
color2[mask2] = color1.reshape(-1, 3)[mask1][nearest]
|
84 |
+
color2 = color2.reshape(*original_mask2_shape, 3)
|
85 |
+
return color2
|
86 |
+
|
87 |
+
def load_model(options):
|
88 |
+
original_models = {}
|
89 |
+
fine_models = {}
|
90 |
+
for option in tqdm(options):
|
91 |
+
print('Please wait ...')
|
92 |
+
print('loading weights of ', option)
|
93 |
+
fine_models[option] = torch.hub.load(".", model_card[option], source='local').to(device)
|
94 |
+
original_models[option] = torch.hub.load(repo_or_dir="facebookresearch/dinov2", model=fine_models[option].backbone_name).eval().to(device)
|
95 |
+
print('Done! Now play the demo :)')
|
96 |
+
return original_models, fine_models
|
97 |
+
|
98 |
+
if __name__ == "__main__":
|
99 |
+
|
100 |
+
if torch.cuda.is_available():
|
101 |
+
device = torch.device('cuda')
|
102 |
+
else:
|
103 |
+
device = torch.device('cpu')
|
104 |
+
|
105 |
+
print("device: ")
|
106 |
+
print(device)
|
107 |
+
|
108 |
+
example_dir = "examples"
|
109 |
+
|
110 |
+
os.makedirs(example_dir, exist_ok=True)
|
111 |
+
|
112 |
+
image_input1 = gr.Image(label="Choose an image:",
|
113 |
+
height=500,
|
114 |
+
type="pil",
|
115 |
+
image_mode='RGB',
|
116 |
+
sources=['upload', 'webcam', 'clipboard']
|
117 |
+
)
|
118 |
+
image_input2 = gr.Image(label="Choose another image:",
|
119 |
+
height=500,
|
120 |
+
type="pil",
|
121 |
+
image_mode='RGB',
|
122 |
+
sources=['upload', 'webcam', 'clipboard']
|
123 |
+
)
|
124 |
+
|
125 |
+
options = ['DINOv2-Base']
|
126 |
+
model_option = gr.Radio(options, value="DINOv2-Base", label='Choose a 2D foundation model')
|
127 |
+
|
128 |
+
model_card = {
|
129 |
+
"DINOv2-Base": "dinov2_base",
|
130 |
+
}
|
131 |
+
|
132 |
+
|
133 |
+
os.environ['TORCH_HOME'] = '/tmp/.cache'
|
134 |
+
# os.environ['GRADIO_EXAMPLES_CACHE'] = '/tmp/gradio_cache'
|
135 |
+
|
136 |
+
# Pre-load all models
|
137 |
+
original_models, fine_models = load_model(options)
|
138 |
+
|
139 |
+
@spaces.GPU
|
140 |
+
def main(image1, image2, model_option, kmeans_num):
|
141 |
+
if image1 is None or image2 is None:
|
142 |
+
return None
|
143 |
+
# Select model
|
144 |
+
original_model = original_models[model_option]
|
145 |
+
fine_model = fine_models[model_option]
|
146 |
+
|
147 |
+
images_resized = [process_image(image, 640, 14, decimation=8) for image in [image1, image2]]
|
148 |
+
masks = [torch.from_numpy(get_bg_mask(image)).to(device) for image in [image1, image2]]
|
149 |
+
feat_shapes = [(images_resized[0].shape[-2] // 14, images_resized[0].shape[-1] // 14),
|
150 |
+
(images_resized[1].shape[-2] // 14, images_resized[1].shape[-1] // 14)]
|
151 |
+
|
152 |
+
masks_resized = [T.functional.resize(mask.float()[None], feat_shape,
|
153 |
+
interpolation=T.functional.InterpolationMode.NEAREST_EXACT)[0]
|
154 |
+
for mask, feat_shape in zip(masks, feat_shapes)]
|
155 |
+
|
156 |
+
with torch.no_grad():
|
157 |
+
original_feats = [original_model.forward_features(image[None])['x_norm_patchtokens'].reshape(*feat_shape, -1)
|
158 |
+
for image, feat_shape in zip(images_resized, feat_shapes)]
|
159 |
+
original_feats = [F.normalize(feat, p=2, dim=-1) for feat in original_feats]
|
160 |
+
|
161 |
+
original_color1 = torch.zeros((original_feats[0].shape[0] * original_feats[0].shape[1], 3,)).to(device)
|
162 |
+
color = pca((original_feats[0][masks_resized[0] > 0]), 3)
|
163 |
+
color = (color - color.min()) / (color.max() - color.min())
|
164 |
+
original_color1[masks_resized[0].reshape(-1) > 0] = color
|
165 |
+
original_color1 = original_color1.reshape(*original_feats[0].shape[:2], 3)
|
166 |
+
|
167 |
+
original_color2 = draw_correspondence(original_feats[0], original_feats[1], original_color1,
|
168 |
+
masks_resized[0] > 0, masks_resized[1] > 0)
|
169 |
+
|
170 |
+
fine_feats = [fine_model.dinov2.forward_features(image[None])['x_norm_patchtokens'].reshape(*feat_shape, -1)
|
171 |
+
for image, feat_shape in zip(images_resized, feat_shapes)]
|
172 |
+
fine_feats = [fine_model.refine_conv(feat[None].permute(0, 3, 1, 2)).permute(0, 2, 3, 1)[0] for feat in fine_feats]
|
173 |
+
fine_feats = [F.normalize(feat, p=2, dim=-1) for feat in fine_feats]
|
174 |
+
fine_color2 = draw_correspondence(fine_feats[0], fine_feats[1], original_color1,
|
175 |
+
masks_resized[0] > 0, masks_resized[1] > 0)
|
176 |
+
|
177 |
+
fig, ax = plt.subplots(2, 2, squeeze=False)
|
178 |
+
ax[0][0].imshow(original_color1.cpu().numpy())
|
179 |
+
ax[0][1].text(-0.1, 0.5, "Original " + model_option, fontsize=7, rotation=90, va='center', transform=ax[0][1].transAxes)
|
180 |
+
ax[0][1].imshow(original_color2.cpu().numpy())
|
181 |
+
|
182 |
+
# ax[1][0].imshow(fine_color1.cpu().numpy())
|
183 |
+
ax[1][1].text(-0.1, 0.5, "Finetuned " + model_option, fontsize=7, rotation=90, va='center', transform=ax[1][1].transAxes)
|
184 |
+
ax[1][1].imshow(fine_color2.cpu().numpy())
|
185 |
+
for xx in ax:
|
186 |
+
for x in xx:
|
187 |
+
x.xaxis.set_major_formatter(plt.NullFormatter())
|
188 |
+
x.yaxis.set_major_formatter(plt.NullFormatter())
|
189 |
+
x.set_xticks([])
|
190 |
+
x.set_yticks([])
|
191 |
+
x.axis('off')
|
192 |
+
|
193 |
+
plt.tight_layout()
|
194 |
+
plt.close(fig)
|
195 |
+
return fig
|
196 |
+
|
197 |
+
|
198 |
+
demo = gr.Interface(
|
199 |
+
title="<div> \
|
200 |
+
<h1>3DCorrEnhance</h1> \
|
201 |
+
<h2>Multiview Equivariance Improves 3D Correspondence Understanding with Minimal Feature Finetuning</h2> \
|
202 |
+
<h2>ICLR 2025</h2> \
|
203 |
+
</div>",
|
204 |
+
description="<div style='display: flex; justify-content: center; align-items: center; text-align: center;'> \
|
205 |
+
<a href='https://arxiv.org/abs/2411.19458'><img src='https://img.shields.io/badge/arXiv-2411.19458-red'></a> \
|
206 |
+
\
|
207 |
+
<a href='#'><img src='https://img.shields.io/badge/Project_Page-3DCorrEnhance-green' alt='Project Page (Coming soon)'></a> \
|
208 |
+
\
|
209 |
+
<a href='https://github.com/qq456cvb/3DCorrEnhance'><img src='https://img.shields.io/badge/Github-Code-blue'></a> \
|
210 |
+
</div>",
|
211 |
+
fn=main,
|
212 |
+
inputs=[image_input1, image_input2, model_option],
|
213 |
+
outputs="plot",
|
214 |
+
examples=[
|
215 |
+
["examples/objs/1-1.png", "examples/objs/1-2.png", "DINOv2-Base"],
|
216 |
+
["examples/scenes/1-1.jpg", "examples/scenes/1-2.jpg", "DINOv2-Base"],
|
217 |
+
["examples/scenes/2-1.jpg", "examples/scenes/2-2.jpg", "DINOv2-Base"],
|
218 |
+
],
|
219 |
+
cache_examples=True)
|
220 |
+
demo.launch()
|
examples/objs/1-1.png
ADDED
![]() |
examples/objs/1-2.png
ADDED
![]() |
examples/scenes/1-1.jpg
ADDED
![]() |
examples/scenes/1-2.jpg
ADDED
![]() |
examples/scenes/2-1.jpg
ADDED
![]() |
examples/scenes/2-2.jpg
ADDED
![]() |