Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from abc import ABCMeta, abstractmethod | |
from typing import Dict, List, Tuple, Union | |
import torch.nn.functional as F | |
from mmengine.model import BaseModule | |
from torch import Tensor | |
from mmdet.registry import MODELS | |
from mmdet.structures import SampleList | |
from mmdet.utils import ConfigType, OptMultiConfig | |
class BaseSemanticHead(BaseModule, metaclass=ABCMeta): | |
"""Base module of Semantic Head. | |
Args: | |
num_classes (int): the number of classes. | |
seg_rescale_factor (float): the rescale factor for ``gt_sem_seg``, | |
which equals to ``1 / output_strides``. The output_strides is | |
for ``seg_preds``. Defaults to 1 / 4. | |
init_cfg (Optional[Union[:obj:`ConfigDict`, dict]]): the initialization | |
config. | |
loss_seg (Union[:obj:`ConfigDict`, dict]): the loss of the semantic | |
head. | |
""" | |
def __init__(self, | |
num_classes: int, | |
seg_rescale_factor: float = 1 / 4., | |
loss_seg: ConfigType = dict( | |
type='CrossEntropyLoss', | |
ignore_index=255, | |
loss_weight=1.0), | |
init_cfg: OptMultiConfig = None) -> None: | |
super().__init__(init_cfg=init_cfg) | |
self.loss_seg = MODELS.build(loss_seg) | |
self.num_classes = num_classes | |
self.seg_rescale_factor = seg_rescale_factor | |
def forward(self, x: Union[Tensor, Tuple[Tensor]]) -> Dict[str, Tensor]: | |
"""Placeholder of forward function. | |
Args: | |
x (Tensor): Feature maps. | |
Returns: | |
Dict[str, Tensor]: A dictionary, including features | |
and predicted scores. Required keys: 'seg_preds' | |
and 'feats'. | |
""" | |
pass | |
def loss(self, x: Union[Tensor, Tuple[Tensor]], | |
batch_data_samples: SampleList) -> Dict[str, Tensor]: | |
""" | |
Args: | |
x (Union[Tensor, Tuple[Tensor]]): Feature maps. | |
batch_data_samples (list[:obj:`DetDataSample`]): The batch | |
data samples. It usually includes information such | |
as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. | |
Args: | |
x (Tensor): Feature maps. | |
Returns: | |
Dict[str, Tensor]: The loss of semantic head. | |
""" | |
pass | |
def predict(self, | |
x: Union[Tensor, Tuple[Tensor]], | |
batch_img_metas: List[dict], | |
rescale: bool = False) -> List[Tensor]: | |
"""Test without Augmentation. | |
Args: | |
x (Union[Tensor, Tuple[Tensor]]): Feature maps. | |
batch_img_metas (List[dict]): List of image information. | |
rescale (bool): Whether to rescale the results. | |
Defaults to False. | |
Returns: | |
list[Tensor]: semantic segmentation logits. | |
""" | |
seg_preds = self.forward(x)['seg_preds'] | |
seg_preds = F.interpolate( | |
seg_preds, | |
size=batch_img_metas[0]['batch_input_shape'], | |
mode='bilinear', | |
align_corners=False) | |
seg_preds = [seg_preds[i] for i in range(len(batch_img_metas))] | |
if rescale: | |
seg_pred_list = [] | |
for i in range(len(batch_img_metas)): | |
h, w = batch_img_metas[i]['img_shape'] | |
seg_pred = seg_preds[i][:, :h, :w] | |
h, w = batch_img_metas[i]['ori_shape'] | |
seg_pred = F.interpolate( | |
seg_pred[None], | |
size=(h, w), | |
mode='bilinear', | |
align_corners=False)[0] | |
seg_pred_list.append(seg_pred) | |
else: | |
seg_pred_list = seg_preds | |
return seg_pred_list | |