Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
from functools import partial | |
import torch.nn as nn | |
from detectron2.config import LazyCall as L | |
from detectron2.modeling import ViT | |
from detectron2.modeling import SimpleFeaturePyramid as BaseSimpleFeaturePyramid | |
from detectron2.modeling.backbone.fpn import LastLevelMaxPool | |
from detectron2.layers import CNNBlockBase, Conv2d, get_norm | |
import sys | |
sys.path.append('../../') | |
from modeling_pretrain_cleaned import PretrainVisionTransformer | |
from detectron2.modeling.backbone.fpn import _assert_strides_are_log2_contiguous | |
from models.mask_rcnn_fpn_v2 import model, constants | |
from detectron2.modeling.backbone import Backbone | |
import torch | |
import math | |
import torch.nn.functional as F | |
import time | |
model.pixel_mean = constants['imagenet_rgb256_mean'] | |
model.pixel_std = constants['imagenet_rgb256_std'] | |
model.input_format = "RGB" | |
class ViT(Backbone): | |
def __init__(self, | |
img_size=224, | |
encoder_embed_dim=768, | |
encoder_depth=12, | |
encoder_num_heads=12, | |
encoder_num_classes=0, | |
decoder_embed_dim=384, | |
decoder_num_heads=16, | |
decoder_depth=8, | |
mlp_ratio=4, | |
qkv_bias=True, | |
k_bias=True, | |
norm_layer=partial(nn.LayerNorm, eps=1e-6), | |
patch_size=(8, 8), | |
num_frames=3, | |
tubelet_size=1, | |
use_flash_attention=True, | |
return_detectron_format=True, | |
out_feature='last_feat' | |
): | |
super().__init__() | |
self.model = PretrainVisionTransformer( # Single-scale ViT backbone | |
img_size=img_size, | |
encoder_embed_dim=encoder_embed_dim, | |
encoder_depth=encoder_depth, | |
encoder_num_heads=encoder_num_heads, | |
encoder_num_classes=encoder_num_classes, | |
decoder_embed_dim=decoder_embed_dim, | |
decoder_num_heads=decoder_num_heads, | |
decoder_depth=decoder_depth, | |
mlp_ratio=mlp_ratio, | |
qkv_bias=qkv_bias, | |
k_bias=k_bias, | |
norm_layer=norm_layer, | |
patch_size=patch_size, | |
num_frames=num_frames, | |
tubelet_size=tubelet_size, | |
use_flash_attention=use_flash_attention, | |
return_detectron_format=return_detectron_format, | |
out_feature=out_feature | |
) | |
self._out_features = [out_feature] | |
self._out_feature_channels = {out_feature: encoder_embed_dim * 2} | |
self._out_feature_strides = {out_feature: patch_size[0]} | |
self.patch_hw = 512 // patch_size[0] | |
self.num_frames = num_frames | |
pos_embed = self.get_abs_pos(self.model.encoder.pos_embed, num_frames, [self.patch_hw, self.patch_hw]) | |
self.model.encoder.pos_embed = pos_embed[:, 0:self.patch_hw**2 * (self.num_frames - 1), :] | |
def forward(self, x): | |
B = x.shape[0] | |
x = x.unsqueeze(2).expand(-1, -1, self.num_frames-1, -1, -1) | |
mask = torch.zeros(B, self.patch_hw**2 * (self.num_frames - 1), dtype=torch.bool).to(x.device) | |
return self.model(x, mask) | |
def get_abs_pos(self, abs_pos, num_frames, hw): | |
""" | |
Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token | |
dimension for the original embeddings. | |
Args: | |
abs_pos (Tensor): absolute positional embeddings with (1, num_position, C). | |
has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token. | |
hw (Tuple): size of input image tokens. | |
Returns: | |
Absolute positional embeddings after processing with shape (1, H, W, C) | |
""" | |
h, w = hw | |
xy_num = abs_pos.shape[1] // num_frames | |
size = int(math.sqrt(xy_num)) | |
assert size * size * num_frames == abs_pos.shape[1] | |
abs_pos = abs_pos.view(num_frames, xy_num, -1) | |
if size != h or size != w: | |
new_abs_pos = torch.nn.functional.interpolate( | |
abs_pos.reshape(num_frames, size, size, -1).permute(0, 3, 1, 2), | |
size=(h, w), | |
mode="bicubic", | |
align_corners=False, | |
) | |
return new_abs_pos.permute(0, 2, 3, 1).flatten(0, 2).unsqueeze(0) | |
else: | |
return abs_pos | |
class SimpleFeaturePyramid(BaseSimpleFeaturePyramid): | |
""" | |
This module implements SimpleFeaturePyramid in :paper:`vitdet`. | |
It creates pyramid features built on top of the input feature map. | |
""" | |
def __init__( | |
self, | |
net, | |
in_feature, | |
out_channels, | |
scale_factors, | |
top_block=None, | |
norm="LN", | |
square_pad=0, | |
): | |
""" | |
Args: | |
net (Backbone): module representing the subnetwork backbone. | |
Must be a subclass of :class:`Backbone`. | |
in_feature (str): names of the input feature maps coming | |
from the net. | |
out_channels (int): number of channels in the output feature maps. | |
scale_factors (list[float]): list of scaling factors to upsample or downsample | |
the input features for creating pyramid features. | |
top_block (nn.Module or None): if provided, an extra operation will | |
be performed on the output of the last (smallest resolution) | |
pyramid output, and the result will extend the result list. The top_block | |
further downsamples the feature map. It must have an attribute | |
"num_levels", meaning the number of extra pyramid levels added by | |
this block, and "in_feature", which is a string representing | |
its input feature (e.g., p5). | |
norm (str): the normalization to use. | |
square_pad (int): If > 0, require input images to be padded to specific square size. | |
""" | |
super(BaseSimpleFeaturePyramid, self).__init__() | |
assert isinstance(net, Backbone) | |
self.scale_factors = scale_factors | |
input_shapes = net.output_shape() | |
strides = [int(input_shapes[in_feature].stride / scale) for scale in scale_factors] | |
_assert_strides_are_log2_contiguous(strides) | |
dim = input_shapes[in_feature].channels | |
self.stages = [] | |
use_bias = norm == "" | |
for idx, scale in enumerate(scale_factors): | |
out_dim = dim | |
if scale == 4.0: | |
layers = [ | |
nn.ConvTranspose2d(dim, dim // 2, kernel_size=2, stride=2), | |
get_norm(norm, dim // 2), | |
nn.GELU(), | |
nn.ConvTranspose2d(dim // 2, dim // 4, kernel_size=2, stride=2), | |
] | |
out_dim = dim // 4 | |
elif scale == 2.0: | |
layers = [nn.ConvTranspose2d(dim, dim // 2, kernel_size=2, stride=2)] | |
out_dim = dim // 2 | |
elif scale == 1.0: | |
layers = [] | |
elif scale == 0.5: | |
layers = [nn.MaxPool2d(kernel_size=2, stride=2)] | |
elif scale == 0.25: | |
layers = [nn.MaxPool2d(kernel_size=4, stride=4)] | |
else: | |
raise NotImplementedError(f"scale_factor={scale} is not supported yet.") | |
layers.extend( | |
[ | |
Conv2d( | |
out_dim, | |
out_channels, | |
kernel_size=1, | |
bias=use_bias, | |
norm=get_norm(norm, out_channels), | |
), | |
Conv2d( | |
out_channels, | |
out_channels, | |
kernel_size=3, | |
padding=1, | |
bias=use_bias, | |
norm=get_norm(norm, out_channels), | |
), | |
] | |
) | |
layers = nn.Sequential(*layers) | |
stage = int(math.log2(strides[idx])) | |
self.add_module(f"simfp_{stage}", layers) | |
self.stages.append(layers) | |
self.net = net | |
self.in_feature = in_feature | |
self.top_block = top_block | |
# Return feature names are "p<stage>", like ["p2", "p3", ..., "p6"] | |
self._out_feature_strides = {"p{}".format(int(math.log2(s))): s for s in strides} | |
# top block output feature maps. | |
if self.top_block is not None: | |
for s in range(stage, stage + self.top_block.num_levels): | |
self._out_feature_strides["p{}".format(s + 1)] = 2 ** (s + 1) | |
self._out_features = list(self._out_feature_strides.keys()) | |
self._out_feature_channels = {k: out_channels for k in self._out_features} | |
self._size_divisibility = strides[-1] | |
self._square_pad = square_pad | |
# Base | |
embed_dim, depth, num_heads, dp = 768, 12, 12, 0.1 | |
# Creates Simple Feature Pyramid from ViT backbone | |
model.backbone = L(SimpleFeaturePyramid)( | |
net=L(ViT)( | |
img_size=224, | |
encoder_embed_dim=768, | |
encoder_depth=12, | |
encoder_num_heads=12, | |
encoder_num_classes=0, | |
decoder_embed_dim=384, | |
decoder_num_heads=16, | |
decoder_depth=8, | |
mlp_ratio=4, | |
qkv_bias=True, | |
k_bias=True, | |
norm_layer=partial(nn.LayerNorm, eps=1e-6), | |
patch_size=(16, 16), #(8, 8), | |
num_frames=3, | |
tubelet_size=1, | |
return_detectron_format=True, | |
use_flash_attention=True, | |
out_feature='last_feat' | |
), | |
in_feature="${.net.out_feature}", | |
out_channels=256, | |
scale_factors=(4.0, 2.0, 1.0, 0.5, 0.25), | |
top_block=L(LastLevelMaxPool)(), | |
norm="LN", | |
square_pad=512, | |
) | |
model.roi_heads.box_head.conv_norm = model.roi_heads.mask_head.conv_norm = "LN" | |
# 2conv in RPN: | |
model.proposal_generator.head.conv_dims = [-1, -1] | |
# 4conv1fc box head | |
model.roi_heads.box_head.conv_dims = [256, 256, 256, 256] | |
model.roi_heads.box_head.fc_dims = [1024] |