rahulvenkk
app.py updated
6dfcb0f
raw
history blame
10 kB
from functools import partial
import torch.nn as nn
from detectron2.config import LazyCall as L
from detectron2.modeling import ViT
from detectron2.modeling import SimpleFeaturePyramid as BaseSimpleFeaturePyramid
from detectron2.modeling.backbone.fpn import LastLevelMaxPool
from detectron2.layers import CNNBlockBase, Conv2d, get_norm
import sys
sys.path.append('../../')
from modeling_pretrain_cleaned import PretrainVisionTransformer
from detectron2.modeling.backbone.fpn import _assert_strides_are_log2_contiguous
from models.mask_rcnn_fpn_v2 import model, constants
from detectron2.modeling.backbone import Backbone
import torch
import math
import torch.nn.functional as F
import time
model.pixel_mean = constants['imagenet_rgb256_mean']
model.pixel_std = constants['imagenet_rgb256_std']
model.input_format = "RGB"
class ViT(Backbone):
def __init__(self,
img_size=224,
encoder_embed_dim=768,
encoder_depth=12,
encoder_num_heads=12,
encoder_num_classes=0,
decoder_embed_dim=384,
decoder_num_heads=16,
decoder_depth=8,
mlp_ratio=4,
qkv_bias=True,
k_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
patch_size=(8, 8),
num_frames=3,
tubelet_size=1,
use_flash_attention=True,
return_detectron_format=True,
out_feature='last_feat'
):
super().__init__()
self.model = PretrainVisionTransformer( # Single-scale ViT backbone
img_size=img_size,
encoder_embed_dim=encoder_embed_dim,
encoder_depth=encoder_depth,
encoder_num_heads=encoder_num_heads,
encoder_num_classes=encoder_num_classes,
decoder_embed_dim=decoder_embed_dim,
decoder_num_heads=decoder_num_heads,
decoder_depth=decoder_depth,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
k_bias=k_bias,
norm_layer=norm_layer,
patch_size=patch_size,
num_frames=num_frames,
tubelet_size=tubelet_size,
use_flash_attention=use_flash_attention,
return_detectron_format=return_detectron_format,
out_feature=out_feature
)
self._out_features = [out_feature]
self._out_feature_channels = {out_feature: encoder_embed_dim * 2}
self._out_feature_strides = {out_feature: patch_size[0]}
self.patch_hw = 512 // patch_size[0]
self.num_frames = num_frames
pos_embed = self.get_abs_pos(self.model.encoder.pos_embed, num_frames, [self.patch_hw, self.patch_hw])
self.model.encoder.pos_embed = pos_embed[:, 0:self.patch_hw**2 * (self.num_frames - 1), :]
def forward(self, x):
B = x.shape[0]
x = x.unsqueeze(2).expand(-1, -1, self.num_frames-1, -1, -1)
mask = torch.zeros(B, self.patch_hw**2 * (self.num_frames - 1), dtype=torch.bool).to(x.device)
return self.model(x, mask)
def get_abs_pos(self, abs_pos, num_frames, hw):
"""
Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token
dimension for the original embeddings.
Args:
abs_pos (Tensor): absolute positional embeddings with (1, num_position, C).
has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token.
hw (Tuple): size of input image tokens.
Returns:
Absolute positional embeddings after processing with shape (1, H, W, C)
"""
h, w = hw
xy_num = abs_pos.shape[1] // num_frames
size = int(math.sqrt(xy_num))
assert size * size * num_frames == abs_pos.shape[1]
abs_pos = abs_pos.view(num_frames, xy_num, -1)
if size != h or size != w:
new_abs_pos = torch.nn.functional.interpolate(
abs_pos.reshape(num_frames, size, size, -1).permute(0, 3, 1, 2),
size=(h, w),
mode="bicubic",
align_corners=False,
)
return new_abs_pos.permute(0, 2, 3, 1).flatten(0, 2).unsqueeze(0)
else:
return abs_pos
class SimpleFeaturePyramid(BaseSimpleFeaturePyramid):
"""
This module implements SimpleFeaturePyramid in :paper:`vitdet`.
It creates pyramid features built on top of the input feature map.
"""
def __init__(
self,
net,
in_feature,
out_channels,
scale_factors,
top_block=None,
norm="LN",
square_pad=0,
):
"""
Args:
net (Backbone): module representing the subnetwork backbone.
Must be a subclass of :class:`Backbone`.
in_feature (str): names of the input feature maps coming
from the net.
out_channels (int): number of channels in the output feature maps.
scale_factors (list[float]): list of scaling factors to upsample or downsample
the input features for creating pyramid features.
top_block (nn.Module or None): if provided, an extra operation will
be performed on the output of the last (smallest resolution)
pyramid output, and the result will extend the result list. The top_block
further downsamples the feature map. It must have an attribute
"num_levels", meaning the number of extra pyramid levels added by
this block, and "in_feature", which is a string representing
its input feature (e.g., p5).
norm (str): the normalization to use.
square_pad (int): If > 0, require input images to be padded to specific square size.
"""
super(BaseSimpleFeaturePyramid, self).__init__()
assert isinstance(net, Backbone)
self.scale_factors = scale_factors
input_shapes = net.output_shape()
strides = [int(input_shapes[in_feature].stride / scale) for scale in scale_factors]
_assert_strides_are_log2_contiguous(strides)
dim = input_shapes[in_feature].channels
self.stages = []
use_bias = norm == ""
for idx, scale in enumerate(scale_factors):
out_dim = dim
if scale == 4.0:
layers = [
nn.ConvTranspose2d(dim, dim // 2, kernel_size=2, stride=2),
get_norm(norm, dim // 2),
nn.GELU(),
nn.ConvTranspose2d(dim // 2, dim // 4, kernel_size=2, stride=2),
]
out_dim = dim // 4
elif scale == 2.0:
layers = [nn.ConvTranspose2d(dim, dim // 2, kernel_size=2, stride=2)]
out_dim = dim // 2
elif scale == 1.0:
layers = []
elif scale == 0.5:
layers = [nn.MaxPool2d(kernel_size=2, stride=2)]
elif scale == 0.25:
layers = [nn.MaxPool2d(kernel_size=4, stride=4)]
else:
raise NotImplementedError(f"scale_factor={scale} is not supported yet.")
layers.extend(
[
Conv2d(
out_dim,
out_channels,
kernel_size=1,
bias=use_bias,
norm=get_norm(norm, out_channels),
),
Conv2d(
out_channels,
out_channels,
kernel_size=3,
padding=1,
bias=use_bias,
norm=get_norm(norm, out_channels),
),
]
)
layers = nn.Sequential(*layers)
stage = int(math.log2(strides[idx]))
self.add_module(f"simfp_{stage}", layers)
self.stages.append(layers)
self.net = net
self.in_feature = in_feature
self.top_block = top_block
# Return feature names are "p<stage>", like ["p2", "p3", ..., "p6"]
self._out_feature_strides = {"p{}".format(int(math.log2(s))): s for s in strides}
# top block output feature maps.
if self.top_block is not None:
for s in range(stage, stage + self.top_block.num_levels):
self._out_feature_strides["p{}".format(s + 1)] = 2 ** (s + 1)
self._out_features = list(self._out_feature_strides.keys())
self._out_feature_channels = {k: out_channels for k in self._out_features}
self._size_divisibility = strides[-1]
self._square_pad = square_pad
# Base
embed_dim, depth, num_heads, dp = 768, 12, 12, 0.1
# Creates Simple Feature Pyramid from ViT backbone
model.backbone = L(SimpleFeaturePyramid)(
net=L(ViT)(
img_size=224,
encoder_embed_dim=768,
encoder_depth=12,
encoder_num_heads=12,
encoder_num_classes=0,
decoder_embed_dim=384,
decoder_num_heads=16,
decoder_depth=8,
mlp_ratio=4,
qkv_bias=True,
k_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
patch_size=(16, 16), #(8, 8),
num_frames=3,
tubelet_size=1,
return_detectron_format=True,
use_flash_attention=True,
out_feature='last_feat'
),
in_feature="${.net.out_feature}",
out_channels=256,
scale_factors=(4.0, 2.0, 1.0, 0.5, 0.25),
top_block=L(LastLevelMaxPool)(),
norm="LN",
square_pad=512,
)
model.roi_heads.box_head.conv_norm = model.roi_heads.mask_head.conv_norm = "LN"
# 2conv in RPN:
model.proposal_generator.head.conv_dims = [-1, -1]
# 4conv1fc box head
model.roi_heads.box_head.conv_dims = [256, 256, 256, 256]
model.roi_heads.box_head.fc_dims = [1024]