diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..1676ebb7f6799ca0286b07f5119f7663988c6f7c --- /dev/null +++ b/app.py @@ -0,0 +1,104 @@ +import gradio as gr +import torch +import cv2 +import numpy as np +import json +from unidepth.models import UniDepthV2 +import os +import matplotlib.pyplot as plt +import matplotlib +from PIL import Image + + +# Load model configurations and initialize model +def load_model(config_path, model_path, encoder, device): + with open(config_path) as f: + config = json.load(f) + + model = UniDepthV2(config) + model.load_state_dict(torch.load(model_path, map_location=device)['model'], strict=True) + model = model.to(device).eval() + + return model + +# Inference function +def depth_estimation(image, model_path, encoder='vits'): + try: + device = 'cuda' if torch.cuda.is_available() else 'cpu' + config_path = 'configs/config_v2_vits14.json' + + # Ensure model path exists or download if needed + if not os.path.exists(model_path): + return "Model checkpoint not found. Please upload a valid model path." + + model = load_model(config_path, model_path, encoder, device) + + # Preprocess image + rgb = torch.from_numpy(np.array(image)).permute(2, 0, 1).to(device) # C, H, W + predictions = model.infer(rgb) + depth = predictions["depth"].squeeze().to('cpu').numpy() + + min_depth = depth.min() + max_depth = depth.max() + + depth_normalized = (depth - min_depth) / (max_depth - min_depth) + + # Apply colormap + cmap = matplotlib.colormaps.get_cmap('Spectral') + depth_color = (cmap(depth_normalized)[:, :, :3] * 255).astype(np.uint8) + + # Create a figure and axis for the colorbar + fig, ax = plt.subplots(figsize=(6, 0.4)) + fig.subplots_adjust(bottom=0.5) + + # Create a colorbar + norm = matplotlib.colors.Normalize(vmin=min_depth, vmax=max_depth) + sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm) + sm.set_array([]) + cbar = fig.colorbar(sm, cax=ax, orientation='horizontal', label='Depth (meters)') + + # Save the colorbar to a BytesIO object + from io import BytesIO + buf = BytesIO() + fig.savefig(buf, format='png', bbox_inches='tight', pad_inches=0.1) + plt.close(fig) + buf.seek(0) + + # Open the colorbar image + colorbar_img = Image.open(buf) + + # Create a new image with space for the colorbar + new_height = depth_color.shape[0] + colorbar_img.size[1] + new_img = Image.new('RGB', (depth_color.shape[1], new_height), (255, 255, 255)) + + # Paste the depth image and colorbar + new_img.paste(Image.fromarray(depth_color), (0, 0)) + new_img.paste(colorbar_img, (0, depth_color.shape[0])) + + return new_img + + + except Exception as e: + return f"Error occurred: {str(e)}" + +# Gradio Interface +def main(): + iface = gr.Interface( + fn=depth_estimation, + inputs=[ + gr.Image(type="numpy", label="Input Image"), + gr.Textbox(value='checkpoint/latest.pth', label='Model Path'), + gr.Dropdown(choices=['vits', 'vitb', 'vitl', 'vitg'], value='vits', label='Encoder'), + ], + outputs=[ + gr.Image(type="pil", label="Predicted Depth") + ], + title="Depth Anything V2 Metric Depth Estimation", + description="Upload an image to get its estimated depth map using Depth Anything V2.", + ) + + iface.launch() + + +if __name__ == "__main__": + main() diff --git a/configs/config_v1_cnvnxtl.json b/configs/config_v1_cnvnxtl.json new file mode 100644 index 0000000000000000000000000000000000000000..f2002be27672042d8845aa1af2e96733c7305427 --- /dev/null +++ b/configs/config_v1_cnvnxtl.json @@ -0,0 +1,24 @@ +{ + "generic": { + "seed": 13 + }, + "training": { + }, + "data": { + "image_shape": [462, 616] + }, + "model": { + "name": "UniDepthV1", + "num_heads": 8, + "expansion": 4, + "pixel_decoder": { + "hidden_dim": 512, + "depths": [3, 2, 1], + "dropout": 0.0 + }, + "pixel_encoder": { + "name": "convnext_large", + "pretrained": null + } + } +} \ No newline at end of file diff --git a/configs/config_v1_vitl14.json b/configs/config_v1_vitl14.json new file mode 100644 index 0000000000000000000000000000000000000000..b57f6db2549297bc407a77b89177e14077d00d8c --- /dev/null +++ b/configs/config_v1_vitl14.json @@ -0,0 +1,23 @@ +{ + "generic": { + "seed": 13 + }, + "training": {}, + "data": { + "image_shape": [462, 616] + }, + "model": { + "name": "UniDepthV1", + "num_heads": 8, + "expansion": 4, + "pixel_decoder": { + "hidden_dim": 512, + "depths": [3, 2, 1], + "dropout": 0.0 + }, + "pixel_encoder": { + "name": "dinov2_vitl14", + "pretrained": null + } + } +} \ No newline at end of file diff --git a/configs/config_v2_vitl14.json b/configs/config_v2_vitl14.json new file mode 100644 index 0000000000000000000000000000000000000000..f5e6c3e3c5c8c654b2aee076ee76675ab0811241 --- /dev/null +++ b/configs/config_v2_vitl14.json @@ -0,0 +1,32 @@ +{ + "generic": { + "seed": 13, + "deterministic": true + }, + "training": {}, + "data": { + "image_shape": [420, 560], + "shape_constraints": { + "ratio_bounds": [0.66, 2.0], + "pixels_bounds": [1400, 2400], + "patch_size": 14 + } + }, + "model": { + "name": "UniDepthV2", + "num_heads": 8, + "expansion": 4, + "pixel_decoder": { + "hidden_dim": 512, + "depths": [6, 0, 0], + "dropout": 0.0 + }, + "pixel_encoder": { + "name": "dinov2_vitl14", + "pretrained": null, + "use_norm": true, + "stacking_fn": "last", + "output_idx": [21,22,23,24] + } + } +} \ No newline at end of file diff --git a/configs/config_v2_vits14.json b/configs/config_v2_vits14.json new file mode 100644 index 0000000000000000000000000000000000000000..ae7a39432755695871f2e29ddfca77a60c2a3ac8 --- /dev/null +++ b/configs/config_v2_vits14.json @@ -0,0 +1,32 @@ +{ + "generic": { + "seed": 13, + "deterministic": true + }, + "training": {}, + "data": { + "image_shape": [420, 560], + "shape_constraints": { + "ratio_bounds": [0.66, 2.0], + "pixels_bounds": [1400, 2400], + "patch_size": 14 + } + }, + "model": { + "name": "UniDepthV2", + "num_heads": 8, + "expansion": 4, + "pixel_decoder": { + "hidden_dim": 512, + "depths": [6, 0, 0], + "dropout": 0.0 + }, + "pixel_encoder": { + "name": "dinov2_vits14", + "pretrained": null, + "use_norm": true, + "stacking_fn": "last", + "output_idx": [9,10,11,12] + } + } +} \ No newline at end of file diff --git a/unidepth/layers/__init__.py b/unidepth/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8d5eb48381b64a8f3c79a201934b239d0a8357bc --- /dev/null +++ b/unidepth/layers/__init__.py @@ -0,0 +1,22 @@ +from .activation import GEGLU, SwiGLU +from .attention import AttentionBlock, AttentionDecoderBlock +from .convnext import CvnxtBlock +from .mlp import MLP +from .nystrom_attention import NystromBlock +from .positional_encoding import PositionEmbeddingSine +from .upsample import (ConvUpsample, ConvUpsampleShuffle, + ConvUpsampleShuffleResidual) + +__all__ = [ + "SwiGLU", + "GEGLU", + "CvnxtBlock", + "AttentionBlock", + "NystromBlock", + "PositionEmbeddingSine", + "ConvUpsample", + "MLP", + "ConvUpsampleShuffle", + "AttentionDecoderBlock", + "ConvUpsampleShuffleResidual", +] diff --git a/unidepth/layers/__pycache__/__init__.cpython-311.pyc b/unidepth/layers/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8b499f69f120f9a1f2639c974a60fdb9c8893a9b Binary files /dev/null and b/unidepth/layers/__pycache__/__init__.cpython-311.pyc differ diff --git a/unidepth/layers/__pycache__/activation.cpython-311.pyc b/unidepth/layers/__pycache__/activation.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7ada14dcc03542b79a98b4cab4e6da793819bff8 Binary files /dev/null and b/unidepth/layers/__pycache__/activation.cpython-311.pyc differ diff --git a/unidepth/layers/__pycache__/attention.cpython-311.pyc b/unidepth/layers/__pycache__/attention.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c544ffeb17fcaf72c089e165f4ceace046ea36e9 Binary files /dev/null and b/unidepth/layers/__pycache__/attention.cpython-311.pyc differ diff --git a/unidepth/layers/__pycache__/convnext.cpython-311.pyc b/unidepth/layers/__pycache__/convnext.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..75923582443920de99f776edde4d57e4c9db3ca1 Binary files /dev/null and b/unidepth/layers/__pycache__/convnext.cpython-311.pyc differ diff --git a/unidepth/layers/__pycache__/layer_scale.cpython-311.pyc b/unidepth/layers/__pycache__/layer_scale.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c85be89fe9512d815f129f7280917a4f3e3c7534 Binary files /dev/null and b/unidepth/layers/__pycache__/layer_scale.cpython-311.pyc differ diff --git a/unidepth/layers/__pycache__/mlp.cpython-311.pyc b/unidepth/layers/__pycache__/mlp.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..466fafe6be14405bb7351c6faf0d2ec5150c0e7e Binary files /dev/null and b/unidepth/layers/__pycache__/mlp.cpython-311.pyc differ diff --git a/unidepth/layers/__pycache__/nystrom_attention.cpython-311.pyc b/unidepth/layers/__pycache__/nystrom_attention.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1635cc2d376a7bfb6fca1f70646936368c4d593a Binary files /dev/null and b/unidepth/layers/__pycache__/nystrom_attention.cpython-311.pyc differ diff --git a/unidepth/layers/__pycache__/positional_encoding.cpython-311.pyc b/unidepth/layers/__pycache__/positional_encoding.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..576ef50caa79ac583cb99d4ec0a315308a3106a6 Binary files /dev/null and b/unidepth/layers/__pycache__/positional_encoding.cpython-311.pyc differ diff --git a/unidepth/layers/__pycache__/upsample.cpython-311.pyc b/unidepth/layers/__pycache__/upsample.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9fab09739152af0adc4e12a3536497a4027a7f3d Binary files /dev/null and b/unidepth/layers/__pycache__/upsample.cpython-311.pyc differ diff --git a/unidepth/layers/activation.py b/unidepth/layers/activation.py new file mode 100644 index 0000000000000000000000000000000000000000..f5787a340013ba59e2956b6b829f724d9cfb7fcc --- /dev/null +++ b/unidepth/layers/activation.py @@ -0,0 +1,15 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class SwiGLU(nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + x, gates = x.chunk(2, dim=-1) + return x * F.silu(gates) + + +class GEGLU(nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + x, gates = x.chunk(2, dim=-1) + return x * F.gelu(gates) diff --git a/unidepth/layers/attention.py b/unidepth/layers/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..c9fc5f79003e28815e65f9f8fe71474b7ed021a1 --- /dev/null +++ b/unidepth/layers/attention.py @@ -0,0 +1,308 @@ +""" +Author: Luigi Piccinelli +Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/) +""" + +from functools import partial + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + +from .layer_scale import LayerScale +from .mlp import MLP + + +class SimpleAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 4, + dropout: float = 0.0, + cosine: bool = False, + context_dim: int | None = None, + ): + super().__init__() + self.dropout = dropout + self.num_heads = num_heads + self.hidden_dim = dim + context_dim = context_dim or dim + + self.kv = nn.Linear(context_dim, dim * 2, bias=False) + self.q = nn.Linear(dim, dim, bias=False) + self.norm_attnx = nn.LayerNorm(dim) + self.norm_attnctx = nn.LayerNorm(context_dim) + self.cosine = cosine + self.out = nn.Linear(dim, dim) + + def forward( + self, + x: torch.Tensor, + attn_bias: torch.Tensor | None = None, + context: torch.Tensor | None = None, + pos_embed: torch.Tensor | None = None, + pos_embed_context: torch.Tensor | None = None, + rope: nn.Module | None = None, + ) -> torch.Tensor: + context = x if context is None else context + x = self.norm_attnx(x) + context = self.norm_attnctx(context) + k, v = rearrange( + self.kv(context), "b n (kv h d) -> b h n d kv", h=self.num_heads, kv=2 + ).unbind(dim=-1) + q = rearrange(self.q(x), "b n (h d) -> b h n d", h=self.num_heads) + + if rope is not None: + q = rope(q) + k = rope(k) + else: + if pos_embed is not None: + pos_embed = rearrange( + pos_embed, "b n (h d) -> b h n d", h=self.num_heads + ) + q = q + pos_embed + if pos_embed_context is not None: + pos_embed_context = rearrange( + pos_embed_context, "b n (h d) -> b h n d", h=self.num_heads + ) + k = k + pos_embed_context + + if self.cosine: + q, k = map(partial(F.normalize, p=2, dim=-1), (q, k)) # cosine sim + x = F.scaled_dot_product_attention( + q, k, v, dropout_p=self.dropout, attn_mask=attn_bias + ) + x = rearrange(x, "b h n d -> b n (h d)") + x = self.out(x) + return x + + +class AttentionBlock(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 4, + expansion: int = 4, + dropout: float = 0.0, + cosine: bool = False, + gated: bool = False, + layer_scale: float = 1.0, + context_dim: int | None = None, + ): + super().__init__() + self.dropout = dropout + self.num_heads = num_heads + self.hidden_dim = dim + context_dim = context_dim or dim + self.mlp = MLP(dim, expansion=expansion, dropout=dropout, gated=gated) + self.kv = nn.Linear(context_dim, dim * 2) + self.q = nn.Linear(dim, dim) + self.norm_attnx = nn.LayerNorm(dim) + self.norm_attnctx = nn.LayerNorm(context_dim) + self.cosine = cosine + self.out = nn.Linear(dim, dim) + self.ls1 = LayerScale(dim, layer_scale) if layer_scale > 0.0 else nn.Identity() + self.ls2 = LayerScale(dim, layer_scale) if layer_scale > 0.0 else nn.Identity() + + def attn( + self, + x: torch.Tensor, + attn_bias: torch.Tensor | None = None, + context: torch.Tensor | None = None, + pos_embed: torch.Tensor | None = None, + pos_embed_context: torch.Tensor | None = None, + rope: nn.Module | None = None, + ) -> torch.Tensor: + x = self.norm_attnx(x) + context = self.norm_attnctx(context) + k, v = rearrange( + self.kv(context), "b n (kv h d) -> b h n d kv", h=self.num_heads, kv=2 + ).unbind(dim=-1) + q = rearrange(self.q(x), "b n (h d) -> b h n d", h=self.num_heads) + + if rope is not None: + q = rope(q) + k = rope(k) + else: + if pos_embed is not None: + pos_embed = rearrange( + pos_embed, "b n (h d) -> b h n d", h=self.num_heads + ) + q = q + pos_embed + if pos_embed_context is not None: + pos_embed_context = rearrange( + pos_embed_context, "b n (h d) -> b h n d", h=self.num_heads + ) + k = k + pos_embed_context + + if self.cosine: + q, k = map(partial(F.normalize, p=2, dim=-1), (q, k)) # cosine sim + + x = F.scaled_dot_product_attention( + q, k, v, dropout_p=self.dropout, attn_mask=attn_bias + ) + x = rearrange(x, "b h n d -> b n (h d)") + x = self.out(x) + return x + + def forward( + self, + x: torch.Tensor, + attn_bias: torch.Tensor | None = None, + context: torch.Tensor | None = None, + pos_embed: torch.Tensor | None = None, + pos_embed_context: torch.Tensor | None = None, + rope: nn.Module | None = None, + ) -> torch.Tensor: + context = x if context is None else context + x = ( + self.ls1( + self.attn( + x, + rope=rope, + attn_bias=attn_bias, + context=context, + pos_embed=pos_embed, + pos_embed_context=pos_embed_context, + ) + ) + + x + ) + x = self.ls2(self.mlp(x)) + x + return x + + +class AttentionDecoderBlock(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 4, + expansion: int = 4, + dropout: float = 0.0, + cosine: bool = False, + gated: bool = False, + layer_scale: float = 1.0, + context_dim: int | None = None, + single_head_ca: bool = True, + ): + super().__init__() + self.dropout = dropout + self.num_heads = num_heads + self.hidden_dim = dim + self.single_head_ca = single_head_ca + context_dim = context_dim or dim + self.mlp = MLP(dim, expansion=expansion, dropout=dropout, gated=gated) + self.kv_ca = nn.Linear(context_dim, dim * 2) + self.q_ca = nn.Linear(dim, dim) + self.kv_sa = nn.Linear(dim, dim * 2) + self.q_sa = nn.Linear(dim, dim) + self.norm_x_sa = nn.LayerNorm(dim) + self.norm_x_ca = nn.LayerNorm(dim) + self.norm_ctx_ca = nn.LayerNorm(context_dim) + self.cosine = cosine + self.out_ca = nn.Linear(dim, dim) + self.out_sa = nn.Linear(dim, dim) + self.ls1 = LayerScale(dim, layer_scale) if layer_scale > 0.0 else nn.Identity() + self.ls2 = LayerScale(dim, layer_scale) if layer_scale > 0.0 else nn.Identity() + self.ls3 = LayerScale(dim, layer_scale) if layer_scale > 0.0 else nn.Identity() + + def cross_attn( + self, + x: torch.Tensor, + attn_bias: torch.Tensor | None = None, + context: torch.Tensor | None = None, + pos_embed: torch.Tensor | None = None, + pos_embed_context: torch.Tensor | None = None, + rope: nn.Module | None = None, + ) -> torch.Tensor: + num_heads = 1 if self.single_head_ca else self.num_heads + x = self.norm_x_ca(x) + context = self.norm_ctx_ca(context) + k, v = rearrange( + self.kv_ca(context), "b n (kv h d) -> b h n d kv", h=num_heads, kv=2 + ).unbind(dim=-1) + q = rearrange(self.q_ca(x), "b n (h d) -> b h n d", h=num_heads) + + if rope is not None: + q = rope(q) + k = rope(k) + else: + if pos_embed is not None: + pos_embed = rearrange(pos_embed, "b n (h d) -> b h n d", h=num_heads) + q = q + pos_embed + if pos_embed_context is not None: + pos_embed_context = rearrange( + pos_embed_context, "b n (h d) -> b h n d", h=num_heads + ) + k = k + pos_embed_context + + if self.cosine: + q, k = map(partial(F.normalize, p=2, dim=-1), (q, k)) # cosine sim + x = F.scaled_dot_product_attention( + q, k, v, dropout_p=self.dropout, attn_mask=attn_bias + ) + x = rearrange(x, "b h n d -> b n (h d)") + x = self.out_ca(x) + return x + + def self_attn( + self, + x: torch.Tensor, + attn_bias: torch.Tensor | None = None, + pos_embed: torch.Tensor | None = None, + rope: nn.Module | None = None, + ) -> torch.Tensor: + x = self.norm_x_sa(x) + k, v = rearrange( + self.kv_sa(x), "b n (kv h d) -> b h n d kv", h=self.num_heads, kv=2 + ).unbind(dim=-1) + q = rearrange(self.q_sa(x), "b n (h d) -> b h n d", h=self.num_heads) + + if rope is not None: + q = rope(q) + k = rope(k) + elif pos_embed is not None: + pos_embed = rearrange(pos_embed, "b n (h d) -> b h n d", h=self.num_heads) + q = q + pos_embed + + if self.cosine: + q, k = map(partial(F.normalize, p=2, dim=-1), (q, k)) # cosine sim + x = F.scaled_dot_product_attention( + q, k, v, dropout_p=self.dropout, attn_mask=attn_bias + ) + x = rearrange(x, "b h n d -> b n (h d)") + x = self.out_sa(x) + return x + + def forward( + self, + x: torch.Tensor, + attn_bias: torch.Tensor | None = None, + context: torch.Tensor | None = None, + pos_embed: torch.Tensor | None = None, + pos_embed_context: torch.Tensor | None = None, + rope: nn.Module | None = None, + ) -> torch.Tensor: + context = x if context is None else context + x = ( + self.ls1( + self.cross_attn( + x, + rope=rope, + attn_bias=attn_bias, + context=context, + pos_embed=pos_embed, + pos_embed_context=pos_embed_context, + ) + ) + + x + ) + x = ( + self.ls2( + self.self_attn(x, rope=rope, attn_bias=attn_bias, pos_embed=pos_embed) + ) + + x + ) + x = self.ls3(self.mlp(x)) + x + return x diff --git a/unidepth/layers/convnext.py b/unidepth/layers/convnext.py new file mode 100644 index 0000000000000000000000000000000000000000..23a104b1b444313b84aff0a1a92a5600ec42d288 --- /dev/null +++ b/unidepth/layers/convnext.py @@ -0,0 +1,44 @@ +import torch +import torch.nn as nn + + +class CvnxtBlock(nn.Module): + def __init__( + self, + dim, + kernel_size=7, + layer_scale=1.0, + expansion=4, + dilation=1, + padding_mode: str = "zeros", + ): + super().__init__() + self.dwconv = nn.Conv2d( + dim, + dim, + kernel_size=kernel_size, + padding=dilation * (kernel_size - 1) // 2, + groups=dim, + dilation=dilation, + padding_mode=padding_mode, + ) # depthwise conv + self.norm = nn.LayerNorm(dim) + self.pwconv1 = nn.Linear(dim, expansion * dim) + self.act = nn.GELU() + self.pwconv2 = nn.Linear(expansion * dim, dim) + self.gamma = ( + nn.Parameter(layer_scale * torch.ones((dim))) if layer_scale > 0.0 else 1.0 + ) + + def forward(self, x): + input = x + x = self.dwconv(x) + x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) + x = self.norm(x) + x = self.pwconv1(x) + x = self.act(x) + x = self.pwconv2(x) + + x = self.gamma * x + x = input + x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) + return x diff --git a/unidepth/layers/drop_path.py b/unidepth/layers/drop_path.py new file mode 100644 index 0000000000000000000000000000000000000000..781ff566500c923b1f199542b0c7dfb862a077ca --- /dev/null +++ b/unidepth/layers/drop_path.py @@ -0,0 +1,25 @@ +import torch +import torch.nn as nn + + +def drop_path(x: torch.Tensor, drop_prob: float = 0.0, training: bool = False): + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * ( + x.ndim - 1 + ) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0: + random_tensor.div_(keep_prob) + output = x * random_tensor + return output + + +class DropPath(nn.Module): + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) diff --git a/unidepth/layers/layer_scale.py b/unidepth/layers/layer_scale.py new file mode 100644 index 0000000000000000000000000000000000000000..01b6662490d7296725f103d1abf8790cac84d0f8 --- /dev/null +++ b/unidepth/layers/layer_scale.py @@ -0,0 +1,17 @@ +import torch +import torch.nn as nn + + +class LayerScale(nn.Module): + def __init__( + self, + dim: int, + init_values: float | torch.Tensor = 1e-5, + inplace: bool = False, + ) -> None: + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma diff --git a/unidepth/layers/mlp.py b/unidepth/layers/mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..c9340834b0ad047bf87915d12a124276248d713f --- /dev/null +++ b/unidepth/layers/mlp.py @@ -0,0 +1,35 @@ +import torch +import torch.nn as nn + +from unidepth.utils.misc import default + +from .activation import SwiGLU + + +class MLP(nn.Module): + def __init__( + self, + input_dim: int, + expansion: int = 4, + dropout: float = 0.0, + gated: bool = False, + output_dim: int | None = None, + ): + super().__init__() + if gated: + expansion = int(expansion * 2 / 3) + hidden_dim = int(input_dim * expansion) + output_dim = default(output_dim, input_dim) + self.norm = nn.LayerNorm(input_dim) + self.proj1 = nn.Linear(input_dim, hidden_dim) + self.proj2 = nn.Linear(hidden_dim, output_dim) + self.act = nn.GELU() if not gated else SwiGLU() + self.dropout = nn.Dropout(dropout) if dropout > 0.0 else nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.norm(x) + x = self.proj1(x) + x = self.act(x) + x = self.proj2(x) + x = self.dropout(x) + return x diff --git a/unidepth/layers/nystrom_attention.py b/unidepth/layers/nystrom_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..9f7476f114a68617bf64bc4cb51eec6c98445df5 --- /dev/null +++ b/unidepth/layers/nystrom_attention.py @@ -0,0 +1,74 @@ +from functools import partial + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from xformers.components.attention import NystromAttention + +from .attention import AttentionBlock + + +class NystromBlock(AttentionBlock): + def __init__( + self, + dim: int, + num_heads: int = 4, + expansion: int = 4, + dropout: float = 0.0, + cosine: bool = False, + gated: bool = False, + layer_scale: float = 1.0, + context_dim: int | None = None, + ): + super().__init__( + dim=dim, + num_heads=num_heads, + expansion=expansion, + dropout=dropout, + cosine=cosine, + gated=gated, + layer_scale=layer_scale, + context_dim=context_dim, + ) + self.attention_fn = NystromAttention( + num_landmarks=128, num_heads=num_heads, dropout=dropout + ) + + def attn( + self, + x: torch.Tensor, + attn_bias: torch.Tensor | None = None, + context: torch.Tensor | None = None, + pos_embed: torch.Tensor | None = None, + pos_embed_context: torch.Tensor | None = None, + rope: nn.Module | None = None, + ) -> torch.Tensor: + x = self.norm_attnx(x) + context = self.norm_attnctx(context) + k, v = rearrange( + self.kv(context), "b n (kv h d) -> b n h d kv", h=self.num_heads, kv=2 + ).unbind(dim=-1) + q = rearrange(self.q(x), "b n (h d) -> b n h d", h=self.num_heads) + + if rope is not None: + q = rope(q) + k = rope(k) + else: + if pos_embed is not None: + pos_embed = rearrange( + pos_embed, "b n (h d) -> b n h d", h=self.num_heads + ) + q = q + pos_embed + if pos_embed_context is not None: + pos_embed_context = rearrange( + pos_embed_context, "b n (h d) -> b n h d", h=self.num_heads + ) + k = k + pos_embed_context + + if self.cosine: + q, k = map(partial(F.normalize, p=2, dim=-1), (q, k)) # cosine sim + x = self.attention_fn(q, k, v, key_padding_mask=attn_bias) + x = rearrange(x, "b n h d -> b n (h d)") + x = self.out(x) + return x diff --git a/unidepth/layers/positional_encoding.py b/unidepth/layers/positional_encoding.py new file mode 100644 index 0000000000000000000000000000000000000000..431918d6b281bd7e3588d6c04425e2eacd6c23e4 --- /dev/null +++ b/unidepth/layers/positional_encoding.py @@ -0,0 +1,227 @@ +""" +Author: Luigi Piccinelli +Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/) +""" + +from math import pi +from typing import Optional + +import torch +import torch.nn as nn +from einops import rearrange, repeat + + +class PositionEmbeddingSine(nn.Module): + def __init__( + self, num_pos_feats=64, temperature=10000, normalize=False, scale=None + ): + super().__init__() + self.num_pos_feats = num_pos_feats + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * pi + self.scale = scale + + def forward( + self, x: torch.Tensor, mask: Optional[torch.Tensor] = None + ) -> torch.Tensor: + if mask is None: + mask = torch.zeros( + (x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool + ) + not_mask = ~mask + y_embed = not_mask.cumsum(1, dtype=torch.float32) + x_embed = not_mask.cumsum(2, dtype=torch.float32) + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = self.temperature ** ( + 2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_pos_feats + ) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack( + (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 + ).flatten(3) + pos_y = torch.stack( + (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 + ).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos + + def __repr__(self, _repr_indent=4): + head = "Positional encoding " + self.__class__.__name__ + body = [ + "num_pos_feats: {}".format(self.num_pos_feats), + "temperature: {}".format(self.temperature), + "normalize: {}".format(self.normalize), + "scale: {}".format(self.scale), + ] + # _repr_indent = 4 + lines = [head] + [" " * _repr_indent + line for line in body] + return "\n".join(lines) + + +class LearnedSinusoidalPosEmb(nn.Module): + def __init__(self, dim): + super().__init__() + assert (dim % 2) == 0 + half_dim = dim // 2 + self.weights = nn.Parameter(torch.randn(half_dim)) + + def forward(self, x): + x = rearrange(x, "b -> b 1") + freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * pi + fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1) + fouriered = torch.cat((x, fouriered), dim=-1) + return fouriered + + +def generate_fourier_features(x, max_freq=64, num_bands=16): + x = x.unsqueeze(-1) + device, dtype, orig_x = x.device, x.dtype, x + + scales = torch.linspace( + -max_freq / 2, max_freq / 2, num_bands, device=device, dtype=dtype + ) + scales = scales[(*((None,) * (len(x.shape) - 1)), Ellipsis)] + + x = x * scales * pi + x = torch.cat([x.sin(), x.cos()], dim=-1) + x = torch.cat((x, orig_x), dim=-1) + return x.flatten(-2) + + +def broadcat(tensors, dim=-1): + num_tensors = len(tensors) + shape_lens = set(list(map(lambda t: len(t.shape), tensors))) + assert len(shape_lens) == 1, "tensors must all have the same number of dimensions" + shape_len = list(shape_lens)[0] + dim = (dim + shape_len) if dim < 0 else dim + dims = list(zip(*map(lambda t: list(t.shape), tensors))) + expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim] + assert all( + [*map(lambda t: len(set(t[1])) <= 2, expandable_dims)] + ), "invalid dimensions for broadcastable concatentation" + max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims)) + expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims)) + expanded_dims.insert(dim, (dim, dims[dim])) + expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims))) + tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes))) + return torch.cat(tensors, dim=dim) + + +def rotate_half(x): + x = rearrange(x, "... (d r) -> ... d r", r=2) + x1, x2 = x.unbind(dim=-1) + x = torch.stack((-x2, x1), dim=-1) + return rearrange(x, "... d r -> ... (d r)") + + +class VisionRotaryEmbedding(nn.Module): + def __init__( + self, + dim, + pt_seq_len, + ft_seq_len=None, + custom_freqs=None, + freqs_for="lang", + theta=10000, + max_freq=10, + num_freqs=1, + ): + super().__init__() + if custom_freqs: + freqs = custom_freqs + elif freqs_for == "lang": + freqs = 1.0 / ( + theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim) + ) + elif freqs_for == "pixel": + freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi + elif freqs_for == "constant": + freqs = torch.ones(num_freqs).float() + else: + raise ValueError(f"unknown modality {freqs_for}") + + if ft_seq_len is None: + ft_seq_len = pt_seq_len + t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len + + freqs_h = torch.einsum("..., f -> ... f", t, freqs) + freqs_h = repeat(freqs_h, "... n -> ... (n r)", r=2) + + freqs_w = torch.einsum("..., f -> ... f", t, freqs) + freqs_w = repeat(freqs_w, "... n -> ... (n r)", r=2) + + freqs = broadcat((freqs_h[:, None, :], freqs_w[None, :, :]), dim=-1) + + self.register_buffer("freqs_cos", freqs.cos()) + self.register_buffer("freqs_sin", freqs.sin()) + + print("======== shape of rope freq", self.freqs_cos.shape, "========") + + def forward(self, t, start_index=0): + rot_dim = self.freqs_cos.shape[-1] + end_index = start_index + rot_dim + assert ( + rot_dim <= t.shape[-1] + ), f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}" + t_left, t, t_right = ( + t[..., :start_index], + t[..., start_index:end_index], + t[..., end_index:], + ) + t = (t * self.freqs_cos) + (rotate_half(t) * self.freqs_sin) + return torch.cat((t_left, t, t_right), dim=-1) + + +class VisionRotaryEmbeddingFast(nn.Module): + def __init__( + self, + dim, + pt_seq_len, + ft_seq_len=None, + custom_freqs=None, + freqs_for="lang", + theta=10000, + max_freq=10, + num_freqs=1, + ): + super().__init__() + if custom_freqs: + freqs = custom_freqs + elif freqs_for == "lang": + freqs = 1.0 / ( + theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim) + ) + elif freqs_for == "pixel": + freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi + elif freqs_for == "constant": + freqs = torch.ones(num_freqs).float() + else: + raise ValueError(f"unknown modality {freqs_for}") + + if ft_seq_len is None: + ft_seq_len = pt_seq_len + t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len + + freqs = torch.einsum("..., f -> ... f", t, freqs) + freqs = repeat(freqs, "... n -> ... (n r)", r=2) + freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim=-1) + + freqs_cos = freqs.cos().view(-1, freqs.shape[-1]) + freqs_sin = freqs.sin().view(-1, freqs.shape[-1]) + + self.register_buffer("freqs_cos", freqs_cos) + self.register_buffer("freqs_sin", freqs_sin) + + def forward(self, t): + return t * self.freqs_cos + rotate_half(t) * self.freqs_sin diff --git a/unidepth/layers/upsample.py b/unidepth/layers/upsample.py new file mode 100644 index 0000000000000000000000000000000000000000..11d91e84cf5c8cc06ce718b2b06cd1eba7779ba6 --- /dev/null +++ b/unidepth/layers/upsample.py @@ -0,0 +1,134 @@ +""" +Author: Luigi Piccinelli +Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/) +""" + +import torch +import torch.nn as nn +from einops import rearrange + +from .convnext import CvnxtBlock + + +class ConvUpsample(nn.Module): + def __init__( + self, + hidden_dim, + num_layers: int = 2, + expansion: int = 4, + layer_scale: float = 1.0, + kernel_size: int = 7, + **kwargs, + ): + super().__init__() + self.convs = nn.ModuleList([]) + for _ in range(num_layers): + self.convs.append( + CvnxtBlock( + hidden_dim, + kernel_size=kernel_size, + expansion=expansion, + layer_scale=layer_scale, + ) + ) + self.up = nn.Sequential( + nn.Conv2d(hidden_dim, hidden_dim // 2, kernel_size=1, padding=0), + nn.UpsamplingBilinear2d(scale_factor=2), + nn.Conv2d(hidden_dim // 2, hidden_dim // 2, kernel_size=3, padding=1), + ) + + def forward(self, x: torch.Tensor): + for conv in self.convs: + x = conv(x) + x = self.up(x) + x = rearrange(x, "b c h w -> b (h w) c") + return x + + +class ConvUpsampleShuffle(nn.Module): + def __init__( + self, + hidden_dim, + num_layers: int = 2, + expansion: int = 4, + layer_scale: float = 1.0, + kernel_size: int = 7, + **kwargs, + ): + super().__init__() + self.convs = nn.ModuleList([]) + for _ in range(num_layers): + self.convs.append( + CvnxtBlock( + hidden_dim, + kernel_size=kernel_size, + expansion=expansion, + layer_scale=layer_scale, + ) + ) + self.up = nn.Sequential( + nn.PixelShuffle(2), + nn.Conv2d(hidden_dim // 4, hidden_dim // 2, kernel_size=3, padding=1), + ) + + def forward(self, x: torch.Tensor): + for conv in self.convs: + x = conv(x) + x = self.up(x) + x = rearrange(x, "b c h w -> b (h w) c") + return x + + +class ConvUpsampleShuffleResidual(nn.Module): + def __init__( + self, + hidden_dim, + num_layers: int = 2, + expansion: int = 4, + layer_scale: float = 1.0, + kernel_size: int = 7, + padding_mode: str = "zeros", + **kwargs, + ): + super().__init__() + self.convs = nn.ModuleList([]) + for _ in range(num_layers): + self.convs.append( + CvnxtBlock( + hidden_dim, + kernel_size=kernel_size, + expansion=expansion, + layer_scale=layer_scale, + padding_mode=padding_mode, + ) + ) + self.up = nn.Sequential( + nn.PixelShuffle(2), + nn.Conv2d( + hidden_dim // 4, + hidden_dim // 4, + kernel_size=7, + padding=3, + padding_mode=padding_mode, + groups=hidden_dim // 4, + ), + nn.ReLU(), + nn.Conv2d( + hidden_dim // 4, + hidden_dim // 2, + kernel_size=3, + padding=1, + padding_mode=padding_mode, + ), + ) + self.residual = nn.Sequential( + nn.Conv2d(hidden_dim, hidden_dim // 2, kernel_size=1, padding=0), + nn.UpsamplingBilinear2d(scale_factor=2), + ) + + def forward(self, x: torch.Tensor): + for conv in self.convs: + x = conv(x) + x = self.up(x) + self.residual(x) + x = rearrange(x, "b c h w -> b (h w) c") + return x diff --git a/unidepth/models/__init__.py b/unidepth/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..285f53b893a7c538ab4e9796081eca6f6e1593cc --- /dev/null +++ b/unidepth/models/__init__.py @@ -0,0 +1,7 @@ +from .unidepthv1 import UniDepthV1 +from .unidepthv2 import UniDepthV2 + +__all__ = [ + "UniDepthV1", + "UniDepthV2", +] diff --git a/unidepth/models/__pycache__/__init__.cpython-311.pyc b/unidepth/models/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bd532a7a6f51ad1f42cb4c2eff633b2a4199c241 Binary files /dev/null and b/unidepth/models/__pycache__/__init__.cpython-311.pyc differ diff --git a/unidepth/models/__pycache__/encoder.cpython-311.pyc b/unidepth/models/__pycache__/encoder.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5da25c35f9a8741751676a4af80934a19a0ea8cf Binary files /dev/null and b/unidepth/models/__pycache__/encoder.cpython-311.pyc differ diff --git a/unidepth/models/backbones/__init__.py b/unidepth/models/backbones/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d443b3faa2d658af6ae95d6600650c3c876ec605 --- /dev/null +++ b/unidepth/models/backbones/__init__.py @@ -0,0 +1,9 @@ +from .convnext import ConvNeXt +from .convnext2 import ConvNeXtV2 +from .dinov2 import _make_dinov2_model + +__all__ = [ + "ConvNeXt", + "ConvNeXtV2", + "_make_dinov2_model", +] diff --git a/unidepth/models/backbones/__pycache__/__init__.cpython-311.pyc b/unidepth/models/backbones/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..24f655aaf802e8a27da2912b26854fed028b5b57 Binary files /dev/null and b/unidepth/models/backbones/__pycache__/__init__.cpython-311.pyc differ diff --git a/unidepth/models/backbones/__pycache__/convnext.cpython-311.pyc b/unidepth/models/backbones/__pycache__/convnext.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..32edac850da8fbed23393b48b1105b337b31fb47 Binary files /dev/null and b/unidepth/models/backbones/__pycache__/convnext.cpython-311.pyc differ diff --git a/unidepth/models/backbones/__pycache__/convnext2.cpython-311.pyc b/unidepth/models/backbones/__pycache__/convnext2.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ff065181a6a9b6b1209bd2105343718e5cb1c8b6 Binary files /dev/null and b/unidepth/models/backbones/__pycache__/convnext2.cpython-311.pyc differ diff --git a/unidepth/models/backbones/__pycache__/dinov2.cpython-311.pyc b/unidepth/models/backbones/__pycache__/dinov2.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e8bc1ca4326f67500f0fbd6e4803c0accf3ad21a Binary files /dev/null and b/unidepth/models/backbones/__pycache__/dinov2.cpython-311.pyc differ diff --git a/unidepth/models/backbones/convnext.py b/unidepth/models/backbones/convnext.py new file mode 100644 index 0000000000000000000000000000000000000000..11b89a94ac66de221671f078ee785eb92f727040 --- /dev/null +++ b/unidepth/models/backbones/convnext.py @@ -0,0 +1,580 @@ +from collections import OrderedDict +from functools import partial +from typing import Callable, Optional, Sequence, Tuple, Union + +import torch +import torch.nn as nn +from timm.layers import (AvgPool2dSame, DropPath, GlobalResponseNormMlp, + LayerNorm, LayerNorm2d, Mlp, create_conv2d, + get_act_layer, make_divisible, to_ntuple, + trunc_normal_) +from torch.utils.checkpoint import checkpoint + + +def get_num_layer_for_convnext(var_name): + """ + Divide [3, 3, 27, 3] layers into 12 groups; each group is three + consecutive blocks, including possible neighboring downsample layers; + adapted from https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py + """ + if var_name.startswith("downsample_layers"): + stage_id = int(var_name.split(".")[1]) + if stage_id == 0: + layer_id = 0 + elif stage_id == 1 or stage_id == 2: + layer_id = stage_id + 1 + elif stage_id == 3: + layer_id = 12 + + elif var_name.startswith("stages"): + stage_id = int(var_name.split(".")[1]) + block_id = int(var_name.split(".")[3]) + if stage_id == 0 or stage_id == 1: + layer_id = stage_id + 1 + elif stage_id == 2: + layer_id = 3 + block_id // 3 + elif stage_id == 3: + layer_id = 12 + + elif var_name.startswith("stem"): + return 0 + else: + layer_id = 12 + return layer_id + 1 + + +def get_parameter_groups(model, lr, wd=1e-5, ld=0.9, skip_list=None): + parameter_group_names = {} + parameter_group_vars = {} + skip = set() + if skip_list is not None: + skip = skip_list + if hasattr(model, "no_weight_decay"): + skip.update(model.no_weight_decay()) + num_layers = 12 + layer_scale = list(ld ** (num_layers + 1 - i) for i in range(num_layers + 2)) + for name, param in model.named_parameters(): + if not param.requires_grad: + continue # frozen weights + if len(param.shape) == 1 or name.endswith(".bias") or name in skip: + group_name = "no_decay" + this_wd = 0.0 + else: + group_name = "decay" + this_wd = wd + + layer_id = get_num_layer_for_convnext(name) + group_name = "layer_%d_%s" % (layer_id, group_name) + + if group_name not in parameter_group_names: + scale = layer_scale[layer_id] + cur_lr = lr * scale + + parameter_group_names[group_name] = { + "weight_decay": this_wd, + "weight_decay_init": this_wd, + "weight_decay_base": this_wd, + "params": [], + "lr_init": cur_lr, + "lr_base": lr, + "lr": cur_lr, + } + parameter_group_vars[group_name] = { + "weight_decay": this_wd, + "weight_decay_init": this_wd, + "weight_decay_base": this_wd, + "params": [], + "lr_init": cur_lr, + "lr_base": lr, + "lr": cur_lr, + } + if this_wd == 0.0: + parameter_group_names[group_name]["weight_decay_final"] = 0.0 + parameter_group_vars[group_name]["weight_decay_final"] = 0.0 + parameter_group_vars[group_name]["params"].append(param) + parameter_group_names[group_name]["params"].append(name) + # from unidepth.utils import is_main_process + # import json + # if is_main_process(): + # print("Param groups = %s" % json.dumps(parameter_group_names, indent=2)) + return list(parameter_group_vars.values()), [ + v["lr"] for k, v in parameter_group_vars.items() + ] + + +class Downsample(nn.Module): + def __init__(self, in_chs, out_chs, stride=1, dilation=1): + super().__init__() + avg_stride = stride if dilation == 1 else 1 + if stride > 1 or dilation > 1: + avg_pool_fn = ( + AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d + ) + self.pool = avg_pool_fn( + 2, avg_stride, ceil_mode=True, count_include_pad=False + ) + else: + self.pool = nn.Identity() + + if in_chs != out_chs: + self.conv = create_conv2d(in_chs, out_chs, 1, stride=1) + else: + self.conv = nn.Identity() + + def forward(self, x): + x = self.pool(x) + x = self.conv(x) + return x + + +class ConvNeXtBlock(nn.Module): + """ConvNeXt Block + There are two equivalent implementations: + (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) + (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back + + Unlike the official impl, this one allows choice of 1 or 2, 1x1 conv can be faster with appropriate + choice of LayerNorm impl, however as model size increases the tradeoffs appear to change and nn.Linear + is a better choice. This was observed with PyTorch 1.10 on 3090 GPU, it could change over time & w/ different HW. + """ + + def __init__( + self, + in_chs: int, + out_chs: Optional[int] = None, + kernel_size: int = 7, + stride: int = 1, + dilation: Union[int, Tuple[int, int]] = (1, 1), + mlp_ratio: float = 4, + conv_mlp: bool = False, + conv_bias: bool = True, + use_grn: bool = False, + ls_init_value: Optional[float] = 1e-6, + act_layer: Union[str, Callable] = "gelu", + norm_layer: Optional[Callable] = None, + drop_path: float = 0.0, + ): + """ + + Args: + in_chs: Block input channels. + out_chs: Block output channels (same as in_chs if None). + kernel_size: Depthwise convolution kernel size. + stride: Stride of depthwise convolution. + dilation: Tuple specifying input and output dilation of block. + mlp_ratio: MLP expansion ratio. + conv_mlp: Use 1x1 convolutions for MLP and a NCHW compatible norm layer if True. + conv_bias: Apply bias for all convolution (linear) layers. + use_grn: Use GlobalResponseNorm in MLP (from ConvNeXt-V2) + ls_init_value: Layer-scale init values, layer-scale applied if not None. + act_layer: Activation layer. + norm_layer: Normalization layer (defaults to LN if not specified). + drop_path: Stochastic depth probability. + """ + super().__init__() + out_chs = out_chs or in_chs + dilation = to_ntuple(2)(dilation) + act_layer = get_act_layer(act_layer) + if not norm_layer: + norm_layer = LayerNorm2d if conv_mlp else LayerNorm + mlp_layer = partial( + GlobalResponseNormMlp if use_grn else Mlp, use_conv=conv_mlp + ) + self.use_conv_mlp = conv_mlp + self.conv_dw = create_conv2d( + in_chs, + out_chs, + kernel_size=kernel_size, + stride=stride, + dilation=dilation[0], + depthwise=True, + bias=conv_bias, + ) + self.norm = norm_layer(out_chs) + self.mlp = mlp_layer(out_chs, int(mlp_ratio * out_chs), act_layer=act_layer) + self.gamma = ( + nn.Parameter(ls_init_value * torch.ones(out_chs)) + if ls_init_value is not None + else None + ) + if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]: + self.shortcut = Downsample( + in_chs, out_chs, stride=stride, dilation=dilation[0] + ) + else: + self.shortcut = nn.Identity() + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + def forward(self, x): + shortcut = x + x = self.conv_dw(x.contiguous()) + if self.use_conv_mlp: + x = self.norm(x) + x = self.mlp(x) + else: + x = x.permute(0, 2, 3, 1).contiguous() + x = self.norm(x) + x = self.mlp(x) + x = x.permute(0, 3, 1, 2).contiguous() + if self.gamma is not None: + x = x.mul(self.gamma.reshape(1, -1, 1, 1)) + + x = self.drop_path(x) + self.shortcut(shortcut) + return x.contiguous() + + +class ConvNeXtStage(nn.Module): + def __init__( + self, + in_chs, + out_chs, + kernel_size=7, + stride=2, + depth=2, + dilation=(1, 1), + drop_path_rates=None, + ls_init_value=1.0, + conv_mlp=False, + conv_bias=True, + use_grn=False, + act_layer="gelu", + norm_layer=None, + norm_layer_cl=None, + ): + super().__init__() + self.grad_checkpointing = False + + if in_chs != out_chs or stride > 1 or dilation[0] != dilation[1]: + ds_ks = 2 if stride > 1 or dilation[0] != dilation[1] else 1 + pad = ( + "same" if dilation[1] > 1 else 0 + ) # same padding needed if dilation used + self.downsample = nn.Sequential( + norm_layer(in_chs), + create_conv2d( + in_chs, + out_chs, + kernel_size=ds_ks, + stride=stride, + dilation=dilation[0], + padding=pad, + bias=conv_bias, + ), + ) + in_chs = out_chs + else: + self.downsample = nn.Identity() + + drop_path_rates = drop_path_rates or [0.0] * depth + stage_blocks = [] + for i in range(depth): + stage_blocks.append( + ConvNeXtBlock( + in_chs=in_chs, + out_chs=out_chs, + kernel_size=kernel_size, + dilation=dilation[1], + drop_path=drop_path_rates[i], + ls_init_value=ls_init_value, + conv_mlp=conv_mlp, + conv_bias=conv_bias, + use_grn=use_grn, + act_layer=act_layer, + norm_layer=norm_layer if conv_mlp else norm_layer_cl, + ) + ) + in_chs = out_chs + self.blocks = nn.ModuleList(stage_blocks) + + def forward(self, x): + xs = [] + x = self.downsample(x) + for block in self.blocks: + if self.grad_checkpointing: + x = checkpoint(block, x) + else: + x = block(x) + xs.append(x) + return xs + + +class ConvNeXt(nn.Module): + def __init__( + self, + in_chans: int = 3, + output_stride: int = 32, + depths: Tuple[int, ...] = (3, 3, 9, 3), + dims: Tuple[int, ...] = (96, 192, 384, 768), + kernel_sizes: Union[int, Tuple[int, ...]] = 7, + ls_init_value: Optional[float] = 1e-6, + stem_type: str = "patch", + patch_size: int = 4, + conv_mlp: bool = False, + conv_bias: bool = True, + use_grn: bool = False, + act_layer: Union[str, Callable] = "gelu", + norm_layer: Optional[Union[str, Callable]] = None, + norm_eps: Optional[float] = None, + drop_path_rate: float = 0.0, + output_idx=[], + use_checkpoint=False, + ): + """ + Args: + in_chans: Number of input image channels. + num_classes: Number of classes for classification head. + global_pool: Global pooling type. + output_stride: Output stride of network, one of (8, 16, 32). + depths: Number of blocks at each stage. + dims: Feature dimension at each stage. + kernel_sizes: Depthwise convolution kernel-sizes for each stage. + ls_init_value: Init value for Layer Scale, disabled if None. + stem_type: Type of stem. + patch_size: Stem patch size for patch stem. + head_init_scale: Init scaling value for classifier weights and biases. + head_norm_first: Apply normalization before global pool + head. + head_hidden_size: Size of MLP hidden layer in head if not None and head_norm_first == False. + conv_mlp: Use 1x1 conv in MLP, improves speed for small networks w/ chan last. + conv_bias: Use bias layers w/ all convolutions. + use_grn: Use Global Response Norm (ConvNeXt-V2) in MLP. + act_layer: Activation layer type. + norm_layer: Normalization layer type. + drop_rate: Head pre-classifier dropout rate. + drop_path_rate: Stochastic depth drop rate. + """ + super().__init__() + self.num_layers = len(depths) + self.depths = output_idx + self.embed_dims = [ + int(dim) for i, dim in enumerate(dims) for _ in range(depths[i]) + ] + self.embed_dim = dims[0] + + assert output_stride in (8, 16, 32) + kernel_sizes = to_ntuple(4)(kernel_sizes) + if norm_layer is None: + norm_layer = LayerNorm2d + norm_layer_cl = norm_layer if conv_mlp else LayerNorm + if norm_eps is not None: + norm_layer = partial(norm_layer, eps=norm_eps) + norm_layer_cl = partial(norm_layer_cl, eps=norm_eps) + else: + assert ( + conv_mlp + ), "If a norm_layer is specified, conv MLP must be used so all norm expect rank-4, channels-first input" + norm_layer_cl = norm_layer + if norm_eps is not None: + norm_layer_cl = partial(norm_layer_cl, eps=norm_eps) + + self.feature_info = [] + + assert stem_type in ("patch", "overlap", "overlap_tiered") + if stem_type == "patch": + # NOTE: this stem is a minimal form of ViT PatchEmbed, as used in SwinTransformer w/ patch_size = 4 + self.stem = nn.Sequential( + nn.Conv2d( + in_chans, + dims[0], + kernel_size=patch_size, + stride=patch_size, + bias=conv_bias, + ), + norm_layer(dims[0]), + ) + stem_stride = patch_size + else: + mid_chs = make_divisible(dims[0] // 2) if "tiered" in stem_type else dims[0] + self.stem = nn.Sequential( + nn.Conv2d( + in_chans, + mid_chs, + kernel_size=3, + stride=2, + padding=1, + bias=conv_bias, + ), + nn.Conv2d( + mid_chs, dims[0], kernel_size=3, stride=2, padding=1, bias=conv_bias + ), + norm_layer(dims[0]), + ) + stem_stride = 4 + + self.stages = nn.Sequential() + dp_rates = [ + x.tolist() + for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths) + ] + stages = [] + prev_chs = dims[0] + curr_stride = stem_stride + dilation = 1 + # 4 feature resolution stages, each consisting of multiple residual blocks + for i in range(4): + stride = 2 if curr_stride == 2 or i > 0 else 1 + if curr_stride >= output_stride and stride > 1: + dilation *= stride + stride = 1 + curr_stride *= stride + first_dilation = 1 if dilation in (1, 2) else 2 + out_chs = dims[i] + stages.append( + ConvNeXtStage( + prev_chs, + out_chs, + kernel_size=kernel_sizes[i], + stride=stride, + dilation=(first_dilation, dilation), + depth=depths[i], + drop_path_rates=dp_rates[i], + ls_init_value=ls_init_value, + conv_mlp=conv_mlp, + conv_bias=conv_bias, + use_grn=use_grn, + act_layer=act_layer, + norm_layer=norm_layer, + norm_layer_cl=norm_layer_cl, + ) + ) + prev_chs = out_chs + # NOTE feature_info use currently assumes stage 0 == stride 1, rest are stride 2 + self.feature_info += [ + dict(num_chs=prev_chs, reduction=curr_stride, module=f"stages.{i}") + ] + self.stages = nn.ModuleList(stages) + self.mask_token = nn.Parameter(torch.zeros(1, self.embed_dim, 1, 1)) + self.num_features = prev_chs + self.apply(self._init_weights) + self.set_grad_checkpointing(use_checkpoint) + + def _init_weights(self, module): + if isinstance(module, nn.Conv2d): + trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=0.02) + nn.init.zeros_(module.bias) + + def forward(self, x, masks=None): + outs = [] + x = self.stem(x) + if masks is not None: + masks = torch.nn.functional.interpolate( + masks.float(), size=x.shape[-2:], mode="nearest" + ) + x = torch.where(masks.bool(), self.mask_token.to(x.dtype), x).contiguous() + for stage in self.stages: + xs = stage(x) + outs.extend([x.permute(0, 2, 3, 1).contiguous() for x in xs]) + x = xs[-1] + return outs, [x.mean(dim=(1, 2)).unsqueeze(1).contiguous() for x in outs] + + @torch.jit.ignore + def group_matcher(self, coarse=False): + return dict( + stem=r"^stem", + blocks=( + r"^stages\.(\d+)" + if coarse + else [ + (r"^stages\.(\d+)\.downsample", (0,)), # blocks + (r"^stages\.(\d+)\.blocks\.(\d+)", None), + (r"^norm_pre", (99999,)), + ] + ), + ) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + for s in self.stages: + s.grad_checkpointing = enable + + def freeze(self) -> None: + for module in self.modules(): + module.eval() + for parameters in self.parameters(): + parameters.requires_grad = False + + def get_params(self, lr, wd, ld, *args, **kwargs): + encoder_p, encoder_lr = get_parameter_groups(self, lr, wd, ld) + return encoder_p, encoder_lr + + def no_weight_decay(self): + return {"mask_token"} + + @classmethod + def build(cls, config): + obj = globals()[config["model"]["encoder"]["name"]](config) + return obj + + +def checkpoint_filter_fn(state_dict, model): + """Remap FB checkpoints -> timm""" + if "head.norm.weight" in state_dict or "norm_pre.weight" in state_dict: + return state_dict # non-FB checkpoint + if "model" in state_dict: + state_dict = state_dict["model"] + + out_dict = {} + if "visual.trunk.stem.0.weight" in state_dict: + out_dict = { + k.replace("visual.trunk.", ""): v + for k, v in state_dict.items() + if k.startswith("visual.trunk.") + } + if "visual.head.proj.weight" in state_dict: + out_dict["head.fc.weight"] = state_dict["visual.head.proj.weight"] + out_dict["head.fc.bias"] = torch.zeros( + state_dict["visual.head.proj.weight"].shape[0] + ) + elif "visual.head.mlp.fc1.weight" in state_dict: + out_dict["head.pre_logits.fc.weight"] = state_dict[ + "visual.head.mlp.fc1.weight" + ] + out_dict["head.pre_logits.fc.bias"] = state_dict["visual.head.mlp.fc1.bias"] + out_dict["head.fc.weight"] = state_dict["visual.head.mlp.fc2.weight"] + out_dict["head.fc.bias"] = torch.zeros( + state_dict["visual.head.mlp.fc2.weight"].shape[0] + ) + return out_dict + + import re + + for k, v in state_dict.items(): + k = k.replace("downsample_layers.0.", "stem.") + k = re.sub(r"stages.([0-9]+).([0-9]+)", r"stages.\1.blocks.\2", k) + k = re.sub( + r"downsample_layers.([0-9]+).([0-9]+)", r"stages.\1.downsample.\2", k + ) + k = k.replace("dwconv", "conv_dw") + k = k.replace("pwconv", "mlp.fc") + if "grn" in k: + k = k.replace("grn.beta", "mlp.grn.bias") + k = k.replace("grn.gamma", "mlp.grn.weight") + v = v.reshape(v.shape[-1]) + k = k.replace("head.", "head.fc.") + if k.startswith("norm."): + k = k.replace("norm", "head.norm") + if v.ndim == 2 and "head" not in k: + model_shape = model.state_dict()[k].shape + v = v.reshape(model_shape) + out_dict[k] = v + + return out_dict + + +HF_URL = { + "convnext_xxlarge_pt": ( + "laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-soup", + "open_clip_pytorch_model.bin", + ), + "convnext_large_pt": ( + "laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft-soup", + "open_clip_pytorch_model.bin", + ), + "convnext_large": ( + "timm/convnext_large_mlp.clip_laion2b_soup_ft_in12k_in1k_384", + "pytorch_model.bin", + ), +} diff --git a/unidepth/models/backbones/convnext2.py b/unidepth/models/backbones/convnext2.py new file mode 100644 index 0000000000000000000000000000000000000000..74d5bbf12bf586552ef35e2e82d59e47b9c9cc42 --- /dev/null +++ b/unidepth/models/backbones/convnext2.py @@ -0,0 +1,288 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from timm.models.layers import DropPath, trunc_normal_ + + +def get_num_layer_for_convnext_single(var_name, depths): + """ + Each layer is assigned distinctive layer ids + """ + if var_name.startswith("downsample_layers"): + stage_id = int(var_name.split(".")[1]) + layer_id = sum(depths[:stage_id]) + 1 + return layer_id + + elif var_name.startswith("stages"): + stage_id = int(var_name.split(".")[1]) + block_id = int(var_name.split(".")[2]) + layer_id = sum(depths[:stage_id]) + block_id + 1 + return layer_id + + else: + return sum(depths) + 1 + + +def get_num_layer_for_convnext(var_name): + """ + Divide [3, 3, 27, 3] layers into 12 groups; each group is three + consecutive blocks, including possible neighboring downsample layers; + adapted from https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py + """ + num_max_layer = 12 + if var_name.startswith("downsample_layers"): + stage_id = int(var_name.split(".")[1]) + if stage_id == 0: + layer_id = 0 + elif stage_id == 1 or stage_id == 2: + layer_id = stage_id + 1 + elif stage_id == 3: + layer_id = 12 + return layer_id + + elif var_name.startswith("stages"): + stage_id = int(var_name.split(".")[1]) + block_id = int(var_name.split(".")[2]) + if stage_id == 0 or stage_id == 1: + layer_id = stage_id + 1 + elif stage_id == 2: + layer_id = 3 + block_id // 3 + elif stage_id == 3: + layer_id = 12 + return layer_id + else: + return num_max_layer + 1 + + +def get_parameter_groups(model, lr, wd=1e-5, ld=0.9, skip_list=()): + parameter_group_names = {} + parameter_group_vars = {} + skip = {} + if skip_list is not None: + skip = skip_list + elif hasattr(model, "no_weight_decay"): + skip = model.no_weight_decay() + num_layers = 12 # sum(model.depths) + layer_scale = list(ld ** (num_layers + 1 - i) for i in range(num_layers + 2)) + for name, param in model.named_parameters(): + if not param.requires_grad: + continue # frozen weights + if ( + len(param.shape) == 1 + or name.endswith(".bias") + or name in skip + or name.endswith(".gamma") + or name.endswith(".beta") + ): + group_name = "no_decay" + this_weight_decay = 0.0 + else: + group_name = "decay" + this_weight_decay = wd + + # layer_id = get_num_layer_for_convnext_single(name, model.depths) + layer_id = get_num_layer_for_convnext(name) + group_name = "layer_%d_%s" % (layer_id, group_name) + + if group_name not in parameter_group_names: + scale = layer_scale[layer_id] + cur_lr = lr * scale + + parameter_group_names[group_name] = { + "weight_decay": this_weight_decay, + "params": [], + "lr_scale": scale, + "lr": cur_lr, + } + parameter_group_vars[group_name] = { + "weight_decay": this_weight_decay, + "params": [], + "lr_scale": scale, + "lr": cur_lr, + } + parameter_group_vars[group_name]["params"].append(param) + parameter_group_names[group_name]["params"].append(name) + # if is_main_process(): + # print("Param groups = %s" % json.dumps(parameter_group_names, indent=2)) + return list(parameter_group_vars.values()), [ + v["lr"] for k, v in parameter_group_vars.items() + ] + + +class LayerNorm(nn.Module): + """LayerNorm that supports two data formats: channels_last (default) or channels_first. + The ordering of the dimensions in the inputs. channels_last corresponds to inputs with + shape (batch_size, height, width, channels) while channels_first corresponds to inputs + with shape (batch_size, channels, height, width). + """ + + def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): + super().__init__() + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.bias = nn.Parameter(torch.zeros(normalized_shape)) + self.eps = eps + self.data_format = data_format + if self.data_format not in ["channels_last", "channels_first"]: + raise NotImplementedError + self.normalized_shape = (normalized_shape,) + + def forward(self, x): + if self.data_format == "channels_last": + return F.layer_norm( + x, self.normalized_shape, self.weight, self.bias, self.eps + ) + elif self.data_format == "channels_first": + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x + + +class GRN(nn.Module): + """GRN (Global Response Normalization) layer""" + + def __init__(self, dim): + super().__init__() + self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim)) + self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim)) + + def forward(self, x): + Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True) + Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) + return self.gamma * (x * Nx) + self.beta + x + + +class Block(nn.Module): + """ConvNeXtV2 Block. + + Args: + dim (int): Number of input channels. + drop_path (float): Stochastic depth rate. Default: 0.0 + """ + + def __init__(self, dim, drop_path=0.0, mult=4, use_checkpoint=False): + super().__init__() + self.dwconv = nn.Conv2d( + dim, dim, kernel_size=7, padding=3, groups=dim + ) # depthwise conv + self.norm = LayerNorm(dim, eps=1e-6) + self.pwconv1 = nn.Linear( + dim, mult * dim + ) # pointwise/1x1 convs, implemented with linear layers + self.act = nn.GELU() + self.grn = GRN(mult * dim) + self.pwconv2 = nn.Linear(mult * dim, dim) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.use_checkpoint = use_checkpoint + + def forward(self, x): + input = x + x = self.dwconv(x) + x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) + x = self.norm(x) + x = self.pwconv1(x) + x = self.act(x) + x = self.grn(x) + x = self.pwconv2(x) + x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) + + x = input + self.drop_path(x) + return x + + +class ConvNeXtV2(nn.Module): + """ConvNeXt V2 + + Args: + in_chans (int): Number of input image channels. Default: 3 + num_classes (int): Number of classes for classification head. Default: 1000 + depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3] + dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768] + drop_path_rate (float): Stochastic depth rate. Default: 0. + head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1. + """ + + def __init__( + self, + in_chans=3, + depths=[3, 3, 9, 3], + dims=96, + drop_path_rate=0.0, + output_idx=[], + use_checkpoint=False, + ): + super().__init__() + self.num_layers = len(depths) + self.depths = output_idx + self.embed_dims = [ + int(dim) for i, dim in enumerate(dims) for _ in range(depths[i]) + ] + self.embed_dim = dims[0] + + self.downsample_layers = ( + nn.ModuleList() + ) # stem and 3 intermediate downsampling conv layers + stem = nn.Sequential( + nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4), + LayerNorm(dims[0], eps=1e-6, data_format="channels_first"), + ) + self.downsample_layers.append(stem) + for i in range(3): + downsample_layer = nn.Sequential( + LayerNorm(dims[i], eps=1e-6, data_format="channels_first"), + nn.Conv2d(dims[i], dims[i + 1], kernel_size=2, stride=2), + ) + self.downsample_layers.append(downsample_layer) + + self.stages = ( + nn.ModuleList() + ) # 4 feature resolution stages, each consisting of multiple residual blocks + self.out_norms = nn.ModuleList() + dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] + cur = 0 + for i in range(4): + stage = nn.ModuleList( + [ + Block( + dim=dims[i], + drop_path=dp_rates[cur + j], + use_checkpoint=use_checkpoint, + ) + for j in range(depths[i]) + ] + ) + self.stages.append(stage) + cur += depths[i] + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, (nn.Conv2d, nn.Linear)): + trunc_normal_(m.weight, std=0.02) + nn.init.constant_(m.bias, 0) + + def forward(self, x): + outs = [] + for i in range(4): + x = self.downsample_layers[i](x) + for stage in self.stages[i]: + x = stage(x) + outs.append(x.permute(0, 2, 3, 1)) + cls_tokens = [x.mean(dim=(1, 2)).unsqueeze(1).contiguous() for x in outs] + return outs, cls_tokens + + def get_params(self, lr, wd, ld, *args, **kwargs): + encoder_p, encoder_lr = get_parameter_groups(self, lr, wd, ld) + return encoder_p, encoder_lr + + def freeze(self) -> None: + for module in self.modules(): + module.eval() + for parameters in self.parameters(): + parameters.requires_grad = False + + @classmethod + def build(cls, config): + obj = globals()[config["model"]["encoder"]["name"]](config) + return obj diff --git a/unidepth/models/backbones/dinov2.py b/unidepth/models/backbones/dinov2.py new file mode 100644 index 0000000000000000000000000000000000000000..726846d9325562d748397d02ad88262eb4d3bfe7 --- /dev/null +++ b/unidepth/models/backbones/dinov2.py @@ -0,0 +1,455 @@ +import logging +import math +from functools import partial +from typing import Callable, Sequence + +import torch +import torch.nn as nn +from torch.nn.init import trunc_normal_ + +from .metadinov2 import Attention, MemEffAttention, Mlp +from .metadinov2 import NestedTensorBlock as Block +from .metadinov2 import PatchEmbed, SwiGLUFFNFused + +_DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2" +logger = logging.getLogger("dinov2") + + +def named_apply( + fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False +) -> nn.Module: + if not depth_first and include_root: + fn(module=module, name=name) + for child_name, child_module in module.named_children(): + child_name = ".".join((name, child_name)) if name else child_name + named_apply( + fn=fn, + module=child_module, + name=child_name, + depth_first=depth_first, + include_root=True, + ) + if depth_first and include_root: + fn(module=module, name=name) + return module + + +def get_parameter_groups(model, lr, wd=1e-5, ld=0.9, skip_list=()): + parameter_group_names = {} + parameter_group_vars = {} + skip = {} + if skip_list is not None: + skip = skip_list + elif hasattr(model, "no_weight_decay"): + skip = model.no_weight_decay() + + num_layers = model.n_blocks + layer_scale = list(ld ** (num_layers - i) for i in range(num_layers)) + + for name, param in model.named_parameters(): + if not param.requires_grad: + continue + + if len(param.shape) == 1: # norm + group_name = "no_decay" + this_wd = 0.0 + # layer scale, bias beta? + elif ( + name in skip + or name.endswith(".gamma") + or name.endswith(".beta") + or name.endswith(".bias") + ): + group_name = "no_decay" + this_wd = 0.0 + elif "cls_token" in name or "pos_embed" in name or "mask_token" in name: + group_name = "no_decay" + this_wd = 0.0 + else: + group_name = "decay" + this_wd = wd + + if name.startswith("blocks"): + layer_id = int(name.split(".")[1]) + elif name.startswith("patch_embed"): + layer_id = 0 + else: + layer_id = 0 + + group_name = f"layer_{layer_id}_{group_name}" + + if group_name not in parameter_group_names: + scale = layer_scale[layer_id] + cur_lr = lr * scale + + parameter_group_names[group_name] = { + "weight_decay": this_wd, + "params": [], + "lr_init": cur_lr, + "lr_base": lr, + "lr": cur_lr, + } + parameter_group_vars[group_name] = { + "weight_decay": this_wd, + "params": [], + "lr_init": cur_lr, + "lr_base": lr, + "lr": cur_lr, + } + parameter_group_vars[group_name]["params"].append(param) + parameter_group_names[group_name]["params"].append(name) + return list(parameter_group_vars.values()), [ + v["lr"] for k, v in parameter_group_vars.items() + ] + + +class BlockChunk(nn.ModuleList): + def forward(self, x): + for b in self: + x = b(x) + return x + + +class DinoVisionTransformer(nn.Module): + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + ffn_bias=True, + proj_bias=True, + drop_path_rate=0.0, + drop_path_uniform=False, + init_values=None, # for layerscale: None or 0 => no layerscale + embed_layer=PatchEmbed, + act_layer=nn.GELU, + block_fn=Block, + ffn_layer="mlp", + block_chunks=1, + output_idx=[5, 12, 18, 24], + checkpoint: bool = False, + num_register_tokens=0, + interpolate_antialias=False, + interpolate_offset=0.0, + use_norm=False, + ): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + proj_bias (bool): enable bias for proj in attn if True + ffn_bias (bool): enable bias for ffn if True + drop_path_rate (float): stochastic depth rate + drop_path_uniform (bool): apply uniform drop rate across blocks + weight_init (str): weight init scheme + init_values (float): layer-scale init values + embed_layer (nn.Module): patch embedding layer + act_layer (nn.Module): MLP activation layer + block_fn (nn.Module): transformer block class + ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity" + block_chunks: (int) split block sequence into block_chunks units for FSDP wrap + """ + super().__init__() + norm_layer = partial(nn.LayerNorm, eps=1e-6) + + self.num_features = self.embed_dim = ( + embed_dim # num_features for consistency with other models + ) + self.embed_dims = [embed_dim] * output_idx[-1] + self.num_tokens = 1 + self.n_blocks = depth + self.num_heads = num_heads + self.patch_size = patch_size + self.depths = output_idx + self.checkpoint = checkpoint + self.num_register_tokens = num_register_tokens + self.interpolate_antialias = interpolate_antialias + self.interpolate_offset = interpolate_offset + + self.patch_embed = embed_layer( + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + ) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter( + torch.zeros(1, num_patches + self.num_tokens, embed_dim) + ) + assert num_register_tokens >= 0 + self.register_tokens = nn.Parameter( + torch.zeros(1, max(1, num_register_tokens), embed_dim) + ) + + if drop_path_uniform is True: + dpr = [drop_path_rate] * depth + else: + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, depth) + ] # stochastic depth decay rule + + if ffn_layer == "mlp": + logger.info("using MLP layer as FFN") + ffn_layer = Mlp + elif ffn_layer == "swiglufused" or ffn_layer == "swiglu": + logger.info("using SwiGLU layer as FFN") + ffn_layer = SwiGLUFFNFused + elif ffn_layer == "identity": + logger.info("using Identity layer as FFN") + + def f(*args, **kwargs): + return nn.Identity() + + ffn_layer = f + else: + raise NotImplementedError + + blocks_list = [ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + ffn_bias=ffn_bias, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer, + ffn_layer=ffn_layer, + init_values=init_values, + ) + for i in range(depth) + ] + if block_chunks > 0: + self.chunked_blocks = True + chunked_blocks = [] + chunksize = depth // block_chunks + for i in range(0, depth, chunksize): + # this is to keep the block index consistent if we chunk the block list + chunked_blocks.append( + [nn.Identity()] * i + blocks_list[i : i + chunksize] + ) + self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks]) + else: + self.chunked_blocks = False + self.blocks = nn.ModuleList(blocks_list) + + self.norm = norm_layer(embed_dim) + self.use_norm = use_norm + self.head = nn.Identity() + self.mask_token = nn.Parameter(torch.zeros(1, embed_dim)) + self.init_weights() + + def init_weights(self): + trunc_normal_(self.pos_embed, std=0.02) + nn.init.normal_(self.cls_token, std=1e-6) + if self.num_register_tokens: + nn.init.normal_(self.register_tokens, std=1e-6) + named_apply(init_weights_vit_timm, self) + + def interpolate_pos_encoding(self, x, w, h): + previous_dtype = x.dtype + npatch = x.shape[1] - 1 + N = self.pos_embed.shape[1] - 1 + if npatch == N and w == h: + return self.pos_embed + pos_embed = self.pos_embed.float() + class_pos_embed = pos_embed[:, 0] + patch_pos_embed = pos_embed[:, 1:] + dim = x.shape[-1] + w0 = w // self.patch_size + h0 = h // self.patch_size + + M = int(math.sqrt(N)) # Recover the number of patches in each dimension + assert N == M * M + kwargs = {} + if self.interpolate_offset: + # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8 + # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors + sx = float(w0 + self.interpolate_offset) / M + sy = float(h0 + self.interpolate_offset) / M + kwargs["scale_factor"] = (sx, sy) + else: + # Simply specify an output size instead of a scale factor + kwargs["size"] = (w0, h0) + + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2), + mode="bicubic", + antialias=self.interpolate_antialias, + **kwargs, + ) + assert (w0, h0) == patch_pos_embed.shape[-2:] + + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to( + previous_dtype + ) + + def prepare_tokens_with_masks(self, x, masks=None): + B, nc, w, h = x.shape + x = self.patch_embed(x) + if masks is not None: + masks = masks.bool().view(B, -1, 1) + x = torch.where(masks, self.mask_token.to(x.dtype).unsqueeze(0), x) + + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + x = x + self.interpolate_pos_encoding(x, w, h) + + if self.num_register_tokens: + x = torch.cat( + (x[:, :1], self.register_tokens.expand(x.shape[0], -1, -1), x[:, 1:]), + dim=1, + ) + return x + + def forward(self, x, masks=None): + shapes = [val // self.patch_size for val in x.shape[-2:]] + batch_size = x.shape[0] + x = self.prepare_tokens_with_masks(x, masks) + outputs = [] + for i, blk in enumerate(self.blocks): + x = blk(x) + outputs.append(x) + + if self.use_norm: + outputs = [self.norm(out) for out in outputs] + class_tokens = [out[:, :1] for out in outputs] + outputs = [out[:, self.num_register_tokens + 1 :] for out in outputs] + outputs = [out.reshape(batch_size, *shapes, -1) for out in outputs] + + return (outputs, class_tokens) + + def get_params(self, lr, wd, ld, *args, **kwargs): + encoder_p, encoder_lr = get_parameter_groups(self, lr, wd, ld) + return encoder_p, encoder_lr + + def freeze(self) -> None: + for module in self.modules(): + module.eval() + for parameters in self.parameters(): + parameters.requires_grad = False + + def train(self, mode=True): + super().train(mode) + self.mask_token.requires_grad = False + self.register_tokens.requires_grad = False + + +def init_weights_vit_timm(module: nn.Module, name: str = ""): + """ViT weight initialization, original timm impl (for reproducibility)""" + if isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + + +def vit_small(patch_size=16, num_register_tokens=0, export=False, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=384, + depth=12, + num_heads=6, + mlp_ratio=4, + num_register_tokens=num_register_tokens, + block_fn=partial(Block, attn_class=Attention if export else MemEffAttention), + **kwargs, + ) + return model + + +def vit_base(patch_size=16, num_register_tokens=0, export=False, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + num_register_tokens=num_register_tokens, + block_fn=partial(Block, attn_class=Attention if export else MemEffAttention), + **kwargs, + ) + return model + + +def vit_large(patch_size=16, num_register_tokens=0, export=False, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + num_register_tokens=num_register_tokens, + block_fn=partial(Block, attn_class=Attention if export else MemEffAttention), + **kwargs, + ) + return model + + +def _make_dinov2_model_name(arch_name: str, patch_size: int) -> str: + compact_arch_name = arch_name.replace("_", "")[:4] + return f"dinov2_{compact_arch_name}{patch_size}" + + +def _make_dinov2_model( + *, + arch_name: str = "vit_large", + img_size: int = 518, + patch_size: int = 14, + init_values: float = 1.0, + ffn_layer: str = "mlp", + block_chunks: int = 0, + pretrained: str = "", + output_idx: Sequence[int] = [], + num_register_tokens: int = 0, + drop_path_rate: float = 0.0, + use_norm: bool = False, + export: bool = False, + interpolate_offset: float = 0.0, + **kwargs, +): + model_name = _make_dinov2_model_name(arch_name, patch_size) + + vit_kwargs = dict( + img_size=img_size, + patch_size=patch_size, + init_values=init_values, + ffn_layer=ffn_layer, + block_chunks=block_chunks, + output_idx=output_idx, + drop_path_rate=drop_path_rate, + num_register_tokens=num_register_tokens, + use_norm=use_norm, + export=export, + interpolate_offset=interpolate_offset, + ) + vit_kwargs.update(**kwargs) + model = eval(arch_name)(**vit_kwargs) + if pretrained == "": + url = _DINOV2_BASE_URL + f"/{model_name}/{model_name}" + if num_register_tokens > 0: + url += "_reg4" + url += "_pretrain.pth" + state_dict = torch.hub.load_state_dict_from_url( + url, map_location="cpu", progress=False + ) + info = model.load_state_dict(state_dict, strict=False) + print(info) + elif pretrained is not None: + state_dict = torch.load(pretrained, map_location="cpu") + info = model.load_state_dict(state_dict, strict=False) + print(f"loading from {pretrained} with:", info) + return model diff --git a/unidepth/models/backbones/metadinov2/__init__.py b/unidepth/models/backbones/metadinov2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e9efd71a64023875330cc1c42a87452597dd9ae3 --- /dev/null +++ b/unidepth/models/backbones/metadinov2/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .attention import Attention, MemEffAttention +from .block import NestedTensorBlock +from .dino_head import DINOHead +from .mlp import Mlp +from .patch_embed import PatchEmbed +from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused diff --git a/unidepth/models/backbones/metadinov2/__pycache__/__init__.cpython-311.pyc b/unidepth/models/backbones/metadinov2/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b0daa24d078f6a6dd447c657e63d46f6593a1966 Binary files /dev/null and b/unidepth/models/backbones/metadinov2/__pycache__/__init__.cpython-311.pyc differ diff --git a/unidepth/models/backbones/metadinov2/__pycache__/attention.cpython-311.pyc b/unidepth/models/backbones/metadinov2/__pycache__/attention.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0d112603d72c0a6f904c8d5c3be0b95ad0325a40 Binary files /dev/null and b/unidepth/models/backbones/metadinov2/__pycache__/attention.cpython-311.pyc differ diff --git a/unidepth/models/backbones/metadinov2/__pycache__/block.cpython-311.pyc b/unidepth/models/backbones/metadinov2/__pycache__/block.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cb66772a2bb7f0da991b9f5a7b2d269ad747f264 Binary files /dev/null and b/unidepth/models/backbones/metadinov2/__pycache__/block.cpython-311.pyc differ diff --git a/unidepth/models/backbones/metadinov2/__pycache__/dino_head.cpython-311.pyc b/unidepth/models/backbones/metadinov2/__pycache__/dino_head.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..45660885d00820d042db5fa504ce7bdd2ed6a348 Binary files /dev/null and b/unidepth/models/backbones/metadinov2/__pycache__/dino_head.cpython-311.pyc differ diff --git a/unidepth/models/backbones/metadinov2/__pycache__/drop_path.cpython-311.pyc b/unidepth/models/backbones/metadinov2/__pycache__/drop_path.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bbe8d85f6c009d37a3fc7ab8cd27f2eace206720 Binary files /dev/null and b/unidepth/models/backbones/metadinov2/__pycache__/drop_path.cpython-311.pyc differ diff --git a/unidepth/models/backbones/metadinov2/__pycache__/layer_scale.cpython-311.pyc b/unidepth/models/backbones/metadinov2/__pycache__/layer_scale.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7dd2db4fa3d7708c477deaea059c95c5abad8ea7 Binary files /dev/null and b/unidepth/models/backbones/metadinov2/__pycache__/layer_scale.cpython-311.pyc differ diff --git a/unidepth/models/backbones/metadinov2/__pycache__/mlp.cpython-311.pyc b/unidepth/models/backbones/metadinov2/__pycache__/mlp.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..64fecee26ff2a693c37b58949b33d74b28f59ab3 Binary files /dev/null and b/unidepth/models/backbones/metadinov2/__pycache__/mlp.cpython-311.pyc differ diff --git a/unidepth/models/backbones/metadinov2/__pycache__/patch_embed.cpython-311.pyc b/unidepth/models/backbones/metadinov2/__pycache__/patch_embed.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4669290d5adc71e18a245317dd31aeab241fc45d Binary files /dev/null and b/unidepth/models/backbones/metadinov2/__pycache__/patch_embed.cpython-311.pyc differ diff --git a/unidepth/models/backbones/metadinov2/__pycache__/swiglu_ffn.cpython-311.pyc b/unidepth/models/backbones/metadinov2/__pycache__/swiglu_ffn.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..af25ef6efe2d2fa99f3d73d9854743f36856e741 Binary files /dev/null and b/unidepth/models/backbones/metadinov2/__pycache__/swiglu_ffn.cpython-311.pyc differ diff --git a/unidepth/models/backbones/metadinov2/attention.py b/unidepth/models/backbones/metadinov2/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..cf3c8867852e0ae138e2dc90ded3e7a0e6ee5a17 --- /dev/null +++ b/unidepth/models/backbones/metadinov2/attention.py @@ -0,0 +1,84 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +import logging + +import torch.nn as nn +from torch import Tensor + +logger = logging.getLogger("dinov2") + + +try: + from xformers.ops import fmha, memory_efficient_attention, unbind + + XFORMERS_AVAILABLE = True +except ImportError: + logger.warning("xFormers not available") + XFORMERS_AVAILABLE = False + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + ) -> None: + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: Tensor) -> Tensor: + B, N, C = x.shape + qkv = ( + self.qkv(x) + .reshape(B, N, 3, self.num_heads, C // self.num_heads) + .permute(2, 0, 3, 1, 4) + ) + + q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] + attn = q @ k.transpose(-2, -1) + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class MemEffAttention(Attention): + def forward(self, x: Tensor, attn_bias=None) -> Tensor: + if not XFORMERS_AVAILABLE: + assert attn_bias is None, "xFormers is required for nested tensors usage" + return super().forward(x) + + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) + + q, k, v = unbind(qkv, 2) + + x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) + x = x.reshape([B, N, C]) + + x = self.proj(x) + x = self.proj_drop(x) + return x diff --git a/unidepth/models/backbones/metadinov2/block.py b/unidepth/models/backbones/metadinov2/block.py new file mode 100644 index 0000000000000000000000000000000000000000..801f8f68785e36fb16dcb1a743302e90b70c0c94 --- /dev/null +++ b/unidepth/models/backbones/metadinov2/block.py @@ -0,0 +1,282 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +import logging +from typing import Any, Callable, Dict, List, Tuple + +import torch +import torch.nn as nn + +from .attention import Attention, MemEffAttention +from .drop_path import DropPath +from .layer_scale import LayerScale +from .mlp import Mlp + +logger = logging.getLogger("dinov2") + + +try: + from xformers.ops import fmha, index_select_cat, scaled_index_add + + XFORMERS_AVAILABLE = True +except ImportError: + logger.warning("xFormers not available") + XFORMERS_AVAILABLE = False + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + proj_bias: bool = True, + ffn_bias: bool = True, + drop: float = 0.0, + attn_drop: float = 0.0, + init_values=None, + drop_path: float = 0.0, + act_layer: Callable[..., nn.Module] = nn.GELU, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_class: Callable[..., nn.Module] = Attention, + ffn_layer: Callable[..., nn.Module] = Mlp, + ) -> None: + super().__init__() + # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}") + self.norm1 = norm_layer(dim) + self.attn = attn_class( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + attn_drop=attn_drop, + proj_drop=drop, + ) + self.ls1 = ( + LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + ) + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = ffn_layer( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + bias=ffn_bias, + ) + self.ls2 = ( + LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + ) + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.sample_drop_ratio = drop_path + + def forward(self, x: torch.Tensor) -> torch.Tensor: + def attn_residual_func(x: torch.Tensor) -> torch.Tensor: + return self.ls1(self.attn(self.norm1(x))) + + def ffn_residual_func(x: torch.Tensor) -> torch.Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + if self.training and self.sample_drop_ratio > 0.1: + # the overhead is compensated only for a drop path rate larger than 0.1 + x = drop_add_residual_stochastic_depth( + x, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + x = drop_add_residual_stochastic_depth( + x, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + elif self.training and self.sample_drop_ratio > 0.0: + x = x + self.drop_path1(attn_residual_func(x)) + x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 + else: + x = x + attn_residual_func(x) + x = x + ffn_residual_func(x) + return x + + +def drop_add_residual_stochastic_depth( + x: torch.Tensor, + residual_func: Callable[[torch.Tensor], torch.Tensor], + sample_drop_ratio: float = 0.0, +) -> torch.Tensor: + # 1) extract subset using permutation + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + x_subset = x[brange] + + # 2) apply residual_func to get residual + residual = residual_func(x_subset) + + x_flat = x.flatten(1) + residual = residual.flatten(1) + + residual_scale_factor = b / sample_subset_size + + # 3) add the residual + x_plus_residual = torch.index_add( + x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor + ) + return x_plus_residual.view_as(x) + + +def get_branges_scales(x, sample_drop_ratio=0.0): + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + residual_scale_factor = b / sample_subset_size + return brange, residual_scale_factor + + +def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None): + if scaling_vector is None: + x_flat = x.flatten(1) + residual = residual.flatten(1) + x_plus_residual = torch.index_add( + x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor + ) + else: + x_plus_residual = scaled_index_add( + x, + brange, + residual.to(dtype=x.dtype), + scaling=scaling_vector, + alpha=residual_scale_factor, + ) + return x_plus_residual + + +attn_bias_cache: Dict[Tuple, Any] = {} + + +def get_attn_bias_and_cat(x_list, branges=None): + """ + this will perform the index select, cat the tensors, and provide the attn_bias from cache + """ + batch_sizes = ( + [b.shape[0] for b in branges] + if branges is not None + else [x.shape[0] for x in x_list] + ) + all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list)) + if all_shapes not in attn_bias_cache.keys(): + seqlens = [] + for b, x in zip(batch_sizes, x_list): + for _ in range(b): + seqlens.append(x.shape[1]) + attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) + attn_bias._batch_sizes = batch_sizes + attn_bias_cache[all_shapes] = attn_bias + + if branges is not None: + cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view( + 1, -1, x_list[0].shape[-1] + ) + else: + tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list) + cat_tensors = torch.cat(tensors_bs1, dim=1) + + return attn_bias_cache[all_shapes], cat_tensors + + +def drop_add_residual_stochastic_depth_list( + x_list: List[torch.Tensor], + residual_func: Callable[[torch.Tensor, Any], torch.Tensor], + sample_drop_ratio: float = 0.0, + scaling_vector=None, +) -> torch.Tensor: + # 1) generate random set of indices for dropping samples in the batch + branges_scales = [ + get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list + ] + branges = [s[0] for s in branges_scales] + residual_scale_factors = [s[1] for s in branges_scales] + + # 2) get attention bias and index+concat the tensors + attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges) + + # 3) apply residual_func to get residual, and split the result + residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore + + outputs = [] + for x, brange, residual, residual_scale_factor in zip( + x_list, branges, residual_list, residual_scale_factors + ): + outputs.append( + add_residual( + x, brange, residual, residual_scale_factor, scaling_vector + ).view_as(x) + ) + return outputs + + +class NestedTensorBlock(Block): + def forward_nested(self, x_list: List[torch.Tensor]) -> List[torch.Tensor]: + """ + x_list contains a list of tensors to nest together and run + """ + assert isinstance(self.attn, MemEffAttention) + + if self.training and self.sample_drop_ratio > 0.0: + + def attn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor: + return self.attn(self.norm1(x), attn_bias=attn_bias) + + def ffn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor: + return self.mlp(self.norm2(x)) + + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=( + self.ls1.gamma if isinstance(self.ls1, LayerScale) else None + ), + ) + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=( + self.ls2.gamma if isinstance(self.ls1, LayerScale) else None + ), + ) + return x_list + else: + + def attn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor: + return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias)) + + def ffn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + attn_bias, x = get_attn_bias_and_cat(x_list) + x = x + attn_residual_func(x, attn_bias=attn_bias) + x = x + ffn_residual_func(x) + return attn_bias.split(x) + + def forward(self, x_or_x_list): + if isinstance(x_or_x_list, torch.Tensor): + return super(NestedTensorBlock, self).forward(x_or_x_list) + elif isinstance(x_or_x_list, list): + assert ( + XFORMERS_AVAILABLE + ), "Please install xFormers for nested tensors usage" + return self.forward_nested(x_or_x_list) + else: + raise AssertionError diff --git a/unidepth/models/backbones/metadinov2/dino_head.py b/unidepth/models/backbones/metadinov2/dino_head.py new file mode 100644 index 0000000000000000000000000000000000000000..1147dd3a3c046aee8d427b42b1055f38a218275b --- /dev/null +++ b/unidepth/models/backbones/metadinov2/dino_head.py @@ -0,0 +1,68 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +from torch.nn.init import trunc_normal_ +from torch.nn.utils import weight_norm + + +class DINOHead(nn.Module): + def __init__( + self, + in_dim, + out_dim, + use_bn=False, + nlayers=3, + hidden_dim=2048, + bottleneck_dim=256, + mlp_bias=True, + ): + super().__init__() + nlayers = max(nlayers, 1) + self.mlp = _build_mlp( + nlayers, + in_dim, + bottleneck_dim, + hidden_dim=hidden_dim, + use_bn=use_bn, + bias=mlp_bias, + ) + self.apply(self._init_weights) + self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) + self.last_layer.weight_g.data.fill_(1) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x): + x = self.mlp(x) + eps = 1e-6 if x.dtype == torch.float16 else 1e-12 + x = nn.functional.normalize(x, dim=-1, p=2, eps=eps) + x = self.last_layer(x) + return x + + +def _build_mlp( + nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True +): + if nlayers == 1: + return nn.Linear(in_dim, bottleneck_dim, bias=bias) + else: + layers = [nn.Linear(in_dim, hidden_dim, bias=bias)] + if use_bn: + layers.append(nn.BatchNorm1d(hidden_dim)) + layers.append(nn.GELU()) + for _ in range(nlayers - 2): + layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias)) + if use_bn: + layers.append(nn.BatchNorm1d(hidden_dim)) + layers.append(nn.GELU()) + layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias)) + return nn.Sequential(*layers) diff --git a/unidepth/models/backbones/metadinov2/drop_path.py b/unidepth/models/backbones/metadinov2/drop_path.py new file mode 100644 index 0000000000000000000000000000000000000000..35b1a620d06ba862ea05297d271d8c2c625b5f93 --- /dev/null +++ b/unidepth/models/backbones/metadinov2/drop_path.py @@ -0,0 +1,37 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py + + +import torch.nn as nn + + +def drop_path(x, drop_prob: float = 0.0, training: bool = False): + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * ( + x.ndim - 1 + ) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0: + random_tensor.div_(keep_prob) + output = x * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) diff --git a/unidepth/models/backbones/metadinov2/layer_scale.py b/unidepth/models/backbones/metadinov2/layer_scale.py new file mode 100644 index 0000000000000000000000000000000000000000..40d18b5427183534d5516652b076f9883a609fc6 --- /dev/null +++ b/unidepth/models/backbones/metadinov2/layer_scale.py @@ -0,0 +1,28 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 + +from typing import Union + +import torch +import torch.nn as nn +from torch import Tensor + + +class LayerScale(nn.Module): + def __init__( + self, + dim: int, + init_values: Union[float, Tensor] = 1e-5, + inplace: bool = False, + ) -> None: + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: Tensor) -> Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma diff --git a/unidepth/models/backbones/metadinov2/mlp.py b/unidepth/models/backbones/metadinov2/mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..5e4b315f972f9a9f54aef1e4ef4e81b52976f018 --- /dev/null +++ b/unidepth/models/backbones/metadinov2/mlp.py @@ -0,0 +1,41 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py + + +from typing import Callable, Optional + +from torch import Tensor, nn + + +class Mlp(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = nn.GELU, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) + self.drop = nn.Dropout(drop) + + def forward(self, x: Tensor) -> Tensor: + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x diff --git a/unidepth/models/backbones/metadinov2/patch_embed.py b/unidepth/models/backbones/metadinov2/patch_embed.py new file mode 100644 index 0000000000000000000000000000000000000000..e5a56c02609e67922eb8f859588ef274e5298b55 --- /dev/null +++ b/unidepth/models/backbones/metadinov2/patch_embed.py @@ -0,0 +1,101 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +from typing import Callable, Optional, Tuple, Union + +import torch.nn as nn +from torch import Tensor + + +def make_2tuple(x): + if isinstance(x, tuple): + assert len(x) == 2 + return x + + assert isinstance(x, int) + return (x, x) + + +class PatchEmbed(nn.Module): + """ + 2D image to patch embedding: (B,C,H,W) -> (B,N,D) + + Args: + img_size: Image size. + patch_size: Patch token size. + in_chans: Number of input image channels. + embed_dim: Number of linear projection output channels. + norm_layer: Normalization layer. + """ + + def __init__( + self, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 16, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer: Optional[Callable] = None, + flatten_embedding: bool = True, + ) -> None: + super().__init__() + + image_HW = make_2tuple(img_size) + patch_HW = make_2tuple(patch_size) + patch_grid_size = ( + image_HW[0] // patch_HW[0], + image_HW[1] // patch_HW[1], + ) + + self.img_size = image_HW + self.patch_size = patch_HW + self.patches_resolution = patch_grid_size + self.num_patches = patch_grid_size[0] * patch_grid_size[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.flatten_embedding = flatten_embedding + + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW + ) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x: Tensor) -> Tensor: + _, _, H, W = x.shape + patch_H, patch_W = self.patch_size + + assert ( + H % patch_H == 0 + ), f"Input image height {H} is not a multiple of patch height {patch_H}" + assert ( + W % patch_W == 0 + ), f"Input image width {W} is not a multiple of patch width: {patch_W}" + + x = self.proj(x) # B C H W + H, W = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) # B HW C + x = self.norm(x) + if not self.flatten_embedding: + x = x.reshape(-1, H, W, self.embed_dim) # B H W C + return x + + def flops(self) -> float: + Ho, Wo = self.patches_resolution + flops = ( + Ho + * Wo + * self.embed_dim + * self.in_chans + * (self.patch_size[0] * self.patch_size[1]) + ) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops diff --git a/unidepth/models/backbones/metadinov2/swiglu_ffn.py b/unidepth/models/backbones/metadinov2/swiglu_ffn.py new file mode 100644 index 0000000000000000000000000000000000000000..e82999e9b09b41cd6aba9edbc4c05d51ab663a1e --- /dev/null +++ b/unidepth/models/backbones/metadinov2/swiglu_ffn.py @@ -0,0 +1,63 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Callable, Optional + +import torch.nn.functional as F +from torch import Tensor, nn + + +class SwiGLUFFN(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) + self.w3 = nn.Linear(hidden_features, out_features, bias=bias) + + def forward(self, x: Tensor) -> Tensor: + x12 = self.w12(x) + x1, x2 = x12.chunk(2, dim=-1) + hidden = F.silu(x1) * x2 + return self.w3(hidden) + + +try: + from xformers.ops import SwiGLU + + XFORMERS_AVAILABLE = True +except ImportError: + SwiGLU = SwiGLUFFN + XFORMERS_AVAILABLE = False + + +class SwiGLUFFNFused(SwiGLU): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + out_features = out_features or in_features + hidden_features = hidden_features or in_features + hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 + super().__init__( + in_features=in_features, + hidden_features=hidden_features, + out_features=out_features, + bias=bias, + ) diff --git a/unidepth/models/encoder.py b/unidepth/models/encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..319e21c5600251f1b3c2da0c3df628b90f357b11 --- /dev/null +++ b/unidepth/models/encoder.py @@ -0,0 +1,193 @@ +import torch +import torch.nn as nn + +from unidepth.models.backbones import ConvNeXt, ConvNeXtV2, _make_dinov2_model + + +class ModelWrap(nn.Module): + def __init__(self, model) -> None: + super().__init__() + self.backbone = model + + def forward(self, x, *args, **kwargs): + features = [] + for layer in self.backbone.features: + x = layer(x) + features.append(x) + return features + + +def convnextv2_base(config, **kwargs): + model = ConvNeXtV2( + depths=[3, 3, 27, 3], + dims=[128, 256, 512, 1024], + output_idx=config.get("output_idx", [3, 6, 33, 36]), + use_checkpoint=config.get("use_checkpoint", False), + **kwargs, + ) + url = "https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_base_22k_384_ema.pt" + state_dict = torch.hub.load_state_dict_from_url( + url, map_location="cpu", progress=False + )["model"] + info = model.load_state_dict(state_dict, strict=False) + print(info) + return model + + +def convnextv2_large(config, **kwargs): + model = ConvNeXtV2( + depths=[3, 3, 27, 3], + dims=[192, 384, 768, 1536], + output_idx=config.get("output_idx", [3, 6, 33, 36]), + use_checkpoint=config.get("use_checkpoint", False), + **kwargs, + ) + url = "https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_large_22k_384_ema.pt" + state_dict = torch.hub.load_state_dict_from_url( + url, map_location="cpu", progress=False + )["model"] + info = model.load_state_dict(state_dict, strict=False) + print(info) + return model + + +def convnextv2_large_mae(config, **kwargs): + model = ConvNeXtV2( + depths=[3, 3, 27, 3], + dims=[192, 384, 768, 1536], + output_idx=config.get("output_idx", [3, 6, 33, 36]), + use_checkpoint=config.get("use_checkpoint", False), + **kwargs, + ) + url = "https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_large_1k_224_fcmae.pt" + state_dict = torch.hub.load_state_dict_from_url( + url, map_location="cpu", progress=False + )["model"] + info = model.load_state_dict(state_dict, strict=False) + print(info) + return model + + +def convnextv2_huge(config, **kwargs): + model = ConvNeXtV2( + depths=[3, 3, 27, 3], + dims=[352, 704, 1408, 2816], + output_idx=config.get("output_idx", [3, 6, 33, 36]), + use_checkpoint=config.get("use_checkpoint", False), + **kwargs, + ) + url = "https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_huge_22k_512_ema.pt" + state_dict = torch.hub.load_state_dict_from_url( + url, map_location="cpu", progress=False + )["model"] + info = model.load_state_dict(state_dict, strict=False) + print(info) + return model + + +def convnextv2_huge_mae(config, **kwargs): + model = ConvNeXtV2( + depths=[3, 3, 27, 3], + dims=[352, 704, 1408, 2816], + output_idx=config.get("output_idx", [3, 6, 33, 36]), + use_checkpoint=config.get("use_checkpoint", False), + **kwargs, + ) + url = "https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_huge_1k_224_fcmae.pt" + state_dict = torch.hub.load_state_dict_from_url( + url, map_location="cpu", progress=False + )["model"] + info = model.load_state_dict(state_dict, strict=False) + print(info) + return model + + +def convnext_large_pt(config, **kwargs): + model = ConvNeXt( + depths=[3, 3, 27, 3], + dims=[192, 384, 768, 1536], + output_idx=config.get("output_idx", [3, 6, 33, 36]), + use_checkpoint=config.get("use_checkpoint", False), + **kwargs, + ) + from huggingface_hub import hf_hub_download + from huggingface_hub.utils import disable_progress_bars + + from unidepth.models.backbones.convnext import HF_URL, checkpoint_filter_fn + + disable_progress_bars() + repo_id, filename = HF_URL["convnext_large_pt"] + state_dict = torch.load(hf_hub_download(repo_id=repo_id, filename=filename)) + state_dict = checkpoint_filter_fn(state_dict, model) + info = model.load_state_dict(state_dict, strict=False) + print(info) + return model + + +def convnext_large(config, **kwargs): + model = ConvNeXt( + depths=[3, 3, 27, 3], + dims=[192, 384, 768, 1536], + output_idx=config.get("output_idx", [3, 6, 33, 36]), + use_checkpoint=config.get("use_checkpoint", False), + drop_path_rate=config.get("drop_path", 0.0), + **kwargs, + ) + return model + + +def dinov2_vits14(config, pretrained: bool = True, **kwargs): + """ + DINOv2 ViT-S/14 model (optionally) pretrained on the LVD-142M dataset. + """ + vit = _make_dinov2_model( + arch_name="vit_small", + pretrained=config["pretrained"], + output_idx=config.get("output_idx", [3, 6, 9, 12]), + checkpoint=config.get("use_checkpoint", False), + drop_path_rate=config.get("drop_path", 0.0), + num_register_tokens=config.get("num_register_tokens", 0), + use_norm=config.get("use_norm", False), + export=config.get("export", False), + interpolate_offset=config.get("interpolate_offset", 0.0), + **kwargs, + ) + return vit + + +def dinov2_vitb14(config, pretrained: bool = True, **kwargs): + """ + DINOv2 ViT-B/14 model (optionally) pretrained on the LVD-142M dataset. + """ + vit = _make_dinov2_model( + arch_name="vit_base", + pretrained=config["pretrained"], + output_idx=config.get("output_idx", [3, 6, 9, 12]), + checkpoint=config.get("use_checkpoint", False), + drop_path_rate=config.get("drop_path", 0.0), + num_register_tokens=config.get("num_register_tokens", 0), + use_norm=config.get("use_norm", False), + export=config.get("export", False), + interpolate_offset=config.get("interpolate_offset", 0.0), + **kwargs, + ) + return vit + + +def dinov2_vitl14(config, pretrained: str = "", **kwargs): + """ + DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset. + """ + vit = _make_dinov2_model( + arch_name="vit_large", + pretrained=config["pretrained"], + output_idx=config.get("output_idx", [5, 12, 18, 24]), + checkpoint=config.get("use_checkpoint", False), + drop_path_rate=config.get("drop_path", 0.0), + num_register_tokens=config.get("num_register_tokens", 0), + use_norm=config.get("use_norm", False), + export=config.get("export", False), + interpolate_offset=config.get("interpolate_offset", 0.0), + **kwargs, + ) + return vit diff --git a/unidepth/models/unidepthv1/__init__.py b/unidepth/models/unidepthv1/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1781bda94cdc13b0c0c805e7cde0872defc20cd3 --- /dev/null +++ b/unidepth/models/unidepthv1/__init__.py @@ -0,0 +1,5 @@ +from .unidepthv1 import UniDepthV1 + +__all__ = [ + "UniDepthV1", +] diff --git a/unidepth/models/unidepthv1/__pycache__/__init__.cpython-311.pyc b/unidepth/models/unidepthv1/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ab614fbbd802eaaea976656c345aa13154b66c61 Binary files /dev/null and b/unidepth/models/unidepthv1/__pycache__/__init__.cpython-311.pyc differ diff --git a/unidepth/models/unidepthv1/__pycache__/decoder.cpython-311.pyc b/unidepth/models/unidepthv1/__pycache__/decoder.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1e99035a33d379d432c508faa7ab966fa130749a Binary files /dev/null and b/unidepth/models/unidepthv1/__pycache__/decoder.cpython-311.pyc differ diff --git a/unidepth/models/unidepthv1/__pycache__/unidepthv1.cpython-311.pyc b/unidepth/models/unidepthv1/__pycache__/unidepthv1.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2cb67fdcd6070f39563b9abd92aa252cf72a52a6 Binary files /dev/null and b/unidepth/models/unidepthv1/__pycache__/unidepthv1.cpython-311.pyc differ diff --git a/unidepth/models/unidepthv1/decoder.py b/unidepth/models/unidepthv1/decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..87b0a0a483884ee74c0e86fe164c57216713213f --- /dev/null +++ b/unidepth/models/unidepthv1/decoder.py @@ -0,0 +1,535 @@ +""" +Author: Luigi Piccinelli +Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/) +""" + +from typing import List, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from timm.models.layers import trunc_normal_ + +from unidepth.layers import (MLP, AttentionBlock, ConvUpsample, NystromBlock, + PositionEmbeddingSine) +from unidepth.utils.geometric import flat_interpolate, generate_rays +from unidepth.utils.misc import max_stack +from unidepth.utils.sht import rsh_cart_8 + + +class ListAdapter(nn.Module): + def __init__(self, input_dims: List[int], hidden_dim: int): + super().__init__() + self.input_adapters = nn.ModuleList([]) + self.num_chunks = len(input_dims) + for input_dim in input_dims: + self.input_adapters.append( + nn.Sequential( + nn.LayerNorm(input_dim), nn.Linear(input_dim, hidden_dim), nn.GELU() + ) + ) + + def forward(self, x: torch.Tensor, splits: torch.Tensor) -> torch.Tensor: + xs = torch.split(x, splits.int().tolist(), dim=-1) + xs = [adapter(x) for x, adapter in zip(xs, self.input_adapters)] + return torch.cat(xs, dim=-1) + + +class CameraHead(nn.Module): + def __init__( + self, + input_dim: int, + hidden_dim: int, + num_heads: int = 8, + expansion: int = 4, + depth: int = 4, + dropout: float = 0.0, + layer_scale: float = 1.0, + **kwargs, + ): + super().__init__() + + self.aggregate = AttentionBlock( + hidden_dim, + num_heads=1, + expansion=expansion, + dropout=dropout, + layer_scale=layer_scale, + ) + self.latents_pos = nn.Parameter( + torch.randn(1, 4, hidden_dim), requires_grad=True + ) + self.layers = nn.ModuleList([]) + self.in_features = MLP(hidden_dim, expansion=2, dropout=dropout) + for _ in range(depth): + blk = AttentionBlock( + hidden_dim, + num_heads=num_heads, + expansion=expansion, + dropout=dropout, + layer_scale=layer_scale, + ) + self.layers.append(blk) + self.out = MLP(hidden_dim, expansion=2, dropout=0.0, output_dim=1) + self.cls_project = nn.Sequential( + nn.LayerNorm(input_dim), + nn.Linear(input_dim, hidden_dim // 2), + nn.GELU(), + nn.Linear(hidden_dim // 2, hidden_dim), + ) + + def forward(self, features, cls_tokens, pos_embed) -> torch.Tensor: + features = features.unbind(dim=-1) + cls_tokens = self.cls_project(cls_tokens) + features_stack = torch.cat(features, dim=1) + features_stack = features_stack + pos_embed + latents_pos = self.latents_pos.expand(cls_tokens.shape[0], -1, -1) + features_stack = self.in_features(features_stack) + features = torch.cat((features_stack, cls_tokens), dim=1) + cls_tokens = self.aggregate(cls_tokens, context=features, pos_embed=latents_pos) + for i, layer in enumerate(self.layers): + cls_tokens = layer(cls_tokens, pos_embed=latents_pos) + + # project + x = self.out(cls_tokens).squeeze(-1) + camera_intrinsics = torch.zeros( + x.shape[0], 3, 3, device=x.device, requires_grad=False + ) + camera_intrinsics[:, 0, 0] = x[:, 0].exp() + camera_intrinsics[:, 1, 1] = x[:, 1].exp() + camera_intrinsics[:, 0, 2] = x[:, 2].sigmoid() + camera_intrinsics[:, 1, 2] = x[:, 3].sigmoid() + camera_intrinsics[:, 2, 2] = 1.0 + return camera_intrinsics + + def set_shapes(self, shapes: Tuple[int, int]): + self.shapes = shapes + + +class DepthHead(nn.Module): + def __init__( + self, + hidden_dim: int, + num_heads: int = 8, + expansion: int = 4, + depths: int | list[int] = 4, + camera_dim: int = 256, + num_resolutions: int = 4, + dropout: float = 0.0, + layer_scale: float = 1.0, + **kwargs, + ) -> None: + super().__init__() + if isinstance(depths, int): + depths = [depths] * 3 + assert len(depths) == 3 + + self.project_rays16 = MLP( + camera_dim, expansion=expansion, dropout=dropout, output_dim=hidden_dim + ) + self.project_rays8 = MLP( + camera_dim, expansion=expansion, dropout=dropout, output_dim=hidden_dim // 2 + ) + self.project_rays4 = MLP( + camera_dim, expansion=expansion, dropout=dropout, output_dim=hidden_dim // 4 + ) + self.to_latents = MLP(hidden_dim, expansion=2, dropout=dropout) + + self.features_channel_cat = nn.Linear(hidden_dim * num_resolutions, hidden_dim) + + self.up8 = ConvUpsample( + hidden_dim, expansion=expansion, layer_scale=layer_scale + ) + self.up4 = ConvUpsample( + hidden_dim // 2, expansion=expansion, layer_scale=layer_scale + ) + self.up2 = ConvUpsample( + hidden_dim // 4, expansion=expansion, layer_scale=layer_scale + ) + + self.layers_16 = nn.ModuleList([]) + self.layers_8 = nn.ModuleList([]) + self.layers_4 = nn.ModuleList([]) + self.aggregate_16 = AttentionBlock( + hidden_dim, + num_heads=1, + expansion=expansion, + dropout=dropout, + layer_scale=layer_scale, + context_dim=hidden_dim, + ) + self.prompt_camera = AttentionBlock( + hidden_dim, + num_heads=1, + expansion=expansion, + dropout=dropout, + layer_scale=layer_scale, + context_dim=hidden_dim, + ) + for i, (blk_lst, depth) in enumerate( + zip([self.layers_16, self.layers_8, self.layers_4], depths) + ): + attn_cls = AttentionBlock if i == 0 else NystromBlock + for _ in range(depth): + blk_lst.append( + attn_cls( + hidden_dim // (2**i), + num_heads=num_heads // (2**i), + expansion=expansion, + dropout=dropout, + layer_scale=layer_scale, + ) + ) + + self.out2 = nn.Conv2d(hidden_dim // 8, 1, 3, padding=1) + self.out4 = nn.Conv2d(hidden_dim // 4, 1, 3, padding=1) + self.out8 = nn.Conv2d(hidden_dim // 2, 1, 3, padding=1) + + def set_original_shapes(self, shapes: Tuple[int, int]): + self.original_shapes = shapes + + def set_shapes(self, shapes: Tuple[int, int]): + self.shapes = shapes + + def forward( + self, features: torch.Tensor, rays_hr: torch.Tensor, pos_embed, level_embed + ) -> torch.Tensor: + features = features.unbind(dim=-1) + shapes = self.shapes + + # camera_embedding + # torch.cuda.synchronize() + # start = time() + # print(f'shapes\n:{self.original_shapes, shapes})') + rays_embedding_16 = F.normalize( + flat_interpolate(rays_hr, old=self.original_shapes, new=shapes), dim=-1 + ) + rays_embedding_8 = F.normalize( + flat_interpolate( + rays_hr, old=self.original_shapes, new=[x * 2 for x in shapes] + ), + dim=-1, + ) + rays_embedding_4 = F.normalize( + flat_interpolate( + rays_hr, old=self.original_shapes, new=[x * 4 for x in shapes] + ), + dim=-1, + ) + rays_embedding_16 = self.project_rays16(rsh_cart_8(rays_embedding_16)) + rays_embedding_8 = self.project_rays8(rsh_cart_8(rays_embedding_8)) + rays_embedding_4 = self.project_rays4(rsh_cart_8(rays_embedding_4)) + # torch.cuda.synchronize() + # print(f"camera_embedding took {time() - start} seconds") + features_tokens = torch.cat(features, dim=1) + features_tokens_pos = pos_embed + level_embed + + # Generate latents with init as pooled features + features_channels = torch.cat(features, dim=-1) + features_16 = self.features_channel_cat(features_channels) + latents_16 = self.to_latents( + flat_interpolate(features_16, old=self.shapes, new=shapes, antialias=False) + ) + + # Aggregate features: F -> D + latents_16 = self.aggregate_16( + latents_16, context=features_tokens, pos_embed_context=features_tokens_pos + ) + + # Aggregate camera: D- > D|E + latents_16 = self.prompt_camera(latents_16, context=rays_embedding_16) + + # Block 16 - Out 8 + for layer in self.layers_16: + latents_16 = layer(latents_16, pos_embed=rays_embedding_16) + latents_8 = self.up8( + rearrange( + latents_16 + rays_embedding_16, + "b (h w) c -> b c h w", + h=shapes[0], + w=shapes[1], + ).contiguous() + ) + out8 = self.out8( + rearrange( + latents_8, "b (h w) c -> b c h w", h=shapes[0] * 2, w=shapes[1] * 2 + ) + ) + + # Block 8 - Out 4 + for layer in self.layers_8: + latents_8 = layer(latents_8, pos_embed=rays_embedding_8) + latents_4 = self.up4( + rearrange( + latents_8 + rays_embedding_8, + "b (h w) c -> b c h w", + h=shapes[0] * 2, + w=shapes[1] * 2, + ).contiguous() + ) + out4 = self.out4( + rearrange( + latents_4, "b (h w) c -> b c h w", h=shapes[0] * 4, w=shapes[1] * 4 + ) + ) + + # Block 4 - Out 2 + for layer in self.layers_4: + latents_4 = layer(latents_4, pos_embed=rays_embedding_4) + latents_2 = self.up2( + rearrange( + latents_4 + rays_embedding_4, + "b (h w) c -> b c h w", + h=shapes[0] * 4, + w=shapes[1] * 4, + ).contiguous() + ) + out2 = self.out2( + rearrange( + latents_2, "b (h w) c -> b c h w", h=shapes[0] * 8, w=shapes[1] * 8 + ) + ) + + # Depth features + proj_latents_16 = rearrange( + latents_16, "b (h w) c -> b c h w", h=shapes[0], w=shapes[1] + ).contiguous() + + # MS Outputs + out2 = out2.clamp(-10.0, 10.0).exp() + out4 = out4.clamp(-10.0, 10.0).exp() + out8 = out8.clamp(-10.0, 10.0).exp() + + return out8, out4, out2, proj_latents_16 + + +class Decoder(nn.Module): + def __init__( + self, + config, + *args, + **kwargs, + ): + super().__init__() + self.build(config) + self.apply(self._init_weights) + self.test_fixed_camera = False + self.skip_camera = False + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Conv2d): + trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def get_adapted_features(self, features_flat, splits): + features_flat_cat = torch.cat(features_flat, dim=-1) + features_projected = self.input_adapter( + features_flat_cat, splits + ) # list [b hw c] shapes + features = torch.chunk(features_projected, len(splits), dim=-1) + return features + + def run_camera(self, cls_tokens, features, pos_embed, original_shapes, rays): + # get cls tokens projections + cls_tokens_splits = torch.tensor( + [x.shape[-1] for x in cls_tokens], + device=features.device, + requires_grad=False, + dtype=features.dtype, + ) + cls_tokens = torch.cat(cls_tokens, dim=-1) + cls_tokens = self.token_adapter(cls_tokens, cls_tokens_splits) + cls_tokens = torch.cat( + torch.chunk(cls_tokens, len(cls_tokens_splits), dim=-1), dim=1 + ) + + # camera layer + intrinsics = self.camera_layer( + features=features, cls_tokens=cls_tokens, pos_embed=pos_embed + ) + intrinsics[:, 0, 0] = max(original_shapes) / 2 * intrinsics[:, 0, 0] + intrinsics[:, 1, 1] = max(original_shapes) / 2 * intrinsics[:, 1, 1] + intrinsics[:, 0, 2] = intrinsics[:, 0, 2] * original_shapes[1] + intrinsics[:, 1, 2] = intrinsics[:, 1, 2] * original_shapes[0] + if not self.test_fixed_camera: + rays, _ = generate_rays(intrinsics, original_shapes, noisy=False) + + return intrinsics, rays + + def forward(self, inputs, image_metas) -> torch.Tensor: + B, _, H, W = inputs["image"].shape + device = inputs["image"].device + + # make stride happy? + original_encoder_outputs = [x.contiguous() for x in inputs["encoder_outputs"]] + cls_tokens = [x.contiguous() for x in inputs["cls_tokens"]] + + # collect features and tokens + original_encoder_outputs = [ + max_stack(original_encoder_outputs[i:j]) + for i, j in self.slices_encoder_range + ] + cls_tokens = [cls_tokens[-i - 1] for i in range(len(self.slices_encoder_range))] + + # get features in b n d format + # level shapes, the shape per level, for swin like [[128, 128], [64, 64],...], for vit [[32,32]] -> mult times resolutions + resolutions = [ + tuple(sorted([x.shape[1], x.shape[2]])) for x in original_encoder_outputs + ] + level_shapes = sorted(list(set(resolutions)))[::-1] + + if len(level_shapes) == 1: + level_shapes = level_shapes * self.num_resolutions + input_shapes = [ + level_shapes[i] + for i, (start, end) in enumerate(self.slices_encoder) + for _ in range(end - start) + ] + common_shape = level_shapes[-2] + + # input shapes repeat shapes for each level, times the amount of the layers: + features_flat = [ + flat_interpolate( + rearrange(x, "b h w c -> b (h w) c"), old=input_shape, new=common_shape + ) + for x, input_shape in zip(original_encoder_outputs, input_shapes) + ] + features_splits = torch.tensor( + [x.shape[-1] for x in features_flat], + device=device, + requires_grad=False, + dtype=torch.float32, + ) + + # input adapter, then do mean of features in same blocks + features = self.get_adapted_features(features_flat, features_splits) + features = torch.stack(features, dim=-1) + + # positional embeddings, spatial and level + level_embed = torch.cat( + [ + self.level_embed_layer(self.level_embeds)[i : i + 1] + .unsqueeze(0) + .repeat(B, common_shape[0] * common_shape[1], 1) + for i in range(self.num_resolutions) + ], + dim=1, + ) + pos_embed = self.pos_embed( + torch.zeros( + B, + 1, + common_shape[0], + common_shape[1], + device=device, + requires_grad=False, + ) + ) + pos_embed = rearrange(pos_embed, "b c h w -> b (h w) c").repeat( + 1, self.num_resolutions, 1 + ) + + self.camera_layer.set_shapes(common_shape) + intrinsics, rays = ( + self.run_camera( + cls_tokens, + features=features, + pos_embed=pos_embed + level_embed, + original_shapes=(H, W), + rays=inputs.get("rays", None), + ) + if not self.skip_camera + else (inputs["K"], inputs["rays"]) + ) + + # run bulk of the model + self.depth_layer.set_shapes(common_shape) + self.depth_layer.set_original_shapes((H, W)) + out8, out4, out2, depth_features = self.depth_layer( + features=features, + rays_hr=rays, + pos_embed=pos_embed, + level_embed=level_embed, + ) + + return intrinsics, [out8, out4, out2], depth_features + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {"latents_pos", "level_embeds"} + + def build(self, config): + depth = config["model"]["pixel_decoder"]["depths"] + input_dims = config["model"]["pixel_encoder"]["embed_dims"] + hidden_dim = config["model"]["pixel_decoder"]["hidden_dim"] + num_heads = config["model"]["num_heads"] + expansion = config["model"]["expansion"] + dropout = config["model"]["pixel_decoder"]["dropout"] + depths_encoder = config["model"]["pixel_encoder"]["depths"] + num_steps = config["model"].get("num_steps", 100000) + layer_scale = 1.0 + + self.depth = depth + self.dim = hidden_dim + self.downsample = 4 + self.num_heads = num_heads + self.num_resolutions = len(depths_encoder) + self.depths_encoder = depths_encoder + + self.slices_encoder_single = list( + zip([d - 1 for d in self.depths_encoder], self.depths_encoder) + ) + self.slices_encoder_range = list( + zip([0, *self.depths_encoder[:-1]], self.depths_encoder) + ) + cls_token_input_dims = [input_dims[-i - 1] for i in range(len(depths_encoder))] + + input_dims = [input_dims[d - 1] for d in depths_encoder] + self.slices_encoder = self.slices_encoder_single + + # adapt from encoder features, just project + self.input_adapter = ListAdapter(input_dims, hidden_dim) + self.token_adapter = ListAdapter(cls_token_input_dims, hidden_dim) + + # camera layer + self.camera_layer = CameraHead( + input_dim=hidden_dim, + hidden_dim=hidden_dim, + num_heads=num_heads, + expansion=expansion, + depth=2, + dropout=dropout, + layer_scale=layer_scale, + ) + + self.depth_layer = DepthHead( + hidden_dim=hidden_dim, + num_heads=num_heads, + expansion=expansion, + depths=depth, + dropout=dropout, + camera_dim=81, + num_resolutions=self.num_resolutions, + layer_scale=layer_scale, + ) + + # transformer part + self.pos_embed = PositionEmbeddingSine(hidden_dim // 2, normalize=True) + self.level_embeds = nn.Parameter( + torch.randn(len(input_dims), hidden_dim), requires_grad=True + ) + self.level_embed_layer = nn.Sequential( + nn.Linear(hidden_dim, hidden_dim), + nn.GELU(), + nn.Linear(hidden_dim, hidden_dim), + nn.LayerNorm(hidden_dim), + ) diff --git a/unidepth/models/unidepthv1/unidepthv1.py b/unidepth/models/unidepthv1/unidepthv1.py new file mode 100644 index 0000000000000000000000000000000000000000..49d6450cc1c1753cbea2426edd9a59115192da57 --- /dev/null +++ b/unidepth/models/unidepthv1/unidepthv1.py @@ -0,0 +1,336 @@ +""" +Author: Luigi Piccinelli +Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/) +""" + +import importlib +from copy import deepcopy +from math import ceil + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.transforms.functional as TF +from einops import rearrange +from huggingface_hub import PyTorchModelHubMixin + +from unidepth.models.unidepthv1.decoder import Decoder +from unidepth.utils.constants import (IMAGENET_DATASET_MEAN, + IMAGENET_DATASET_STD) +from unidepth.utils.distributed import is_main_process +from unidepth.utils.geometric import (generate_rays, + spherical_zbuffer_to_euclidean) +from unidepth.utils.misc import get_params + +MAP_BACKBONES = {"ViTL14": "vitl14", "ConvNextL": "cnvnxtl"} + + +# inference helpers +def _paddings(image_shape, network_shape): + cur_h, cur_w = image_shape + h, w = network_shape + pad_top, pad_bottom = (h - cur_h) // 2, h - cur_h - (h - cur_h) // 2 + pad_left, pad_right = (w - cur_w) // 2, w - cur_w - (w - cur_w) // 2 + return pad_left, pad_right, pad_top, pad_bottom + + +def _shapes(image_shape, network_shape): + h, w = image_shape + input_ratio = w / h + output_ratio = network_shape[1] / network_shape[0] + if output_ratio > input_ratio: + ratio = network_shape[0] / h + elif output_ratio <= input_ratio: + ratio = network_shape[1] / w + return (ceil(h * ratio - 0.5), ceil(w * ratio - 0.5)), ratio + + +def _preprocess(rgbs, intrinsics, shapes, pads, ratio, output_shapes): + (pad_left, pad_right, pad_top, pad_bottom) = pads + rgbs = F.interpolate( + rgbs, size=shapes, mode="bilinear", align_corners=False, antialias=True + ) + rgbs = F.pad(rgbs, (pad_left, pad_right, pad_top, pad_bottom), mode="constant") + if intrinsics is not None: + intrinsics = intrinsics.clone() + intrinsics[:, 0, 0] = intrinsics[:, 0, 0] * ratio + intrinsics[:, 1, 1] = intrinsics[:, 1, 1] * ratio + intrinsics[:, 0, 2] = intrinsics[:, 0, 2] * ratio + pad_left + intrinsics[:, 1, 2] = intrinsics[:, 1, 2] * ratio + pad_top + return rgbs, intrinsics + return rgbs, None + + +def _postprocess(predictions, intrinsics, shapes, pads, ratio, original_shapes): + (pad_left, pad_right, pad_top, pad_bottom) = pads + # pred mean, trim paddings, and upsample to input dim + predictions = sum( + [ + F.interpolate( + x.clone(), + size=shapes, + mode="bilinear", + align_corners=False, + antialias=True, + ) + for x in predictions + ] + ) / len(predictions) + predictions = predictions[ + ..., pad_top : shapes[0] - pad_bottom, pad_left : shapes[1] - pad_right + ] + predictions = F.interpolate( + predictions, + size=original_shapes, + mode="bilinear", + align_corners=False, + antialias=True, + ) + intrinsics[:, 0, 0] = intrinsics[:, 0, 0] / ratio + intrinsics[:, 1, 1] = intrinsics[:, 1, 1] / ratio + intrinsics[:, 0, 2] = (intrinsics[:, 0, 2] - pad_left) / ratio + intrinsics[:, 1, 2] = (intrinsics[:, 1, 2] - pad_top) / ratio + return predictions, intrinsics + + +class UniDepthV1( + nn.Module, + PyTorchModelHubMixin, + library_name="UniDepth", + repo_url="https://github.com/lpiccinelli-eth/UniDepth", + tags=["monocular-metric-depth-estimation"], +): + def __init__( + self, + config, + eps: float = 1e-6, + **kwargs, + ): + super().__init__() + self.build(config) + self.eps = eps + + def forward(self, inputs, image_metas=None): + rgbs = inputs['image'] + gt_intrinsics = inputs.get('K') + H, W = rgbs.shape[-2:] + + # Encode + encoder_outputs, cls_tokens = self.pixel_encoder(rgbs) + if "dino" in self.pixel_encoder.__class__.__name__.lower(): + encoder_outputs = [ + (x + y.unsqueeze(1)).contiguous() + for x, y in zip(encoder_outputs, cls_tokens) + ] + inputs["encoder_outputs"] = encoder_outputs + inputs["cls_tokens"] = cls_tokens + + # Get camera infos, if any + if gt_intrinsics is not None: + rays, angles = generate_rays( + gt_intrinsics, self.image_shape, noisy=self.training + ) + inputs["rays"] = rays + inputs["angles"] = angles + inputs["K"] = gt_intrinsics + self.pixel_decoder.test_fixed_camera = True # use GT camera in fwd + + # Decode + pred_intrinsics, predictions, _ = self.pixel_decoder(inputs, {}) + predictions = sum( + [ + F.interpolate( + x.clone(), + size=self.image_shape, + mode="bilinear", + align_corners=False, + antialias=True, + ) + for x in predictions + ] + ) / len(predictions) + + # Final 3D points backprojection + pred_angles = generate_rays(pred_intrinsics, (H, W), noisy=False)[-1] + # You may want to use inputs["angles"] if available? + pred_angles = rearrange(pred_angles, "b (h w) c -> b c h w", h=H, w=W) + pred_angles = F.interpolate( + pred_angles.clone(), + size=self.image_shape, + mode="bilinear", + align_corners=False, + antialias=True, + ) + + points_3d = torch.cat((pred_angles, predictions), dim=1) + points_3d = spherical_zbuffer_to_euclidean( + points_3d.permute(0, 2, 3, 1) + ).permute(0, 3, 1, 2) + + # Output data, use for loss computation + outputs = { + "angles": pred_angles, + "intrinsics": pred_intrinsics, + "points": points_3d, + "depth": predictions.squeeze(1), + } + self.pixel_decoder.test_fixed_camera = False + return outputs + + @torch.no_grad() + def infer(self, rgbs: torch.Tensor, intrinsics=None, skip_camera=False): + if rgbs.ndim == 3: + rgbs = rgbs.unsqueeze(0) + if intrinsics is not None and intrinsics.ndim == 2: + intrinsics = intrinsics.unsqueeze(0) + B, _, H, W = rgbs.shape + + rgbs = rgbs.to(self.device) + if intrinsics is not None: + intrinsics = intrinsics.to(self.device) + + # process image and intrinsiscs (if any) to match network input (slow?) + if rgbs.max() > 5 or rgbs.dtype == torch.uint8: + rgbs = rgbs.to(torch.float32).div(255) + if rgbs.min() >= 0.0 and rgbs.max() <= 1.0: + rgbs = TF.normalize( + rgbs, + mean=IMAGENET_DATASET_MEAN, + std=IMAGENET_DATASET_STD, + ) + + (h, w), ratio = _shapes((H, W), self.image_shape) + pad_left, pad_right, pad_top, pad_bottom = _paddings((h, w), self.image_shape) + rgbs, gt_intrinsics = _preprocess( + rgbs, + intrinsics, + (h, w), + (pad_left, pad_right, pad_top, pad_bottom), + ratio, + self.image_shape, + ) + + # run encoder + encoder_outputs, cls_tokens = self.pixel_encoder(rgbs) + if "dino" in self.pixel_encoder.__class__.__name__.lower(): + encoder_outputs = [ + (x + y.unsqueeze(1)).contiguous() + for x, y in zip(encoder_outputs, cls_tokens) + ] + + # get data for decoder and adapt to given camera + inputs = {} + inputs["encoder_outputs"] = encoder_outputs + inputs["cls_tokens"] = cls_tokens + inputs["image"] = rgbs + if gt_intrinsics is not None: + rays, angles = generate_rays( + gt_intrinsics, self.image_shape, noisy=self.training + ) + inputs["rays"] = rays + inputs["angles"] = angles + inputs["K"] = gt_intrinsics + self.pixel_decoder.test_fixed_camera = True + self.pixel_decoder.skip_camera = skip_camera + + # decode all + pred_intrinsics, predictions, _ = self.pixel_decoder(inputs, {}) + + # undo the reshaping and get original image size (slow) + predictions, pred_intrinsics = _postprocess( + predictions, + pred_intrinsics, + self.image_shape, + (pad_left, pad_right, pad_top, pad_bottom), + ratio, + (H, W), + ) + + # final 3D points backprojection + intrinsics = gt_intrinsics if gt_intrinsics is not None else pred_intrinsics + angles = generate_rays(intrinsics, (H, W), noisy=False)[-1] + angles = rearrange(angles, "b (h w) c -> b c h w", h=H, w=W) + points_3d = torch.cat((angles, predictions), dim=1) + points_3d = spherical_zbuffer_to_euclidean( + points_3d.permute(0, 2, 3, 1) + ).permute(0, 3, 1, 2) + + # output data + outputs = { + "intrinsics": pred_intrinsics, + "points": points_3d, + "depth": predictions[:, -1:], + } + self.pixel_decoder.test_fixed_camera = False + self.pixel_decoder.skip_camera = False + return outputs + + def load_pretrained(self, model_file): + device = ( + torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + ) + dict_model = torch.load(model_file, map_location=device) + + if "model" in dict_model: + dict_model = dict_model["model"] + new_state_dict = deepcopy( + {k.replace("module.", ""): v for k, v in dict_model.items()} + ) + + info = self.load_state_dict(new_state_dict, strict=False) + if is_main_process(): + print( + f"Loaded from {model_file} for {self.__class__.__name__} results in:", + info, + ) + + def get_params(self, config): + if hasattr(self.pixel_encoder, "get_params"): + encoder_p, encoder_lr = self.pixel_encoder.get_params( + config["model"]["pixel_encoder"]["lr"], + config["training"]["wd"], + config["training"]["ld"], + ) + else: + encoder_p, encoder_lr = get_params( + self.pixel_encoder, + config["model"]["pixel_encoder"]["lr"], + config["training"]["wd"], + ) + decoder_p, decoder_lr = get_params( + self.pixel_decoder, config["training"]["lr"], config["training"]["wd"] + ) + return [*encoder_p, *decoder_p], [*encoder_lr, *decoder_lr] + + @property + def device(self): + return next(self.parameters()).device + + def build(self, config): + mod = importlib.import_module("unidepth.models.encoder") + pixel_encoder_factory = getattr(mod, config["model"]["pixel_encoder"]["name"]) + pixel_encoder_config = { + **config["training"], + **config["data"], + **config["model"]["pixel_encoder"], + "interpolate_offset": 0.1, + } + pixel_encoder = pixel_encoder_factory(pixel_encoder_config) + + config["model"]["pixel_encoder"]["patch_size"] = ( + 14 if "dino" in config["model"]["pixel_encoder"]["name"] else 16 + ) + pixel_encoder_embed_dims = ( + pixel_encoder.embed_dims + if hasattr(pixel_encoder, "embed_dims") + else [getattr(pixel_encoder, "embed_dim") * 2**i for i in range(4)] + ) + config["model"]["pixel_encoder"]["embed_dim"] = getattr( + pixel_encoder, "embed_dim" + ) + config["model"]["pixel_encoder"]["embed_dims"] = pixel_encoder_embed_dims + config["model"]["pixel_encoder"]["depths"] = pixel_encoder.depths + + self.pixel_encoder = pixel_encoder + self.pixel_decoder = Decoder(config) + self.image_shape = config["data"]["image_shape"] diff --git a/unidepth/models/unidepthv2/__init__.py b/unidepth/models/unidepthv2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..678562babe6f4661592390ceb990201e583ce6f3 --- /dev/null +++ b/unidepth/models/unidepthv2/__init__.py @@ -0,0 +1,5 @@ +from .unidepthv2 import UniDepthV2 + +__all__ = [ + "UniDepthV2", +] diff --git a/unidepth/models/unidepthv2/__pycache__/__init__.cpython-311.pyc b/unidepth/models/unidepthv2/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..06d4b0817662aef6a5402671f4867841a1c3b34a Binary files /dev/null and b/unidepth/models/unidepthv2/__pycache__/__init__.cpython-311.pyc differ diff --git a/unidepth/models/unidepthv2/__pycache__/decoder.cpython-311.pyc b/unidepth/models/unidepthv2/__pycache__/decoder.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f219e7066ffe99f56725c28b4990b49b1dad6fbd Binary files /dev/null and b/unidepth/models/unidepthv2/__pycache__/decoder.cpython-311.pyc differ diff --git a/unidepth/models/unidepthv2/__pycache__/unidepthv2.cpython-311.pyc b/unidepth/models/unidepthv2/__pycache__/unidepthv2.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..22d6aabe15196222073267e6fb89986592d17f45 Binary files /dev/null and b/unidepth/models/unidepthv2/__pycache__/unidepthv2.cpython-311.pyc differ diff --git a/unidepth/models/unidepthv2/decoder.py b/unidepth/models/unidepthv2/decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..45de9ce3f4674a1f9188d0179441fa58b4beab4a --- /dev/null +++ b/unidepth/models/unidepthv2/decoder.py @@ -0,0 +1,585 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from timm.models.layers import trunc_normal_ + +from unidepth.layers import (MLP, AttentionBlock, ConvUpsampleShuffleResidual, + NystromBlock, PositionEmbeddingSine) +from unidepth.utils.geometric import flat_interpolate, generate_rays +from unidepth.utils.positional_embedding import generate_fourier_features + + +class ListAdapter(nn.Module): + def __init__(self, input_dims: list[int], hidden_dim: int): + super().__init__() + self.input_adapters = nn.ModuleList([]) + self.num_chunks = len(input_dims) + self.checkpoint = True + for input_dim in input_dims: + self.input_adapters.append( + nn.Sequential( + nn.LayerNorm(input_dim), nn.Linear(input_dim, hidden_dim), nn.GELU() + ) + ) + + def forward(self, x: torch.Tensor, splits: torch.Tensor) -> torch.Tensor: + xs = torch.split(x, splits.int().tolist(), dim=-1) + xs = [adapter(x) for x, adapter in zip(xs, self.input_adapters)] + return torch.cat(xs, dim=-1) + + +class CameraHead(nn.Module): + def __init__( + self, + hidden_dim: int, + num_heads: int = 8, + expansion: int = 4, + dropout: float = 0.0, + **kwargs, + ): + super().__init__() + self.aggregate1 = AttentionBlock( + hidden_dim, num_heads=1, expansion=expansion, dropout=dropout + ) + self.aggregate2 = AttentionBlock( + hidden_dim, num_heads=1, expansion=expansion, dropout=dropout + ) + self.latents_pos = nn.Parameter( + torch.randn(1, 4, hidden_dim), requires_grad=True + ) + self.in_features = MLP(hidden_dim, expansion=2, dropout=dropout) + self.project_cls = MLP(hidden_dim, dropout=dropout) + self.out = MLP(hidden_dim, expansion=2, dropout=0.0, output_dim=1) + + def fill_intrinsics(self, x): + camera_intrinsics = torch.zeros( + x.shape[0], 3, 3, device=x.device, requires_grad=False + ) + camera_intrinsics[:, 0, 0] = x[:, 0].exp() + camera_intrinsics[:, 1, 1] = x[:, 1].exp() + camera_intrinsics[:, 0, 2] = x[:, 2].sigmoid() + camera_intrinsics[:, 1, 2] = x[:, 3].sigmoid() + camera_intrinsics[:, 2, 2] = 1.0 + return camera_intrinsics + + def forward(self, features, cls_tokens, pos_embed) -> torch.Tensor: + features = features.unbind(dim=-1) + cls_tokens = self.project_cls(cls_tokens) + latents_pos = self.latents_pos.expand(cls_tokens.shape[0], -1, -1) + features = self.in_features(torch.cat(features, dim=1) + pos_embed) + features = torch.cat((features, cls_tokens), dim=1) + cls_tokens = self.aggregate1( + cls_tokens, context=features, pos_embed=latents_pos + ) + cls_tokens = self.aggregate2( + cls_tokens, context=features, pos_embed=latents_pos + ) + + # project to intrinsics + x = self.out(cls_tokens).squeeze(-1) + camera_intrinsics = self.fill_intrinsics(x) + + return camera_intrinsics + + def set_shapes(self, shapes: tuple[int, int]): + self.shapes = shapes + + +class GlobalHead(nn.Module): + def __init__( + self, + hidden_dim: int, + camera_dim: int, + expansion: int = 4, + dropout: float = 0.0, + **kwargs, + ): + super().__init__() + self.camera_dim = camera_dim + self.in_features = nn.Linear(hidden_dim, hidden_dim) + self.project_rays = nn.Linear(camera_dim + 3, hidden_dim) + self.aggregate1 = AttentionBlock( + hidden_dim, num_heads=1, expansion=expansion, dropout=dropout + ) + self.aggregate2 = AttentionBlock( + hidden_dim, num_heads=1, expansion=expansion, dropout=dropout + ) + self.project_cls = MLP(hidden_dim, dropout=dropout) + self.out = MLP(hidden_dim, expansion=2, dropout=0.0, output_dim=1) + + def embed_rays(self, rays, shapes): + rays_embedding = flat_interpolate(rays, old=self.original_shapes, new=shapes) + rays_embedding = F.normalize(rays_embedding, dim=-1) + rays_embedding = generate_fourier_features( + rays_embedding, + dim=self.camera_dim, + max_freq=max(shapes) // 2, + use_log=True, + cat_orig=True, + ) + return rays_embedding + + def set_original_shapes(self, shapes: tuple[int, int]): + self.original_shapes = shapes + + def set_shapes(self, shapes: tuple[int, int]): + self.shapes = shapes + + def get_scaleshift(self, x): + scale, shift = torch.chunk(x, 2, dim=1) + scale = scale.exp().reshape(-1, 1, 1, 1) + shift = shift.reshape(-1, 1, 1, 1) + return scale, shift + + def forward(self, features, cls_tokens, rays) -> torch.Tensor: + features = features.unbind(dim=-1) + cls_tokens = self.project_cls(cls_tokens) + rays_embedding = self.project_rays(self.embed_rays(rays, self.shapes)) + rays_embedding = rays_embedding.repeat(1, len(features), 1) + features = self.in_features(torch.cat(features, dim=1) + rays_embedding) + features = torch.cat((features, cls_tokens), dim=1) + cls_tokens = self.aggregate1(cls_tokens, context=features) + cls_tokens = self.aggregate2(cls_tokens, context=features) + x = self.out(cls_tokens).squeeze(-1) + scale, shift = self.get_scaleshift(x) + return scale, shift + + +class DepthHead(nn.Module): + def __init__( + self, + hidden_dim: int, + num_heads: int = 8, + expansion: int = 4, + depths: int | list[int] = 4, + checkpoint: bool = True, + camera_dim: int = 256, + num_resolutions: int = 4, + dropout: float = 0.0, + **kwargs, + ) -> None: + super().__init__() + self.checkpoint = checkpoint + self.camera_dim = camera_dim + self.skip_depth = False + + self.to_latents = MLP(hidden_dim, expansion=2, dropout=dropout) + self.features_channel_cat = nn.Linear(hidden_dim * num_resolutions, hidden_dim) + self.aggregate_16 = AttentionBlock( + hidden_dim, + num_heads=1, + expansion=expansion, + dropout=dropout, + context_dim=hidden_dim, + ) + self.prompt_camera = AttentionBlock( + hidden_dim, + num_heads=1, + expansion=expansion, + dropout=dropout, + context_dim=hidden_dim, + ) + + self.rays_layers = nn.ModuleList([]) + self.ups = nn.ModuleList([]) + self.process_layers = nn.ModuleList([]) + self.norms, self.out_layers = nn.ModuleList([]), nn.ModuleList([]) + self.depth_mlp, self.confidence_mlp = nn.ModuleList([]), nn.ModuleList([]) + for i, depth in enumerate(depths): + blk_lst = nn.ModuleList([]) + for _ in range(depth): + blk_lst.append( + NystromBlock( + hidden_dim // int(2**i), + num_heads=num_heads // int(2**i), + expansion=expansion, + dropout=dropout, + ) + ) + self.process_layers.append(blk_lst) + self.rays_layers.append(nn.Linear(camera_dim + 3, hidden_dim // int(2**i))) + self.ups.append( + ConvUpsampleShuffleResidual( + hidden_dim // int(2**i), + expansion=expansion, + kernel_size=7, + num_layers=2, + ) + ) + self.depth_mlp.append( + MLP( + input_dim=hidden_dim // int(2 ** (i + 1)), + output_dim=16, + expansion=1, + ) + ) + self.confidence_mlp.append( + MLP( + input_dim=hidden_dim // int(2 ** (i + 1)), + output_dim=16, + expansion=1, + ) + ) + self.to_depth = nn.Conv2d( + 16 * len(depths), 1, 7, padding=3, padding_mode="reflect" + ) + self.to_confidence = nn.Conv2d( + 16 * len(depths), 1, 7, padding=3, padding_mode="reflect" + ) + + def set_original_shapes(self, shapes: tuple[int, int]): + self.original_shapes = shapes + + def set_shapes(self, shapes: tuple[int, int]): + self.shapes = shapes + + def embed_rays(self, rays, shapes): + rays_embedding = flat_interpolate(rays, old=self.original_shapes, new=shapes) + rays_embedding = F.normalize(rays_embedding, dim=-1) + rays_embedding = generate_fourier_features( + rays_embedding, + dim=self.camera_dim, + max_freq=max(shapes) // 2, + use_log=True, + cat_orig=True, + ) + return rays_embedding + + def project_rays(self, rays, shapes): + embedded_rays = [] + for i, layer in enumerate(self.rays_layers): + embedded_rays.append( + layer(self.embed_rays(rays, [(2**i) * x for x in shapes])) + ) + return embedded_rays + + def decode_depth(self, latents_16, rays, shapes): + latents = latents_16 + out_features, depths, confidences = [], [], [] + for i, (up, layers, rays_embedding) in enumerate( + zip(self.ups, self.process_layers, rays) + ): + for layer in layers: + latents = layer(latents, pos_embed=rays_embedding) + latents = up( + rearrange( + latents + rays_embedding, + "b (h w) c -> b c h w", + h=shapes[0] * int(2**i), + w=shapes[1] * int(2**i), + ).contiguous() + ) + out = rearrange( + latents, + "b (h w) c -> b h w c", + h=shapes[0] * int(2 ** (1 + i)), + w=shapes[1] * int(2 ** (1 + i)), + ) + out_features.append(out) + + # aggregate output and project to depth + for i, (layer, features) in enumerate( + zip(self.depth_mlp[::-1], out_features[::-1]) + ): + out_depth_features = layer(features).permute(0, 3, 1, 2) + out_depth_features = F.interpolate( + out_depth_features, size=self.original_shapes, mode="bilinear" + ) + depths.append(out_depth_features) + logdepth = self.to_depth(torch.cat(depths, dim=1)) + + # aggregate output and project to confidences + for i, (layer, features) in enumerate( + zip(self.confidence_mlp[::-1], out_features[::-1]) + ): + out_conf_features = layer(features).permute(0, 3, 1, 2) + out_conf_features = F.interpolate( + out_conf_features, size=self.original_shapes, mode="bilinear" + ) + confidences.append(out_conf_features) + confidence = self.to_confidence(torch.cat(confidences, dim=1)) + + # apply sigmoid ot get conf in [0, 1] + confidence = torch.sigmoid(confidence) + + return logdepth, confidence + + def init_latents(self, features, shapes): + # Generate latents with init as pooled features + features_channels = torch.cat(features, dim=-1) + features_16 = self.features_channel_cat(features_channels) + latents_16 = features_16 + self.to_latents( + flat_interpolate(features_16, old=self.shapes, new=shapes, antialias=False) + ) + return latents_16 + + def forward( + self, features: torch.Tensor, rays_hr: torch.Tensor, pos_embed, level_embed + ) -> torch.Tensor: + B = features.shape[0] + features = features.unbind(dim=-1) + shapes = self.shapes + + # camera_embedding + rays_embeddings = self.project_rays(rays_hr, shapes) + + # Init latents + init_latents_16 = self.init_latents(features, shapes) + + # Aggregate features: F -> D + latents_16 = self.aggregate_16( + init_latents_16, + context=torch.cat(features, dim=1), + pos_embed_context=pos_embed + level_embed, + ) + + # Aggregate camera: D -> D|E + latents_16 = self.prompt_camera(latents_16, context=rays_embeddings[0]) + + # Decode depth + logdepth, confidence = self.decode_depth(latents_16, rays_embeddings, shapes) + + return logdepth, confidence, latents_16 + + +class Decoder(nn.Module): + def __init__( + self, + config, + ): + super().__init__() + self.build(config) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Conv2d): + trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + if m.bias is not None: + nn.init.constant_(m.bias, 0) + if m.weight is not None: + nn.init.constant_(m.weight, 1.0) + + def get_adapted_features(self, features_flat, splits): + features_flat_cat = torch.cat(features_flat, dim=-1) + features_projected = self.input_adapter( + features_flat_cat, splits + ) # list [b hw c] shapes + features = torch.chunk(features_projected, splits.shape[0], dim=-1) + return features + + def run_camera(self, cls_tokens, features, pos_embed, original_shapes, rays_gt): + # get cls tokens projections + cls_tokens_splits = torch.tensor( + [x.shape[-1] for x in cls_tokens], + device=features.device, + requires_grad=False, + dtype=features.dtype, + ) + cls_tokens = torch.cat(cls_tokens, dim=-1) + cls_tokens = self.camera_token_adapter(cls_tokens, cls_tokens_splits) + cls_tokens = torch.cat( + torch.chunk(cls_tokens, cls_tokens_splits.shape[0], dim=-1), dim=1 + ) + + # camera layer + intrinsics = self.camera_layer( + features=features, cls_tokens=cls_tokens, pos_embed=pos_embed + ) + intrinsics[:, 0, 0] = max(original_shapes) / 2 * intrinsics[:, 0, 0] + intrinsics[:, 1, 1] = max(original_shapes) / 2 * intrinsics[:, 1, 1] + intrinsics[:, 0, 2] = intrinsics[:, 0, 2] * original_shapes[1] + intrinsics[:, 1, 2] = intrinsics[:, 1, 2] * original_shapes[0] + + rays = ( + rays_gt + if rays_gt is not None + else generate_rays(intrinsics, original_shapes)[0] + ) + return intrinsics, rays + + def run_global(self, cls_tokens, features, rays): + # get cls tokens projections + cls_tokens_splits = torch.tensor( + [x.shape[-1] for x in cls_tokens], + device=features.device, + requires_grad=False, + dtype=torch.float32, + ) + cls_tokens = torch.cat(cls_tokens, dim=-1) + cls_tokens = self.global_token_adapter(cls_tokens, cls_tokens_splits) + cls_tokens = torch.cat( + torch.chunk(cls_tokens, cls_tokens_splits.shape[0], dim=-1), dim=1 + ) + + scale, shift = self.global_layer( + features=features, rays=rays, cls_tokens=cls_tokens + ) + + return scale, shift + + def forward(self, inputs, image_metas) -> torch.Tensor: + B, C, H, W = inputs["image"].shape + device = inputs["image"].device + dtype = inputs["image"].dtype + + # get features in b n d format + # level shapes, the shape per level, for swin like [[128, 128], [64, 64],...], for vit [[32,32]] -> mult times resolutions + level_shapes = sorted( + list(set([tuple([x.shape[1], x.shape[2]]) for x in inputs["features"]])) + )[::-1] + if len(level_shapes) == 1: + level_shapes = level_shapes * self.num_resolutions + input_shapes = [ + level_shapes[i] + for i, (start, end) in enumerate(self.slices_encoder) + for _ in range(end - start) + ] + common_shape = level_shapes[-2] + + # input shapes repeat shapes for each level, times the amount of the layers: + features_flat = [ + flat_interpolate( + rearrange(x, "b h w c -> b (h w) c"), old=input_shape, new=common_shape + ) + for x, input_shape in zip(inputs["features"], input_shapes) + ] + features_splits = torch.tensor( + [x.shape[-1] for x in features_flat], + device=device, + requires_grad=False, + dtype=torch.float32, + ) + features = self.get_adapted_features(features_flat, features_splits) + features = torch.stack(features, dim=-1) + + # positional embeddings, spatial and level + level_embed = torch.cat( + [ + self.level_embed_layer(self.level_embeds)[i : i + 1] + .unsqueeze(0) + .repeat(B, common_shape[0] * common_shape[1], 1) + for i in range(self.num_resolutions) + ], + dim=1, + ) + dummy_tensor = torch.zeros( + B, 1, common_shape[0], common_shape[1], device=device, requires_grad=False + ) + pos_embed = self.pos_embed(dummy_tensor) + pos_embed = rearrange(pos_embed, "b c h w -> b (h w) c").repeat( + 1, self.num_resolutions, 1 + ) + + self.camera_layer.set_shapes(common_shape) + intrinsics, rays = self.run_camera( + inputs["camera_tokens"], + features=features, + pos_embed=pos_embed + level_embed, + original_shapes=(H, W), + rays_gt=inputs.get("rays"), + ) + + self.global_layer.set_shapes(common_shape) + self.global_layer.set_original_shapes((H, W)) + scale, shift = self.run_global( + inputs["global_tokens"], features=features, rays=rays + ) + + # run bulk of the model + self.depth_layer.set_shapes(common_shape) + self.depth_layer.set_original_shapes((H, W)) + logdepth, confidence, depth_features = self.depth_layer( + features=features, + rays_hr=rays, + pos_embed=pos_embed, + level_embed=level_embed, + ) + logdepth = logdepth.to(torch.float32, non_blocking=True) + + # norm in log space, why performs better? + shapes = [int(x) for x in logdepth.shape[-2:]] + depth_normalized = F.layer_norm(logdepth, shapes).exp() + + depth = ( + depth_normalized + shift + ) * scale # shift is scale invariant if we do (x + mu) * sigma + depth = F.softplus(depth, beta=10.0).to(dtype, non_blocking=True) + + outputs = { + "depth": depth, + "confidence": confidence, + "depth_features": depth_features, + "K": intrinsics, + } + return outputs + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {"latents_pos", "level_embeds"} + + def build(self, config): + input_dims = config["model"]["pixel_encoder"]["embed_dims"] + hidden_dim = config["model"]["pixel_decoder"]["hidden_dim"] + expansion = config["model"]["expansion"] + num_heads = config["model"]["num_heads"] + dropout = config["model"]["pixel_decoder"]["dropout"] + depths_encoder = config["model"]["pixel_encoder"]["depths"] + depth = config["model"]["pixel_decoder"]["depths"] + depths_encoder = config["model"]["pixel_encoder"]["depths"] + self.downsample = 4 + self.num_resolutions = len(depths_encoder) + + self.slices_encoder = list(zip([d - 1 for d in depths_encoder], depths_encoder)) + cls_token_input_dims = [input_dims[i] for i in [-1, -2, -3, -4]] + input_dims = [input_dims[d - 1] for d in depths_encoder] + + # # camera layer + self.camera_layer = CameraHead( + hidden_dim=hidden_dim, + num_heads=num_heads, + expansion=expansion, + dropout=dropout, + ) + + # # scale shift layer + self.global_layer = GlobalHead( + hidden_dim=hidden_dim, + camera_dim=96, + num_heads=num_heads, + expansion=expansion, + dropout=dropout, + ) + + # # adapt from encoder features, just project + self.input_adapter = ListAdapter(input_dims, hidden_dim) + self.camera_token_adapter = ListAdapter(cls_token_input_dims, hidden_dim) + self.global_token_adapter = ListAdapter(cls_token_input_dims[:2], hidden_dim) + + self.depth_layer = DepthHead( + hidden_dim=hidden_dim, + num_heads=num_heads, + expansion=expansion, + depths=depth, + dropout=dropout, + camera_dim=96, + num_resolutions=self.num_resolutions, + ) + + self.pos_embed = PositionEmbeddingSine(hidden_dim // 2, normalize=True) + self.level_embeds = nn.Parameter( + torch.randn(len(input_dims), hidden_dim), requires_grad=True + ) + self.level_embed_layer = nn.Sequential( + nn.Linear(hidden_dim, hidden_dim), + nn.GELU(), + nn.Linear(hidden_dim, hidden_dim), + nn.LayerNorm(hidden_dim), + ) diff --git a/unidepth/models/unidepthv2/decoder_old.py b/unidepth/models/unidepthv2/decoder_old.py new file mode 100644 index 0000000000000000000000000000000000000000..1d0bf4213df099b65271f9f089c4288ffda4ee92 --- /dev/null +++ b/unidepth/models/unidepthv2/decoder_old.py @@ -0,0 +1,588 @@ +""" +Author: Luigi Piccinelli +Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/) +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from timm.models.layers import trunc_normal_ + +from unidepth.layers import (MLP, AttentionBlock, ConvUpsampleShuffle, + NystromBlock, PositionEmbeddingSine) +from unidepth.utils.geometric import flat_interpolate, generate_rays +from unidepth.utils.positional_embedding import generate_fourier_features + + +class ListAdapter(nn.Module): + def __init__(self, input_dims: list[int], hidden_dim: int): + super().__init__() + self.input_adapters = nn.ModuleList([]) + self.num_chunks = len(input_dims) + self.checkpoint = True + for input_dim in input_dims: + self.input_adapters.append( + nn.Sequential( + nn.LayerNorm(input_dim), nn.Linear(input_dim, hidden_dim), nn.GELU() + ) + ) + + def forward(self, x: torch.Tensor, splits: torch.Tensor) -> torch.Tensor: + xs = torch.split(x, splits.int().tolist(), dim=-1) + xs = [adapter(x) for x, adapter in zip(xs, self.input_adapters)] + return torch.cat(xs, dim=-1) + + +class CameraHead(nn.Module): + def __init__( + self, + hidden_dim: int, + expansion: int = 4, + dropout: float = 0.0, + **kwargs, + ): + super().__init__() + self.aggregate1 = AttentionBlock( + hidden_dim, num_heads=1, expansion=expansion, dropout=dropout + ) + self.aggregate2 = AttentionBlock( + hidden_dim, num_heads=1, expansion=expansion, dropout=dropout + ) + self.latents_pos = nn.Parameter( + torch.randn(1, 4, hidden_dim), requires_grad=True + ) + self.in_features = MLP(hidden_dim, expansion=2, dropout=dropout) + self.project_cls = MLP(hidden_dim, dropout=dropout) + self.out = MLP(hidden_dim, expansion=2, dropout=0.0, output_dim=1) + + def fill_intrinsics(self, x): + camera_intrinsics = torch.zeros( + x.shape[0], 3, 3, device=x.device, requires_grad=False + ) + camera_intrinsics[:, 0, 0] = x[:, 0].exp() + camera_intrinsics[:, 1, 1] = x[:, 1].exp() + camera_intrinsics[:, 0, 2] = x[:, 2].sigmoid() + camera_intrinsics[:, 1, 2] = x[:, 3].sigmoid() + camera_intrinsics[:, 2, 2] = 1.0 + return camera_intrinsics + + def forward(self, features, cls_tokens, pos_embed) -> torch.Tensor: + features = features.unbind(dim=-1) + cls_tokens = self.project_cls(cls_tokens) + latents_pos = self.latents_pos.expand(cls_tokens.shape[0], -1, -1) + features = self.in_features(torch.cat(features, dim=1) + pos_embed) + features = torch.cat((features, cls_tokens), dim=1) + cls_tokens = self.aggregate1( + cls_tokens, context=features, pos_embed=latents_pos + ) + cls_tokens = self.aggregate2( + cls_tokens, context=features, pos_embed=latents_pos + ) + + # project to intrinsics + x = self.out(cls_tokens).squeeze(-1) + camera_intrinsics = self.fill_intrinsics(x) + + return camera_intrinsics + + def set_shapes(self, shapes: tuple[int, int]): + self.shapes = shapes + + +class GlobalHead(nn.Module): + def __init__( + self, + hidden_dim: int, + camera_dim: int, + expansion: int = 4, + dropout: float = 0.0, + **kwargs, + ): + super().__init__() + self.camera_dim = camera_dim + self.in_features = nn.Linear(hidden_dim, hidden_dim) + self.project_rays = nn.Linear(camera_dim + 3, hidden_dim) + self.aggregate1 = AttentionBlock( + hidden_dim, num_heads=1, expansion=expansion, dropout=dropout + ) + self.aggregate2 = AttentionBlock( + hidden_dim, num_heads=1, expansion=expansion, dropout=dropout + ) + self.project_cls = MLP(hidden_dim, dropout=dropout) + self.out = MLP(hidden_dim, expansion=2, dropout=0.0, output_dim=1) + + def embed_rays(self, rays, shapes): + rays_embedding = flat_interpolate(rays, old=self.original_shapes, new=shapes) + rays_embedding = F.normalize(rays_embedding, dim=-1) + rays_embedding = generate_fourier_features( + rays_embedding, + dim=self.camera_dim, + max_freq=max(shapes) // 2, + use_log=True, + cat_orig=True, + ) + return rays_embedding + + def set_original_shapes(self, shapes: tuple[int, int]): + self.original_shapes = shapes + + def set_shapes(self, shapes: tuple[int, int]): + self.shapes = shapes + + def get_scaleshift(self, x): + scale, shift = torch.chunk(x, 2, dim=1) + scale = scale.exp().reshape(-1, 1, 1, 1) + shift = shift.reshape(-1, 1, 1, 1) + return scale, shift + + def forward(self, features, cls_tokens, rays) -> torch.Tensor: + features = features.unbind(dim=-1) + cls_tokens = self.project_cls(cls_tokens) + rays_embedding = self.project_rays(self.embed_rays(rays, self.shapes)) + rays_embedding = rays_embedding.repeat(1, len(features), 1) + features = self.in_features(torch.cat(features, dim=1) + rays_embedding) + features = torch.cat((features, cls_tokens), dim=1) + cls_tokens = self.aggregate1(cls_tokens, context=features) + cls_tokens = self.aggregate2(cls_tokens, context=features) + x = self.out(cls_tokens).squeeze(-1) + scale, shift = self.get_scaleshift(x) + return scale, shift + + +class DepthHead(nn.Module): + def __init__( + self, + hidden_dim: int, + num_heads: int = 8, + expansion: int = 4, + depths: int | list[int] = 4, + checkpoint: bool = True, + camera_dim: int = 256, + num_resolutions: int = 4, + dropout: float = 0.0, + **kwargs, + ) -> None: + super().__init__() + self.checkpoint = checkpoint + self.camera_dim = camera_dim + self.skip_depth = False + + self.to_latents = MLP(hidden_dim, expansion=2, dropout=dropout) + self.features_channel_cat = nn.Linear(hidden_dim * num_resolutions, hidden_dim) + self.aggregate_16 = AttentionBlock( + hidden_dim, + num_heads=1, + expansion=expansion, + dropout=dropout, + context_dim=hidden_dim, + ) + self.prompt_camera = AttentionBlock( + hidden_dim, + num_heads=1, + expansion=expansion, + dropout=dropout, + context_dim=hidden_dim, + ) + + self.rays_layers = nn.ModuleList([]) + self.ups = nn.ModuleList([]) + self.process_layers = nn.ModuleList([]) + self.norms, self.out_layers = nn.ModuleList([]), nn.ModuleList([]) + self.confidence_norms, self.confidence_out_layers = nn.ModuleList( + [] + ), nn.ModuleList([]) + for i, depth in enumerate(depths): + blk_lst = nn.ModuleList([]) + for _ in range(depth): + blk_lst.append( + NystromBlock( + hidden_dim // int(2**i), + num_heads=num_heads // int(2**i), + expansion=expansion, + dropout=dropout, + ) + ) + self.process_layers.append(blk_lst) + self.rays_layers.append(nn.Linear(camera_dim + 3, hidden_dim // int(2**i))) + self.ups.append( + ConvUpsampleShuffle( + hidden_dim // int(2**i), + expansion=expansion, + kernel_size=7, + num_layers=2, + ) + ) + self.norms.append(nn.LayerNorm(hidden_dim // int(2 ** (i + 1)))) + self.out_layers.append( + nn.Conv2d(hidden_dim // int(2 ** (i + 1)), 1, 3, padding=1) + ) + self.confidence_norms.append(nn.LayerNorm(hidden_dim // int(2 ** (i + 1)))) + self.confidence_out_layers.append( + nn.Conv2d(hidden_dim // int(2 ** (i + 1)), 1, 3, padding=1) + ) + + def set_original_shapes(self, shapes: tuple[int, int]): + self.original_shapes = shapes + + def set_shapes(self, shapes: tuple[int, int]): + self.shapes = shapes + + def embed_rays(self, rays, shapes): + rays_embedding = flat_interpolate(rays, old=self.original_shapes, new=shapes) + rays_embedding = F.normalize(rays_embedding, dim=-1) + rays_embedding = generate_fourier_features( + rays_embedding, + dim=self.camera_dim, + max_freq=max(shapes) // 2, + use_log=True, + cat_orig=True, + ) + return rays_embedding + + def project_rays(self, rays, shapes): + embedded_rays = [] + for i, layer in enumerate(self.rays_layers): + embedded_rays.append( + layer(self.embed_rays(rays, [(2**i) * x for x in shapes])) + ) + return embedded_rays + + def decode_depth(self, latents_16, rays, shapes): + dtype = latents_16.dtype + latents = latents_16 + out_features, confidences, outs = [], [], [] + for i, (up, layers, rays_embedding) in enumerate( + zip(self.ups, self.process_layers, rays) + ): + for layer in layers: + latents = layer(latents, pos_embed=rays_embedding) + latents = up( + rearrange( + latents + rays_embedding, + "b (h w) c -> b c h w", + h=shapes[0] * int(2**i), + w=shapes[1] * int(2**i), + ).contiguous() + ) + out = rearrange( + latents, + "b (h w) c -> b h w c", + h=shapes[0] * int(2 ** (1 + i)), + w=shapes[1] * int(2 ** (1 + i)), + ) + out_features.append(out) + + for i, (norm, out_layer, features) in enumerate( + zip(self.norms[::-1], self.out_layers[::-1], out_features[::-1]) + ): + features = norm(features) + out_d = out_layer(features.permute(0, 3, 1, 2)) + outs.append(out_d) + out = sum( + F.interpolate( + x, + size=outs[0].shape[-2:], + mode="bilinear", + ) + for x in outs + ) + out = out / len(outs) + # jit complains, fix as list (loose dyn input) + out_shapes = [int(s) for s in out.shape[1:]] + out = F.layer_norm(out.float(), out_shapes) + out = out.clamp(-10.0, 10.0).exp().to(dtype, non_blocking=True) + + for i, (norm, out_layer, features) in enumerate( + zip( + self.confidence_norms[::-1], + self.confidence_out_layers[::-1], + out_features[::-1], + ) + ): + features = norm(features) + out_c = out_layer(features.permute(0, 3, 1, 2)) + confidences.append(out_c) + confidence = sum( + F.interpolate( + x, + size=confidences[0].shape[-2:], + mode="bilinear", + ) + for x in confidences + ) + confidence = confidence / len(confidences) + confidence = torch.sigmoid(confidence) + + return out, confidence + + def init_latents(self, features, shapes): + # Generate latents with init as pooled features + features_channels = torch.cat(features, dim=-1) + features_16 = self.features_channel_cat(features_channels) + latents_16 = features_16 + self.to_latents( + flat_interpolate(features_16, old=self.shapes, new=shapes, antialias=False) + ) + return latents_16 + + def forward( + self, features: torch.Tensor, rays_hr: torch.Tensor, pos_embed, level_embed + ) -> torch.Tensor: + B = features.shape[0] + features = features.unbind(dim=-1) + shapes = self.shapes + + # camera_embedding + rays_embeddings = self.project_rays(rays_hr, shapes) + + # Init latents + init_latents_16 = self.init_latents(features, shapes) + + # Aggregate features: F -> D + latents_16 = self.aggregate_16( + init_latents_16, + context=torch.cat(features, dim=1), + pos_embed_context=pos_embed + level_embed, + ) + + # Aggregate camera: D -> D|E + latents_16 = self.prompt_camera(latents_16, context=rays_embeddings[0]) + + # Decode depth + out, confidence = self.decode_depth(latents_16, rays_embeddings, shapes) + + return out, confidence, latents_16 + + +class Decoder(nn.Module): + def __init__( + self, + config, + ): + super().__init__() + self.build(config) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Conv2d): + trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + if m.bias is not None: + nn.init.constant_(m.bias, 0) + if m.weight is not None: + nn.init.constant_(m.weight, 1.0) + + def get_adapted_features(self, features_flat, splits): + features_flat_cat = torch.cat(features_flat, dim=-1) + features_projected = self.input_adapter( + features_flat_cat, splits + ) # list [b hw c] shapes + features = torch.chunk(features_projected, len(splits), dim=-1) + return features + + def run_camera(self, cls_tokens, features, pos_embed, original_shapes, rays_gt): + # get cls tokens projections + cls_tokens_splits = torch.tensor( + [x.shape[-1] for x in cls_tokens], + device=features.device, + requires_grad=False, + dtype=features.dtype, + ) + cls_tokens = torch.cat(cls_tokens, dim=-1) + cls_tokens = self.camera_token_adapter(cls_tokens, cls_tokens_splits) + cls_tokens = torch.cat( + torch.chunk(cls_tokens, len(cls_tokens_splits), dim=-1), dim=1 + ) + + # camera layer + intrinsics = self.camera_layer( + features=features, cls_tokens=cls_tokens, pos_embed=pos_embed + ) + intrinsics[:, 0, 0] = max(original_shapes) / 2 * intrinsics[:, 0, 0] + intrinsics[:, 1, 1] = max(original_shapes) / 2 * intrinsics[:, 1, 1] + intrinsics[:, 0, 2] = intrinsics[:, 0, 2] * original_shapes[1] + intrinsics[:, 1, 2] = intrinsics[:, 1, 2] * original_shapes[0] + + rays = ( + rays_gt + if rays_gt is not None + else generate_rays(intrinsics, original_shapes)[0] + ) + return intrinsics, rays + + def run_global(self, cls_tokens, features, rays): + # get cls tokens projections + cls_tokens_splits = torch.tensor( + [x.shape[-1] for x in cls_tokens], + device=features.device, + requires_grad=False, + dtype=torch.float32, + ) + cls_tokens = torch.cat(cls_tokens, dim=-1) + cls_tokens = self.global_token_adapter(cls_tokens, cls_tokens_splits) + cls_tokens = torch.cat( + torch.chunk(cls_tokens, len(cls_tokens_splits), dim=-1), dim=1 + ) + + scale, shift = self.global_layer( + features=features, rays=rays, cls_tokens=cls_tokens + ) + + return scale, shift + + def forward(self, inputs, image_metas) -> torch.Tensor: + B, C, H, W = inputs["image"].shape + device = inputs["image"].device + + # get features in b n d format + # level shapes, the shape per level, for swin like [[128, 128], [64, 64],...], for vit [[32,32]] -> mult times resolutions + level_shapes = sorted( + list(set([tuple([x.shape[1], x.shape[2]]) for x in inputs["features"]])) + )[::-1] + if len(level_shapes) == 1: + level_shapes = level_shapes * self.num_resolutions + input_shapes = [ + level_shapes[i] + for i, (start, end) in enumerate(self.slices_encoder) + for _ in range(end - start) + ] + common_shape = level_shapes[-2] + + # input shapes repeat shapes for each level, times the amount of the layers: + features_flat = [ + flat_interpolate( + rearrange(x, "b h w c -> b (h w) c"), old=input_shape, new=common_shape + ) + for x, input_shape in zip(inputs["features"], input_shapes) + ] + features_splits = torch.tensor( + [x.shape[-1] for x in features_flat], + device=device, + requires_grad=False, + dtype=torch.float32, + ) + features = self.get_adapted_features(features_flat, features_splits) + features = torch.stack(features, dim=-1) + + # positional embeddings, spatial and level + level_embed = torch.cat( + [ + self.level_embed_layer(self.level_embeds)[i : i + 1] + .unsqueeze(0) + .repeat(B, common_shape[0] * common_shape[1], 1) + for i in range(self.num_resolutions) + ], + dim=1, + ) + dummy_tensor = torch.zeros( + B, 1, common_shape[0], common_shape[1], device=device, requires_grad=False + ) + pos_embed = self.pos_embed(dummy_tensor) + pos_embed = rearrange(pos_embed, "b c h w -> b (h w) c").repeat( + 1, self.num_resolutions, 1 + ) + + self.camera_layer.set_shapes(common_shape) + intrinsics, rays = self.run_camera( + inputs["camera_tokens"], + features=features, + pos_embed=pos_embed + level_embed, + original_shapes=(H, W), + rays_gt=inputs.get("rays"), + ) + + self.global_layer.set_shapes(common_shape) + self.global_layer.set_original_shapes((H, W)) + scale, shift = self.run_global( + inputs["global_tokens"], features=features, rays=rays + ) + + # run bulk of the model + self.depth_layer.set_shapes(common_shape) + self.depth_layer.set_original_shapes((H, W)) + out_normalized, confidence, depth_features = self.depth_layer( + features=features, + rays_hr=rays, + pos_embed=pos_embed, + level_embed=level_embed, + ) + # shift is scale invariant if we do (x + mu) * sigma + out = (out_normalized + shift) * scale + + outputs = { + "depth": out.clamp(min=1e-3), + "confidence": confidence, + "K": intrinsics, + "rays": rays, + "depth_features": depth_features, + } + return outputs + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {"latents_pos", "level_embeds"} + + def build(self, config): + input_dims = config["model"]["pixel_encoder"]["embed_dims"] + hidden_dim = config["model"]["pixel_decoder"]["hidden_dim"] + expansion = config["model"]["expansion"] + num_heads = config["model"]["num_heads"] + dropout = config["model"]["pixel_decoder"]["dropout"] + depths_encoder = config["model"]["pixel_encoder"]["depths"] + depth = config["model"]["pixel_decoder"]["depths"] + depths_encoder = config["model"]["pixel_encoder"]["depths"] + self.downsample = 4 + self.num_resolutions = len(depths_encoder) + + self.slices_encoder = list(zip([d - 1 for d in depths_encoder], depths_encoder)) + cls_token_input_dims = [input_dims[i] for i in [-1, -2, -3, -4]] + input_dims = [input_dims[d - 1] for d in depths_encoder] + + # # camera layer + self.camera_layer = CameraHead( + hidden_dim=hidden_dim, + num_heads=num_heads, + expansion=expansion, + dropout=dropout, + ) + + # # scale shift layer + self.global_layer = GlobalHead( + hidden_dim=hidden_dim, + camera_dim=96, + num_heads=num_heads, + expansion=expansion, + dropout=dropout, + ) + + # # adapt from encoder features, just project + self.input_adapter = ListAdapter(input_dims, hidden_dim) + self.camera_token_adapter = ListAdapter(cls_token_input_dims, hidden_dim) + self.global_token_adapter = ListAdapter(cls_token_input_dims[:2], hidden_dim) + + self.depth_layer = DepthHead( + hidden_dim=hidden_dim, + num_heads=num_heads, + expansion=expansion, + depths=depth, + dropout=dropout, + camera_dim=96, + num_resolutions=self.num_resolutions, + ) + + self.pos_embed = PositionEmbeddingSine(hidden_dim // 2, normalize=True) + self.level_embeds = nn.Parameter( + torch.randn(len(input_dims), hidden_dim), requires_grad=True + ) + self.level_embed_layer = nn.Sequential( + nn.Linear(hidden_dim, hidden_dim), + nn.GELU(), + nn.Linear(hidden_dim, hidden_dim), + nn.LayerNorm(hidden_dim), + ) diff --git a/unidepth/models/unidepthv2/export.py b/unidepth/models/unidepthv2/export.py new file mode 100644 index 0000000000000000000000000000000000000000..fbc3d1969c91e7ed317f3f33d0ef60b2741258fe --- /dev/null +++ b/unidepth/models/unidepthv2/export.py @@ -0,0 +1,218 @@ +""" +Author: Luigi Piccinelli +Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/) +""" + +import argparse +import json +import os +from math import ceil + +import huggingface_hub +import torch.nn.functional as F +import torch.onnx + +from unidepth.models.unidepthv2 import UniDepthV2 +from unidepth.utils.geometric import generate_rays + + +class UniDepthV2ONNX(UniDepthV2): + def __init__( + self, + config, + eps: float = 1e-6, + **kwargs, + ): + super(UniDepthV2ONNX, self).__init__(config, eps) + + def forward(self, rgbs): + H, W = rgbs.shape[-2:] + + features, tokens = self.pixel_encoder(rgbs) + + cls_tokens = [x.contiguous() for x in tokens] + features = [ + self.stacking_fn(features[i:j]).contiguous() + for i, j in self.slices_encoder_range + ] + tokens = [ + self.stacking_fn(tokens[i:j]).contiguous() + for i, j in self.slices_encoder_range + ] + global_tokens = [cls_tokens[i] for i in [-2, -1]] + camera_tokens = [cls_tokens[i] for i in [-3, -2, -1]] + [tokens[-2]] + + inputs = {} + inputs["image"] = rgbs + inputs["features"] = features + inputs["tokens"] = tokens + inputs["global_tokens"] = global_tokens + inputs["camera_tokens"] = camera_tokens + + outs = self.pixel_decoder(inputs, {}) + + predictions = F.interpolate( + outs["depth"], + size=(H, W), + mode="bilinear", + ) + confidence = F.interpolate( + outs["confidence"], + size=(H, W), + mode="bilinear", + ) + + return outs["K"], predictions, confidence + + +class UniDepthV2wCamONNX(UniDepthV2): + def __init__( + self, + config, + eps: float = 1e-6, + **kwargs, + ): + super(UniDepthV2wCamONNX, self).__init__(config, eps) + + def forward(self, rgbs, K): + H, W = rgbs.shape[-2:] + + features, tokens = self.pixel_encoder(rgbs) + + cls_tokens = [x.contiguous() for x in tokens] + features = [ + self.stacking_fn(features[i:j]).contiguous() + for i, j in self.slices_encoder_range + ] + tokens = [ + self.stacking_fn(tokens[i:j]).contiguous() + for i, j in self.slices_encoder_range + ] + global_tokens = [cls_tokens[i] for i in [-2, -1]] + camera_tokens = [cls_tokens[i] for i in [-3, -2, -1]] + [tokens[-2]] + + inputs = {} + inputs["image"] = rgbs + inputs["features"] = features + inputs["tokens"] = tokens + inputs["global_tokens"] = global_tokens + inputs["camera_tokens"] = camera_tokens + rays, angles = generate_rays(K, (H, W)) + inputs["rays"] = rays + inputs["angles"] = angles + inputs["K"] = K + + outs = self.pixel_decoder(inputs, {}) + + predictions = F.interpolate( + outs["depth"], + size=(H, W), + mode="bilinear", + ) + predictions_normalized = F.interpolate( + outs["depth_ssi"], + size=(H, W), + mode="bilinear", + ) + confidence = F.interpolate( + outs["confidence"], + size=(H, W), + mode="bilinear", + ) + + return outs["K"], predictions, predictions_normalized, confidence + + +def export(model, path, shape=(462, 616), with_camera=False): + model.eval() + image = torch.rand(1, 3, *shape) + dynamic_axes_in = {"image": {0: "batch"}} + inputs = [image] + if with_camera: + K = torch.rand(1, 3, 3) + inputs.append(K) + dynamic_axes_in["K"] = {0: "batch"} + + dynamic_axes_out = { + "out_K": {0: "batch"}, + "depth": {0: "batch"}, + "confidence": {0: "batch"}, + } + torch.onnx.export( + model, + tuple(inputs), + path, + input_names=list(dynamic_axes_in.keys()), + output_names=list(dynamic_axes_out.keys()), + opset_version=14, + dynamic_axes={**dynamic_axes_in, **dynamic_axes_out}, + ) + print(f"Model exported to {path}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Export UniDepthV2 model to ONNX") + parser.add_argument( + "--version", type=str, default="v2", choices=["v2"], help="UniDepth version" + ) + parser.add_argument( + "--backbone", + type=str, + default="vitl14", + choices=["vits14", "vitl14"], + help="Backbone model", + ) + parser.add_argument( + "--shape", + type=int, + nargs=2, + default=(462, 616), + help="Input shape. No dyamic shape supported!", + ) + parser.add_argument( + "--output-path", type=str, default="unidepthv2.onnx", help="Output ONNX file" + ) + parser.add_argument( + "--with-camera", + action="store_true", + help="Export model that expects GT camera matrix at inference", + ) + args = parser.parse_args() + + version = args.version + backbone = args.backbone + shape = args.shape + output_path = args.output_path + with_camera = args.with_camera + + # force shape to be multiple of 14 + shape_rounded = [14 * ceil(x // 14 - 0.5) for x in shape] + if list(shape) != list(shape_rounded): + print(f"Shape {shape} is not multiple of 14. Rounding to {shape_rounded}") + shape = shape_rounded + + # assumes command is from root of repo + with open(os.path.join("configs", f"config_{version}_{backbone}.json")) as f: + config = json.load(f) + + # tell DINO not to use efficient attention: not exportable + config["training"]["export"] = True + + model_factory = UniDepthV2ONNX if not with_camera else UniDepthV2wCamONNX + model = model_factory(config) + path = huggingface_hub.hf_hub_download( + repo_id=f"lpiccinelli/unidepth-{version}-{backbone}", + filename=f"pytorch_model.bin", + repo_type="model", + ) + info = model.load_state_dict(torch.load(path), strict=False) + print(f"UniDepth_{version}_{backbone} is loaded with:") + print(f"\t missing keys: {info.missing_keys}") + print(f"\t additional keys: {info.unexpected_keys}") + + export( + model=model, + path=os.path.join(os.environ["TMPDIR"], output_path), + shape=shape, + with_camera=with_camera, + ) diff --git a/unidepth/models/unidepthv2/unidepthv2.py b/unidepth/models/unidepthv2/unidepthv2.py new file mode 100644 index 0000000000000000000000000000000000000000..551ea340fb70aacf7f8f561763e7cb2cfca32f2e --- /dev/null +++ b/unidepth/models/unidepthv2/unidepthv2.py @@ -0,0 +1,364 @@ +import importlib +import warnings +from copy import deepcopy +from math import ceil + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.transforms.functional as TF +from einops import rearrange +from huggingface_hub import PyTorchModelHubMixin +import torchvision.transforms as transforms + + +from unidepth.models.unidepthv2.decoder import Decoder +from unidepth.utils.constants import (IMAGENET_DATASET_MEAN, + IMAGENET_DATASET_STD) +from unidepth.utils.distributed import is_main_process +from unidepth.utils.geometric import (generate_rays, + spherical_zbuffer_to_euclidean) +from unidepth.utils.misc import (first_stack, last_stack, max_stack, + mean_stack, softmax_stack) + +STACKING_FNS = { + "max": max_stack, + "mean": mean_stack, + "first": first_stack, + "last": last_stack, + "softmax": softmax_stack, +} +RESOLUTION_LEVELS = 10 + + +# inference helpers +def _check_ratio(image_ratio, ratio_bounds): + ratio_bounds = sorted(ratio_bounds) + if ratio_bounds is not None and ( + image_ratio < ratio_bounds[0] or image_ratio > ratio_bounds[1] + ): + warnings.warn( + f"Input image ratio ({image_ratio:.3f}) is out of training " + f"distribution: {ratio_bounds}. This may lead to unexpected results. " + f"Consider resizing/padding the image to match the training distribution." + ) + + +def _check_resolution(shape_constraints, resolution_level): + if resolution_level is None: + warnings.warn( + "Resolution level is not set. Using max resolution. " + "You can tradeoff resolution for speed by setting a number in [0,10]. " + "This can be achieved by setting model's `resolution_level` attribute." + ) + resolution_level = RESOLUTION_LEVELS + pixel_bounds = sorted(shape_constraints["pixels_bounds_ori"]) + pixel_range = pixel_bounds[-1] - pixel_bounds[0] + clipped_resolution_level = min(max(resolution_level, 0), RESOLUTION_LEVELS) + if clipped_resolution_level != resolution_level: + warnings.warn( + f"Resolution level {resolution_level} is out of bounds ([0,{RESOLUTION_LEVELS}]). " + f"Clipping to {clipped_resolution_level}." + ) + shape_constraints["pixels_bounds"] = [ + pixel_bounds[0] + + ceil(pixel_range * clipped_resolution_level / RESOLUTION_LEVELS), + pixel_bounds[0] + + ceil(pixel_range * clipped_resolution_level / RESOLUTION_LEVELS), + ] + return shape_constraints + + +def _get_closes_num_pixels(image_shape, pixels_bounds): + h, w = image_shape + num_pixels = h * w + pixels_bounds = sorted(pixels_bounds) + num_pixels = max(min(num_pixels, pixels_bounds[1]), pixels_bounds[0]) + return num_pixels + + +def _shapes(image_shape, shape_constraints): + h, w = image_shape + image_ratio = w / h + # _check_ratio(image_ratio, shape_constraints["ratio_bounds"]) + num_pixels = _get_closes_num_pixels( + (h / shape_constraints["patch_size"], w / shape_constraints["patch_size"]), + shape_constraints["pixels_bounds"], + ) + h = ceil((num_pixels / image_ratio) ** 0.5 - 0.5) + w = ceil(h * image_ratio - 0.5) + ratio = h / image_shape[0] * shape_constraints["patch_size"] + return ( + h * shape_constraints["patch_size"], + w * shape_constraints["patch_size"], + ), ratio + + +def _preprocess(rgbs, intrinsics, shapes, ratio): + rgbs = F.interpolate(rgbs, size=shapes, mode="bilinear", antialias=True) + if intrinsics is not None: + intrinsics = intrinsics.clone() + intrinsics[:, 0, 0] = intrinsics[:, 0, 0] * ratio + intrinsics[:, 1, 1] = intrinsics[:, 1, 1] * ratio + intrinsics[:, 0, 2] = intrinsics[:, 0, 2] * ratio + intrinsics[:, 1, 2] = intrinsics[:, 1, 2] * ratio + return rgbs, intrinsics + return rgbs, None + + +def _postprocess(outs, ratio, original_shapes, mode="nearest-exact"): + outs["depth"] = F.interpolate(outs["depth"], size=original_shapes, mode=mode) + outs["confidence"] = F.interpolate( + outs["confidence"], size=original_shapes, mode="bilinear", antialias=True + ) + outs["K"][:, 0, 0] = outs["K"][:, 0, 0] / ratio + outs["K"][:, 1, 1] = outs["K"][:, 1, 1] / ratio + outs["K"][:, 0, 2] = outs["K"][:, 0, 2] / ratio + outs["K"][:, 1, 2] = outs["K"][:, 1, 2] / ratio + return outs + + +class UniDepthV2( + nn.Module, + PyTorchModelHubMixin, + library_name="UniDepth", + repo_url="https://github.com/lpiccinelli-eth/UniDepth", + tags=["monocular-metric-depth-estimation"], +): + def __init__( + self, + config, + eps: float = 1e-6, + **kwargs, + ): + super().__init__() + self.build(config) + self.interpolation_mode = "bilinear" + self.eps = eps + self.resolution_level = 10 + + def forward(self, inputs, image_metas=None): + H, W = inputs["depth"].shape[-2:] + + if "K" in inputs: + rays, angles = generate_rays(inputs["K"], (H, W)) + inputs["rays"] = rays + inputs["angles"] = angles + + features, tokens = self.pixel_encoder(inputs[f"image"]) + + cls_tokens = [x.contiguous() for x in tokens] + features = [ + self.stacking_fn(features[i:j]).contiguous() + for i, j in self.slices_encoder_range + ] + tokens = [ + self.stacking_fn(tokens[i:j]).contiguous() + for i, j in self.slices_encoder_range + ] + global_tokens = [cls_tokens[i] for i in [-2, -1]] + camera_tokens = [cls_tokens[i] for i in [-3, -2, -1]] + [tokens[-2]] + + inputs["features"] = features + inputs["tokens"] = tokens + inputs["global_tokens"] = global_tokens + inputs["camera_tokens"] = camera_tokens + + outs = self.pixel_decoder(inputs, image_metas) + + angles = rearrange( + generate_rays(outs["K"], (H, W), noisy=False)[-1], + "b (h w) c -> b c h w", + h=H, + w=W, + ) + predictions = F.interpolate( + outs["depth"], + size=(H, W), + mode="bilinear", + align_corners=False, + antialias=True, + ) + confidence = F.interpolate( + outs["confidence"], + size=(H, W), + mode="bilinear", + align_corners=False, + antialias=True, + ) + predictions_3d = torch.cat((angles, predictions), dim=1) + predictions_3d = spherical_zbuffer_to_euclidean( + predictions_3d.permute(0, 2, 3, 1) + ).permute(0, 3, 1, 2) + + outputs = { + "K": outs["K"], + "depth": predictions.squeeze(1), + "confidence": confidence, + "points": predictions_3d, + "depth_features": outs["depth_features"], + } + return outputs + + @torch.no_grad() + def infer(self, rgbs: torch.Tensor, intrinsics=None): + shape_constraints = self.shape_constraints + if rgbs.ndim == 3: + rgbs = rgbs.unsqueeze(0) + if intrinsics is not None and intrinsics.ndim == 2: + intrinsics = intrinsics.unsqueeze(0) + + B, _, H, W = rgbs.shape + target_aspect_ratio = 1.33 # for example + # Calculate new width or height based on target aspect ratio + new_width = int(H * target_aspect_ratio) + # Resize the image + resize_transform = transforms.Resize((H, new_width)) # You can also pad if needed + rgbs = resize_transform(rgbs) + rgbs = rgbs.to(self.device) + + if intrinsics is not None: + scale_width = new_width / W + # Adjust the intrinsic matrix + K_new = intrinsics.clone() + K_new[0, 0] = K_new[0, 0] * scale_width # f_x + K_new[0, 2] = K_new[0, 2] * scale_width # c_x + intrinsics = K_new.to(self.device) + + # process image and intrinsiscs (if any) to match network input (slow?) + if rgbs.max() > 5 or rgbs.dtype == torch.uint8: + rgbs = rgbs.to(torch.float32).div(255) + if rgbs.min() >= 0.0 and rgbs.max() <= 1.0: + rgbs = TF.normalize( + rgbs, + mean=IMAGENET_DATASET_MEAN, + std=IMAGENET_DATASET_STD, + ) + + # check resolution constraints: tradeoff resolution and speed + shape_constraints = _check_resolution(shape_constraints, self.resolution_level) + + # get image shape + (h, w), ratio = _shapes((H, W), shape_constraints) + rgbs, gt_intrinsics = _preprocess( + rgbs, + intrinsics, + (h, w), + ratio, + ) + + # run encoder + features, tokens = self.pixel_encoder(rgbs) + + cls_tokens = [x.contiguous() for x in tokens] + features = [ + self.stacking_fn(features[i:j]).contiguous() + for i, j in self.slices_encoder_range + ] + tokens = [ + self.stacking_fn(tokens[i:j]).contiguous() + for i, j in self.slices_encoder_range + ] + global_tokens = [cls_tokens[i] for i in [-2, -1]] + camera_tokens = [cls_tokens[i] for i in [-3, -2, -1]] + [tokens[-2]] + + # get data fro decoder and adapt to given camera + inputs = {} + inputs["features"] = features + inputs["tokens"] = tokens + inputs["global_tokens"] = global_tokens + inputs["camera_tokens"] = camera_tokens + inputs["image"] = rgbs + if gt_intrinsics is not None: + rays, angles = generate_rays(gt_intrinsics, (h, w)) + inputs["rays"] = rays + inputs["angles"] = angles + inputs["K"] = gt_intrinsics + + outs = self.pixel_decoder(inputs, {}) + # undo the reshaping and get original image size (slow) + outs = _postprocess(outs, ratio, (H, W), mode=self.interpolation_mode) + pred_intrinsics = outs["K"] + depth = outs["depth"] + confidence = outs["confidence"] + + # final 3D points backprojection + intrinsics = intrinsics if intrinsics is not None else pred_intrinsics + angles = generate_rays(intrinsics, (H, W))[-1] + angles = rearrange(angles, "b (h w) c -> b c h w", h=H, w=W) + points_3d = torch.cat((angles, depth), dim=1) + points_3d = spherical_zbuffer_to_euclidean( + points_3d.permute(0, 2, 3, 1) + ).permute(0, 3, 1, 2) + + outputs = { + "intrinsics": pred_intrinsics, + "points": points_3d, + "depth": depth.squeeze(1), + "confidence": confidence, + } + return outputs + + def load_pretrained(self, model_file): + device = ( + torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + ) + dict_model = torch.load(model_file, map_location=device) + if "model" in dict_model: + dict_model = dict_model["model"] + new_state_dict = deepcopy( + {k.replace("module.", ""): v for k, v in dict_model.items()} + ) + + info = self.load_state_dict(new_state_dict, strict=False) + if is_main_process(): + print( + f"Loaded from {model_file} for {self.__class__.__name__} results in:", + info, + ) + + @property + def device(self): + return next(self.parameters()).device + + def build(self, config): + mod = importlib.import_module("unidepth.models.encoder") + pixel_encoder_factory = getattr(mod, config["model"]["pixel_encoder"]["name"]) + pixel_encoder_config = { + **config["training"], + **config["data"], + **config["model"]["pixel_encoder"], + } + pixel_encoder = pixel_encoder_factory(pixel_encoder_config) + + config["model"]["pixel_encoder"]["patch_size"] = ( + 14 if "dino" in config["model"]["pixel_encoder"]["name"] else 16 + ) + pixel_encoder_embed_dims = ( + pixel_encoder.embed_dims + if hasattr(pixel_encoder, "embed_dims") + else [getattr(pixel_encoder, "embed_dim") * 2**i for i in range(4)] + ) + config["model"]["pixel_encoder"]["embed_dim"] = getattr( + pixel_encoder, "embed_dim" + ) + config["model"]["pixel_encoder"]["embed_dims"] = pixel_encoder_embed_dims + config["model"]["pixel_encoder"]["depths"] = pixel_encoder.depths + + pixel_decoder = Decoder(config) + + self.pixel_encoder = pixel_encoder + self.pixel_decoder = pixel_decoder + stacking_fn = config["model"]["pixel_encoder"]["stacking_fn"] + assert ( + stacking_fn in STACKING_FNS + ), f"Stacking function {stacking_fn} not found in {STACKING_FNS.keys()}" + self.stacking_fn = STACKING_FNS[stacking_fn] + + self.slices_encoder_range = list( + zip([0, *pixel_encoder.depths[:-1]], pixel_encoder.depths) + ) + self.shape_constraints = config["data"]["shape_constraints"] + self.shape_constraints["pixels_bounds_ori"] = self.shape_constraints.get( + "pixels_bounds", [1400, 2400] + ) diff --git a/unidepth/ops/__init__.py b/unidepth/ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..eb25b2f28686628ae481af311f8890d3267cc751 --- /dev/null +++ b/unidepth/ops/__init__.py @@ -0,0 +1,9 @@ +from .losses import MSE, SelfCons, SILog +from .scheduler import CosineScheduler + +__all__ = [ + "SILog", + "MSE", + "SelfCons", + "CosineScheduler", +] diff --git a/unidepth/ops/losses.py b/unidepth/ops/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..c0827694dbeb8d2140bf1b249a118453d6588302 --- /dev/null +++ b/unidepth/ops/losses.py @@ -0,0 +1,428 @@ +""" +Author: Luigi Piccinelli +Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/) +""" + +from typing import Any, Dict, List, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +FNS = { + "sqrt": torch.sqrt, + "log": torch.log, + "log1": lambda x: torch.log(x + 1), + "linear": lambda x: x, + "square": torch.square, + "disp": lambda x: 1 / x, +} + + +FNS_INV = { + "sqrt": torch.square, + "log": torch.exp, + "log1": lambda x: torch.exp(x) - 1, + "linear": lambda x: x, + "square": torch.sqrt, + "disp": lambda x: 1 / x, +} + + +def masked_mean_var(data: torch.Tensor, mask: torch.Tensor, dim: List[int]): + if mask is None: + return data.mean(dim=dim, keepdim=True), data.var(dim=dim, keepdim=True) + mask = mask.float() + mask_sum = torch.sum(mask, dim=dim, keepdim=True) + mask_mean = torch.sum(data * mask, dim=dim, keepdim=True) / torch.clamp( + mask_sum, min=1.0 + ) + mask_var = torch.sum( + mask * (data - mask_mean) ** 2, dim=dim, keepdim=True + ) / torch.clamp(mask_sum, min=1.0) + return mask_mean.squeeze(dim), mask_var.squeeze(dim) + + +def masked_mean(data: torch.Tensor, mask: torch.Tensor | None, dim: List[int]): + if mask is None: + return data.mean(dim=dim, keepdim=True) + mask = mask.float() + mask_sum = torch.sum(mask, dim=dim, keepdim=True) + mask_mean = torch.sum(data * mask, dim=dim, keepdim=True) / torch.clamp( + mask_sum, min=1.0 + ) + return mask_mean + + +def masked_mae(data: torch.Tensor, mask: torch.Tensor, dim: Tuple[int, ...]): + if mask is None: + return data.abs().mean(dim=dim, keepdim=True) + mask = mask.float() + mask_sum = torch.sum(mask, dim=dim, keepdim=True) + mask_mean = torch.sum(data.abs() * mask, dim=dim, keepdim=True) / torch.clamp( + mask_sum, min=1.0 + ) + return mask_mean + + +def masked_mse(data: torch.Tensor, mask: torch.Tensor, dim: Tuple[int, ...]): + if mask is None: + return (data**2).mean(dim=dim, keepdim=True) + mask = mask.float() + mask_sum = torch.sum(mask, dim=dim, keepdim=True) + mask_mean = torch.sum((data**2) * mask, dim=dim, keepdim=True) / torch.clamp( + mask_sum, min=1.0 + ) + return mask_mean + + +def masked_median(data: torch.Tensor, mask: torch.Tensor, dim: List[int]): + ndim = data.ndim + data = data.flatten(ndim - len(dim)) + mask = mask.flatten(ndim - len(dim)) + mask_median = torch.median(data[mask], dim=-1).values + return mask_median + + +def masked_median_mad(data: torch.Tensor, mask: torch.Tensor): + data = data.flatten() + mask = mask.flatten() + mask_median = torch.median(data[mask]) + n_samples = torch.clamp(torch.sum(mask.float()), min=1.0) + mask_mad = torch.sum((data[mask] - mask_median).abs()) / n_samples + return mask_median, mask_mad + + +def masked_weighted_mean_var( + data: torch.Tensor, mask: torch.Tensor, weights: torch.Tensor, dim: Tuple[int, ...] +): + if mask is None: + return data.mean(dim=dim, keepdim=True), data.var(dim=dim, keepdim=True) + mask = mask.float() + mask_mean = torch.sum(data * mask * weights, dim=dim, keepdim=True) / torch.sum( + mask * weights, dim=dim, keepdim=True + ).clamp(min=1.0) + # V1**2 - V2, V1: sum w_i, V2: sum w_i**2 + denom = torch.sum(weights * mask, dim=dim, keepdim=True).square() - torch.sum( + (mask * weights).square(), dim=dim, keepdim=True + ) + # correction is V1 / (V1**2 - V2), if w_i=1 => N/(N**2 - N) => 1/(N-1) (unbiased estimator of variance, cvd) + correction_factor = torch.sum(mask * weights, dim=dim, keepdim=True) / denom.clamp( + min=1.0 + ) + mask_var = correction_factor * torch.sum( + weights * mask * (data - mask_mean) ** 2, dim=dim, keepdim=True + ) + return mask_mean, mask_var + + +def masked_mean_var_q(data: torch.Tensor, mask: torch.Tensor, dim: List[int]): + if mask is None: + return data.mean(dim=dim, keepdim=True), data.var(dim=dim, keepdim=True) + mask = mask.float() + mask_sum = torch.sum(mask, dim=dim, keepdim=True) + mask_mean = torch.sum(data * mask, dim=dim, keepdim=True) / torch.clamp( + mask_sum, min=1.0 + ) + mask_var = torch.sum( + mask * (data - mask_mean) ** 2, dim=dim, keepdim=True + ) / torch.clamp(mask_sum, min=1.0) + return mask_mean, mask_var + + +class SILog(nn.Module): + def __init__( + self, + weight: float, + scale_pred_weight: float = 0.15, + output_fn: str = "sqrt", + input_fn: str = "log", + legacy: bool = False, + abs_rel: bool = False, + norm: bool = False, + eps: float = 1e-5, + ): + super().__init__() + assert output_fn in FNS + self.name: str = self.__class__.__name__ + self.weight: float = weight + + self.scale_pred_weight: float = scale_pred_weight + self.dims = (-4, -3, -2, -1) if legacy else (-2, -1) + self.output_fn = FNS[output_fn] + self.input_fn = FNS[input_fn] + self.abs_rel = abs_rel + self.norm = norm + self.eps: float = eps + + @torch.cuda.amp.autocast(enabled=False) + def forward( + self, + input: torch.Tensor, + target: torch.Tensor, + mask: Optional[torch.Tensor] = None, + interpolate: bool = True, + scale_inv: torch.Tensor | None = None, + ss_inv: torch.Tensor | None = None, + **kwargs, + ) -> torch.Tensor: + if interpolate: + input = F.interpolate( + input, target.shape[-2:], mode="bilinear", align_corners=False + ) + if mask is not None: + mask = mask.to(torch.bool) + if ss_inv is not None: + ss_inv = ~ss_inv + + if input.shape[1] > 1: + input_ = torch.cat( + [input[:, :-1], self.input_fn(input[:, -1:].clamp(min=self.eps))], dim=1 + ) + target_ = torch.cat( + [target[:, :-1], self.input_fn(target[:, -1:].clamp(min=self.eps))], + dim=1, + ) + error = torch.norm(input_ - target_, dim=1, keepdim=True) + else: + input_ = self.input_fn(input.clamp(min=self.eps)) + target_ = self.input_fn(target.clamp(min=self.eps)) + error = input_ - target_ + + mean_error, var_error = masked_mean_var(data=error, mask=mask, dim=self.dims) + + # prevoiusly was inverted!! + if self.abs_rel: + scale_error = (input - target).abs()[:, -1:] / target[:, -1:].clip( + min=self.eps + ) + scale_error = masked_mean(data=scale_error, mask=mask, dim=self.dims) + else: + scale_error = mean_error**2 + + if var_error.ndim > 1: + var_error = var_error.sum(dim=1) + scale_error = scale_error.sum(dim=1) + + # if scale inv -> mask scale error, if scale/shift, mask the full loss + if scale_inv is not None: + scale_error = (1 - scale_inv.int()) * scale_error + scale_error = self.scale_pred_weight * scale_error + loss = var_error + scale_error + out_loss = self.output_fn(loss.clamp(min=self.eps)) + out_loss = masked_mean(data=out_loss, mask=ss_inv, dim=(0,)) + return out_loss.mean() + + @classmethod + def build(cls, config: Dict[str, Any]): + obj = cls( + weight=config["weight"], + legacy=config["legacy"], + output_fn=config["output_fn"], + input_fn=config["input_fn"], + norm=config.get("norm", False), + scale_pred_weight=config.get("gamma", 0.15), + abs_rel=config.get("abs_rel", False), + ) + return obj + + +class MSE(nn.Module): + def __init__( + self, + weight: float = 1.0, + input_fn: str = "linear", + output_fn: str = "linear", + ): + super().__init__() + self.name: str = self.__class__.__name__ + self.output_fn = FNS[output_fn] + self.input_fn = FNS[input_fn] + self.weight: float = weight + self.eps = 1e-6 + + @torch.cuda.amp.autocast(enabled=False) + def forward( + self, + input: torch.Tensor, + target: torch.Tensor, + mask: torch.Tensor | None = None, + batch_mask: torch.Tensor | None = None, + **kwargs, + ) -> torch.Tensor: + input = input[..., : target.shape[-1]] # B N C or B H W C + error = self.input_fn(input + self.eps) - self.input_fn(target + self.eps) + abs_error = torch.square(error).sum(dim=-1) + mean_error = masked_mean(data=abs_error, mask=mask, dim=(-1,)).mean(dim=-1) + batched_error = masked_mean( + self.output_fn(mean_error.clamp(self.eps)), batch_mask, dim=(0,) + ) + return batched_error.mean(), mean_error.detach() + + @classmethod + def build(cls, config: Dict[str, Any]): + obj = cls( + weight=config["weight"], + output_fn=config["output_fn"], + input_fn=config["input_fn"], + ) + return obj + + +class SelfCons(nn.Module): + def __init__( + self, + weight: float, + scale_pred_weight: float = 0.15, + output_fn: str = "sqrt", + input_fn: str = "log", + abs_rel: bool = False, + norm: bool = False, + eps: float = 1e-5, + ): + super().__init__() + assert output_fn in FNS + self.name: str = self.__class__.__name__ + self.weight: float = weight + + self.scale_pred_weight: float = scale_pred_weight + self.dims = (-2, -1) + self.output_fn = FNS[output_fn] + self.input_fn = FNS[input_fn] + self.abs_rel = abs_rel + self.norm = norm + self.eps: float = eps + + @torch.cuda.amp.autocast(enabled=False) + def forward( + self, + input: torch.Tensor, + mask: torch.Tensor, + metas: List[Dict[str, torch.Tensor]], + ) -> torch.Tensor: + chunks = input.shape[0] // 2 + device = input.device + mask = F.interpolate(mask.float(), size=input.shape[-2:], mode="nearest") + + rescales = input.shape[-2] / torch.tensor( + [x["resized_shape"][0] for x in metas], device=device + ) + cams = torch.cat([x["K_target"] for x in metas], dim=0).to(device) + flips = torch.tensor([x["flip"] for x in metas], device=device) + + iters = zip( + input.chunk(chunks), + mask.chunk(chunks), + cams.chunk(chunks), + rescales.chunk(chunks), + flips.chunk(chunks), + ) + inputs0, inputs1, masks = [], [], [] + for i, (pair_input, pair_mask, pair_cam, pair_rescale, pair_flip) in enumerate( + iters + ): + mask0, mask1 = pair_mask + input0, input1 = pair_input + cam0, cam1 = pair_cam + rescale0, rescale1 = pair_rescale + flip0, flip1 = pair_flip + + fx_0 = cam0[0, 0] * rescale0 + fx_1 = cam1[0, 0] * rescale1 + cx_0 = (cam0[0, 2] - 0.5) * rescale0 + 0.5 + cx_1 = (cam1[0, 2] - 0.5) * rescale1 + 0.5 + cy_0 = (cam0[1, 2] - 0.5) * rescale0 + 0.5 + cy_1 = (cam1[1, 2] - 0.5) * rescale1 + 0.5 + + # flip image + if flip0 ^ flip1: + input0 = torch.flip(input0, dims=(2,)) + mask0 = torch.flip(mask0, dims=(2,)) + cx_0 = input0.shape[-1] - cx_0 + + # calc zoom + zoom_x = float(fx_1 / fx_0) + + # apply zoom + input0 = F.interpolate( + input0.unsqueeze(0), + scale_factor=zoom_x, + mode="bilinear", + align_corners=True, + ).squeeze(0) + mask0 = F.interpolate( + mask0.unsqueeze(0), scale_factor=zoom_x, mode="nearest" + ).squeeze(0) + + # calc translation + change_left = int(cx_1 - (cx_0 - 0.5) * zoom_x - 0.5) + change_top = int(cy_1 - (cy_0 - 0.5) * zoom_x - 0.5) + change_right = input1.shape[-1] - change_left - input0.shape[-1] + change_bottom = input1.shape[-2] - change_top - input0.shape[-2] + + # apply translation + pad_left = max(0, change_left) + pad_right = max(0, change_right) + pad_top = max(0, change_top) + pad_bottom = max(0, change_bottom) + + crop_left = max(0, -change_left) + crop_right = max(0, -change_right) + crop_top = max(0, -change_top) + crop_bottom = max(0, -change_bottom) + + input0 = F.pad( + input0, + (pad_left, pad_right, pad_top, pad_bottom), + mode="constant", + value=0, + ) + mask0 = F.pad( + mask0, + (pad_left, pad_right, pad_top, pad_bottom), + mode="constant", + value=0, + ) + input0 = input0[ + :, + crop_top : input0.shape[-2] - crop_bottom, + crop_left : input0.shape[-1] - crop_right, + ] + mask0 = mask0[ + :, + crop_top : mask0.shape[-2] - crop_bottom, + crop_left : mask0.shape[-1] - crop_right, + ] + + mask = torch.logical_and(mask0, mask1) + + inputs0.append(input0) + inputs1.append(input1) + masks.append(mask) + + inputs0 = torch.stack(inputs0, dim=0) + inputs1 = torch.stack(inputs1, dim=0) + masks = torch.stack(masks, dim=0) + loss1 = self.loss(inputs0, inputs1.detach(), masks) + loss2 = self.loss(inputs1, inputs0.detach(), masks) + return torch.cat([loss1, loss2], dim=0).mean() + + def loss( + self, + input: torch.Tensor, + target: torch.Tensor, + mask: torch.Tensor, + ) -> torch.Tensor: + loss = masked_mean( + (input - target).square().mean(dim=1), mask=mask, dim=(-2, -1) + ) + return self.output_fn(loss + self.eps) + + @classmethod + def build(cls, config: Dict[str, Any]): + obj = cls( + weight=config["weight"], + output_fn=config["output_fn"], + input_fn=config["input_fn"], + ) + return obj diff --git a/unidepth/ops/scheduler.py b/unidepth/ops/scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..a182ff6e204ab445a67846314a8bea087119685e --- /dev/null +++ b/unidepth/ops/scheduler.py @@ -0,0 +1,70 @@ +""" +Author: Luigi Piccinelli +Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/) +""" + +import numpy as np + + +class CosineScheduler(object): + def __init__( + self, + optimizer, + warmup_iters, + total_iters, + key, + overwrite=False, + init_value=None, + base_value=None, + final_value=None, + step_init=-1, + ): + super().__init__() + self.iter = step_init + self.overwrite = overwrite + self.optimizer = optimizer + self.base_value = base_value + self.init_value = init_value + self.final_value = final_value + self.total_iters = total_iters + self.warmup_iters = warmup_iters + self.key = key + self.schedulers = [ + self.get_schedulers(group) for group in optimizer.param_groups + ] + + def get_schedulers(self, group): + init_value = group.get(self.key + "_init", self.init_value) + base_value = group.get(self.key + "_base", self.base_value) + final_value = group.get(self.key + "_final", self.final_value) + warmup_iters = self.warmup_iters + total_iters = self.total_iters + if self.overwrite: + final_value = self.final_value + + # normalize in 0,1, then apply function (power) and denormalize + normalized_schedule = np.linspace(0, 1, warmup_iters, endpoint=True) + normalized_schedule = np.power(normalized_schedule, 2) + warmup_schedule = (base_value - init_value) * normalized_schedule + init_value + + # main scheduling + iters = np.arange(total_iters - warmup_iters) + schedule = final_value + 0.5 * (base_value - final_value) * ( + 1 + np.cos(np.pi * iters / len(iters)) + ) + return np.concatenate((warmup_schedule, schedule)) + + def step(self): + self.iter = self.iter + 1 + vals = self[self.iter] + for group, val in zip(self.optimizer.param_groups, vals): + if isinstance(group[self.key], (tuple, list)): + val = (val, *group[self.key][1:]) + group[self.key] = val + + def __getitem__(self, it): + it = min(it, self.total_iters - 1) + return [scheduler[it] for scheduler in self.schedulers] + + def get(self): + return [group[self.key] for group in self.optimizer.param_groups] diff --git a/unidepth/utils/__init__.py b/unidepth/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..08d90351357dc4034b1508ea57bea750d1f49b3d --- /dev/null +++ b/unidepth/utils/__init__.py @@ -0,0 +1,29 @@ +from .distributed import (barrier, get_dist_info, get_rank, is_main_process, + setup_multi_processes, setup_slurm, + sync_tensor_across_gpus) +from .evaluation_depth import DICT_METRICS, eval_depth +from .geometric import spherical_zbuffer_to_euclidean, unproject_points +from .misc import format_seconds, get_params, identity, remove_padding +from .visualization import colorize, image_grid, log_train_artifacts + +__all__ = [ + "eval_depth", + "DICT_METRICS", + "colorize", + "image_grid", + "log_train_artifacts", + "format_seconds", + "remove_padding", + "get_params", + "identity", + "is_main_process", + "setup_multi_processes", + "setup_slurm", + "sync_tensor_across_gpus", + "barrier", + "get_rank", + "unproject_points", + "spherical_zbuffer_to_euclidean", + "validate", + "get_dist_info", +] diff --git a/unidepth/utils/__pycache__/__init__.cpython-311.pyc b/unidepth/utils/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5432f5b911d0545adfa57392971811cdf50420d1 Binary files /dev/null and b/unidepth/utils/__pycache__/__init__.cpython-311.pyc differ diff --git a/unidepth/utils/__pycache__/constants.cpython-311.pyc b/unidepth/utils/__pycache__/constants.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9dbd4ed1766861c973606678274228633fac9921 Binary files /dev/null and b/unidepth/utils/__pycache__/constants.cpython-311.pyc differ diff --git a/unidepth/utils/__pycache__/distributed.cpython-311.pyc b/unidepth/utils/__pycache__/distributed.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ee14cb1cff45dcebb89ac780e3a80d65d60c51bd Binary files /dev/null and b/unidepth/utils/__pycache__/distributed.cpython-311.pyc differ diff --git a/unidepth/utils/__pycache__/evaluation_depth.cpython-311.pyc b/unidepth/utils/__pycache__/evaluation_depth.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..007a2c2de3e032d58dc41754e63d97a6f23435dd Binary files /dev/null and b/unidepth/utils/__pycache__/evaluation_depth.cpython-311.pyc differ diff --git a/unidepth/utils/__pycache__/geometric.cpython-311.pyc b/unidepth/utils/__pycache__/geometric.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d3afb74a00861135cf110f5fd9bc2a7707dc49dd Binary files /dev/null and b/unidepth/utils/__pycache__/geometric.cpython-311.pyc differ diff --git a/unidepth/utils/__pycache__/misc.cpython-311.pyc b/unidepth/utils/__pycache__/misc.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a63135c6f078d17b62b469104d11ccc7abc8a2e9 Binary files /dev/null and b/unidepth/utils/__pycache__/misc.cpython-311.pyc differ diff --git a/unidepth/utils/__pycache__/positional_embedding.cpython-311.pyc b/unidepth/utils/__pycache__/positional_embedding.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a4fbe48e5c8562a46ba56bac965b39446f606fe4 Binary files /dev/null and b/unidepth/utils/__pycache__/positional_embedding.cpython-311.pyc differ diff --git a/unidepth/utils/__pycache__/sht.cpython-311.pyc b/unidepth/utils/__pycache__/sht.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2d03844eac9c900d77f3733a3d41dac9beb12bfe Binary files /dev/null and b/unidepth/utils/__pycache__/sht.cpython-311.pyc differ diff --git a/unidepth/utils/__pycache__/visualization.cpython-311.pyc b/unidepth/utils/__pycache__/visualization.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b9815fe4b0874dc9fe2547b00ddb1ae079e09f61 Binary files /dev/null and b/unidepth/utils/__pycache__/visualization.cpython-311.pyc differ diff --git a/unidepth/utils/constants.py b/unidepth/utils/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..fb18fd1bf8677e0e1189b19c9064b05b71db1a7a --- /dev/null +++ b/unidepth/utils/constants.py @@ -0,0 +1,22 @@ +""" +Author: Luigi Piccinelli +Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/) +""" + +import math + +import torch + +OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073) +OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711) +IMAGENET_DATASET_MEAN = (0.485, 0.456, 0.406) +IMAGENET_DATASET_STD = (0.229, 0.224, 0.225) +DEPTH_BINS = torch.cat( + ( + torch.logspace(math.log10(0.1), math.log10(180.0), steps=512), + torch.tensor([260.0]), + ), + dim=0, +) +LOGERR_BINS = torch.linspace(-2, 2, steps=128 + 1) +LINERR_BINS = torch.linspace(-50, 50, steps=256 + 1) diff --git a/unidepth/utils/distributed.py b/unidepth/utils/distributed.py new file mode 100644 index 0000000000000000000000000000000000000000..2b2f16f9d0e781af76a5ce3c3937039d7860ff74 --- /dev/null +++ b/unidepth/utils/distributed.py @@ -0,0 +1,178 @@ +""" +Author: Luigi Piccinelli +Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/) +""" + +import os +import platform +import subprocess +import warnings + +import cv2 +import torch +import torch.utils.data.distributed +from torch import distributed as dist +from torch import multiprocessing as mp + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +def barrier(): + if not is_dist_avail_and_initialized(): + return + dist.barrier() + + +def is_main_process(): + return get_rank() == 0 + + +def is_rank_zero(args): + return args.rank == 0 + + +def get_dist_info(): + if dist.is_available() and dist.is_initialized(): + rank = dist.get_rank() + world_size = dist.get_world_size() + else: + rank = 0 + world_size = 1 + return rank, world_size + + +def setup_multi_processes(cfg): + """Setup multi-processing environment variables.""" + # set multi-process start method as `fork` to speed up the training + if platform.system() != "Windows": + mp_start_method = cfg.get("mp_start_method", "fork") + current_method = mp.get_start_method(allow_none=True) + if current_method is not None and current_method != mp_start_method: + warnings.warn( + f"Multi-processing start method `{mp_start_method}` is " + f"different from the previous setting `{current_method}`." + f"It will be force set to `{mp_start_method}`. You can change " + f"this behavior by changing `mp_start_method` in your config." + ) + mp.set_start_method(mp_start_method, force=True) + + # disable opencv multithreading to avoid system being overloaded + opencv_num_threads = cfg.get("opencv_num_threads", 0) + cv2.setNumThreads(opencv_num_threads) + + # setup OMP threads + # This code is referred from https://github.com/pytorch/pytorch/blob/master/torch/distributed/run.py # noqa + workers_per_gpu = cfg.get("workers_per_gpu", 4) + + if "OMP_NUM_THREADS" not in os.environ and workers_per_gpu > 1: + omp_num_threads = 1 + warnings.warn( + f"Setting OMP_NUM_THREADS environment variable for each process " + f"to be {omp_num_threads} in default, to avoid your system being " + f"overloaded, please further tune the variable for optimal " + f"performance in your application as needed." + ) + os.environ["OMP_NUM_THREADS"] = str(omp_num_threads) + + # setup MKL threads + if "MKL_NUM_THREADS" not in os.environ and workers_per_gpu > 1: + mkl_num_threads = os.environ.get("OMP_NUM_THREADS", 1) + warnings.warn( + f"Setting MKL_NUM_THREADS environment variable for each process " + f"to be {mkl_num_threads} in default, to avoid your system being " + f"overloaded, please further tune the variable for optimal " + f"performance in your application as needed." + ) + os.environ["MKL_NUM_THREADS"] = str(mkl_num_threads) + + +def setup_slurm(backend: str, port: str) -> None: + """Initialize slurm distributed training environment. + If argument ``port`` is not specified, then the master port will be system + environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system + environment variable, then a default port ``29500`` will be used. + Args: + backend (str): Backend of torch.distributed. + port (int, optional): Master port. Defaults to None. + """ + proc_id = int(os.environ["SLURM_PROCID"]) + ntasks = int(os.environ["SLURM_NTASKS"]) + node_list = os.environ["SLURM_NODELIST"] + + num_gpus = torch.cuda.device_count() + + torch.cuda.set_device(proc_id % num_gpus) + addr = subprocess.getoutput(f"scontrol show hostname {node_list} | head -n1") + os.environ["MASTER_PORT"] = str(port) + os.environ["MASTER_ADDR"] = addr + os.environ["WORLD_SIZE"] = str(ntasks) + os.environ["LOCAL_RANK"] = str(proc_id % num_gpus) + os.environ["RANK"] = str(proc_id) + print( + proc_id, + ntasks, + num_gpus, + proc_id % num_gpus, + node_list, + addr, + os.environ["MASTER_PORT"], + os.system("nvidia-smi -L"), + ) + dist.init_process_group(backend, rank=proc_id, world_size=ntasks) + + +def sync_tensor_across_gpus(t, dim=0, cat=True): + if t is None or not (dist.is_available() and dist.is_initialized()): + return t + t = torch.atleast_1d(t) + group = dist.group.WORLD + group_size = torch.distributed.get_world_size(group) + + local_size = torch.tensor(t.size(dim), device=t.device) + all_sizes = [torch.zeros_like(local_size) for _ in range(group_size)] + dist.all_gather(all_sizes, local_size) + max_size = max(all_sizes) + size_diff = max_size.item() - local_size.item() + if size_diff: + padding = torch.zeros(size_diff, device=t.device, dtype=t.dtype) + t = torch.cat((t, padding)) + + gather_t_tensor = [torch.zeros_like(t) for _ in range(group_size)] + dist.all_gather(gather_t_tensor, t) + all_ts = [] + for t, size in zip(gather_t_tensor, all_sizes): + all_ts.append(t[:size]) + if cat: + return torch.cat(all_ts, dim=0) + return all_ts + + +import pickle + + +def sync_string_across_gpus(keys: list[str], device, dim=0): + keys_serialized = pickle.dumps(keys, protocol=pickle.HIGHEST_PROTOCOL) + keys_serialized_tensor = torch.frombuffer(keys_serialized, dtype=torch.uint8).to( + device + ) + keys_serialized_tensor = sync_tensor_across_gpus( + keys_serialized_tensor, dim=0, cat=False + ) + keys = [ + key + for keys in keys_serialized_tensor + for key in pickle.loads(bytes(keys.cpu().tolist())) + ] + return keys diff --git a/unidepth/utils/ema_torch.py b/unidepth/utils/ema_torch.py new file mode 100644 index 0000000000000000000000000000000000000000..bea1d29a8c224b81b19ac68fe02453a4330ed58c --- /dev/null +++ b/unidepth/utils/ema_torch.py @@ -0,0 +1,341 @@ +""" +Author: Luigi Piccinelli +Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/) +""" + +from __future__ import division, unicode_literals + +import contextlib +import copy +import weakref +from math import tanh +from typing import Iterable, Optional + +import torch + + +class DummyExponentialMovingAverage: + def __init__(self, *args, **kwargs): + pass + + def _get_parameters(self, *args, **kwargs): + pass + + def get_current_decay(self, *args, **kwargs): + pass + + def update(self, *args, **kwargs): + pass + + def copy_to(self, *args, **kwargs): + pass + + def store(self, *args, **kwargs): + return + + def restore(self, *args, **kwargs): + return + + @contextlib.contextmanager + def average_parameters(self, *args, **kwargs): + try: + yield + finally: + pass + + def to(self, *args, **kwargs): + pass + + def state_dict(self, *args, **kwargs): + pass + + def load_state_dict(self, *args, **kwargs): + pass + + +class ExponentialMovingAverage: + """ + Maintains (exponential) moving average of a set of parameters. + + Args: + parameters: Iterable of `torch.nn.Parameter` (typically from + `model.parameters()`). + Note that EMA is computed on *all* provided parameters, + regardless of whether or not they have `requires_grad = True`; + this allows a single EMA object to be consistantly used even + if which parameters are trainable changes step to step. + + If you want to some parameters in the EMA, do not pass them + to the object in the first place. For example: + + ExponentialMovingAverage( + parameters=[p for p in model.parameters() if p.requires_grad], + decay=0.9 + ) + + will ignore parameters that do not require grad. + + decay: The exponential decay. + + use_num_updates: Whether to use number of updates when computing + averages. + """ + + def __init__( + self, + parameters: Iterable[torch.nn.Parameter], + decay: float, + use_num_updates: bool = True, + update_after_step: int = 10000, + tau: int = 20000, + switch: bool = False, + ): + if decay < 0.0 or decay > 1.0: + raise ValueError("Decay must be between 0 and 1") + self.decay = decay + self.switch = switch # fi keeping EMA params in model after epochs + self.num_updates = 0 if use_num_updates else None + parameters = list(parameters) + self.shadow_params = [p.clone().detach() for p in parameters] + self.collected_params = None + # By maintaining only a weakref to each parameter, + # we maintain the old GC behaviour of ExponentialMovingAverage: + # if the model goes out of scope but the ExponentialMovingAverage + # is kept, no references to the model or its parameters will be + # maintained, and the model will be cleaned up. + self._params_refs = [weakref.ref(p) for p in parameters] + self.update_after_step = update_after_step + self.tau = tau + + def _get_parameters( + self, parameters: Optional[Iterable[torch.nn.Parameter]] + ) -> Iterable[torch.nn.Parameter]: + if parameters is None: + parameters = [p() for p in self._params_refs] + if any(p is None for p in parameters): + raise ValueError( + "(One of) the parameters with which this ExponentialMovingAverage was initialized no longer exists (was garbage collected);" + " please either provide `parameters` explicitly or keep the model to which they belong from being garbage collected." + ) + return parameters + else: + parameters = list(parameters) + if len(parameters) != len(self.shadow_params): + raise ValueError( + "Number of parameters passed as argument is different " + "from number of shadow parameters maintained by this " + "ExponentialMovingAverage" + ) + return parameters + + def get_current_decay(self): + epoch = max(self.num_updates - self.update_after_step - 1, 0.0) + if epoch <= 0: + return 0.0 + value = tanh(epoch / self.tau) * self.decay + return value + + def update(self, parameters: Optional[Iterable[torch.nn.Parameter]] = None) -> None: + """ + Update currently maintained parameters. + + Call this every time the parameters are updated, such as the result of + the `optimizer.step()` call. + + Args: + parameters: Iterable of `torch.nn.Parameter`; usually the same set of + parameters used to initialize this object. If `None`, the + parameters with which this `ExponentialMovingAverage` was + initialized will be used. + """ + parameters = self._get_parameters(parameters) + decay = self.get_current_decay() + if self.num_updates is not None: + self.num_updates += 1 + + one_minus_decay = 1.0 - decay + with torch.no_grad(): + for s_param, param in zip(self.shadow_params, parameters): + tmp = s_param - param + # tmp will be a new tensor so we can do in-place + tmp.mul_(one_minus_decay) + s_param.sub_(tmp) + + def copy_to( + self, parameters: Optional[Iterable[torch.nn.Parameter]] = None + ) -> None: + """ + Copy current averaged parameters into given collection of parameters. + + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored moving averages. If `None`, the + parameters with which this `ExponentialMovingAverage` was + initialized will be used. + """ + parameters = self._get_parameters(parameters) + for s_param, param in zip(self.shadow_params, parameters): + param.data.copy_(s_param.data) + + def store(self, parameters: Optional[Iterable[torch.nn.Parameter]] = None) -> None: + """ + Save the current parameters for restoring later. + + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + temporarily stored. If `None`, the parameters of with which this + `ExponentialMovingAverage` was initialized will be used. + """ + parameters = self._get_parameters(parameters) + self.collected_params = [param.detach().clone() for param in parameters] + + def restore( + self, parameters: Optional[Iterable[torch.nn.Parameter]] = None + ) -> None: + """ + Restore the parameters stored with the `store` method. + Useful to validate the model with EMA parameters without affecting the + original optimization process. Store the parameters before the + `copy_to` method. After validation (or model saving), use this to + restore the former parameters. + + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored parameters. If `None`, the + parameters with which this `ExponentialMovingAverage` was + initialized will be used. + """ + if self.collected_params is None: + raise RuntimeError( + "This ExponentialMovingAverage has no `store()`ed weights " + "to `restore()`" + ) + parameters = self._get_parameters(parameters) + for c_param, param in zip(self.collected_params, parameters): + param.data.copy_(c_param.data) + + @contextlib.contextmanager + def average_parameters( + self, parameters: Optional[Iterable[torch.nn.Parameter]] = None + ): + r""" + Context manager for validation/inference with averaged parameters. + + Equivalent to: + + ema.store() + ema.copy_to() + try: + ... + finally: + ema.restore() + + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored parameters. If `None`, the + parameters with which this `ExponentialMovingAverage` was + initialized will be used. + """ + parameters = self._get_parameters(parameters) + self.store(parameters) + self.copy_to(parameters) + try: + yield + finally: + if not self.switch: + self.restore(parameters) + + def to(self, device=None, dtype=None) -> None: + r"""Move internal buffers of the ExponentialMovingAverage to `device`. + + Args: + device: like `device` argument to `torch.Tensor.to` + """ + # .to() on the tensors handles None correctly + self.shadow_params = [ + ( + p.to(device=device, dtype=dtype) + if p.is_floating_point() + else p.to(device=device) + ) + for p in self.shadow_params + ] + if self.collected_params is not None: + self.collected_params = [ + ( + p.to(device=device, dtype=dtype) + if p.is_floating_point() + else p.to(device=device) + ) + for p in self.collected_params + ] + return + + def state_dict(self) -> dict: + r"""Returns the state of the ExponentialMovingAverage as a dict.""" + # Following PyTorch conventions, references to tensors are returned: + # "returns a reference to the state and not its copy!" - + # https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict + return { + "decay": self.decay, + "num_updates": self.num_updates, + "shadow_params": self.shadow_params, + "collected_params": self.collected_params, + } + + def load_state_dict(self, state_dict: dict) -> None: + r"""Loads the ExponentialMovingAverage state. + + Args: + state_dict (dict): EMA state. Should be an object returned + from a call to :meth:`state_dict`. + """ + # deepcopy, to be consistent with module API + state_dict = copy.deepcopy(state_dict) + self.decay = state_dict["decay"] + if self.decay < 0.0 or self.decay > 1.0: + raise ValueError("Decay must be between 0 and 1") + self.num_updates = state_dict["num_updates"] + assert self.num_updates is None or isinstance( + self.num_updates, int + ), "Invalid num_updates" + + self.shadow_params = state_dict["shadow_params"] + assert isinstance(self.shadow_params, list), "shadow_params must be a list" + assert all( + isinstance(p, torch.Tensor) for p in self.shadow_params + ), "shadow_params must all be Tensors" + + self.collected_params = state_dict["collected_params"] + if self.collected_params is not None: + assert isinstance( + self.collected_params, list + ), "collected_params must be a list" + assert all( + isinstance(p, torch.Tensor) for p in self.collected_params + ), "collected_params must all be Tensors" + assert len(self.collected_params) == len( + self.shadow_params + ), "collected_params and shadow_params had different lengths" + + if len(self.shadow_params) == len(self._params_refs): + # Consistant with torch.optim.Optimizer, cast things to consistant + # device and dtype with the parameters + params = [p() for p in self._params_refs] + # If parameters have been garbage collected, just load the state + # we were given without change. + if not any(p is None for p in params): + # ^ parameter references are still good + for i, p in enumerate(params): + self.shadow_params[i] = self.shadow_params[i].to( + device=p.device, dtype=p.dtype + ) + if self.collected_params is not None: + self.collected_params[i] = self.collected_params[i].to( + device=p.device, dtype=p.dtype + ) + else: + raise ValueError( + "Tried to `load_state_dict()` with the wrong number of " + "parameters in the saved state." + ) diff --git a/unidepth/utils/evaluation_depth.py b/unidepth/utils/evaluation_depth.py new file mode 100644 index 0000000000000000000000000000000000000000..ab35a9335f0947332ab8419b5cbe4940368b2a0c --- /dev/null +++ b/unidepth/utils/evaluation_depth.py @@ -0,0 +1,174 @@ +""" +Author: Luigi Piccinelli +Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/) +""" + +# We prefer not to install PyTorch3D in the package +# Code commented is how 3D metrics are computed + +from collections import defaultdict +from functools import partial + +import torch +import torch.nn.functional as F + +from unidepth.utils.constants import DEPTH_BINS + +# from chamfer_distance import ChamferDistance + + +# chamfer_cls = ChamferDistance() + + +# def chamfer_dist(tensor1, tensor2): +# x_lengths = torch.tensor((tensor1.shape[1],), device=tensor1.device) +# y_lengths = torch.tensor((tensor2.shape[1],), device=tensor2.device) +# dist1, dist2, idx1, idx2 = chamfer_cls( +# tensor1, tensor2, x_lengths=x_lengths, y_lengths=y_lengths +# ) +# return (torch.sqrt(dist1) + torch.sqrt(dist2)) / 2 + + +# def auc(tensor1, tensor2, thresholds): +# x_lengths = torch.tensor((tensor1.shape[1],), device=tensor1.device) +# y_lengths = torch.tensor((tensor2.shape[1],), device=tensor2.device) +# dist1, dist2, idx1, idx2 = chamfer_cls( +# tensor1, tensor2, x_lengths=x_lengths, y_lengths=y_lengths +# ) +# # compute precision recall +# precisions = [(dist1 < threshold).sum() / dist1.numel() for threshold in thresholds] +# recalls = [(dist2 < threshold).sum() / dist2.numel() for threshold in thresholds] +# auc_value = torch.trapz( +# torch.tensor(precisions, device=tensor1.device), +# torch.tensor(recalls, device=tensor1.device), +# ) +# return auc_value + + +def delta(tensor1, tensor2, exponent): + inlier = torch.maximum((tensor1 / tensor2), (tensor2 / tensor1)) + return (inlier < 1.25**exponent).to(torch.float32).mean() + + +def ssi(tensor1, tensor2, qtl=0.05): + stability_mat = 1e-9 * torch.eye(2, device=tensor1.device) + error = (tensor1 - tensor2).abs() + mask = error < torch.quantile(error, 1 - qtl) + tensor1_mask = tensor1[mask] + tensor2_mask = tensor2[mask] + tensor2_one = torch.stack( + [tensor2_mask.detach(), torch.ones_like(tensor2_mask).detach()], dim=1 + ) + scale_shift = torch.inverse(tensor2_one.T @ tensor2_one + stability_mat) @ ( + tensor2_one.T @ tensor1_mask.unsqueeze(1) + ) + scale, shift = scale_shift.squeeze().chunk(2, dim=0) + return tensor2 * scale + shift + # tensor2_one = torch.stack([tensor2.detach(), torch.ones_like(tensor2).detach()], dim=1) + # scale_shift = torch.inverse(tensor2_one.T @ tensor2_one + stability_mat) @ (tensor2_one.T @ tensor1.unsqueeze(1)) + # scale, shift = scale_shift.squeeze().chunk(2, dim=0) + # return tensor2 * scale + shift + + +def d1_ssi(tensor1, tensor2): + delta_ = delta(tensor1, ssi(tensor1, tensor2), 1.0) + return delta_ + + +def d_auc(tensor1, tensor2): + exponents = torch.linspace(0.01, 5.0, steps=100, device=tensor1.device) + deltas = [delta(tensor1, tensor2, exponent) for exponent in exponents] + return torch.trapz(torch.tensor(deltas, device=tensor1.device), exponents) / 5.0 + + +# def f1_score(tensor1, tensor2, thresholds): +# x_lengths = torch.tensor((tensor1.shape[1],), device=tensor1.device) +# y_lengths = torch.tensor((tensor2.shape[1],), device=tensor2.device) +# dist1, dist2, idx1, idx2 = chamfer_cls( +# tensor1, tensor2, x_lengths=x_lengths, y_lengths=y_lengths +# ) +# # compute precision recall +# precisions = [(dist1 < threshold).sum() / dist1.numel() for threshold in thresholds] +# recalls = [(dist2 < threshold).sum() / dist2.numel() for threshold in thresholds] +# precisions = torch.tensor(precisions, device=tensor1.device) +# recalls = torch.tensor(recalls, device=tensor1.device) +# f1_thresholds = 2 * precisions * recalls / (precisions + recalls) +# f1_thresholds = torch.where( +# torch.isnan(f1_thresholds), torch.zeros_like(f1_thresholds), f1_thresholds +# ) +# f1_value = torch.trapz(f1_thresholds) / len(thresholds) +# return f1_value + + +DICT_METRICS = { + "d1": partial(delta, exponent=1.0), + "d2": partial(delta, exponent=2.0), + "d3": partial(delta, exponent=3.0), + "rmse": lambda gt, pred: torch.sqrt(((gt - pred) ** 2).mean()), + "rmselog": lambda gt, pred: torch.sqrt( + ((torch.log(gt) - torch.log(pred)) ** 2).mean() + ), + "arel": lambda gt, pred: (torch.abs(gt - pred) / gt).mean(), + "sqrel": lambda gt, pred: (((gt - pred) ** 2) / gt).mean(), + "log10": lambda gt, pred: torch.abs(torch.log10(pred) - torch.log10(gt)).mean(), + "silog": lambda gt, pred: 100 * torch.std(torch.log(pred) - torch.log(gt)).mean(), + "medianlog": lambda gt, pred: 100 + * (torch.log(pred) - torch.log(gt)).median().abs(), + "d_auc": d_auc, + "d1_ssi": d1_ssi, +} + + +# DICT_METRICS_3D = { +# "chamfer": lambda gt, pred, thresholds: chamfer_dist( +# gt.unsqueeze(0).permute(0, 2, 1), pred.unsqueeze(0).permute(0, 2, 1) +# ), +# "F1": lambda gt, pred, thresholds: f1_score( +# gt.unsqueeze(0).permute(0, 2, 1), +# pred.unsqueeze(0).permute(0, 2, 1), +# thresholds=thresholds, +# ), +# } + + +DICT_METRICS_D = { + "a1": lambda gt, pred: (torch.maximum((gt / pred), (pred / gt)) > 1.25**1.0).to( + torch.float32 + ), + "abs_rel": lambda gt, pred: (torch.abs(gt - pred) / gt), +} + + +def eval_depth( + gts: torch.Tensor, preds: torch.Tensor, masks: torch.Tensor, max_depth=None +): + summary_metrics = defaultdict(list) + preds = F.interpolate(preds, gts.shape[-2:], mode="bilinear") + for i, (gt, pred, mask) in enumerate(zip(gts, preds, masks)): + if max_depth is not None: + mask = torch.logical_and(mask, gt <= max_depth) + for name, fn in DICT_METRICS.items(): + summary_metrics[name].append(fn(gt[mask], pred[mask]).mean()) + return {name: torch.stack(vals, dim=0) for name, vals in summary_metrics.items()} + + +# def eval_3d( +# gts: torch.Tensor, preds: torch.Tensor, masks: torch.Tensor, thresholds=None +# ): +# summary_metrics = defaultdict(list) +# w_max = min(gts.shape[-1] // 4, 400) +# gts = F.interpolate( +# gts, (int(w_max * gts.shape[-2] / gts.shape[-1]), w_max), mode="nearest" +# ) +# preds = F.interpolate(preds, gts.shape[-2:], mode="nearest") +# masks = F.interpolate( +# masks.to(torch.float32), gts.shape[-2:], mode="nearest" +# ).bool() +# for i, (gt, pred, mask) in enumerate(zip(gts, preds, masks)): +# if not torch.any(mask): +# continue +# for name, fn in DICT_METRICS_3D.items(): +# summary_metrics[name].append( +# fn(gt[:, mask.squeeze()], pred[:, mask.squeeze()], thresholds).mean() +# ) +# return {name: torch.stack(vals, dim=0) for name, vals in summary_metrics.items()} diff --git a/unidepth/utils/geometric.py b/unidepth/utils/geometric.py new file mode 100644 index 0000000000000000000000000000000000000000..a9f288de30b9220ea7390cd635a4b421cc937696 --- /dev/null +++ b/unidepth/utils/geometric.py @@ -0,0 +1,252 @@ +""" +Author: Luigi Piccinelli +Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/) +""" + +from typing import Tuple + +import torch +from torch.nn import functional as F + + +@torch.jit.script +def generate_rays( + camera_intrinsics: torch.Tensor, image_shape: Tuple[int, int], noisy: bool = False +): + batch_size, device, dtype = ( + camera_intrinsics.shape[0], + camera_intrinsics.device, + camera_intrinsics.dtype, + ) + height, width = image_shape + # Generate grid of pixel coordinates + pixel_coords_x = torch.linspace(0, width - 1, width, device=device, dtype=dtype) + pixel_coords_y = torch.linspace(0, height - 1, height, device=device, dtype=dtype) + if noisy: + pixel_coords_x += torch.rand_like(pixel_coords_x) - 0.5 + pixel_coords_y += torch.rand_like(pixel_coords_y) - 0.5 + pixel_coords = torch.stack( + [pixel_coords_x.repeat(height, 1), pixel_coords_y.repeat(width, 1).t()], dim=2 + ) # (H, W, 2) + pixel_coords = pixel_coords + 0.5 + + # Calculate ray directions + intrinsics_inv = torch.eye(3, device=device).unsqueeze(0).repeat(batch_size, 1, 1) + intrinsics_inv[:, 0, 0] = 1.0 / camera_intrinsics[:, 0, 0] + intrinsics_inv[:, 1, 1] = 1.0 / camera_intrinsics[:, 1, 1] + intrinsics_inv[:, 0, 2] = -camera_intrinsics[:, 0, 2] / camera_intrinsics[:, 0, 0] + intrinsics_inv[:, 1, 2] = -camera_intrinsics[:, 1, 2] / camera_intrinsics[:, 1, 1] + homogeneous_coords = torch.cat( + [pixel_coords, torch.ones_like(pixel_coords[:, :, :1])], dim=2 + ) # (H, W, 3) + ray_directions = torch.matmul( + intrinsics_inv, homogeneous_coords.permute(2, 0, 1).flatten(1) + ) # (3, H*W) + ray_directions = F.normalize(ray_directions, dim=1) # (B, 3, H*W) + ray_directions = ray_directions.permute(0, 2, 1) # (B, H*W, 3) + + theta = torch.atan2(ray_directions[..., 0], ray_directions[..., -1]) + phi = torch.acos(ray_directions[..., 1]) + # pitch = torch.asin(ray_directions[..., 1]) + # roll = torch.atan2(ray_directions[..., 0], - ray_directions[..., 1]) + angles = torch.stack([theta, phi], dim=-1) + return ray_directions, angles + + +@torch.jit.script +def spherical_zbuffer_to_euclidean(spherical_tensor: torch.Tensor) -> torch.Tensor: + theta = spherical_tensor[..., 0] # Extract polar angle + phi = spherical_tensor[..., 1] # Extract azimuthal angle + z = spherical_tensor[..., 2] # Extract zbuffer depth + + # y = r * cos(phi) + # x = r * sin(phi) * sin(theta) + # z = r * sin(phi) * cos(theta) + # => + # r = z / sin(phi) / cos(theta) + # y = z / (sin(phi) / cos(phi)) / cos(theta) + # x = z * sin(theta) / cos(theta) + x = z * torch.tan(theta) + y = z / torch.tan(phi) / torch.cos(theta) + + euclidean_tensor = torch.stack((x, y, z), dim=-1) + return euclidean_tensor + + +@torch.jit.script +def spherical_to_euclidean(spherical_tensor: torch.Tensor) -> torch.Tensor: + theta = spherical_tensor[..., 0] # Extract polar angle + phi = spherical_tensor[..., 1] # Extract azimuthal angle + r = spherical_tensor[..., 2] # Extract radius + # y = r * cos(phi) + # x = r * sin(phi) * sin(theta) + # z = r * sin(phi) * cos(theta) + x = r * torch.sin(phi) * torch.sin(theta) + y = r * torch.cos(phi) + z = r * torch.cos(theta) * torch.sin(phi) + + euclidean_tensor = torch.stack((x, y, z), dim=-1) + return euclidean_tensor + + +@torch.jit.script +def euclidean_to_spherical(spherical_tensor: torch.Tensor) -> torch.Tensor: + x = spherical_tensor[..., 0] # Extract polar angle + y = spherical_tensor[..., 1] # Extract azimuthal angle + z = spherical_tensor[..., 2] # Extract radius + # y = r * cos(phi) + # x = r * sin(phi) * sin(theta) + # z = r * sin(phi) * cos(theta) + r = torch.sqrt(x**2 + y**2 + z**2) + theta = torch.atan2(x / r, z / r) + phi = torch.acos(y / r) + + euclidean_tensor = torch.stack((theta, phi, r), dim=-1) + return euclidean_tensor + + +@torch.jit.script +def euclidean_to_spherical_zbuffer(euclidean_tensor: torch.Tensor) -> torch.Tensor: + pitch = torch.asin(euclidean_tensor[..., 1]) + yaw = torch.atan2(euclidean_tensor[..., 0], euclidean_tensor[..., -1]) + z = euclidean_tensor[..., 2] # Extract zbuffer depth + euclidean_tensor = torch.stack((pitch, yaw, z), dim=-1) + return euclidean_tensor + + +@torch.jit.script +def unproject_points( + depth: torch.Tensor, camera_intrinsics: torch.Tensor +) -> torch.Tensor: + """ + Unprojects a batch of depth maps to 3D point clouds using camera intrinsics. + + Args: + depth (torch.Tensor): Batch of depth maps of shape (B, 1, H, W). + camera_intrinsics (torch.Tensor): Camera intrinsic matrix of shape (B, 3, 3). + + Returns: + torch.Tensor: Batch of 3D point clouds of shape (B, 3, H, W). + """ + batch_size, _, height, width = depth.shape + device = depth.device + + # Create pixel grid + y_coords, x_coords = torch.meshgrid( + torch.arange(height, device=device), + torch.arange(width, device=device), + indexing="ij", + ) + pixel_coords = torch.stack((x_coords, y_coords), dim=-1) # (H, W, 2) + + # Get homogeneous coords (u v 1) + pixel_coords_homogeneous = torch.cat( + (pixel_coords, torch.ones((height, width, 1), device=device)), dim=-1 + ) + pixel_coords_homogeneous = pixel_coords_homogeneous.permute(2, 0, 1).flatten( + 1 + ) # (3, H*W) + # Apply K^-1 @ (u v 1): [B, 3, 3] @ [3, H*W] -> [B, 3, H*W] + unprojected_points = torch.matmul( + torch.inverse(camera_intrinsics), pixel_coords_homogeneous + ) # (B, 3, H*W) + unprojected_points = unprojected_points.view( + batch_size, 3, height, width + ) # (B, 3, H, W) + unprojected_points = unprojected_points * depth # (B, 3, H, W) + return unprojected_points + + +@torch.jit.script +def project_points( + points_3d: torch.Tensor, + intrinsic_matrix: torch.Tensor, + image_shape: Tuple[int, int], +) -> torch.Tensor: + # Project 3D points onto the image plane via intrinsics (u v w) = (x y z) @ K^T + points_2d = torch.matmul(points_3d, intrinsic_matrix.transpose(1, 2)) + + # Normalize projected points: (u v w) -> (u / w, v / w, 1) + points_2d = points_2d[..., :2] / points_2d[..., 2:] + + points_2d = points_2d.int() + + # points need to be inside the image (can it diverge onto all points out???) + valid_mask = ( + (points_2d[..., 0] >= 0) + & (points_2d[..., 0] < image_shape[1]) + & (points_2d[..., 1] >= 0) + & (points_2d[..., 1] < image_shape[0]) + ) + + # Calculate the flat indices of the valid pixels + flat_points_2d = points_2d[..., 0] + points_2d[..., 1] * image_shape[1] + flat_indices = flat_points_2d.long() + + # Create depth maps and counts using scatter_add, (B, H, W) + depth_maps = torch.zeros( + [points_3d.shape[0], *image_shape], device=points_3d.device + ) + counts = torch.zeros([points_3d.shape[0], *image_shape], device=points_3d.device) + + # Loop over batches to apply masks and accumulate depth/count values + for i in range(points_3d.shape[0]): + valid_indices = flat_indices[i, valid_mask[i]] + depth_maps[i].view(-1).scatter_add_( + 0, valid_indices, points_3d[i, valid_mask[i], 2] + ) + counts[i].view(-1).scatter_add_( + 0, valid_indices, torch.ones_like(points_3d[i, valid_mask[i], 2]) + ) + + # Calculate mean depth for each pixel in each batch + mean_depth_maps = depth_maps / counts.clamp(min=1.0) + return mean_depth_maps.reshape(-1, 1, *image_shape) # (B, 1, H, W) + + +@torch.jit.script +def downsample(data: torch.Tensor, downsample_factor: int = 2): + N, _, H, W = data.shape + data = data.view( + N, + H // downsample_factor, + downsample_factor, + W // downsample_factor, + downsample_factor, + 1, + ) + data = data.permute(0, 1, 3, 5, 2, 4).contiguous() + data = data.view(-1, downsample_factor * downsample_factor) + data_tmp = torch.where(data == 0.0, 1e5 * torch.ones_like(data), data) + data = torch.min(data_tmp, dim=-1).values + data = data.view(N, 1, H // downsample_factor, W // downsample_factor) + data = torch.where(data > 1000, torch.zeros_like(data), data) + return data + + +@torch.jit.script +def flat_interpolate( + flat_tensor: torch.Tensor, + old: Tuple[int, int], + new: Tuple[int, int], + antialias: bool = True, + mode: str = "bilinear", +) -> torch.Tensor: + if old[0] == new[0] and old[1] == new[1]: + return flat_tensor + tensor = flat_tensor.reshape(flat_tensor.shape[0], old[0], old[1], -1).permute( + 0, 3, 1, 2 + ) # b c h w + tensor_interp = F.interpolate( + tensor, + size=(new[0], new[1]), + mode=mode, + align_corners=False, + antialias=antialias, + ) + flat_tensor_interp = tensor_interp.reshape( + flat_tensor.shape[0], -1, new[0] * new[1] + ).permute( + 0, 2, 1 + ) # b (h w) c + return flat_tensor_interp.contiguous() diff --git a/unidepth/utils/misc.py b/unidepth/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..7084f074c8aaa241a5ca76f2ec6e7a5f1871f152 --- /dev/null +++ b/unidepth/utils/misc.py @@ -0,0 +1,418 @@ +from collections import defaultdict +from functools import partial, wraps + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange, reduce, repeat +from scipy import interpolate + + +def max_stack(tensors): + if len(tensors) == 1: + return tensors[0] + return torch.stack(tensors, dim=-1).max(dim=-1).values + + +def last_stack(tensors): + return tensors[-1] + + +def first_stack(tensors): + return tensors[0] + + +def softmax_stack(tensors, temperature=1.0): + if len(tensors) == 1: + return tensors[0] + return F.softmax(torch.stack(tensors, dim=-1) / temperature, dim=-1).sum(dim=-1) + + +def mean_stack(tensors): + if len(tensors) == 1: + return tensors[0] + return torch.stack(tensors, dim=-1).mean(dim=-1) + + +def sum_stack(tensors): + if len(tensors) == 1: + return tensors[0] + return torch.stack(tensors, dim=-1).sum(dim=-1) + + +def convert_module_to_f16(l): + """ + Convert primitive modules to float16. + """ + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): + l.weight.data = l.weight.data.half() + if l.bias is not None: + l.bias.data = l.bias.data.half() + + +def convert_module_to_f32(l): + """ + Convert primitive modules to float32, undoing convert_module_to_f16(). + """ + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): + l.weight.data = l.weight.data.float() + if l.bias is not None: + l.bias.data = l.bias.data.float() + + +def format_seconds(seconds): + minutes, seconds = divmod(seconds, 60) + hours, minutes = divmod(minutes, 60) + return f"{hours:d}:{minutes:02d}:{seconds:02d}" + + +def get_params(module, lr, wd): + skip_list = {} + skip_keywords = {} + if hasattr(module, "no_weight_decay"): + skip_list = module.no_weight_decay() + if hasattr(module, "no_weight_decay_keywords"): + skip_keywords = module.no_weight_decay_keywords() + has_decay = [] + no_decay = [] + for name, param in module.named_parameters(): + if not param.requires_grad: + continue # frozen weights + if ( + (name in skip_list) + or any((kw in name for kw in skip_keywords)) + or len(param.shape) == 1 + ): + # if (name in skip_list) or any((kw in name for kw in skip_keywords)): + # print(name, skip_keywords) + no_decay.append(param) + else: + has_decay.append(param) + + group1 = { + "params": has_decay, + "weight_decay": wd, + "lr": lr, + "weight_decay_init": wd, + "weight_decay_base": wd, + "lr_init": lr, + "lr_base": lr, + } + group2 = { + "params": no_decay, + "weight_decay": 0.0, + "lr": lr, + "weight_decay_init": 0.0, + "weight_decay_base": 0.0, + "weight_decay_final": 0.0, + "lr_init": lr, + "lr_base": lr, + } + return [group1, group2], [lr, lr] + + +def get_num_layer_for_swin(var_name, num_max_layer, layers_per_stage): + if var_name in ("cls_token", "mask_token", "pos_embed", "absolute_pos_embed"): + return 0 + elif var_name.startswith("patch_embed"): + return 0 + elif var_name.startswith("layers"): + if var_name.split(".")[2] == "blocks": + stage_id = int(var_name.split(".")[1]) + layer_id = int(var_name.split(".")[3]) + sum(layers_per_stage[:stage_id]) + return layer_id + 1 + elif var_name.split(".")[2] == "downsample": + stage_id = int(var_name.split(".")[1]) + layer_id = sum(layers_per_stage[: stage_id + 1]) + return layer_id + else: + return num_max_layer - 1 + + +def get_params_layerdecayswin(module, lr, wd, ld): + skip_list = {} + skip_keywords = {} + if hasattr(module, "no_weight_decay"): + skip_list = module.no_weight_decay() + if hasattr(module, "no_weight_decay_keywords"): + skip_keywords = module.no_weight_decay_keywords() + layers_per_stage = module.depths + num_layers = sum(layers_per_stage) + 1 + lrs = [] + params = [] + for name, param in module.named_parameters(): + if not param.requires_grad: + print(f"{name} frozen") + continue # frozen weights + layer_id = get_num_layer_for_swin(name, num_layers, layers_per_stage) + lr_cur = lr * ld ** (num_layers - layer_id - 1) + # if (name in skip_list) or any((kw in name for kw in skip_keywords)) or len(param.shape) == 1 or name.endswith(".bias"): + if (name in skip_list) or any((kw in name for kw in skip_keywords)): + wd_cur = 0.0 + else: + wd_cur = wd + params.append({"params": param, "weight_decay": wd_cur, "lr": lr_cur}) + lrs.append(lr_cur) + return params, lrs + + +def log(t, eps: float = 1e-5): + return torch.log(t.clamp(min=eps)) + + +def l2norm(t): + return F.normalize(t, dim=-1) + + +def exists(val): + return val is not None + + +def identity(t, *args, **kwargs): + return t + + +def divisible_by(numer, denom): + return (numer % denom) == 0 + + +def first(arr, d=None): + if len(arr) == 0: + return d + return arr[0] + + +def default(val, d): + if exists(val): + return val + return d() if callable(d) else d + + +def maybe(fn): + @wraps(fn) + def inner(x): + if not exists(x): + return x + return fn(x) + + return inner + + +def once(fn): + called = False + + @wraps(fn) + def inner(x): + nonlocal called + if called: + return + called = True + return fn(x) + + return inner + + +def _many(fn): + @wraps(fn) + def inner(tensors, pattern, **kwargs): + return (fn(tensor, pattern, **kwargs) for tensor in tensors) + + return inner + + +rearrange_many = _many(rearrange) +repeat_many = _many(repeat) +reduce_many = _many(reduce) + + +def load_pretrained(state_dict, checkpoint): + checkpoint_model = checkpoint["model"] + if any([True if "encoder." in k else False for k in checkpoint_model.keys()]): + checkpoint_model = { + k.replace("encoder.", ""): v + for k, v in checkpoint_model.items() + if k.startswith("encoder.") + } + print("Detect pre-trained model, remove [encoder.] prefix.") + else: + print("Detect non-pre-trained model, pass without doing anything.") + print(f">>>>>>>>>> Remapping pre-trained keys for SWIN ..........") + checkpoint = load_checkpoint_swin(state_dict, checkpoint_model) + + +def load_checkpoint_swin(model, checkpoint_model): + state_dict = model.state_dict() + # Geometric interpolation when pre-trained patch size mismatch with fine-tuned patch size + all_keys = list(checkpoint_model.keys()) + for key in all_keys: + if "relative_position_bias_table" in key: + relative_position_bias_table_pretrained = checkpoint_model[key] + relative_position_bias_table_current = state_dict[key] + L1, nH1 = relative_position_bias_table_pretrained.size() + L2, nH2 = relative_position_bias_table_current.size() + if nH1 != nH2: + print(f"Error in loading {key}, passing......") + else: + if L1 != L2: + print(f"{key}: Interpolate relative_position_bias_table using geo.") + src_size = int(L1**0.5) + dst_size = int(L2**0.5) + + def geometric_progression(a, r, n): + return a * (1.0 - r**n) / (1.0 - r) + + left, right = 1.01, 1.5 + while right - left > 1e-6: + q = (left + right) / 2.0 + gp = geometric_progression(1, q, src_size // 2) + if gp > dst_size // 2: + right = q + else: + left = q + + # if q > 1.090307: + # q = 1.090307 + + dis = [] + cur = 1 + for i in range(src_size // 2): + dis.append(cur) + cur += q ** (i + 1) + + r_ids = [-_ for _ in reversed(dis)] + + x = r_ids + [0] + dis + y = r_ids + [0] + dis + + t = dst_size // 2.0 + dx = np.arange(-t, t + 0.1, 1.0) + dy = np.arange(-t, t + 0.1, 1.0) + + print("Original positions = %s" % str(x)) + print("Target positions = %s" % str(dx)) + + all_rel_pos_bias = [] + + for i in range(nH1): + z = ( + relative_position_bias_table_pretrained[:, i] + .view(src_size, src_size) + .float() + .numpy() + ) + f_cubic = interpolate.interp2d(x, y, z, kind="cubic") + all_rel_pos_bias.append( + torch.Tensor(f_cubic(dx, dy)) + .contiguous() + .view(-1, 1) + .to(relative_position_bias_table_pretrained.device) + ) + + new_rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1) + checkpoint_model[key] = new_rel_pos_bias + + # delete relative_position_index since we always re-init it + relative_position_index_keys = [ + k for k in checkpoint_model.keys() if "relative_position_index" in k + ] + for k in relative_position_index_keys: + del checkpoint_model[k] + + # delete relative_coords_table since we always re-init it + relative_coords_table_keys = [ + k for k in checkpoint_model.keys() if "relative_coords_table" in k + ] + for k in relative_coords_table_keys: + del checkpoint_model[k] + + # # re-map keys due to name change + rpe_mlp_keys = [k for k in checkpoint_model.keys() if "cpb_mlp" in k] + for k in rpe_mlp_keys: + checkpoint_model[k.replace("cpb_mlp", "rpe_mlp")] = checkpoint_model.pop(k) + + # delete attn_mask since we always re-init it + attn_mask_keys = [k for k in checkpoint_model.keys() if "attn_mask" in k] + for k in attn_mask_keys: + del checkpoint_model[k] + + encoder_keys = [k for k in checkpoint_model.keys() if k.startswith("encoder.")] + for k in encoder_keys: + checkpoint_model[k.replace("encoder.", "")] = checkpoint_model.pop(k) + + return checkpoint_model + + +def add_padding_metas(out, image_metas): + device = out.device + # left, right, top, bottom + paddings = [img_meta.get("padding_size", [0] * 4) for img_meta in image_metas] + paddings = torch.stack(paddings).to(device) + outs = [F.pad(o, padding, value=0.0) for padding, o in zip(paddings, out)] + return torch.stack(outs) + + +def remove_padding(out, paddings): + B, C, H, W = out.shape + device = out.device + # left, right, top, bottom + paddings = torch.stack(paddings).to(device) + outs = [ + o[:, padding[1] : H - padding[3], padding[0] : W - padding[2]] + for padding, o in zip(paddings, out) + ] + return torch.stack(outs) + + +def remove_padding_metas(out, image_metas): + # left, right, top, bottom + paddings = [ + torch.tensor(img_meta.get("padding_size", [0] * 4)) for img_meta in image_metas + ] + return remove_padding(out, paddings) + + +def ssi_helper(tensor1, tensor2): + stability_mat = 1e-4 * torch.eye(2, device=tensor1.device) + tensor2_one = torch.stack([tensor2, torch.ones_like(tensor2)], dim=1) + scale_shift = torch.inverse(tensor2_one.T @ tensor2_one + stability_mat) @ ( + tensor2_one.T @ tensor1.unsqueeze(1) + ) + scale, shift = scale_shift.squeeze().chunk(2, dim=0) + return scale, shift + + +def calculate_mean_values(names, values): + # Create a defaultdict to store sum and count for each name + name_values = {name: {} for name in names} + + # Iterate through the lists and accumulate values for each name + for name, value in zip(names, values): + name_values[name]["sum"] = name_values[name].get("sum", 0.0) + value + name_values[name]["count"] = name_values[name].get("count", 0.0) + 1 + + # Calculate mean values and create the output dictionary + output_dict = { + name: name_values[name]["sum"] / name_values[name]["count"] + for name in name_values + } + + return output_dict + + +def remove_leading_dim(infos): + if isinstance(infos, dict): + return {k: remove_leading_dim(v) for k, v in infos.items()} + elif isinstance(infos, torch.Tensor): + return infos.squeeze(0) + else: + return infos + + +def to_cpu(infos): + if isinstance(infos, dict): + return {k: to_cpu(v) for k, v in infos.items()} + elif isinstance(infos, torch.Tensor): + return infos.detach() + else: + return infos diff --git a/unidepth/utils/positional_embedding.py b/unidepth/utils/positional_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..63872de03f672d988f5e7bdec04102d380ce2c08 --- /dev/null +++ b/unidepth/utils/positional_embedding.py @@ -0,0 +1,273 @@ +""" +Author: Luigi Piccinelli +Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/) +""" + +from math import pi +from typing import Optional + +import torch +import torch.nn as nn +from einops import rearrange, repeat + + +class PositionEmbeddingSine(nn.Module): + def __init__( + self, num_pos_feats=64, temperature=10000, normalize=False, scale=None + ): + super().__init__() + self.num_pos_feats = num_pos_feats + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * pi + self.scale = scale + + def forward( + self, x: torch.Tensor, mask: Optional[torch.Tensor] = None + ) -> torch.Tensor: + if mask is None: + mask = torch.zeros( + (x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool + ) + not_mask = ~mask + y_embed = not_mask.cumsum(1, dtype=torch.float32) + x_embed = not_mask.cumsum(2, dtype=torch.float32) + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = self.temperature ** ( + 2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_pos_feats + ) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack( + (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 + ).flatten(3) + pos_y = torch.stack( + (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 + ).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos + + def __repr__(self, _repr_indent=4): + head = "Positional encoding " + self.__class__.__name__ + body = [ + "num_pos_feats: {}".format(self.num_pos_feats), + "temperature: {}".format(self.temperature), + "normalize: {}".format(self.normalize), + "scale: {}".format(self.scale), + ] + # _repr_indent = 4 + lines = [head] + [" " * _repr_indent + line for line in body] + return "\n".join(lines) + + +class LearnedSinusoidalPosEmb(nn.Module): + def __init__(self, dim): + super().__init__() + assert (dim % 2) == 0 + half_dim = dim // 2 + self.weights = nn.Parameter(torch.randn(half_dim)) + + def forward(self, x): + x = rearrange(x, "b -> b 1") + freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * pi + fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1) + fouriered = torch.cat((x, fouriered), dim=-1) + return fouriered + + +def broadcat(tensors, dim=-1): + num_tensors = len(tensors) + shape_lens = set(list(map(lambda t: len(t.shape), tensors))) + assert len(shape_lens) == 1, "tensors must all have the same number of dimensions" + shape_len = list(shape_lens)[0] + dim = (dim + shape_len) if dim < 0 else dim + dims = list(zip(*map(lambda t: list(t.shape), tensors))) + expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim] + assert all( + [*map(lambda t: len(set(t[1])) <= 2, expandable_dims)] + ), "invalid dimensions for broadcastable concatentation" + max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims)) + expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims)) + expanded_dims.insert(dim, (dim, dims[dim])) + expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims))) + tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes))) + return torch.cat(tensors, dim=dim) + + +def rotate_half(x): + x = rearrange(x, "... (d r) -> ... d r", r=2) + x1, x2 = x.unbind(dim=-1) + x = torch.stack((-x2, x1), dim=-1) + return rearrange(x, "... d r -> ... (d r)") + + +class VisionRotaryEmbedding(nn.Module): + def __init__( + self, + dim, + pt_seq_len, + ft_seq_len=None, + custom_freqs=None, + freqs_for="lang", + theta=10000, + max_freq=10, + num_freqs=1, + ): + super().__init__() + if custom_freqs: + freqs = custom_freqs + elif freqs_for == "lang": + freqs = 1.0 / ( + theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim) + ) + elif freqs_for == "pixel": + freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi + elif freqs_for == "constant": + freqs = torch.ones(num_freqs).float() + else: + raise ValueError(f"unknown modality {freqs_for}") + + if ft_seq_len is None: + ft_seq_len = pt_seq_len + t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len + + freqs_h = torch.einsum("..., f -> ... f", t, freqs) + freqs_h = repeat(freqs_h, "... n -> ... (n r)", r=2) + + freqs_w = torch.einsum("..., f -> ... f", t, freqs) + freqs_w = repeat(freqs_w, "... n -> ... (n r)", r=2) + + freqs = broadcat((freqs_h[:, None, :], freqs_w[None, :, :]), dim=-1) + + self.register_buffer("freqs_cos", freqs.cos()) + self.register_buffer("freqs_sin", freqs.sin()) + + print("======== shape of rope freq", self.freqs_cos.shape, "========") + + def forward(self, t, start_index=0): + rot_dim = self.freqs_cos.shape[-1] + end_index = start_index + rot_dim + assert ( + rot_dim <= t.shape[-1] + ), f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}" + t_left, t, t_right = ( + t[..., :start_index], + t[..., start_index:end_index], + t[..., end_index:], + ) + t = (t * self.freqs_cos) + (rotate_half(t) * self.freqs_sin) + return torch.cat((t_left, t, t_right), dim=-1) + + +class VisionRotaryEmbeddingFast(nn.Module): + def __init__( + self, + dim, + pt_seq_len, + ft_seq_len=None, + custom_freqs=None, + freqs_for="lang", + theta=10000, + max_freq=10, + num_freqs=1, + ): + super().__init__() + if custom_freqs: + freqs = custom_freqs + elif freqs_for == "lang": + freqs = 1.0 / ( + theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim) + ) + elif freqs_for == "pixel": + freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi + elif freqs_for == "constant": + freqs = torch.ones(num_freqs).float() + else: + raise ValueError(f"unknown modality {freqs_for}") + + if ft_seq_len is None: + ft_seq_len = pt_seq_len + t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len + + freqs = torch.einsum("..., f -> ... f", t, freqs) + freqs = repeat(freqs, "... n -> ... (n r)", r=2) + freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim=-1) + + freqs_cos = freqs.cos().view(-1, freqs.shape[-1]) + freqs_sin = freqs.sin().view(-1, freqs.shape[-1]) + + self.register_buffer("freqs_cos", freqs_cos) + self.register_buffer("freqs_sin", freqs_sin) + + def forward(self, t): + return t * self.freqs_cos + rotate_half(t) * self.freqs_sin + + +from math import log2 + + +def generate_fourier_features( + x: torch.Tensor, + dim: int = 512, + max_freq: int = 64, + use_cos: bool = False, + use_log: bool = False, + cat_orig: bool = False, +): + x_orig = x + device, dtype, input_dim = x.device, x.dtype, x.shape[-1] + num_bands = dim // (2 * input_dim) if use_cos else dim // input_dim + + if use_log: + scales = 2.0 ** torch.linspace( + 0.0, log2(max_freq), steps=num_bands, device=device, dtype=dtype + ) + else: + scales = torch.linspace( + 1.0, max_freq / 2, num_bands, device=device, dtype=dtype + ) + + x = x.unsqueeze(-1) + scales = scales[(*((None,) * (len(x.shape) - 1)), Ellipsis)] + + x = x * scales * pi + x = torch.cat( + ( + [x.sin(), x.cos()] + if use_cos + else [ + x.sin(), + ] + ), + dim=-1, + ) + x = x.flatten(-2) + if cat_orig: + return torch.cat((x, x_orig), dim=-1) + return x + + +# from PIL import Image +# from unidepth.utils import image_grid, colorize +# if __name__ == "__main__": +# H, W = 512, 512 +# resolution = 128 +# mesh = torch.meshgrid(torch.linspace(-1, 1, H), torch.linspace(-1, 1, W)) +# mesh = torch.stack(mesh, dim=0).unsqueeze(0) +# mesh = mesh.view(1, 2, -1).permute(0, 2, 1) + +# features = generate_fourier_features(mesh, dim=32, max_freq=resolution, use_log=True) +# channels = features.shape[-1] +# print(features.shape) + +# features = features[0].view(H, W, channels).permute(2, 0, 1).numpy() +# Image.fromarray(image_grid([colorize(1+x, 0.0, 2.0, "viridis") for x in features], rows=8, cols=4)).save(f"tmp_{resolution}.png") diff --git a/unidepth/utils/sht.py b/unidepth/utils/sht.py new file mode 100644 index 0000000000000000000000000000000000000000..4d6492ad047e6f7656e12f5e68fceda87e248d25 --- /dev/null +++ b/unidepth/utils/sht.py @@ -0,0 +1,1638 @@ +"""Real spherical harmonics in Cartesian form for PyTorch. + +This is an autogenerated file. See +https://github.com/cheind/torch-spherical-harmonics +for more information. +""" + +import torch + + +def rsh_cart_0(xyz: torch.Tensor): + """Computes all real spherical harmonics up to degree 0. + + This is an autogenerated method. See + https://github.com/cheind/torch-spherical-harmonics + for more information. + + Params: + xyz: (N,...,3) tensor of points on the unit sphere + + Returns: + rsh: (N,...,1) real spherical harmonics + projections of input. Ynm is found at index + `n*(n+1) + m`, with `0 <= n <= degree` and + `-n <= m <= n`. + """ + + return torch.stack( + [ + xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]), + ], + -1, + ) + + +def rsh_cart_1(xyz: torch.Tensor): + """Computes all real spherical harmonics up to degree 1. + + This is an autogenerated method. See + https://github.com/cheind/torch-spherical-harmonics + for more information. + + Params: + xyz: (N,...,3) tensor of points on the unit sphere + + Returns: + rsh: (N,...,4) real spherical harmonics + projections of input. Ynm is found at index + `n*(n+1) + m`, with `0 <= n <= degree` and + `-n <= m <= n`. + """ + x = xyz[..., 0] + y = xyz[..., 1] + z = xyz[..., 2] + + return torch.stack( + [ + xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]), + -0.48860251190292 * y, + 0.48860251190292 * z, + -0.48860251190292 * x, + ], + -1, + ) + + +def rsh_cart_2(xyz: torch.Tensor): + """Computes all real spherical harmonics up to degree 2. + + This is an autogenerated method. See + https://github.com/cheind/torch-spherical-harmonics + for more information. + + Params: + xyz: (N,...,3) tensor of points on the unit sphere + + Returns: + rsh: (N,...,9) real spherical harmonics + projections of input. Ynm is found at index + `n*(n+1) + m`, with `0 <= n <= degree` and + `-n <= m <= n`. + """ + x = xyz[..., 0] + y = xyz[..., 1] + z = xyz[..., 2] + + x2 = x**2 + y2 = y**2 + z2 = z**2 + xy = x * y + xz = x * z + yz = y * z + + return torch.stack( + [ + xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]), + -0.48860251190292 * y, + 0.48860251190292 * z, + -0.48860251190292 * x, + 1.09254843059208 * xy, + -1.09254843059208 * yz, + 0.94617469575756 * z2 - 0.31539156525252, + -1.09254843059208 * xz, + 0.54627421529604 * x2 - 0.54627421529604 * y2, + ], + -1, + ) + + +def rsh_cart_3(xyz: torch.Tensor): + """Computes all real spherical harmonics up to degree 3. + + This is an autogenerated method. See + https://github.com/cheind/torch-spherical-harmonics + for more information. + + Params: + xyz: (N,...,3) tensor of points on the unit sphere + + Returns: + rsh: (N,...,16) real spherical harmonics + projections of input. Ynm is found at index + `n*(n+1) + m`, with `0 <= n <= degree` and + `-n <= m <= n`. + """ + x = xyz[..., 0] + y = xyz[..., 1] + z = xyz[..., 2] + + x2 = x**2 + y2 = y**2 + z2 = z**2 + xy = x * y + xz = x * z + yz = y * z + + return torch.stack( + [ + xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]), + -0.48860251190292 * y, + 0.48860251190292 * z, + -0.48860251190292 * x, + 1.09254843059208 * xy, + -1.09254843059208 * yz, + 0.94617469575756 * z2 - 0.31539156525252, + -1.09254843059208 * xz, + 0.54627421529604 * x2 - 0.54627421529604 * y2, + -0.590043589926644 * y * (3.0 * x2 - y2), + 2.89061144264055 * xy * z, + 0.304697199642977 * y * (1.5 - 7.5 * z2), + 1.24392110863372 * z * (1.5 * z2 - 0.5) - 0.497568443453487 * z, + 0.304697199642977 * x * (1.5 - 7.5 * z2), + 1.44530572132028 * z * (x2 - y2), + -0.590043589926644 * x * (x2 - 3.0 * y2), + ], + -1, + ) + + +def rsh_cart_4(xyz: torch.Tensor): + """Computes all real spherical harmonics up to degree 4. + + This is an autogenerated method. See + https://github.com/cheind/torch-spherical-harmonics + for more information. + + Params: + xyz: (N,...,3) tensor of points on the unit sphere + + Returns: + rsh: (N,...,25) real spherical harmonics + projections of input. Ynm is found at index + `n*(n+1) + m`, with `0 <= n <= degree` and + `-n <= m <= n`. + """ + x = xyz[..., 0] + y = xyz[..., 1] + z = xyz[..., 2] + + x2 = x**2 + y2 = y**2 + z2 = z**2 + xy = x * y + xz = x * z + yz = y * z + x4 = x2**2 + y4 = y2**2 + z4 = z2**2 + + return torch.stack( + [ + xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]), + -0.48860251190292 * y, + 0.48860251190292 * z, + -0.48860251190292 * x, + 1.09254843059208 * xy, + -1.09254843059208 * yz, + 0.94617469575756 * z2 - 0.31539156525252, + -1.09254843059208 * xz, + 0.54627421529604 * x2 - 0.54627421529604 * y2, + -0.590043589926644 * y * (3.0 * x2 - y2), + 2.89061144264055 * xy * z, + 0.304697199642977 * y * (1.5 - 7.5 * z2), + 1.24392110863372 * z * (1.5 * z2 - 0.5) - 0.497568443453487 * z, + 0.304697199642977 * x * (1.5 - 7.5 * z2), + 1.44530572132028 * z * (x2 - y2), + -0.590043589926644 * x * (x2 - 3.0 * y2), + 2.5033429417967 * xy * (x2 - y2), + -1.77013076977993 * yz * (3.0 * x2 - y2), + 0.126156626101008 * xy * (52.5 * z2 - 7.5), + 0.267618617422916 * y * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z), + 1.48099765681286 + * z + * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) + - 0.952069922236839 * z2 + + 0.317356640745613, + 0.267618617422916 * x * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z), + 0.063078313050504 * (x2 - y2) * (52.5 * z2 - 7.5), + -1.77013076977993 * xz * (x2 - 3.0 * y2), + -3.75501441269506 * x2 * y2 + + 0.625835735449176 * x4 + + 0.625835735449176 * y4, + ], + -1, + ) + + +def rsh_cart_5(xyz: torch.Tensor): + """Computes all real spherical harmonics up to degree 5. + + This is an autogenerated method. See + https://github.com/cheind/torch-spherical-harmonics + for more information. + + Params: + xyz: (N,...,3) tensor of points on the unit sphere + + Returns: + rsh: (N,...,36) real spherical harmonics + projections of input. Ynm is found at index + `n*(n+1) + m`, with `0 <= n <= degree` and + `-n <= m <= n`. + """ + x = xyz[..., 0] + y = xyz[..., 1] + z = xyz[..., 2] + + x2 = x**2 + y2 = y**2 + z2 = z**2 + xy = x * y + xz = x * z + yz = y * z + x4 = x2**2 + y4 = y2**2 + z4 = z2**2 + + return torch.stack( + [ + xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]), + -0.48860251190292 * y, + 0.48860251190292 * z, + -0.48860251190292 * x, + 1.09254843059208 * xy, + -1.09254843059208 * yz, + 0.94617469575756 * z2 - 0.31539156525252, + -1.09254843059208 * xz, + 0.54627421529604 * x2 - 0.54627421529604 * y2, + -0.590043589926644 * y * (3.0 * x2 - y2), + 2.89061144264055 * xy * z, + 0.304697199642977 * y * (1.5 - 7.5 * z2), + 1.24392110863372 * z * (1.5 * z2 - 0.5) - 0.497568443453487 * z, + 0.304697199642977 * x * (1.5 - 7.5 * z2), + 1.44530572132028 * z * (x2 - y2), + -0.590043589926644 * x * (x2 - 3.0 * y2), + 2.5033429417967 * xy * (x2 - y2), + -1.77013076977993 * yz * (3.0 * x2 - y2), + 0.126156626101008 * xy * (52.5 * z2 - 7.5), + 0.267618617422916 * y * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z), + 1.48099765681286 + * z + * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) + - 0.952069922236839 * z2 + + 0.317356640745613, + 0.267618617422916 * x * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z), + 0.063078313050504 * (x2 - y2) * (52.5 * z2 - 7.5), + -1.77013076977993 * xz * (x2 - 3.0 * y2), + -3.75501441269506 * x2 * y2 + + 0.625835735449176 * x4 + + 0.625835735449176 * y4, + -0.65638205684017 * y * (-10.0 * x2 * y2 + 5.0 * x4 + y4), + 8.30264925952416 * xy * z * (x2 - y2), + 0.00931882475114763 * y * (52.5 - 472.5 * z2) * (3.0 * x2 - y2), + 0.0913054625709205 * xy * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z), + 0.241571547304372 + * y + * ( + 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + + 9.375 * z2 + - 1.875 + ), + -1.24747010616985 * z * (1.5 * z2 - 0.5) + + 1.6840846433293 + * z + * ( + 1.75 + * z + * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) + - 1.125 * z2 + + 0.375 + ) + + 0.498988042467941 * z, + 0.241571547304372 + * x + * ( + 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + + 9.375 * z2 + - 1.875 + ), + 0.0456527312854602 * (x2 - y2) * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z), + 0.00931882475114763 * x * (52.5 - 472.5 * z2) * (x2 - 3.0 * y2), + 2.07566231488104 * z * (-6.0 * x2 * y2 + x4 + y4), + -0.65638205684017 * x * (-10.0 * x2 * y2 + x4 + 5.0 * y4), + ], + -1, + ) + + +def rsh_cart_6(xyz: torch.Tensor): + """Computes all real spherical harmonics up to degree 6. + + This is an autogenerated method. See + https://github.com/cheind/torch-spherical-harmonics + for more information. + + Params: + xyz: (N,...,3) tensor of points on the unit sphere + + Returns: + rsh: (N,...,49) real spherical harmonics + projections of input. Ynm is found at index + `n*(n+1) + m`, with `0 <= n <= degree` and + `-n <= m <= n`. + """ + x = xyz[..., 0] + y = xyz[..., 1] + z = xyz[..., 2] + + x2 = x**2 + y2 = y**2 + z2 = z**2 + xy = x * y + xz = x * z + yz = y * z + x4 = x2**2 + y4 = y2**2 + z4 = z2**2 + + return torch.stack( + [ + xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]), + -0.48860251190292 * y, + 0.48860251190292 * z, + -0.48860251190292 * x, + 1.09254843059208 * xy, + -1.09254843059208 * yz, + 0.94617469575756 * z2 - 0.31539156525252, + -1.09254843059208 * xz, + 0.54627421529604 * x2 - 0.54627421529604 * y2, + -0.590043589926644 * y * (3.0 * x2 - y2), + 2.89061144264055 * xy * z, + 0.304697199642977 * y * (1.5 - 7.5 * z2), + 1.24392110863372 * z * (1.5 * z2 - 0.5) - 0.497568443453487 * z, + 0.304697199642977 * x * (1.5 - 7.5 * z2), + 1.44530572132028 * z * (x2 - y2), + -0.590043589926644 * x * (x2 - 3.0 * y2), + 2.5033429417967 * xy * (x2 - y2), + -1.77013076977993 * yz * (3.0 * x2 - y2), + 0.126156626101008 * xy * (52.5 * z2 - 7.5), + 0.267618617422916 * y * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z), + 1.48099765681286 + * z + * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) + - 0.952069922236839 * z2 + + 0.317356640745613, + 0.267618617422916 * x * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z), + 0.063078313050504 * (x2 - y2) * (52.5 * z2 - 7.5), + -1.77013076977993 * xz * (x2 - 3.0 * y2), + -3.75501441269506 * x2 * y2 + + 0.625835735449176 * x4 + + 0.625835735449176 * y4, + -0.65638205684017 * y * (-10.0 * x2 * y2 + 5.0 * x4 + y4), + 8.30264925952416 * xy * z * (x2 - y2), + 0.00931882475114763 * y * (52.5 - 472.5 * z2) * (3.0 * x2 - y2), + 0.0913054625709205 * xy * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z), + 0.241571547304372 + * y + * ( + 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + + 9.375 * z2 + - 1.875 + ), + -1.24747010616985 * z * (1.5 * z2 - 0.5) + + 1.6840846433293 + * z + * ( + 1.75 + * z + * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) + - 1.125 * z2 + + 0.375 + ) + + 0.498988042467941 * z, + 0.241571547304372 + * x + * ( + 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + + 9.375 * z2 + - 1.875 + ), + 0.0456527312854602 * (x2 - y2) * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z), + 0.00931882475114763 * x * (52.5 - 472.5 * z2) * (x2 - 3.0 * y2), + 2.07566231488104 * z * (-6.0 * x2 * y2 + x4 + y4), + -0.65638205684017 * x * (-10.0 * x2 * y2 + x4 + 5.0 * y4), + 4.09910463115149 * x**4 * xy + - 13.6636821038383 * xy**3 + + 4.09910463115149 * xy * y**4, + -2.36661916223175 * yz * (-10.0 * x2 * y2 + 5.0 * x4 + y4), + 0.00427144889505798 * xy * (x2 - y2) * (5197.5 * z2 - 472.5), + 0.00584892228263444 + * y + * (3.0 * x2 - y2) + * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z), + 0.0701870673916132 + * xy + * ( + 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z) + - 91.875 * z2 + + 13.125 + ), + 0.221950995245231 + * y + * ( + -2.8 * z * (1.5 - 7.5 * z2) + + 2.2 + * z + * ( + 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + + 9.375 * z2 + - 1.875 + ) + - 4.8 * z + ), + -1.48328138624466 + * z + * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) + + 1.86469659985043 + * z + * ( + -1.33333333333333 * z * (1.5 * z2 - 0.5) + + 1.8 + * z + * ( + 1.75 + * z + * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) + - 1.125 * z2 + + 0.375 + ) + + 0.533333333333333 * z + ) + + 0.953538034014426 * z2 + - 0.317846011338142, + 0.221950995245231 + * x + * ( + -2.8 * z * (1.5 - 7.5 * z2) + + 2.2 + * z + * ( + 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + + 9.375 * z2 + - 1.875 + ) + - 4.8 * z + ), + 0.0350935336958066 + * (x2 - y2) + * ( + 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z) + - 91.875 * z2 + + 13.125 + ), + 0.00584892228263444 + * x + * (x2 - 3.0 * y2) + * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z), + 0.0010678622237645 * (5197.5 * z2 - 472.5) * (-6.0 * x2 * y2 + x4 + y4), + -2.36661916223175 * xz * (-10.0 * x2 * y2 + x4 + 5.0 * y4), + 0.683184105191914 * x2**3 + + 10.2477615778787 * x2 * y4 + - 10.2477615778787 * x4 * y2 + - 0.683184105191914 * y2**3, + ], + -1, + ) + + +def rsh_cart_7(xyz: torch.Tensor): + """Computes all real spherical harmonics up to degree 7. + + This is an autogenerated method. See + https://github.com/cheind/torch-spherical-harmonics + for more information. + + Params: + xyz: (N,...,3) tensor of points on the unit sphere + + Returns: + rsh: (N,...,64) real spherical harmonics + projections of input. Ynm is found at index + `n*(n+1) + m`, with `0 <= n <= degree` and + `-n <= m <= n`. + """ + x = xyz[..., 0] + y = xyz[..., 1] + z = xyz[..., 2] + + x2 = x**2 + y2 = y**2 + z2 = z**2 + xy = x * y + xz = x * z + yz = y * z + x4 = x2**2 + y4 = y2**2 + z4 = z2**2 + + return torch.stack( + [ + xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]), + -0.48860251190292 * y, + 0.48860251190292 * z, + -0.48860251190292 * x, + 1.09254843059208 * xy, + -1.09254843059208 * yz, + 0.94617469575756 * z2 - 0.31539156525252, + -1.09254843059208 * xz, + 0.54627421529604 * x2 - 0.54627421529604 * y2, + -0.590043589926644 * y * (3.0 * x2 - y2), + 2.89061144264055 * xy * z, + 0.304697199642977 * y * (1.5 - 7.5 * z2), + 1.24392110863372 * z * (1.5 * z2 - 0.5) - 0.497568443453487 * z, + 0.304697199642977 * x * (1.5 - 7.5 * z2), + 1.44530572132028 * z * (x2 - y2), + -0.590043589926644 * x * (x2 - 3.0 * y2), + 2.5033429417967 * xy * (x2 - y2), + -1.77013076977993 * yz * (3.0 * x2 - y2), + 0.126156626101008 * xy * (52.5 * z2 - 7.5), + 0.267618617422916 * y * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z), + 1.48099765681286 + * z + * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) + - 0.952069922236839 * z2 + + 0.317356640745613, + 0.267618617422916 * x * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z), + 0.063078313050504 * (x2 - y2) * (52.5 * z2 - 7.5), + -1.77013076977993 * xz * (x2 - 3.0 * y2), + -3.75501441269506 * x2 * y2 + + 0.625835735449176 * x4 + + 0.625835735449176 * y4, + -0.65638205684017 * y * (-10.0 * x2 * y2 + 5.0 * x4 + y4), + 8.30264925952416 * xy * z * (x2 - y2), + 0.00931882475114763 * y * (52.5 - 472.5 * z2) * (3.0 * x2 - y2), + 0.0913054625709205 * xy * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z), + 0.241571547304372 + * y + * ( + 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + + 9.375 * z2 + - 1.875 + ), + -1.24747010616985 * z * (1.5 * z2 - 0.5) + + 1.6840846433293 + * z + * ( + 1.75 + * z + * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) + - 1.125 * z2 + + 0.375 + ) + + 0.498988042467941 * z, + 0.241571547304372 + * x + * ( + 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + + 9.375 * z2 + - 1.875 + ), + 0.0456527312854602 * (x2 - y2) * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z), + 0.00931882475114763 * x * (52.5 - 472.5 * z2) * (x2 - 3.0 * y2), + 2.07566231488104 * z * (-6.0 * x2 * y2 + x4 + y4), + -0.65638205684017 * x * (-10.0 * x2 * y2 + x4 + 5.0 * y4), + 4.09910463115149 * x**4 * xy + - 13.6636821038383 * xy**3 + + 4.09910463115149 * xy * y**4, + -2.36661916223175 * yz * (-10.0 * x2 * y2 + 5.0 * x4 + y4), + 0.00427144889505798 * xy * (x2 - y2) * (5197.5 * z2 - 472.5), + 0.00584892228263444 + * y + * (3.0 * x2 - y2) + * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z), + 0.0701870673916132 + * xy + * ( + 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z) + - 91.875 * z2 + + 13.125 + ), + 0.221950995245231 + * y + * ( + -2.8 * z * (1.5 - 7.5 * z2) + + 2.2 + * z + * ( + 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + + 9.375 * z2 + - 1.875 + ) + - 4.8 * z + ), + -1.48328138624466 + * z + * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) + + 1.86469659985043 + * z + * ( + -1.33333333333333 * z * (1.5 * z2 - 0.5) + + 1.8 + * z + * ( + 1.75 + * z + * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) + - 1.125 * z2 + + 0.375 + ) + + 0.533333333333333 * z + ) + + 0.953538034014426 * z2 + - 0.317846011338142, + 0.221950995245231 + * x + * ( + -2.8 * z * (1.5 - 7.5 * z2) + + 2.2 + * z + * ( + 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + + 9.375 * z2 + - 1.875 + ) + - 4.8 * z + ), + 0.0350935336958066 + * (x2 - y2) + * ( + 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z) + - 91.875 * z2 + + 13.125 + ), + 0.00584892228263444 + * x + * (x2 - 3.0 * y2) + * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z), + 0.0010678622237645 * (5197.5 * z2 - 472.5) * (-6.0 * x2 * y2 + x4 + y4), + -2.36661916223175 * xz * (-10.0 * x2 * y2 + x4 + 5.0 * y4), + 0.683184105191914 * x2**3 + + 10.2477615778787 * x2 * y4 + - 10.2477615778787 * x4 * y2 + - 0.683184105191914 * y2**3, + -0.707162732524596 + * y + * (7.0 * x2**3 + 21.0 * x2 * y4 - 35.0 * x4 * y2 - y2**3), + 2.6459606618019 * z * (6.0 * x**4 * xy - 20.0 * xy**3 + 6.0 * xy * y**4), + 9.98394571852353e-5 + * y + * (5197.5 - 67567.5 * z2) + * (-10.0 * x2 * y2 + 5.0 * x4 + y4), + 0.00239614697244565 + * xy + * (x2 - y2) + * (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z), + 0.00397356022507413 + * y + * (3.0 * x2 - y2) + * ( + 3.25 * z * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z) + + 1063.125 * z2 + - 118.125 + ), + 0.0561946276120613 + * xy + * ( + -4.8 * z * (52.5 * z2 - 7.5) + + 2.6 + * z + * ( + 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z) + - 91.875 * z2 + + 13.125 + ) + + 48.0 * z + ), + 0.206472245902897 + * y + * ( + -2.625 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + + 2.16666666666667 + * z + * ( + -2.8 * z * (1.5 - 7.5 * z2) + + 2.2 + * z + * ( + 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + + 9.375 * z2 + - 1.875 + ) + - 4.8 * z + ) + - 10.9375 * z2 + + 2.1875 + ), + 1.24862677781952 * z * (1.5 * z2 - 0.5) + - 1.68564615005635 + * z + * ( + 1.75 + * z + * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) + - 1.125 * z2 + + 0.375 + ) + + 2.02901851395672 + * z + * ( + -1.45833333333333 + * z + * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) + + 1.83333333333333 + * z + * ( + -1.33333333333333 * z * (1.5 * z2 - 0.5) + + 1.8 + * z + * ( + 1.75 + * z + * ( + 1.66666666666667 * z * (1.5 * z2 - 0.5) + - 0.666666666666667 * z + ) + - 1.125 * z2 + + 0.375 + ) + + 0.533333333333333 * z + ) + + 0.9375 * z2 + - 0.3125 + ) + - 0.499450711127808 * z, + 0.206472245902897 + * x + * ( + -2.625 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + + 2.16666666666667 + * z + * ( + -2.8 * z * (1.5 - 7.5 * z2) + + 2.2 + * z + * ( + 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + + 9.375 * z2 + - 1.875 + ) + - 4.8 * z + ) + - 10.9375 * z2 + + 2.1875 + ), + 0.0280973138060306 + * (x2 - y2) + * ( + -4.8 * z * (52.5 * z2 - 7.5) + + 2.6 + * z + * ( + 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z) + - 91.875 * z2 + + 13.125 + ) + + 48.0 * z + ), + 0.00397356022507413 + * x + * (x2 - 3.0 * y2) + * ( + 3.25 * z * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z) + + 1063.125 * z2 + - 118.125 + ), + 0.000599036743111412 + * (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z) + * (-6.0 * x2 * y2 + x4 + y4), + 9.98394571852353e-5 + * x + * (5197.5 - 67567.5 * z2) + * (-10.0 * x2 * y2 + x4 + 5.0 * y4), + 2.6459606618019 * z * (x2**3 + 15.0 * x2 * y4 - 15.0 * x4 * y2 - y2**3), + -0.707162732524596 + * x + * (x2**3 + 35.0 * x2 * y4 - 21.0 * x4 * y2 - 7.0 * y2**3), + ], + -1, + ) + + +# @torch.jit.script +def rsh_cart_8(xyz: torch.Tensor): + """Computes all real spherical harmonics up to degree 8. + + This is an autogenerated method. See + https://github.com/cheind/torch-spherical-harmonics + for more information. + + Params: + xyz: (N,...,3) tensor of points on the unit sphere + + Returns: + rsh: (N,...,81) real spherical harmonics + projections of input. Ynm is found at index + `n*(n+1) + m`, with `0 <= n <= degree` and + `-n <= m <= n`. + """ + x = xyz[..., 0] + y = xyz[..., 1] + z = xyz[..., 2] + + x2 = x**2 + y2 = y**2 + z2 = z**2 + xy = x * y + xz = x * z + yz = y * z + x4 = x2**2 + y4 = y2**2 + # z4 = z2**2 + return torch.stack( + [ + 0.282094791773878 * torch.ones(1, device=xyz.device).expand(xyz.shape[:-1]), + -0.48860251190292 * y, + 0.48860251190292 * z, + -0.48860251190292 * x, + 1.09254843059208 * xy, + -1.09254843059208 * yz, + 0.94617469575756 * z2 - 0.31539156525252, + -1.09254843059208 * xz, + 0.54627421529604 * x2 - 0.54627421529604 * y2, + -0.590043589926644 * y * (3.0 * x2 - y2), + 2.89061144264055 * xy * z, + 0.304697199642977 * y * (1.5 - 7.5 * z2), + 1.24392110863372 * z * (1.5 * z2 - 0.5) - 0.497568443453487 * z, + 0.304697199642977 * x * (1.5 - 7.5 * z2), + 1.44530572132028 * z * (x2 - y2), + -0.590043589926644 * x * (x2 - 3.0 * y2), + 2.5033429417967 * xy * (x2 - y2), + -1.77013076977993 * yz * (3.0 * x2 - y2), + 0.126156626101008 * xy * (52.5 * z2 - 7.5), + 0.267618617422916 * y * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z), + 1.48099765681286 + * z + * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) + - 0.952069922236839 * z2 + + 0.317356640745613, + 0.267618617422916 * x * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z), + 0.063078313050504 * (x2 - y2) * (52.5 * z2 - 7.5), + -1.77013076977993 * xz * (x2 - 3.0 * y2), + -3.75501441269506 * x2 * y2 + + 0.625835735449176 * x4 + + 0.625835735449176 * y4, + -0.65638205684017 * y * (-10.0 * x2 * y2 + 5.0 * x4 + y4), + 8.30264925952416 * xy * z * (x2 - y2), + 0.00931882475114763 * y * (52.5 - 472.5 * z2) * (3.0 * x2 - y2), + 0.0913054625709205 * xy * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z), + 0.241571547304372 + * y + * ( + 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + + 9.375 * z2 + - 1.875 + ), + -1.24747010616985 * z * (1.5 * z2 - 0.5) + + 1.6840846433293 + * z + * ( + 1.75 + * z + * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) + - 1.125 * z2 + + 0.375 + ) + + 0.498988042467941 * z, + 0.241571547304372 + * x + * ( + 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + + 9.375 * z2 + - 1.875 + ), + 0.0456527312854602 * (x2 - y2) * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z), + 0.00931882475114763 * x * (52.5 - 472.5 * z2) * (x2 - 3.0 * y2), + 2.07566231488104 * z * (-6.0 * x2 * y2 + x4 + y4), + -0.65638205684017 * x * (-10.0 * x2 * y2 + x4 + 5.0 * y4), + 4.09910463115149 * x**4 * xy + - 13.6636821038383 * xy**3 + + 4.09910463115149 * xy * y**4, + -2.36661916223175 * yz * (-10.0 * x2 * y2 + 5.0 * x4 + y4), + 0.00427144889505798 * xy * (x2 - y2) * (5197.5 * z2 - 472.5), + 0.00584892228263444 + * y + * (3.0 * x2 - y2) + * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z), + 0.0701870673916132 + * xy + * ( + 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z) + - 91.875 * z2 + + 13.125 + ), + 0.221950995245231 + * y + * ( + -2.8 * z * (1.5 - 7.5 * z2) + + 2.2 + * z + * ( + 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + + 9.375 * z2 + - 1.875 + ) + - 4.8 * z + ), + -1.48328138624466 + * z + * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) + + 1.86469659985043 + * z + * ( + -1.33333333333333 * z * (1.5 * z2 - 0.5) + + 1.8 + * z + * ( + 1.75 + * z + * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) + - 1.125 * z2 + + 0.375 + ) + + 0.533333333333333 * z + ) + + 0.953538034014426 * z2 + - 0.317846011338142, + 0.221950995245231 + * x + * ( + -2.8 * z * (1.5 - 7.5 * z2) + + 2.2 + * z + * ( + 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + + 9.375 * z2 + - 1.875 + ) + - 4.8 * z + ), + 0.0350935336958066 + * (x2 - y2) + * ( + 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z) + - 91.875 * z2 + + 13.125 + ), + 0.00584892228263444 + * x + * (x2 - 3.0 * y2) + * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z), + 0.0010678622237645 * (5197.5 * z2 - 472.5) * (-6.0 * x2 * y2 + x4 + y4), + -2.36661916223175 * xz * (-10.0 * x2 * y2 + x4 + 5.0 * y4), + 0.683184105191914 * x2**3 + + 10.2477615778787 * x2 * y4 + - 10.2477615778787 * x4 * y2 + - 0.683184105191914 * y2**3, + -0.707162732524596 + * y + * (7.0 * x2**3 + 21.0 * x2 * y4 - 35.0 * x4 * y2 - y2**3), + 2.6459606618019 * z * (6.0 * x**4 * xy - 20.0 * xy**3 + 6.0 * xy * y**4), + 9.98394571852353e-5 + * y + * (5197.5 - 67567.5 * z2) + * (-10.0 * x2 * y2 + 5.0 * x4 + y4), + 0.00239614697244565 + * xy + * (x2 - y2) + * (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z), + 0.00397356022507413 + * y + * (3.0 * x2 - y2) + * ( + 3.25 * z * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z) + + 1063.125 * z2 + - 118.125 + ), + 0.0561946276120613 + * xy + * ( + -4.8 * z * (52.5 * z2 - 7.5) + + 2.6 + * z + * ( + 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z) + - 91.875 * z2 + + 13.125 + ) + + 48.0 * z + ), + 0.206472245902897 + * y + * ( + -2.625 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + + 2.16666666666667 + * z + * ( + -2.8 * z * (1.5 - 7.5 * z2) + + 2.2 + * z + * ( + 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + + 9.375 * z2 + - 1.875 + ) + - 4.8 * z + ) + - 10.9375 * z2 + + 2.1875 + ), + 1.24862677781952 * z * (1.5 * z2 - 0.5) + - 1.68564615005635 + * z + * ( + 1.75 + * z + * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) + - 1.125 * z2 + + 0.375 + ) + + 2.02901851395672 + * z + * ( + -1.45833333333333 + * z + * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) + + 1.83333333333333 + * z + * ( + -1.33333333333333 * z * (1.5 * z2 - 0.5) + + 1.8 + * z + * ( + 1.75 + * z + * ( + 1.66666666666667 * z * (1.5 * z2 - 0.5) + - 0.666666666666667 * z + ) + - 1.125 * z2 + + 0.375 + ) + + 0.533333333333333 * z + ) + + 0.9375 * z2 + - 0.3125 + ) + - 0.499450711127808 * z, + 0.206472245902897 + * x + * ( + -2.625 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + + 2.16666666666667 + * z + * ( + -2.8 * z * (1.5 - 7.5 * z2) + + 2.2 + * z + * ( + 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + + 9.375 * z2 + - 1.875 + ) + - 4.8 * z + ) + - 10.9375 * z2 + + 2.1875 + ), + 0.0280973138060306 + * (x2 - y2) + * ( + -4.8 * z * (52.5 * z2 - 7.5) + + 2.6 + * z + * ( + 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z) + - 91.875 * z2 + + 13.125 + ) + + 48.0 * z + ), + 0.00397356022507413 + * x + * (x2 - 3.0 * y2) + * ( + 3.25 * z * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z) + + 1063.125 * z2 + - 118.125 + ), + 0.000599036743111412 + * (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z) + * (-6.0 * x2 * y2 + x4 + y4), + 9.98394571852353e-5 + * x + * (5197.5 - 67567.5 * z2) + * (-10.0 * x2 * y2 + x4 + 5.0 * y4), + 2.6459606618019 * z * (x2**3 + 15.0 * x2 * y4 - 15.0 * x4 * y2 - y2**3), + -0.707162732524596 + * x + * (x2**3 + 35.0 * x2 * y4 - 21.0 * x4 * y2 - 7.0 * y2**3), + 5.83141328139864 * xy * (x2**3 + 7.0 * x2 * y4 - 7.0 * x4 * y2 - y2**3), + -2.91570664069932 + * yz + * (7.0 * x2**3 + 21.0 * x2 * y4 - 35.0 * x4 * y2 - y2**3), + 7.87853281621404e-6 + * (1013512.5 * z2 - 67567.5) + * (6.0 * x**4 * xy - 20.0 * xy**3 + 6.0 * xy * y**4), + 5.10587282657803e-5 + * y + * (5.0 * z * (5197.5 - 67567.5 * z2) + 41580.0 * z) + * (-10.0 * x2 * y2 + 5.0 * x4 + y4), + 0.00147275890257803 + * xy + * (x2 - y2) + * ( + 3.75 * z * (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z) + - 14293.125 * z2 + + 1299.375 + ), + 0.0028519853513317 + * y + * (3.0 * x2 - y2) + * ( + -7.33333333333333 * z * (52.5 - 472.5 * z2) + + 3.0 + * z + * ( + 3.25 * z * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z) + + 1063.125 * z2 + - 118.125 + ) + - 560.0 * z + ), + 0.0463392770473559 + * xy + * ( + -4.125 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z) + + 2.5 + * z + * ( + -4.8 * z * (52.5 * z2 - 7.5) + + 2.6 + * z + * ( + 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z) + - 91.875 * z2 + + 13.125 + ) + + 48.0 * z + ) + + 137.8125 * z2 + - 19.6875 + ), + 0.193851103820053 + * y + * ( + 3.2 * z * (1.5 - 7.5 * z2) + - 2.51428571428571 + * z + * ( + 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + + 9.375 * z2 + - 1.875 + ) + + 2.14285714285714 + * z + * ( + -2.625 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + + 2.16666666666667 + * z + * ( + -2.8 * z * (1.5 - 7.5 * z2) + + 2.2 + * z + * ( + 2.25 + * z + * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + + 9.375 * z2 + - 1.875 + ) + - 4.8 * z + ) + - 10.9375 * z2 + + 2.1875 + ) + + 5.48571428571429 * z + ), + 1.48417251362228 + * z + * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) + - 1.86581687426801 + * z + * ( + -1.33333333333333 * z * (1.5 * z2 - 0.5) + + 1.8 + * z + * ( + 1.75 + * z + * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) + - 1.125 * z2 + + 0.375 + ) + + 0.533333333333333 * z + ) + + 2.1808249179756 + * z + * ( + 1.14285714285714 * z * (1.5 * z2 - 0.5) + - 1.54285714285714 + * z + * ( + 1.75 + * z + * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) + - 1.125 * z2 + + 0.375 + ) + + 1.85714285714286 + * z + * ( + -1.45833333333333 + * z + * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) + + 1.83333333333333 + * z + * ( + -1.33333333333333 * z * (1.5 * z2 - 0.5) + + 1.8 + * z + * ( + 1.75 + * z + * ( + 1.66666666666667 * z * (1.5 * z2 - 0.5) + - 0.666666666666667 * z + ) + - 1.125 * z2 + + 0.375 + ) + + 0.533333333333333 * z + ) + + 0.9375 * z2 + - 0.3125 + ) + - 0.457142857142857 * z + ) + - 0.954110901614325 * z2 + + 0.318036967204775, + 0.193851103820053 + * x + * ( + 3.2 * z * (1.5 - 7.5 * z2) + - 2.51428571428571 + * z + * ( + 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + + 9.375 * z2 + - 1.875 + ) + + 2.14285714285714 + * z + * ( + -2.625 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + + 2.16666666666667 + * z + * ( + -2.8 * z * (1.5 - 7.5 * z2) + + 2.2 + * z + * ( + 2.25 + * z + * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + + 9.375 * z2 + - 1.875 + ) + - 4.8 * z + ) + - 10.9375 * z2 + + 2.1875 + ) + + 5.48571428571429 * z + ), + 0.0231696385236779 + * (x2 - y2) + * ( + -4.125 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z) + + 2.5 + * z + * ( + -4.8 * z * (52.5 * z2 - 7.5) + + 2.6 + * z + * ( + 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z) + - 91.875 * z2 + + 13.125 + ) + + 48.0 * z + ) + + 137.8125 * z2 + - 19.6875 + ), + 0.0028519853513317 + * x + * (x2 - 3.0 * y2) + * ( + -7.33333333333333 * z * (52.5 - 472.5 * z2) + + 3.0 + * z + * ( + 3.25 * z * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z) + + 1063.125 * z2 + - 118.125 + ) + - 560.0 * z + ), + 0.000368189725644507 + * (-6.0 * x2 * y2 + x4 + y4) + * ( + 3.75 * z * (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z) + - 14293.125 * z2 + + 1299.375 + ), + 5.10587282657803e-5 + * x + * (5.0 * z * (5197.5 - 67567.5 * z2) + 41580.0 * z) + * (-10.0 * x2 * y2 + x4 + 5.0 * y4), + 7.87853281621404e-6 + * (1013512.5 * z2 - 67567.5) + * (x2**3 + 15.0 * x2 * y4 - 15.0 * x4 * y2 - y2**3), + -2.91570664069932 + * xz + * (x2**3 + 35.0 * x2 * y4 - 21.0 * x4 * y2 - 7.0 * y2**3), + -20.4099464848952 * x2**3 * y2 + - 20.4099464848952 * x2 * y2**3 + + 0.72892666017483 * x4**2 + + 51.0248662122381 * x4 * y4 + + 0.72892666017483 * y4**2, + ], + -1, + ) + + +__all__ = [ + "rsh_cart_0", + "rsh_cart_1", + "rsh_cart_2", + "rsh_cart_3", + "rsh_cart_4", + "rsh_cart_5", + "rsh_cart_6", + "rsh_cart_7", + "rsh_cart_8", +] + + +from typing import Optional + +import torch + + +class SphHarm(torch.nn.Module): + def __init__(self, m, n, dtype=torch.float32) -> None: + super().__init__() + self.dtype = dtype + m = torch.tensor(list(range(-m + 1, m))) + n = torch.tensor(list(range(n))) + self.is_normalized = False + vals = torch.cartesian_prod(m, n).T + vals = vals[:, vals[0] <= vals[1]] + m, n = vals.unbind(0) + + self.register_buffer("m", tensor=m) + self.register_buffer("n", tensor=n) + self.register_buffer("l_max", tensor=torch.max(self.n)) + + f_a, f_b, initial_value, d0_mask_3d, d1_mask_3d = self._init_legendre() + self.register_buffer("f_a", tensor=f_a) + self.register_buffer("f_b", tensor=f_b) + self.register_buffer("d0_mask_3d", tensor=d0_mask_3d) + self.register_buffer("d1_mask_3d", tensor=d1_mask_3d) + self.register_buffer("initial_value", tensor=initial_value) + + @property + def device(self): + return next(self.buffers()).device + + def forward(self, points: torch.Tensor) -> torch.Tensor: + """Computes the spherical harmonics.""" + # Y_l^m = (-1) ^ m c_l^m P_l^m(cos(theta)) exp(i m phi) + B, N, D = points.shape + dtype = points.dtype + theta, phi = points.view(-1, D).to(self.dtype).unbind(-1) + cos_colatitude = torch.cos(phi) + legendre = self._gen_associated_legendre(cos_colatitude) + vals = torch.stack([self.m.abs(), self.n], dim=0) + vals = torch.cat( + [ + vals.repeat(1, theta.shape[0]), + torch.arange(theta.shape[0], device=theta.device) + .unsqueeze(0) + .repeat_interleave(vals.shape[1], dim=1), + ], + dim=0, + ) + legendre_vals = legendre[vals[0], vals[1], vals[2]] + legendre_vals = legendre_vals.reshape(-1, theta.shape[0]) + angle = torch.outer(self.m.abs(), theta) + vandermonde = torch.complex(torch.cos(angle), torch.sin(angle)) + harmonics = torch.complex( + legendre_vals * torch.real(vandermonde), + legendre_vals * torch.imag(vandermonde), + ) + + # Negative order. + m = self.m.unsqueeze(-1) + harmonics = torch.where( + m < 0, (-1.0) ** m.abs() * torch.conj(harmonics), harmonics + ) + harmonics = harmonics.permute(1, 0).reshape(B, N, -1).to(dtype) + return harmonics + + def _gen_recurrence_mask(self) -> tuple[torch.Tensor, torch.Tensor]: + """Generates mask for recurrence relation on the remaining entries. + + The remaining entries are with respect to the diagonal and offdiagonal + entries. + + Args: + l_max: see `gen_normalized_legendre`. + Returns: + torch.Tensors representing the mask used by the recurrence relations. + """ + + # Computes all coefficients. + m_mat, l_mat = torch.meshgrid( + torch.arange(0, self.l_max + 1, device=self.device, dtype=self.dtype), + torch.arange(0, self.l_max + 1, device=self.device, dtype=self.dtype), + indexing="ij", + ) + if self.is_normalized: + c0 = l_mat * l_mat + c1 = m_mat * m_mat + c2 = 2.0 * l_mat + c3 = (l_mat - 1.0) * (l_mat - 1.0) + d0 = torch.sqrt((4.0 * c0 - 1.0) / (c0 - c1)) + d1 = torch.sqrt(((c2 + 1.0) * (c3 - c1)) / ((c2 - 3.0) * (c0 - c1))) + else: + d0 = (2.0 * l_mat - 1.0) / (l_mat - m_mat) + d1 = (l_mat + m_mat - 1.0) / (l_mat - m_mat) + + d0_mask_indices = torch.triu_indices(self.l_max + 1, 1) + d1_mask_indices = torch.triu_indices(self.l_max + 1, 2) + + d_zeros = torch.zeros( + (self.l_max + 1, self.l_max + 1), dtype=self.dtype, device=self.device + ) + d_zeros[d0_mask_indices] = d0[d0_mask_indices] + d0_mask = d_zeros + + d_zeros = torch.zeros( + (self.l_max + 1, self.l_max + 1), dtype=self.dtype, device=self.device + ) + d_zeros[d1_mask_indices] = d1[d1_mask_indices] + d1_mask = d_zeros + + # Creates a 3D mask that contains 1s on the diagonal plane and 0s elsewhere. + i = torch.arange(self.l_max + 1, device=self.device)[:, None, None] + j = torch.arange(self.l_max + 1, device=self.device)[None, :, None] + k = torch.arange(self.l_max + 1, device=self.device)[None, None, :] + mask = (i + j - k == 0).to(self.dtype) + d0_mask_3d = torch.einsum("jk,ijk->ijk", d0_mask, mask) + d1_mask_3d = torch.einsum("jk,ijk->ijk", d1_mask, mask) + return (d0_mask_3d, d1_mask_3d) + + def _recursive(self, i: int, p_val: torch.Tensor, x: torch.Tensor) -> torch.Tensor: + coeff_0 = self.d0_mask_3d[i] + coeff_1 = self.d1_mask_3d[i] + h = torch.einsum( + "ij,ijk->ijk", + coeff_0, + torch.einsum("ijk,k->ijk", torch.roll(p_val, shifts=1, dims=1), x), + ) - torch.einsum("ij,ijk->ijk", coeff_1, torch.roll(p_val, shifts=2, dims=1)) + p_val = p_val + h + return p_val + + def _init_legendre(self): + a_idx = torch.arange(1, self.l_max + 1, dtype=self.dtype, device=self.device) + b_idx = torch.arange(self.l_max, dtype=self.dtype, device=self.device) + if self.is_normalized: + # The initial value p(0,0). + initial_value: torch.Tensor = torch.tensor( + 0.5 / (torch.pi**0.5), device=self.device + ) + f_a = torch.cumprod(-1 * torch.sqrt(1.0 + 0.5 / a_idx), dim=0) + f_b = torch.sqrt(2.0 * b_idx + 3.0) + else: + # The initial value p(0,0). + initial_value = torch.tensor(1.0, device=self.device) + f_a = torch.cumprod(1.0 - 2.0 * a_idx, dim=0) + f_b = 2.0 * b_idx + 1.0 + + d0_mask_3d, d1_mask_3d = self._gen_recurrence_mask() + return f_a, f_b, initial_value, d0_mask_3d, d1_mask_3d + + def _gen_associated_legendre(self, x: torch.Tensor) -> torch.Tensor: + r"""Computes associated Legendre functions (ALFs) of the first kind. + + The ALFs of the first kind are used in spherical harmonics. The spherical + harmonic of degree `l` and order `m` can be written as + `Y_l^m(θ, φ) = N_l^m * P_l^m(cos(θ)) * exp(i m φ)`, where `N_l^m` is the + normalization factor and θ and φ are the colatitude and longitude, + repectively. `N_l^m` is chosen in the way that the spherical harmonics form + a set of orthonormal basis function of L^2(S^2). For the computational + efficiency of spherical harmonics transform, the normalization factor is + used in the computation of the ALFs. In addition, normalizing `P_l^m` + avoids overflow/underflow and achieves better numerical stability. Three + recurrence relations are used in the computation. + + Args: + l_max: The maximum degree of the associated Legendre function. Both the + degrees and orders are `[0, 1, 2, ..., l_max]`. + x: A vector of type `float32`, `float64` containing the sampled points in + spherical coordinates, at which the ALFs are computed; `x` is essentially + `cos(θ)`. For the numerical integration used by the spherical harmonics + transforms, `x` contains the quadrature points in the interval of + `[-1, 1]`. There are several approaches to provide the quadrature points: + Gauss-Legendre method (`scipy.special.roots_legendre`), Gauss-Chebyshev + method (`scipy.special.roots_chebyu`), and Driscoll & Healy + method (Driscoll, James R., and Dennis M. Healy. "Computing Fourier + transforms and convolutions on the 2-sphere." Advances in applied + mathematics 15, no. 2 (1994): 202-250.). The Gauss-Legendre quadrature + points are nearly equal-spaced along θ and provide exact discrete + orthogonality, (P^m)^T W P_m = I, where `T` represents the transpose + operation, `W` is a diagonal matrix containing the quadrature weights, + and `I` is the identity matrix. The Gauss-Chebyshev points are equally + spaced, which only provide approximate discrete orthogonality. The + Driscoll & Healy qudarture points are equally spaced and provide the + exact discrete orthogonality. The number of sampling points is required to + be twice as the number of frequency points (modes) in the Driscoll & Healy + approach, which enables FFT and achieves a fast spherical harmonics + transform. + is_normalized: True if the associated Legendre functions are normalized. + With normalization, `N_l^m` is applied such that the spherical harmonics + form a set of orthonormal basis functions of L^2(S^2). + + Returns: + The 3D array of shape `(l_max + 1, l_max + 1, len(x))` containing the values + of the ALFs at `x`; the dimensions in the sequence of order, degree, and + evalution points. + """ + p = torch.zeros( + (self.l_max + 1, self.l_max + 1, x.shape[0]), dtype=x.dtype, device=x.device + ) + p[0, 0] = self.initial_value + + # Compute the diagonal entries p(l,l) with recurrence. + y = torch.cumprod( + torch.broadcast_to(torch.sqrt(1.0 - x * x), (self.l_max, x.shape[0])), dim=0 + ) + p_diag = self.initial_value * torch.einsum("i,ij->ij", self.f_a, y) + # torch.diag_indices(l_max + 1) + diag_indices = torch.stack( + [torch.arange(0, self.l_max + 1, device=x.device)] * 2, dim=0 + ) + p[(diag_indices[0][1:], diag_indices[1][1:])] = p_diag + + diag_indices = torch.stack( + [torch.arange(0, self.l_max, device=x.device)] * 2, dim=0 + ) + + # Compute the off-diagonal entries with recurrence. + p_offdiag = torch.einsum( + "ij,ij->ij", + torch.einsum("i,j->ij", self.f_b, x), + p[(diag_indices[0], diag_indices[1])], + ) # p[torch.diag_indices(l_max)]) + p[(diag_indices[0][: self.l_max], diag_indices[1][: self.l_max] + 1)] = ( + p_offdiag + ) + + # Compute the remaining entries with recurrence. + if self.l_max > 1: + for i in range(2, self.l_max + 1): + p = self._recursive(i, p, x) + return p diff --git a/unidepth/utils/visualization.py b/unidepth/utils/visualization.py new file mode 100644 index 0000000000000000000000000000000000000000..85beba496af2bd1f178d63b801e86a2abe011185 --- /dev/null +++ b/unidepth/utils/visualization.py @@ -0,0 +1,201 @@ +""" +Author: Luigi Piccinelli +Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/) +""" + +import os + +import matplotlib.pyplot as plt +import numpy as np +import torch +import wandb +from PIL import Image + +from unidepth.utils.misc import ssi_helper + + +def colorize( + value: np.ndarray, vmin: float = None, vmax: float = None, cmap: str = "magma_r" +): + # if already RGB, do nothing + if value.ndim > 2: + if value.shape[-1] > 1: + return value + value = value[..., 0] + invalid_mask = value < 0.0001 + # normalize + vmin = value.min() if vmin is None else vmin + vmax = value.max() if vmax is None else vmax + value = (value - vmin) / (vmax - vmin) # vmin..vmax + + # set color + cmapper = plt.get_cmap(cmap) + value = cmapper(value, bytes=True) # (nxmx4) + value[invalid_mask] = 0 + img = value[..., :3] + return img + + +def image_grid(imgs: list[np.ndarray], rows: int, cols: int) -> np.ndarray: + if not len(imgs): + return None + assert len(imgs) == rows * cols + h, w = imgs[0].shape[:2] + grid = Image.new("RGB", size=(cols * w, rows * h)) + + for i, img in enumerate(imgs): + grid.paste( + Image.fromarray(img.astype(np.uint8)).resize( + (w, h), resample=Image.BILINEAR + ), + box=(i % cols * w, i // cols * h), + ) + + return np.array(grid) + + +def get_pointcloud_from_rgbd( + image: np.array, + depth: np.array, + mask: np.ndarray, + intrinsic_matrix: np.array, + extrinsic_matrix: np.array = None, +): + depth = np.array(depth).squeeze() + mask = np.array(mask).squeeze() + # Mask the depth array + masked_depth = np.ma.masked_where(mask == False, depth) + # masked_depth = np.ma.masked_greater(masked_depth, 8000) + # Create idx array + idxs = np.indices(masked_depth.shape) + u_idxs = idxs[1] + v_idxs = idxs[0] + # Get only non-masked depth and idxs + z = masked_depth[~masked_depth.mask] + compressed_u_idxs = u_idxs[~masked_depth.mask] + compressed_v_idxs = v_idxs[~masked_depth.mask] + image = np.stack( + [image[..., i][~masked_depth.mask] for i in range(image.shape[-1])], axis=-1 + ) + + # Calculate local position of each point + # Apply vectorized math to depth using compressed arrays + cx = intrinsic_matrix[0, 2] + fx = intrinsic_matrix[0, 0] + x = (compressed_u_idxs - cx) * z / fx + cy = intrinsic_matrix[1, 2] + fy = intrinsic_matrix[1, 1] + # Flip y as we want +y pointing up not down + y = -((compressed_v_idxs - cy) * z / fy) + + # # Apply camera_matrix to pointcloud as to get the pointcloud in world coords + # if extrinsic_matrix is not None: + # # Calculate camera pose from extrinsic matrix + # camera_matrix = np.linalg.inv(extrinsic_matrix) + # # Create homogenous array of vectors by adding 4th entry of 1 + # # At the same time flip z as for eye space the camera is looking down the -z axis + # w = np.ones(z.shape) + # x_y_z_eye_hom = np.vstack((x, y, -z, w)) + # # Transform the points from eye space to world space + # x_y_z_world = np.dot(camera_matrix, x_y_z_eye_hom)[:3] + # return x_y_z_world.T + # else: + x_y_z_local = np.stack((x, y, z), axis=-1) + return np.concatenate([x_y_z_local, image], axis=-1) + + +def save_file_ply(xyz, rgb, pc_file): + if rgb.max() < 1.001: + rgb = rgb * 255.0 + rgb = rgb.astype(np.uint8) + # print(rgb) + with open(pc_file, "w") as f: + # headers + f.writelines( + [ + "ply\n" "format ascii 1.0\n", + "element vertex {}\n".format(xyz.shape[0]), + "property float x\n", + "property float y\n", + "property float z\n", + "property uchar red\n", + "property uchar green\n", + "property uchar blue\n", + "end_header\n", + ] + ) + + for i in range(xyz.shape[0]): + str_v = "{:10.6f} {:10.6f} {:10.6f} {:d} {:d} {:d}\n".format( + xyz[i, 0], xyz[i, 1], xyz[i, 2], rgb[i, 0], rgb[i, 1], rgb[i, 2] + ) + f.write(str_v) + + +# really awful fct... FIXME +def log_train_artifacts(rgbs, gts, preds, ds_name, step, infos={}): + rgbs = [ + (127.5 * (rgb + 1)) + .clip(0, 255) + .to(torch.uint8) + .cpu() + .detach() + .permute(1, 2, 0) + .numpy() + for rgb in rgbs + ] + + new_gts, new_preds = [], [] + if len(gts) > 0: + for i, gt in enumerate(gts): + scale, shift = ssi_helper( + gts[i][gts[i] > 0].cpu().detach(), preds[i][gts[i] > 0].cpu().detach() + ) + gt = gts[i].cpu().detach().squeeze().numpy() + pred = (preds[i].cpu().detach() * scale + shift).squeeze().numpy() + vmin = gt[gt > 0].min() if (gt > 0).any() else 0.0 + vmax = gt.max() if (gt > 0).any() else 0.1 + new_gts.append(colorize(gt, vmin=vmin, vmax=vmax)) + new_preds.append(colorize(pred, vmin=vmin, vmax=vmax)) + gts, preds = new_gts, new_preds + else: + preds = [ + colorize(pred.cpu().detach().squeeze().numpy(), 0.0, 80.0) + for i, pred in enumerate(preds) + ] + + num_additional, additionals = 0, [] + for name, info in infos.items(): + num_additional += 1 + if info.shape[1] == 3: + additionals.extend( + [ + (127.5 * (x + 1)) + .clip(0, 255) + .to(torch.uint8) + .cpu() + .detach() + .permute(1, 2, 0) + .numpy() + for x in info[:4] + ] + ) + else: + additionals.extend( + [ + colorize(x.cpu().detach().squeeze().numpy()) + for i, x in enumerate(info[:4]) + ] + ) + + num_rows = 2 + int(len(gts) > 0) + num_additional + artifacts_grid = image_grid( + [*rgbs, *gts, *preds, *additionals], num_rows, len(rgbs) + ) + try: + wandb.log({f"{ds_name}_training": [wandb.Image(artifacts_grid)]}, step=step) + except: + Image.fromarray(artifacts_grid).save( + os.path.join(os.environ["HOME"], "Workspace", f"art_grid{step}.png") + ) + print("Logging training images failed")