ai-photo-gallery / mmcls /models /heads /multi_label_linear_head.py
KyanChen's picture
init
f549064
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Optional, Tuple
import torch
import torch.nn as nn
from mmcls.registry import MODELS
from .multi_label_cls_head import MultiLabelClsHead
@MODELS.register_module()
class MultiLabelLinearClsHead(MultiLabelClsHead):
"""Linear classification head for multilabel task.
Args:
loss (dict): Config of classification loss. Defaults to
dict(type='CrossEntropyLoss', use_sigmoid=True).
thr (float, optional): Predictions with scores under the thresholds
are considered as negative. Defaults to None.
topk (int, optional): Predictions with the k-th highest scores are
considered as positive. Defaults to None.
init_cfg (dict, optional): The extra init config of layers.
Defaults to use dict(type='Normal', layer='Linear', std=0.01).
Notes:
If both ``thr`` and ``topk`` are set, use ``thr` to determine
positive predictions. If neither is set, use ``thr=0.5`` as
default.
"""
def __init__(self,
num_classes: int,
in_channels: int,
loss: Dict = dict(type='CrossEntropyLoss', use_sigmoid=True),
thr: Optional[float] = None,
topk: Optional[int] = None,
init_cfg: Optional[dict] = dict(
type='Normal', layer='Linear', std=0.01)):
super(MultiLabelLinearClsHead, self).__init__(
loss=loss, thr=thr, topk=topk, init_cfg=init_cfg)
assert num_classes > 0, f'num_classes ({num_classes}) must be a ' \
'positive integer.'
self.in_channels = in_channels
self.num_classes = num_classes
self.fc = nn.Linear(self.in_channels, self.num_classes)
def pre_logits(self, feats: Tuple[torch.Tensor]) -> torch.Tensor:
"""The process before the final classification head.
The input ``feats`` is a tuple of tensor, and each tensor is the
feature of a backbone stage. In ``MultiLabelLinearClsHead``, we just
obtain the feature of the last stage.
"""
# The obtain the MultiLabelLinearClsHead doesn't have other module,
# just return after unpacking.
return feats[-1]
def forward(self, feats: Tuple[torch.Tensor]) -> torch.Tensor:
"""The forward process."""
pre_logits = self.pre_logits(feats)
# The final classification head.
cls_score = self.fc(pre_logits)
return cls_score