from functools import partial import torch.nn as nn from detectron2.config import LazyCall as L from detectron2.modeling import ViT from detectron2.modeling import SimpleFeaturePyramid as BaseSimpleFeaturePyramid from detectron2.modeling.backbone.fpn import LastLevelMaxPool from detectron2.layers import CNNBlockBase, Conv2d, get_norm import sys sys.path.append('../../') from modeling_pretrain_cleaned import PretrainVisionTransformer from detectron2.modeling.backbone.fpn import _assert_strides_are_log2_contiguous from models.mask_rcnn_fpn_v2 import model, constants from detectron2.modeling.backbone import Backbone import torch import math import torch.nn.functional as F import time model.pixel_mean = constants['imagenet_rgb256_mean'] model.pixel_std = constants['imagenet_rgb256_std'] model.input_format = "RGB" class ViT(Backbone): def __init__(self, img_size=224, encoder_embed_dim=768, encoder_depth=12, encoder_num_heads=12, encoder_num_classes=0, decoder_embed_dim=384, decoder_num_heads=16, decoder_depth=8, mlp_ratio=4, qkv_bias=True, k_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), patch_size=(8, 8), num_frames=3, tubelet_size=1, use_flash_attention=True, return_detectron_format=True, out_feature='last_feat' ): super().__init__() self.model = PretrainVisionTransformer( # Single-scale ViT backbone img_size=img_size, encoder_embed_dim=encoder_embed_dim, encoder_depth=encoder_depth, encoder_num_heads=encoder_num_heads, encoder_num_classes=encoder_num_classes, decoder_embed_dim=decoder_embed_dim, decoder_num_heads=decoder_num_heads, decoder_depth=decoder_depth, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, k_bias=k_bias, norm_layer=norm_layer, patch_size=patch_size, num_frames=num_frames, tubelet_size=tubelet_size, use_flash_attention=use_flash_attention, return_detectron_format=return_detectron_format, out_feature=out_feature ) self._out_features = [out_feature] self._out_feature_channels = {out_feature: encoder_embed_dim * 2} self._out_feature_strides = {out_feature: patch_size[0]} self.patch_hw = 512 // patch_size[0] self.num_frames = num_frames pos_embed = self.get_abs_pos(self.model.encoder.pos_embed, num_frames, [self.patch_hw, self.patch_hw]) self.model.encoder.pos_embed = pos_embed[:, 0:self.patch_hw**2 * (self.num_frames - 1), :] def forward(self, x): B = x.shape[0] x = x.unsqueeze(2).expand(-1, -1, self.num_frames-1, -1, -1) mask = torch.zeros(B, self.patch_hw**2 * (self.num_frames - 1), dtype=torch.bool).to(x.device) return self.model(x, mask) def get_abs_pos(self, abs_pos, num_frames, hw): """ Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token dimension for the original embeddings. Args: abs_pos (Tensor): absolute positional embeddings with (1, num_position, C). has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token. hw (Tuple): size of input image tokens. Returns: Absolute positional embeddings after processing with shape (1, H, W, C) """ h, w = hw xy_num = abs_pos.shape[1] // num_frames size = int(math.sqrt(xy_num)) assert size * size * num_frames == abs_pos.shape[1] abs_pos = abs_pos.view(num_frames, xy_num, -1) if size != h or size != w: new_abs_pos = torch.nn.functional.interpolate( abs_pos.reshape(num_frames, size, size, -1).permute(0, 3, 1, 2), size=(h, w), mode="bicubic", align_corners=False, ) return new_abs_pos.permute(0, 2, 3, 1).flatten(0, 2).unsqueeze(0) else: return abs_pos class SimpleFeaturePyramid(BaseSimpleFeaturePyramid): """ This module implements SimpleFeaturePyramid in :paper:`vitdet`. It creates pyramid features built on top of the input feature map. """ def __init__( self, net, in_feature, out_channels, scale_factors, top_block=None, norm="LN", square_pad=0, ): """ Args: net (Backbone): module representing the subnetwork backbone. Must be a subclass of :class:`Backbone`. in_feature (str): names of the input feature maps coming from the net. out_channels (int): number of channels in the output feature maps. scale_factors (list[float]): list of scaling factors to upsample or downsample the input features for creating pyramid features. top_block (nn.Module or None): if provided, an extra operation will be performed on the output of the last (smallest resolution) pyramid output, and the result will extend the result list. The top_block further downsamples the feature map. It must have an attribute "num_levels", meaning the number of extra pyramid levels added by this block, and "in_feature", which is a string representing its input feature (e.g., p5). norm (str): the normalization to use. square_pad (int): If > 0, require input images to be padded to specific square size. """ super(BaseSimpleFeaturePyramid, self).__init__() assert isinstance(net, Backbone) self.scale_factors = scale_factors input_shapes = net.output_shape() strides = [int(input_shapes[in_feature].stride / scale) for scale in scale_factors] _assert_strides_are_log2_contiguous(strides) dim = input_shapes[in_feature].channels self.stages = [] use_bias = norm == "" for idx, scale in enumerate(scale_factors): out_dim = dim if scale == 4.0: layers = [ nn.ConvTranspose2d(dim, dim // 2, kernel_size=2, stride=2), get_norm(norm, dim // 2), nn.GELU(), nn.ConvTranspose2d(dim // 2, dim // 4, kernel_size=2, stride=2), ] out_dim = dim // 4 elif scale == 2.0: layers = [nn.ConvTranspose2d(dim, dim // 2, kernel_size=2, stride=2)] out_dim = dim // 2 elif scale == 1.0: layers = [] elif scale == 0.5: layers = [nn.MaxPool2d(kernel_size=2, stride=2)] elif scale == 0.25: layers = [nn.MaxPool2d(kernel_size=4, stride=4)] else: raise NotImplementedError(f"scale_factor={scale} is not supported yet.") layers.extend( [ Conv2d( out_dim, out_channels, kernel_size=1, bias=use_bias, norm=get_norm(norm, out_channels), ), Conv2d( out_channels, out_channels, kernel_size=3, padding=1, bias=use_bias, norm=get_norm(norm, out_channels), ), ] ) layers = nn.Sequential(*layers) stage = int(math.log2(strides[idx])) self.add_module(f"simfp_{stage}", layers) self.stages.append(layers) self.net = net self.in_feature = in_feature self.top_block = top_block # Return feature names are "p", like ["p2", "p3", ..., "p6"] self._out_feature_strides = {"p{}".format(int(math.log2(s))): s for s in strides} # top block output feature maps. if self.top_block is not None: for s in range(stage, stage + self.top_block.num_levels): self._out_feature_strides["p{}".format(s + 1)] = 2 ** (s + 1) self._out_features = list(self._out_feature_strides.keys()) self._out_feature_channels = {k: out_channels for k in self._out_features} self._size_divisibility = strides[-1] self._square_pad = square_pad # Base embed_dim, depth, num_heads, dp = 768, 12, 12, 0.1 # Creates Simple Feature Pyramid from ViT backbone model.backbone = L(SimpleFeaturePyramid)( net=L(ViT)( img_size=224, encoder_embed_dim=768, encoder_depth=12, encoder_num_heads=12, encoder_num_classes=0, decoder_embed_dim=384, decoder_num_heads=16, decoder_depth=8, mlp_ratio=4, qkv_bias=True, k_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), patch_size=(16, 16), #(8, 8), num_frames=3, tubelet_size=1, return_detectron_format=True, use_flash_attention=True, out_feature='last_feat' ), in_feature="${.net.out_feature}", out_channels=256, scale_factors=(4.0, 2.0, 1.0, 0.5, 0.25), top_block=L(LastLevelMaxPool)(), norm="LN", square_pad=512, ) model.roi_heads.box_head.conv_norm = model.roi_heads.mask_head.conv_norm = "LN" # 2conv in RPN: model.proposal_generator.head.conv_dims = [-1, -1] # 4conv1fc box head model.roi_heads.box_head.conv_dims = [256, 256, 256, 256] model.roi_heads.box_head.fc_dims = [1024]