# Copyright (c) Facebook, Inc. and its affiliates. import math import fvcore.nn.weight_init as weight_init import torch import torch.nn.functional as F from torch import nn from annotator.oneformer.detectron2.layers import Conv2d, ShapeSpec, get_norm from .backbone import Backbone from .build import BACKBONE_REGISTRY from .resnet import build_resnet_backbone __all__ = ["build_resnet_fpn_backbone", "build_retinanet_resnet_fpn_backbone", "FPN"] class FPN(Backbone): """ This module implements :paper:`FPN`. It creates pyramid features built on top of some input feature maps. """ _fuse_type: torch.jit.Final[str] def __init__( self, bottom_up, in_features, out_channels, norm="", top_block=None, fuse_type="sum", square_pad=0, ): """ Args: bottom_up (Backbone): module representing the bottom up subnetwork. Must be a subclass of :class:`Backbone`. The multi-scale feature maps generated by the bottom up network, and listed in `in_features`, are used to generate FPN levels. in_features (list[str]): names of the input feature maps coming from the backbone to which FPN is attached. For example, if the backbone produces ["res2", "res3", "res4"], any *contiguous* sublist of these may be used; order must be from high to low resolution. out_channels (int): number of channels in the output feature maps. norm (str): the normalization to use. top_block (nn.Module or None): if provided, an extra operation will be performed on the output of the last (smallest resolution) FPN output, and the result will extend the result list. The top_block further downsamples the feature map. It must have an attribute "num_levels", meaning the number of extra FPN levels added by this block, and "in_feature", which is a string representing its input feature (e.g., p5). fuse_type (str): types for fusing the top down features and the lateral ones. It can be "sum" (default), which sums up element-wise; or "avg", which takes the element-wise mean of the two. square_pad (int): If > 0, require input images to be padded to specific square size. """ super(FPN, self).__init__() assert isinstance(bottom_up, Backbone) assert in_features, in_features # Feature map strides and channels from the bottom up network (e.g. ResNet) input_shapes = bottom_up.output_shape() strides = [input_shapes[f].stride for f in in_features] in_channels_per_feature = [input_shapes[f].channels for f in in_features] _assert_strides_are_log2_contiguous(strides) lateral_convs = [] output_convs = [] use_bias = norm == "" for idx, in_channels in enumerate(in_channels_per_feature): lateral_norm = get_norm(norm, out_channels) output_norm = get_norm(norm, out_channels) lateral_conv = Conv2d( in_channels, out_channels, kernel_size=1, bias=use_bias, norm=lateral_norm ) output_conv = Conv2d( out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=use_bias, norm=output_norm, ) weight_init.c2_xavier_fill(lateral_conv) weight_init.c2_xavier_fill(output_conv) stage = int(math.log2(strides[idx])) self.add_module("fpn_lateral{}".format(stage), lateral_conv) self.add_module("fpn_output{}".format(stage), output_conv) lateral_convs.append(lateral_conv) output_convs.append(output_conv) # Place convs into top-down order (from low to high resolution) # to make the top-down computation in forward clearer. self.lateral_convs = lateral_convs[::-1] self.output_convs = output_convs[::-1] self.top_block = top_block self.in_features = tuple(in_features) self.bottom_up = bottom_up # Return feature names are "p", like ["p2", "p3", ..., "p6"] self._out_feature_strides = {"p{}".format(int(math.log2(s))): s for s in strides} # top block output feature maps. if self.top_block is not None: for s in range(stage, stage + self.top_block.num_levels): self._out_feature_strides["p{}".format(s + 1)] = 2 ** (s + 1) self._out_features = list(self._out_feature_strides.keys()) self._out_feature_channels = {k: out_channels for k in self._out_features} self._size_divisibility = strides[-1] self._square_pad = square_pad assert fuse_type in {"avg", "sum"} self._fuse_type = fuse_type @property def size_divisibility(self): return self._size_divisibility @property def padding_constraints(self): return {"square_size": self._square_pad} def forward(self, x): """ Args: input (dict[str->Tensor]): mapping feature map name (e.g., "res5") to feature map tensor for each feature level in high to low resolution order. Returns: dict[str->Tensor]: mapping from feature map name to FPN feature map tensor in high to low resolution order. Returned feature names follow the FPN paper convention: "p", where stage has stride = 2 ** stage e.g., ["p2", "p3", ..., "p6"]. """ bottom_up_features = self.bottom_up(x) results = [] prev_features = self.lateral_convs[0](bottom_up_features[self.in_features[-1]]) results.append(self.output_convs[0](prev_features)) # Reverse feature maps into top-down order (from low to high resolution) for idx, (lateral_conv, output_conv) in enumerate( zip(self.lateral_convs, self.output_convs) ): # Slicing of ModuleList is not supported https://github.com/pytorch/pytorch/issues/47336 # Therefore we loop over all modules but skip the first one if idx > 0: features = self.in_features[-idx - 1] features = bottom_up_features[features] top_down_features = F.interpolate(prev_features, scale_factor=2.0, mode="nearest") lateral_features = lateral_conv(features) prev_features = lateral_features + top_down_features if self._fuse_type == "avg": prev_features /= 2 results.insert(0, output_conv(prev_features)) if self.top_block is not None: if self.top_block.in_feature in bottom_up_features: top_block_in_feature = bottom_up_features[self.top_block.in_feature] else: top_block_in_feature = results[self._out_features.index(self.top_block.in_feature)] results.extend(self.top_block(top_block_in_feature)) assert len(self._out_features) == len(results) return {f: res for f, res in zip(self._out_features, results)} def output_shape(self): return { name: ShapeSpec( channels=self._out_feature_channels[name], stride=self._out_feature_strides[name] ) for name in self._out_features } def _assert_strides_are_log2_contiguous(strides): """ Assert that each stride is 2x times its preceding stride, i.e. "contiguous in log2". """ for i, stride in enumerate(strides[1:], 1): assert stride == 2 * strides[i - 1], "Strides {} {} are not log2 contiguous".format( stride, strides[i - 1] ) class LastLevelMaxPool(nn.Module): """ This module is used in the original FPN to generate a downsampled P6 feature from P5. """ def __init__(self): super().__init__() self.num_levels = 1 self.in_feature = "p5" def forward(self, x): return [F.max_pool2d(x, kernel_size=1, stride=2, padding=0)] class LastLevelP6P7(nn.Module): """ This module is used in RetinaNet to generate extra layers, P6 and P7 from C5 feature. """ def __init__(self, in_channels, out_channels, in_feature="res5"): super().__init__() self.num_levels = 2 self.in_feature = in_feature self.p6 = nn.Conv2d(in_channels, out_channels, 3, 2, 1) self.p7 = nn.Conv2d(out_channels, out_channels, 3, 2, 1) for module in [self.p6, self.p7]: weight_init.c2_xavier_fill(module) def forward(self, c5): p6 = self.p6(c5) p7 = self.p7(F.relu(p6)) return [p6, p7] @BACKBONE_REGISTRY.register() def build_resnet_fpn_backbone(cfg, input_shape: ShapeSpec): """ Args: cfg: a detectron2 CfgNode Returns: backbone (Backbone): backbone module, must be a subclass of :class:`Backbone`. """ bottom_up = build_resnet_backbone(cfg, input_shape) in_features = cfg.MODEL.FPN.IN_FEATURES out_channels = cfg.MODEL.FPN.OUT_CHANNELS backbone = FPN( bottom_up=bottom_up, in_features=in_features, out_channels=out_channels, norm=cfg.MODEL.FPN.NORM, top_block=LastLevelMaxPool(), fuse_type=cfg.MODEL.FPN.FUSE_TYPE, ) return backbone @BACKBONE_REGISTRY.register() def build_retinanet_resnet_fpn_backbone(cfg, input_shape: ShapeSpec): """ Args: cfg: a detectron2 CfgNode Returns: backbone (Backbone): backbone module, must be a subclass of :class:`Backbone`. """ bottom_up = build_resnet_backbone(cfg, input_shape) in_features = cfg.MODEL.FPN.IN_FEATURES out_channels = cfg.MODEL.FPN.OUT_CHANNELS in_channels_p6p7 = bottom_up.output_shape()["res5"].channels backbone = FPN( bottom_up=bottom_up, in_features=in_features, out_channels=out_channels, norm=cfg.MODEL.FPN.NORM, top_block=LastLevelP6P7(in_channels_p6p7, out_channels), fuse_type=cfg.MODEL.FPN.FUSE_TYPE, ) return backbone