Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from typing import Dict, Optional, Tuple | |
import torch | |
import torch.nn as nn | |
from mmcls.registry import MODELS | |
from .multi_label_cls_head import MultiLabelClsHead | |
class MultiLabelLinearClsHead(MultiLabelClsHead): | |
"""Linear classification head for multilabel task. | |
Args: | |
loss (dict): Config of classification loss. Defaults to | |
dict(type='CrossEntropyLoss', use_sigmoid=True). | |
thr (float, optional): Predictions with scores under the thresholds | |
are considered as negative. Defaults to None. | |
topk (int, optional): Predictions with the k-th highest scores are | |
considered as positive. Defaults to None. | |
init_cfg (dict, optional): The extra init config of layers. | |
Defaults to use dict(type='Normal', layer='Linear', std=0.01). | |
Notes: | |
If both ``thr`` and ``topk`` are set, use ``thr` to determine | |
positive predictions. If neither is set, use ``thr=0.5`` as | |
default. | |
""" | |
def __init__(self, | |
num_classes: int, | |
in_channels: int, | |
loss: Dict = dict(type='CrossEntropyLoss', use_sigmoid=True), | |
thr: Optional[float] = None, | |
topk: Optional[int] = None, | |
init_cfg: Optional[dict] = dict( | |
type='Normal', layer='Linear', std=0.01)): | |
super(MultiLabelLinearClsHead, self).__init__( | |
loss=loss, thr=thr, topk=topk, init_cfg=init_cfg) | |
assert num_classes > 0, f'num_classes ({num_classes}) must be a ' \ | |
'positive integer.' | |
self.in_channels = in_channels | |
self.num_classes = num_classes | |
self.fc = nn.Linear(self.in_channels, self.num_classes) | |
def pre_logits(self, feats: Tuple[torch.Tensor]) -> torch.Tensor: | |
"""The process before the final classification head. | |
The input ``feats`` is a tuple of tensor, and each tensor is the | |
feature of a backbone stage. In ``MultiLabelLinearClsHead``, we just | |
obtain the feature of the last stage. | |
""" | |
# The obtain the MultiLabelLinearClsHead doesn't have other module, | |
# just return after unpacking. | |
return feats[-1] | |
def forward(self, feats: Tuple[torch.Tensor]) -> torch.Tensor: | |
"""The forward process.""" | |
pre_logits = self.pre_logits(feats) | |
# The final classification head. | |
cls_score = self.fc(pre_logits) | |
return cls_score | |