|
|
|
import math |
|
import fvcore.nn.weight_init as weight_init |
|
import torch |
|
import torch.nn.functional as F |
|
from torch import nn |
|
|
|
from annotator.oneformer.detectron2.layers import Conv2d, ShapeSpec, get_norm |
|
|
|
from .backbone import Backbone |
|
from .build import BACKBONE_REGISTRY |
|
from .resnet import build_resnet_backbone |
|
|
|
__all__ = ["build_resnet_fpn_backbone", "build_retinanet_resnet_fpn_backbone", "FPN"] |
|
|
|
|
|
class FPN(Backbone): |
|
""" |
|
This module implements :paper:`FPN`. |
|
It creates pyramid features built on top of some input feature maps. |
|
""" |
|
|
|
_fuse_type: torch.jit.Final[str] |
|
|
|
def __init__( |
|
self, |
|
bottom_up, |
|
in_features, |
|
out_channels, |
|
norm="", |
|
top_block=None, |
|
fuse_type="sum", |
|
square_pad=0, |
|
): |
|
""" |
|
Args: |
|
bottom_up (Backbone): module representing the bottom up subnetwork. |
|
Must be a subclass of :class:`Backbone`. The multi-scale feature |
|
maps generated by the bottom up network, and listed in `in_features`, |
|
are used to generate FPN levels. |
|
in_features (list[str]): names of the input feature maps coming |
|
from the backbone to which FPN is attached. For example, if the |
|
backbone produces ["res2", "res3", "res4"], any *contiguous* sublist |
|
of these may be used; order must be from high to low resolution. |
|
out_channels (int): number of channels in the output feature maps. |
|
norm (str): the normalization to use. |
|
top_block (nn.Module or None): if provided, an extra operation will |
|
be performed on the output of the last (smallest resolution) |
|
FPN output, and the result will extend the result list. The top_block |
|
further downsamples the feature map. It must have an attribute |
|
"num_levels", meaning the number of extra FPN levels added by |
|
this block, and "in_feature", which is a string representing |
|
its input feature (e.g., p5). |
|
fuse_type (str): types for fusing the top down features and the lateral |
|
ones. It can be "sum" (default), which sums up element-wise; or "avg", |
|
which takes the element-wise mean of the two. |
|
square_pad (int): If > 0, require input images to be padded to specific square size. |
|
""" |
|
super(FPN, self).__init__() |
|
assert isinstance(bottom_up, Backbone) |
|
assert in_features, in_features |
|
|
|
|
|
input_shapes = bottom_up.output_shape() |
|
strides = [input_shapes[f].stride for f in in_features] |
|
in_channels_per_feature = [input_shapes[f].channels for f in in_features] |
|
|
|
_assert_strides_are_log2_contiguous(strides) |
|
lateral_convs = [] |
|
output_convs = [] |
|
|
|
use_bias = norm == "" |
|
for idx, in_channels in enumerate(in_channels_per_feature): |
|
lateral_norm = get_norm(norm, out_channels) |
|
output_norm = get_norm(norm, out_channels) |
|
|
|
lateral_conv = Conv2d( |
|
in_channels, out_channels, kernel_size=1, bias=use_bias, norm=lateral_norm |
|
) |
|
output_conv = Conv2d( |
|
out_channels, |
|
out_channels, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1, |
|
bias=use_bias, |
|
norm=output_norm, |
|
) |
|
weight_init.c2_xavier_fill(lateral_conv) |
|
weight_init.c2_xavier_fill(output_conv) |
|
stage = int(math.log2(strides[idx])) |
|
self.add_module("fpn_lateral{}".format(stage), lateral_conv) |
|
self.add_module("fpn_output{}".format(stage), output_conv) |
|
|
|
lateral_convs.append(lateral_conv) |
|
output_convs.append(output_conv) |
|
|
|
|
|
self.lateral_convs = lateral_convs[::-1] |
|
self.output_convs = output_convs[::-1] |
|
self.top_block = top_block |
|
self.in_features = tuple(in_features) |
|
self.bottom_up = bottom_up |
|
|
|
self._out_feature_strides = {"p{}".format(int(math.log2(s))): s for s in strides} |
|
|
|
if self.top_block is not None: |
|
for s in range(stage, stage + self.top_block.num_levels): |
|
self._out_feature_strides["p{}".format(s + 1)] = 2 ** (s + 1) |
|
|
|
self._out_features = list(self._out_feature_strides.keys()) |
|
self._out_feature_channels = {k: out_channels for k in self._out_features} |
|
self._size_divisibility = strides[-1] |
|
self._square_pad = square_pad |
|
assert fuse_type in {"avg", "sum"} |
|
self._fuse_type = fuse_type |
|
|
|
@property |
|
def size_divisibility(self): |
|
return self._size_divisibility |
|
|
|
@property |
|
def padding_constraints(self): |
|
return {"square_size": self._square_pad} |
|
|
|
def forward(self, x): |
|
""" |
|
Args: |
|
input (dict[str->Tensor]): mapping feature map name (e.g., "res5") to |
|
feature map tensor for each feature level in high to low resolution order. |
|
|
|
Returns: |
|
dict[str->Tensor]: |
|
mapping from feature map name to FPN feature map tensor |
|
in high to low resolution order. Returned feature names follow the FPN |
|
paper convention: "p<stage>", where stage has stride = 2 ** stage e.g., |
|
["p2", "p3", ..., "p6"]. |
|
""" |
|
bottom_up_features = self.bottom_up(x) |
|
results = [] |
|
prev_features = self.lateral_convs[0](bottom_up_features[self.in_features[-1]]) |
|
results.append(self.output_convs[0](prev_features)) |
|
|
|
|
|
for idx, (lateral_conv, output_conv) in enumerate( |
|
zip(self.lateral_convs, self.output_convs) |
|
): |
|
|
|
|
|
if idx > 0: |
|
features = self.in_features[-idx - 1] |
|
features = bottom_up_features[features] |
|
top_down_features = F.interpolate(prev_features, scale_factor=2.0, mode="nearest") |
|
lateral_features = lateral_conv(features) |
|
prev_features = lateral_features + top_down_features |
|
if self._fuse_type == "avg": |
|
prev_features /= 2 |
|
results.insert(0, output_conv(prev_features)) |
|
|
|
if self.top_block is not None: |
|
if self.top_block.in_feature in bottom_up_features: |
|
top_block_in_feature = bottom_up_features[self.top_block.in_feature] |
|
else: |
|
top_block_in_feature = results[self._out_features.index(self.top_block.in_feature)] |
|
results.extend(self.top_block(top_block_in_feature)) |
|
assert len(self._out_features) == len(results) |
|
return {f: res for f, res in zip(self._out_features, results)} |
|
|
|
def output_shape(self): |
|
return { |
|
name: ShapeSpec( |
|
channels=self._out_feature_channels[name], stride=self._out_feature_strides[name] |
|
) |
|
for name in self._out_features |
|
} |
|
|
|
|
|
def _assert_strides_are_log2_contiguous(strides): |
|
""" |
|
Assert that each stride is 2x times its preceding stride, i.e. "contiguous in log2". |
|
""" |
|
for i, stride in enumerate(strides[1:], 1): |
|
assert stride == 2 * strides[i - 1], "Strides {} {} are not log2 contiguous".format( |
|
stride, strides[i - 1] |
|
) |
|
|
|
|
|
class LastLevelMaxPool(nn.Module): |
|
""" |
|
This module is used in the original FPN to generate a downsampled |
|
P6 feature from P5. |
|
""" |
|
|
|
def __init__(self): |
|
super().__init__() |
|
self.num_levels = 1 |
|
self.in_feature = "p5" |
|
|
|
def forward(self, x): |
|
return [F.max_pool2d(x, kernel_size=1, stride=2, padding=0)] |
|
|
|
|
|
class LastLevelP6P7(nn.Module): |
|
""" |
|
This module is used in RetinaNet to generate extra layers, P6 and P7 from |
|
C5 feature. |
|
""" |
|
|
|
def __init__(self, in_channels, out_channels, in_feature="res5"): |
|
super().__init__() |
|
self.num_levels = 2 |
|
self.in_feature = in_feature |
|
self.p6 = nn.Conv2d(in_channels, out_channels, 3, 2, 1) |
|
self.p7 = nn.Conv2d(out_channels, out_channels, 3, 2, 1) |
|
for module in [self.p6, self.p7]: |
|
weight_init.c2_xavier_fill(module) |
|
|
|
def forward(self, c5): |
|
p6 = self.p6(c5) |
|
p7 = self.p7(F.relu(p6)) |
|
return [p6, p7] |
|
|
|
|
|
@BACKBONE_REGISTRY.register() |
|
def build_resnet_fpn_backbone(cfg, input_shape: ShapeSpec): |
|
""" |
|
Args: |
|
cfg: a detectron2 CfgNode |
|
|
|
Returns: |
|
backbone (Backbone): backbone module, must be a subclass of :class:`Backbone`. |
|
""" |
|
bottom_up = build_resnet_backbone(cfg, input_shape) |
|
in_features = cfg.MODEL.FPN.IN_FEATURES |
|
out_channels = cfg.MODEL.FPN.OUT_CHANNELS |
|
backbone = FPN( |
|
bottom_up=bottom_up, |
|
in_features=in_features, |
|
out_channels=out_channels, |
|
norm=cfg.MODEL.FPN.NORM, |
|
top_block=LastLevelMaxPool(), |
|
fuse_type=cfg.MODEL.FPN.FUSE_TYPE, |
|
) |
|
return backbone |
|
|
|
|
|
@BACKBONE_REGISTRY.register() |
|
def build_retinanet_resnet_fpn_backbone(cfg, input_shape: ShapeSpec): |
|
""" |
|
Args: |
|
cfg: a detectron2 CfgNode |
|
|
|
Returns: |
|
backbone (Backbone): backbone module, must be a subclass of :class:`Backbone`. |
|
""" |
|
bottom_up = build_resnet_backbone(cfg, input_shape) |
|
in_features = cfg.MODEL.FPN.IN_FEATURES |
|
out_channels = cfg.MODEL.FPN.OUT_CHANNELS |
|
in_channels_p6p7 = bottom_up.output_shape()["res5"].channels |
|
backbone = FPN( |
|
bottom_up=bottom_up, |
|
in_features=in_features, |
|
out_channels=out_channels, |
|
norm=cfg.MODEL.FPN.NORM, |
|
top_block=LastLevelP6P7(in_channels_p6p7, out_channels), |
|
fuse_type=cfg.MODEL.FPN.FUSE_TYPE, |
|
) |
|
return backbone |
|
|