diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..525424d02cc3281a973c14423db36e911a273bf8 --- /dev/null +++ b/app.py @@ -0,0 +1,68 @@ +import gradio as gr +import torch +from PIL import Image +import cv2 +import numpy as np +from geobench.modeling.archs.dam.dam import DepthAnything +from geobench.utils.image_util import colorize_depth_maps +from geobench.midas.transforms import Resize, NormalizeImage, PrepareForNet +from torchvision.transforms import Compose +import os + +# Helper function to load model (same as your original code) +def load_model_by_name(arch_name, checkpoint_path, device): + if arch_name == 'depthanything': + if '.safetensors' in checkpoint_path: + model = DepthAnything.from_pretrained(os.path.dirname(checkpoint_path)).to(device) + else: + raise NotImplementedError("Model architecture not implemented.") + else: + raise NotImplementedError(f"Unknown architecture: {arch_name}") + return model + +# Image processing function (same as your original code, modified for Gradio) +def process_image(image, model, device, mode='rel_depth'): + # Preprocess the image + image_np = np.array(image)[..., ::-1] / 255 + transform = Compose([ + Resize(512, 512, resize_target=None, keep_aspect_ratio=False, ensure_multiple_of=32, image_interpolation_method=cv2.INTER_CUBIC), + NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + PrepareForNet() + ]) + + image_tensor = transform({'image': image_np})['image'] + image_tensor = torch.from_numpy(image_tensor).unsqueeze(0).to(device) + + with torch.no_grad(): # Disable autograd since we don't need gradients on CPU + pred_disp, _ = model(image_tensor) + pred_disp_np = pred_disp.cpu().detach().numpy()[0, :, :, :].transpose(1, 2, 0) + pred_disp = (pred_disp_np - pred_disp_np.min()) / (pred_disp_np.max() - pred_disp_np.min()) + + # Colorize depth map + cmap = "Spectral_r" if mode != 'metric' else 'Spectral_r' + depth_colored = colorize_depth_maps(pred_disp[None, ...], 0, 1, cmap=cmap).squeeze() + depth_colored = (depth_colored * 255).astype(np.uint8) + + depth_image = Image.fromarray(depth_colored) + return depth_image + +# Gradio interface function +def gradio_interface(image, mode='rel_depth'): + # Set device to CPU explicitly + device = torch.device("cpu") # Force using CPU + model = load_model_by_name("depthanything", "your_checkpoint_path_here", device) + + # Process image and return output + return process_image(image, model, device, mode) + +# Create Gradio interface +iface = gr.Interface( + fn=gradio_interface, + inputs=[gr.Image(type="pil"), gr.Dropdown(choices=['rel_depth', 'metric_depth', 'disparity'], label="Mode")], + outputs=gr.Image(type="pil"), + title="Depth Estimation Demo", + description="Upload an image to see the depth estimation results." +) + +# Launch the Gradio interface +iface.launch() diff --git a/geobench/depth_anything_v2/__pycache__/dinov2.cpython-310.pyc b/geobench/depth_anything_v2/__pycache__/dinov2.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d6a15915683609339d8129b595c5278022acea03 Binary files /dev/null and b/geobench/depth_anything_v2/__pycache__/dinov2.cpython-310.pyc differ diff --git a/geobench/depth_anything_v2/__pycache__/dpt.cpython-310.pyc b/geobench/depth_anything_v2/__pycache__/dpt.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ea1fbe6a4a8d498304e93c00d3feb022af49eeee Binary files /dev/null and b/geobench/depth_anything_v2/__pycache__/dpt.cpython-310.pyc differ diff --git a/geobench/depth_anything_v2/dinov2.py b/geobench/depth_anything_v2/dinov2.py new file mode 100644 index 0000000000000000000000000000000000000000..83d250818c721c6df3b30d3f4352945527701615 --- /dev/null +++ b/geobench/depth_anything_v2/dinov2.py @@ -0,0 +1,415 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +from functools import partial +import math +import logging +from typing import Sequence, Tuple, Union, Callable + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from torch.nn.init import trunc_normal_ + +from .dinov2_layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block + + +logger = logging.getLogger("dinov2") + + +def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module: + if not depth_first and include_root: + fn(module=module, name=name) + for child_name, child_module in module.named_children(): + child_name = ".".join((name, child_name)) if name else child_name + named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True) + if depth_first and include_root: + fn(module=module, name=name) + return module + + +class BlockChunk(nn.ModuleList): + def forward(self, x): + for b in self: + x = b(x) + return x + + +class DinoVisionTransformer(nn.Module): + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + ffn_bias=True, + proj_bias=True, + drop_path_rate=0.0, + drop_path_uniform=False, + init_values=None, # for layerscale: None or 0 => no layerscale + embed_layer=PatchEmbed, + act_layer=nn.GELU, + block_fn=Block, + ffn_layer="mlp", + block_chunks=1, + num_register_tokens=0, + interpolate_antialias=False, + interpolate_offset=0.1, + ): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + proj_bias (bool): enable bias for proj in attn if True + ffn_bias (bool): enable bias for ffn if True + drop_path_rate (float): stochastic depth rate + drop_path_uniform (bool): apply uniform drop rate across blocks + weight_init (str): weight init scheme + init_values (float): layer-scale init values + embed_layer (nn.Module): patch embedding layer + act_layer (nn.Module): MLP activation layer + block_fn (nn.Module): transformer block class + ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity" + block_chunks: (int) split block sequence into block_chunks units for FSDP wrap + num_register_tokens: (int) number of extra cls tokens (so-called "registers") + interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings + interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings + """ + super().__init__() + norm_layer = partial(nn.LayerNorm, eps=1e-6) + + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_tokens = 1 + self.n_blocks = depth + self.num_heads = num_heads + self.patch_size = patch_size + self.num_register_tokens = num_register_tokens + self.interpolate_antialias = interpolate_antialias + self.interpolate_offset = interpolate_offset + + self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) + assert num_register_tokens >= 0 + self.register_tokens = ( + nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None + ) + + if drop_path_uniform is True: + dpr = [drop_path_rate] * depth + else: + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + + if ffn_layer == "mlp": + logger.info("using MLP layer as FFN") + ffn_layer = Mlp + elif ffn_layer == "swiglufused" or ffn_layer == "swiglu": + logger.info("using SwiGLU layer as FFN") + ffn_layer = SwiGLUFFNFused + elif ffn_layer == "identity": + logger.info("using Identity layer as FFN") + + def f(*args, **kwargs): + return nn.Identity() + + ffn_layer = f + else: + raise NotImplementedError + + blocks_list = [ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + ffn_bias=ffn_bias, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer, + ffn_layer=ffn_layer, + init_values=init_values, + ) + for i in range(depth) + ] + if block_chunks > 0: + self.chunked_blocks = True + chunked_blocks = [] + chunksize = depth // block_chunks + for i in range(0, depth, chunksize): + # this is to keep the block index consistent if we chunk the block list + chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize]) + self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks]) + else: + self.chunked_blocks = False + self.blocks = nn.ModuleList(blocks_list) + + self.norm = norm_layer(embed_dim) + self.head = nn.Identity() + + self.mask_token = nn.Parameter(torch.zeros(1, embed_dim)) + + self.init_weights() + + def init_weights(self): + trunc_normal_(self.pos_embed, std=0.02) + nn.init.normal_(self.cls_token, std=1e-6) + if self.register_tokens is not None: + nn.init.normal_(self.register_tokens, std=1e-6) + named_apply(init_weights_vit_timm, self) + + def interpolate_pos_encoding(self, x, w, h): + previous_dtype = x.dtype + npatch = x.shape[1] - 1 + N = self.pos_embed.shape[1] - 1 + if npatch == N and w == h: + return self.pos_embed + pos_embed = self.pos_embed.float() + class_pos_embed = pos_embed[:, 0] + patch_pos_embed = pos_embed[:, 1:] + dim = x.shape[-1] + w0 = w // self.patch_size + h0 = h // self.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + # DINOv2 with register modify the interpolate_offset from 0.1 to 0.0 + w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset + # w0, h0 = w0 + 0.1, h0 + 0.1 + + sqrt_N = math.sqrt(N) + sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2), + scale_factor=(sx, sy), + # (int(w0), int(h0)), # to solve the upsampling shape issue + mode="bicubic", + antialias=self.interpolate_antialias + ) + + assert int(w0) == patch_pos_embed.shape[-2] + assert int(h0) == patch_pos_embed.shape[-1] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype) + + def prepare_tokens_with_masks(self, x, masks=None): + B, nc, w, h = x.shape + x = self.patch_embed(x) + if masks is not None: + x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x) + + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + x = x + self.interpolate_pos_encoding(x, w, h) + + if self.register_tokens is not None: + x = torch.cat( + ( + x[:, :1], + self.register_tokens.expand(x.shape[0], -1, -1), + x[:, 1:], + ), + dim=1, + ) + + return x + + def forward_features_list(self, x_list, masks_list): + x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)] + for blk in self.blocks: + x = blk(x) + + all_x = x + output = [] + for x, masks in zip(all_x, masks_list): + x_norm = self.norm(x) + output.append( + { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + "x_prenorm": x, + "masks": masks, + } + ) + return output + + def forward_features(self, x, masks=None): + if isinstance(x, list): + return self.forward_features_list(x, masks) + + x = self.prepare_tokens_with_masks(x, masks) + + for blk in self.blocks: + x = blk(x) + + x_norm = self.norm(x) + return { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + "x_prenorm": x, + "masks": masks, + } + + def _get_intermediate_layers_not_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + # If n is an int, take the n last blocks. If it's a list, take them + output, total_block_len = [], len(self.blocks) + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for i, blk in enumerate(self.blocks): + x = blk(x) + if i in blocks_to_take: + output.append(x) + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def _get_intermediate_layers_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + output, i, total_block_len = [], 0, len(self.blocks[-1]) + # If n is an int, take the n last blocks. If it's a list, take them + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for block_chunk in self.blocks: + for blk in block_chunk[i:]: # Passing the nn.Identity() + x = blk(x) + if i in blocks_to_take: + output.append(x) + i += 1 + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def get_intermediate_layers( + self, + x: torch.Tensor, + n: Union[int, Sequence] = 1, # Layers or n last layers to take + reshape: bool = False, + return_class_token: bool = False, + norm=True + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: + if self.chunked_blocks: + outputs = self._get_intermediate_layers_chunked(x, n) + else: + outputs = self._get_intermediate_layers_not_chunked(x, n) + if norm: + outputs = [self.norm(out) for out in outputs] + class_tokens = [out[:, 0] for out in outputs] + outputs = [out[:, 1 + self.num_register_tokens:] for out in outputs] + if reshape: + B, _, w, h = x.shape + outputs = [ + out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous() + for out in outputs + ] + if return_class_token: + return tuple(zip(outputs, class_tokens)) + return tuple(outputs) + + def forward(self, *args, is_training=False, **kwargs): + ret = self.forward_features(*args, **kwargs) + if is_training: + return ret + else: + return self.head(ret["x_norm_clstoken"]) + + +def init_weights_vit_timm(module: nn.Module, name: str = ""): + """ViT weight initialization, original timm impl (for reproducibility)""" + if isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + + +def vit_small(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=384, + depth=12, + num_heads=6, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_base(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_large(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs): + """ + Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64 + """ + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1536, + depth=40, + num_heads=24, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def DINOv2(model_name): + model_zoo = { + "vits": vit_small, + "vitb": vit_base, + "vitl": vit_large, + "vitg": vit_giant2 + } + + return model_zoo[model_name]( + img_size=518, + patch_size=14, + init_values=1.0, + ffn_layer="mlp" if model_name != "vitg" else "swiglufused", + block_chunks=0, + num_register_tokens=0, + interpolate_antialias=False, + interpolate_offset=0.1 + ) diff --git a/geobench/depth_anything_v2/dinov2_layers/__init__.py b/geobench/depth_anything_v2/dinov2_layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8120f4bc83066cb3f825ce32daa3b437f88486f1 --- /dev/null +++ b/geobench/depth_anything_v2/dinov2_layers/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .mlp import Mlp +from .patch_embed import PatchEmbed +from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused +from .block import NestedTensorBlock +from .attention import MemEffAttention diff --git a/geobench/depth_anything_v2/dinov2_layers/__pycache__/__init__.cpython-310.pyc b/geobench/depth_anything_v2/dinov2_layers/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8dd393fc567b9cc65872afe47346e76a761b7abb Binary files /dev/null and b/geobench/depth_anything_v2/dinov2_layers/__pycache__/__init__.cpython-310.pyc differ diff --git a/geobench/depth_anything_v2/dinov2_layers/__pycache__/attention.cpython-310.pyc b/geobench/depth_anything_v2/dinov2_layers/__pycache__/attention.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..da0e2f459524df7e47010fe8e1340059269c7388 Binary files /dev/null and b/geobench/depth_anything_v2/dinov2_layers/__pycache__/attention.cpython-310.pyc differ diff --git a/geobench/depth_anything_v2/dinov2_layers/__pycache__/block.cpython-310.pyc b/geobench/depth_anything_v2/dinov2_layers/__pycache__/block.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..902440fa2a34abf52c8150768dcb9d773692ddaa Binary files /dev/null and b/geobench/depth_anything_v2/dinov2_layers/__pycache__/block.cpython-310.pyc differ diff --git a/geobench/depth_anything_v2/dinov2_layers/__pycache__/drop_path.cpython-310.pyc b/geobench/depth_anything_v2/dinov2_layers/__pycache__/drop_path.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b765b50891c1da6b2bef79430a28cfd960da95f7 Binary files /dev/null and b/geobench/depth_anything_v2/dinov2_layers/__pycache__/drop_path.cpython-310.pyc differ diff --git a/geobench/depth_anything_v2/dinov2_layers/__pycache__/layer_scale.cpython-310.pyc b/geobench/depth_anything_v2/dinov2_layers/__pycache__/layer_scale.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eb66da3f8831a4ccac5ed828abf2ffcb6f51443f Binary files /dev/null and b/geobench/depth_anything_v2/dinov2_layers/__pycache__/layer_scale.cpython-310.pyc differ diff --git a/geobench/depth_anything_v2/dinov2_layers/__pycache__/mlp.cpython-310.pyc b/geobench/depth_anything_v2/dinov2_layers/__pycache__/mlp.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e322d677121e7cf693fabe95e472a2bb20977151 Binary files /dev/null and b/geobench/depth_anything_v2/dinov2_layers/__pycache__/mlp.cpython-310.pyc differ diff --git a/geobench/depth_anything_v2/dinov2_layers/__pycache__/patch_embed.cpython-310.pyc b/geobench/depth_anything_v2/dinov2_layers/__pycache__/patch_embed.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e5ccb963bcf5f28c8928f140544fae24a2e6c569 Binary files /dev/null and b/geobench/depth_anything_v2/dinov2_layers/__pycache__/patch_embed.cpython-310.pyc differ diff --git a/geobench/depth_anything_v2/dinov2_layers/__pycache__/swiglu_ffn.cpython-310.pyc b/geobench/depth_anything_v2/dinov2_layers/__pycache__/swiglu_ffn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2605c9d6468c4fff8adefe7feb07fa06e956801f Binary files /dev/null and b/geobench/depth_anything_v2/dinov2_layers/__pycache__/swiglu_ffn.cpython-310.pyc differ diff --git a/geobench/depth_anything_v2/dinov2_layers/attention.py b/geobench/depth_anything_v2/dinov2_layers/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..815a2bf53dbec496f6a184ed7d03bcecb7124262 --- /dev/null +++ b/geobench/depth_anything_v2/dinov2_layers/attention.py @@ -0,0 +1,83 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +import logging + +from torch import Tensor +from torch import nn + + +logger = logging.getLogger("dinov2") + + +try: + from xformers.ops import memory_efficient_attention, unbind, fmha + + XFORMERS_AVAILABLE = True +except ImportError: + logger.warning("xFormers not available") + XFORMERS_AVAILABLE = False + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + ) -> None: + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: Tensor) -> Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + + q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] + attn = q @ k.transpose(-2, -1) + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class MemEffAttention(Attention): + def forward(self, x: Tensor, attn_bias=None) -> Tensor: + if not XFORMERS_AVAILABLE: + assert attn_bias is None, "xFormers is required for nested tensors usage" + return super().forward(x) + + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) + + q, k, v = unbind(qkv, 2) + + x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) + x = x.reshape([B, N, C]) + + x = self.proj(x) + x = self.proj_drop(x) + return x + + \ No newline at end of file diff --git a/geobench/depth_anything_v2/dinov2_layers/block.py b/geobench/depth_anything_v2/dinov2_layers/block.py new file mode 100644 index 0000000000000000000000000000000000000000..25488f57cc0ad3c692f86b62555f6668e2a66db1 --- /dev/null +++ b/geobench/depth_anything_v2/dinov2_layers/block.py @@ -0,0 +1,252 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +import logging +from typing import Callable, List, Any, Tuple, Dict + +import torch +from torch import nn, Tensor + +from .attention import Attention, MemEffAttention +from .drop_path import DropPath +from .layer_scale import LayerScale +from .mlp import Mlp + + +logger = logging.getLogger("dinov2") + + +try: + from xformers.ops import fmha + from xformers.ops import scaled_index_add, index_select_cat + + XFORMERS_AVAILABLE = True +except ImportError: + logger.warning("xFormers not available") + XFORMERS_AVAILABLE = False + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + proj_bias: bool = True, + ffn_bias: bool = True, + drop: float = 0.0, + attn_drop: float = 0.0, + init_values=None, + drop_path: float = 0.0, + act_layer: Callable[..., nn.Module] = nn.GELU, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_class: Callable[..., nn.Module] = Attention, + ffn_layer: Callable[..., nn.Module] = Mlp, + ) -> None: + super().__init__() + # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}") + self.norm1 = norm_layer(dim) + self.attn = attn_class( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + attn_drop=attn_drop, + proj_drop=drop, + ) + self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = ffn_layer( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + bias=ffn_bias, + ) + self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.sample_drop_ratio = drop_path + + def forward(self, x: Tensor) -> Tensor: + def attn_residual_func(x: Tensor) -> Tensor: + return self.ls1(self.attn(self.norm1(x))) + + def ffn_residual_func(x: Tensor) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + if self.training and self.sample_drop_ratio > 0.1: + # the overhead is compensated only for a drop path rate larger than 0.1 + x = drop_add_residual_stochastic_depth( + x, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + x = drop_add_residual_stochastic_depth( + x, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + elif self.training and self.sample_drop_ratio > 0.0: + x = x + self.drop_path1(attn_residual_func(x)) + x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 + else: + x = x + attn_residual_func(x) + x = x + ffn_residual_func(x) + return x + + +def drop_add_residual_stochastic_depth( + x: Tensor, + residual_func: Callable[[Tensor], Tensor], + sample_drop_ratio: float = 0.0, +) -> Tensor: + # 1) extract subset using permutation + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + x_subset = x[brange] + + # 2) apply residual_func to get residual + residual = residual_func(x_subset) + + x_flat = x.flatten(1) + residual = residual.flatten(1) + + residual_scale_factor = b / sample_subset_size + + # 3) add the residual + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + return x_plus_residual.view_as(x) + + +def get_branges_scales(x, sample_drop_ratio=0.0): + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + residual_scale_factor = b / sample_subset_size + return brange, residual_scale_factor + + +def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None): + if scaling_vector is None: + x_flat = x.flatten(1) + residual = residual.flatten(1) + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + else: + x_plus_residual = scaled_index_add( + x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor + ) + return x_plus_residual + + +attn_bias_cache: Dict[Tuple, Any] = {} + + +def get_attn_bias_and_cat(x_list, branges=None): + """ + this will perform the index select, cat the tensors, and provide the attn_bias from cache + """ + batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list] + all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list)) + if all_shapes not in attn_bias_cache.keys(): + seqlens = [] + for b, x in zip(batch_sizes, x_list): + for _ in range(b): + seqlens.append(x.shape[1]) + attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) + attn_bias._batch_sizes = batch_sizes + attn_bias_cache[all_shapes] = attn_bias + + if branges is not None: + cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1]) + else: + tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list) + cat_tensors = torch.cat(tensors_bs1, dim=1) + + return attn_bias_cache[all_shapes], cat_tensors + + +def drop_add_residual_stochastic_depth_list( + x_list: List[Tensor], + residual_func: Callable[[Tensor, Any], Tensor], + sample_drop_ratio: float = 0.0, + scaling_vector=None, +) -> Tensor: + # 1) generate random set of indices for dropping samples in the batch + branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list] + branges = [s[0] for s in branges_scales] + residual_scale_factors = [s[1] for s in branges_scales] + + # 2) get attention bias and index+concat the tensors + attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges) + + # 3) apply residual_func to get residual, and split the result + residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore + + outputs = [] + for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors): + outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x)) + return outputs + + +class NestedTensorBlock(Block): + def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]: + """ + x_list contains a list of tensors to nest together and run + """ + assert isinstance(self.attn, MemEffAttention) + + if self.training and self.sample_drop_ratio > 0.0: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.attn(self.norm1(x), attn_bias=attn_bias) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.mlp(self.norm2(x)) + + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None, + ) + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None, + ) + return x_list + else: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias)) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + attn_bias, x = get_attn_bias_and_cat(x_list) + x = x + attn_residual_func(x, attn_bias=attn_bias) + x = x + ffn_residual_func(x) + return attn_bias.split(x) + + def forward(self, x_or_x_list): + if isinstance(x_or_x_list, Tensor): + return super().forward(x_or_x_list) + elif isinstance(x_or_x_list, list): + assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage" + return self.forward_nested(x_or_x_list) + else: + raise AssertionError diff --git a/geobench/depth_anything_v2/dinov2_layers/drop_path.py b/geobench/depth_anything_v2/dinov2_layers/drop_path.py new file mode 100644 index 0000000000000000000000000000000000000000..af05625984dd14682cc96a63bf0c97bab1f123b1 --- /dev/null +++ b/geobench/depth_anything_v2/dinov2_layers/drop_path.py @@ -0,0 +1,35 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py + + +from torch import nn + + +def drop_path(x, drop_prob: float = 0.0, training: bool = False): + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0: + random_tensor.div_(keep_prob) + output = x * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) diff --git a/geobench/depth_anything_v2/dinov2_layers/layer_scale.py b/geobench/depth_anything_v2/dinov2_layers/layer_scale.py new file mode 100644 index 0000000000000000000000000000000000000000..ca5daa52bd81d3581adeb2198ea5b7dba2a3aea1 --- /dev/null +++ b/geobench/depth_anything_v2/dinov2_layers/layer_scale.py @@ -0,0 +1,28 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 + +from typing import Union + +import torch +from torch import Tensor +from torch import nn + + +class LayerScale(nn.Module): + def __init__( + self, + dim: int, + init_values: Union[float, Tensor] = 1e-5, + inplace: bool = False, + ) -> None: + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: Tensor) -> Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma diff --git a/geobench/depth_anything_v2/dinov2_layers/mlp.py b/geobench/depth_anything_v2/dinov2_layers/mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..5e4b315f972f9a9f54aef1e4ef4e81b52976f018 --- /dev/null +++ b/geobench/depth_anything_v2/dinov2_layers/mlp.py @@ -0,0 +1,41 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py + + +from typing import Callable, Optional + +from torch import Tensor, nn + + +class Mlp(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = nn.GELU, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) + self.drop = nn.Dropout(drop) + + def forward(self, x: Tensor) -> Tensor: + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x diff --git a/geobench/depth_anything_v2/dinov2_layers/patch_embed.py b/geobench/depth_anything_v2/dinov2_layers/patch_embed.py new file mode 100644 index 0000000000000000000000000000000000000000..574abe41175568d700a389b8b96d1ba554914779 --- /dev/null +++ b/geobench/depth_anything_v2/dinov2_layers/patch_embed.py @@ -0,0 +1,89 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +from typing import Callable, Optional, Tuple, Union + +from torch import Tensor +import torch.nn as nn + + +def make_2tuple(x): + if isinstance(x, tuple): + assert len(x) == 2 + return x + + assert isinstance(x, int) + return (x, x) + + +class PatchEmbed(nn.Module): + """ + 2D image to patch embedding: (B,C,H,W) -> (B,N,D) + + Args: + img_size: Image size. + patch_size: Patch token size. + in_chans: Number of input image channels. + embed_dim: Number of linear projection output channels. + norm_layer: Normalization layer. + """ + + def __init__( + self, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 16, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer: Optional[Callable] = None, + flatten_embedding: bool = True, + ) -> None: + super().__init__() + + image_HW = make_2tuple(img_size) + patch_HW = make_2tuple(patch_size) + patch_grid_size = ( + image_HW[0] // patch_HW[0], + image_HW[1] // patch_HW[1], + ) + + self.img_size = image_HW + self.patch_size = patch_HW + self.patches_resolution = patch_grid_size + self.num_patches = patch_grid_size[0] * patch_grid_size[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.flatten_embedding = flatten_embedding + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x: Tensor) -> Tensor: + _, _, H, W = x.shape + patch_H, patch_W = self.patch_size + + assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}" + assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}" + + x = self.proj(x) # B C H W + H, W = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) # B HW C + x = self.norm(x) + if not self.flatten_embedding: + x = x.reshape(-1, H, W, self.embed_dim) # B H W C + return x + + def flops(self) -> float: + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops diff --git a/geobench/depth_anything_v2/dinov2_layers/swiglu_ffn.py b/geobench/depth_anything_v2/dinov2_layers/swiglu_ffn.py new file mode 100644 index 0000000000000000000000000000000000000000..b3324b266fb0a50ccf8c3a0ede2ae10ac4dfa03e --- /dev/null +++ b/geobench/depth_anything_v2/dinov2_layers/swiglu_ffn.py @@ -0,0 +1,63 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Callable, Optional + +from torch import Tensor, nn +import torch.nn.functional as F + + +class SwiGLUFFN(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) + self.w3 = nn.Linear(hidden_features, out_features, bias=bias) + + def forward(self, x: Tensor) -> Tensor: + x12 = self.w12(x) + x1, x2 = x12.chunk(2, dim=-1) + hidden = F.silu(x1) * x2 + return self.w3(hidden) + + +try: + from xformers.ops import SwiGLU + + XFORMERS_AVAILABLE = True +except ImportError: + SwiGLU = SwiGLUFFN + XFORMERS_AVAILABLE = False + + +class SwiGLUFFNFused(SwiGLU): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + out_features = out_features or in_features + hidden_features = hidden_features or in_features + hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 + super().__init__( + in_features=in_features, + hidden_features=hidden_features, + out_features=out_features, + bias=bias, + ) diff --git a/geobench/depth_anything_v2/dpt.py b/geobench/depth_anything_v2/dpt.py new file mode 100644 index 0000000000000000000000000000000000000000..b3b4ee82f4813a23777048645bf4afe87d35b121 --- /dev/null +++ b/geobench/depth_anything_v2/dpt.py @@ -0,0 +1,230 @@ +import cv2 +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision.transforms import Compose + +from .dinov2 import DINOv2 +from .util.blocks import FeatureFusionBlock, _make_scratch +from .util.transform import Resize, NormalizeImage, PrepareForNet +from huggingface_hub import PyTorchModelHubMixin, hf_hub_download +from geobench.modeling.necks.image_projector import ImageProjModel + + +def _make_fusion_block(features, use_bn, size=None): + return FeatureFusionBlock( + features, + nn.ReLU(False), + deconv=False, + bn=use_bn, + expand=False, + align_corners=True, + size=size, + ) + + +class ConvBlock(nn.Module): + def __init__(self, in_feature, out_feature): + super().__init__() + + self.conv_block = nn.Sequential( + nn.Conv2d(in_feature, out_feature, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(out_feature), + nn.ReLU(True) + ) + + def forward(self, x): + return self.conv_block(x) + + +class DPTHead(nn.Module): + def __init__( + self, + in_channels, + features=256, + use_bn=False, + out_channels=[256, 512, 1024, 1024], + use_clstoken=False + ): + super(DPTHead, self).__init__() + + self.use_clstoken = use_clstoken + + self.projects = nn.ModuleList([ + nn.Conv2d( + in_channels=in_channels, + out_channels=out_channel, + kernel_size=1, + stride=1, + padding=0, + ) for out_channel in out_channels + ]) + + self.resize_layers = nn.ModuleList([ + nn.ConvTranspose2d( + in_channels=out_channels[0], + out_channels=out_channels[0], + kernel_size=4, + stride=4, + padding=0), + nn.ConvTranspose2d( + in_channels=out_channels[1], + out_channels=out_channels[1], + kernel_size=2, + stride=2, + padding=0), + nn.Identity(), + nn.Conv2d( + in_channels=out_channels[3], + out_channels=out_channels[3], + kernel_size=3, + stride=2, + padding=1) + ]) + + if use_clstoken: + self.readout_projects = nn.ModuleList() + for _ in range(len(self.projects)): + self.readout_projects.append( + nn.Sequential( + nn.Linear(2 * in_channels, in_channels), + nn.GELU())) + + self.scratch = _make_scratch( + out_channels, + features, + groups=1, + expand=False, + ) + + self.scratch.stem_transpose = None + + self.scratch.refinenet1 = _make_fusion_block(features, use_bn) + self.scratch.refinenet2 = _make_fusion_block(features, use_bn) + self.scratch.refinenet3 = _make_fusion_block(features, use_bn) + self.scratch.refinenet4 = _make_fusion_block(features, use_bn) + + head_features_1 = features + head_features_2 = 32 + + self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1) + self.scratch.output_conv2 = nn.Sequential( + nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True), + nn.Identity(), + ) + + def forward(self, out_features, patch_h, patch_w): + out = [] + for i, x in enumerate(out_features): + if self.use_clstoken: + x, cls_token = x[0], x[1] + readout = cls_token.unsqueeze(1).expand_as(x) + x = self.readout_projects[i](torch.cat((x, readout), -1)) + else: + x = x[0] + + x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w)) + + x = self.projects[i](x) + x = self.resize_layers[i](x) + + out.append(x) + + layer_1, layer_2, layer_3, layer_4 = out + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:]) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:]) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:]) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv1(path_1) + out = F.interpolate(out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True) + out = self.scratch.output_conv2(out) + + return out + + +class DepthAnythingV2(nn.Module, PyTorchModelHubMixin): + def __init__( + self, + encoder='vitl', + features=256, + out_channels=[256, 512, 1024, 1024], + use_bn=False, + use_clstoken=False + ): + super(DepthAnythingV2, self).__init__() + + self.intermediate_layer_idx = { + 'vits': [2, 5, 8, 11], + 'vitb': [2, 5, 8, 11], + 'vitl': [4, 11, 17, 23], + 'vitg': [9, 19, 29, 39] + } + + self.encoder = encoder + self.pretrained = DINOv2(model_name=encoder) + # self.proj = ImageProjModel() + + self.depth_head = DPTHead(self.pretrained.embed_dim, features, use_bn, out_channels=out_channels, use_clstoken=use_clstoken) + + def forward(self, x): + bs, _, h, w = x.shape + + patch_h, patch_w = x.shape[-2] // 14, x.shape[-1] // 14 + + features = self.pretrained.get_intermediate_layers(x, self.intermediate_layer_idx[self.encoder], return_class_token=True) + # features_output = self.proj(features[3][0]) + + depth = self.depth_head(features, patch_h, patch_w) + # depth = F.interpolate(depth, size=(h, w), mode="bilinear", align_corners=True) + + depth = F.relu(depth) + # import pdb; pdb.set_trace() + # return depth.squeeze(1), features[3][0] + return depth, features[3][0] + + @torch.no_grad() + def infer_image(self, raw_image, input_size=518): + image, (h, w) = self.image2tensor(raw_image, input_size) + + depth = self.forward(image) + + depth = F.interpolate(depth[:, None], (h, w), mode="bilinear", align_corners=True)[0, 0] + + return depth.cpu().numpy() + + def image2tensor(self, raw_image, input_size=518): + transform = Compose([ + Resize( + width=input_size, + height=input_size, + resize_target=False, + keep_aspect_ratio=True, + ensure_multiple_of=14, + resize_method='lower_bound', + image_interpolation_method=cv2.INTER_CUBIC, + ), + NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + PrepareForNet(), + ]) + + h, w = raw_image.shape[:2] + + image = cv2.cvtColor(raw_image, cv2.COLOR_BGR2RGB) / 255.0 + + image = transform({'image': image})['image'] + image = torch.from_numpy(image).unsqueeze(0) + + DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu' + image = image.to(DEVICE) + + return image, (h, w) diff --git a/geobench/depth_anything_v2/util/__pycache__/blocks.cpython-310.pyc b/geobench/depth_anything_v2/util/__pycache__/blocks.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4631cad91e22d3f0279d45131d8428c7f11cfa6c Binary files /dev/null and b/geobench/depth_anything_v2/util/__pycache__/blocks.cpython-310.pyc differ diff --git a/geobench/depth_anything_v2/util/__pycache__/transform.cpython-310.pyc b/geobench/depth_anything_v2/util/__pycache__/transform.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1cae45860398bc552366ab59195cc36d43870606 Binary files /dev/null and b/geobench/depth_anything_v2/util/__pycache__/transform.cpython-310.pyc differ diff --git a/geobench/depth_anything_v2/util/blocks.py b/geobench/depth_anything_v2/util/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..382ea183a40264056142afffc201c992a2b01d37 --- /dev/null +++ b/geobench/depth_anything_v2/util/blocks.py @@ -0,0 +1,148 @@ +import torch.nn as nn + + +def _make_scratch(in_shape, out_shape, groups=1, expand=False): + scratch = nn.Module() + + out_shape1 = out_shape + out_shape2 = out_shape + out_shape3 = out_shape + if len(in_shape) >= 4: + out_shape4 = out_shape + + if expand: + out_shape1 = out_shape + out_shape2 = out_shape * 2 + out_shape3 = out_shape * 4 + if len(in_shape) >= 4: + out_shape4 = out_shape * 8 + + scratch.layer1_rn = nn.Conv2d(in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups) + scratch.layer2_rn = nn.Conv2d(in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups) + scratch.layer3_rn = nn.Conv2d(in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups) + if len(in_shape) >= 4: + scratch.layer4_rn = nn.Conv2d(in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups) + + return scratch + + +class ResidualConvUnit(nn.Module): + """Residual convolution module. + """ + + def __init__(self, features, activation, bn): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.bn = bn + + self.groups=1 + + self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups) + + self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups) + + if self.bn == True: + self.bn1 = nn.BatchNorm2d(features) + self.bn2 = nn.BatchNorm2d(features) + + self.activation = activation + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + + out = self.activation(x) + out = self.conv1(out) + if self.bn == True: + out = self.bn1(out) + + out = self.activation(out) + out = self.conv2(out) + if self.bn == True: + out = self.bn2(out) + + if self.groups > 1: + out = self.conv_merge(out) + + return self.skip_add.add(out, x) + + +class FeatureFusionBlock(nn.Module): + """Feature fusion block. + """ + + def __init__( + self, + features, + activation, + deconv=False, + bn=False, + expand=False, + align_corners=True, + size=None + ): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock, self).__init__() + + self.deconv = deconv + self.align_corners = align_corners + + self.groups=1 + + self.expand = expand + out_features = features + if self.expand == True: + out_features = features // 2 + + self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1) + + self.resConfUnit1 = ResidualConvUnit(features, activation, bn) + self.resConfUnit2 = ResidualConvUnit(features, activation, bn) + + self.skip_add = nn.quantized.FloatFunctional() + + self.size=size + + def forward(self, *xs, size=None): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + res = self.resConfUnit1(xs[1]) + output = self.skip_add.add(output, res) + + output = self.resConfUnit2(output) + + if (size is None) and (self.size is None): + modifier = {"scale_factor": 2} + elif size is None: + modifier = {"size": self.size} + else: + modifier = {"size": size} + + output = nn.functional.interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners) + + output = self.out_conv(output) + + return output diff --git a/geobench/depth_anything_v2/util/transform.py b/geobench/depth_anything_v2/util/transform.py new file mode 100644 index 0000000000000000000000000000000000000000..b14aacd44ea086b01725a9ca68bb49eadcf37d73 --- /dev/null +++ b/geobench/depth_anything_v2/util/transform.py @@ -0,0 +1,158 @@ +import numpy as np +import cv2 + + +class Resize(object): + """Resize sample to given size (width, height). + """ + + def __init__( + self, + width, + height, + resize_target=True, + keep_aspect_ratio=False, + ensure_multiple_of=1, + resize_method="lower_bound", + image_interpolation_method=cv2.INTER_AREA, + ): + """Init. + + Args: + width (int): desired output width + height (int): desired output height + resize_target (bool, optional): + True: Resize the full sample (image, mask, target). + False: Resize image only. + Defaults to True. + keep_aspect_ratio (bool, optional): + True: Keep the aspect ratio of the input sample. + Output sample might not have the given width and height, and + resize behaviour depends on the parameter 'resize_method'. + Defaults to False. + ensure_multiple_of (int, optional): + Output width and height is constrained to be multiple of this parameter. + Defaults to 1. + resize_method (str, optional): + "lower_bound": Output will be at least as large as the given size. + "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) + "minimal": Scale as least as possible. (Output size might be smaller than given size.) + Defaults to "lower_bound". + """ + self.__width = width + self.__height = height + + self.__resize_target = resize_target + self.__keep_aspect_ratio = keep_aspect_ratio + self.__multiple_of = ensure_multiple_of + self.__resize_method = resize_method + self.__image_interpolation_method = image_interpolation_method + + def constrain_to_multiple_of(self, x, min_val=0, max_val=None): + y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if max_val is not None and y > max_val: + y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if y < min_val: + y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int) + + return y + + def get_size(self, width, height): + # determine new height and width + scale_height = self.__height / height + scale_width = self.__width / width + + if self.__keep_aspect_ratio: + if self.__resize_method == "lower_bound": + # scale such that output size is lower bound + if scale_width > scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "upper_bound": + # scale such that output size is upper bound + if scale_width < scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "minimal": + # scale as least as possbile + if abs(1 - scale_width) < abs(1 - scale_height): + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + else: + raise ValueError(f"resize_method {self.__resize_method} not implemented") + + if self.__resize_method == "lower_bound": + new_height = self.constrain_to_multiple_of(scale_height * height, min_val=self.__height) + new_width = self.constrain_to_multiple_of(scale_width * width, min_val=self.__width) + elif self.__resize_method == "upper_bound": + new_height = self.constrain_to_multiple_of(scale_height * height, max_val=self.__height) + new_width = self.constrain_to_multiple_of(scale_width * width, max_val=self.__width) + elif self.__resize_method == "minimal": + new_height = self.constrain_to_multiple_of(scale_height * height) + new_width = self.constrain_to_multiple_of(scale_width * width) + else: + raise ValueError(f"resize_method {self.__resize_method} not implemented") + + return (new_width, new_height) + + def __call__(self, sample): + width, height = self.get_size(sample["image"].shape[1], sample["image"].shape[0]) + + # resize sample + sample["image"] = cv2.resize(sample["image"], (width, height), interpolation=self.__image_interpolation_method) + + if self.__resize_target: + if "depth" in sample: + sample["depth"] = cv2.resize(sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST) + + if "mask" in sample: + sample["mask"] = cv2.resize(sample["mask"].astype(np.float32), (width, height), interpolation=cv2.INTER_NEAREST) + + return sample + + +class NormalizeImage(object): + """Normlize image by given mean and std. + """ + + def __init__(self, mean, std): + self.__mean = mean + self.__std = std + + def __call__(self, sample): + sample["image"] = (sample["image"] - self.__mean) / self.__std + + return sample + + +class PrepareForNet(object): + """Prepare sample for usage as network input. + """ + + def __init__(self): + pass + + def __call__(self, sample): + image = np.transpose(sample["image"], (2, 0, 1)) + sample["image"] = np.ascontiguousarray(image).astype(np.float32) + + if "depth" in sample: + depth = sample["depth"].astype(np.float32) + sample["depth"] = np.ascontiguousarray(depth) + + if "mask" in sample: + sample["mask"] = sample["mask"].astype(np.float32) + sample["mask"] = np.ascontiguousarray(sample["mask"]) + + return sample \ No newline at end of file diff --git a/geobench/midas/__pycache__/base_model.cpython-310.pyc b/geobench/midas/__pycache__/base_model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bb9f70b646742553babbc4068cdd28c42858270b Binary files /dev/null and b/geobench/midas/__pycache__/base_model.cpython-310.pyc differ diff --git a/geobench/midas/__pycache__/blocks.cpython-310.pyc b/geobench/midas/__pycache__/blocks.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0b87e6d57f1d7681e77234d1355b7aca39057b34 Binary files /dev/null and b/geobench/midas/__pycache__/blocks.cpython-310.pyc differ diff --git a/geobench/midas/__pycache__/dpt_depth.cpython-310.pyc b/geobench/midas/__pycache__/dpt_depth.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..560f55dd58a259b12bd8c6d1632f24f123486c30 Binary files /dev/null and b/geobench/midas/__pycache__/dpt_depth.cpython-310.pyc differ diff --git a/geobench/midas/__pycache__/midas_net.cpython-310.pyc b/geobench/midas/__pycache__/midas_net.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6c64d450e363b9264ee98343f8dfde21d1c2d562 Binary files /dev/null and b/geobench/midas/__pycache__/midas_net.cpython-310.pyc differ diff --git a/geobench/midas/__pycache__/midas_net_custom.cpython-310.pyc b/geobench/midas/__pycache__/midas_net_custom.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..14543793e5bf8a20ef1557b081a87fed9a90ba6b Binary files /dev/null and b/geobench/midas/__pycache__/midas_net_custom.cpython-310.pyc differ diff --git a/geobench/midas/__pycache__/model_loader.cpython-310.pyc b/geobench/midas/__pycache__/model_loader.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8d95d830ad7e6e240554898cc87b4b44d8e3e824 Binary files /dev/null and b/geobench/midas/__pycache__/model_loader.cpython-310.pyc differ diff --git a/geobench/midas/__pycache__/transforms.cpython-310.pyc b/geobench/midas/__pycache__/transforms.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b7dd31378ec3f20355aa0e9381188550e13d68ee Binary files /dev/null and b/geobench/midas/__pycache__/transforms.cpython-310.pyc differ diff --git a/geobench/midas/backbones/__pycache__/beit.cpython-310.pyc b/geobench/midas/backbones/__pycache__/beit.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c9f881e8280d45d926528e8cfcfc5c5332e1c73f Binary files /dev/null and b/geobench/midas/backbones/__pycache__/beit.cpython-310.pyc differ diff --git a/geobench/midas/backbones/__pycache__/levit.cpython-310.pyc b/geobench/midas/backbones/__pycache__/levit.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0e2837af0707026387f8cecf3762c20c60e21dcc Binary files /dev/null and b/geobench/midas/backbones/__pycache__/levit.cpython-310.pyc differ diff --git a/geobench/midas/backbones/__pycache__/swin.cpython-310.pyc b/geobench/midas/backbones/__pycache__/swin.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..452aafdee94b2e1e8da2825fc4b79940bbd2c089 Binary files /dev/null and b/geobench/midas/backbones/__pycache__/swin.cpython-310.pyc differ diff --git a/geobench/midas/backbones/__pycache__/swin2.cpython-310.pyc b/geobench/midas/backbones/__pycache__/swin2.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0c02ef83659b09d73bfc3ffd1b6028ca40f4cf50 Binary files /dev/null and b/geobench/midas/backbones/__pycache__/swin2.cpython-310.pyc differ diff --git a/geobench/midas/backbones/__pycache__/swin_common.cpython-310.pyc b/geobench/midas/backbones/__pycache__/swin_common.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ff0b0f3cd695f3f3b1df4747f34fb45549535781 Binary files /dev/null and b/geobench/midas/backbones/__pycache__/swin_common.cpython-310.pyc differ diff --git a/geobench/midas/backbones/__pycache__/utils.cpython-310.pyc b/geobench/midas/backbones/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..af9bb75e11c7e007debcec9efa8bbb304919deae Binary files /dev/null and b/geobench/midas/backbones/__pycache__/utils.cpython-310.pyc differ diff --git a/geobench/midas/backbones/__pycache__/vit.cpython-310.pyc b/geobench/midas/backbones/__pycache__/vit.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ee2909f312ca65efcadb245f389bcd2733567a91 Binary files /dev/null and b/geobench/midas/backbones/__pycache__/vit.cpython-310.pyc differ diff --git a/geobench/midas/backbones/beit.py b/geobench/midas/backbones/beit.py new file mode 100644 index 0000000000000000000000000000000000000000..7a24e02cd2b979844bf638b46ac60949ee9ce691 --- /dev/null +++ b/geobench/midas/backbones/beit.py @@ -0,0 +1,196 @@ +import timm +import torch +import types + +import numpy as np +import torch.nn.functional as F + +from .utils import forward_adapted_unflatten, make_backbone_default +from timm.models.beit import gen_relative_position_index +from torch.utils.checkpoint import checkpoint +from typing import Optional + + +def forward_beit(pretrained, x): + return forward_adapted_unflatten(pretrained, x, "forward_features") + + +def patch_embed_forward(self, x): + """ + Modification of timm.models.layers.patch_embed.py: PatchEmbed.forward to support arbitrary window sizes. + """ + x = self.proj(x) + if self.flatten: + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + return x + + +def _get_rel_pos_bias(self, window_size): + """ + Modification of timm.models.beit.py: Attention._get_rel_pos_bias to support arbitrary window sizes. + """ + old_height = 2 * self.window_size[0] - 1 + old_width = 2 * self.window_size[1] - 1 + + new_height = 2 * window_size[0] - 1 + new_width = 2 * window_size[1] - 1 + + old_relative_position_bias_table = self.relative_position_bias_table + + old_num_relative_distance = self.num_relative_distance + new_num_relative_distance = new_height * new_width + 3 + + old_sub_table = old_relative_position_bias_table[:old_num_relative_distance - 3] + + old_sub_table = old_sub_table.reshape(1, old_width, old_height, -1).permute(0, 3, 1, 2) + new_sub_table = F.interpolate(old_sub_table, size=(new_height, new_width), mode="bilinear") + new_sub_table = new_sub_table.permute(0, 2, 3, 1).reshape(new_num_relative_distance - 3, -1) + + new_relative_position_bias_table = torch.cat( + [new_sub_table, old_relative_position_bias_table[old_num_relative_distance - 3:]]) + + key = str(window_size[1]) + "," + str(window_size[0]) + if key not in self.relative_position_indices.keys(): + self.relative_position_indices[key] = gen_relative_position_index(window_size) + + relative_position_bias = new_relative_position_bias_table[ + self.relative_position_indices[key].view(-1)].view( + window_size[0] * window_size[1] + 1, + window_size[0] * window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + return relative_position_bias.unsqueeze(0) + + +def attention_forward(self, x, resolution, shared_rel_pos_bias: Optional[torch.Tensor] = None): + """ + Modification of timm.models.beit.py: Attention.forward to support arbitrary window sizes. + """ + B, N, C = x.shape + + qkv_bias = torch.cat((self.q_bias, self.k_bias, self.v_bias)) if self.q_bias is not None else None + qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) + qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + if self.relative_position_bias_table is not None: + window_size = tuple(np.array(resolution) // 16) + attn = attn + self._get_rel_pos_bias(window_size) + if shared_rel_pos_bias is not None: + attn = attn + shared_rel_pos_bias + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, -1) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +def block_forward(self, x, resolution, shared_rel_pos_bias: Optional[torch.Tensor] = None): + """ + Modification of timm.models.beit.py: Block.forward to support arbitrary window sizes. + """ + if self.gamma_1 is None: + x = x + self.drop_path(self.attn(self.norm1(x), resolution, shared_rel_pos_bias=shared_rel_pos_bias)) + x = x + self.drop_path(self.mlp(self.norm2(x))) + else: + x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), resolution, + shared_rel_pos_bias=shared_rel_pos_bias)) + x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) + return x + + +def beit_forward_features(self, x): + """ + Modification of timm.models.beit.py: Beit.forward_features to support arbitrary window sizes. + """ + resolution = x.shape[2:] + + x = self.patch_embed(x) + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + if self.pos_embed is not None: + x = x + self.pos_embed + x = self.pos_drop(x) + + rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None + for blk in self.blocks: + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(blk, x, shared_rel_pos_bias=rel_pos_bias) + else: + x = blk(x, resolution, shared_rel_pos_bias=rel_pos_bias) + x = self.norm(x) + return x + + +def _make_beit_backbone( + model, + features=[96, 192, 384, 768], + size=[384, 384], + hooks=[0, 4, 8, 11], + vit_features=768, + use_readout="ignore", + start_index=1, + start_index_readout=1, +): + backbone = make_backbone_default(model, features, size, hooks, vit_features, use_readout, start_index, + start_index_readout) + + backbone.model.patch_embed.forward = types.MethodType(patch_embed_forward, backbone.model.patch_embed) + backbone.model.forward_features = types.MethodType(beit_forward_features, backbone.model) + + for block in backbone.model.blocks: + attn = block.attn + attn._get_rel_pos_bias = types.MethodType(_get_rel_pos_bias, attn) + attn.forward = types.MethodType(attention_forward, attn) + attn.relative_position_indices = {} + + block.forward = types.MethodType(block_forward, block) + + return backbone + + +def _make_pretrained_beitl16_512(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("beit_large_patch16_512", pretrained=pretrained) + + hooks = [5, 11, 17, 23] if hooks is None else hooks + + features = [256, 512, 1024, 1024] + + return _make_beit_backbone( + model, + features=features, + size=[512, 512], + hooks=hooks, + vit_features=1024, + use_readout=use_readout, + ) + + +def _make_pretrained_beitl16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("beit_large_patch16_384", pretrained=pretrained) + + hooks = [5, 11, 17, 23] if hooks is None else hooks + return _make_beit_backbone( + model, + features=[256, 512, 1024, 1024], + hooks=hooks, + vit_features=1024, + use_readout=use_readout, + ) + + +def _make_pretrained_beitb16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("beit_base_patch16_384", pretrained=pretrained) + + hooks = [2, 5, 8, 11] if hooks is None else hooks + return _make_beit_backbone( + model, + features=[96, 192, 384, 768], + hooks=hooks, + use_readout=use_readout, + ) diff --git a/geobench/midas/backbones/levit.py b/geobench/midas/backbones/levit.py new file mode 100644 index 0000000000000000000000000000000000000000..6d023a98702a0451806d26f33f8bccf931814f10 --- /dev/null +++ b/geobench/midas/backbones/levit.py @@ -0,0 +1,106 @@ +import timm +import torch +import torch.nn as nn +import numpy as np + +from .utils import activations, get_activation, Transpose + + +def forward_levit(pretrained, x): + pretrained.model.forward_features(x) + + layer_1 = pretrained.activations["1"] + layer_2 = pretrained.activations["2"] + layer_3 = pretrained.activations["3"] + + layer_1 = pretrained.act_postprocess1(layer_1) + layer_2 = pretrained.act_postprocess2(layer_2) + layer_3 = pretrained.act_postprocess3(layer_3) + + return layer_1, layer_2, layer_3 + + +def _make_levit_backbone( + model, + hooks=[3, 11, 21], + patch_grid=[14, 14] +): + pretrained = nn.Module() + + pretrained.model = model + pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) + pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) + pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) + + pretrained.activations = activations + + patch_grid_size = np.array(patch_grid, dtype=int) + + pretrained.act_postprocess1 = nn.Sequential( + Transpose(1, 2), + nn.Unflatten(2, torch.Size(patch_grid_size.tolist())) + ) + pretrained.act_postprocess2 = nn.Sequential( + Transpose(1, 2), + nn.Unflatten(2, torch.Size((np.ceil(patch_grid_size / 2).astype(int)).tolist())) + ) + pretrained.act_postprocess3 = nn.Sequential( + Transpose(1, 2), + nn.Unflatten(2, torch.Size((np.ceil(patch_grid_size / 4).astype(int)).tolist())) + ) + + return pretrained + + +class ConvTransposeNorm(nn.Sequential): + """ + Modification of + https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/levit.py: ConvNorm + such that ConvTranspose2d is used instead of Conv2d. + """ + + def __init__( + self, in_chs, out_chs, kernel_size=1, stride=1, pad=0, dilation=1, + groups=1, bn_weight_init=1): + super().__init__() + self.add_module('c', + nn.ConvTranspose2d(in_chs, out_chs, kernel_size, stride, pad, dilation, groups, bias=False)) + self.add_module('bn', nn.BatchNorm2d(out_chs)) + + nn.init.constant_(self.bn.weight, bn_weight_init) + + @torch.no_grad() + def fuse(self): + c, bn = self._modules.values() + w = bn.weight / (bn.running_var + bn.eps) ** 0.5 + w = c.weight * w[:, None, None, None] + b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5 + m = nn.ConvTranspose2d( + w.size(1), w.size(0), w.shape[2:], stride=self.c.stride, + padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups) + m.weight.data.copy_(w) + m.bias.data.copy_(b) + return m + + +def stem_b4_transpose(in_chs, out_chs, activation): + """ + Modification of + https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/levit.py: stem_b16 + such that ConvTranspose2d is used instead of Conv2d and stem is also reduced to the half. + """ + return nn.Sequential( + ConvTransposeNorm(in_chs, out_chs, 3, 2, 1), + activation(), + ConvTransposeNorm(out_chs, out_chs // 2, 3, 2, 1), + activation()) + + +def _make_pretrained_levit_384(pretrained, hooks=None): + model = timm.create_model("levit_384", pretrained=pretrained) + + hooks = [3, 11, 21] if hooks == None else hooks + return _make_levit_backbone( + model, + hooks=hooks + ) diff --git a/geobench/midas/backbones/next_vit.py b/geobench/midas/backbones/next_vit.py new file mode 100644 index 0000000000000000000000000000000000000000..8afdd8b743b5ab023a359dc3b721e601b1a40d11 --- /dev/null +++ b/geobench/midas/backbones/next_vit.py @@ -0,0 +1,39 @@ +import timm + +import torch.nn as nn + +from pathlib import Path +from .utils import activations, forward_default, get_activation + +from ..external.next_vit.classification.nextvit import * + + +def forward_next_vit(pretrained, x): + return forward_default(pretrained, x, "forward") + + +def _make_next_vit_backbone( + model, + hooks=[2, 6, 36, 39], +): + pretrained = nn.Module() + + pretrained.model = model + pretrained.model.features[hooks[0]].register_forward_hook(get_activation("1")) + pretrained.model.features[hooks[1]].register_forward_hook(get_activation("2")) + pretrained.model.features[hooks[2]].register_forward_hook(get_activation("3")) + pretrained.model.features[hooks[3]].register_forward_hook(get_activation("4")) + + pretrained.activations = activations + + return pretrained + + +def _make_pretrained_next_vit_large_6m(hooks=None): + model = timm.create_model("nextvit_large") + + hooks = [2, 6, 36, 39] if hooks == None else hooks + return _make_next_vit_backbone( + model, + hooks=hooks, + ) diff --git a/geobench/midas/backbones/swin.py b/geobench/midas/backbones/swin.py new file mode 100644 index 0000000000000000000000000000000000000000..f8c71367e3e78b087f80b2ab3e2f495a9c372f1a --- /dev/null +++ b/geobench/midas/backbones/swin.py @@ -0,0 +1,13 @@ +import timm + +from .swin_common import _make_swin_backbone + + +def _make_pretrained_swinl12_384(pretrained, hooks=None): + model = timm.create_model("swin_large_patch4_window12_384", pretrained=pretrained) + + hooks = [1, 1, 17, 1] if hooks == None else hooks + return _make_swin_backbone( + model, + hooks=hooks + ) diff --git a/geobench/midas/backbones/swin2.py b/geobench/midas/backbones/swin2.py new file mode 100644 index 0000000000000000000000000000000000000000..ce4c8f1d6fc1807a207dc6b9a261c6f7b14a87a3 --- /dev/null +++ b/geobench/midas/backbones/swin2.py @@ -0,0 +1,34 @@ +import timm + +from .swin_common import _make_swin_backbone + + +def _make_pretrained_swin2l24_384(pretrained, hooks=None): + model = timm.create_model("swinv2_large_window12to24_192to384_22kft1k", pretrained=pretrained) + + hooks = [1, 1, 17, 1] if hooks == None else hooks + return _make_swin_backbone( + model, + hooks=hooks + ) + + +def _make_pretrained_swin2b24_384(pretrained, hooks=None): + model = timm.create_model("swinv2_base_window12to24_192to384_22kft1k", pretrained=pretrained) + + hooks = [1, 1, 17, 1] if hooks == None else hooks + return _make_swin_backbone( + model, + hooks=hooks + ) + + +def _make_pretrained_swin2t16_256(pretrained, hooks=None): + model = timm.create_model("swinv2_tiny_window16_256", pretrained=pretrained) + + hooks = [1, 1, 5, 1] if hooks == None else hooks + return _make_swin_backbone( + model, + hooks=hooks, + patch_grid=[64, 64] + ) diff --git a/geobench/midas/backbones/swin_common.py b/geobench/midas/backbones/swin_common.py new file mode 100644 index 0000000000000000000000000000000000000000..94d63d408f18511179d90b3ac6f697385d1e556d --- /dev/null +++ b/geobench/midas/backbones/swin_common.py @@ -0,0 +1,52 @@ +import torch + +import torch.nn as nn +import numpy as np + +from .utils import activations, forward_default, get_activation, Transpose + + +def forward_swin(pretrained, x): + return forward_default(pretrained, x) + + +def _make_swin_backbone( + model, + hooks=[1, 1, 17, 1], + patch_grid=[96, 96] +): + pretrained = nn.Module() + + pretrained.model = model + pretrained.model.layers[0].blocks[hooks[0]].register_forward_hook(get_activation("1")) + pretrained.model.layers[1].blocks[hooks[1]].register_forward_hook(get_activation("2")) + pretrained.model.layers[2].blocks[hooks[2]].register_forward_hook(get_activation("3")) + pretrained.model.layers[3].blocks[hooks[3]].register_forward_hook(get_activation("4")) + + pretrained.activations = activations + + if hasattr(model, "patch_grid"): + used_patch_grid = model.patch_grid + else: + used_patch_grid = patch_grid + + patch_grid_size = np.array(used_patch_grid, dtype=int) + + pretrained.act_postprocess1 = nn.Sequential( + Transpose(1, 2), + nn.Unflatten(2, torch.Size(patch_grid_size.tolist())) + ) + pretrained.act_postprocess2 = nn.Sequential( + Transpose(1, 2), + nn.Unflatten(2, torch.Size((patch_grid_size // 2).tolist())) + ) + pretrained.act_postprocess3 = nn.Sequential( + Transpose(1, 2), + nn.Unflatten(2, torch.Size((patch_grid_size // 4).tolist())) + ) + pretrained.act_postprocess4 = nn.Sequential( + Transpose(1, 2), + nn.Unflatten(2, torch.Size((patch_grid_size // 8).tolist())) + ) + + return pretrained diff --git a/geobench/midas/backbones/utils.py b/geobench/midas/backbones/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0558899dddcfccec5f01a764d4f21738eb612149 --- /dev/null +++ b/geobench/midas/backbones/utils.py @@ -0,0 +1,249 @@ +import torch + +import torch.nn as nn + + +class Slice(nn.Module): + def __init__(self, start_index=1): + super(Slice, self).__init__() + self.start_index = start_index + + def forward(self, x): + return x[:, self.start_index:] + + +class AddReadout(nn.Module): + def __init__(self, start_index=1): + super(AddReadout, self).__init__() + self.start_index = start_index + + def forward(self, x): + if self.start_index == 2: + readout = (x[:, 0] + x[:, 1]) / 2 + else: + readout = x[:, 0] + return x[:, self.start_index:] + readout.unsqueeze(1) + + +class ProjectReadout(nn.Module): + def __init__(self, in_features, start_index=1): + super(ProjectReadout, self).__init__() + self.start_index = start_index + + self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU()) + + def forward(self, x): + readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index:]) + features = torch.cat((x[:, self.start_index:], readout), -1) + + return self.project(features) + + +class Transpose(nn.Module): + def __init__(self, dim0, dim1): + super(Transpose, self).__init__() + self.dim0 = dim0 + self.dim1 = dim1 + + def forward(self, x): + x = x.transpose(self.dim0, self.dim1) + return x + + +activations = {} + + +def get_activation(name): + def hook(model, input, output): + activations[name] = output + + return hook + + +def forward_default(pretrained, x, function_name="forward_features"): + exec(f"pretrained.model.{function_name}(x)") + + layer_1 = pretrained.activations["1"] + layer_2 = pretrained.activations["2"] + layer_3 = pretrained.activations["3"] + layer_4 = pretrained.activations["4"] + + if hasattr(pretrained, "act_postprocess1"): + layer_1 = pretrained.act_postprocess1(layer_1) + if hasattr(pretrained, "act_postprocess2"): + layer_2 = pretrained.act_postprocess2(layer_2) + if hasattr(pretrained, "act_postprocess3"): + layer_3 = pretrained.act_postprocess3(layer_3) + if hasattr(pretrained, "act_postprocess4"): + layer_4 = pretrained.act_postprocess4(layer_4) + + return layer_1, layer_2, layer_3, layer_4 + + +def forward_adapted_unflatten(pretrained, x, function_name="forward_features"): + b, c, h, w = x.shape + + exec(f"glob = pretrained.model.{function_name}(x)") + + layer_1 = pretrained.activations["1"] + layer_2 = pretrained.activations["2"] + layer_3 = pretrained.activations["3"] + layer_4 = pretrained.activations["4"] + + layer_1 = pretrained.act_postprocess1[0:2](layer_1) + layer_2 = pretrained.act_postprocess2[0:2](layer_2) + layer_3 = pretrained.act_postprocess3[0:2](layer_3) + layer_4 = pretrained.act_postprocess4[0:2](layer_4) + + unflatten = nn.Sequential( + nn.Unflatten( + 2, + torch.Size( + [ + h // pretrained.model.patch_size[1], + w // pretrained.model.patch_size[0], + ] + ), + ) + ) + + if layer_1.ndim == 3: + layer_1 = unflatten(layer_1) + if layer_2.ndim == 3: + layer_2 = unflatten(layer_2) + if layer_3.ndim == 3: + layer_3 = unflatten(layer_3) + if layer_4.ndim == 3: + layer_4 = unflatten(layer_4) + + layer_1 = pretrained.act_postprocess1[3: len(pretrained.act_postprocess1)](layer_1) + layer_2 = pretrained.act_postprocess2[3: len(pretrained.act_postprocess2)](layer_2) + layer_3 = pretrained.act_postprocess3[3: len(pretrained.act_postprocess3)](layer_3) + layer_4 = pretrained.act_postprocess4[3: len(pretrained.act_postprocess4)](layer_4) + + return layer_1, layer_2, layer_3, layer_4 + + +def get_readout_oper(vit_features, features, use_readout, start_index=1): + if use_readout == "ignore": + readout_oper = [Slice(start_index)] * len(features) + elif use_readout == "add": + readout_oper = [AddReadout(start_index)] * len(features) + elif use_readout == "project": + readout_oper = [ + ProjectReadout(vit_features, start_index) for out_feat in features + ] + else: + assert ( + False + ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'" + + return readout_oper + + +def make_backbone_default( + model, + features=[96, 192, 384, 768], + size=[384, 384], + hooks=[2, 5, 8, 11], + vit_features=768, + use_readout="ignore", + start_index=1, + start_index_readout=1, +): + pretrained = nn.Module() + + pretrained.model = model + pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) + pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) + pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) + pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4")) + + pretrained.activations = activations + + readout_oper = get_readout_oper(vit_features, features, use_readout, start_index_readout) + + # 32, 48, 136, 384 + pretrained.act_postprocess1 = nn.Sequential( + readout_oper[0], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[0], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[0], + out_channels=features[0], + kernel_size=4, + stride=4, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess2 = nn.Sequential( + readout_oper[1], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[1], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[1], + out_channels=features[1], + kernel_size=2, + stride=2, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess3 = nn.Sequential( + readout_oper[2], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[2], + kernel_size=1, + stride=1, + padding=0, + ), + ) + + pretrained.act_postprocess4 = nn.Sequential( + readout_oper[3], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[3], + kernel_size=1, + stride=1, + padding=0, + ), + nn.Conv2d( + in_channels=features[3], + out_channels=features[3], + kernel_size=3, + stride=2, + padding=1, + ), + ) + + pretrained.model.start_index = start_index + pretrained.model.patch_size = [16, 16] + + return pretrained diff --git a/geobench/midas/backbones/vit.py b/geobench/midas/backbones/vit.py new file mode 100644 index 0000000000000000000000000000000000000000..413f9693bd4548342280e329c9128c1a52cea920 --- /dev/null +++ b/geobench/midas/backbones/vit.py @@ -0,0 +1,221 @@ +import torch +import torch.nn as nn +import timm +import types +import math +import torch.nn.functional as F + +from .utils import (activations, forward_adapted_unflatten, get_activation, get_readout_oper, + make_backbone_default, Transpose) + + +def forward_vit(pretrained, x): + return forward_adapted_unflatten(pretrained, x, "forward_flex") + + +def _resize_pos_embed(self, posemb, gs_h, gs_w): + posemb_tok, posemb_grid = ( + posemb[:, : self.start_index], + posemb[0, self.start_index:], + ) + + gs_old = int(math.sqrt(len(posemb_grid))) + + posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) + posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear") + posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1) + + posemb = torch.cat([posemb_tok, posemb_grid], dim=1) + + return posemb + + +def forward_flex(self, x): + b, c, h, w = x.shape + + pos_embed = self._resize_pos_embed( + self.pos_embed, h // self.patch_size[1], w // self.patch_size[0] + ) + + B = x.shape[0] + + if hasattr(self.patch_embed, "backbone"): + x = self.patch_embed.backbone(x) + if isinstance(x, (list, tuple)): + x = x[-1] # last feature if backbone outputs list/tuple of features + + x = self.patch_embed.proj(x).flatten(2).transpose(1, 2) + + if getattr(self, "dist_token", None) is not None: + cls_tokens = self.cls_token.expand( + B, -1, -1 + ) # stole cls_tokens impl from Phil Wang, thanks + dist_token = self.dist_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, dist_token, x), dim=1) + else: + if self.no_embed_class: + x = x + pos_embed + cls_tokens = self.cls_token.expand( + B, -1, -1 + ) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + + if not self.no_embed_class: + x = x + pos_embed + x = self.pos_drop(x) + + for blk in self.blocks: + x = blk(x) + + x = self.norm(x) + + return x + + +def _make_vit_b16_backbone( + model, + features=[96, 192, 384, 768], + size=[384, 384], + hooks=[2, 5, 8, 11], + vit_features=768, + use_readout="ignore", + start_index=1, + start_index_readout=1, +): + pretrained = make_backbone_default(model, features, size, hooks, vit_features, use_readout, start_index, + start_index_readout) + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) + pretrained.model._resize_pos_embed = types.MethodType( + _resize_pos_embed, pretrained.model + ) + + return pretrained + + +def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("vit_large_patch16_384", pretrained=pretrained) + + hooks = [5, 11, 17, 23] if hooks == None else hooks + return _make_vit_b16_backbone( + model, + features=[256, 512, 1024, 1024], + hooks=hooks, + vit_features=1024, + use_readout=use_readout, + ) + + +def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("vit_base_patch16_384", pretrained=pretrained) + + hooks = [2, 5, 8, 11] if hooks == None else hooks + return _make_vit_b16_backbone( + model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout + ) + + +def _make_vit_b_rn50_backbone( + model, + features=[256, 512, 768, 768], + size=[384, 384], + hooks=[0, 1, 8, 11], + vit_features=768, + patch_size=[16, 16], + number_stages=2, + use_vit_only=False, + use_readout="ignore", + start_index=1, +): + pretrained = nn.Module() + + pretrained.model = model + + used_number_stages = 0 if use_vit_only else number_stages + for s in range(used_number_stages): + pretrained.model.patch_embed.backbone.stages[s].register_forward_hook( + get_activation(str(s + 1)) + ) + for s in range(used_number_stages, 4): + pretrained.model.blocks[hooks[s]].register_forward_hook(get_activation(str(s + 1))) + + pretrained.activations = activations + + readout_oper = get_readout_oper(vit_features, features, use_readout, start_index) + + for s in range(used_number_stages): + value = nn.Sequential(nn.Identity(), nn.Identity(), nn.Identity()) + exec(f"pretrained.act_postprocess{s + 1}=value") + for s in range(used_number_stages, 4): + if s < number_stages: + final_layer = nn.ConvTranspose2d( + in_channels=features[s], + out_channels=features[s], + kernel_size=4 // (2 ** s), + stride=4 // (2 ** s), + padding=0, + bias=True, + dilation=1, + groups=1, + ) + elif s > number_stages: + final_layer = nn.Conv2d( + in_channels=features[3], + out_channels=features[3], + kernel_size=3, + stride=2, + padding=1, + ) + else: + final_layer = None + + layers = [ + readout_oper[s], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[s], + kernel_size=1, + stride=1, + padding=0, + ), + ] + if final_layer is not None: + layers.append(final_layer) + + value = nn.Sequential(*layers) + exec(f"pretrained.act_postprocess{s + 1}=value") + + pretrained.model.start_index = start_index + pretrained.model.patch_size = patch_size + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model._resize_pos_embed = types.MethodType( + _resize_pos_embed, pretrained.model + ) + + return pretrained + + +def _make_pretrained_vitb_rn50_384( + pretrained, use_readout="ignore", hooks=None, use_vit_only=False +): + model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained) + + hooks = [0, 1, 8, 11] if hooks == None else hooks + return _make_vit_b_rn50_backbone( + model, + features=[256, 512, 768, 768], + size=[384, 384], + hooks=hooks, + use_vit_only=use_vit_only, + use_readout=use_readout, + ) diff --git a/geobench/midas/base_model.py b/geobench/midas/base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..f2dc2d2be16cf19c0fce59751d94d4388a6b4820 --- /dev/null +++ b/geobench/midas/base_model.py @@ -0,0 +1,17 @@ +import torch +from huggingface_hub import PyTorchModelHubMixin, hf_hub_download + + +class BaseModel(torch.nn.Module, PyTorchModelHubMixin): + def load(self, path): + """Load model from file. + + Args: + path (str): file path + """ + parameters = torch.load(path, map_location=torch.device('cpu')) + + if "optimizer" in parameters: + parameters = parameters["model"] + + self.load_state_dict(parameters, strict=False) diff --git a/geobench/midas/blocks.py b/geobench/midas/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..6d87a00680bb6ed9a6d7c3043ea30a1e90361794 --- /dev/null +++ b/geobench/midas/blocks.py @@ -0,0 +1,439 @@ +import torch +import torch.nn as nn + +from .backbones.beit import ( + _make_pretrained_beitl16_512, + _make_pretrained_beitl16_384, + _make_pretrained_beitb16_384, + forward_beit, +) +from .backbones.swin_common import ( + forward_swin, +) +from .backbones.swin2 import ( + _make_pretrained_swin2l24_384, + _make_pretrained_swin2b24_384, + _make_pretrained_swin2t16_256, +) +from .backbones.swin import ( + _make_pretrained_swinl12_384, +) +from .backbones.levit import ( + _make_pretrained_levit_384, + forward_levit, +) +from .backbones.vit import ( + _make_pretrained_vitb_rn50_384, + _make_pretrained_vitl16_384, + _make_pretrained_vitb16_384, + forward_vit, +) + +def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, + use_vit_only=False, use_readout="ignore", in_features=[96, 256, 512, 1024]): + if backbone == "beitl16_512": + pretrained = _make_pretrained_beitl16_512( + use_pretrained, hooks=hooks, use_readout=use_readout + ) + scratch = _make_scratch( + [256, 512, 1024, 1024], features, groups=groups, expand=expand + ) # BEiT_512-L (backbone) + elif backbone == "beitl16_384": + pretrained = _make_pretrained_beitl16_384( + use_pretrained, hooks=hooks, use_readout=use_readout + ) + scratch = _make_scratch( + [256, 512, 1024, 1024], features, groups=groups, expand=expand + ) # BEiT_384-L (backbone) + elif backbone == "beitb16_384": + pretrained = _make_pretrained_beitb16_384( + use_pretrained, hooks=hooks, use_readout=use_readout + ) + scratch = _make_scratch( + [96, 192, 384, 768], features, groups=groups, expand=expand + ) # BEiT_384-B (backbone) + elif backbone == "swin2l24_384": + pretrained = _make_pretrained_swin2l24_384( + use_pretrained, hooks=hooks + ) + scratch = _make_scratch( + [192, 384, 768, 1536], features, groups=groups, expand=expand + ) # Swin2-L/12to24 (backbone) + elif backbone == "swin2b24_384": + pretrained = _make_pretrained_swin2b24_384( + use_pretrained, hooks=hooks + ) + scratch = _make_scratch( + [128, 256, 512, 1024], features, groups=groups, expand=expand + ) # Swin2-B/12to24 (backbone) + elif backbone == "swin2t16_256": + pretrained = _make_pretrained_swin2t16_256( + use_pretrained, hooks=hooks + ) + scratch = _make_scratch( + [96, 192, 384, 768], features, groups=groups, expand=expand + ) # Swin2-T/16 (backbone) + elif backbone == "swinl12_384": + pretrained = _make_pretrained_swinl12_384( + use_pretrained, hooks=hooks + ) + scratch = _make_scratch( + [192, 384, 768, 1536], features, groups=groups, expand=expand + ) # Swin-L/12 (backbone) + elif backbone == "next_vit_large_6m": + from .backbones.next_vit import _make_pretrained_next_vit_large_6m + pretrained = _make_pretrained_next_vit_large_6m(hooks=hooks) + scratch = _make_scratch( + in_features, features, groups=groups, expand=expand + ) # Next-ViT-L on ImageNet-1K-6M (backbone) + elif backbone == "levit_384": + pretrained = _make_pretrained_levit_384( + use_pretrained, hooks=hooks + ) + scratch = _make_scratch( + [384, 512, 768], features, groups=groups, expand=expand + ) # LeViT 384 (backbone) + elif backbone == "vitl16_384": + pretrained = _make_pretrained_vitl16_384( + use_pretrained, hooks=hooks, use_readout=use_readout + ) + scratch = _make_scratch( + [256, 512, 1024, 1024], features, groups=groups, expand=expand + ) # ViT-L/16 - 85.0% Top1 (backbone) + elif backbone == "vitb_rn50_384": + pretrained = _make_pretrained_vitb_rn50_384( + use_pretrained, + hooks=hooks, + use_vit_only=use_vit_only, + use_readout=use_readout, + ) + scratch = _make_scratch( + [256, 512, 768, 768], features, groups=groups, expand=expand + ) # ViT-H/16 - 85.0% Top1 (backbone) + elif backbone == "vitb16_384": + pretrained = _make_pretrained_vitb16_384( + use_pretrained, hooks=hooks, use_readout=use_readout + ) + scratch = _make_scratch( + [96, 192, 384, 768], features, groups=groups, expand=expand + ) # ViT-B/16 - 84.6% Top1 (backbone) + elif backbone == "resnext101_wsl": + pretrained = _make_pretrained_resnext101_wsl(use_pretrained) + scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3 + elif backbone == "efficientnet_lite3": + pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable) + scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3 + else: + print(f"Backbone '{backbone}' not implemented") + assert False + + return pretrained, scratch + + +def _make_scratch(in_shape, out_shape, groups=1, expand=False): + scratch = nn.Module() + + out_shape1 = out_shape + out_shape2 = out_shape + out_shape3 = out_shape + if len(in_shape) >= 4: + out_shape4 = out_shape + + if expand: + out_shape1 = out_shape + out_shape2 = out_shape*2 + out_shape3 = out_shape*4 + if len(in_shape) >= 4: + out_shape4 = out_shape*8 + + scratch.layer1_rn = nn.Conv2d( + in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer2_rn = nn.Conv2d( + in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer3_rn = nn.Conv2d( + in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + if len(in_shape) >= 4: + scratch.layer4_rn = nn.Conv2d( + in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + + return scratch + + +def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False): + efficientnet = torch.hub.load( + "rwightman/gen-efficientnet-pytorch", + "tf_efficientnet_lite3", + pretrained=use_pretrained, + exportable=exportable + ) + return _make_efficientnet_backbone(efficientnet) + + +def _make_efficientnet_backbone(effnet): + pretrained = nn.Module() + + pretrained.layer1 = nn.Sequential( + effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2] + ) + pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3]) + pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5]) + pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9]) + + return pretrained + + +def _make_resnet_backbone(resnet): + pretrained = nn.Module() + pretrained.layer1 = nn.Sequential( + resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1 + ) + + pretrained.layer2 = resnet.layer2 + pretrained.layer3 = resnet.layer3 + pretrained.layer4 = resnet.layer4 + + return pretrained + + +def _make_pretrained_resnext101_wsl(use_pretrained): + resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl") + return _make_resnet_backbone(resnet) + + + +class Interpolate(nn.Module): + """Interpolation module. + """ + + def __init__(self, scale_factor, mode, align_corners=False): + """Init. + + Args: + scale_factor (float): scaling + mode (str): interpolation mode + """ + super(Interpolate, self).__init__() + + self.interp = nn.functional.interpolate + self.scale_factor = scale_factor + self.mode = mode + self.align_corners = align_corners + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: interpolated data + """ + + x = self.interp( + x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners + ) + + return x + + +class ResidualConvUnit(nn.Module): + """Residual convolution module. + """ + + def __init__(self, features): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.conv1 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True + ) + + self.conv2 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True + ) + + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + out = self.relu(x) + out = self.conv1(out) + out = self.relu(out) + out = self.conv2(out) + + return out + x + + +class FeatureFusionBlock(nn.Module): + """Feature fusion block. + """ + + def __init__(self, features): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock, self).__init__() + + self.resConfUnit1 = ResidualConvUnit(features) + self.resConfUnit2 = ResidualConvUnit(features) + + def forward(self, *xs): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + output += self.resConfUnit1(xs[1]) + + output = self.resConfUnit2(output) + + output = nn.functional.interpolate( + output, scale_factor=2, mode="bilinear", align_corners=True + ) + + return output + + + + +class ResidualConvUnit_custom(nn.Module): + """Residual convolution module. + """ + + def __init__(self, features, activation, bn): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.bn = bn + + self.groups=1 + + self.conv1 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups + ) + + self.conv2 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups + ) + + if self.bn==True: + self.bn1 = nn.BatchNorm2d(features) + self.bn2 = nn.BatchNorm2d(features) + + self.activation = activation + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + + out = self.activation(x) + out = self.conv1(out) + if self.bn==True: + out = self.bn1(out) + + out = self.activation(out) + out = self.conv2(out) + if self.bn==True: + out = self.bn2(out) + + if self.groups > 1: + out = self.conv_merge(out) + + return self.skip_add.add(out, x) + + # return out + x + + +class FeatureFusionBlock_custom(nn.Module): + """Feature fusion block. + """ + + def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True, size=None): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock_custom, self).__init__() + + self.deconv = deconv + self.align_corners = align_corners + + self.groups=1 + + self.expand = expand + out_features = features + if self.expand==True: + out_features = features//2 + + self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1) + + self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn) + self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn) + + self.skip_add = nn.quantized.FloatFunctional() + + self.size=size + + def forward(self, *xs, size=None): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + res = self.resConfUnit1(xs[1]) + output = self.skip_add.add(output, res) + # output += res + + output = self.resConfUnit2(output) + + if (size is None) and (self.size is None): + modifier = {"scale_factor": 2} + elif size is None: + modifier = {"size": self.size} + else: + modifier = {"size": size} + + output = nn.functional.interpolate( + output, **modifier, mode="bilinear", align_corners=self.align_corners + ) + + output = self.out_conv(output) + + return output + diff --git a/geobench/midas/dpt_depth.py b/geobench/midas/dpt_depth.py new file mode 100644 index 0000000000000000000000000000000000000000..16e21746102ceaa7d2ee4af636e7a1249bf73a5c --- /dev/null +++ b/geobench/midas/dpt_depth.py @@ -0,0 +1,166 @@ +import torch +import torch.nn as nn + +from .base_model import BaseModel +from .blocks import ( + FeatureFusionBlock_custom, + Interpolate, + _make_encoder, + forward_beit, + forward_swin, + forward_levit, + forward_vit, +) +from .backbones.levit import stem_b4_transpose +from timm.models.layers import get_act_layer + + +def _make_fusion_block(features, use_bn, size = None): + return FeatureFusionBlock_custom( + features, + nn.ReLU(False), + deconv=False, + bn=use_bn, + expand=False, + align_corners=True, + size=size, + ) + + +class DPT(BaseModel): + def __init__( + self, + head, + features=256, + backbone="vitb_rn50_384", + readout="project", + channels_last=False, + use_bn=False, + **kwargs + ): + + super(DPT, self).__init__() + + self.channels_last = channels_last + + # For the Swin, Swin 2, LeViT and Next-ViT Transformers, the hierarchical architectures prevent setting the + # hooks freely. Instead, the hooks have to be chosen according to the ranges specified in the comments. + hooks = { + "beitl16_512": [5, 11, 17, 23], + "beitl16_384": [5, 11, 17, 23], + "beitb16_384": [2, 5, 8, 11], + "swin2l24_384": [1, 1, 17, 1], # Allowed ranges: [0, 1], [0, 1], [ 0, 17], [ 0, 1] + "swin2b24_384": [1, 1, 17, 1], # [0, 1], [0, 1], [ 0, 17], [ 0, 1] + "swin2t16_256": [1, 1, 5, 1], # [0, 1], [0, 1], [ 0, 5], [ 0, 1] + "swinl12_384": [1, 1, 17, 1], # [0, 1], [0, 1], [ 0, 17], [ 0, 1] + "next_vit_large_6m": [2, 6, 36, 39], # [0, 2], [3, 6], [ 7, 36], [37, 39] + "levit_384": [3, 11, 21], # [0, 3], [6, 11], [14, 21] + "vitb_rn50_384": [0, 1, 8, 11], + "vitb16_384": [2, 5, 8, 11], + "vitl16_384": [5, 11, 17, 23], + }[backbone] + + if "next_vit" in backbone: + in_features = { + "next_vit_large_6m": [96, 256, 512, 1024], + }[backbone] + else: + in_features = None + + # Instantiate backbone and reassemble blocks + self.pretrained, self.scratch = _make_encoder( + backbone, + features, + False, # Set to true of you want to train from scratch, uses ImageNet weights + groups=1, + expand=False, + exportable=False, + hooks=hooks, + use_readout=readout, + in_features=in_features, + ) + + self.number_layers = len(hooks) if hooks is not None else 4 + size_refinenet3 = None + self.scratch.stem_transpose = None + + if "beit" in backbone: + self.forward_transformer = forward_beit + elif "swin" in backbone: + self.forward_transformer = forward_swin + elif "next_vit" in backbone: + from .backbones.next_vit import forward_next_vit + self.forward_transformer = forward_next_vit + elif "levit" in backbone: + self.forward_transformer = forward_levit + size_refinenet3 = 7 + self.scratch.stem_transpose = stem_b4_transpose(256, 128, get_act_layer("hard_swish")) + else: + self.forward_transformer = forward_vit + + self.scratch.refinenet1 = _make_fusion_block(features, use_bn) + self.scratch.refinenet2 = _make_fusion_block(features, use_bn) + self.scratch.refinenet3 = _make_fusion_block(features, use_bn, size_refinenet3) + if self.number_layers >= 4: + self.scratch.refinenet4 = _make_fusion_block(features, use_bn) + + self.scratch.output_conv = head + + + def forward(self, x): + if self.channels_last == True: + x.contiguous(memory_format=torch.channels_last) + + layers = self.forward_transformer(self.pretrained, x) + if self.number_layers == 3: + layer_1, layer_2, layer_3 = layers + else: + layer_1, layer_2, layer_3, layer_4 = layers + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + if self.number_layers >= 4: + layer_4_rn = self.scratch.layer4_rn(layer_4) + + if self.number_layers == 3: + path_3 = self.scratch.refinenet3(layer_3_rn, size=layer_2_rn.shape[2:]) + else: + path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:]) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:]) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:]) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + if self.scratch.stem_transpose is not None: + path_1 = self.scratch.stem_transpose(path_1) + + out = self.scratch.output_conv(path_1) + + return out + + +class DPTDepthModel(DPT): + def __init__(self, path=None, non_negative=True, **kwargs): + features = kwargs["features"] if "features" in kwargs else 256 + head_features_1 = kwargs["head_features_1"] if "head_features_1" in kwargs else features + head_features_2 = kwargs["head_features_2"] if "head_features_2" in kwargs else 32 + kwargs.pop("head_features_1", None) + kwargs.pop("head_features_2", None) + + head = nn.Sequential( + nn.Conv2d(head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1), + Interpolate(scale_factor=2, mode="bilinear", align_corners=True), + nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + nn.Identity(), + ) + + super().__init__(head, **kwargs) + + if path is not None: + self.load(path) + + def forward(self, x): + return super().forward(x) diff --git a/geobench/midas/midas_net.py b/geobench/midas/midas_net.py new file mode 100644 index 0000000000000000000000000000000000000000..8a954977800b0a0f48807e80fa63041910e33c1f --- /dev/null +++ b/geobench/midas/midas_net.py @@ -0,0 +1,76 @@ +"""MidashNet: Network for monocular depth estimation trained by mixing several datasets. +This file contains code that is adapted from +https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py +""" +import torch +import torch.nn as nn + +from .base_model import BaseModel +from .blocks import FeatureFusionBlock, Interpolate, _make_encoder + + +class MidasNet(BaseModel): + """Network for monocular depth estimation. + """ + + def __init__(self, path=None, features=256, non_negative=True): + """Init. + + Args: + path (str, optional): Path to saved model. Defaults to None. + features (int, optional): Number of features. Defaults to 256. + backbone (str, optional): Backbone network for encoder. Defaults to resnet50 + """ + print("Loading weights: ", path) + + super(MidasNet, self).__init__() + + use_pretrained = False if path is None else True + + self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained) + + self.scratch.refinenet4 = FeatureFusionBlock(features) + self.scratch.refinenet3 = FeatureFusionBlock(features) + self.scratch.refinenet2 = FeatureFusionBlock(features) + self.scratch.refinenet1 = FeatureFusionBlock(features) + + self.scratch.output_conv = nn.Sequential( + nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1), + Interpolate(scale_factor=2, mode="bilinear"), + nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + ) + + if path: + self.load(path) + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input data (image) + + Returns: + tensor: depth + """ + + layer_1 = self.pretrained.layer1(x) + layer_2 = self.pretrained.layer2(layer_1) + layer_3 = self.pretrained.layer3(layer_2) + layer_4 = self.pretrained.layer4(layer_3) + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + path_4 = self.scratch.refinenet4(layer_4_rn) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv(path_1) + + return torch.squeeze(out, dim=1) diff --git a/geobench/midas/midas_net_custom.py b/geobench/midas/midas_net_custom.py new file mode 100644 index 0000000000000000000000000000000000000000..50e4acb5e53d5fabefe3dde16ab49c33c2b7797c --- /dev/null +++ b/geobench/midas/midas_net_custom.py @@ -0,0 +1,128 @@ +"""MidashNet: Network for monocular depth estimation trained by mixing several datasets. +This file contains code that is adapted from +https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py +""" +import torch +import torch.nn as nn + +from .base_model import BaseModel +from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder + + +class MidasNet_small(BaseModel): + """Network for monocular depth estimation. + """ + + def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True, + blocks={'expand': True}): + """Init. + + Args: + path (str, optional): Path to saved model. Defaults to None. + features (int, optional): Number of features. Defaults to 256. + backbone (str, optional): Backbone network for encoder. Defaults to resnet50 + """ + print("Loading weights: ", path) + + super(MidasNet_small, self).__init__() + + use_pretrained = False if path else True + + self.channels_last = channels_last + self.blocks = blocks + self.backbone = backbone + + self.groups = 1 + + features1=features + features2=features + features3=features + features4=features + self.expand = False + if "expand" in self.blocks and self.blocks['expand'] == True: + self.expand = True + features1=features + features2=features*2 + features3=features*4 + features4=features*8 + + self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable) + + self.scratch.activation = nn.ReLU(False) + + self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) + self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) + self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) + self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners) + + + self.scratch.output_conv = nn.Sequential( + nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups), + Interpolate(scale_factor=2, mode="bilinear"), + nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1), + self.scratch.activation, + nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + nn.Identity(), + ) + + if path: + self.load(path) + + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input data (image) + + Returns: + tensor: depth + """ + if self.channels_last==True: + print("self.channels_last = ", self.channels_last) + x.contiguous(memory_format=torch.channels_last) + + + layer_1 = self.pretrained.layer1(x) + layer_2 = self.pretrained.layer2(layer_1) + layer_3 = self.pretrained.layer3(layer_2) + layer_4 = self.pretrained.layer4(layer_3) + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + + path_4 = self.scratch.refinenet4(layer_4_rn) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv(path_1) + + return torch.squeeze(out, dim=1) + + + +def fuse_model(m): + prev_previous_type = nn.Identity() + prev_previous_name = '' + previous_type = nn.Identity() + previous_name = '' + for name, module in m.named_modules(): + if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU: + # print("FUSED ", prev_previous_name, previous_name, name) + torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True) + elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d: + # print("FUSED ", prev_previous_name, previous_name) + torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True) + # elif previous_type == nn.Conv2d and type(module) == nn.ReLU: + # print("FUSED ", previous_name, name) + # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True) + + prev_previous_type = previous_type + prev_previous_name = previous_name + previous_type = type(module) + previous_name = name \ No newline at end of file diff --git a/geobench/midas/model_loader.py b/geobench/midas/model_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..7243d54abe18cdc7d5d34e47045f07436708631e --- /dev/null +++ b/geobench/midas/model_loader.py @@ -0,0 +1,242 @@ +import cv2 +import torch + +from geobench.midas.dpt_depth import DPTDepthModel +from geobench.midas.midas_net import MidasNet +from geobench.midas.midas_net_custom import MidasNet_small +from geobench.midas.transforms import Resize, NormalizeImage, PrepareForNet + +from torchvision.transforms import Compose + +default_models = { + "dpt_beit_large_512": "weights/dpt_beit_large_512.pt", + "dpt_beit_large_384": "weights/dpt_beit_large_384.pt", + "dpt_beit_base_384": "weights/dpt_beit_base_384.pt", + "dpt_swin2_large_384": "weights/dpt_swin2_large_384.pt", + "dpt_swin2_base_384": "weights/dpt_swin2_base_384.pt", + "dpt_swin2_tiny_256": "weights/dpt_swin2_tiny_256.pt", + "dpt_swin_large_384": "weights/dpt_swin_large_384.pt", + "dpt_next_vit_large_384": "weights/dpt_next_vit_large_384.pt", + "dpt_levit_224": "weights/dpt_levit_224.pt", + "dpt_large_384": "weights/dpt_large_384.pt", + "dpt_hybrid_384": "weights/dpt_hybrid_384.pt", + "midas_v21_384": "weights/midas_v21_384.pt", + "midas_v21_small_256": "weights/midas_v21_small_256.pt", + "openvino_midas_v21_small_256": "weights/openvino_midas_v21_small_256.xml", +} + + +def load_model(device, model_path, model_type="dpt_large_384", optimize=True, height=None, square=False): + """Load the specified network. + + Args: + device (device): the torch device used + model_path (str): path to saved model + model_type (str): the type of the model to be loaded + optimize (bool): optimize the model to half-integer on CUDA? + height (int): inference encoder image height + square (bool): resize to a square resolution? + + Returns: + The loaded network, the transform which prepares images as input to the network and the dimensions of the + network input + """ + if "openvino" in model_type: + from openvino.runtime import Core + + keep_aspect_ratio = not square + + if model_type == "dpt_beit_large_512": + model = DPTDepthModel( + path=model_path, + backbone="beitl16_512", + non_negative=True, + ) + net_w, net_h = 512, 512 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_beit_large_384": + model = DPTDepthModel( + path=model_path, + backbone="beitl16_384", + non_negative=True, + ) + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_beit_base_384": + model = DPTDepthModel( + path=model_path, + backbone="beitb16_384", + non_negative=True, + ) + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_swin2_large_384": + model = DPTDepthModel( + path=model_path, + backbone="swin2l24_384", + non_negative=True, + ) + net_w, net_h = 384, 384 + keep_aspect_ratio = False + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_swin2_base_384": + model = DPTDepthModel( + path=model_path, + backbone="swin2b24_384", + non_negative=True, + ) + net_w, net_h = 384, 384 + keep_aspect_ratio = False + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_swin2_tiny_256": + model = DPTDepthModel( + path=model_path, + backbone="swin2t16_256", + non_negative=True, + ) + net_w, net_h = 256, 256 + keep_aspect_ratio = False + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_swin_large_384": + model = DPTDepthModel( + path=model_path, + backbone="swinl12_384", + non_negative=True, + ) + net_w, net_h = 384, 384 + keep_aspect_ratio = False + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_next_vit_large_384": + model = DPTDepthModel( + path=model_path, + backbone="next_vit_large_6m", + non_negative=True, + ) + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + # We change the notation from dpt_levit_224 (MiDaS notation) to levit_384 (timm notation) here, where the 224 refers + # to the resolution 224x224 used by LeViT and 384 is the first entry of the embed_dim, see _cfg and model_cfgs of + # https://github.com/rwightman/pytorch-image-models/blob/main/timm/models/levit.py + # (commit id: 927f031293a30afb940fff0bee34b85d9c059b0e) + elif model_type == "dpt_levit_224": + model = DPTDepthModel( + path=model_path, + backbone="levit_384", + non_negative=True, + head_features_1=64, + head_features_2=8, + ) + net_w, net_h = 224, 224 + keep_aspect_ratio = False + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_large_384": + model = DPTDepthModel( + path=model_path, + backbone="vitl16_384", + non_negative=True, + ) + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_hybrid_384": + model = DPTDepthModel( + path=model_path, + backbone="vitb_rn50_384", + non_negative=True, + ) + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "midas_v21_384": + model = MidasNet(model_path, non_negative=True) + net_w, net_h = 384, 384 + resize_mode = "upper_bound" + normalization = NormalizeImage( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + + elif model_type == "midas_v21_small_256": + model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True, + non_negative=True, blocks={'expand': True}) + net_w, net_h = 256, 256 + resize_mode = "upper_bound" + normalization = NormalizeImage( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + + elif model_type == "openvino_midas_v21_small_256": + ie = Core() + uncompiled_model = ie.read_model(model=model_path) + model = ie.compile_model(uncompiled_model, "CPU") + net_w, net_h = 256, 256 + resize_mode = "upper_bound" + normalization = NormalizeImage( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + + else: + print(f"model_type '{model_type}' not implemented, use: --model_type large") + assert False + + if not "openvino" in model_type: + print("Model loaded, number of parameters = {:.0f}M".format(sum(p.numel() for p in model.parameters()) / 1e6)) + else: + print("Model loaded, optimized with OpenVINO") + + if "openvino" in model_type: + keep_aspect_ratio = False + + if height is not None: + net_w, net_h = height, height + + transform = Compose( + [ + Resize( + net_w, + net_h, + resize_target=None, + keep_aspect_ratio=keep_aspect_ratio, + ensure_multiple_of=32, + resize_method=resize_mode, + image_interpolation_method=cv2.INTER_CUBIC, + ), + normalization, + PrepareForNet(), + ] + ) + + if not "openvino" in model_type: + model.eval() + + if optimize and (device == torch.device("cuda")): + if not "openvino" in model_type: + model = model.to(memory_format=torch.channels_last) + model = model.half() + else: + print("Error: OpenVINO models are already optimized. No optimization to half-float possible.") + exit() + + if not "openvino" in model_type: + model.to(device) + + return model, transform, net_w, net_h diff --git a/geobench/midas/transforms.py b/geobench/midas/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..6e6c1251691a774e0a6c98d50b31916ab7946579 --- /dev/null +++ b/geobench/midas/transforms.py @@ -0,0 +1,234 @@ +import numpy as np +import cv2 +import math + + +def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA): + """Rezise the sample to ensure the given size. Keeps aspect ratio. + + Args: + sample (dict): sample + size (tuple): image size + + Returns: + tuple: new size + """ + shape = list(sample["disparity"].shape) + + if shape[0] >= size[0] and shape[1] >= size[1]: + return sample + + scale = [0, 0] + scale[0] = size[0] / shape[0] + scale[1] = size[1] / shape[1] + + scale = max(scale) + + shape[0] = math.ceil(scale * shape[0]) + shape[1] = math.ceil(scale * shape[1]) + + # resize + sample["image"] = cv2.resize( + sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method + ) + + sample["disparity"] = cv2.resize( + sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST + ) + sample["mask"] = cv2.resize( + sample["mask"].astype(np.float32), + tuple(shape[::-1]), + interpolation=cv2.INTER_NEAREST, + ) + sample["mask"] = sample["mask"].astype(bool) + + return tuple(shape) + + +class Resize(object): + """Resize sample to given size (width, height). + """ + + def __init__( + self, + width, + height, + resize_target=True, + keep_aspect_ratio=False, + ensure_multiple_of=1, + resize_method="lower_bound", + image_interpolation_method=cv2.INTER_AREA, + ): + """Init. + + Args: + width (int): desired output width + height (int): desired output height + resize_target (bool, optional): + True: Resize the full sample (image, mask, target). + False: Resize image only. + Defaults to True. + keep_aspect_ratio (bool, optional): + True: Keep the aspect ratio of the input sample. + Output sample might not have the given width and height, and + resize behaviour depends on the parameter 'resize_method'. + Defaults to False. + ensure_multiple_of (int, optional): + Output width and height is constrained to be multiple of this parameter. + Defaults to 1. + resize_method (str, optional): + "lower_bound": Output will be at least as large as the given size. + "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) + "minimal": Scale as least as possible. (Output size might be smaller than given size.) + Defaults to "lower_bound". + """ + self.__width = width + self.__height = height + + self.__resize_target = resize_target + self.__keep_aspect_ratio = keep_aspect_ratio + self.__multiple_of = ensure_multiple_of + self.__resize_method = resize_method + self.__image_interpolation_method = image_interpolation_method + + def constrain_to_multiple_of(self, x, min_val=0, max_val=None): + y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if max_val is not None and y > max_val: + y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if y < min_val: + y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int) + + return y + + def get_size(self, width, height): + # determine new height and width + scale_height = self.__height / height + scale_width = self.__width / width + + if self.__keep_aspect_ratio: + if self.__resize_method == "lower_bound": + # scale such that output size is lower bound + if scale_width > scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "upper_bound": + # scale such that output size is upper bound + if scale_width < scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "minimal": + # scale as least as possbile + if abs(1 - scale_width) < abs(1 - scale_height): + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + else: + raise ValueError( + f"resize_method {self.__resize_method} not implemented" + ) + + if self.__resize_method == "lower_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, min_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, min_val=self.__width + ) + elif self.__resize_method == "upper_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, max_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, max_val=self.__width + ) + elif self.__resize_method == "minimal": + new_height = self.constrain_to_multiple_of(scale_height * height) + new_width = self.constrain_to_multiple_of(scale_width * width) + else: + raise ValueError(f"resize_method {self.__resize_method} not implemented") + + return (new_width, new_height) + + def __call__(self, sample): + width, height = self.get_size( + sample["image"].shape[1], sample["image"].shape[0] + ) + + # resize sample + sample["image"] = cv2.resize( + sample["image"], + (width, height), + interpolation=self.__image_interpolation_method, + ) + + if self.__resize_target: + if "disparity" in sample: + sample["disparity"] = cv2.resize( + sample["disparity"], + (width, height), + interpolation=cv2.INTER_NEAREST, + ) + + if "depth" in sample: + sample["depth"] = cv2.resize( + sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST + ) + + # sample["mask"] = cv2.resize( + # sample["mask"].astype(np.float32), + # (width, height), + # interpolation=cv2.INTER_NEAREST, + # ) + # sample["mask"] = sample["mask"].astype(bool) + + return sample + + +class NormalizeImage(object): + """Normlize image by given mean and std. + """ + + def __init__(self, mean, std): + self.__mean = mean + self.__std = std + + def __call__(self, sample): + sample["image"] = (sample["image"] - self.__mean) / self.__std + + return sample + + +class PrepareForNet(object): + """Prepare sample for usage as network input. + """ + + def __init__(self): + pass + + def __call__(self, sample): + image = np.transpose(sample["image"], (2, 0, 1)) + sample["image"] = np.ascontiguousarray(image).astype(np.float32) + + if "mask" in sample: + sample["mask"] = sample["mask"].astype(np.float32) + sample["mask"] = np.ascontiguousarray(sample["mask"]) + + if "disparity" in sample: + disparity = sample["disparity"].astype(np.float32) + sample["disparity"] = np.ascontiguousarray(disparity) + + if "depth" in sample: + depth = sample["depth"].astype(np.float32) + sample["depth"] = np.ascontiguousarray(depth) + + return sample diff --git a/geobench/modeling/__pycache__/__init__.cpython-310.pyc b/geobench/modeling/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9473c4b817cd1e54d4b64f447fd1e2730fbf6021 Binary files /dev/null and b/geobench/modeling/__pycache__/__init__.cpython-310.pyc differ diff --git a/geobench/modeling/archs/dam/dam.py b/geobench/modeling/archs/dam/dam.py new file mode 100644 index 0000000000000000000000000000000000000000..1355c6e9aa70b913fc8f9a404b66da20fdb6b7e7 --- /dev/null +++ b/geobench/modeling/archs/dam/dam.py @@ -0,0 +1,510 @@ +import argparse +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +from huggingface_hub import PyTorchModelHubMixin, hf_hub_download +from geobench.modeling.backbones.vit.ViT_DINO import vit_large, vit_giant2, vit_base +from geobench.modeling.backbones.vit.ViT_DINO_reg import vit_large_reg, vit_giant2_reg +from timm.models.vision_transformer import vit_large_patch16_224, vit_large_patch14_224 + +def compute_depth_expectation(prob, depth_values): + depth_values = depth_values.view(*depth_values.shape, 1, 1) + depth = torch.sum(prob * depth_values, 1) + return depth + + +def _make_scratch(in_shape, out_shape, groups=1, expand=False): + scratch = nn.Module() + + out_shape1 = out_shape + out_shape2 = out_shape + out_shape3 = out_shape + if len(in_shape) >= 4: + out_shape4 = out_shape + + if expand: + out_shape1 = out_shape + out_shape2 = out_shape*2 + out_shape3 = out_shape*4 + if len(in_shape) >= 4: + out_shape4 = out_shape*8 + + scratch.layer1_rn = nn.Conv2d( + in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer2_rn = nn.Conv2d( + in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer3_rn = nn.Conv2d( + in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + if len(in_shape) >= 4: + scratch.layer4_rn = nn.Conv2d( + in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + + return scratch + + +def _make_fusion_block(features, use_bn, size = None): + return FeatureFusionBlock( + features, + nn.ReLU(False), + deconv=False, + bn=use_bn, + expand=False, + align_corners=True, + size=size, + ) + + +class ResidualConvUnit(nn.Module): + """Residual convolution module. + """ + + def __init__(self, features, activation, bn): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.bn = bn + + self.groups=1 + + self.conv1 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups + ) + + self.conv2 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups + ) + + if self.bn==True: + self.bn1 = nn.BatchNorm2d(features) + self.bn2 = nn.BatchNorm2d(features) + + self.activation = activation + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + + out = self.activation(x) + out = self.conv1(out) + if self.bn==True: + out = self.bn1(out) + + out = self.activation(out) + out = self.conv2(out) + if self.bn==True: + out = self.bn2(out) + + if self.groups > 1: + out = self.conv_merge(out) + + return self.skip_add.add(out, x) + + +class FeatureFusionBlock(nn.Module): + """Feature fusion block. + """ + + def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True, size=None): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock, self).__init__() + + self.deconv = deconv + self.align_corners = align_corners + + self.groups=1 + + self.expand = expand + out_features = features + if self.expand==True: + out_features = features//2 + + self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1) + + self.resConfUnit1 = ResidualConvUnit(features, activation, bn) + self.resConfUnit2 = ResidualConvUnit(features, activation, bn) + + self.skip_add = nn.quantized.FloatFunctional() + + self.size=size + + def forward(self, *xs, size=None): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + res = self.resConfUnit1(xs[1]) + output = self.skip_add.add(output, res) + + output = self.resConfUnit2(output) + + if (size is None) and (self.size is None): + modifier = {"scale_factor": 2} + elif size is None: + modifier = {"size": self.size} + else: + modifier = {"size": size} + + output = nn.functional.interpolate( + output, **modifier, mode="bilinear", align_corners=self.align_corners + ) + + output = self.out_conv(output) + + return output + + +class DPTHead(nn.Module): + def __init__( + self, + mode, + in_channels, + features=256, + use_bn=False, + out_channels=[256, 512, 1024, 1024], + head_out_channels=1, + use_clstoken=False, + num_depth_regressor_anchor=512, + ): + super(DPTHead, self).__init__() + + self.use_clstoken = use_clstoken + + self.projects = nn.ModuleList([ + nn.Conv2d( + in_channels=in_channels, + out_channels=out_channel, + kernel_size=1, + stride=1, + padding=0, + ) for out_channel in out_channels + ]) + + self.resize_layers = nn.ModuleList([ + nn.ConvTranspose2d( + in_channels=out_channels[0], + out_channels=out_channels[0], + kernel_size=4, + stride=4, + padding=0), + nn.ConvTranspose2d( + in_channels=out_channels[1], + out_channels=out_channels[1], + kernel_size=2, + stride=2, + padding=0), + nn.Identity(), + nn.Conv2d( + in_channels=out_channels[3], + out_channels=out_channels[3], + kernel_size=3, + stride=2, + padding=1) + ]) + + if use_clstoken: + self.readout_projects = nn.ModuleList() + for _ in range(len(self.projects)): + self.readout_projects.append( + nn.Sequential( + nn.Linear(2 * in_channels, in_channels), + nn.GELU())) + + self.scratch = _make_scratch( + out_channels, + features, + groups=1, + expand=False, + ) + + self.scratch.stem_transpose = None + + self.scratch.refinenet1 = _make_fusion_block(features, use_bn) + self.scratch.refinenet2 = _make_fusion_block(features, use_bn) + self.scratch.refinenet3 = _make_fusion_block(features, use_bn) + self.scratch.refinenet4 = _make_fusion_block(features, use_bn) + + head_features_1 = features + + # if nclass > 1: + # self.scratch.output_conv = nn.Sequential( + # nn.Conv2d(head_features_1, head_features_1, kernel_size=3, stride=1, padding=1), + # nn.ReLU(True), + # nn.Conv2d(head_features_1, nclass, kernel_size=1, stride=1, padding=0), + # ) + # else: + + self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1) + + # if 'metric' in mode: + # num_depth_regressor_anchor = 512 + # # head_features_2 = num_depth_regressor_anchor + # self.scratch.output_conv2 = nn.Sequential( + # nn.Conv2d(head_features_1 // 2, num_depth_regressor_anchor, kernel_size=3, stride=1, padding=1), + # nn.ReLU(True), + # nn.Conv2d(num_depth_regressor_anchor, num_depth_regressor_anchor, kernel_size=1, stride=1, padding=0), + # # nn.Sigmoid() + # ) + + # elif 'disparity' in mode or 'rel_depth' in mode: + # head_features_2 = 32 + + # self.scratch.output_conv2 = nn.Sequential( + # nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1), + # nn.ReLU(True), + # nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0), + # nn.ReLU(True), + # nn.Identity(), + # ) + # else: + # raise NotImplementedError + + head_features_2 = 32 + + self.scratch.output_conv2 = nn.Sequential( + nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(head_features_2, head_out_channels, kernel_size=1, stride=1, padding=0), + # nn.ReLU(True), + # nn.Identity(), + ) + + + def forward(self, out_features, patch_h, patch_w): + out = [] + for i, x in enumerate(out_features): + if self.use_clstoken: + x, cls_token = x[0], x[1] + readout = cls_token.unsqueeze(1).expand_as(x) + x = self.readout_projects[i](torch.cat((x, readout), -1)) + else: + x = x[0] + # import pdb;pdb.set_trace() + x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w)) + + x = self.projects[i](x) + x = self.resize_layers[i](x) + + out.append(x) + + layer_1, layer_2, layer_3, layer_4 = out + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:]) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:]) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:]) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + out = self.scratch.output_conv1(path_1) + + out = F.interpolate(out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True) + out = self.scratch.output_conv2(out) + + # print(out.min()) + # import pdb;pdb.set_trace() + + return out + + +class DepthAnything(nn.Module, PyTorchModelHubMixin): + # @register_to_config + def __init__( + self, + encoder='vitl', + features=256, + out_channels=[256, 512, 1024, 1024], + head_out_channels=1, + wo_relu_1_2_channel=False, + use_bn=False, + use_clstoken=False, + # localhub=None + use_registers=False, + max_depth=1.0, + mode='disparity', + num_depth_regressor_anchor=512, + depth_normalize=(0.1, 150), + pretrain_type='dinov2', # sam, imagenet + del_mask_token=True, + ): + super(DepthAnything, self).__init__() + + self.pretrain_type = pretrain_type + self.max_depth = max_depth + self.mode = mode + + assert encoder in ['vits', 'vitb', 'vitl', "vitg"] + + self.intermediate_layer_idx = { + 'vits': [2, 5, 8, 11], + 'vitb': [2, 5, 8, 11], + 'vitl': [4, 11, 17, 23], + 'vitg': [9, 19, 29, 39] + } + + self.backbone_name = encoder + # in case the Internet connection is not stable, please load the DINOv2 locally + # if localhub: + # assert type(localhub) == str + # # self.backbone = torch.hub.load(localhub, 'dinov2_{:}14'.format(encoder), source='local', pretrained=False) + # self.backbone = torch.hub.load(localhub, 'dinov2_{:}14'.format(encoder), source='local', pretrained=True) + # else: + # self.backbone = torch.hub.load('facebookresearch/dinov2', 'dinov2_{:}14'.format(encoder)) + + if use_registers: + if encoder == 'vitl': + checkpoint='data/weights/dinov2/dinov2_vitl14_reg4_pretrain.pth' + self.backbone = vit_large_reg(checkpoint=checkpoint) + elif encoder == 'vitg': + checkpoint='data/weights/dinov2/dinov2_vitg14_reg4_pretrain.pth' + self.backbone = vit_giant2_reg(checkpoint=checkpoint) + else: + raise NotImplementedError + else: + if encoder == 'vitl': + if pretrain_type == 'dinov2': + self.backbone = vit_large(checkpoint=None, del_mask_token=del_mask_token) + + # import pdb;pdb.set_trace() + elif encoder == 'vitb': + self.backbone = vit_base(checkpoint=None, del_mask_token=del_mask_token) + + elif encoder == 'vitg': + from geobench.depthanything_v2.dinov2 import DINOv2 + checkpoint='data/weights/dinov2/dinov2_vitg14_pretrain.pth' + self.backbone = DINOv2(model_name=encoder) + miss, unexpected = self.backbone.load_state_dict(torch.load(checkpoint, map_location='cpu'), strict=False) + print('missing keys:', miss) + print('unexpected keys:', unexpected) + # import pdb;pdb.set_trace() + # self.backbone = vit_giant2(checkpoint=checkpoint) + else: + raise NotImplementedError + + + # dim = self.backbone.blocks[0].attn.qkv.in_features + dim = self.backbone.embed_dim + + self.min_depth = depth_normalize[0] + self.max_depth = depth_normalize[1] + self.num_depth_regressor_anchor = num_depth_regressor_anchor + self.depth_head = DPTHead(mode, dim, features, use_bn, + out_channels=out_channels, + head_out_channels=head_out_channels, + use_clstoken=use_clstoken, + num_depth_regressor_anchor=num_depth_regressor_anchor, + ) + # import pdb;pdb.set_trace() + + self.wo_relu_1_2_channel = wo_relu_1_2_channel + + def get_bins(self, bins_num): + depth_bins_vec = torch.linspace(math.log(self.min_depth), math.log(self.max_depth), bins_num, device='cuda') + depth_bins_vec = torch.exp(depth_bins_vec) + return depth_bins_vec + + + def register_depth_expectation_anchor(self, bins_num, B): + depth_bins_vec = self.get_bins(bins_num) + depth_bins_vec = depth_bins_vec.unsqueeze(0).repeat(B, 1) + self.register_buffer('depth_expectation_anchor', depth_bins_vec, persistent=False) + + + def forward(self, x): + + bs, _, h, w = x.shape + + # features = self.backbone.get_intermediate_layers(x, 4, return_class_token=True) + + if self.pretrain_type=='dinov2': + features = self.backbone.get_intermediate_layers(x, self.intermediate_layer_idx[self.backbone_name], return_class_token=True) + patch_h, patch_w = h // 14, w // 14 + + elif self.pretrain_type=='imagenet': + features = self.backbone.get_intermediate_layers(x, self.intermediate_layer_idx[self.backbone_name], return_prefix_tokens=True) + patch_h, patch_w = h // 16, w // 16 + else: + raise NotImplementedError + + # import pdb;pdb.set_trace() + + # if 'metric' in self.mode: + # prob_feature = self.depth_head(features, patch_h, patch_w) + # prob_feature = F.interpolate(prob_feature, size=(h, w), mode="bilinear", align_corners=True) + # prob = prob_feature.softmax(dim=1) + + # if "depth_expectation_anchor" not in self._buffers: + # self.register_depth_expectation_anchor(self.num_depth_regressor_anchor, bs) + + # depth = compute_depth_expectation( + # prob, + # self.depth_expectation_anchor[:bs, ...] + # ).unsqueeze(1) + + # elif 'disparity' in self.mode or 'rel_depth' in self.mode: + # depth = self.depth_head(features, patch_h, patch_w) + # depth = F.interpolate(depth, size=(h, w), mode="bilinear", align_corners=True) + # # import pdb;pdb.set_trace() + # depth = F.relu(depth) + # else: + # raise NotImplementedError + + depth = self.depth_head(features, patch_h, patch_w) + depth = F.interpolate(depth, size=(h, w), mode="bilinear", align_corners=True) + + if not self.wo_relu_1_2_channel: + depth = F.relu(depth) + else: + depth[:, 2:] = F.relu(depth[:, 2:]) + # import pdb;pdb.set_trace() + + return depth, features[3][0] + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument( + "--encoder", + default="vits", + type=str, + choices=["vits", "vitb", "vitl", "vitg"], + ) + args = parser.parse_args() + + # model = DepthAnything.from_pretrained("LiheYoung/depth_anything_{:}14".format(args.encoder)) + # model = DepthAnything.from_pretrained("LiheYoung/depth_anything_{:}14".format(args.encoder)) + # print(model) + device = 'cuda' + image = torch.randn(1,3, 420, 420).to(device) + local_hub = "~/.cache/torch/hub/facebookresearch_dinov2_main/" + model = DepthAnything(localhub=local_hub,).to(device) + output = model(image) + import pdb;pdb.set_trace() + # .from_pretrained("LiheYoung/depth_anything_{:}14".format(args.encoder)) + print(model) \ No newline at end of file diff --git a/geobench/modeling/backbones/vit/ViT_DINO.py b/geobench/modeling/backbones/vit/ViT_DINO.py new file mode 100644 index 0000000000000000000000000000000000000000..00a0ecd459d985bd032b957d7a0e9af0679c9b1e --- /dev/null +++ b/geobench/modeling/backbones/vit/ViT_DINO.py @@ -0,0 +1,1496 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +from functools import partial +import math +import logging +from typing import Sequence, Tuple, Union, Callable, Optional, Dict, Any, List + +import torch +import torch.nn as nn +from torch import Tensor +import torch.utils.checkpoint +from torch.nn.init import trunc_normal_ + +#from dinov2.layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block + +logger = logging.getLogger("dinov2") + +class ConvBlock(nn.Module): + def __init__(self, channels): + super(ConvBlock, self).__init__() + + self.act = nn.ReLU(inplace=True) + self.conv1 = nn.Conv2d( + channels, + channels, + kernel_size=3, + stride=1, + padding=1 + ) + self.norm1 = nn.BatchNorm2d(channels) + self.conv2 = nn.Conv2d( + channels, + channels, + kernel_size=3, + stride=1, + padding=1 + ) + self.norm2 = nn.BatchNorm2d(channels) + + def forward(self, x): + + out = self.norm1(x) + out = self.act(out) + out = self.conv1(out) + out = self.norm2(out) + out = self.act(out) + out = self.conv2(out) + return x + out + +def make_2tuple(x): + if isinstance(x, tuple): + assert len(x) == 2 + return x + + assert isinstance(x, int) + return (x, x) + +def drop_path(x, drop_prob: float = 0.0, training: bool = False): + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0: + random_tensor.div_(keep_prob) + output = x * random_tensor + return output + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + +class LayerScale(nn.Module): + def __init__( + self, + dim: int, + init_values: Union[float, Tensor] = 1e-5, + inplace: bool = False, + ) -> None: + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: Tensor) -> Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma + + +class PatchEmbed(nn.Module): + """ + 2D image to patch embedding: (B,C,H,W) -> (B,N,D) + + Args: + img_size: Image size. + patch_size: Patch token size. + in_chans: Number of input image channels. + embed_dim: Number of linear projection output channels. + norm_layer: Normalization layer. + """ + + def __init__( + self, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 16, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer: Optional[Callable] = None, + flatten_embedding: bool = True, + ) -> None: + super().__init__() + + image_HW = make_2tuple(img_size) + patch_HW = make_2tuple(patch_size) + patch_grid_size = ( + image_HW[0] // patch_HW[0], + image_HW[1] // patch_HW[1], + ) + + self.img_size = image_HW + self.patch_size = patch_HW + self.patches_resolution = patch_grid_size + self.num_patches = patch_grid_size[0] * patch_grid_size[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.flatten_embedding = flatten_embedding + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x: Tensor) -> Tensor: + _, _, H, W = x.shape + patch_H, patch_W = self.patch_size + # import pdb;pdb.set_trace() + assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}" + assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}" + + x = self.proj(x) # B C H W + H, W = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) # B HW C + x = self.norm(x) + if not self.flatten_embedding: + x = x.reshape(-1, H, W, self.embed_dim) # B H W C + return x + + def flops(self) -> float: + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops + +class Mlp(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = nn.GELU, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) + self.drop = nn.Dropout(drop) + + def forward(self, x: Tensor) -> Tensor: + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class SwiGLUFFN(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) + self.w3 = nn.Linear(hidden_features, out_features, bias=bias) + + def forward(self, x: Tensor) -> Tensor: + x12 = self.w12(x) + x1, x2 = x12.chunk(2, dim=-1) + hidden = F.silu(x1) * x2 + return self.w3(hidden) + + +try: + from xformers.ops import SwiGLU + #import numpy.bool + XFORMERS_AVAILABLE = True +except ImportError: + SwiGLU = SwiGLUFFN + XFORMERS_AVAILABLE = False + +class SwiGLUFFNFused(SwiGLU): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + out_features = out_features or in_features + hidden_features = hidden_features or in_features + hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 + super().__init__( + in_features=in_features, + hidden_features=hidden_features, + out_features=out_features, + bias=bias, + ) + + +try: + from xformers.ops import memory_efficient_attention, unbind, fmha + from xformers.components.attention import ScaledDotProduct + from xformers.components import MultiHeadDispatch + #import numpy.bool + XFORMERS_AVAILABLE = True +except ImportError: + logger.warning("xFormers not available") + XFORMERS_AVAILABLE = False + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + window_size: int = 0, + ) -> None: + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + + #if not self.training: + # + # self.attn = ScaledDotProduct() + #self.attn = MultiHeadDispatch(dim_model=EMB, residual_dropout=DROPOUT, num_heads=HEADS, attention=attn) + + def forward(self, x: Tensor, attn_bias=None) -> Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + + q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] + attn = q @ k.transpose(-2, -1) + + if attn_bias is not None: + attn = attn + attn_bias[:, :, :N] + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class MemEffAttention(Attention): + def forward(self, x: Tensor, attn_bias=None) -> Tensor: + if not XFORMERS_AVAILABLE: + #if True: + assert attn_bias is None, "xFormers is required for nested tensors usage" + return super().forward(x, attn_bias) + + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) + + q, k, v = unbind(qkv, 2) + if attn_bias is not None: + x = memory_efficient_attention(q, k, v, attn_bias=attn_bias[:, :, :N]) + else: + x = memory_efficient_attention(q, k, v) + x = x.reshape([B, N, C]) + + x = self.proj(x) + x = self.proj_drop(x) + return x + +try: + from xformers.ops import fmha + from xformers.ops import scaled_index_add, index_select_cat + #import numpy.bool + XFORMERS_AVAILABLE = True +except ImportError: + logger.warning("xFormers not available") + XFORMERS_AVAILABLE = False + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + proj_bias: bool = True, + ffn_bias: bool = True, + drop: float = 0.0, + attn_drop: float = 0.0, + init_values = None, + drop_path: float = 0.0, + act_layer: Callable[..., nn.Module] = nn.GELU, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_class: Callable[..., nn.Module] = Attention, + ffn_layer: Callable[..., nn.Module] = Mlp, + ) -> None: + super().__init__() + # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}") + self.norm1 = norm_layer(dim) + self.attn = attn_class( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + attn_drop=attn_drop, + proj_drop=drop, + ) + self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = ffn_layer( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + bias=ffn_bias, + ) + self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.sample_drop_ratio = drop_path + + def forward(self, x: Tensor, attn_bias=None) -> Tensor: + def attn_residual_func(x: Tensor, attn_bias) -> Tensor: + return self.ls1(self.attn(self.norm1(x), attn_bias)) + + def ffn_residual_func(x: Tensor) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + if self.training and self.sample_drop_ratio > 0.1: + # the overhead is compensated only for a drop path rate larger than 0.1 + x = drop_add_residual_stochastic_depth( + x, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + attn_bias=attn_bias + ) + x = drop_add_residual_stochastic_depth( + x, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + elif self.training and self.sample_drop_ratio > 0.0: + x = x + self.drop_path1(attn_residual_func(x, attn_bias)) + x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 + else: + x = x + attn_residual_func(x, attn_bias) + x = x + ffn_residual_func(x) + return x + + +def drop_add_residual_stochastic_depth( + x: Tensor, + residual_func: Callable[[Tensor], Tensor], + sample_drop_ratio: float = 0.0, attn_bias=None +) -> Tensor: + # 1) extract subset using permutation + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + x_subset = x[brange] + + # 2) apply residual_func to get residual + residual = residual_func(x_subset, attn_bias) + + x_flat = x.flatten(1) + residual = residual.flatten(1) + + residual_scale_factor = b / sample_subset_size + + # 3) add the residual + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + return x_plus_residual.view_as(x) + + +def get_branges_scales(x, sample_drop_ratio=0.0): + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + residual_scale_factor = b / sample_subset_size + return brange, residual_scale_factor + + +def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None): + if scaling_vector is None: + x_flat = x.flatten(1) + residual = residual.flatten(1) + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + else: + x_plus_residual = scaled_index_add( + x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor + ) + return x_plus_residual + + +attn_bias_cache: Dict[Tuple, Any] = {} + + +def get_attn_bias_and_cat(x_list, branges=None): + """ + this will perform the index select, cat the tensors, and provide the attn_bias from cache + """ + batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list] + all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list)) + if all_shapes not in attn_bias_cache.keys(): + seqlens = [] + for b, x in zip(batch_sizes, x_list): + for _ in range(b): + seqlens.append(x.shape[1]) + attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) + attn_bias._batch_sizes = batch_sizes + attn_bias_cache[all_shapes] = attn_bias + + if branges is not None: + cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1]) + else: + tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list) + cat_tensors = torch.cat(tensors_bs1, dim=1) + + return attn_bias_cache[all_shapes], cat_tensors + + +def drop_add_residual_stochastic_depth_list( + x_list: List[Tensor], + residual_func: Callable[[Tensor, Any], Tensor], + sample_drop_ratio: float = 0.0, + scaling_vector=None, +) -> Tensor: + # 1) generate random set of indices for dropping samples in the batch + branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list] + branges = [s[0] for s in branges_scales] + residual_scale_factors = [s[1] for s in branges_scales] + + # 2) get attention bias and index+concat the tensors + attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges) + + # 3) apply residual_func to get residual, and split the result + residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore + + outputs = [] + for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors): + outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x)) + return outputs + + +class NestedTensorBlock(Block): + def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]: + """ + x_list contains a list of tensors to nest together and run + """ + assert isinstance(self.attn, MemEffAttention) + + if self.training and self.sample_drop_ratio > 0.0: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.attn(self.norm1(x), attn_bias=attn_bias) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.mlp(self.norm2(x)) + + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None, + ) + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None, + ) + return x_list + else: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias)) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + attn_bias, x = get_attn_bias_and_cat(x_list) + x = x + attn_residual_func(x, attn_bias=attn_bias) + x = x + ffn_residual_func(x) + return attn_bias.split(x) + + def forward(self, x_or_x_list, attn_bias=None): + if isinstance(x_or_x_list, Tensor): + return super().forward(x_or_x_list, attn_bias) + elif isinstance(x_or_x_list, list): + assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage" + return self.forward_nested(x_or_x_list) + else: + raise AssertionError + + +def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module: + if not depth_first and include_root: + fn(module=module, name=name) + for child_name, child_module in module.named_children(): + child_name = ".".join((name, child_name)) if name else child_name + named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True) + if depth_first and include_root: + fn(module=module, name=name) + return module + + +class BlockChunk(nn.ModuleList): + def forward(self, x, others=None): + for b in self: + if others == None: + x = b(x) + else: + x = b(x, others) + return x + + +class DinoVisionTransformer(nn.Module): + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + ffn_bias=True, + proj_bias=True, + drop_path_rate=0.0, + drop_path_uniform=False, + #init_values=None, # for layerscale: None or 0 => no layerscale + init_values=1e-5, # for layerscale: None or 0 => no layerscale + embed_layer=PatchEmbed, + act_layer=nn.GELU, + block_fn=NestedTensorBlock, + ffn_layer="mlp", + block_chunks=1, + window_size=37, + **kwargs + ): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + proj_bias (bool): enable bias for proj in attn if True + ffn_bias (bool): enable bias for ffn if True + drop_path_rate (float): stochastic depth rate + drop_path_uniform (bool): apply uniform drop rate across blocks + weight_init (str): weight init scheme + init_values (float): layer-scale init values + embed_layer (nn.Module): patch embedding layer + act_layer (nn.Module): MLP activation layer + block_fn (nn.Module): transformer block class + ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity" + block_chunks: (int) split block sequence into block_chunks units for FSDP wrap + """ + super().__init__() + norm_layer = partial(nn.LayerNorm, eps=1e-6) + + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_tokens = 1 + self.n_blocks = depth + self.num_heads = num_heads + self.patch_size = patch_size + self.window_size = window_size + + self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) + + if drop_path_uniform is True: + dpr = [drop_path_rate] * depth + else: + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + + if ffn_layer == "mlp": + logger.info("using MLP layer as FFN") + ffn_layer = Mlp + elif ffn_layer == "swiglufused" or ffn_layer == "swiglu": + logger.info("using SwiGLU layer as FFN") + ffn_layer = SwiGLUFFNFused + elif ffn_layer == "identity": + logger.info("using Identity layer as FFN") + + def f(*args, **kwargs): + return nn.Identity() + + ffn_layer = f + else: + raise NotImplementedError + + blocks_list = [ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + ffn_bias=ffn_bias, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer, + ffn_layer=ffn_layer, + init_values=init_values, + ) + for i in range(depth) + ] + if block_chunks > 0: + self.chunked_blocks = True + chunked_blocks = [] + chunksize = depth // block_chunks + for i in range(0, depth, chunksize): + # this is to keep the block index consistent if we chunk the block list + chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize]) + self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks]) + else: + self.chunked_blocks = False + self.blocks = nn.ModuleList(blocks_list) + + self.norm = norm_layer(embed_dim) + self.head = nn.Identity() + + self.mask_token = nn.Parameter(torch.zeros(1, embed_dim)) + + self.init_weights() + + def init_weights(self): + trunc_normal_(self.pos_embed, std=0.02) + nn.init.normal_(self.cls_token, std=1e-6) + named_apply(init_weights_vit_timm, self) + + def interpolate_pos_encoding(self, x, w, h): + previous_dtype = x.dtype + npatch = x.shape[1] - 1 + N = self.pos_embed.shape[1] - 1 + if npatch == N and w == h: + return self.pos_embed + pos_embed = self.pos_embed.float() + class_pos_embed = pos_embed[:, 0] + patch_pos_embed = pos_embed[:, 1:] + dim = x.shape[-1] + w0 = w // self.patch_size + h0 = h // self.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + w0, h0 = w0 + 0.1, h0 + 0.1 + + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), + scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), + mode="bicubic", + ) + + assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype) + + def prepare_tokens_with_masks(self, x, masks=None): + B, nc, w, h = x.shape + x = self.patch_embed(x) + if masks is not None: + x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x) + + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + x = x + self.interpolate_pos_encoding(x, w, h) + + return x + + def forward_features_list(self, x_list, masks_list): + x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)] + for blk in self.blocks: + x = blk(x) + + all_x = x + output = [] + for x, masks in zip(all_x, masks_list): + x_norm = self.norm(x) + output.append( + { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_patchtokens": x_norm[:, 1:], + "x_prenorm": x, + "masks": masks, + } + ) + return output + + def forward_features(self, x, masks=None): + if isinstance(x, list): + return self.forward_features_list(x, masks) + + B, C, H, W = x.size() + pad_h = (self.patch_size - H % self.patch_size) + pad_w = (self.patch_size - W % self.patch_size) + if pad_h == self.patch_size: + pad_h = 0 + if pad_w == self.patch_size: + pad_w = 0 + #x = nn.functional.pad(x, (pad_h//2, pad_h-pad_h//2, pad_w//2, pad_w-pad_w//2)) + if pad_h + pad_w > 0: + x = torch.nn.functional.interpolate(x, (H+pad_h, W+pad_w), mode='bilinear') + + x = self.prepare_tokens_with_masks(x, masks) + + features = [] + for blk in self.blocks: + x = blk(x) + # for idx in range(len(self.blocks[0])): + # x = self.blocks[0][idx](x) + # if (idx + 1) % (len(self.blocks[0]) // 4) == 0: + # features.append(x) + + #return [features, (B, (H+pad_h)//self.patch_size, (W+pad_w)//self.patch_size, H, W)] + + x_norm = self.norm(x) + # return { + # "x_norm_clstoken": x_norm[:, 0], + # "x_norm_patchtokens": x_norm[:, 1:], + # "x_prenorm": x, + # "masks": masks, + # } + features = [] + features.append(x_norm) + features.append(x_norm) + features.append(x_norm) + features.append(x_norm) + return [features, (B, (H+pad_h)//self.patch_size, (W+pad_w)//self.patch_size, H, W)] + + def _get_intermediate_layers_not_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + # If n is an int, take the n last blocks. If it's a list, take them + output, total_block_len = [], len(self.blocks) + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for i, blk in enumerate(self.blocks): + x = blk(x) + if i in blocks_to_take: + output.append(x) + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def _get_intermediate_layers_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + output, i, total_block_len = [], 0, len(self.blocks[-1]) + # If n is an int, take the n last blocks. If it's a list, take them + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for block_chunk in self.blocks: + for blk in block_chunk[i:]: # Passing the nn.Identity() + x = blk(x) + if i in blocks_to_take: + output.append(x) + i += 1 + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def get_intermediate_layers( + self, + x: torch.Tensor, + n: Union[int, Sequence] = 1, # Layers or n last layers to take + reshape: bool = False, + return_class_token: bool = False, + norm=True, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: + if self.chunked_blocks: + outputs = self._get_intermediate_layers_chunked(x, n) + else: + outputs = self._get_intermediate_layers_not_chunked(x, n) + if norm: + outputs = [self.norm(out) for out in outputs] + class_tokens = [out[:, 0] for out in outputs] + outputs = [out[:, 1:] for out in outputs] + if reshape: + B, _, w, h = x.shape + outputs = [ + out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous() + for out in outputs + ] + if return_class_token: + return tuple(zip(outputs, class_tokens)) + return tuple(outputs) + + def forward(self, *args, is_training=False, **kwargs): + ret = self.forward_features(*args, **kwargs) + return ret + # if is_training: + # return ret + # else: + # return self.head(ret["x_norm_clstoken"]) + + +class PosConv(nn.Module): + # PEG from https://arxiv.org/abs/2102.10882 + def __init__(self, in_chans, embed_dim=768, stride=1): + super(PosConv, self).__init__() + self.proj = nn.Sequential( + nn.Conv2d(in_chans, embed_dim, 37, stride, 18, bias=True, groups=embed_dim), + ) + self.stride = stride + + def forward(self, x, size): + B, N, C = x.shape + cnn_feat_token = x.transpose(1, 2).view(B, C, *size) + x = self.proj(cnn_feat_token) + if self.stride == 1: + x += cnn_feat_token + x = x.flatten(2).transpose(1, 2) + return x + + #def no_weight_decay(self): + #return ['proj.%d.weight' % i for i in range(4)] + + +class DinoWindowVisionTransformer(nn.Module): + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + ffn_bias=True, + proj_bias=True, + drop_path_rate=0.0, + drop_path_uniform=False, + #init_values=None, # for layerscale: None or 0 => no layerscale + init_values=1e-5, # for layerscale: None or 0 => no layerscale + embed_layer=PatchEmbed, + act_layer=nn.GELU, + block_fn=NestedTensorBlock, + ffn_layer="mlp", + block_chunks=1, + window_size=7, + **kwargs + ): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + proj_bias (bool): enable bias for proj in attn if True + ffn_bias (bool): enable bias for ffn if True + drop_path_rate (float): stochastic depth rate + drop_path_uniform (bool): apply uniform drop rate across blocks + weight_init (str): weight init scheme + init_values (float): layer-scale init values + embed_layer (nn.Module): patch embedding layer + act_layer (nn.Module): MLP activation layer + block_fn (nn.Module): transformer block class + ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity" + block_chunks: (int) split block sequence into block_chunks units for FSDP wrap + """ + super().__init__() + norm_layer = partial(nn.LayerNorm, eps=1e-6) + + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_tokens = 1 + self.n_blocks = depth + self.num_heads = num_heads + self.patch_size = patch_size + + self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + #self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + #self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + + self.pos_conv = PosConv(self.embed_dim, self.embed_dim) + + self.window_size = window_size + #self.conv_block = nn.ModuleList([ConvBlock(embed_dim) for i in range(4)]) + #self.conv_block = nn.ModuleList([nn.Identity() for i in range(4)]) + + if drop_path_uniform is True: + dpr = [drop_path_rate] * depth + else: + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + + if ffn_layer == "mlp": + logger.info("using MLP layer as FFN") + ffn_layer = Mlp + elif ffn_layer == "swiglufused" or ffn_layer == "swiglu": + logger.info("using SwiGLU layer as FFN") + ffn_layer = SwiGLUFFNFused + elif ffn_layer == "identity": + logger.info("using Identity layer as FFN") + + def f(*args, **kwargs): + return nn.Identity() + + ffn_layer = f + else: + raise NotImplementedError + + blocks_list = [ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + ffn_bias=ffn_bias, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer, + ffn_layer=ffn_layer, + init_values=init_values, + ) + for i in range(depth) + ] + if block_chunks > 0: + self.chunked_blocks = True + chunked_blocks = [] + chunksize = depth // block_chunks + for i in range(0, depth, chunksize): + # this is to keep the block index consistent if we chunk the block list + chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize]) + self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks]) + else: + self.chunked_blocks = False + self.blocks = nn.ModuleList(blocks_list) + + self.norm = norm_layer(embed_dim) + self.head = nn.Identity() + + self.mask_token = nn.Parameter(torch.zeros(1, embed_dim)) + + self.nh = -1 + self.nw = -1 + try: + H = cfg.data_basic['crop_size'][0] + W = cfg.data_basic['crop_size'][1] + pad_h = (self.patch_size - H % self.patch_size) + pad_w = (self.patch_size - W % self.patch_size) + if pad_h == self.patch_size: + pad_h = 0 + if pad_w == self.patch_size: + pad_w = 0 + self.nh = (H + pad_h) // self.patch_size + self.nw = (W + pad_w) // self.patch_size + self.prepare_attn_bias((self.nh, self.nw)) + except: + pass + self.init_weights() + + self.total_step = 10000 # For PE -> GPE transfer + self.start_step = 2000 + self.current_step = 20000 + + def init_weights(self): + #trunc_normal_(self.pos_embed, std=0.02) + #nn.init.normal_(self.cls_token, std=1e-6) + named_apply(init_weights_vit_timm, self) + for i in range(4): + try: + nn.init.constant_(self.conv_block[i].conv2.weight, 0.0) + except: + pass + + def interpolate_pos_encoding(self, x, w, h): + previous_dtype = x.dtype + #npatch = x.shape[1] - 1 + #N = self.pos_embed.shape[1] - 1 + npatch = x.shape[1] + N = self.pos_embed.shape[1] + if npatch == N and w == h: + return self.pos_embed + pos_embed = self.pos_embed.float() + #class_pos_embed = pos_embed[:, 0] + #patch_pos_embed = pos_embed[:, 1:] + patch_pos_embed = pos_embed + dim = x.shape[-1] + w0 = w // self.patch_size + h0 = h // self.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + w0, h0 = w0 + 0.1, h0 + 0.1 + + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), + scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), + mode="bicubic", + ) + + assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return patch_pos_embed.to(previous_dtype) + #return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype) + + def window_partition(self, x: torch.Tensor, window_size: int, hw: Tuple[int, int], conv_feature=False) -> Tuple[torch.Tensor, Tuple[int, int]]: + """ + Partition into non-overlapping windows with padding if needed. + Args: + x (tensor): input tokens with [B, H, W, C]. + window_size (int): window size. + + Returns: + windows: windows after partition with [B * num_windows, window_size, window_size, C]. + (Hp, Wp): padded height and width before partition + """ + if conv_feature == False: + B, N, C = x.shape + H, W = hw[0], hw[1] + + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size * window_size, C) + else: + B, C, H, W = x.shape + + x = x.view(B, C, H // window_size, window_size, W // window_size, window_size) + + windows = x.permute(0, 2, 4, 3, 5, 1).contiguous().view(-1, window_size * window_size, C) + + #y = torch.cat((x_cls, windows), dim=1) + return windows #, (Hp, Wp) + + + def window_unpartition(self, + windows: torch.Tensor, window_size: int, hw: Tuple[int, int], conv_feature=False + ) -> torch.Tensor: + """ + Window unpartition into original sequences and removing padding. + Args: + windows (tensor): input tokens with [B * num_windows, window_size, window_size, C]. + window_size (int): window size. + pad_hw (Tuple): padded height and width (Hp, Wp). + hw (Tuple): original height and width (H, W) before padding. + + Returns: + x: unpartitioned sequences with [B, H, W, C]. + """ + H, W = hw + + B = windows.shape[0] // (H * W // window_size // window_size) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + + if conv_feature == False: + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp * Wp, -1) + else: + C = windows.shape[-1] + x = x.permute(0, 5, 1, 3, 2, 4).contiguous().view(B, C, H, W) + + # if Hp > H or Wp > W: + # x = x[:, :H, :W, :].contiguous() + return x + + def prepare_tokens_with_masks(self, x, masks=None, step=-1): + B, nc, w, h = x.shape + x = self.patch_embed(x) + if masks is not None: + x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x) + + #x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + if step == -1: + step = self.current_step + else: + self.current_step = step + + if step < self.start_step: + coef = 0.0 + elif step < self.total_step: + coef = (step - self.start_step) / (self.total_step - self.start_step) + else: + coef = 1.0 + + x = x + (1 - coef) * self.interpolate_pos_encoding(x, w, h) + coef * self.pos_conv(x, (self.nh, self.nw)) + + return x + + def prepare_attn_bias(self, shape): + window_size = self.window_size + if window_size <= 0: + return + + import xformers.components.attention.attention_patterns as AP + + nh, nw = shape + radius = (window_size-1)//2 + mask_ori = AP.local_2d_pattern(nh, nw, distance = radius + 0.1, p=torch.inf).cuda() + + pad = (8 - (nh * nw) % 8) + if pad == 8: + pad = 0 + mask_pad = nn.functional.pad(mask_ori, (0, pad)).contiguous() + if pad > 0: + mask = mask_pad[:, :-pad].view(nh, nw, nh, nw) + else: + mask = mask_pad[:, :].view(nh, nw, nh, nw) + + # angle + mask[:radius+1, :radius+1, :window_size, :window_size] = True + mask[:radius+1, -radius-1:, :window_size, -window_size:] = True + mask[-radius-1:, :radius+1, -window_size:, :window_size] = True + mask[-radius-1:, -radius-1:, -window_size:, -window_size:] = True + + # edge + mask[radius+1:-radius-1, :radius+1, :, :] = mask[radius+1:-radius-1, radius:radius+1, :, :] + mask[radius+1:-radius-1, -radius-1:, :, :] = mask[radius+1:-radius-1, -radius-1:-radius, :, :] + mask[:radius+1, radius+1:-radius-1, :, :] = mask[radius:radius+1, radius+1:-radius-1, :, :] + mask[-radius-1:, radius+1:-radius-1, :, :] = mask[-radius-1:-radius, radius+1:-radius-1, :, :] + + mask = mask.view(nh*nw, nh*nw) + bias_pad = torch.log(mask_pad) + #bias = bias_pad[:, :-pad] + self.register_buffer('attn_bias', bias_pad) + + return bias_pad + + def forward_features_list(self, x_list, masks_list): + x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)] + for blk in self.blocks: + x = blk(x) + + all_x = x + output = [] + for x, masks in zip(all_x, masks_list): + x_norm = self.norm(x) + output.append( + { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_patchtokens": x_norm[:, 1:], + "x_prenorm": x, + "masks": masks, + } + ) + return output + + def forward_features(self, x, masks=None, **kwargs): + if isinstance(x, list): + return self.forward_features_list(x, masks) + + B, C, H, W = x.size() + pad_h = (self.patch_size - H % self.patch_size) + pad_w = (self.patch_size - W % self.patch_size) + if pad_h == self.patch_size: + pad_h = 0 + if pad_w == self.patch_size: + pad_w = 0 + #x = nn.functional.pad(x, (pad_h//2, pad_h-pad_h//2, pad_w//2, pad_w-pad_w//2)) + if pad_h + pad_w > 0: + x = torch.nn.functional.interpolate(x, (H+pad_h, W+pad_w), mode='bilinear') + + nh = (H+pad_h)//self.patch_size + nw = (W+pad_w)//self.patch_size + + if self.window_size > 0: + if nh == self.nh and nw == self.nw: + attn_bias = self.attn_bias + else: + attn_bias = self.prepare_attn_bias(((H+pad_h)//self.patch_size, (W+pad_w)//self.patch_size)) + self.nh = nh + self.nw = nw + attn_bias = attn_bias.unsqueeze(0).repeat(B * self.num_heads, 1, 1) + else: + attn_bias = None + + x = self.prepare_tokens_with_masks(x, masks) + #x = self.patch_embed(x) + + features = [] + #x = self.window_partition(x, self.window_size, (H // self.patch_size, W // self.patch_size)) + for blk in self.blocks: + x = blk(x, attn_bias) + #x = self.window_unpartition(x, self.window_size, (H // self.patch_size, W // self.patch_size)) + + # for idx in range(len(self.blocks[0])): + # x = self.blocks[0][idx](x, attn_bias) + + # if (idx + 1) % (len(self.blocks[0]) // 4) == 0: + # x = self.window_unpartition(x, self.window_size, (H // self.patch_size, W // self.patch_size), conv_feature=True) + # x = self.conv_block[idx // (len(self.blocks[0]) // 4)](x) + # if idx + 1 != len(self.blocks[0]): + # x = self.window_partition(x, self.window_size, (H // self.patch_size, W // self.patch_size), conv_feature=True) + # else: + # b, c, h, w = x.size() + # x = x.permute(0, 2, 3, 1).contiguous().view(b, h, w, c) + #features.append(x) + + #return [features, (B, (H+pad_h)//self.patch_size, (W+pad_w)//self.patch_size, H, W)] + + x_norm = self.norm(x) + # return { + # "x_norm_clstoken": x_norm[:, 0], + # "x_norm_patchtokens": x_norm[:, 1:], + # "x_prenorm": x, + # "masks": masks, + # } + features = [] + features.append(x_norm) + features.append(x_norm) + features.append(x_norm) + features.append(x_norm) + return [features, (B, (H+pad_h)//self.patch_size, (W+pad_w)//self.patch_size, H, W)] + + def _get_intermediate_layers_not_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + # If n is an int, take the n last blocks. If it's a list, take them + output, total_block_len = [], len(self.blocks) + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for i, blk in enumerate(self.blocks): + x = blk(x) + if i in blocks_to_take: + output.append(x) + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def _get_intermediate_layers_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + output, i, total_block_len = [], 0, len(self.blocks[-1]) + # If n is an int, take the n last blocks. If it's a list, take them + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for block_chunk in self.blocks: + for blk in block_chunk[i:]: # Passing the nn.Identity() + x = blk(x) + if i in blocks_to_take: + output.append(x) + i += 1 + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def get_intermediate_layers( + self, + x: torch.Tensor, + n: Union[int, Sequence] = 1, # Layers or n last layers to take + reshape: bool = False, + return_class_token: bool = False, + norm=True, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: + if self.chunked_blocks: + outputs = self._get_intermediate_layers_chunked(x, n) + else: + outputs = self._get_intermediate_layers_not_chunked(x, n) + if norm: + outputs = [self.norm(out) for out in outputs] + class_tokens = [out[:, 0] for out in outputs] + outputs = [out[:, 1:] for out in outputs] + if reshape: + B, _, w, h = x.shape + outputs = [ + out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous() + for out in outputs + ] + if return_class_token: + return tuple(zip(outputs, class_tokens)) + return tuple(outputs) + + def forward(self, *args, is_training=False, **kwargs): + ret = self.forward_features(*args, **kwargs) + return ret + # if is_training: + # return ret + # else: + # return self.head(ret["x_norm_clstoken"]) + + + +def init_weights_vit_timm(module: nn.Module, name: str = ""): + """ViT weight initialization, original timm impl (for reproducibility)""" + if isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + + +def vit_small(patch_size=14, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=384, + depth=12, + num_heads=6, + mlp_ratio=4, + block_fn=partial(NestedTensorBlock, attn_class=MemEffAttention), + **kwargs, + ) + return model + + +def vit_base(patch_size=14, **kwargs): + model = DinoWindowVisionTransformer( + patch_size=patch_size, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + block_fn=partial(NestedTensorBlock, attn_class=MemEffAttention), + **kwargs, + ) + return model + + +def vit_large(patch_size=14, checkpoint=None, **kwargs): + model = DinoVisionTransformer( + img_size = 518, + patch_size=patch_size, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + block_fn=partial(NestedTensorBlock, attn_class=MemEffAttention), + **kwargs, + ) + + if checkpoint is not None: + with open(checkpoint, "rb") as f: + state_dict = torch.load(f) + try: + model.load_state_dict(state_dict, strict=True) + except: + new_state_dict = {} + for key, value in state_dict.items(): + if 'blocks' in key: + key_new = 'blocks.0' + key[len('blocks'):] + else: + key_new = key + new_state_dict[key_new] = value + + model.load_state_dict(new_state_dict, strict=True) + #del model.norm + if 'del_mask_token' in kwargs and kwargs['del_mask_token'] == True: + del model.mask_token + return model + + +def vit_giant2(patch_size=14, checkpoint=None, **kwargs): + """ + Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64 + """ + model = DinoVisionTransformer( + img_size = 518, + patch_size=patch_size, + embed_dim=1536, + depth=40, + num_heads=24, + mlp_ratio=4, + # block_fn=partial(Block, attn_class=MemEffAttention), + block_fn=partial(NestedTensorBlock, attn_class=MemEffAttention), + **kwargs, + ) + + if checkpoint is not None: + with open(checkpoint, "rb") as f: + state_dict = torch.load(f) + + try: + model.load_state_dict(state_dict, strict=True) + except: + new_state_dict = {} + for key, value in state_dict.items(): + if 'blocks' in key: + key_new = 'blocks.0' + key[len('blocks'):] + else: + key_new = key + new_state_dict[key_new] = value + + import pdb;pdb.set_trace() + model.load_state_dict(new_state_dict, strict=True) + #del model.norm + del model.mask_token + + return model + + +if __name__ == '__main__': + try: + from mmcv.utils import Config + except: + from mmengine import Config + + #rgb = torch.rand((2, 3, 518, 518)).cuda() + + #cfg.data_basic['crop_size']['0'] + #cfg.data_basic['crop_size']['1'] + cfg = Config.fromfile('mu.hu/monodepth/mono/configs/HourglassDecoder/pub12.convlarge.0.3_150.py') + + #rgb = torch.arange(0, 2*3*1036*1036, 1).cuda().float().view(2, 3, 1036, 1036) + rgb = torch.zeros(1, 3, 1400, 1680).cuda() + model = vit_large(checkpoint="pretrained_weight_repo/vit/dinov2_vitl14_pretrain.pth", kwarg=cfg).cuda() + + #import timm + #model2 = timm.models.vision_transformer.vit_large_patch14_dinov2().cuda() + #timm.models.load_checkpoint(model2, '/cpfs02/shared/public/yvan/pretrained_weight_repo/vit/dinov2_vitl14_pretrain.pth', filter_fn=timm.models.vision_transformer.checkpoint_filter_fn) + + out1 = model(rgb) + #out2 = model2(rgb) + temp = 0 + + + +# import time +# window_size = 37 +# def prepare_window_masks(shape): +# if window_size <= 0: +# return None +# import xformers.components.attention.attention_patterns as AP + +# B, nh, nw, _, _ = shape +# radius = (window_size-1)//2 +# #time0 = time.time() +# d = AP.local_nd_distance(nh, nw, distance = radius + 0.1, p=torch.inf).cuda() +# #mask = AP.local_2d_pattern(nh, nw, distance = radius + 0.1, p=torch.inf).cuda() +# # mask = mask.view(nh, nw, nh, nw) +# # #time1 = time.time() - time0 + +# # # angle +# # mask[:radius+1, :radius+1, :window_size, :window_size] = True +# # mask[:radius+1, -radius-1:, :window_size, -window_size:] = True +# # mask[-radius-1:, :radius+1, -window_size:, :window_size] = True +# # mask[-radius-1:, -radius-1:, -window_size:, -window_size:] = True +# # time2 = time.time() - time0 - time1 + +# # # edge +# # mask[radius+1:-radius-1, :radius+1, :, :] = mask[radius+1:-radius-1, radius:radius+1, :, :] +# # mask[radius+1:-radius-1, -radius-1:, :, :] = mask[radius+1:-radius-1, -radius-1:-radius, :, :] +# # mask[:radius+1, radius+1:-radius-1, :, :] = mask[radius:radius+1, radius+1:-radius-1, :, :] +# # mask[-radius-1:, radius+1:-radius-1, :, :] = mask[-radius-1:-radius, radius+1:-radius-1, :, :] +# # time3 = time.time() - time0 - time2 +# # print(time1, time2, time3) + +# # return mask.view(nw*nw, nh*nw).unsqueeze(0).repeat(B, 1) + +# shape = (1, 55, 55, None, None) +# mask = prepare_window_masks(shape) +# # temp = 1 \ No newline at end of file diff --git a/geobench/modeling/backbones/vit/ViT_DINO_reg.py b/geobench/modeling/backbones/vit/ViT_DINO_reg.py new file mode 100644 index 0000000000000000000000000000000000000000..9f7208fb4f671924b3c664b5415bd4093048c5b0 --- /dev/null +++ b/geobench/modeling/backbones/vit/ViT_DINO_reg.py @@ -0,0 +1,1328 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +from functools import partial +import math +import logging +from typing import Sequence, Tuple, Union, Callable, Optional, Dict, Any, List + +import torch +import torch.nn as nn +from torch import Tensor +import torch.utils.checkpoint +from torch.nn.init import trunc_normal_ +import torch.nn.init +import torch.nn.functional as F + +#from dinov2.layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block + +logger = logging.getLogger("dinov2") + +# SSF finetuning originally by dongzelian +def init_ssf_scale_shift(dim): + scale = nn.Parameter(torch.ones(dim)) + shift = nn.Parameter(torch.zeros(dim)) + + nn.init.normal_(scale, mean=1, std=.02) + nn.init.normal_(shift, std=.02) + + return scale, shift + +def ssf_ada(x, scale, shift): + assert scale.shape == shift.shape + if x.shape[-1] == scale.shape[0]: + return x * scale + shift + elif x.shape[1] == scale.shape[0]: + return x * scale.view(1, -1, 1, 1) + shift.view(1, -1, 1, 1) + else: + raise ValueError('the input tensor shape does not match the shape of the scale factor.') + +# LoRA finetuning originally by edwardjhu +class LoRALayer(): + def __init__( + self, + r: int, + lora_alpha: int, + lora_dropout: float, + merge_weights: bool, + ): + self.r = r + self.lora_alpha = lora_alpha + # Optional dropout + if lora_dropout > 0.: + self.lora_dropout = nn.Dropout(p=lora_dropout) + else: + self.lora_dropout = lambda x: x + # Mark the weight as unmerged + self.merged = False + self.merge_weights = merge_weights + +class LoRALinear(nn.Linear, LoRALayer): + # LoRA implemented in a dense layer + def __init__( + self, + in_features: int, + out_features: int, + r: int = 0, + lora_alpha: int = 1, + lora_dropout: float = 0., + fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out) + merge_weights: bool = True, + **kwargs + ): + nn.Linear.__init__(self, in_features, out_features, **kwargs) + LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, + merge_weights=merge_weights) + + self.fan_in_fan_out = fan_in_fan_out + # Actual trainable parameters + if r > 0: + self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features))) + self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r))) + self.scaling = self.lora_alpha / self.r + # Freezing the pre-trained weight matrix + self.weight.requires_grad = False + # import pdb;pdb.set_trace() + + self.reset_parameters() + if fan_in_fan_out: + self.weight.data = self.weight.data.transpose(0, 1) + + def reset_parameters(self): + + if hasattr(self, 'lora_A'): + # initialize B the same way as the default for nn.Linear and A to zero + # this is different than what is described in the paper but should not affect performance + nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) + nn.init.zeros_(self.lora_B) + else: + nn.Linear.reset_parameters(self) + + # def train(self, mode: bool = True): + # def T(w): + # return w.transpose(0, 1) if self.fan_in_fan_out else w + # nn.Linear.train(self, mode) + # if mode: + # if self.merge_weights and self.merged: + # # Make sure that the weights are not merged + # if self.r > 0: + # self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling + # self.merged = False + # else: + # if self.merge_weights and not self.merged: + # # Merge the weights and mark it + # if self.r > 0: + # self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling + # self.merged = True + + def forward(self, x: torch.Tensor): + def T(w): + return w.transpose(0, 1) if self.fan_in_fan_out else w + if self.r > 0 and not self.merged: + result = F.linear(x, T(self.weight), bias=self.bias) + result += (self.lora_dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)) * self.scaling + return result + else: + return F.linear(x, T(self.weight), bias=self.bias) + +def make_2tuple(x): + if isinstance(x, tuple): + assert len(x) == 2 + return x + + assert isinstance(x, int) + return (x, x) + +def drop_path(x, drop_prob: float = 0.0, training: bool = False): + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0: + random_tensor.div_(keep_prob) + output = x * random_tensor + return output + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + +class LayerScale(nn.Module): + def __init__( + self, + dim: int, + init_values: Union[float, Tensor] = 1e-5, + inplace: bool = False, + ) -> None: + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: Tensor) -> Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma + +class PatchEmbed(nn.Module): + """ + 2D image to patch embedding: (B,C,H,W) -> (B,N,D) + + Args: + img_size: Image size. + patch_size: Patch token size. + in_chans: Number of input image channels. + embed_dim: Number of linear projection output channels. + norm_layer: Normalization layer. + """ + + def __init__( + self, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 16, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer: Optional[Callable] = None, + flatten_embedding: bool = True, + tuning_mode: Optional[str] = None + ) -> None: + super().__init__() + + image_HW = make_2tuple(img_size) + patch_HW = make_2tuple(patch_size) + patch_grid_size = ( + image_HW[0] // patch_HW[0], + image_HW[1] // patch_HW[1], + ) + + self.img_size = image_HW + self.patch_size = patch_HW + self.patches_resolution = patch_grid_size + self.num_patches = patch_grid_size[0] * patch_grid_size[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.flatten_embedding = flatten_embedding + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + if tuning_mode != None: + self.tuning_mode = tuning_mode + if tuning_mode == 'ssf': + self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(embed_dim) + else: + pass + #raise NotImplementedError() + else: + self.tuning_mode = None + + def forward(self, x: Tensor) -> Tensor: + _, _, H, W = x.shape + patch_H, patch_W = self.patch_size + + assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}" + assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}" + + x = self.proj(x) # B C H W + H, W = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) # B HW C + x = self.norm(x) + if self.tuning_mode == 'ssf': + x = ssf_ada(x, self.ssf_scale_1, self.ssf_shift_1) + if not self.flatten_embedding: + x = x.reshape(-1, H, W, self.embed_dim) # B H W C + return x + + def flops(self) -> float: + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops + +class Mlp(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = nn.GELU, + drop: float = 0.0, + bias: bool = True, + tuning_mode: Optional[int] = None + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) + self.drop = nn.Dropout(drop) + + if tuning_mode != None: + self.tuning_mode = tuning_mode + if tuning_mode == 'ssf': + self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(hidden_features) + self.ssf_scale_2, self.ssf_shift_2 = init_ssf_scale_shift(out_features) + else: + pass + #raise NotImplementedError() + else: + self.tuning_mode = None + + def forward(self, x: Tensor) -> Tensor: + x = self.fc1(x) + if self.tuning_mode == 'ssf': + x = ssf_ada(x, self.ssf_scale_1, self.ssf_shift_1) + + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + if self.tuning_mode == 'ssf': + x = ssf_ada(x, self.ssf_scale_2, self.ssf_shift_2) + + x = self.drop(x) + return x + +class SwiGLUFFN(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + tuning_mode: Optional[int] = None + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) + self.w3 = nn.Linear(hidden_features, out_features, bias=bias) + + if tuning_mode != None: + self.tuning_mode = tuning_mode + if tuning_mode == 'ssf': + self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(2 * hidden_features) + self.ssf_scale_2, self.ssf_shift_2 = init_ssf_scale_shift(out_features) + else: + pass + #raise NotImplementedError() + else: + self.tuning_mode = None + + + def forward(self, x: Tensor) -> Tensor: + x12 = self.w12(x) + if self.tuning_mode == 'ssf': + x12 = ssf_ada(x12, self.ssf_scale_1, self.ssf_shift_1) + + x1, x2 = x12.chunk(2, dim=-1) + hidden = F.silu(x1) * x2 + out = self.w3(hidden) + + if self.tuning_mode == 'ssf': + out = ssf_ada(out, self.ssf_scale_2, self.ssf_scale_2) + + return out + +try: + from xformers.ops import SwiGLU + #import numpy.bool + XFORMERS_AVAILABLE = True +except ImportError: + SwiGLU = SwiGLUFFN + XFORMERS_AVAILABLE = False + +class SwiGLUFFNFused(SwiGLU): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + out_features = out_features or in_features + hidden_features = hidden_features or in_features + hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 + super().__init__( + in_features=in_features, + hidden_features=hidden_features, + out_features=out_features, + bias=bias, + ) + + +try: + from xformers.ops import memory_efficient_attention, unbind, fmha + from xformers.components.attention import ScaledDotProduct + from xformers.components import MultiHeadDispatch + #import numpy.bool + XFORMERS_AVAILABLE = True +except ImportError: + logger.warning("xFormers not available") + XFORMERS_AVAILABLE = False + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + window_size: int = 0, + tuning_mode: Optional[int] = None + ) -> None: + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + if tuning_mode == 'lora': + self.tuning_mode = tuning_mode + self.qkv = LoRALinear(dim, dim * 3, bias=qkv_bias, r=8) + else: + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + + self.attn_drop = nn.Dropout(attn_drop) + + if tuning_mode == 'lora': + self.tuning_mode = tuning_mode + self.proj = LoRALinear(dim, dim, bias=proj_bias, r=8) + else: + self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + + if tuning_mode != None: + self.tuning_mode = tuning_mode + if tuning_mode == 'ssf': + self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(dim * 3) + self.ssf_scale_2, self.ssf_shift_2 = init_ssf_scale_shift(dim) + else: + pass + #raise NotImplementedError() + else: + self.tuning_mode = None + + #if not self.training: + # + # self.attn = ScaledDotProduct() + #self.attn = MultiHeadDispatch(dim_model=EMB, residual_dropout=DROPOUT, num_heads=HEADS, attention=attn) + + def forward(self, x: Tensor, attn_bias=None) -> Tensor: + B, N, C = x.shape + if self.tuning_mode == 'ssf': + qkv = ssf_ada(self.qkv(x), self.ssf_scale_1, self.ssf_shift_1).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + else: + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + + q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] + attn = q @ k.transpose(-2, -1) + + if attn_bias is not None: + attn = attn + attn_bias[:, :, :N] + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + + if self.tuning_mode == 'ssf': + x = ssf_ada(x, self.ssf_scale_2, self.ssf_shift_2) + + x = self.proj_drop(x) + return x + + +class MemEffAttention(Attention): + def forward(self, x: Tensor, attn_bias=None) -> Tensor: + if not XFORMERS_AVAILABLE: + #if True: + assert attn_bias is None, "xFormers is required for nested tensors usage" + return super().forward(x, attn_bias) + + B, N, C = x.shape + if self.tuning_mode == 'ssf': + qkv = ssf_ada(self.qkv(x), self.ssf_scale_1, self.ssf_shift_1).reshape(B, N, 3, self.num_heads, C // self.num_heads) + else: + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) + + q, k, v = unbind(qkv, 2) + if attn_bias is not None: + x = memory_efficient_attention(q, k, v, attn_bias=attn_bias[:, :, :N]) + else: + x = memory_efficient_attention(q, k, v) + x = x.reshape([B, N, C]) + + x = self.proj(x) + if self.tuning_mode == 'ssf': + x = ssf_ada(x, self.ssf_scale_2, self.ssf_shift_2) + + x = self.proj_drop(x) + return x + +try: + from xformers.ops import fmha + from xformers.ops import scaled_index_add, index_select_cat + #import numpy.bool + XFORMERS_AVAILABLE = True +except ImportError: + logger.warning("xFormers not available") + XFORMERS_AVAILABLE = False + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + proj_bias: bool = True, + ffn_bias: bool = True, + drop: float = 0.0, + attn_drop: float = 0.0, + init_values = None, + drop_path: float = 0.0, + act_layer: Callable[..., nn.Module] = nn.GELU, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_class: Callable[..., nn.Module] = Attention, + ffn_layer: Callable[..., nn.Module] = Mlp, + tuning_mode: Optional[int] = None + ) -> None: + super().__init__() + # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}") + self.norm1 = norm_layer(dim) + self.attn = attn_class( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + attn_drop=attn_drop, + proj_drop=drop, + tuning_mode=tuning_mode + ) + + if tuning_mode != None: + self.tuning_mode = tuning_mode + if tuning_mode == 'ssf': + self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(dim) + self.ssf_scale_2, self.ssf_shift_2 = init_ssf_scale_shift(dim) + else: + pass + #raise NotImplementedError() + else: + self.tuning_mode = None + + self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = ffn_layer( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + bias=ffn_bias, + ) + self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.sample_drop_ratio = drop_path + + def forward(self, x: Tensor, attn_bias=None) -> Tensor: + def attn_residual_func(x: Tensor, attn_bias) -> Tensor: + if self.tuning_mode == 'ssf': + return self.ls1(self.attn(ssf_ada(self.norm1(x), self.ssf_scale_1, self.ssf_shift_1), attn_bias)) + else: + return self.ls1(self.attn(self.norm1(x), attn_bias)) + + def ffn_residual_func(x: Tensor) -> Tensor: + if self.tuning_mode == 'ssf': + return self.ls2(self.mlp(ssf_ada(self.norm2(x), self.ssf_scale_2, self.ssf_shift_2))) + else: + return self.ls2(self.mlp(self.norm2(x))) + + if self.training and self.sample_drop_ratio > 0.1: + # the overhead is compensated only for a drop path rate larger than 0.1 + x = drop_add_residual_stochastic_depth( + x, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + attn_bias=attn_bias + ) + x = drop_add_residual_stochastic_depth( + x, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + elif self.training and self.sample_drop_ratio > 0.0: + x = x + self.drop_path1(attn_residual_func(x, attn_bias)) + x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 + else: + x = x + attn_residual_func(x, attn_bias) + x = x + ffn_residual_func(x) + return x + + +def drop_add_residual_stochastic_depth( + x: Tensor, + residual_func: Callable[[Tensor], Tensor], + sample_drop_ratio: float = 0.0, attn_bias=None +) -> Tensor: + # 1) extract subset using permutation + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + x_subset = x[brange] + + # 2) apply residual_func to get residual + residual = residual_func(x_subset, attn_bias) + + x_flat = x.flatten(1) + residual = residual.flatten(1) + + residual_scale_factor = b / sample_subset_size + + # 3) add the residual + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + return x_plus_residual.view_as(x) + + +def get_branges_scales(x, sample_drop_ratio=0.0): + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + residual_scale_factor = b / sample_subset_size + return brange, residual_scale_factor + + +def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None): + if scaling_vector is None: + x_flat = x.flatten(1) + residual = residual.flatten(1) + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + else: + x_plus_residual = scaled_index_add( + x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor + ) + return x_plus_residual + + +attn_bias_cache: Dict[Tuple, Any] = {} + + +def get_attn_bias_and_cat(x_list, branges=None): + """ + this will perform the index select, cat the tensors, and provide the attn_bias from cache + """ + batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list] + all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list)) + if all_shapes not in attn_bias_cache.keys(): + seqlens = [] + for b, x in zip(batch_sizes, x_list): + for _ in range(b): + seqlens.append(x.shape[1]) + attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) + attn_bias._batch_sizes = batch_sizes + attn_bias_cache[all_shapes] = attn_bias + + if branges is not None: + cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1]) + else: + tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list) + cat_tensors = torch.cat(tensors_bs1, dim=1) + + return attn_bias_cache[all_shapes], cat_tensors + + +def drop_add_residual_stochastic_depth_list( + x_list: List[Tensor], + residual_func: Callable[[Tensor, Any], Tensor], + sample_drop_ratio: float = 0.0, + scaling_vector=None, +) -> Tensor: + # 1) generate random set of indices for dropping samples in the batch + branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list] + branges = [s[0] for s in branges_scales] + residual_scale_factors = [s[1] for s in branges_scales] + + # 2) get attention bias and index+concat the tensors + attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges) + + # 3) apply residual_func to get residual, and split the result + residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore + + outputs = [] + for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors): + outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x)) + return outputs + + +class NestedTensorBlock(Block): + def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]: + """ + x_list contains a list of tensors to nest together and run + """ + assert isinstance(self.attn, MemEffAttention) + + if self.training and self.sample_drop_ratio > 0.0: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.attn(self.norm1(x), attn_bias=attn_bias) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.mlp(self.norm2(x)) + + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None, + ) + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None, + ) + return x_list + else: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias)) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + attn_bias, x = get_attn_bias_and_cat(x_list) + x = x + attn_residual_func(x, attn_bias=attn_bias) + x = x + ffn_residual_func(x) + return attn_bias.split(x) + + def forward(self, x_or_x_list, attn_bias=None): + if isinstance(x_or_x_list, Tensor): + return super().forward(x_or_x_list, attn_bias) + elif isinstance(x_or_x_list, list): + assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage" + return self.forward_nested(x_or_x_list) + else: + raise AssertionError + + +def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module: + if not depth_first and include_root: + fn(module=module, name=name) + for child_name, child_module in module.named_children(): + child_name = ".".join((name, child_name)) if name else child_name + named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True) + if depth_first and include_root: + fn(module=module, name=name) + return module + + +class BlockChunk(nn.ModuleList): + def forward(self, x, others=None): + for b in self: + if others == None: + x = b(x) + else: + x = b(x, others) + return x + + +class DinoVisionTransformer(nn.Module): + def __init__( + self, + img_size=518, + patch_size=16, + in_chans=3, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + ffn_bias=True, + proj_bias=True, + drop_path_rate=0.0, + drop_path_uniform=False, + init_values=1e-5, # for layerscale: None or 0 => no layerscale + embed_layer=PatchEmbed, + act_layer=nn.GELU, + block_fn=Block, + ffn_layer="mlp", + block_chunks=1, + num_register_tokens=0, + interpolate_antialias=False, + interpolate_offset=0.1, + multi_output=False, + tuning_mode=None, + **kwargs + ): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + proj_bias (bool): enable bias for proj in attn if True + ffn_bias (bool): enable bias for ffn if True + drop_path_rate (float): stochastic depth rate + drop_path_uniform (bool): apply uniform drop rate across blocks + weight_init (str): weight init scheme + init_values (float): layer-scale init values + embed_layer (nn.Module): patch embedding layer + act_layer (nn.Module): MLP activation layer + block_fn (nn.Module): transformer block class + ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity" + block_chunks: (int) split block sequence into block_chunks units for FSDP wrap + num_register_tokens: (int) number of extra cls tokens (so-called "registers") + interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings + interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings + """ + super().__init__() + norm_layer = partial(nn.LayerNorm, eps=1e-6) + + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_tokens = 1 + self.n_blocks = depth + self.num_heads = num_heads + self.patch_size = patch_size + self.num_register_tokens = num_register_tokens + self.interpolate_antialias = interpolate_antialias + self.interpolate_offset = interpolate_offset + + if tuning_mode != None: + self.tuning_mode = tuning_mode + if tuning_mode == 'ssf': + self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(embed_dim) + else: + pass + #raise NotImplementedError() + else: + self.tuning_mode = None + tuning_mode_list = [tuning_mode] * depth + + self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, tuning_mode=tuning_mode) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) + self.multi_output = multi_output + assert num_register_tokens >= 0 + self.register_tokens = ( + nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None + ) + + if drop_path_uniform is True: + dpr = [drop_path_rate] * depth + else: + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + + if ffn_layer == "mlp": + logger.info("using MLP layer as FFN") + ffn_layer = Mlp + elif ffn_layer == "swiglufused" or ffn_layer == "swiglu": + logger.info("using SwiGLU layer as FFN") + ffn_layer = SwiGLUFFNFused + elif ffn_layer == "identity": + logger.info("using Identity layer as FFN") + + def f(*args, **kwargs): + return nn.Identity() + + ffn_layer = f + else: + raise NotImplementedError + + blocks_list = [ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + ffn_bias=ffn_bias, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer, + ffn_layer=ffn_layer, + init_values=init_values, + tuning_mode=tuning_mode_list[i] + ) + for i in range(depth) + ] + if block_chunks > 0: + self.chunked_blocks = True + chunked_blocks = [] + chunksize = depth // block_chunks + for i in range(0, depth, chunksize): + # this is to keep the block index consistent if we chunk the block list + chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize]) + self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks]) + else: + self.chunked_blocks = False + self.blocks = nn.ModuleList(blocks_list) + + self.norm = norm_layer(embed_dim) + self.head = nn.Identity() + + self.mask_token = nn.Parameter(torch.zeros(1, embed_dim)) + + self.init_weights() + + def init_weights(self): + trunc_normal_(self.pos_embed, std=0.02) + nn.init.normal_(self.cls_token, std=1e-6) + if self.register_tokens is not None: + nn.init.normal_(self.register_tokens, std=1e-6) + named_apply(init_weights_vit_timm, self) + + def interpolate_pos_encoding(self, x, w, h): + previous_dtype = x.dtype + npatch = x.shape[1] - 1 + N = self.pos_embed.shape[1] - 1 + if npatch == N and w == h: + return self.pos_embed + pos_embed = self.pos_embed.float() + class_pos_embed = pos_embed[:, 0] + patch_pos_embed = pos_embed[:, 1:] + dim = x.shape[-1] + w0 = w // self.patch_size + h0 = h // self.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset + + sqrt_N = math.sqrt(N) + sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2), + scale_factor=(sx, sy), + mode="bicubic", + antialias=self.interpolate_antialias, + ) + + assert int(w0) == patch_pos_embed.shape[-2] + assert int(h0) == patch_pos_embed.shape[-1] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype) + + def prepare_tokens_with_masks(self, x, masks=None): + B, nc, w, h = x.shape + x = self.patch_embed(x) + if masks is not None: + x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x) + + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + x = x + self.interpolate_pos_encoding(x, w, h) + + if self.register_tokens is not None: + x = torch.cat( + ( + x[:, :1], + self.register_tokens.expand(x.shape[0], -1, -1), + x[:, 1:], + ), + dim=1, + ) + + return x + + def forward_features_list(self, x_list, masks_list): + x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)] + for blk in self.blocks: + x = blk(x) + + all_x = x + output = [] + for x, masks in zip(all_x, masks_list): + x_norm = self.norm(x) + output.append( + { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + "x_prenorm": x, + "masks": masks, + } + ) + return output + + def forward_features(self, x, masks=None): + if isinstance(x, list): + return self.forward_features_list(x, masks) + # import pdb;pdb.set_trace() + if len(x.size()) == 3: + x = x[None] + + B, C, H, W = x.size() + pad_h = (self.patch_size - H % self.patch_size) + pad_w = (self.patch_size - W % self.patch_size) + if pad_h == self.patch_size: + pad_h = 0 + if pad_w == self.patch_size: + pad_w = 0 + #x = nn.functional.pad(x, (pad_h//2, pad_h-pad_h//2, pad_w//2, pad_w-pad_w//2)) + if pad_h + pad_w > 0: + x = torch.nn.functional.interpolate(x, (H+pad_h, W+pad_w), mode='bilinear') + + x = self.prepare_tokens_with_masks(x, masks) + + #for blk in self.blocks: + #x = blk(x) + + #x_norm = self.norm(x) + #if self.tuning_mode == 'ssf': + #x_norm = ssf_ada(x_norm, self.ssf_scale_1, self.ssf_shift_1) + + # return { + # "x_norm_clstoken": x_norm[:, 0], + # "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + # "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + # "x_prenorm": x, + # "masks": masks, + # } + # features = [] + # features.append(x_norm) + # features.append(x_norm) + # features.append(x_norm) + # features.append(x_norm) + # return [features, (B, (H+pad_h)//self.patch_size, (W+pad_w)//self.patch_size, H, W, self.num_register_tokens)] + + if self.multi_output == False: + for blk in self.blocks: + x = blk(x) + x_norm = self.norm(x) + if self.tuning_mode == 'ssf': + x_norm = ssf_ada(x_norm, self.ssf_scale_1, self.ssf_shift_1) + + features = [] + features.append(x_norm) + features.append(x_norm) + features.append(x_norm) + features.append(x_norm) + return [features, (B, (H+pad_h)//self.patch_size, (W+pad_w)//self.patch_size, H, W, self.num_register_tokens)] + else: + features = [] + for blk in self.blocks: + for idx, sub_blk in enumerate(blk): + x = sub_blk(x) + if (idx + 1) % (len(blk) // 4) == 0: + features.append(x) + + return [features, (B, (H+pad_h)//self.patch_size, (W+pad_w)//self.patch_size, H, W, self.num_register_tokens)] + + def _get_intermediate_layers_not_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + # If n is an int, take the n last blocks. If it's a list, take them + output, total_block_len = [], len(self.blocks) + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for i, blk in enumerate(self.blocks): + x = blk(x) + if i in blocks_to_take: + output.append(x) + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def _get_intermediate_layers_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + output, i, total_block_len = [], 0, len(self.blocks[-1]) + # If n is an int, take the n last blocks. If it's a list, take them + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for block_chunk in self.blocks: + for blk in block_chunk[i:]: # Passing the nn.Identity() + x = blk(x) + if i in blocks_to_take: + output.append(x) + i += 1 + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def get_intermediate_layers( + self, + x: torch.Tensor, + n: Union[int, Sequence] = 1, # Layers or n last layers to take + reshape: bool = False, + return_class_token: bool = False, + norm=True, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: + if self.chunked_blocks: + outputs = self._get_intermediate_layers_chunked(x, n) + else: + outputs = self._get_intermediate_layers_not_chunked(x, n) + if norm: + outputs = [self.norm(out) for out in outputs] + class_tokens = [out[:, 0] for out in outputs] + outputs = [out[:, 1:] for out in outputs] + if reshape: + B, _, w, h = x.shape + outputs = [ + out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous() + for out in outputs + ] + if return_class_token: + return tuple(zip(outputs, class_tokens)) + return tuple(outputs) + + def forward(self, *args, is_training=False, **kwargs): + ret = self.forward_features(*args, **kwargs) + return ret + # if is_training: + # return ret + # else: + # return self.head(ret["x_norm_clstoken"]) + + +def init_weights_vit_timm(module: nn.Module, name: str = ""): + """ViT weight initialization, original timm impl (for reproducibility)""" + if isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + + +def load_ckpt_dino(checkpoint, model): + + if checkpoint is not None: + try: + with open(checkpoint, "rb") as f: + state_dict = torch.load(f) + print('loading dinov2 backbone checkpoint from %s' %checkpoint) + except: + print('NO pretrained imagenet ckpt available! Check your path!') + del model.mask_token + return + + # import pdb;pdb.set_trace() + try: + model.load_state_dict(state_dict, strict=True) + print('load dinov2 pretrain weights successfully ') + except: + new_state_dict = {} + for key, value in state_dict.items(): + if 'blocks' in key: + key_new = 'blocks.0' + key[len('blocks'):] + else: + key_new = key + new_state_dict[key_new] = value + + model.load_state_dict(new_state_dict, strict=True) + print('load dinov2 pretrain weights successfully') + del model.mask_token + return + else: + return + + +def vit_small(patch_size=14, num_register_tokens=0, checkpoint=None, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=384, + depth=12, + num_heads=6, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + load_ckpt_dino(checkpoint, model) + + return model + + +def vit_base(patch_size=14, num_register_tokens=0, checkpoint=None, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_large(patch_size=14, num_register_tokens=0, checkpoint=None, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + + if checkpoint is not None: + with open(checkpoint, "rb") as f: + state_dict = torch.load(f) + try: + model.load_state_dict(state_dict, strict=True) + except: + new_state_dict = {} + for key, value in state_dict.items(): + if 'blocks' in key: + key_new = 'blocks.0' + key[len('blocks'):] + else: + key_new = key + new_state_dict[key_new] = value + + model.load_state_dict(new_state_dict, strict=True) + del model.mask_token + return model + + +def vit_giant2(patch_size=14, num_register_tokens=0, checkpoint=None, **kwargs): + """ + Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64 + """ + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1536, + depth=40, + num_heads=24, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + ffn_layer='swiglu', + **kwargs, + ) + return model + + +def vit_small_reg(patch_size=14, num_register_tokens=4, checkpoint=None, tuning_mode=None, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=384, + depth=12, + num_heads=6, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + tuning_mode=tuning_mode, + **kwargs, + ) + + load_ckpt_dino(checkpoint, model) + + return model + + +def vit_base_reg(patch_size=14, num_register_tokens=4, checkpoint=None, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + + load_ckpt_dino(checkpoint, model) + + return model + + +def vit_large_reg(patch_size=14, num_register_tokens=4, checkpoint=None, tuning_mode=None, **kwargs): + model = DinoVisionTransformer( + img_size = 518, + patch_size=patch_size, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + tuning_mode=tuning_mode, + **kwargs, + ) + load_ckpt_dino(checkpoint, model) + + return model + + +def vit_giant2_reg(patch_size=14, num_register_tokens=4, checkpoint=None, tuning_mode=None, **kwargs): + """ + Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64 + """ + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1536, + depth=40, + num_heads=24, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + ffn_layer='swiglu', + tuning_mode=tuning_mode, + multi_output=True, + **kwargs, + ) + + load_ckpt_dino(checkpoint, model) + + return model + + +if __name__ == '__main__': + try: + from mmcv.utils import Config + except: + from mmengine import Config + + #rgb = torch.rand((2, 3, 518, 518)).cuda() + + #cfg.data_basic['crop_size']['0'] + #cfg.data_basic['crop_size']['1'] + # cfg = Config.fromfile('configs/metric3d_configs/RAFTDecoder_xugk/vit.raft5.large.kitti.py') + # cfg['tuning_mode'] = 'ssf' + #rgb = torch.arange(0, 2*3*1036*1036, 1).cuda().float().view(2, 3, 1036, 1036) + + rgb = torch.zeros(1, 3, 616, 1064).cuda() + model = vit_large_reg( + checkpoint="./data/weights/dinov2/dinov2_vitl14_reg4_pretrain.pth", + tuning_mode=None, + # kwarg=cfg + ).cuda() + + # model = vit_large_reg(tuning_mode='ssf').cuda() + + out1 = model(rgb) + import pdb;pdb.set_trace() + #import timm + #model2 = timm.models.vision_transformer.vit_large_patch14_dinov2().cuda() + #timm.models.load_checkpoint(model2, '/cpfs02/shared/public/yvan/pretrained_weight_repo/vit/dinov2_vitl14_pretrain.pth', filter_fn=timm.models.vision_transformer.checkpoint_filter_fn) + #out2 = model2(rgb) + temp = 0 + + diff --git a/geobench/modeling/backbones/vit/__init__.py b/geobench/modeling/backbones/vit/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d18d620aad544b53d011b849f05e0858a39e59a4 --- /dev/null +++ b/geobench/modeling/backbones/vit/__init__.py @@ -0,0 +1,6 @@ +from .ViT_DINO import vit_large +from .ViT_DINO_reg import vit_small_reg, vit_large_reg, vit_giant2_reg + +__all__ = [ + 'vit_small_reg', 'vit_large_reg', 'vit_giant2_reg' +] diff --git a/geobench/modeling/backbones/vit/dinov2.py b/geobench/modeling/backbones/vit/dinov2.py new file mode 100644 index 0000000000000000000000000000000000000000..83d250818c721c6df3b30d3f4352945527701615 --- /dev/null +++ b/geobench/modeling/backbones/vit/dinov2.py @@ -0,0 +1,415 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +from functools import partial +import math +import logging +from typing import Sequence, Tuple, Union, Callable + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from torch.nn.init import trunc_normal_ + +from .dinov2_layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block + + +logger = logging.getLogger("dinov2") + + +def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module: + if not depth_first and include_root: + fn(module=module, name=name) + for child_name, child_module in module.named_children(): + child_name = ".".join((name, child_name)) if name else child_name + named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True) + if depth_first and include_root: + fn(module=module, name=name) + return module + + +class BlockChunk(nn.ModuleList): + def forward(self, x): + for b in self: + x = b(x) + return x + + +class DinoVisionTransformer(nn.Module): + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + ffn_bias=True, + proj_bias=True, + drop_path_rate=0.0, + drop_path_uniform=False, + init_values=None, # for layerscale: None or 0 => no layerscale + embed_layer=PatchEmbed, + act_layer=nn.GELU, + block_fn=Block, + ffn_layer="mlp", + block_chunks=1, + num_register_tokens=0, + interpolate_antialias=False, + interpolate_offset=0.1, + ): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + proj_bias (bool): enable bias for proj in attn if True + ffn_bias (bool): enable bias for ffn if True + drop_path_rate (float): stochastic depth rate + drop_path_uniform (bool): apply uniform drop rate across blocks + weight_init (str): weight init scheme + init_values (float): layer-scale init values + embed_layer (nn.Module): patch embedding layer + act_layer (nn.Module): MLP activation layer + block_fn (nn.Module): transformer block class + ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity" + block_chunks: (int) split block sequence into block_chunks units for FSDP wrap + num_register_tokens: (int) number of extra cls tokens (so-called "registers") + interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings + interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings + """ + super().__init__() + norm_layer = partial(nn.LayerNorm, eps=1e-6) + + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_tokens = 1 + self.n_blocks = depth + self.num_heads = num_heads + self.patch_size = patch_size + self.num_register_tokens = num_register_tokens + self.interpolate_antialias = interpolate_antialias + self.interpolate_offset = interpolate_offset + + self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) + assert num_register_tokens >= 0 + self.register_tokens = ( + nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None + ) + + if drop_path_uniform is True: + dpr = [drop_path_rate] * depth + else: + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + + if ffn_layer == "mlp": + logger.info("using MLP layer as FFN") + ffn_layer = Mlp + elif ffn_layer == "swiglufused" or ffn_layer == "swiglu": + logger.info("using SwiGLU layer as FFN") + ffn_layer = SwiGLUFFNFused + elif ffn_layer == "identity": + logger.info("using Identity layer as FFN") + + def f(*args, **kwargs): + return nn.Identity() + + ffn_layer = f + else: + raise NotImplementedError + + blocks_list = [ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + ffn_bias=ffn_bias, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer, + ffn_layer=ffn_layer, + init_values=init_values, + ) + for i in range(depth) + ] + if block_chunks > 0: + self.chunked_blocks = True + chunked_blocks = [] + chunksize = depth // block_chunks + for i in range(0, depth, chunksize): + # this is to keep the block index consistent if we chunk the block list + chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize]) + self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks]) + else: + self.chunked_blocks = False + self.blocks = nn.ModuleList(blocks_list) + + self.norm = norm_layer(embed_dim) + self.head = nn.Identity() + + self.mask_token = nn.Parameter(torch.zeros(1, embed_dim)) + + self.init_weights() + + def init_weights(self): + trunc_normal_(self.pos_embed, std=0.02) + nn.init.normal_(self.cls_token, std=1e-6) + if self.register_tokens is not None: + nn.init.normal_(self.register_tokens, std=1e-6) + named_apply(init_weights_vit_timm, self) + + def interpolate_pos_encoding(self, x, w, h): + previous_dtype = x.dtype + npatch = x.shape[1] - 1 + N = self.pos_embed.shape[1] - 1 + if npatch == N and w == h: + return self.pos_embed + pos_embed = self.pos_embed.float() + class_pos_embed = pos_embed[:, 0] + patch_pos_embed = pos_embed[:, 1:] + dim = x.shape[-1] + w0 = w // self.patch_size + h0 = h // self.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + # DINOv2 with register modify the interpolate_offset from 0.1 to 0.0 + w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset + # w0, h0 = w0 + 0.1, h0 + 0.1 + + sqrt_N = math.sqrt(N) + sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2), + scale_factor=(sx, sy), + # (int(w0), int(h0)), # to solve the upsampling shape issue + mode="bicubic", + antialias=self.interpolate_antialias + ) + + assert int(w0) == patch_pos_embed.shape[-2] + assert int(h0) == patch_pos_embed.shape[-1] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype) + + def prepare_tokens_with_masks(self, x, masks=None): + B, nc, w, h = x.shape + x = self.patch_embed(x) + if masks is not None: + x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x) + + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + x = x + self.interpolate_pos_encoding(x, w, h) + + if self.register_tokens is not None: + x = torch.cat( + ( + x[:, :1], + self.register_tokens.expand(x.shape[0], -1, -1), + x[:, 1:], + ), + dim=1, + ) + + return x + + def forward_features_list(self, x_list, masks_list): + x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)] + for blk in self.blocks: + x = blk(x) + + all_x = x + output = [] + for x, masks in zip(all_x, masks_list): + x_norm = self.norm(x) + output.append( + { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + "x_prenorm": x, + "masks": masks, + } + ) + return output + + def forward_features(self, x, masks=None): + if isinstance(x, list): + return self.forward_features_list(x, masks) + + x = self.prepare_tokens_with_masks(x, masks) + + for blk in self.blocks: + x = blk(x) + + x_norm = self.norm(x) + return { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + "x_prenorm": x, + "masks": masks, + } + + def _get_intermediate_layers_not_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + # If n is an int, take the n last blocks. If it's a list, take them + output, total_block_len = [], len(self.blocks) + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for i, blk in enumerate(self.blocks): + x = blk(x) + if i in blocks_to_take: + output.append(x) + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def _get_intermediate_layers_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + output, i, total_block_len = [], 0, len(self.blocks[-1]) + # If n is an int, take the n last blocks. If it's a list, take them + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for block_chunk in self.blocks: + for blk in block_chunk[i:]: # Passing the nn.Identity() + x = blk(x) + if i in blocks_to_take: + output.append(x) + i += 1 + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def get_intermediate_layers( + self, + x: torch.Tensor, + n: Union[int, Sequence] = 1, # Layers or n last layers to take + reshape: bool = False, + return_class_token: bool = False, + norm=True + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: + if self.chunked_blocks: + outputs = self._get_intermediate_layers_chunked(x, n) + else: + outputs = self._get_intermediate_layers_not_chunked(x, n) + if norm: + outputs = [self.norm(out) for out in outputs] + class_tokens = [out[:, 0] for out in outputs] + outputs = [out[:, 1 + self.num_register_tokens:] for out in outputs] + if reshape: + B, _, w, h = x.shape + outputs = [ + out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous() + for out in outputs + ] + if return_class_token: + return tuple(zip(outputs, class_tokens)) + return tuple(outputs) + + def forward(self, *args, is_training=False, **kwargs): + ret = self.forward_features(*args, **kwargs) + if is_training: + return ret + else: + return self.head(ret["x_norm_clstoken"]) + + +def init_weights_vit_timm(module: nn.Module, name: str = ""): + """ViT weight initialization, original timm impl (for reproducibility)""" + if isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + + +def vit_small(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=384, + depth=12, + num_heads=6, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_base(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_large(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs): + """ + Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64 + """ + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1536, + depth=40, + num_heads=24, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def DINOv2(model_name): + model_zoo = { + "vits": vit_small, + "vitb": vit_base, + "vitl": vit_large, + "vitg": vit_giant2 + } + + return model_zoo[model_name]( + img_size=518, + patch_size=14, + init_values=1.0, + ffn_layer="mlp" if model_name != "vitg" else "swiglufused", + block_chunks=0, + num_register_tokens=0, + interpolate_antialias=False, + interpolate_offset=0.1 + ) diff --git a/geobench/modeling/backbones/vit/dinov2_layers/__init__.py b/geobench/modeling/backbones/vit/dinov2_layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8120f4bc83066cb3f825ce32daa3b437f88486f1 --- /dev/null +++ b/geobench/modeling/backbones/vit/dinov2_layers/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .mlp import Mlp +from .patch_embed import PatchEmbed +from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused +from .block import NestedTensorBlock +from .attention import MemEffAttention diff --git a/geobench/modeling/backbones/vit/dinov2_layers/__pycache__/__init__.cpython-310.pyc b/geobench/modeling/backbones/vit/dinov2_layers/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..02f90e12afbad811fb3e26f727cea67ea234bde7 Binary files /dev/null and b/geobench/modeling/backbones/vit/dinov2_layers/__pycache__/__init__.cpython-310.pyc differ diff --git a/geobench/modeling/backbones/vit/dinov2_layers/__pycache__/attention.cpython-310.pyc b/geobench/modeling/backbones/vit/dinov2_layers/__pycache__/attention.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..594f2629ca3d80cca18b229e867622a3571c2e97 Binary files /dev/null and b/geobench/modeling/backbones/vit/dinov2_layers/__pycache__/attention.cpython-310.pyc differ diff --git a/geobench/modeling/backbones/vit/dinov2_layers/__pycache__/block.cpython-310.pyc b/geobench/modeling/backbones/vit/dinov2_layers/__pycache__/block.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..35dbc29c6a3bd51103539222dbc8399672d2e1f4 Binary files /dev/null and b/geobench/modeling/backbones/vit/dinov2_layers/__pycache__/block.cpython-310.pyc differ diff --git a/geobench/modeling/backbones/vit/dinov2_layers/__pycache__/drop_path.cpython-310.pyc b/geobench/modeling/backbones/vit/dinov2_layers/__pycache__/drop_path.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4925e95133928cd2727fc8a8105a3ce2cfd24352 Binary files /dev/null and b/geobench/modeling/backbones/vit/dinov2_layers/__pycache__/drop_path.cpython-310.pyc differ diff --git a/geobench/modeling/backbones/vit/dinov2_layers/__pycache__/layer_scale.cpython-310.pyc b/geobench/modeling/backbones/vit/dinov2_layers/__pycache__/layer_scale.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c4eba724995ffc77b8943939e23088fc0ca6b5c4 Binary files /dev/null and b/geobench/modeling/backbones/vit/dinov2_layers/__pycache__/layer_scale.cpython-310.pyc differ diff --git a/geobench/modeling/backbones/vit/dinov2_layers/__pycache__/mlp.cpython-310.pyc b/geobench/modeling/backbones/vit/dinov2_layers/__pycache__/mlp.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..247cd50f9e7e0bf96e61a3a905c1b5b597eb2b12 Binary files /dev/null and b/geobench/modeling/backbones/vit/dinov2_layers/__pycache__/mlp.cpython-310.pyc differ diff --git a/geobench/modeling/backbones/vit/dinov2_layers/__pycache__/patch_embed.cpython-310.pyc b/geobench/modeling/backbones/vit/dinov2_layers/__pycache__/patch_embed.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ed67df12ff416d5c807c85787071873602df35c5 Binary files /dev/null and b/geobench/modeling/backbones/vit/dinov2_layers/__pycache__/patch_embed.cpython-310.pyc differ diff --git a/geobench/modeling/backbones/vit/dinov2_layers/__pycache__/swiglu_ffn.cpython-310.pyc b/geobench/modeling/backbones/vit/dinov2_layers/__pycache__/swiglu_ffn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ff580c1c0895049ca159d0188fc0a0ce7fc81d21 Binary files /dev/null and b/geobench/modeling/backbones/vit/dinov2_layers/__pycache__/swiglu_ffn.cpython-310.pyc differ diff --git a/geobench/modeling/backbones/vit/dinov2_layers/attention.py b/geobench/modeling/backbones/vit/dinov2_layers/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..815a2bf53dbec496f6a184ed7d03bcecb7124262 --- /dev/null +++ b/geobench/modeling/backbones/vit/dinov2_layers/attention.py @@ -0,0 +1,83 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +import logging + +from torch import Tensor +from torch import nn + + +logger = logging.getLogger("dinov2") + + +try: + from xformers.ops import memory_efficient_attention, unbind, fmha + + XFORMERS_AVAILABLE = True +except ImportError: + logger.warning("xFormers not available") + XFORMERS_AVAILABLE = False + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + ) -> None: + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: Tensor) -> Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + + q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] + attn = q @ k.transpose(-2, -1) + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class MemEffAttention(Attention): + def forward(self, x: Tensor, attn_bias=None) -> Tensor: + if not XFORMERS_AVAILABLE: + assert attn_bias is None, "xFormers is required for nested tensors usage" + return super().forward(x) + + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) + + q, k, v = unbind(qkv, 2) + + x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) + x = x.reshape([B, N, C]) + + x = self.proj(x) + x = self.proj_drop(x) + return x + + \ No newline at end of file diff --git a/geobench/modeling/backbones/vit/dinov2_layers/block.py b/geobench/modeling/backbones/vit/dinov2_layers/block.py new file mode 100644 index 0000000000000000000000000000000000000000..25488f57cc0ad3c692f86b62555f6668e2a66db1 --- /dev/null +++ b/geobench/modeling/backbones/vit/dinov2_layers/block.py @@ -0,0 +1,252 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +import logging +from typing import Callable, List, Any, Tuple, Dict + +import torch +from torch import nn, Tensor + +from .attention import Attention, MemEffAttention +from .drop_path import DropPath +from .layer_scale import LayerScale +from .mlp import Mlp + + +logger = logging.getLogger("dinov2") + + +try: + from xformers.ops import fmha + from xformers.ops import scaled_index_add, index_select_cat + + XFORMERS_AVAILABLE = True +except ImportError: + logger.warning("xFormers not available") + XFORMERS_AVAILABLE = False + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + proj_bias: bool = True, + ffn_bias: bool = True, + drop: float = 0.0, + attn_drop: float = 0.0, + init_values=None, + drop_path: float = 0.0, + act_layer: Callable[..., nn.Module] = nn.GELU, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_class: Callable[..., nn.Module] = Attention, + ffn_layer: Callable[..., nn.Module] = Mlp, + ) -> None: + super().__init__() + # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}") + self.norm1 = norm_layer(dim) + self.attn = attn_class( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + attn_drop=attn_drop, + proj_drop=drop, + ) + self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = ffn_layer( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + bias=ffn_bias, + ) + self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.sample_drop_ratio = drop_path + + def forward(self, x: Tensor) -> Tensor: + def attn_residual_func(x: Tensor) -> Tensor: + return self.ls1(self.attn(self.norm1(x))) + + def ffn_residual_func(x: Tensor) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + if self.training and self.sample_drop_ratio > 0.1: + # the overhead is compensated only for a drop path rate larger than 0.1 + x = drop_add_residual_stochastic_depth( + x, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + x = drop_add_residual_stochastic_depth( + x, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + elif self.training and self.sample_drop_ratio > 0.0: + x = x + self.drop_path1(attn_residual_func(x)) + x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 + else: + x = x + attn_residual_func(x) + x = x + ffn_residual_func(x) + return x + + +def drop_add_residual_stochastic_depth( + x: Tensor, + residual_func: Callable[[Tensor], Tensor], + sample_drop_ratio: float = 0.0, +) -> Tensor: + # 1) extract subset using permutation + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + x_subset = x[brange] + + # 2) apply residual_func to get residual + residual = residual_func(x_subset) + + x_flat = x.flatten(1) + residual = residual.flatten(1) + + residual_scale_factor = b / sample_subset_size + + # 3) add the residual + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + return x_plus_residual.view_as(x) + + +def get_branges_scales(x, sample_drop_ratio=0.0): + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + residual_scale_factor = b / sample_subset_size + return brange, residual_scale_factor + + +def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None): + if scaling_vector is None: + x_flat = x.flatten(1) + residual = residual.flatten(1) + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + else: + x_plus_residual = scaled_index_add( + x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor + ) + return x_plus_residual + + +attn_bias_cache: Dict[Tuple, Any] = {} + + +def get_attn_bias_and_cat(x_list, branges=None): + """ + this will perform the index select, cat the tensors, and provide the attn_bias from cache + """ + batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list] + all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list)) + if all_shapes not in attn_bias_cache.keys(): + seqlens = [] + for b, x in zip(batch_sizes, x_list): + for _ in range(b): + seqlens.append(x.shape[1]) + attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) + attn_bias._batch_sizes = batch_sizes + attn_bias_cache[all_shapes] = attn_bias + + if branges is not None: + cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1]) + else: + tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list) + cat_tensors = torch.cat(tensors_bs1, dim=1) + + return attn_bias_cache[all_shapes], cat_tensors + + +def drop_add_residual_stochastic_depth_list( + x_list: List[Tensor], + residual_func: Callable[[Tensor, Any], Tensor], + sample_drop_ratio: float = 0.0, + scaling_vector=None, +) -> Tensor: + # 1) generate random set of indices for dropping samples in the batch + branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list] + branges = [s[0] for s in branges_scales] + residual_scale_factors = [s[1] for s in branges_scales] + + # 2) get attention bias and index+concat the tensors + attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges) + + # 3) apply residual_func to get residual, and split the result + residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore + + outputs = [] + for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors): + outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x)) + return outputs + + +class NestedTensorBlock(Block): + def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]: + """ + x_list contains a list of tensors to nest together and run + """ + assert isinstance(self.attn, MemEffAttention) + + if self.training and self.sample_drop_ratio > 0.0: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.attn(self.norm1(x), attn_bias=attn_bias) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.mlp(self.norm2(x)) + + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None, + ) + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None, + ) + return x_list + else: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias)) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + attn_bias, x = get_attn_bias_and_cat(x_list) + x = x + attn_residual_func(x, attn_bias=attn_bias) + x = x + ffn_residual_func(x) + return attn_bias.split(x) + + def forward(self, x_or_x_list): + if isinstance(x_or_x_list, Tensor): + return super().forward(x_or_x_list) + elif isinstance(x_or_x_list, list): + assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage" + return self.forward_nested(x_or_x_list) + else: + raise AssertionError diff --git a/geobench/modeling/backbones/vit/dinov2_layers/drop_path.py b/geobench/modeling/backbones/vit/dinov2_layers/drop_path.py new file mode 100644 index 0000000000000000000000000000000000000000..af05625984dd14682cc96a63bf0c97bab1f123b1 --- /dev/null +++ b/geobench/modeling/backbones/vit/dinov2_layers/drop_path.py @@ -0,0 +1,35 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py + + +from torch import nn + + +def drop_path(x, drop_prob: float = 0.0, training: bool = False): + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0: + random_tensor.div_(keep_prob) + output = x * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) diff --git a/geobench/modeling/backbones/vit/dinov2_layers/layer_scale.py b/geobench/modeling/backbones/vit/dinov2_layers/layer_scale.py new file mode 100644 index 0000000000000000000000000000000000000000..ca5daa52bd81d3581adeb2198ea5b7dba2a3aea1 --- /dev/null +++ b/geobench/modeling/backbones/vit/dinov2_layers/layer_scale.py @@ -0,0 +1,28 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 + +from typing import Union + +import torch +from torch import Tensor +from torch import nn + + +class LayerScale(nn.Module): + def __init__( + self, + dim: int, + init_values: Union[float, Tensor] = 1e-5, + inplace: bool = False, + ) -> None: + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: Tensor) -> Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma diff --git a/geobench/modeling/backbones/vit/dinov2_layers/mlp.py b/geobench/modeling/backbones/vit/dinov2_layers/mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..5e4b315f972f9a9f54aef1e4ef4e81b52976f018 --- /dev/null +++ b/geobench/modeling/backbones/vit/dinov2_layers/mlp.py @@ -0,0 +1,41 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py + + +from typing import Callable, Optional + +from torch import Tensor, nn + + +class Mlp(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = nn.GELU, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) + self.drop = nn.Dropout(drop) + + def forward(self, x: Tensor) -> Tensor: + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x diff --git a/geobench/modeling/backbones/vit/dinov2_layers/patch_embed.py b/geobench/modeling/backbones/vit/dinov2_layers/patch_embed.py new file mode 100644 index 0000000000000000000000000000000000000000..574abe41175568d700a389b8b96d1ba554914779 --- /dev/null +++ b/geobench/modeling/backbones/vit/dinov2_layers/patch_embed.py @@ -0,0 +1,89 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +from typing import Callable, Optional, Tuple, Union + +from torch import Tensor +import torch.nn as nn + + +def make_2tuple(x): + if isinstance(x, tuple): + assert len(x) == 2 + return x + + assert isinstance(x, int) + return (x, x) + + +class PatchEmbed(nn.Module): + """ + 2D image to patch embedding: (B,C,H,W) -> (B,N,D) + + Args: + img_size: Image size. + patch_size: Patch token size. + in_chans: Number of input image channels. + embed_dim: Number of linear projection output channels. + norm_layer: Normalization layer. + """ + + def __init__( + self, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 16, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer: Optional[Callable] = None, + flatten_embedding: bool = True, + ) -> None: + super().__init__() + + image_HW = make_2tuple(img_size) + patch_HW = make_2tuple(patch_size) + patch_grid_size = ( + image_HW[0] // patch_HW[0], + image_HW[1] // patch_HW[1], + ) + + self.img_size = image_HW + self.patch_size = patch_HW + self.patches_resolution = patch_grid_size + self.num_patches = patch_grid_size[0] * patch_grid_size[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.flatten_embedding = flatten_embedding + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x: Tensor) -> Tensor: + _, _, H, W = x.shape + patch_H, patch_W = self.patch_size + + assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}" + assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}" + + x = self.proj(x) # B C H W + H, W = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) # B HW C + x = self.norm(x) + if not self.flatten_embedding: + x = x.reshape(-1, H, W, self.embed_dim) # B H W C + return x + + def flops(self) -> float: + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops diff --git a/geobench/modeling/backbones/vit/dinov2_layers/swiglu_ffn.py b/geobench/modeling/backbones/vit/dinov2_layers/swiglu_ffn.py new file mode 100644 index 0000000000000000000000000000000000000000..b3324b266fb0a50ccf8c3a0ede2ae10ac4dfa03e --- /dev/null +++ b/geobench/modeling/backbones/vit/dinov2_layers/swiglu_ffn.py @@ -0,0 +1,63 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Callable, Optional + +from torch import Tensor, nn +import torch.nn.functional as F + + +class SwiGLUFFN(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) + self.w3 = nn.Linear(hidden_features, out_features, bias=bias) + + def forward(self, x: Tensor) -> Tensor: + x12 = self.w12(x) + x1, x2 = x12.chunk(2, dim=-1) + hidden = F.silu(x1) * x2 + return self.w3(hidden) + + +try: + from xformers.ops import SwiGLU + + XFORMERS_AVAILABLE = True +except ImportError: + SwiGLU = SwiGLUFFN + XFORMERS_AVAILABLE = False + + +class SwiGLUFFNFused(SwiGLU): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + out_features = out_features or in_features + hidden_features = hidden_features or in_features + hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 + super().__init__( + in_features=in_features, + hidden_features=hidden_features, + out_features=out_features, + bias=bias, + ) diff --git a/geobench/utils/image_util.py b/geobench/utils/image_util.py new file mode 100644 index 0000000000000000000000000000000000000000..e580fd50eea2009d29213362edb459241a0f9f72 --- /dev/null +++ b/geobench/utils/image_util.py @@ -0,0 +1,338 @@ +import matplotlib +import numpy as np +import torch +from PIL import Image +from PIL.Image import Resampling +from torchvision.transforms import InterpolationMode +from torchvision.transforms.functional import resize +import cv2 +import re + +def load_pfm(file): + color = None + width = None + height = None + scale = None + data_type = None + header = file.readline().decode('UTF-8').rstrip() + + if header == 'PF': + color = True + elif header == 'Pf': + color = False + else: + raise Exception('Not a PFM file.') + dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode('UTF-8')) + if dim_match: + width, height = map(int, dim_match.groups()) + else: + raise Exception('Malformed PFM header.') + # scale = float(file.readline().rstrip()) + scale = float((file.readline()).decode('UTF-8').rstrip()) + if scale < 0: # little-endian + data_type = '= 2, "Invalid dimension" + + if isinstance(depth_map, torch.Tensor): + depth = depth_map.detach().clone().squeeze().cpu().numpy() + elif isinstance(depth_map, np.ndarray): + depth = depth_map.copy().squeeze() + # reshape to [ (B,) H, W ] + if depth.ndim < 3: + depth = depth[np.newaxis, :, :] + + # colorize + cm = matplotlib.colormaps[cmap] + + # if min_depth is None or max_depth is None: + # if cmap == "magma_r": + # min_depth = np.percentile(depth, 2) + # max_depth = np.percentile(depth, 85) + # elif cmap == "Spectral": + # min_depth = np.percentile(depth, 2) + # max_depth = np.percentile(depth, 98) + + if min_depth != max_depth: + depth = ((depth - min_depth) / (max_depth - min_depth)).clip(0, 1) + else: + # Avoid 0-division + depth = depth * 0. + + img_colored_np = cm(depth, bytes=False)[:, :, :, 0:3] # value from 0 to 1 + img_colored_np = np.rollaxis(img_colored_np, 3, 1) + + if valid_mask is not None: + if isinstance(depth_map, torch.Tensor): + valid_mask = valid_mask.detach().numpy() + valid_mask = valid_mask.squeeze() # [H, W] or [B, H, W] + if valid_mask.ndim < 3: + valid_mask = valid_mask[np.newaxis, np.newaxis, :, :] + else: + valid_mask = valid_mask[:, np.newaxis, :, :] + valid_mask = np.repeat(valid_mask, 3, axis=1) + img_colored_np[~valid_mask] = 0 + + if isinstance(depth_map, torch.Tensor): + img_colored = torch.from_numpy(img_colored_np).float() + elif isinstance(depth_map, np.ndarray): + img_colored = img_colored_np + + return img_colored + + +def chw2hwc(chw): + assert 3 == len(chw.shape) + if isinstance(chw, torch.Tensor): + hwc = torch.permute(chw, (1, 2, 0)) + elif isinstance(chw, np.ndarray): + hwc = np.moveaxis(chw, 0, -1) + return hwc + + +def resize_max_res_torch( + img: torch.Tensor, + max_edge_resolution: int, + resample_method: InterpolationMode = InterpolationMode.BILINEAR, +) -> torch.Tensor: + """ + Resize image to limit maximum edge length while keeping aspect ratio. + + Args: + img (`torch.Tensor`): + Image tensor to be resized. + max_edge_resolution (`int`): + Maximum edge length (pixel). + resample_method (`PIL.Image.Resampling`): + Resampling method used to resize images. + + Returns: + `torch.Tensor`: Resized image. + """ + assert 3 == img.dim() + _, original_height, original_width = img.shape + downscale_factor = min( + max_edge_resolution / original_width, max_edge_resolution / original_height + ) + + new_width = int(original_width * downscale_factor) + new_height = int(original_height * downscale_factor) + + round_num = 16 + new_width = round(new_width / round_num) * round_num + new_height = round(new_height / round_num) * round_num + + resized_img = resize(img, (new_height, new_width), resample_method, antialias=True) + return resized_img + + +def resize_max_res(img: Image.Image, max_edge_resolution: int, resample_method=Resampling.BILINEAR) -> Image.Image: + """ + Resize image to limit maximum edge length while keeping aspect ratio + + Args: + img (Image.Image): Image to be resized + max_edge_resolution (int): Maximum edge length (px). + + Returns: + Image.Image: Resized image. + """ + # import pdb;pdb.set_trace() + if isinstance(img, torch.Tensor): + return resize_max_res_torch(img, max_edge_resolution, resample_method) + + original_width, original_height = img.size + downscale_factor = min( + max_edge_resolution / original_width, max_edge_resolution / original_height + ) + + new_width = int(original_width * downscale_factor) + new_height = int(original_height * downscale_factor) + + resized_img = img.resize((new_width, new_height), resample=resample_method) + return resized_img + + +def get_pil_resample_method(method_str: str) -> Resampling: + resample_method_dict = { + "bilinear": Resampling.BILINEAR, + "bicubic": Resampling.BICUBIC, + "nearest": Resampling.NEAREST, + } + resample_method = resample_method_dict.get(method_str, None) + if resample_method is None: + raise ValueError(f"Unknown resampling method: {resample_method}") + else: + return resample_method + + +def get_tv_resample_method(method_str: str) -> InterpolationMode: + resample_method_dict = { + "bilinear": InterpolationMode.BILINEAR, + "bicubic": InterpolationMode.BICUBIC, + # "nearest": InterpolationMode.NEAREST_EXACT, + } + resample_method = resample_method_dict.get(method_str, None) + if resample_method is None: + raise ValueError(f"Unknown resampling method: {resample_method}") + else: + return resample_method + + +def create_point_cloud(depth_map, camera_matrix, extrinsic_matrix): + + """Create point cloud from depth map and camera parameters.""" + + # Get shape of depth map + height, width = depth_map.shape + + # Create meshgrid for pixel coordinates + x = np.linspace(0, width - 1, width) + y = np.linspace(0, height - 1, height) + x, y = np.meshgrid(x, y) + + # Normalize pixel coordinates + normalized_x = (x - camera_matrix[0, 2]) / camera_matrix[0, 0] + normalized_y = (y - camera_matrix[1, 2]) / camera_matrix[1, 1] + normalized_z = np.ones_like(x) + + # Homogeneous coordinates in camera frame + depth_map_reshaped = np.repeat(depth_map[:, :, np.newaxis], 3, axis=2) + homogeneous_camera_coords = depth_map_reshaped * np.dstack((normalized_x, + normalized_y, + normalized_z)) + + # Add ones to the last dimension + ones = np.ones((height, width, 1)) + homogeneous_camera_coords = np.dstack((homogeneous_camera_coords, ones)) + + # Transform points to world coordinates + homogeneous_world_coords = np.dot(homogeneous_camera_coords, + extrinsic_matrix.T) + + # Divide by the fourth coordinate (homogeneous normalization) + point_cloud = (homogeneous_world_coords[:, :, :3] / + homogeneous_world_coords[:, :, 3:]) + + point_cloud = point_cloud.reshape(-1, 3) + + return point_cloud + + +def write_ply_mask(points,colors,path_ply,mask=None): + if mask is not None: + num = np.sum(mask) + else: + num = points.shape[0] + ply_header = ''' + ply format ascii 1.0 + element vertex {} + property float x + property float y + property float z + property uchar red + property uchar green + property uchar blue + end_header + '''.format(num) + # points.shape[0] + # import ipdb;ipdb.set_trace() + # if mask is not None: + with open(path_ply, 'w') as f: + f.write(ply_header) + for i in range(points.shape[0]): + if mask.reshape(-1)[i]: + f.write('{} {} {} {} {} {}\n'.format(points[i,0], points[i,1], points[i,2], + int(colors[i, 2]*255), int(colors[i, 1]*255), int(colors[i, 0]*255))) + + +def write_ply(points,colors,path_ply,mask=None): + if mask is not None: + num = np.sum(mask) + else: + num = points.shape[0] + ply_header = '''ply + format ascii 1.0 + element vertex {} + property float x + property float y + property float z + property uchar red + property uchar green + property uchar blue + end_header + '''.format(num) + + with open(path_ply, 'w') as f: + f.write(ply_header) + for i in range(points.shape[0]): + f.write('{} {} {} {} {} {}\n'.format(points[i,0], points[i,1], points[i,2], + int(colors[i, 2]*255), int(colors[i, 1]*255), int(colors[i, 0]*255))) + + +def Disparity_Normalization(disparity): + min_value = torch.min(disparity) + max_value = torch.max(disparity) + normalized_disparity = ((disparity -min_value)/(max_value-min_value+1e-5) - 0.5) * 2 + return normalized_disparity + + +def Disparity_Normalization_mask(disparity, min_value, max_value): + normalized_disparity = ((disparity -min_value)/(max_value-min_value+1e-5) - 0.5) * 2 + return normalized_disparity + + +def resize_max_res_tensor(input_tensor,is_disp=False,recom_resolution=768): + + original_H, original_W = input_tensor.shape[2:] + + downscale_factor = min(recom_resolution/original_H, + recom_resolution/original_W) + + resized_input_tensor = F.interpolate(input_tensor, + scale_factor=downscale_factor,mode='bilinear', + align_corners=False) + if is_disp: + return resized_input_tensor * downscale_factor, downscale_factor + else: + return resized_input_tensor \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..cd1e6765536a552ffdbeca91f1235ae034a413c6 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,7 @@ +gradio_imageslider +gradio==4.36.0 +torch +torchvision +opencv-python +matplotlib +huggingface_hub \ No newline at end of file