# Copyright (c) OpenMMLab. All rights reserved. from abc import ABCMeta, abstractmethod from typing import Dict, List, Tuple, Union import torch.nn.functional as F from mmengine.model import BaseModule from torch import Tensor from mmdet.registry import MODELS from mmdet.structures import SampleList from mmdet.utils import ConfigType, OptMultiConfig @MODELS.register_module() class BaseSemanticHead(BaseModule, metaclass=ABCMeta): """Base module of Semantic Head. Args: num_classes (int): the number of classes. seg_rescale_factor (float): the rescale factor for ``gt_sem_seg``, which equals to ``1 / output_strides``. The output_strides is for ``seg_preds``. Defaults to 1 / 4. init_cfg (Optional[Union[:obj:`ConfigDict`, dict]]): the initialization config. loss_seg (Union[:obj:`ConfigDict`, dict]): the loss of the semantic head. """ def __init__(self, num_classes: int, seg_rescale_factor: float = 1 / 4., loss_seg: ConfigType = dict( type='CrossEntropyLoss', ignore_index=255, loss_weight=1.0), init_cfg: OptMultiConfig = None) -> None: super().__init__(init_cfg=init_cfg) self.loss_seg = MODELS.build(loss_seg) self.num_classes = num_classes self.seg_rescale_factor = seg_rescale_factor @abstractmethod def forward(self, x: Union[Tensor, Tuple[Tensor]]) -> Dict[str, Tensor]: """Placeholder of forward function. Args: x (Tensor): Feature maps. Returns: Dict[str, Tensor]: A dictionary, including features and predicted scores. Required keys: 'seg_preds' and 'feats'. """ pass @abstractmethod def loss(self, x: Union[Tensor, Tuple[Tensor]], batch_data_samples: SampleList) -> Dict[str, Tensor]: """ Args: x (Union[Tensor, Tuple[Tensor]]): Feature maps. batch_data_samples (list[:obj:`DetDataSample`]): The batch data samples. It usually includes information such as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. Args: x (Tensor): Feature maps. Returns: Dict[str, Tensor]: The loss of semantic head. """ pass def predict(self, x: Union[Tensor, Tuple[Tensor]], batch_img_metas: List[dict], rescale: bool = False) -> List[Tensor]: """Test without Augmentation. Args: x (Union[Tensor, Tuple[Tensor]]): Feature maps. batch_img_metas (List[dict]): List of image information. rescale (bool): Whether to rescale the results. Defaults to False. Returns: list[Tensor]: semantic segmentation logits. """ seg_preds = self.forward(x)['seg_preds'] seg_preds = F.interpolate( seg_preds, size=batch_img_metas[0]['batch_input_shape'], mode='bilinear', align_corners=False) seg_preds = [seg_preds[i] for i in range(len(batch_img_metas))] if rescale: seg_pred_list = [] for i in range(len(batch_img_metas)): h, w = batch_img_metas[i]['img_shape'] seg_pred = seg_preds[i][:, :h, :w] h, w = batch_img_metas[i]['ori_shape'] seg_pred = F.interpolate( seg_pred[None], size=(h, w), mode='bilinear', align_corners=False)[0] seg_pred_list.append(seg_pred) else: seg_pred_list = seg_preds return seg_pred_list