Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from abc import ABCMeta, abstractmethod | |
from typing import List, Tuple | |
from mmengine.model import BaseModel | |
from mmengine.structures import PixelData | |
from torch import Tensor | |
from mmseg.structures import SegDataSample | |
from mmseg.utils import (ForwardResults, OptConfigType, OptMultiConfig, | |
OptSampleList, SampleList) | |
from ..utils import resize | |
class BaseSegmentor(BaseModel, metaclass=ABCMeta): | |
"""Base class for segmentors. | |
Args: | |
data_preprocessor (dict, optional): Model preprocessing config | |
for processing the input data. it usually includes | |
``to_rgb``, ``pad_size_divisor``, ``pad_val``, | |
``mean`` and ``std``. Default to None. | |
init_cfg (dict, optional): the config to control the | |
initialization. Default to None. | |
""" | |
def __init__(self, | |
data_preprocessor: OptConfigType = None, | |
init_cfg: OptMultiConfig = None): | |
super().__init__( | |
data_preprocessor=data_preprocessor, init_cfg=init_cfg) | |
def with_neck(self) -> bool: | |
"""bool: whether the segmentor has neck""" | |
return hasattr(self, 'neck') and self.neck is not None | |
def with_auxiliary_head(self) -> bool: | |
"""bool: whether the segmentor has auxiliary head""" | |
return hasattr(self, | |
'auxiliary_head') and self.auxiliary_head is not None | |
def with_decode_head(self) -> bool: | |
"""bool: whether the segmentor has decode head""" | |
return hasattr(self, 'decode_head') and self.decode_head is not None | |
def extract_feat(self, inputs: Tensor) -> bool: | |
"""Placeholder for extract features from images.""" | |
pass | |
def encode_decode(self, inputs: Tensor, batch_data_samples: SampleList): | |
"""Placeholder for encode images with backbone and decode into a | |
semantic segmentation map of the same size as input.""" | |
pass | |
def forward(self, | |
inputs: Tensor, | |
data_samples: OptSampleList = None, | |
mode: str = 'tensor') -> ForwardResults: | |
"""The unified entry for a forward process in both training and test. | |
The method should accept three modes: "tensor", "predict" and "loss": | |
- "tensor": Forward the whole network and return tensor or tuple of | |
tensor without any post-processing, same as a common nn.Module. | |
- "predict": Forward and return the predictions, which are fully | |
processed to a list of :obj:`SegDataSample`. | |
- "loss": Forward and return a dict of losses according to the given | |
inputs and data samples. | |
Note that this method doesn't handle neither back propagation nor | |
optimizer updating, which are done in the :meth:`train_step`. | |
Args: | |
inputs (torch.Tensor): The input tensor with shape (N, C, ...) in | |
general. | |
data_samples (list[:obj:`SegDataSample`]): The seg data samples. | |
It usually includes information such as `metainfo` and | |
`gt_sem_seg`. Default to None. | |
mode (str): Return what kind of value. Defaults to 'tensor'. | |
Returns: | |
The return type depends on ``mode``. | |
- If ``mode="tensor"``, return a tensor or a tuple of tensor. | |
- If ``mode="predict"``, return a list of :obj:`DetDataSample`. | |
- If ``mode="loss"``, return a dict of tensor. | |
""" | |
if mode == 'loss': | |
return self.loss(inputs, data_samples) | |
elif mode == 'predict': | |
return self.predict(inputs, data_samples) | |
elif mode == 'tensor': | |
return self._forward(inputs, data_samples) | |
else: | |
raise RuntimeError(f'Invalid mode "{mode}". ' | |
'Only supports loss, predict and tensor mode') | |
def loss(self, inputs: Tensor, data_samples: SampleList) -> dict: | |
"""Calculate losses from a batch of inputs and data samples.""" | |
pass | |
def predict(self, | |
inputs: Tensor, | |
data_samples: OptSampleList = None) -> SampleList: | |
"""Predict results from a batch of inputs and data samples with post- | |
processing.""" | |
pass | |
def _forward(self, | |
inputs: Tensor, | |
data_samples: OptSampleList = None) -> Tuple[List[Tensor]]: | |
"""Network forward process. | |
Usually includes backbone, neck and head forward without any post- | |
processing. | |
""" | |
pass | |
def postprocess_result(self, | |
seg_logits: Tensor, | |
data_samples: OptSampleList = None) -> SampleList: | |
""" Convert results list to `SegDataSample`. | |
Args: | |
seg_logits (Tensor): The segmentation results, seg_logits from | |
model of each input image. | |
data_samples (list[:obj:`SegDataSample`]): The seg data samples. | |
It usually includes information such as `metainfo` and | |
`gt_sem_seg`. Default to None. | |
Returns: | |
list[:obj:`SegDataSample`]: Segmentation results of the | |
input images. Each SegDataSample usually contain: | |
- ``pred_sem_seg``(PixelData): Prediction of semantic segmentation. | |
- ``seg_logits``(PixelData): Predicted logits of semantic | |
segmentation before normalization. | |
""" | |
batch_size, C, H, W = seg_logits.shape | |
if data_samples is None: | |
data_samples = [SegDataSample() for _ in range(batch_size)] | |
only_prediction = True | |
else: | |
only_prediction = False | |
for i in range(batch_size): | |
if not only_prediction: | |
img_meta = data_samples[i].metainfo | |
# remove padding area | |
if 'img_padding_size' not in img_meta: | |
padding_size = img_meta.get('padding_size', [0] * 4) | |
else: | |
padding_size = img_meta['img_padding_size'] | |
padding_left, padding_right, padding_top, padding_bottom =\ | |
padding_size | |
# i_seg_logits shape is 1, C, H, W after remove padding | |
i_seg_logits = seg_logits[i:i + 1, :, | |
padding_top:H - padding_bottom, | |
padding_left:W - padding_right] | |
flip = img_meta.get('flip', None) | |
if flip: | |
flip_direction = img_meta.get('flip_direction', None) | |
assert flip_direction in ['horizontal', 'vertical'] | |
if flip_direction == 'horizontal': | |
i_seg_logits = i_seg_logits.flip(dims=(3, )) | |
else: | |
i_seg_logits = i_seg_logits.flip(dims=(2, )) | |
# resize as original shape | |
i_seg_logits = resize( | |
i_seg_logits, | |
size=img_meta['ori_shape'], | |
mode='bilinear', | |
align_corners=self.align_corners, | |
warning=False).squeeze(0) | |
else: | |
i_seg_logits = seg_logits[i] | |
if C > 1: | |
i_seg_pred = i_seg_logits.argmax(dim=0, keepdim=True) | |
else: | |
i_seg_pred = (i_seg_logits > | |
self.decode_head.threshold).to(i_seg_logits) | |
data_samples[i].set_data({ | |
'seg_logits': | |
PixelData(**{'data': i_seg_logits}), | |
'pred_sem_seg': | |
PixelData(**{'data': i_seg_pred}) | |
}) | |
return data_samples | |