""" Author: Luigi Piccinelli Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/) """ import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange from timm.models.layers import trunc_normal_ from unidepth.layers import (MLP, AttentionBlock, ConvUpsampleShuffle, NystromBlock, PositionEmbeddingSine) from unidepth.utils.geometric import flat_interpolate, generate_rays from unidepth.utils.positional_embedding import generate_fourier_features class ListAdapter(nn.Module): def __init__(self, input_dims: list[int], hidden_dim: int): super().__init__() self.input_adapters = nn.ModuleList([]) self.num_chunks = len(input_dims) self.checkpoint = True for input_dim in input_dims: self.input_adapters.append( nn.Sequential( nn.LayerNorm(input_dim), nn.Linear(input_dim, hidden_dim), nn.GELU() ) ) def forward(self, x: torch.Tensor, splits: torch.Tensor) -> torch.Tensor: xs = torch.split(x, splits.int().tolist(), dim=-1) xs = [adapter(x) for x, adapter in zip(xs, self.input_adapters)] return torch.cat(xs, dim=-1) class CameraHead(nn.Module): def __init__( self, hidden_dim: int, expansion: int = 4, dropout: float = 0.0, **kwargs, ): super().__init__() self.aggregate1 = AttentionBlock( hidden_dim, num_heads=1, expansion=expansion, dropout=dropout ) self.aggregate2 = AttentionBlock( hidden_dim, num_heads=1, expansion=expansion, dropout=dropout ) self.latents_pos = nn.Parameter( torch.randn(1, 4, hidden_dim), requires_grad=True ) self.in_features = MLP(hidden_dim, expansion=2, dropout=dropout) self.project_cls = MLP(hidden_dim, dropout=dropout) self.out = MLP(hidden_dim, expansion=2, dropout=0.0, output_dim=1) def fill_intrinsics(self, x): camera_intrinsics = torch.zeros( x.shape[0], 3, 3, device=x.device, requires_grad=False ) camera_intrinsics[:, 0, 0] = x[:, 0].exp() camera_intrinsics[:, 1, 1] = x[:, 1].exp() camera_intrinsics[:, 0, 2] = x[:, 2].sigmoid() camera_intrinsics[:, 1, 2] = x[:, 3].sigmoid() camera_intrinsics[:, 2, 2] = 1.0 return camera_intrinsics def forward(self, features, cls_tokens, pos_embed) -> torch.Tensor: features = features.unbind(dim=-1) cls_tokens = self.project_cls(cls_tokens) latents_pos = self.latents_pos.expand(cls_tokens.shape[0], -1, -1) features = self.in_features(torch.cat(features, dim=1) + pos_embed) features = torch.cat((features, cls_tokens), dim=1) cls_tokens = self.aggregate1( cls_tokens, context=features, pos_embed=latents_pos ) cls_tokens = self.aggregate2( cls_tokens, context=features, pos_embed=latents_pos ) # project to intrinsics x = self.out(cls_tokens).squeeze(-1) camera_intrinsics = self.fill_intrinsics(x) return camera_intrinsics def set_shapes(self, shapes: tuple[int, int]): self.shapes = shapes class GlobalHead(nn.Module): def __init__( self, hidden_dim: int, camera_dim: int, expansion: int = 4, dropout: float = 0.0, **kwargs, ): super().__init__() self.camera_dim = camera_dim self.in_features = nn.Linear(hidden_dim, hidden_dim) self.project_rays = nn.Linear(camera_dim + 3, hidden_dim) self.aggregate1 = AttentionBlock( hidden_dim, num_heads=1, expansion=expansion, dropout=dropout ) self.aggregate2 = AttentionBlock( hidden_dim, num_heads=1, expansion=expansion, dropout=dropout ) self.project_cls = MLP(hidden_dim, dropout=dropout) self.out = MLP(hidden_dim, expansion=2, dropout=0.0, output_dim=1) def embed_rays(self, rays, shapes): rays_embedding = flat_interpolate(rays, old=self.original_shapes, new=shapes) rays_embedding = F.normalize(rays_embedding, dim=-1) rays_embedding = generate_fourier_features( rays_embedding, dim=self.camera_dim, max_freq=max(shapes) // 2, use_log=True, cat_orig=True, ) return rays_embedding def set_original_shapes(self, shapes: tuple[int, int]): self.original_shapes = shapes def set_shapes(self, shapes: tuple[int, int]): self.shapes = shapes def get_scaleshift(self, x): scale, shift = torch.chunk(x, 2, dim=1) scale = scale.exp().reshape(-1, 1, 1, 1) shift = shift.reshape(-1, 1, 1, 1) return scale, shift def forward(self, features, cls_tokens, rays) -> torch.Tensor: features = features.unbind(dim=-1) cls_tokens = self.project_cls(cls_tokens) rays_embedding = self.project_rays(self.embed_rays(rays, self.shapes)) rays_embedding = rays_embedding.repeat(1, len(features), 1) features = self.in_features(torch.cat(features, dim=1) + rays_embedding) features = torch.cat((features, cls_tokens), dim=1) cls_tokens = self.aggregate1(cls_tokens, context=features) cls_tokens = self.aggregate2(cls_tokens, context=features) x = self.out(cls_tokens).squeeze(-1) scale, shift = self.get_scaleshift(x) return scale, shift class DepthHead(nn.Module): def __init__( self, hidden_dim: int, num_heads: int = 8, expansion: int = 4, depths: int | list[int] = 4, checkpoint: bool = True, camera_dim: int = 256, num_resolutions: int = 4, dropout: float = 0.0, **kwargs, ) -> None: super().__init__() self.checkpoint = checkpoint self.camera_dim = camera_dim self.skip_depth = False self.to_latents = MLP(hidden_dim, expansion=2, dropout=dropout) self.features_channel_cat = nn.Linear(hidden_dim * num_resolutions, hidden_dim) self.aggregate_16 = AttentionBlock( hidden_dim, num_heads=1, expansion=expansion, dropout=dropout, context_dim=hidden_dim, ) self.prompt_camera = AttentionBlock( hidden_dim, num_heads=1, expansion=expansion, dropout=dropout, context_dim=hidden_dim, ) self.rays_layers = nn.ModuleList([]) self.ups = nn.ModuleList([]) self.process_layers = nn.ModuleList([]) self.norms, self.out_layers = nn.ModuleList([]), nn.ModuleList([]) self.confidence_norms, self.confidence_out_layers = nn.ModuleList( [] ), nn.ModuleList([]) for i, depth in enumerate(depths): blk_lst = nn.ModuleList([]) for _ in range(depth): blk_lst.append( NystromBlock( hidden_dim // int(2**i), num_heads=num_heads // int(2**i), expansion=expansion, dropout=dropout, ) ) self.process_layers.append(blk_lst) self.rays_layers.append(nn.Linear(camera_dim + 3, hidden_dim // int(2**i))) self.ups.append( ConvUpsampleShuffle( hidden_dim // int(2**i), expansion=expansion, kernel_size=7, num_layers=2, ) ) self.norms.append(nn.LayerNorm(hidden_dim // int(2 ** (i + 1)))) self.out_layers.append( nn.Conv2d(hidden_dim // int(2 ** (i + 1)), 1, 3, padding=1) ) self.confidence_norms.append(nn.LayerNorm(hidden_dim // int(2 ** (i + 1)))) self.confidence_out_layers.append( nn.Conv2d(hidden_dim // int(2 ** (i + 1)), 1, 3, padding=1) ) def set_original_shapes(self, shapes: tuple[int, int]): self.original_shapes = shapes def set_shapes(self, shapes: tuple[int, int]): self.shapes = shapes def embed_rays(self, rays, shapes): rays_embedding = flat_interpolate(rays, old=self.original_shapes, new=shapes) rays_embedding = F.normalize(rays_embedding, dim=-1) rays_embedding = generate_fourier_features( rays_embedding, dim=self.camera_dim, max_freq=max(shapes) // 2, use_log=True, cat_orig=True, ) return rays_embedding def project_rays(self, rays, shapes): embedded_rays = [] for i, layer in enumerate(self.rays_layers): embedded_rays.append( layer(self.embed_rays(rays, [(2**i) * x for x in shapes])) ) return embedded_rays def decode_depth(self, latents_16, rays, shapes): dtype = latents_16.dtype latents = latents_16 out_features, confidences, outs = [], [], [] for i, (up, layers, rays_embedding) in enumerate( zip(self.ups, self.process_layers, rays) ): for layer in layers: latents = layer(latents, pos_embed=rays_embedding) latents = up( rearrange( latents + rays_embedding, "b (h w) c -> b c h w", h=shapes[0] * int(2**i), w=shapes[1] * int(2**i), ).contiguous() ) out = rearrange( latents, "b (h w) c -> b h w c", h=shapes[0] * int(2 ** (1 + i)), w=shapes[1] * int(2 ** (1 + i)), ) out_features.append(out) for i, (norm, out_layer, features) in enumerate( zip(self.norms[::-1], self.out_layers[::-1], out_features[::-1]) ): features = norm(features) out_d = out_layer(features.permute(0, 3, 1, 2)) outs.append(out_d) out = sum( F.interpolate( x, size=outs[0].shape[-2:], mode="bilinear", ) for x in outs ) out = out / len(outs) # jit complains, fix as list (loose dyn input) out_shapes = [int(s) for s in out.shape[1:]] out = F.layer_norm(out.float(), out_shapes) out = out.clamp(-10.0, 10.0).exp().to(dtype, non_blocking=True) for i, (norm, out_layer, features) in enumerate( zip( self.confidence_norms[::-1], self.confidence_out_layers[::-1], out_features[::-1], ) ): features = norm(features) out_c = out_layer(features.permute(0, 3, 1, 2)) confidences.append(out_c) confidence = sum( F.interpolate( x, size=confidences[0].shape[-2:], mode="bilinear", ) for x in confidences ) confidence = confidence / len(confidences) confidence = torch.sigmoid(confidence) return out, confidence def init_latents(self, features, shapes): # Generate latents with init as pooled features features_channels = torch.cat(features, dim=-1) features_16 = self.features_channel_cat(features_channels) latents_16 = features_16 + self.to_latents( flat_interpolate(features_16, old=self.shapes, new=shapes, antialias=False) ) return latents_16 def forward( self, features: torch.Tensor, rays_hr: torch.Tensor, pos_embed, level_embed ) -> torch.Tensor: B = features.shape[0] features = features.unbind(dim=-1) shapes = self.shapes # camera_embedding rays_embeddings = self.project_rays(rays_hr, shapes) # Init latents init_latents_16 = self.init_latents(features, shapes) # Aggregate features: F -> D latents_16 = self.aggregate_16( init_latents_16, context=torch.cat(features, dim=1), pos_embed_context=pos_embed + level_embed, ) # Aggregate camera: D -> D|E latents_16 = self.prompt_camera(latents_16, context=rays_embeddings[0]) # Decode depth out, confidence = self.decode_depth(latents_16, rays_embeddings, shapes) return out, confidence, latents_16 class Decoder(nn.Module): def __init__( self, config, ): super().__init__() self.build(config) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=0.02) if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Conv2d): trunc_normal_(m.weight, std=0.02) if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): if m.bias is not None: nn.init.constant_(m.bias, 0) if m.weight is not None: nn.init.constant_(m.weight, 1.0) def get_adapted_features(self, features_flat, splits): features_flat_cat = torch.cat(features_flat, dim=-1) features_projected = self.input_adapter( features_flat_cat, splits ) # list [b hw c] shapes features = torch.chunk(features_projected, len(splits), dim=-1) return features def run_camera(self, cls_tokens, features, pos_embed, original_shapes, rays_gt): # get cls tokens projections cls_tokens_splits = torch.tensor( [x.shape[-1] for x in cls_tokens], device=features.device, requires_grad=False, dtype=features.dtype, ) cls_tokens = torch.cat(cls_tokens, dim=-1) cls_tokens = self.camera_token_adapter(cls_tokens, cls_tokens_splits) cls_tokens = torch.cat( torch.chunk(cls_tokens, len(cls_tokens_splits), dim=-1), dim=1 ) # camera layer intrinsics = self.camera_layer( features=features, cls_tokens=cls_tokens, pos_embed=pos_embed ) intrinsics[:, 0, 0] = max(original_shapes) / 2 * intrinsics[:, 0, 0] intrinsics[:, 1, 1] = max(original_shapes) / 2 * intrinsics[:, 1, 1] intrinsics[:, 0, 2] = intrinsics[:, 0, 2] * original_shapes[1] intrinsics[:, 1, 2] = intrinsics[:, 1, 2] * original_shapes[0] rays = ( rays_gt if rays_gt is not None else generate_rays(intrinsics, original_shapes)[0] ) return intrinsics, rays def run_global(self, cls_tokens, features, rays): # get cls tokens projections cls_tokens_splits = torch.tensor( [x.shape[-1] for x in cls_tokens], device=features.device, requires_grad=False, dtype=torch.float32, ) cls_tokens = torch.cat(cls_tokens, dim=-1) cls_tokens = self.global_token_adapter(cls_tokens, cls_tokens_splits) cls_tokens = torch.cat( torch.chunk(cls_tokens, len(cls_tokens_splits), dim=-1), dim=1 ) scale, shift = self.global_layer( features=features, rays=rays, cls_tokens=cls_tokens ) return scale, shift def forward(self, inputs, image_metas) -> torch.Tensor: B, C, H, W = inputs["image"].shape device = inputs["image"].device # get features in b n d format # level shapes, the shape per level, for swin like [[128, 128], [64, 64],...], for vit [[32,32]] -> mult times resolutions level_shapes = sorted( list(set([tuple([x.shape[1], x.shape[2]]) for x in inputs["features"]])) )[::-1] if len(level_shapes) == 1: level_shapes = level_shapes * self.num_resolutions input_shapes = [ level_shapes[i] for i, (start, end) in enumerate(self.slices_encoder) for _ in range(end - start) ] common_shape = level_shapes[-2] # input shapes repeat shapes for each level, times the amount of the layers: features_flat = [ flat_interpolate( rearrange(x, "b h w c -> b (h w) c"), old=input_shape, new=common_shape ) for x, input_shape in zip(inputs["features"], input_shapes) ] features_splits = torch.tensor( [x.shape[-1] for x in features_flat], device=device, requires_grad=False, dtype=torch.float32, ) features = self.get_adapted_features(features_flat, features_splits) features = torch.stack(features, dim=-1) # positional embeddings, spatial and level level_embed = torch.cat( [ self.level_embed_layer(self.level_embeds)[i : i + 1] .unsqueeze(0) .repeat(B, common_shape[0] * common_shape[1], 1) for i in range(self.num_resolutions) ], dim=1, ) dummy_tensor = torch.zeros( B, 1, common_shape[0], common_shape[1], device=device, requires_grad=False ) pos_embed = self.pos_embed(dummy_tensor) pos_embed = rearrange(pos_embed, "b c h w -> b (h w) c").repeat( 1, self.num_resolutions, 1 ) self.camera_layer.set_shapes(common_shape) intrinsics, rays = self.run_camera( inputs["camera_tokens"], features=features, pos_embed=pos_embed + level_embed, original_shapes=(H, W), rays_gt=inputs.get("rays"), ) self.global_layer.set_shapes(common_shape) self.global_layer.set_original_shapes((H, W)) scale, shift = self.run_global( inputs["global_tokens"], features=features, rays=rays ) # run bulk of the model self.depth_layer.set_shapes(common_shape) self.depth_layer.set_original_shapes((H, W)) out_normalized, confidence, depth_features = self.depth_layer( features=features, rays_hr=rays, pos_embed=pos_embed, level_embed=level_embed, ) # shift is scale invariant if we do (x + mu) * sigma out = (out_normalized + shift) * scale outputs = { "depth": out.clamp(min=1e-3), "confidence": confidence, "K": intrinsics, "rays": rays, "depth_features": depth_features, } return outputs @torch.jit.ignore def no_weight_decay_keywords(self): return {"latents_pos", "level_embeds"} def build(self, config): input_dims = config["model"]["pixel_encoder"]["embed_dims"] hidden_dim = config["model"]["pixel_decoder"]["hidden_dim"] expansion = config["model"]["expansion"] num_heads = config["model"]["num_heads"] dropout = config["model"]["pixel_decoder"]["dropout"] depths_encoder = config["model"]["pixel_encoder"]["depths"] depth = config["model"]["pixel_decoder"]["depths"] depths_encoder = config["model"]["pixel_encoder"]["depths"] self.downsample = 4 self.num_resolutions = len(depths_encoder) self.slices_encoder = list(zip([d - 1 for d in depths_encoder], depths_encoder)) cls_token_input_dims = [input_dims[i] for i in [-1, -2, -3, -4]] input_dims = [input_dims[d - 1] for d in depths_encoder] # # camera layer self.camera_layer = CameraHead( hidden_dim=hidden_dim, num_heads=num_heads, expansion=expansion, dropout=dropout, ) # # scale shift layer self.global_layer = GlobalHead( hidden_dim=hidden_dim, camera_dim=96, num_heads=num_heads, expansion=expansion, dropout=dropout, ) # # adapt from encoder features, just project self.input_adapter = ListAdapter(input_dims, hidden_dim) self.camera_token_adapter = ListAdapter(cls_token_input_dims, hidden_dim) self.global_token_adapter = ListAdapter(cls_token_input_dims[:2], hidden_dim) self.depth_layer = DepthHead( hidden_dim=hidden_dim, num_heads=num_heads, expansion=expansion, depths=depth, dropout=dropout, camera_dim=96, num_resolutions=self.num_resolutions, ) self.pos_embed = PositionEmbeddingSine(hidden_dim // 2, normalize=True) self.level_embeds = nn.Parameter( torch.randn(len(input_dims), hidden_dim), requires_grad=True ) self.level_embed_layer = nn.Sequential( nn.Linear(hidden_dim, hidden_dim), nn.GELU(), nn.Linear(hidden_dim, hidden_dim), nn.LayerNorm(hidden_dim), )