# Copyright (c) OpenMMLab. All rights reserved. from typing import Dict, List, Optional, Sequence, Tuple, Union import torch import torch.nn.functional as F from mmpretrain.models.selfsup.mae import MAE, MAEViT from mmpretrain.registry import MODELS from mmpretrain.structures import DataSample @MODELS.register_module() class MFFViT(MAEViT): """Vision Transformer for MFF Pretraining. This class inherits all these functionalities from ``MAEViT``, and add multi-level feature fusion to it. For more details, you can refer to `Improving Pixel-based MIM by Reducing Wasted Modeling Capability`. Args: arch (str | dict): Vision Transformer architecture Default: 'b' img_size (int | tuple): Input image size patch_size (int | tuple): The patch size out_indices (Sequence | int): Output from which stages. Defaults to -1, means the last stage. drop_rate (float): Probability of an element to be zeroed. Defaults to 0. drop_path_rate (float): stochastic depth rate. Defaults to 0. norm_cfg (dict): Config dict for normalization layer. Defaults to ``dict(type='LN')``. final_norm (bool): Whether to add a additional layer to normalize final feature map. Defaults to True. out_type (str): The type of output features. Please choose from - ``"cls_token"``: The class token tensor with shape (B, C). - ``"featmap"``: The feature map tensor from the patch tokens with shape (B, C, H, W). - ``"avg_featmap"``: The global averaged feature map tensor with shape (B, C). - ``"raw"``: The raw feature tensor includes patch tokens and class tokens with shape (B, L, C). It only works without input mask. Defaults to ``"avg_featmap"``. interpolate_mode (str): Select the interpolate mode for position embeding vector resize. Defaults to "bicubic". patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict. layer_cfgs (Sequence | dict): Configs of each transformer layer in encoder. Defaults to an empty dict. mask_ratio (bool): The ratio of total number of patches to be masked. Defaults to 0.75. init_cfg (Union[List[dict], dict], optional): Initialization config dict. Defaults to None. """ def __init__(self, arch: Union[str, dict] = 'b', img_size: int = 224, patch_size: int = 16, out_indices: Union[Sequence, int] = -1, drop_rate: float = 0, drop_path_rate: float = 0, norm_cfg: dict = dict(type='LN', eps=1e-6), final_norm: bool = True, out_type: str = 'raw', interpolate_mode: str = 'bicubic', patch_cfg: dict = dict(), layer_cfgs: dict = dict(), mask_ratio: float = 0.75, init_cfg: Optional[Union[List[dict], dict]] = None) -> None: super().__init__( arch=arch, img_size=img_size, patch_size=patch_size, out_indices=out_indices, drop_rate=drop_rate, drop_path_rate=drop_path_rate, norm_cfg=norm_cfg, final_norm=final_norm, out_type=out_type, interpolate_mode=interpolate_mode, patch_cfg=patch_cfg, layer_cfgs=layer_cfgs, mask_ratio=mask_ratio, init_cfg=init_cfg) proj_layers = [ torch.nn.Linear(self.embed_dims, self.embed_dims) for _ in range(len(self.out_indices) - 1) ] self.proj_layers = torch.nn.ModuleList(proj_layers) self.proj_weights = torch.nn.Parameter( torch.ones(len(self.out_indices)).view(-1, 1, 1, 1)) if len(self.out_indices) == 1: self.proj_weights.requires_grad = False def forward( self, x: torch.Tensor, mask: Optional[bool] = True ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Generate features for masked images. The function supports two kind of forward behaviors. If the ``mask`` is ``True``, the function will generate mask to masking some patches randomly and get the hidden features for visible patches, which means the function will be executed as masked imagemodeling pre-training; if the ``mask`` is ``None`` or ``False``, the forward function will call ``super().forward()``, which extract features from images without mask. Args: x (torch.Tensor): Input images, which is of shape B x C x H x W. mask (bool, optional): To indicate whether the forward function generating ``mask`` or not. Returns: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Hidden features, mask and the ids to restore original image. - ``x`` (torch.Tensor): hidden features, which is of shape B x (L * mask_ratio) x C. - ``mask`` (torch.Tensor): mask used to mask image. - ``ids_restore`` (torch.Tensor): ids to restore original image. """ if mask is None or False: return super().forward(x) else: B = x.shape[0] x = self.patch_embed(x)[0] # add pos embed w/o cls token x = x + self.pos_embed[:, 1:, :] # masking: length -> length * mask_ratio x, mask, ids_restore = self.random_masking(x, self.mask_ratio) # append cls token cls_token = self.cls_token + self.pos_embed[:, :1, :] cls_tokens = cls_token.expand(B, -1, -1) x = torch.cat((cls_tokens, x), dim=1) res = [] for i, layer in enumerate(self.layers): x = layer(x) if i in self.out_indices: if i != self.out_indices[-1]: proj_x = self.proj_layers[self.out_indices.index(i)](x) else: proj_x = x res.append(proj_x) res = torch.stack(res) proj_weights = F.softmax(self.proj_weights, dim=0) res = res * proj_weights res = res.sum(dim=0) # Use final norm x = self.norm1(res) return (x, mask, ids_restore, proj_weights.view(-1)) @MODELS.register_module() class MFF(MAE): """MFF. Implementation of `Improving Pixel-based MIM by Reducing Wasted Modeling Capability`. """ def loss(self, inputs: torch.Tensor, data_samples: List[DataSample], **kwargs) -> Dict[str, torch.Tensor]: """The forward function in training. Args: inputs (torch.Tensor): The input images. data_samples (List[DataSample]): All elements required during the forward function. Returns: Dict[str, torch.Tensor]: A dictionary of loss components. """ # ids_restore: the same as that in original repo, which is used # to recover the original order of tokens in decoder. latent, mask, ids_restore, weights = self.backbone(inputs) pred = self.neck(latent, ids_restore) loss = self.head.loss(pred, inputs, mask) weight_params = { f'weight_{i}': weights[i] for i in range(weights.size(0)) } losses = dict(loss=loss) losses.update(weight_params) return losses