KyanChen's picture
Upload 1861 files
3b96cb1
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Sequence, Tuple, Union
import torch
import torch.nn.functional as F
from mmpretrain.models.selfsup.mae import MAE, MAEViT
from mmpretrain.registry import MODELS
from mmpretrain.structures import DataSample
@MODELS.register_module()
class MFFViT(MAEViT):
"""Vision Transformer for MFF Pretraining.
This class inherits all these functionalities from ``MAEViT``, and
add multi-level feature fusion to it. For more details, you can
refer to `Improving Pixel-based MIM by Reducing Wasted Modeling
Capability`.
Args:
arch (str | dict): Vision Transformer architecture
Default: 'b'
img_size (int | tuple): Input image size
patch_size (int | tuple): The patch size
out_indices (Sequence | int): Output from which stages.
Defaults to -1, means the last stage.
drop_rate (float): Probability of an element to be zeroed.
Defaults to 0.
drop_path_rate (float): stochastic depth rate. Defaults to 0.
norm_cfg (dict): Config dict for normalization layer.
Defaults to ``dict(type='LN')``.
final_norm (bool): Whether to add a additional layer to normalize
final feature map. Defaults to True.
out_type (str): The type of output features. Please choose from
- ``"cls_token"``: The class token tensor with shape (B, C).
- ``"featmap"``: The feature map tensor from the patch tokens
with shape (B, C, H, W).
- ``"avg_featmap"``: The global averaged feature map tensor
with shape (B, C).
- ``"raw"``: The raw feature tensor includes patch tokens and
class tokens with shape (B, L, C).
It only works without input mask. Defaults to ``"avg_featmap"``.
interpolate_mode (str): Select the interpolate mode for position
embeding vector resize. Defaults to "bicubic".
patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict.
layer_cfgs (Sequence | dict): Configs of each transformer layer in
encoder. Defaults to an empty dict.
mask_ratio (bool): The ratio of total number of patches to be masked.
Defaults to 0.75.
init_cfg (Union[List[dict], dict], optional): Initialization config
dict. Defaults to None.
"""
def __init__(self,
arch: Union[str, dict] = 'b',
img_size: int = 224,
patch_size: int = 16,
out_indices: Union[Sequence, int] = -1,
drop_rate: float = 0,
drop_path_rate: float = 0,
norm_cfg: dict = dict(type='LN', eps=1e-6),
final_norm: bool = True,
out_type: str = 'raw',
interpolate_mode: str = 'bicubic',
patch_cfg: dict = dict(),
layer_cfgs: dict = dict(),
mask_ratio: float = 0.75,
init_cfg: Optional[Union[List[dict], dict]] = None) -> None:
super().__init__(
arch=arch,
img_size=img_size,
patch_size=patch_size,
out_indices=out_indices,
drop_rate=drop_rate,
drop_path_rate=drop_path_rate,
norm_cfg=norm_cfg,
final_norm=final_norm,
out_type=out_type,
interpolate_mode=interpolate_mode,
patch_cfg=patch_cfg,
layer_cfgs=layer_cfgs,
mask_ratio=mask_ratio,
init_cfg=init_cfg)
proj_layers = [
torch.nn.Linear(self.embed_dims, self.embed_dims)
for _ in range(len(self.out_indices) - 1)
]
self.proj_layers = torch.nn.ModuleList(proj_layers)
self.proj_weights = torch.nn.Parameter(
torch.ones(len(self.out_indices)).view(-1, 1, 1, 1))
if len(self.out_indices) == 1:
self.proj_weights.requires_grad = False
def forward(
self,
x: torch.Tensor,
mask: Optional[bool] = True
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Generate features for masked images.
The function supports two kind of forward behaviors. If the ``mask`` is
``True``, the function will generate mask to masking some patches
randomly and get the hidden features for visible patches, which means
the function will be executed as masked imagemodeling pre-training;
if the ``mask`` is ``None`` or ``False``, the forward function will
call ``super().forward()``, which extract features from images without
mask.
Args:
x (torch.Tensor): Input images, which is of shape B x C x H x W.
mask (bool, optional): To indicate whether the forward function
generating ``mask`` or not.
Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Hidden features,
mask and the ids to restore original image.
- ``x`` (torch.Tensor): hidden features, which is of shape
B x (L * mask_ratio) x C.
- ``mask`` (torch.Tensor): mask used to mask image.
- ``ids_restore`` (torch.Tensor): ids to restore original image.
"""
if mask is None or False:
return super().forward(x)
else:
B = x.shape[0]
x = self.patch_embed(x)[0]
# add pos embed w/o cls token
x = x + self.pos_embed[:, 1:, :]
# masking: length -> length * mask_ratio
x, mask, ids_restore = self.random_masking(x, self.mask_ratio)
# append cls token
cls_token = self.cls_token + self.pos_embed[:, :1, :]
cls_tokens = cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
res = []
for i, layer in enumerate(self.layers):
x = layer(x)
if i in self.out_indices:
if i != self.out_indices[-1]:
proj_x = self.proj_layers[self.out_indices.index(i)](x)
else:
proj_x = x
res.append(proj_x)
res = torch.stack(res)
proj_weights = F.softmax(self.proj_weights, dim=0)
res = res * proj_weights
res = res.sum(dim=0)
# Use final norm
x = self.norm1(res)
return (x, mask, ids_restore, proj_weights.view(-1))
@MODELS.register_module()
class MFF(MAE):
"""MFF.
Implementation of `Improving Pixel-based MIM by Reducing Wasted Modeling
Capability`.
"""
def loss(self, inputs: torch.Tensor, data_samples: List[DataSample],
**kwargs) -> Dict[str, torch.Tensor]:
"""The forward function in training.
Args:
inputs (torch.Tensor): The input images.
data_samples (List[DataSample]): All elements required
during the forward function.
Returns:
Dict[str, torch.Tensor]: A dictionary of loss components.
"""
# ids_restore: the same as that in original repo, which is used
# to recover the original order of tokens in decoder.
latent, mask, ids_restore, weights = self.backbone(inputs)
pred = self.neck(latent, ids_restore)
loss = self.head.loss(pred, inputs, mask)
weight_params = {
f'weight_{i}': weights[i]
for i in range(weights.size(0))
}
losses = dict(loss=loss)
losses.update(weight_params)
return losses