ai-photo-gallery / mmcls /models /heads /multi_label_csra_head.py
KyanChen's picture
init
f549064
# Copyright (c) OpenMMLab. All rights reserved.
# Modified from https://github.com/Kevinz-code/CSRA
from typing import Tuple
import torch
import torch.nn as nn
from mmengine.model import BaseModule, ModuleList
from mmcls.registry import MODELS
from .multi_label_cls_head import MultiLabelClsHead
@MODELS.register_module()
class CSRAClsHead(MultiLabelClsHead):
"""Class-specific residual attention classifier head.
Please refer to the `Residual Attention: A Simple but Effective Method for
Multi-Label Recognition (ICCV 2021) <https://arxiv.org/abs/2108.02456>`_
for details.
Args:
num_classes (int): Number of categories.
in_channels (int): Number of channels in the input feature map.
num_heads (int): Number of residual at tensor heads.
loss (dict): Config of classification loss.
lam (float): Lambda that combines global average and max pooling
scores.
init_cfg (dict, optional): The extra init config of layers.
Defaults to use ``dict(type='Normal', layer='Linear', std=0.01)``.
"""
temperature_settings = { # softmax temperature settings
1: [1],
2: [1, 99],
4: [1, 2, 4, 99],
6: [1, 2, 3, 4, 5, 99],
8: [1, 2, 3, 4, 5, 6, 7, 99]
}
def __init__(self,
num_classes: int,
in_channels: int,
num_heads: int,
lam: float,
init_cfg=dict(type='Normal', layer='Linear', std=0.01),
**kwargs):
assert num_heads in self.temperature_settings.keys(
), 'The num of heads is not in temperature setting.'
assert lam > 0, 'Lambda should be between 0 and 1.'
super(CSRAClsHead, self).__init__(init_cfg=init_cfg, **kwargs)
self.temp_list = self.temperature_settings[num_heads]
self.csra_heads = ModuleList([
CSRAModule(num_classes, in_channels, self.temp_list[i], lam)
for i in range(num_heads)
])
def pre_logits(self, feats: Tuple[torch.Tensor]) -> torch.Tensor:
"""The process before the final classification head.
The input ``feats`` is a tuple of tensor, and each tensor is the
feature of a backbone stage. In ``CSRAClsHead``, we just obtain the
feature of the last stage.
"""
# The CSRAClsHead doesn't have other module, just return after
# unpacking.
return feats[-1]
def forward(self, feats: Tuple[torch.Tensor]) -> torch.Tensor:
"""The forward process."""
pre_logits = self.pre_logits(feats)
logit = sum([head(pre_logits) for head in self.csra_heads])
return logit
class CSRAModule(BaseModule):
"""Basic module of CSRA with different temperature.
Args:
num_classes (int): Number of categories.
in_channels (int): Number of channels in the input feature map.
T (int): Temperature setting.
lam (float): Lambda that combines global average and max pooling
scores.
init_cfg (dict | optional): The extra init config of layers.
Defaults to use dict(type='Normal', layer='Linear', std=0.01).
"""
def __init__(self,
num_classes: int,
in_channels: int,
T: int,
lam: float,
init_cfg=None):
super(CSRAModule, self).__init__(init_cfg=init_cfg)
self.T = T # temperature
self.lam = lam # Lambda
self.head = nn.Conv2d(in_channels, num_classes, 1, bias=False)
self.softmax = nn.Softmax(dim=2)
def forward(self, x):
score = self.head(x) / torch.norm(
self.head.weight, dim=1, keepdim=True).transpose(0, 1)
score = score.flatten(2)
base_logit = torch.mean(score, dim=2)
if self.T == 99: # max-pooling
att_logit = torch.max(score, dim=2)[0]
else:
score_soft = self.softmax(score * self.T)
att_logit = torch.sum(score * score_soft, dim=2)
return base_logit + self.lam * att_logit