# Copyright (c) OpenMMLab. All rights reserved. from abc import ABCMeta, abstractmethod from typing import List, Tuple, Union from mmengine.model import BaseModule from torch import Tensor from mmdet.structures import SampleList from mmdet.utils import InstanceList, OptInstanceList, OptMultiConfig from ..utils import unpack_gt_instances class BaseMaskHead(BaseModule, metaclass=ABCMeta): """Base class for mask heads used in One-Stage Instance Segmentation.""" def __init__(self, init_cfg: OptMultiConfig = None) -> None: super().__init__(init_cfg=init_cfg) @abstractmethod def loss_by_feat(self, *args, **kwargs): """Calculate the loss based on the features extracted by the mask head.""" pass @abstractmethod def predict_by_feat(self, *args, **kwargs): """Transform a batch of output features extracted from the head into mask results.""" pass def loss(self, x: Union[List[Tensor], Tuple[Tensor]], batch_data_samples: SampleList, positive_infos: OptInstanceList = None, **kwargs) -> dict: """Perform forward propagation and loss calculation of the mask head on the features of the upstream network. Args: x (list[Tensor] | tuple[Tensor]): Features from FPN. Each has a shape (B, C, H, W). batch_data_samples (list[:obj:`DetDataSample`]): Each item contains the meta information of each image and corresponding annotations. positive_infos (list[:obj:`InstanceData`], optional): Information of positive samples. Used when the label assignment is done outside the MaskHead, e.g., BboxHead in YOLACT or CondInst, etc. When the label assignment is done in MaskHead, it would be None, like SOLO or SOLOv2. All values in it should have shape (num_positive_samples, *). Returns: dict: A dictionary of loss components. """ if positive_infos is None: outs = self(x) else: outs = self(x, positive_infos) assert isinstance(outs, tuple), 'Forward results should be a tuple, ' \ 'even if only one item is returned' outputs = unpack_gt_instances(batch_data_samples) batch_gt_instances, batch_gt_instances_ignore, batch_img_metas \ = outputs for gt_instances, img_metas in zip(batch_gt_instances, batch_img_metas): img_shape = img_metas['batch_input_shape'] gt_masks = gt_instances.masks.pad(img_shape) gt_instances.masks = gt_masks losses = self.loss_by_feat( *outs, batch_gt_instances=batch_gt_instances, batch_img_metas=batch_img_metas, positive_infos=positive_infos, batch_gt_instances_ignore=batch_gt_instances_ignore, **kwargs) return losses def predict(self, x: Tuple[Tensor], batch_data_samples: SampleList, rescale: bool = False, results_list: OptInstanceList = None, **kwargs) -> InstanceList: """Test function without test-time augmentation. Args: x (tuple[Tensor]): Multi-level features from the upstream network, each is a 4D-tensor. batch_data_samples (List[:obj:`DetDataSample`]): The Data Samples. It usually includes information such as `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. rescale (bool, optional): Whether to rescale the results. Defaults to False. results_list (list[obj:`InstanceData`], optional): Detection results of each image after the post process. Only exist if there is a `bbox_head`, like `YOLACT`, `CondInst`, etc. Returns: list[obj:`InstanceData`]: Instance segmentation results of each image after the post process. Each item usually contains following keys. - scores (Tensor): Classification scores, has a shape (num_instance,) - labels (Tensor): Has a shape (num_instances,). - masks (Tensor): Processed mask results, has a shape (num_instances, h, w). """ batch_img_metas = [ data_samples.metainfo for data_samples in batch_data_samples ] if results_list is None: outs = self(x) else: outs = self(x, results_list) results_list = self.predict_by_feat( *outs, batch_img_metas=batch_img_metas, rescale=rescale, results_list=results_list, **kwargs) return results_list