ThomasSimonini HF staff commited on
Commit
d4759cc
Β·
verified Β·
1 Parent(s): 9dd4e7f

Delete instant-mesh

Browse files
This view is limited to 50 files because it contains too many changes. Β  See raw diff
Files changed (50) hide show
  1. instant-mesh/configs/instant-mesh-base.yaml +0 -22
  2. instant-mesh/configs/instant-mesh-large.yaml +0 -22
  3. instant-mesh/configs/instant-nerf-base.yaml +0 -21
  4. instant-mesh/configs/instant-nerf-large.yaml +0 -21
  5. instant-mesh/examples/bird.jpg +0 -0
  6. instant-mesh/examples/bubble_mart_blue.png +0 -0
  7. instant-mesh/examples/cake.jpg +0 -0
  8. instant-mesh/examples/cartoon_dinosaur.png +0 -0
  9. instant-mesh/examples/cartoon_panda.png +0 -3
  10. instant-mesh/examples/chair_armed.png +0 -0
  11. instant-mesh/examples/chair_comfort.jpg +0 -0
  12. instant-mesh/examples/chair_wood.jpg +0 -0
  13. instant-mesh/examples/chest.jpg +0 -0
  14. instant-mesh/examples/cute_horse.jpg +0 -0
  15. instant-mesh/examples/cute_tiger.jpg +0 -0
  16. instant-mesh/examples/earphone.jpg +0 -0
  17. instant-mesh/examples/fox.jpg +0 -0
  18. instant-mesh/examples/fruit.jpg +0 -0
  19. instant-mesh/examples/fruit_elephant.jpg +0 -0
  20. instant-mesh/examples/genshin_building.png +0 -0
  21. instant-mesh/examples/genshin_teapot.png +0 -0
  22. instant-mesh/examples/hatsune_miku.png +0 -0
  23. instant-mesh/examples/house2.jpg +0 -0
  24. instant-mesh/examples/mushroom_teapot.jpg +0 -0
  25. instant-mesh/examples/pikachu.png +0 -0
  26. instant-mesh/examples/plant.jpg +0 -0
  27. instant-mesh/examples/robot.jpg +0 -0
  28. instant-mesh/examples/sea_turtle.png +0 -0
  29. instant-mesh/examples/skating_shoe.jpg +0 -0
  30. instant-mesh/examples/sorting_board.png +0 -0
  31. instant-mesh/examples/sword.png +0 -0
  32. instant-mesh/examples/toy_car.jpg +0 -0
  33. instant-mesh/examples/watermelon.png +0 -0
  34. instant-mesh/examples/whitedog.png +0 -0
  35. instant-mesh/examples/x_teapot.jpg +0 -0
  36. instant-mesh/examples/x_toyduck.jpg +0 -0
  37. instant-mesh/src/__init__.py +0 -0
  38. instant-mesh/src/data/__init__.py +0 -0
  39. instant-mesh/src/data/objaverse.py +0 -329
  40. instant-mesh/src/model.py +0 -310
  41. instant-mesh/src/model_mesh.py +0 -325
  42. instant-mesh/src/models/__init__.py +0 -0
  43. instant-mesh/src/models/decoder/__init__.py +0 -0
  44. instant-mesh/src/models/decoder/transformer.py +0 -123
  45. instant-mesh/src/models/encoder/__init__.py +0 -0
  46. instant-mesh/src/models/encoder/dino.py +0 -550
  47. instant-mesh/src/models/encoder/dino_wrapper.py +0 -80
  48. instant-mesh/src/models/geometry/__init__.py +0 -7
  49. instant-mesh/src/models/geometry/camera/__init__.py +0 -16
  50. instant-mesh/src/models/geometry/camera/perspective_camera.py +0 -35
instant-mesh/configs/instant-mesh-base.yaml DELETED
@@ -1,22 +0,0 @@
1
- model_config:
2
- target: src.models.lrm_mesh.InstantMesh
3
- params:
4
- encoder_feat_dim: 768
5
- encoder_freeze: false
6
- encoder_model_name: facebook/dino-vitb16
7
- transformer_dim: 1024
8
- transformer_layers: 12
9
- transformer_heads: 16
10
- triplane_low_res: 32
11
- triplane_high_res: 64
12
- triplane_dim: 40
13
- rendering_samples_per_ray: 96
14
- grid_res: 128
15
- grid_scale: 2.1
16
-
17
-
18
- infer_config:
19
- unet_path: ckpts/diffusion_pytorch_model.bin
20
- model_path: ckpts/instant_mesh_base.ckpt
21
- texture_resolution: 1024
22
- render_resolution: 512
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
instant-mesh/configs/instant-mesh-large.yaml DELETED
@@ -1,22 +0,0 @@
1
- model_config:
2
- target: src.models.lrm_mesh.InstantMesh
3
- params:
4
- encoder_feat_dim: 768
5
- encoder_freeze: false
6
- encoder_model_name: facebook/dino-vitb16
7
- transformer_dim: 1024
8
- transformer_layers: 16
9
- transformer_heads: 16
10
- triplane_low_res: 32
11
- triplane_high_res: 64
12
- triplane_dim: 80
13
- rendering_samples_per_ray: 128
14
- grid_res: 128
15
- grid_scale: 2.1
16
-
17
-
18
- infer_config:
19
- unet_path: ckpts/diffusion_pytorch_model.bin
20
- model_path: ckpts/instant_mesh_large.ckpt
21
- texture_resolution: 1024
22
- render_resolution: 512
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
instant-mesh/configs/instant-nerf-base.yaml DELETED
@@ -1,21 +0,0 @@
1
- model_config:
2
- target: src.models.lrm.InstantNeRF
3
- params:
4
- encoder_feat_dim: 768
5
- encoder_freeze: false
6
- encoder_model_name: facebook/dino-vitb16
7
- transformer_dim: 1024
8
- transformer_layers: 12
9
- transformer_heads: 16
10
- triplane_low_res: 32
11
- triplane_high_res: 64
12
- triplane_dim: 40
13
- rendering_samples_per_ray: 96
14
-
15
-
16
- infer_config:
17
- unet_path: ckpts/diffusion_pytorch_model.bin
18
- model_path: ckpts/instant_nerf_base.ckpt
19
- mesh_threshold: 10.0
20
- mesh_resolution: 256
21
- render_resolution: 384
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
instant-mesh/configs/instant-nerf-large.yaml DELETED
@@ -1,21 +0,0 @@
1
- model_config:
2
- target: src.models.lrm.InstantNeRF
3
- params:
4
- encoder_feat_dim: 768
5
- encoder_freeze: false
6
- encoder_model_name: facebook/dino-vitb16
7
- transformer_dim: 1024
8
- transformer_layers: 16
9
- transformer_heads: 16
10
- triplane_low_res: 32
11
- triplane_high_res: 64
12
- triplane_dim: 80
13
- rendering_samples_per_ray: 128
14
-
15
-
16
- infer_config:
17
- unet_path: ckpts/diffusion_pytorch_model.bin
18
- model_path: ckpts/instant_nerf_large.ckpt
19
- mesh_threshold: 10.0
20
- mesh_resolution: 256
21
- render_resolution: 384
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
instant-mesh/examples/bird.jpg DELETED
Binary file (38.1 kB)
 
instant-mesh/examples/bubble_mart_blue.png DELETED
Binary file (212 kB)
 
instant-mesh/examples/cake.jpg DELETED
Binary file (82 kB)
 
instant-mesh/examples/cartoon_dinosaur.png DELETED
Binary file (727 kB)
 
instant-mesh/examples/cartoon_panda.png DELETED

Git LFS Details

  • SHA256: c82fea6ac66b782b2aa1c6bd133447b5f54f688c7eb44998c4b00f190d47b2b7
  • Pointer size: 132 Bytes
  • Size of remote file: 1.52 MB
instant-mesh/examples/chair_armed.png DELETED
Binary file (184 kB)
 
instant-mesh/examples/chair_comfort.jpg DELETED
Binary file (10.8 kB)
 
instant-mesh/examples/chair_wood.jpg DELETED
Binary file (12.5 kB)
 
instant-mesh/examples/chest.jpg DELETED
Binary file (26.6 kB)
 
instant-mesh/examples/cute_horse.jpg DELETED
Binary file (150 kB)
 
instant-mesh/examples/cute_tiger.jpg DELETED
Binary file (41.1 kB)
 
instant-mesh/examples/earphone.jpg DELETED
Binary file (30.2 kB)
 
instant-mesh/examples/fox.jpg DELETED
Binary file (64 kB)
 
instant-mesh/examples/fruit.jpg DELETED
Binary file (32.1 kB)
 
instant-mesh/examples/fruit_elephant.jpg DELETED
Binary file (40.7 kB)
 
instant-mesh/examples/genshin_building.png DELETED
Binary file (52.8 kB)
 
instant-mesh/examples/genshin_teapot.png DELETED
Binary file (19.8 kB)
 
instant-mesh/examples/hatsune_miku.png DELETED
Binary file (96.2 kB)
 
instant-mesh/examples/house2.jpg DELETED
Binary file (79.2 kB)
 
instant-mesh/examples/mushroom_teapot.jpg DELETED
Binary file (36.3 kB)
 
instant-mesh/examples/pikachu.png DELETED
Binary file (350 kB)
 
instant-mesh/examples/plant.jpg DELETED
Binary file (43.9 kB)
 
instant-mesh/examples/robot.jpg DELETED
Binary file (31.9 kB)
 
instant-mesh/examples/sea_turtle.png DELETED
Binary file (185 kB)
 
instant-mesh/examples/skating_shoe.jpg DELETED
Binary file (37 kB)
 
instant-mesh/examples/sorting_board.png DELETED
Binary file (85.1 kB)
 
instant-mesh/examples/sword.png DELETED
Binary file (850 kB)
 
instant-mesh/examples/toy_car.jpg DELETED
Binary file (11.8 kB)
 
instant-mesh/examples/watermelon.png DELETED
Binary file (583 kB)
 
instant-mesh/examples/whitedog.png DELETED
Binary file (246 kB)
 
instant-mesh/examples/x_teapot.jpg DELETED
Binary file (77.4 kB)
 
instant-mesh/examples/x_toyduck.jpg DELETED
Binary file (102 kB)
 
instant-mesh/src/__init__.py DELETED
File without changes
instant-mesh/src/data/__init__.py DELETED
File without changes
instant-mesh/src/data/objaverse.py DELETED
@@ -1,329 +0,0 @@
1
- import os, sys
2
- import math
3
- import json
4
- import importlib
5
- from pathlib import Path
6
-
7
- import cv2
8
- import random
9
- import numpy as np
10
- from PIL import Image
11
- import webdataset as wds
12
- import pytorch_lightning as pl
13
-
14
- import torch
15
- import torch.nn.functional as F
16
- from torch.utils.data import Dataset
17
- from torch.utils.data import DataLoader
18
- from torch.utils.data.distributed import DistributedSampler
19
- from torchvision import transforms
20
-
21
- from src.utils.train_util import instantiate_from_config
22
- from src.utils.camera_util import (
23
- FOV_to_intrinsics,
24
- center_looking_at_camera_pose,
25
- get_surrounding_views,
26
- )
27
-
28
-
29
- class DataModuleFromConfig(pl.LightningDataModule):
30
- def __init__(
31
- self,
32
- batch_size=8,
33
- num_workers=4,
34
- train=None,
35
- validation=None,
36
- test=None,
37
- **kwargs,
38
- ):
39
- super().__init__()
40
-
41
- self.batch_size = batch_size
42
- self.num_workers = num_workers
43
-
44
- self.dataset_configs = dict()
45
- if train is not None:
46
- self.dataset_configs['train'] = train
47
- if validation is not None:
48
- self.dataset_configs['validation'] = validation
49
- if test is not None:
50
- self.dataset_configs['test'] = test
51
-
52
- def setup(self, stage):
53
-
54
- if stage in ['fit']:
55
- self.datasets = dict((k, instantiate_from_config(self.dataset_configs[k])) for k in self.dataset_configs)
56
- else:
57
- raise NotImplementedError
58
-
59
- def train_dataloader(self):
60
-
61
- sampler = DistributedSampler(self.datasets['train'])
62
- return wds.WebLoader(self.datasets['train'], batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, sampler=sampler)
63
-
64
- def val_dataloader(self):
65
-
66
- sampler = DistributedSampler(self.datasets['validation'])
67
- return wds.WebLoader(self.datasets['validation'], batch_size=1, num_workers=self.num_workers, shuffle=False, sampler=sampler)
68
-
69
- def test_dataloader(self):
70
-
71
- return wds.WebLoader(self.datasets['test'], batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
72
-
73
-
74
- class ObjaverseData(Dataset):
75
- def __init__(self,
76
- root_dir='objaverse/',
77
- meta_fname='valid_paths.json',
78
- input_image_dir='rendering_random_32views',
79
- target_image_dir='rendering_random_32views',
80
- input_view_num=6,
81
- target_view_num=2,
82
- total_view_n=32,
83
- fov=50,
84
- camera_rotation=True,
85
- validation=False,
86
- ):
87
- self.root_dir = Path(root_dir)
88
- self.input_image_dir = input_image_dir
89
- self.target_image_dir = target_image_dir
90
-
91
- self.input_view_num = input_view_num
92
- self.target_view_num = target_view_num
93
- self.total_view_n = total_view_n
94
- self.fov = fov
95
- self.camera_rotation = camera_rotation
96
-
97
- with open(os.path.join(root_dir, meta_fname)) as f:
98
- filtered_dict = json.load(f)
99
- paths = filtered_dict['good_objs']
100
- self.paths = paths
101
-
102
- self.depth_scale = 4.0
103
-
104
- total_objects = len(self.paths)
105
- print('============= length of dataset %d =============' % len(self.paths))
106
-
107
- def __len__(self):
108
- return len(self.paths)
109
-
110
- def load_im(self, path, color):
111
- '''
112
- replace background pixel with random color in rendering
113
- '''
114
- pil_img = Image.open(path)
115
-
116
- image = np.asarray(pil_img, dtype=np.float32) / 255.
117
- alpha = image[:, :, 3:]
118
- image = image[:, :, :3] * alpha + color * (1 - alpha)
119
-
120
- image = torch.from_numpy(image).permute(2, 0, 1).contiguous().float()
121
- alpha = torch.from_numpy(alpha).permute(2, 0, 1).contiguous().float()
122
- return image, alpha
123
-
124
- def __getitem__(self, index):
125
- # load data
126
- while True:
127
- input_image_path = os.path.join(self.root_dir, self.input_image_dir, self.paths[index])
128
- target_image_path = os.path.join(self.root_dir, self.target_image_dir, self.paths[index])
129
-
130
- indices = np.random.choice(range(self.total_view_n), self.input_view_num + self.target_view_num, replace=False)
131
- input_indices = indices[:self.input_view_num]
132
- target_indices = indices[self.input_view_num:]
133
-
134
- '''background color, default: white'''
135
- bg_white = [1., 1., 1.]
136
- bg_black = [0., 0., 0.]
137
-
138
- image_list = []
139
- alpha_list = []
140
- depth_list = []
141
- normal_list = []
142
- pose_list = []
143
-
144
- try:
145
- input_cameras = np.load(os.path.join(input_image_path, 'cameras.npz'))['cam_poses']
146
- for idx in input_indices:
147
- image, alpha = self.load_im(os.path.join(input_image_path, '%03d.png' % idx), bg_white)
148
- normal, _ = self.load_im(os.path.join(input_image_path, '%03d_normal.png' % idx), bg_black)
149
- depth = cv2.imread(os.path.join(input_image_path, '%03d_depth.png' % idx), cv2.IMREAD_UNCHANGED) / 255.0 * self.depth_scale
150
- depth = torch.from_numpy(depth).unsqueeze(0)
151
- pose = input_cameras[idx]
152
- pose = np.concatenate([pose, np.array([[0, 0, 0, 1]])], axis=0)
153
-
154
- image_list.append(image)
155
- alpha_list.append(alpha)
156
- depth_list.append(depth)
157
- normal_list.append(normal)
158
- pose_list.append(pose)
159
-
160
- target_cameras = np.load(os.path.join(target_image_path, 'cameras.npz'))['cam_poses']
161
- for idx in target_indices:
162
- image, alpha = self.load_im(os.path.join(target_image_path, '%03d.png' % idx), bg_white)
163
- normal, _ = self.load_im(os.path.join(target_image_path, '%03d_normal.png' % idx), bg_black)
164
- depth = cv2.imread(os.path.join(target_image_path, '%03d_depth.png' % idx), cv2.IMREAD_UNCHANGED) / 255.0 * self.depth_scale
165
- depth = torch.from_numpy(depth).unsqueeze(0)
166
- pose = target_cameras[idx]
167
- pose = np.concatenate([pose, np.array([[0, 0, 0, 1]])], axis=0)
168
-
169
- image_list.append(image)
170
- alpha_list.append(alpha)
171
- depth_list.append(depth)
172
- normal_list.append(normal)
173
- pose_list.append(pose)
174
-
175
- except Exception as e:
176
- print(e)
177
- index = np.random.randint(0, len(self.paths))
178
- continue
179
-
180
- break
181
-
182
- images = torch.stack(image_list, dim=0).float() # (6+V, 3, H, W)
183
- alphas = torch.stack(alpha_list, dim=0).float() # (6+V, 1, H, W)
184
- depths = torch.stack(depth_list, dim=0).float() # (6+V, 1, H, W)
185
- normals = torch.stack(normal_list, dim=0).float() # (6+V, 3, H, W)
186
- w2cs = torch.from_numpy(np.stack(pose_list, axis=0)).float() # (6+V, 4, 4)
187
- c2ws = torch.linalg.inv(w2cs).float()
188
-
189
- normals = normals * 2.0 - 1.0
190
- normals = F.normalize(normals, dim=1)
191
- normals = (normals + 1.0) / 2.0
192
- normals = torch.lerp(torch.zeros_like(normals), normals, alphas)
193
-
194
- # random rotation along z axis
195
- if self.camera_rotation:
196
- degree = np.random.uniform(0, math.pi * 2)
197
- rot = torch.tensor([
198
- [np.cos(degree), -np.sin(degree), 0, 0],
199
- [np.sin(degree), np.cos(degree), 0, 0],
200
- [0, 0, 1, 0],
201
- [0, 0, 0, 1],
202
- ]).unsqueeze(0).float()
203
- c2ws = torch.matmul(rot, c2ws)
204
-
205
- # rotate normals
206
- N, _, H, W = normals.shape
207
- normals = normals * 2.0 - 1.0
208
- normals = torch.matmul(rot[:, :3, :3], normals.view(N, 3, -1)).view(N, 3, H, W)
209
- normals = F.normalize(normals, dim=1)
210
- normals = (normals + 1.0) / 2.0
211
- normals = torch.lerp(torch.zeros_like(normals), normals, alphas)
212
-
213
- # random scaling
214
- if np.random.rand() < 0.5:
215
- scale = np.random.uniform(0.8, 1.0)
216
- c2ws[:, :3, 3] *= scale
217
- depths *= scale
218
-
219
- # instrinsics of perspective cameras
220
- K = FOV_to_intrinsics(self.fov)
221
- Ks = K.unsqueeze(0).repeat(self.input_view_num + self.target_view_num, 1, 1).float()
222
-
223
- data = {
224
- 'input_images': images[:self.input_view_num], # (6, 3, H, W)
225
- 'input_alphas': alphas[:self.input_view_num], # (6, 1, H, W)
226
- 'input_depths': depths[:self.input_view_num], # (6, 1, H, W)
227
- 'input_normals': normals[:self.input_view_num], # (6, 3, H, W)
228
- 'input_c2ws': c2ws_input[:self.input_view_num], # (6, 4, 4)
229
- 'input_Ks': Ks[:self.input_view_num], # (6, 3, 3)
230
-
231
- # lrm generator input and supervision
232
- 'target_images': images[self.input_view_num:], # (V, 3, H, W)
233
- 'target_alphas': alphas[self.input_view_num:], # (V, 1, H, W)
234
- 'target_depths': depths[self.input_view_num:], # (V, 1, H, W)
235
- 'target_normals': normals[self.input_view_num:], # (V, 3, H, W)
236
- 'target_c2ws': c2ws[self.input_view_num:], # (V, 4, 4)
237
- 'target_Ks': Ks[self.input_view_num:], # (V, 3, 3)
238
-
239
- 'depth_available': 1,
240
- }
241
- return data
242
-
243
-
244
- class ValidationData(Dataset):
245
- def __init__(self,
246
- root_dir='objaverse/',
247
- input_view_num=6,
248
- input_image_size=256,
249
- fov=50,
250
- ):
251
- self.root_dir = Path(root_dir)
252
- self.input_view_num = input_view_num
253
- self.input_image_size = input_image_size
254
- self.fov = fov
255
-
256
- self.paths = sorted(os.listdir(self.root_dir))
257
- print('============= length of dataset %d =============' % len(self.paths))
258
-
259
- cam_distance = 2.5
260
- azimuths = np.array([30, 90, 150, 210, 270, 330])
261
- elevations = np.array([30, -20, 30, -20, 30, -20])
262
- azimuths = np.deg2rad(azimuths)
263
- elevations = np.deg2rad(elevations)
264
-
265
- x = cam_distance * np.cos(elevations) * np.cos(azimuths)
266
- y = cam_distance * np.cos(elevations) * np.sin(azimuths)
267
- z = cam_distance * np.sin(elevations)
268
-
269
- cam_locations = np.stack([x, y, z], axis=-1)
270
- cam_locations = torch.from_numpy(cam_locations).float()
271
- c2ws = center_looking_at_camera_pose(cam_locations)
272
- self.c2ws = c2ws.float()
273
- self.Ks = FOV_to_intrinsics(self.fov).unsqueeze(0).repeat(6, 1, 1).float()
274
-
275
- render_c2ws = get_surrounding_views(M=8, radius=cam_distance)
276
- render_Ks = FOV_to_intrinsics(self.fov).unsqueeze(0).repeat(render_c2ws.shape[0], 1, 1)
277
- self.render_c2ws = render_c2ws.float()
278
- self.render_Ks = render_Ks.float()
279
-
280
- def __len__(self):
281
- return len(self.paths)
282
-
283
- def load_im(self, path, color):
284
- '''
285
- replace background pixel with random color in rendering
286
- '''
287
- pil_img = Image.open(path)
288
- pil_img = pil_img.resize((self.input_image_size, self.input_image_size), resample=Image.BICUBIC)
289
-
290
- image = np.asarray(pil_img, dtype=np.float32) / 255.
291
- if image.shape[-1] == 4:
292
- alpha = image[:, :, 3:]
293
- image = image[:, :, :3] * alpha + color * (1 - alpha)
294
- else:
295
- alpha = np.ones_like(image[:, :, :1])
296
-
297
- image = torch.from_numpy(image).permute(2, 0, 1).contiguous().float()
298
- alpha = torch.from_numpy(alpha).permute(2, 0, 1).contiguous().float()
299
- return image, alpha
300
-
301
- def __getitem__(self, index):
302
- # load data
303
- input_image_path = os.path.join(self.root_dir, self.paths[index])
304
-
305
- '''background color, default: white'''
306
- # color = np.random.uniform(0.48, 0.52)
307
- bkg_color = [1.0, 1.0, 1.0]
308
-
309
- image_list = []
310
- alpha_list = []
311
-
312
- for idx in range(self.input_view_num):
313
- image, alpha = self.load_im(os.path.join(input_image_path, f'{idx:03d}.png'), bkg_color)
314
- image_list.append(image)
315
- alpha_list.append(alpha)
316
-
317
- images = torch.stack(image_list, dim=0).float() # (6+V, 3, H, W)
318
- alphas = torch.stack(alpha_list, dim=0).float() # (6+V, 1, H, W)
319
-
320
- data = {
321
- 'input_images': images, # (6, 3, H, W)
322
- 'input_alphas': alphas, # (6, 1, H, W)
323
- 'input_c2ws': self.c2ws, # (6, 4, 4)
324
- 'input_Ks': self.Ks, # (6, 3, 3)
325
-
326
- 'render_c2ws': self.render_c2ws,
327
- 'render_Ks': self.render_Ks,
328
- }
329
- return data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
instant-mesh/src/model.py DELETED
@@ -1,310 +0,0 @@
1
- import os
2
- import numpy as np
3
- import torch
4
- import torch.nn.functional as F
5
- from torchvision.transforms import v2
6
- from torchvision.utils import make_grid, save_image
7
- from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
8
- import pytorch_lightning as pl
9
- from einops import rearrange, repeat
10
-
11
- from src.utils.train_util import instantiate_from_config
12
-
13
-
14
- class MVRecon(pl.LightningModule):
15
- def __init__(
16
- self,
17
- lrm_generator_config,
18
- lrm_path=None,
19
- input_size=256,
20
- render_size=192,
21
- ):
22
- super(MVRecon, self).__init__()
23
-
24
- self.input_size = input_size
25
- self.render_size = render_size
26
-
27
- # init modules
28
- self.lrm_generator = instantiate_from_config(lrm_generator_config)
29
- if lrm_path is not None:
30
- lrm_ckpt = torch.load(lrm_path)
31
- self.lrm_generator.load_state_dict(lrm_ckpt['weights'], strict=False)
32
-
33
- self.lpips = LearnedPerceptualImagePatchSimilarity(net_type='vgg')
34
-
35
- self.validation_step_outputs = []
36
-
37
- def on_fit_start(self):
38
- if self.global_rank == 0:
39
- os.makedirs(os.path.join(self.logdir, 'images'), exist_ok=True)
40
- os.makedirs(os.path.join(self.logdir, 'images_val'), exist_ok=True)
41
-
42
- def prepare_batch_data(self, batch):
43
- lrm_generator_input = {}
44
- render_gt = {} # for supervision
45
-
46
- # input images
47
- images = batch['input_images']
48
- images = v2.functional.resize(
49
- images, self.input_size, interpolation=3, antialias=True).clamp(0, 1)
50
-
51
- lrm_generator_input['images'] = images.to(self.device)
52
-
53
- # input cameras and render cameras
54
- input_c2ws = batch['input_c2ws'].flatten(-2)
55
- input_Ks = batch['input_Ks'].flatten(-2)
56
- target_c2ws = batch['target_c2ws'].flatten(-2)
57
- target_Ks = batch['target_Ks'].flatten(-2)
58
- render_cameras_input = torch.cat([input_c2ws, input_Ks], dim=-1)
59
- render_cameras_target = torch.cat([target_c2ws, target_Ks], dim=-1)
60
- render_cameras = torch.cat([render_cameras_input, render_cameras_target], dim=1)
61
-
62
- input_extrinsics = input_c2ws[:, :, :12]
63
- input_intrinsics = torch.stack([
64
- input_Ks[:, :, 0], input_Ks[:, :, 4],
65
- input_Ks[:, :, 2], input_Ks[:, :, 5],
66
- ], dim=-1)
67
- cameras = torch.cat([input_extrinsics, input_intrinsics], dim=-1)
68
-
69
- # add noise to input cameras
70
- cameras = cameras + torch.rand_like(cameras) * 0.04 - 0.02
71
-
72
- lrm_generator_input['cameras'] = cameras.to(self.device)
73
- lrm_generator_input['render_cameras'] = render_cameras.to(self.device)
74
-
75
- # target images
76
- target_images = torch.cat([batch['input_images'], batch['target_images']], dim=1)
77
- target_depths = torch.cat([batch['input_depths'], batch['target_depths']], dim=1)
78
- target_alphas = torch.cat([batch['input_alphas'], batch['target_alphas']], dim=1)
79
-
80
- # random crop
81
- render_size = np.random.randint(self.render_size, 513)
82
- target_images = v2.functional.resize(
83
- target_images, render_size, interpolation=3, antialias=True).clamp(0, 1)
84
- target_depths = v2.functional.resize(
85
- target_depths, render_size, interpolation=0, antialias=True)
86
- target_alphas = v2.functional.resize(
87
- target_alphas, render_size, interpolation=0, antialias=True)
88
-
89
- crop_params = v2.RandomCrop.get_params(
90
- target_images, output_size=(self.render_size, self.render_size))
91
- target_images = v2.functional.crop(target_images, *crop_params)
92
- target_depths = v2.functional.crop(target_depths, *crop_params)[:, :, 0:1]
93
- target_alphas = v2.functional.crop(target_alphas, *crop_params)[:, :, 0:1]
94
-
95
- lrm_generator_input['render_size'] = render_size
96
- lrm_generator_input['crop_params'] = crop_params
97
-
98
- render_gt['target_images'] = target_images.to(self.device)
99
- render_gt['target_depths'] = target_depths.to(self.device)
100
- render_gt['target_alphas'] = target_alphas.to(self.device)
101
-
102
- return lrm_generator_input, render_gt
103
-
104
- def prepare_validation_batch_data(self, batch):
105
- lrm_generator_input = {}
106
-
107
- # input images
108
- images = batch['input_images']
109
- images = v2.functional.resize(
110
- images, self.input_size, interpolation=3, antialias=True).clamp(0, 1)
111
-
112
- lrm_generator_input['images'] = images.to(self.device)
113
-
114
- input_c2ws = batch['input_c2ws'].flatten(-2)
115
- input_Ks = batch['input_Ks'].flatten(-2)
116
-
117
- input_extrinsics = input_c2ws[:, :, :12]
118
- input_intrinsics = torch.stack([
119
- input_Ks[:, :, 0], input_Ks[:, :, 4],
120
- input_Ks[:, :, 2], input_Ks[:, :, 5],
121
- ], dim=-1)
122
- cameras = torch.cat([input_extrinsics, input_intrinsics], dim=-1)
123
-
124
- lrm_generator_input['cameras'] = cameras.to(self.device)
125
-
126
- render_c2ws = batch['render_c2ws'].flatten(-2)
127
- render_Ks = batch['render_Ks'].flatten(-2)
128
- render_cameras = torch.cat([render_c2ws, render_Ks], dim=-1)
129
-
130
- lrm_generator_input['render_cameras'] = render_cameras.to(self.device)
131
- lrm_generator_input['render_size'] = 384
132
- lrm_generator_input['crop_params'] = None
133
-
134
- return lrm_generator_input
135
-
136
- def forward_lrm_generator(
137
- self,
138
- images,
139
- cameras,
140
- render_cameras,
141
- render_size=192,
142
- crop_params=None,
143
- chunk_size=1,
144
- ):
145
- planes = torch.utils.checkpoint.checkpoint(
146
- self.lrm_generator.forward_planes,
147
- images,
148
- cameras,
149
- use_reentrant=False,
150
- )
151
- frames = []
152
- for i in range(0, render_cameras.shape[1], chunk_size):
153
- frames.append(
154
- torch.utils.checkpoint.checkpoint(
155
- self.lrm_generator.synthesizer,
156
- planes,
157
- cameras=render_cameras[:, i:i+chunk_size],
158
- render_size=render_size,
159
- crop_params=crop_params,
160
- use_reentrant=False
161
- )
162
- )
163
- frames = {
164
- k: torch.cat([r[k] for r in frames], dim=1)
165
- for k in frames[0].keys()
166
- }
167
- return frames
168
-
169
- def forward(self, lrm_generator_input):
170
- images = lrm_generator_input['images']
171
- cameras = lrm_generator_input['cameras']
172
- render_cameras = lrm_generator_input['render_cameras']
173
- render_size = lrm_generator_input['render_size']
174
- crop_params = lrm_generator_input['crop_params']
175
-
176
- out = self.forward_lrm_generator(
177
- images,
178
- cameras,
179
- render_cameras,
180
- render_size=render_size,
181
- crop_params=crop_params,
182
- chunk_size=1,
183
- )
184
- render_images = torch.clamp(out['images_rgb'], 0.0, 1.0)
185
- render_depths = out['images_depth']
186
- render_alphas = torch.clamp(out['images_weight'], 0.0, 1.0)
187
-
188
- out = {
189
- 'render_images': render_images,
190
- 'render_depths': render_depths,
191
- 'render_alphas': render_alphas,
192
- }
193
- return out
194
-
195
- def training_step(self, batch, batch_idx):
196
- lrm_generator_input, render_gt = self.prepare_batch_data(batch)
197
-
198
- render_out = self.forward(lrm_generator_input)
199
-
200
- loss, loss_dict = self.compute_loss(render_out, render_gt)
201
-
202
- self.log_dict(loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=True)
203
-
204
- if self.global_step % 1000 == 0 and self.global_rank == 0:
205
- B, N, C, H, W = render_gt['target_images'].shape
206
- N_in = lrm_generator_input['images'].shape[1]
207
-
208
- input_images = v2.functional.resize(
209
- lrm_generator_input['images'], (H, W), interpolation=3, antialias=True).clamp(0, 1)
210
- input_images = torch.cat(
211
- [input_images, torch.ones(B, N-N_in, C, H, W).to(input_images)], dim=1)
212
-
213
- input_images = rearrange(
214
- input_images, 'b n c h w -> b c h (n w)')
215
- target_images = rearrange(
216
- render_gt['target_images'], 'b n c h w -> b c h (n w)')
217
- render_images = rearrange(
218
- render_out['render_images'], 'b n c h w -> b c h (n w)')
219
- target_alphas = rearrange(
220
- repeat(render_gt['target_alphas'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
221
- render_alphas = rearrange(
222
- repeat(render_out['render_alphas'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
223
- target_depths = rearrange(
224
- repeat(render_gt['target_depths'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
225
- render_depths = rearrange(
226
- repeat(render_out['render_depths'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
227
- MAX_DEPTH = torch.max(target_depths)
228
- target_depths = target_depths / MAX_DEPTH * target_alphas
229
- render_depths = render_depths / MAX_DEPTH
230
-
231
- grid = torch.cat([
232
- input_images,
233
- target_images, render_images,
234
- target_alphas, render_alphas,
235
- target_depths, render_depths,
236
- ], dim=-2)
237
- grid = make_grid(grid, nrow=target_images.shape[0], normalize=True, value_range=(0, 1))
238
-
239
- save_image(grid, os.path.join(self.logdir, 'images', f'train_{self.global_step:07d}.png'))
240
-
241
- return loss
242
-
243
- def compute_loss(self, render_out, render_gt):
244
- # NOTE: the rgb value range of OpenLRM is [0, 1]
245
- render_images = render_out['render_images']
246
- target_images = render_gt['target_images'].to(render_images)
247
- render_images = rearrange(render_images, 'b n ... -> (b n) ...') * 2.0 - 1.0
248
- target_images = rearrange(target_images, 'b n ... -> (b n) ...') * 2.0 - 1.0
249
-
250
- loss_mse = F.mse_loss(render_images, target_images)
251
- loss_lpips = 2.0 * self.lpips(render_images, target_images)
252
-
253
- render_alphas = render_out['render_alphas']
254
- target_alphas = render_gt['target_alphas']
255
- loss_mask = F.mse_loss(render_alphas, target_alphas)
256
-
257
- loss = loss_mse + loss_lpips + loss_mask
258
-
259
- prefix = 'train'
260
- loss_dict = {}
261
- loss_dict.update({f'{prefix}/loss_mse': loss_mse})
262
- loss_dict.update({f'{prefix}/loss_lpips': loss_lpips})
263
- loss_dict.update({f'{prefix}/loss_mask': loss_mask})
264
- loss_dict.update({f'{prefix}/loss': loss})
265
-
266
- return loss, loss_dict
267
-
268
- @torch.no_grad()
269
- def validation_step(self, batch, batch_idx):
270
- lrm_generator_input = self.prepare_validation_batch_data(batch)
271
-
272
- render_out = self.forward(lrm_generator_input)
273
- render_images = render_out['render_images']
274
- render_images = rearrange(render_images, 'b n c h w -> b c h (n w)')
275
-
276
- self.validation_step_outputs.append(render_images)
277
-
278
- def on_validation_epoch_end(self):
279
- images = torch.cat(self.validation_step_outputs, dim=-1)
280
-
281
- all_images = self.all_gather(images)
282
- all_images = rearrange(all_images, 'r b c h w -> (r b) c h w')
283
-
284
- if self.global_rank == 0:
285
- image_path = os.path.join(self.logdir, 'images_val', f'val_{self.global_step:07d}.png')
286
-
287
- grid = make_grid(all_images, nrow=1, normalize=True, value_range=(0, 1))
288
- save_image(grid, image_path)
289
- print(f"Saved image to {image_path}")
290
-
291
- self.validation_step_outputs.clear()
292
-
293
- def configure_optimizers(self):
294
- lr = self.learning_rate
295
-
296
- params = []
297
-
298
- lrm_params_fast, lrm_params_slow = [], []
299
- for n, p in self.lrm_generator.named_parameters():
300
- if 'adaLN_modulation' in n or 'camera_embedder' in n:
301
- lrm_params_fast.append(p)
302
- else:
303
- lrm_params_slow.append(p)
304
- params.append({"params": lrm_params_fast, "lr": lr, "weight_decay": 0.01 })
305
- params.append({"params": lrm_params_slow, "lr": lr / 10.0, "weight_decay": 0.01 })
306
-
307
- optimizer = torch.optim.AdamW(params, lr=lr, betas=(0.90, 0.95))
308
- scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 3000, eta_min=lr/4)
309
-
310
- return {'optimizer': optimizer, 'lr_scheduler': scheduler}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
instant-mesh/src/model_mesh.py DELETED
@@ -1,325 +0,0 @@
1
- import os
2
- import numpy as np
3
- import torch
4
- import torch.nn.functional as F
5
- from torchvision.transforms import v2
6
- from torchvision.utils import make_grid, save_image
7
- from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
8
- import pytorch_lightning as pl
9
- from einops import rearrange, repeat
10
-
11
- from src.utils.train_util import instantiate_from_config
12
-
13
-
14
- # Regulrarization loss for FlexiCubes
15
- def sdf_reg_loss_batch(sdf, all_edges):
16
- sdf_f1x6x2 = sdf[:, all_edges.reshape(-1)].reshape(sdf.shape[0], -1, 2)
17
- mask = torch.sign(sdf_f1x6x2[..., 0]) != torch.sign(sdf_f1x6x2[..., 1])
18
- sdf_f1x6x2 = sdf_f1x6x2[mask]
19
- sdf_diff = F.binary_cross_entropy_with_logits(
20
- sdf_f1x6x2[..., 0], (sdf_f1x6x2[..., 1] > 0).float()) + \
21
- F.binary_cross_entropy_with_logits(
22
- sdf_f1x6x2[..., 1], (sdf_f1x6x2[..., 0] > 0).float())
23
- return sdf_diff
24
-
25
-
26
- class MVRecon(pl.LightningModule):
27
- def __init__(
28
- self,
29
- lrm_generator_config,
30
- input_size=256,
31
- render_size=512,
32
- init_ckpt=None,
33
- ):
34
- super(MVRecon, self).__init__()
35
-
36
- self.input_size = input_size
37
- self.render_size = render_size
38
-
39
- # init modules
40
- self.lrm_generator = instantiate_from_config(lrm_generator_config)
41
-
42
- self.lpips = LearnedPerceptualImagePatchSimilarity(net_type='vgg')
43
-
44
- # Load weights from pretrained MVRecon model, and use the mlp
45
- # weights to initialize the weights of sdf and rgb mlps.
46
- if init_ckpt is not None:
47
- sd = torch.load(init_ckpt, map_location='cpu')['state_dict']
48
- sd = {k: v for k, v in sd.items() if k.startswith('lrm_generator')}
49
- sd_fc = {}
50
- for k, v in sd.items():
51
- if k.startswith('lrm_generator.synthesizer.decoder.net.'):
52
- if k.startswith('lrm_generator.synthesizer.decoder.net.6.'): # last layer
53
- # Here we assume the density filed's isosurface threshold is t,
54
- # we reverse the sign of density filed to initialize SDF field.
55
- # -(w*x + b - t) = (-w)*x + (t - b)
56
- if 'weight' in k:
57
- sd_fc[k.replace('net.', 'net_sdf.')] = -v[0:1]
58
- else:
59
- sd_fc[k.replace('net.', 'net_sdf.')] = 3.0 - v[0:1]
60
- sd_fc[k.replace('net.', 'net_rgb.')] = v[1:4]
61
- else:
62
- sd_fc[k.replace('net.', 'net_sdf.')] = v
63
- sd_fc[k.replace('net.', 'net_rgb.')] = v
64
- else:
65
- sd_fc[k] = v
66
- sd_fc = {k.replace('lrm_generator.', ''): v for k, v in sd_fc.items()}
67
- # missing `net_deformation` and `net_weight` parameters
68
- self.lrm_generator.load_state_dict(sd_fc, strict=False)
69
- print(f'Loaded weights from {init_ckpt}')
70
-
71
- self.validation_step_outputs = []
72
-
73
- def on_fit_start(self):
74
- device = torch.device(f'cuda:{self.global_rank}')
75
- self.lrm_generator.init_flexicubes_geometry(device)
76
- if self.global_rank == 0:
77
- os.makedirs(os.path.join(self.logdir, 'images'), exist_ok=True)
78
- os.makedirs(os.path.join(self.logdir, 'images_val'), exist_ok=True)
79
-
80
- def prepare_batch_data(self, batch):
81
- lrm_generator_input = {}
82
- render_gt = {}
83
-
84
- # input images
85
- images = batch['input_images']
86
- images = v2.functional.resize(
87
- images, self.input_size, interpolation=3, antialias=True).clamp(0, 1)
88
-
89
- lrm_generator_input['images'] = images.to(self.device)
90
-
91
- # input cameras and render cameras
92
- input_c2ws = batch['input_c2ws']
93
- input_Ks = batch['input_Ks']
94
- target_c2ws = batch['target_c2ws']
95
-
96
- render_c2ws = torch.cat([input_c2ws, target_c2ws], dim=1)
97
- render_w2cs = torch.linalg.inv(render_c2ws)
98
-
99
- input_extrinsics = input_c2ws.flatten(-2)
100
- input_extrinsics = input_extrinsics[:, :, :12]
101
- input_intrinsics = input_Ks.flatten(-2)
102
- input_intrinsics = torch.stack([
103
- input_intrinsics[:, :, 0], input_intrinsics[:, :, 4],
104
- input_intrinsics[:, :, 2], input_intrinsics[:, :, 5],
105
- ], dim=-1)
106
- cameras = torch.cat([input_extrinsics, input_intrinsics], dim=-1)
107
-
108
- # add noise to input_cameras
109
- cameras = cameras + torch.rand_like(cameras) * 0.04 - 0.02
110
-
111
- lrm_generator_input['cameras'] = cameras.to(self.device)
112
- lrm_generator_input['render_cameras'] = render_w2cs.to(self.device)
113
-
114
- # target images
115
- target_images = torch.cat([batch['input_images'], batch['target_images']], dim=1)
116
- target_depths = torch.cat([batch['input_depths'], batch['target_depths']], dim=1)
117
- target_alphas = torch.cat([batch['input_alphas'], batch['target_alphas']], dim=1)
118
- target_normals = torch.cat([batch['input_normals'], batch['target_normals']], dim=1)
119
-
120
- render_size = self.render_size
121
- target_images = v2.functional.resize(
122
- target_images, render_size, interpolation=3, antialias=True).clamp(0, 1)
123
- target_depths = v2.functional.resize(
124
- target_depths, render_size, interpolation=0, antialias=True)
125
- target_alphas = v2.functional.resize(
126
- target_alphas, render_size, interpolation=0, antialias=True)
127
- target_normals = v2.functional.resize(
128
- target_normals, render_size, interpolation=3, antialias=True)
129
-
130
- lrm_generator_input['render_size'] = render_size
131
-
132
- render_gt['target_images'] = target_images.to(self.device)
133
- render_gt['target_depths'] = target_depths.to(self.device)
134
- render_gt['target_alphas'] = target_alphas.to(self.device)
135
- render_gt['target_normals'] = target_normals.to(self.device)
136
-
137
- return lrm_generator_input, render_gt
138
-
139
- def prepare_validation_batch_data(self, batch):
140
- lrm_generator_input = {}
141
-
142
- # input images
143
- images = batch['input_images']
144
- images = v2.functional.resize(
145
- images, self.input_size, interpolation=3, antialias=True).clamp(0, 1)
146
-
147
- lrm_generator_input['images'] = images.to(self.device)
148
-
149
- # input cameras
150
- input_c2ws = batch['input_c2ws'].flatten(-2)
151
- input_Ks = batch['input_Ks'].flatten(-2)
152
-
153
- input_extrinsics = input_c2ws[:, :, :12]
154
- input_intrinsics = torch.stack([
155
- input_Ks[:, :, 0], input_Ks[:, :, 4],
156
- input_Ks[:, :, 2], input_Ks[:, :, 5],
157
- ], dim=-1)
158
- cameras = torch.cat([input_extrinsics, input_intrinsics], dim=-1)
159
-
160
- lrm_generator_input['cameras'] = cameras.to(self.device)
161
-
162
- # render cameras
163
- render_c2ws = batch['render_c2ws']
164
- render_w2cs = torch.linalg.inv(render_c2ws)
165
-
166
- lrm_generator_input['render_cameras'] = render_w2cs.to(self.device)
167
- lrm_generator_input['render_size'] = 384
168
-
169
- return lrm_generator_input
170
-
171
- def forward_lrm_generator(self, images, cameras, render_cameras, render_size=512):
172
- planes = torch.utils.checkpoint.checkpoint(
173
- self.lrm_generator.forward_planes,
174
- images,
175
- cameras,
176
- use_reentrant=False,
177
- )
178
- out = self.lrm_generator.forward_geometry(
179
- planes,
180
- render_cameras,
181
- render_size,
182
- )
183
- return out
184
-
185
- def forward(self, lrm_generator_input):
186
- images = lrm_generator_input['images']
187
- cameras = lrm_generator_input['cameras']
188
- render_cameras = lrm_generator_input['render_cameras']
189
- render_size = lrm_generator_input['render_size']
190
-
191
- out = self.forward_lrm_generator(
192
- images, cameras, render_cameras, render_size=render_size)
193
-
194
- return out
195
-
196
- def training_step(self, batch, batch_idx):
197
- lrm_generator_input, render_gt = self.prepare_batch_data(batch)
198
-
199
- render_out = self.forward(lrm_generator_input)
200
-
201
- loss, loss_dict = self.compute_loss(render_out, render_gt)
202
-
203
- self.log_dict(loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=True)
204
-
205
- if self.global_step % 1000 == 0 and self.global_rank == 0:
206
- B, N, C, H, W = render_gt['target_images'].shape
207
- N_in = lrm_generator_input['images'].shape[1]
208
-
209
- target_images = rearrange(
210
- render_gt['target_images'], 'b n c h w -> b c h (n w)')
211
- render_images = rearrange(
212
- render_out['img'], 'b n c h w -> b c h (n w)')
213
- target_alphas = rearrange(
214
- repeat(render_gt['target_alphas'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
215
- render_alphas = rearrange(
216
- repeat(render_out['mask'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
217
- target_depths = rearrange(
218
- repeat(render_gt['target_depths'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
219
- render_depths = rearrange(
220
- repeat(render_out['depth'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
221
- target_normals = rearrange(
222
- render_gt['target_normals'], 'b n c h w -> b c h (n w)')
223
- render_normals = rearrange(
224
- render_out['normal'], 'b n c h w -> b c h (n w)')
225
- MAX_DEPTH = torch.max(target_depths)
226
- target_depths = target_depths / MAX_DEPTH * target_alphas
227
- render_depths = render_depths / MAX_DEPTH
228
-
229
- grid = torch.cat([
230
- target_images, render_images,
231
- target_alphas, render_alphas,
232
- target_depths, render_depths,
233
- target_normals, render_normals,
234
- ], dim=-2)
235
- grid = make_grid(grid, nrow=target_images.shape[0], normalize=True, value_range=(0, 1))
236
-
237
- image_path = os.path.join(self.logdir, 'images', f'train_{self.global_step:07d}.png')
238
- save_image(grid, image_path)
239
- print(f"Saved image to {image_path}")
240
-
241
- return loss
242
-
243
- def compute_loss(self, render_out, render_gt):
244
- # NOTE: the rgb value range of OpenLRM is [0, 1]
245
- render_images = render_out['img']
246
- target_images = render_gt['target_images'].to(render_images)
247
- render_images = rearrange(render_images, 'b n ... -> (b n) ...') * 2.0 - 1.0
248
- target_images = rearrange(target_images, 'b n ... -> (b n) ...') * 2.0 - 1.0
249
- loss_mse = F.mse_loss(render_images, target_images)
250
- loss_lpips = 2.0 * self.lpips(render_images, target_images)
251
-
252
- render_alphas = render_out['mask']
253
- target_alphas = render_gt['target_alphas']
254
- loss_mask = F.mse_loss(render_alphas, target_alphas)
255
-
256
- render_depths = render_out['depth']
257
- target_depths = render_gt['target_depths']
258
- loss_depth = 0.5 * F.l1_loss(render_depths[target_alphas>0], target_depths[target_alphas>0])
259
-
260
- render_normals = render_out['normal'] * 2.0 - 1.0
261
- target_normals = render_gt['target_normals'] * 2.0 - 1.0
262
- similarity = (render_normals * target_normals).sum(dim=-3).abs()
263
- normal_mask = target_alphas.squeeze(-3)
264
- loss_normal = 1 - similarity[normal_mask>0].mean()
265
- loss_normal = 0.2 * loss_normal
266
-
267
- # flexicubes regularization loss
268
- sdf = render_out['sdf']
269
- sdf_reg_loss = render_out['sdf_reg_loss']
270
- sdf_reg_loss_entropy = sdf_reg_loss_batch(sdf, self.lrm_generator.geometry.all_edges).mean() * 0.01
271
- _, flexicubes_surface_reg, flexicubes_weights_reg = sdf_reg_loss
272
- flexicubes_surface_reg = flexicubes_surface_reg.mean() * 0.5
273
- flexicubes_weights_reg = flexicubes_weights_reg.mean() * 0.1
274
-
275
- loss_reg = sdf_reg_loss_entropy + flexicubes_surface_reg + flexicubes_weights_reg
276
-
277
- loss = loss_mse + loss_lpips + loss_mask + loss_normal + loss_reg
278
-
279
- prefix = 'train'
280
- loss_dict = {}
281
- loss_dict.update({f'{prefix}/loss_mse': loss_mse})
282
- loss_dict.update({f'{prefix}/loss_lpips': loss_lpips})
283
- loss_dict.update({f'{prefix}/loss_mask': loss_mask})
284
- loss_dict.update({f'{prefix}/loss_normal': loss_normal})
285
- loss_dict.update({f'{prefix}/loss_depth': loss_depth})
286
- loss_dict.update({f'{prefix}/loss_reg_sdf': sdf_reg_loss_entropy})
287
- loss_dict.update({f'{prefix}/loss_reg_surface': flexicubes_surface_reg})
288
- loss_dict.update({f'{prefix}/loss_reg_weights': flexicubes_weights_reg})
289
- loss_dict.update({f'{prefix}/loss': loss})
290
-
291
- return loss, loss_dict
292
-
293
- @torch.no_grad()
294
- def validation_step(self, batch, batch_idx):
295
- lrm_generator_input = self.prepare_validation_batch_data(batch)
296
-
297
- render_out = self.forward(lrm_generator_input)
298
- render_images = render_out['img']
299
- render_images = rearrange(render_images, 'b n c h w -> b c h (n w)')
300
-
301
- self.validation_step_outputs.append(render_images)
302
-
303
- def on_validation_epoch_end(self):
304
- images = torch.cat(self.validation_step_outputs, dim=-1)
305
-
306
- all_images = self.all_gather(images)
307
- all_images = rearrange(all_images, 'r b c h w -> (r b) c h w')
308
-
309
- if self.global_rank == 0:
310
- image_path = os.path.join(self.logdir, 'images_val', f'val_{self.global_step:07d}.png')
311
-
312
- grid = make_grid(all_images, nrow=1, normalize=True, value_range=(0, 1))
313
- save_image(grid, image_path)
314
- print(f"Saved image to {image_path}")
315
-
316
- self.validation_step_outputs.clear()
317
-
318
- def configure_optimizers(self):
319
- lr = self.learning_rate
320
-
321
- optimizer = torch.optim.AdamW(
322
- self.lrm_generator.parameters(), lr=lr, betas=(0.90, 0.95), weight_decay=0.01)
323
- scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 100000, eta_min=0)
324
-
325
- return {'optimizer': optimizer, 'lr_scheduler': scheduler}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
instant-mesh/src/models/__init__.py DELETED
File without changes
instant-mesh/src/models/decoder/__init__.py DELETED
File without changes
instant-mesh/src/models/decoder/transformer.py DELETED
@@ -1,123 +0,0 @@
1
- # Copyright (c) 2023, Zexin He
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # https://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
-
16
- import torch
17
- import torch.nn as nn
18
-
19
-
20
- class BasicTransformerBlock(nn.Module):
21
- """
22
- Transformer block that takes in a cross-attention condition and another modulation vector applied to sub-blocks.
23
- """
24
- # use attention from torch.nn.MultiHeadAttention
25
- # Block contains a cross-attention layer, a self-attention layer, and a MLP
26
- def __init__(
27
- self,
28
- inner_dim: int,
29
- cond_dim: int,
30
- num_heads: int,
31
- eps: float,
32
- attn_drop: float = 0.,
33
- attn_bias: bool = False,
34
- mlp_ratio: float = 4.,
35
- mlp_drop: float = 0.,
36
- ):
37
- super().__init__()
38
-
39
- self.norm1 = nn.LayerNorm(inner_dim)
40
- self.cross_attn = nn.MultiheadAttention(
41
- embed_dim=inner_dim, num_heads=num_heads, kdim=cond_dim, vdim=cond_dim,
42
- dropout=attn_drop, bias=attn_bias, batch_first=True)
43
- self.norm2 = nn.LayerNorm(inner_dim)
44
- self.self_attn = nn.MultiheadAttention(
45
- embed_dim=inner_dim, num_heads=num_heads,
46
- dropout=attn_drop, bias=attn_bias, batch_first=True)
47
- self.norm3 = nn.LayerNorm(inner_dim)
48
- self.mlp = nn.Sequential(
49
- nn.Linear(inner_dim, int(inner_dim * mlp_ratio)),
50
- nn.GELU(),
51
- nn.Dropout(mlp_drop),
52
- nn.Linear(int(inner_dim * mlp_ratio), inner_dim),
53
- nn.Dropout(mlp_drop),
54
- )
55
-
56
- def forward(self, x, cond):
57
- # x: [N, L, D]
58
- # cond: [N, L_cond, D_cond]
59
- x = x + self.cross_attn(self.norm1(x), cond, cond)[0]
60
- before_sa = self.norm2(x)
61
- x = x + self.self_attn(before_sa, before_sa, before_sa)[0]
62
- x = x + self.mlp(self.norm3(x))
63
- return x
64
-
65
-
66
- class TriplaneTransformer(nn.Module):
67
- """
68
- Transformer with condition that generates a triplane representation.
69
-
70
- Reference:
71
- Timm: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L486
72
- """
73
- def __init__(
74
- self,
75
- inner_dim: int,
76
- image_feat_dim: int,
77
- triplane_low_res: int,
78
- triplane_high_res: int,
79
- triplane_dim: int,
80
- num_layers: int,
81
- num_heads: int,
82
- eps: float = 1e-6,
83
- ):
84
- super().__init__()
85
-
86
- # attributes
87
- self.triplane_low_res = triplane_low_res
88
- self.triplane_high_res = triplane_high_res
89
- self.triplane_dim = triplane_dim
90
-
91
- # modules
92
- # initialize pos_embed with 1/sqrt(dim) * N(0, 1)
93
- self.pos_embed = nn.Parameter(torch.randn(1, 3*triplane_low_res**2, inner_dim) * (1. / inner_dim) ** 0.5)
94
- self.layers = nn.ModuleList([
95
- BasicTransformerBlock(
96
- inner_dim=inner_dim, cond_dim=image_feat_dim, num_heads=num_heads, eps=eps)
97
- for _ in range(num_layers)
98
- ])
99
- self.norm = nn.LayerNorm(inner_dim, eps=eps)
100
- self.deconv = nn.ConvTranspose2d(inner_dim, triplane_dim, kernel_size=2, stride=2, padding=0)
101
-
102
- def forward(self, image_feats):
103
- # image_feats: [N, L_cond, D_cond]
104
-
105
- N = image_feats.shape[0]
106
- H = W = self.triplane_low_res
107
- L = 3 * H * W
108
-
109
- x = self.pos_embed.repeat(N, 1, 1) # [N, L, D]
110
- for layer in self.layers:
111
- x = layer(x, image_feats)
112
- x = self.norm(x)
113
-
114
- # separate each plane and apply deconv
115
- x = x.view(N, 3, H, W, -1)
116
- x = torch.einsum('nihwd->indhw', x) # [3, N, D, H, W]
117
- x = x.contiguous().view(3*N, -1, H, W) # [3*N, D, H, W]
118
- x = self.deconv(x) # [3*N, D', H', W']
119
- x = x.view(3, N, *x.shape[-3:]) # [3, N, D', H', W']
120
- x = torch.einsum('indhw->nidhw', x) # [N, 3, D', H', W']
121
- x = x.contiguous()
122
-
123
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
instant-mesh/src/models/encoder/__init__.py DELETED
File without changes
instant-mesh/src/models/encoder/dino.py DELETED
@@ -1,550 +0,0 @@
1
- # coding=utf-8
2
- # Copyright 2021 Google AI, Ross Wightman, The HuggingFace Inc. team. All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- """ PyTorch ViT model."""
16
-
17
-
18
- import collections.abc
19
- import math
20
- from typing import Dict, List, Optional, Set, Tuple, Union
21
-
22
- import torch
23
- from torch import nn
24
-
25
- from transformers.activations import ACT2FN
26
- from transformers.modeling_outputs import (
27
- BaseModelOutput,
28
- BaseModelOutputWithPooling,
29
- )
30
- from transformers import PreTrainedModel, ViTConfig
31
- from transformers.pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
32
-
33
-
34
- class ViTEmbeddings(nn.Module):
35
- """
36
- Construct the CLS token, position and patch embeddings. Optionally, also the mask token.
37
- """
38
-
39
- def __init__(self, config: ViTConfig, use_mask_token: bool = False) -> None:
40
- super().__init__()
41
-
42
- self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size))
43
- self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None
44
- self.patch_embeddings = ViTPatchEmbeddings(config)
45
- num_patches = self.patch_embeddings.num_patches
46
- self.position_embeddings = nn.Parameter(torch.randn(1, num_patches + 1, config.hidden_size))
47
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
48
- self.config = config
49
-
50
- def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
51
- """
52
- This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
53
- resolution images.
54
-
55
- Source:
56
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
57
- """
58
-
59
- num_patches = embeddings.shape[1] - 1
60
- num_positions = self.position_embeddings.shape[1] - 1
61
- if num_patches == num_positions and height == width:
62
- return self.position_embeddings
63
- class_pos_embed = self.position_embeddings[:, 0]
64
- patch_pos_embed = self.position_embeddings[:, 1:]
65
- dim = embeddings.shape[-1]
66
- h0 = height // self.config.patch_size
67
- w0 = width // self.config.patch_size
68
- # we add a small number to avoid floating point error in the interpolation
69
- # see discussion at https://github.com/facebookresearch/dino/issues/8
70
- h0, w0 = h0 + 0.1, w0 + 0.1
71
- patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
72
- patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
73
- patch_pos_embed = nn.functional.interpolate(
74
- patch_pos_embed,
75
- scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
76
- mode="bicubic",
77
- align_corners=False,
78
- )
79
- assert int(h0) == patch_pos_embed.shape[-2] and int(w0) == patch_pos_embed.shape[-1]
80
- patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
81
- return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
82
-
83
- def forward(
84
- self,
85
- pixel_values: torch.Tensor,
86
- bool_masked_pos: Optional[torch.BoolTensor] = None,
87
- interpolate_pos_encoding: bool = False,
88
- ) -> torch.Tensor:
89
- batch_size, num_channels, height, width = pixel_values.shape
90
- embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
91
-
92
- if bool_masked_pos is not None:
93
- seq_length = embeddings.shape[1]
94
- mask_tokens = self.mask_token.expand(batch_size, seq_length, -1)
95
- # replace the masked visual tokens by mask_tokens
96
- mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
97
- embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
98
-
99
- # add the [CLS] token to the embedded patch tokens
100
- cls_tokens = self.cls_token.expand(batch_size, -1, -1)
101
- embeddings = torch.cat((cls_tokens, embeddings), dim=1)
102
-
103
- # add positional encoding to each token
104
- if interpolate_pos_encoding:
105
- embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
106
- else:
107
- embeddings = embeddings + self.position_embeddings
108
-
109
- embeddings = self.dropout(embeddings)
110
-
111
- return embeddings
112
-
113
-
114
- class ViTPatchEmbeddings(nn.Module):
115
- """
116
- This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
117
- `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
118
- Transformer.
119
- """
120
-
121
- def __init__(self, config):
122
- super().__init__()
123
- image_size, patch_size = config.image_size, config.patch_size
124
- num_channels, hidden_size = config.num_channels, config.hidden_size
125
-
126
- image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
127
- patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
128
- num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
129
- self.image_size = image_size
130
- self.patch_size = patch_size
131
- self.num_channels = num_channels
132
- self.num_patches = num_patches
133
-
134
- self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
135
-
136
- def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
137
- batch_size, num_channels, height, width = pixel_values.shape
138
- if num_channels != self.num_channels:
139
- raise ValueError(
140
- "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
141
- f" Expected {self.num_channels} but got {num_channels}."
142
- )
143
- if not interpolate_pos_encoding:
144
- if height != self.image_size[0] or width != self.image_size[1]:
145
- raise ValueError(
146
- f"Input image size ({height}*{width}) doesn't match model"
147
- f" ({self.image_size[0]}*{self.image_size[1]})."
148
- )
149
- embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
150
- return embeddings
151
-
152
-
153
- class ViTSelfAttention(nn.Module):
154
- def __init__(self, config: ViTConfig) -> None:
155
- super().__init__()
156
- if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
157
- raise ValueError(
158
- f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
159
- f"heads {config.num_attention_heads}."
160
- )
161
-
162
- self.num_attention_heads = config.num_attention_heads
163
- self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
164
- self.all_head_size = self.num_attention_heads * self.attention_head_size
165
-
166
- self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
167
- self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
168
- self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
169
-
170
- self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
171
-
172
- def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
173
- new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
174
- x = x.view(new_x_shape)
175
- return x.permute(0, 2, 1, 3)
176
-
177
- def forward(
178
- self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
179
- ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
180
- mixed_query_layer = self.query(hidden_states)
181
-
182
- key_layer = self.transpose_for_scores(self.key(hidden_states))
183
- value_layer = self.transpose_for_scores(self.value(hidden_states))
184
- query_layer = self.transpose_for_scores(mixed_query_layer)
185
-
186
- # Take the dot product between "query" and "key" to get the raw attention scores.
187
- attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
188
-
189
- attention_scores = attention_scores / math.sqrt(self.attention_head_size)
190
-
191
- # Normalize the attention scores to probabilities.
192
- attention_probs = nn.functional.softmax(attention_scores, dim=-1)
193
-
194
- # This is actually dropping out entire tokens to attend to, which might
195
- # seem a bit unusual, but is taken from the original Transformer paper.
196
- attention_probs = self.dropout(attention_probs)
197
-
198
- # Mask heads if we want to
199
- if head_mask is not None:
200
- attention_probs = attention_probs * head_mask
201
-
202
- context_layer = torch.matmul(attention_probs, value_layer)
203
-
204
- context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
205
- new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
206
- context_layer = context_layer.view(new_context_layer_shape)
207
-
208
- outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
209
-
210
- return outputs
211
-
212
-
213
- class ViTSelfOutput(nn.Module):
214
- """
215
- The residual connection is defined in ViTLayer instead of here (as is the case with other models), due to the
216
- layernorm applied before each block.
217
- """
218
-
219
- def __init__(self, config: ViTConfig) -> None:
220
- super().__init__()
221
- self.dense = nn.Linear(config.hidden_size, config.hidden_size)
222
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
223
-
224
- def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
225
- hidden_states = self.dense(hidden_states)
226
- hidden_states = self.dropout(hidden_states)
227
-
228
- return hidden_states
229
-
230
-
231
- class ViTAttention(nn.Module):
232
- def __init__(self, config: ViTConfig) -> None:
233
- super().__init__()
234
- self.attention = ViTSelfAttention(config)
235
- self.output = ViTSelfOutput(config)
236
- self.pruned_heads = set()
237
-
238
- def prune_heads(self, heads: Set[int]) -> None:
239
- if len(heads) == 0:
240
- return
241
- heads, index = find_pruneable_heads_and_indices(
242
- heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
243
- )
244
-
245
- # Prune linear layers
246
- self.attention.query = prune_linear_layer(self.attention.query, index)
247
- self.attention.key = prune_linear_layer(self.attention.key, index)
248
- self.attention.value = prune_linear_layer(self.attention.value, index)
249
- self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
250
-
251
- # Update hyper params and store pruned heads
252
- self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
253
- self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
254
- self.pruned_heads = self.pruned_heads.union(heads)
255
-
256
- def forward(
257
- self,
258
- hidden_states: torch.Tensor,
259
- head_mask: Optional[torch.Tensor] = None,
260
- output_attentions: bool = False,
261
- ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
262
- self_outputs = self.attention(hidden_states, head_mask, output_attentions)
263
-
264
- attention_output = self.output(self_outputs[0], hidden_states)
265
-
266
- outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
267
- return outputs
268
-
269
-
270
- class ViTIntermediate(nn.Module):
271
- def __init__(self, config: ViTConfig) -> None:
272
- super().__init__()
273
- self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
274
- if isinstance(config.hidden_act, str):
275
- self.intermediate_act_fn = ACT2FN[config.hidden_act]
276
- else:
277
- self.intermediate_act_fn = config.hidden_act
278
-
279
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
280
- hidden_states = self.dense(hidden_states)
281
- hidden_states = self.intermediate_act_fn(hidden_states)
282
-
283
- return hidden_states
284
-
285
-
286
- class ViTOutput(nn.Module):
287
- def __init__(self, config: ViTConfig) -> None:
288
- super().__init__()
289
- self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
290
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
291
-
292
- def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
293
- hidden_states = self.dense(hidden_states)
294
- hidden_states = self.dropout(hidden_states)
295
-
296
- hidden_states = hidden_states + input_tensor
297
-
298
- return hidden_states
299
-
300
-
301
- def modulate(x, shift, scale):
302
- return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
303
-
304
-
305
- class ViTLayer(nn.Module):
306
- """This corresponds to the Block class in the timm implementation."""
307
-
308
- def __init__(self, config: ViTConfig) -> None:
309
- super().__init__()
310
- self.chunk_size_feed_forward = config.chunk_size_feed_forward
311
- self.seq_len_dim = 1
312
- self.attention = ViTAttention(config)
313
- self.intermediate = ViTIntermediate(config)
314
- self.output = ViTOutput(config)
315
- self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
316
- self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
317
-
318
- self.adaLN_modulation = nn.Sequential(
319
- nn.SiLU(),
320
- nn.Linear(config.hidden_size, 4 * config.hidden_size, bias=True)
321
- )
322
- nn.init.constant_(self.adaLN_modulation[-1].weight, 0)
323
- nn.init.constant_(self.adaLN_modulation[-1].bias, 0)
324
-
325
- def forward(
326
- self,
327
- hidden_states: torch.Tensor,
328
- adaln_input: torch.Tensor = None,
329
- head_mask: Optional[torch.Tensor] = None,
330
- output_attentions: bool = False,
331
- ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
332
- shift_msa, scale_msa, shift_mlp, scale_mlp = self.adaLN_modulation(adaln_input).chunk(4, dim=1)
333
-
334
- self_attention_outputs = self.attention(
335
- modulate(self.layernorm_before(hidden_states), shift_msa, scale_msa), # in ViT, layernorm is applied before self-attention
336
- head_mask,
337
- output_attentions=output_attentions,
338
- )
339
- attention_output = self_attention_outputs[0]
340
- outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
341
-
342
- # first residual connection
343
- hidden_states = attention_output + hidden_states
344
-
345
- # in ViT, layernorm is also applied after self-attention
346
- layer_output = modulate(self.layernorm_after(hidden_states), shift_mlp, scale_mlp)
347
- layer_output = self.intermediate(layer_output)
348
-
349
- # second residual connection is done here
350
- layer_output = self.output(layer_output, hidden_states)
351
-
352
- outputs = (layer_output,) + outputs
353
-
354
- return outputs
355
-
356
-
357
- class ViTEncoder(nn.Module):
358
- def __init__(self, config: ViTConfig) -> None:
359
- super().__init__()
360
- self.config = config
361
- self.layer = nn.ModuleList([ViTLayer(config) for _ in range(config.num_hidden_layers)])
362
- self.gradient_checkpointing = False
363
-
364
- def forward(
365
- self,
366
- hidden_states: torch.Tensor,
367
- adaln_input: torch.Tensor = None,
368
- head_mask: Optional[torch.Tensor] = None,
369
- output_attentions: bool = False,
370
- output_hidden_states: bool = False,
371
- return_dict: bool = True,
372
- ) -> Union[tuple, BaseModelOutput]:
373
- all_hidden_states = () if output_hidden_states else None
374
- all_self_attentions = () if output_attentions else None
375
-
376
- for i, layer_module in enumerate(self.layer):
377
- if output_hidden_states:
378
- all_hidden_states = all_hidden_states + (hidden_states,)
379
-
380
- layer_head_mask = head_mask[i] if head_mask is not None else None
381
-
382
- if self.gradient_checkpointing and self.training:
383
- layer_outputs = self._gradient_checkpointing_func(
384
- layer_module.__call__,
385
- hidden_states,
386
- adaln_input,
387
- layer_head_mask,
388
- output_attentions,
389
- )
390
- else:
391
- layer_outputs = layer_module(hidden_states, adaln_input, layer_head_mask, output_attentions)
392
-
393
- hidden_states = layer_outputs[0]
394
-
395
- if output_attentions:
396
- all_self_attentions = all_self_attentions + (layer_outputs[1],)
397
-
398
- if output_hidden_states:
399
- all_hidden_states = all_hidden_states + (hidden_states,)
400
-
401
- if not return_dict:
402
- return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
403
- return BaseModelOutput(
404
- last_hidden_state=hidden_states,
405
- hidden_states=all_hidden_states,
406
- attentions=all_self_attentions,
407
- )
408
-
409
-
410
- class ViTPreTrainedModel(PreTrainedModel):
411
- """
412
- An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
413
- models.
414
- """
415
-
416
- config_class = ViTConfig
417
- base_model_prefix = "vit"
418
- main_input_name = "pixel_values"
419
- supports_gradient_checkpointing = True
420
- _no_split_modules = ["ViTEmbeddings", "ViTLayer"]
421
-
422
- def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
423
- """Initialize the weights"""
424
- if isinstance(module, (nn.Linear, nn.Conv2d)):
425
- # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
426
- # `trunc_normal_cpu` not implemented in `half` issues
427
- module.weight.data = nn.init.trunc_normal_(
428
- module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range
429
- ).to(module.weight.dtype)
430
- if module.bias is not None:
431
- module.bias.data.zero_()
432
- elif isinstance(module, nn.LayerNorm):
433
- module.bias.data.zero_()
434
- module.weight.data.fill_(1.0)
435
- elif isinstance(module, ViTEmbeddings):
436
- module.position_embeddings.data = nn.init.trunc_normal_(
437
- module.position_embeddings.data.to(torch.float32),
438
- mean=0.0,
439
- std=self.config.initializer_range,
440
- ).to(module.position_embeddings.dtype)
441
-
442
- module.cls_token.data = nn.init.trunc_normal_(
443
- module.cls_token.data.to(torch.float32),
444
- mean=0.0,
445
- std=self.config.initializer_range,
446
- ).to(module.cls_token.dtype)
447
-
448
-
449
- class ViTModel(ViTPreTrainedModel):
450
- def __init__(self, config: ViTConfig, add_pooling_layer: bool = True, use_mask_token: bool = False):
451
- super().__init__(config)
452
- self.config = config
453
-
454
- self.embeddings = ViTEmbeddings(config, use_mask_token=use_mask_token)
455
- self.encoder = ViTEncoder(config)
456
-
457
- self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
458
- self.pooler = ViTPooler(config) if add_pooling_layer else None
459
-
460
- # Initialize weights and apply final processing
461
- self.post_init()
462
-
463
- def get_input_embeddings(self) -> ViTPatchEmbeddings:
464
- return self.embeddings.patch_embeddings
465
-
466
- def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
467
- """
468
- Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
469
- class PreTrainedModel
470
- """
471
- for layer, heads in heads_to_prune.items():
472
- self.encoder.layer[layer].attention.prune_heads(heads)
473
-
474
- def forward(
475
- self,
476
- pixel_values: Optional[torch.Tensor] = None,
477
- adaln_input: Optional[torch.Tensor] = None,
478
- bool_masked_pos: Optional[torch.BoolTensor] = None,
479
- head_mask: Optional[torch.Tensor] = None,
480
- output_attentions: Optional[bool] = None,
481
- output_hidden_states: Optional[bool] = None,
482
- interpolate_pos_encoding: Optional[bool] = None,
483
- return_dict: Optional[bool] = None,
484
- ) -> Union[Tuple, BaseModelOutputWithPooling]:
485
- r"""
486
- bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
487
- Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
488
- """
489
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
490
- output_hidden_states = (
491
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
492
- )
493
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
494
-
495
- if pixel_values is None:
496
- raise ValueError("You have to specify pixel_values")
497
-
498
- # Prepare head mask if needed
499
- # 1.0 in head_mask indicate we keep the head
500
- # attention_probs has shape bsz x n_heads x N x N
501
- # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
502
- # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
503
- head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
504
-
505
- # TODO: maybe have a cleaner way to cast the input (from `ImageProcessor` side?)
506
- expected_dtype = self.embeddings.patch_embeddings.projection.weight.dtype
507
- if pixel_values.dtype != expected_dtype:
508
- pixel_values = pixel_values.to(expected_dtype)
509
-
510
- embedding_output = self.embeddings(
511
- pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
512
- )
513
-
514
- encoder_outputs = self.encoder(
515
- embedding_output,
516
- adaln_input=adaln_input,
517
- head_mask=head_mask,
518
- output_attentions=output_attentions,
519
- output_hidden_states=output_hidden_states,
520
- return_dict=return_dict,
521
- )
522
- sequence_output = encoder_outputs[0]
523
- sequence_output = self.layernorm(sequence_output)
524
- pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
525
-
526
- if not return_dict:
527
- head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,)
528
- return head_outputs + encoder_outputs[1:]
529
-
530
- return BaseModelOutputWithPooling(
531
- last_hidden_state=sequence_output,
532
- pooler_output=pooled_output,
533
- hidden_states=encoder_outputs.hidden_states,
534
- attentions=encoder_outputs.attentions,
535
- )
536
-
537
-
538
- class ViTPooler(nn.Module):
539
- def __init__(self, config: ViTConfig):
540
- super().__init__()
541
- self.dense = nn.Linear(config.hidden_size, config.hidden_size)
542
- self.activation = nn.Tanh()
543
-
544
- def forward(self, hidden_states):
545
- # We "pool" the model by simply taking the hidden state corresponding
546
- # to the first token.
547
- first_token_tensor = hidden_states[:, 0]
548
- pooled_output = self.dense(first_token_tensor)
549
- pooled_output = self.activation(pooled_output)
550
- return pooled_output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
instant-mesh/src/models/encoder/dino_wrapper.py DELETED
@@ -1,80 +0,0 @@
1
- # Copyright (c) 2023, Zexin He
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # https://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
-
16
- import torch.nn as nn
17
- from transformers import ViTImageProcessor
18
- from einops import rearrange, repeat
19
- from .dino import ViTModel
20
-
21
-
22
- class DinoWrapper(nn.Module):
23
- """
24
- Dino v1 wrapper using huggingface transformer implementation.
25
- """
26
- def __init__(self, model_name: str, freeze: bool = True):
27
- super().__init__()
28
- self.model, self.processor = self._build_dino(model_name)
29
- self.camera_embedder = nn.Sequential(
30
- nn.Linear(16, self.model.config.hidden_size, bias=True),
31
- nn.SiLU(),
32
- nn.Linear(self.model.config.hidden_size, self.model.config.hidden_size, bias=True)
33
- )
34
- if freeze:
35
- self._freeze()
36
-
37
- def forward(self, image, camera):
38
- # image: [B, N, C, H, W]
39
- # camera: [B, N, D]
40
- # RGB image with [0,1] scale and properly sized
41
- if image.ndim == 5:
42
- image = rearrange(image, 'b n c h w -> (b n) c h w')
43
- dtype = image.dtype
44
- inputs = self.processor(
45
- images=image.float(),
46
- return_tensors="pt",
47
- do_rescale=False,
48
- do_resize=False,
49
- ).to(self.model.device).to(dtype)
50
- # embed camera
51
- N = camera.shape[1]
52
- camera_embeddings = self.camera_embedder(camera)
53
- camera_embeddings = rearrange(camera_embeddings, 'b n d -> (b n) d')
54
- embeddings = camera_embeddings
55
- # This resampling of positional embedding uses bicubic interpolation
56
- outputs = self.model(**inputs, adaln_input=embeddings, interpolate_pos_encoding=True)
57
- last_hidden_states = outputs.last_hidden_state
58
- return last_hidden_states
59
-
60
- def _freeze(self):
61
- print(f"======== Freezing DinoWrapper ========")
62
- self.model.eval()
63
- for name, param in self.model.named_parameters():
64
- param.requires_grad = False
65
-
66
- @staticmethod
67
- def _build_dino(model_name: str, proxy_error_retries: int = 3, proxy_error_cooldown: int = 5):
68
- import requests
69
- try:
70
- model = ViTModel.from_pretrained(model_name, add_pooling_layer=False)
71
- processor = ViTImageProcessor.from_pretrained(model_name)
72
- return model, processor
73
- except requests.exceptions.ProxyError as err:
74
- if proxy_error_retries > 0:
75
- print(f"Huggingface ProxyError: Retrying in {proxy_error_cooldown} seconds...")
76
- import time
77
- time.sleep(proxy_error_cooldown)
78
- return DinoWrapper._build_dino(model_name, proxy_error_retries - 1, proxy_error_cooldown)
79
- else:
80
- raise err
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
instant-mesh/src/models/geometry/__init__.py DELETED
@@ -1,7 +0,0 @@
1
- # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- #
3
- # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
4
- # and proprietary rights in and to this software, related documentation
5
- # and any modifications thereto. Any use, reproduction, disclosure or
6
- # distribution of this software and related documentation without an express
7
- # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
 
 
 
 
 
 
 
 
instant-mesh/src/models/geometry/camera/__init__.py DELETED
@@ -1,16 +0,0 @@
1
- # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- #
3
- # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
4
- # and proprietary rights in and to this software, related documentation
5
- # and any modifications thereto. Any use, reproduction, disclosure or
6
- # distribution of this software and related documentation without an express
7
- # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
8
-
9
- import torch
10
- from torch import nn
11
-
12
-
13
- class Camera(nn.Module):
14
- def __init__(self):
15
- super(Camera, self).__init__()
16
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
instant-mesh/src/models/geometry/camera/perspective_camera.py DELETED
@@ -1,35 +0,0 @@
1
- # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- #
3
- # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
4
- # and proprietary rights in and to this software, related documentation
5
- # and any modifications thereto. Any use, reproduction, disclosure or
6
- # distribution of this software and related documentation without an express
7
- # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
8
-
9
- import torch
10
- from . import Camera
11
- import numpy as np
12
-
13
-
14
- def projection(x=0.1, n=1.0, f=50.0, near_plane=None):
15
- if near_plane is None:
16
- near_plane = n
17
- return np.array(
18
- [[n / x, 0, 0, 0],
19
- [0, n / -x, 0, 0],
20
- [0, 0, -(f + near_plane) / (f - near_plane), -(2 * f * near_plane) / (f - near_plane)],
21
- [0, 0, -1, 0]]).astype(np.float32)
22
-
23
-
24
- class PerspectiveCamera(Camera):
25
- def __init__(self, fovy=49.0, device='cuda'):
26
- super(PerspectiveCamera, self).__init__()
27
- self.device = device
28
- focal = np.tan(fovy / 180.0 * np.pi * 0.5)
29
- self.proj_mtx = torch.from_numpy(projection(x=focal, f=1000.0, n=1.0, near_plane=0.1)).to(self.device).unsqueeze(dim=0)
30
-
31
- def project(self, points_bxnx4):
32
- out = torch.matmul(
33
- points_bxnx4,
34
- torch.transpose(self.proj_mtx, 1, 2))
35
- return out