Spaces:
Runtime error
Runtime error
File size: 4,841 Bytes
3b96cb1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Union
from torch import Tensor
from mmdet.models.mot import BaseMOTModel
from mmdet.registry import MODELS
from mmdet.structures import TrackDataSample, TrackSampleList
from mmdet.utils import OptConfigType, OptMultiConfig
@MODELS.register_module()
class Mask2FormerVideo(BaseMOTModel):
r"""Implementation of `Masked-attention Mask
Transformer for Universal Image Segmentation
<https://arxiv.org/pdf/2112.01527>`_.
Args:
backbone (dict): Configuration of backbone. Defaults to None.
track_head (dict): Configuration of track head. Defaults to None.
data_preprocessor (dict or ConfigDict, optional): The pre-process
config of :class:`TrackDataPreprocessor`. it usually includes,
``pad_size_divisor``, ``pad_value``, ``mean`` and ``std``.
Defaults to None.
init_cfg (dict or list[dict]): Configuration of initialization.
Defaults to None.
"""
def __init__(self,
backbone: Optional[dict] = None,
track_head: Optional[dict] = None,
data_preprocessor: OptConfigType = None,
init_cfg: OptMultiConfig = None):
super(BaseMOTModel, self).__init__(
data_preprocessor=data_preprocessor, init_cfg=init_cfg)
if backbone is not None:
self.backbone = MODELS.build(backbone)
if track_head is not None:
self.track_head = MODELS.build(track_head)
self.num_classes = self.track_head.num_classes
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
"""Overload in order to load mmdet pretrained ckpt."""
for key in list(state_dict):
if key.startswith('panoptic_head'):
state_dict[key.replace('panoptic',
'track')] = state_dict.pop(key)
super()._load_from_state_dict(state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys,
error_msgs)
def loss(self, inputs: Tensor, data_samples: TrackSampleList,
**kwargs) -> Union[dict, tuple]:
"""
Args:
inputs (Tensor): Input images of shape (N, T, C, H, W).
These should usually be mean centered and std scaled.
data_samples (list[:obj:`TrackDataSample`]): The batch
data samples. It usually includes information such
as `gt_instance`.
Returns:
dict[str, Tensor]: a dictionary of loss components
"""
assert inputs.dim() == 5, 'The img must be 5D Tensor (N, T, C, H, W).'
# shape (N * T, C, H, W)
img = inputs.flatten(0, 1)
x = self.backbone(img)
losses = self.track_head.loss(x, data_samples)
return losses
def predict(self,
inputs: Tensor,
data_samples: TrackSampleList,
rescale: bool = True) -> TrackSampleList:
"""Predict results from a batch of inputs and data samples with
postprocessing.
Args:
inputs (Tensor): of shape (N, T, C, H, W) encoding
input images. The N denotes batch size.
The T denotes the number of frames in a video.
data_samples (list[:obj:`TrackDataSample`]): The batch
data samples. It usually includes information such
as `video_data_samples`.
rescale (bool, Optional): If False, then returned bboxes and masks
will fit the scale of img, otherwise, returned bboxes and masks
will fit the scale of original image shape. Defaults to True.
Returns:
TrackSampleList: Tracking results of the inputs.
"""
assert inputs.dim() == 5, 'The img must be 5D Tensor (N, T, C, H, W).'
assert len(data_samples) == 1, \
'Mask2former only support 1 batch size per gpu for now.'
# [T, C, H, W]
img = inputs[0]
track_data_sample = data_samples[0]
feats = self.backbone(img)
pred_track_ins_list = self.track_head.predict(feats, track_data_sample,
rescale)
det_data_samples_list = []
for idx, pred_track_ins in enumerate(pred_track_ins_list):
img_data_sample = track_data_sample[idx]
img_data_sample.pred_track_instances = pred_track_ins
det_data_samples_list.append(img_data_sample)
results = TrackDataSample()
results.video_data_samples = det_data_samples_list
return [results]
|