Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import copy | |
from typing import Dict, List, Optional, Tuple, Union | |
import torch | |
import torch.nn as nn | |
from torch import Tensor | |
from mmdet.models.utils import (filter_gt_instances, rename_loss_dict, | |
reweight_loss_dict) | |
from mmdet.registry import MODELS | |
from mmdet.structures import SampleList | |
from mmdet.structures.bbox import bbox_project | |
from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig | |
from .base import BaseDetector | |
class SemiBaseDetector(BaseDetector): | |
"""Base class for semi-supervised detectors. | |
Semi-supervised detectors typically consisting of a teacher model | |
updated by exponential moving average and a student model updated | |
by gradient descent. | |
Args: | |
detector (:obj:`ConfigDict` or dict): The detector config. | |
semi_train_cfg (:obj:`ConfigDict` or dict, optional): | |
The semi-supervised training config. | |
semi_test_cfg (:obj:`ConfigDict` or dict, optional): | |
The semi-supervised testing config. | |
data_preprocessor (:obj:`ConfigDict` or dict, optional): Config of | |
:class:`DetDataPreprocessor` to process the input data. | |
Defaults to None. | |
init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or | |
list[dict], optional): Initialization config dict. | |
Defaults to None. | |
""" | |
def __init__(self, | |
detector: ConfigType, | |
semi_train_cfg: OptConfigType = None, | |
semi_test_cfg: OptConfigType = None, | |
data_preprocessor: OptConfigType = None, | |
init_cfg: OptMultiConfig = None) -> None: | |
super().__init__( | |
data_preprocessor=data_preprocessor, init_cfg=init_cfg) | |
self.student = MODELS.build(detector) | |
self.teacher = MODELS.build(detector) | |
self.semi_train_cfg = semi_train_cfg | |
self.semi_test_cfg = semi_test_cfg | |
if self.semi_train_cfg.get('freeze_teacher', True) is True: | |
self.freeze(self.teacher) | |
def freeze(model: nn.Module): | |
"""Freeze the model.""" | |
model.eval() | |
for param in model.parameters(): | |
param.requires_grad = False | |
def loss(self, multi_batch_inputs: Dict[str, Tensor], | |
multi_batch_data_samples: Dict[str, SampleList]) -> dict: | |
"""Calculate losses from multi-branch inputs and data samples. | |
Args: | |
multi_batch_inputs (Dict[str, Tensor]): The dict of multi-branch | |
input images, each value with shape (N, C, H, W). | |
Each value should usually be mean centered and std scaled. | |
multi_batch_data_samples (Dict[str, List[:obj:`DetDataSample`]]): | |
The dict of multi-branch data samples. | |
Returns: | |
dict: A dictionary of loss components | |
""" | |
losses = dict() | |
losses.update(**self.loss_by_gt_instances( | |
multi_batch_inputs['sup'], multi_batch_data_samples['sup'])) | |
origin_pseudo_data_samples, batch_info = self.get_pseudo_instances( | |
multi_batch_inputs['unsup_teacher'], | |
multi_batch_data_samples['unsup_teacher']) | |
multi_batch_data_samples[ | |
'unsup_student'] = self.project_pseudo_instances( | |
origin_pseudo_data_samples, | |
multi_batch_data_samples['unsup_student']) | |
losses.update(**self.loss_by_pseudo_instances( | |
multi_batch_inputs['unsup_student'], | |
multi_batch_data_samples['unsup_student'], batch_info)) | |
return losses | |
def loss_by_gt_instances(self, batch_inputs: Tensor, | |
batch_data_samples: SampleList) -> dict: | |
"""Calculate losses from a batch of inputs and ground-truth data | |
samples. | |
Args: | |
batch_inputs (Tensor): Input images of shape (N, C, H, W). | |
These should usually be mean centered and std scaled. | |
batch_data_samples (List[:obj:`DetDataSample`]): The batch | |
data samples. It usually includes information such | |
as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. | |
Returns: | |
dict: A dictionary of loss components | |
""" | |
losses = self.student.loss(batch_inputs, batch_data_samples) | |
sup_weight = self.semi_train_cfg.get('sup_weight', 1.) | |
return rename_loss_dict('sup_', reweight_loss_dict(losses, sup_weight)) | |
def loss_by_pseudo_instances(self, | |
batch_inputs: Tensor, | |
batch_data_samples: SampleList, | |
batch_info: Optional[dict] = None) -> dict: | |
"""Calculate losses from a batch of inputs and pseudo data samples. | |
Args: | |
batch_inputs (Tensor): Input images of shape (N, C, H, W). | |
These should usually be mean centered and std scaled. | |
batch_data_samples (List[:obj:`DetDataSample`]): The batch | |
data samples. It usually includes information such | |
as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`, | |
which are `pseudo_instance` or `pseudo_panoptic_seg` | |
or `pseudo_sem_seg` in fact. | |
batch_info (dict): Batch information of teacher model | |
forward propagation process. Defaults to None. | |
Returns: | |
dict: A dictionary of loss components | |
""" | |
batch_data_samples = filter_gt_instances( | |
batch_data_samples, score_thr=self.semi_train_cfg.cls_pseudo_thr) | |
losses = self.student.loss(batch_inputs, batch_data_samples) | |
pseudo_instances_num = sum([ | |
len(data_samples.gt_instances) | |
for data_samples in batch_data_samples | |
]) | |
unsup_weight = self.semi_train_cfg.get( | |
'unsup_weight', 1.) if pseudo_instances_num > 0 else 0. | |
return rename_loss_dict('unsup_', | |
reweight_loss_dict(losses, unsup_weight)) | |
def get_pseudo_instances( | |
self, batch_inputs: Tensor, batch_data_samples: SampleList | |
) -> Tuple[SampleList, Optional[dict]]: | |
"""Get pseudo instances from teacher model.""" | |
self.teacher.eval() | |
results_list = self.teacher.predict( | |
batch_inputs, batch_data_samples, rescale=False) | |
batch_info = {} | |
for data_samples, results in zip(batch_data_samples, results_list): | |
data_samples.gt_instances = results.pred_instances | |
data_samples.gt_instances.bboxes = bbox_project( | |
data_samples.gt_instances.bboxes, | |
torch.from_numpy(data_samples.homography_matrix).inverse().to( | |
self.data_preprocessor.device), data_samples.ori_shape) | |
return batch_data_samples, batch_info | |
def project_pseudo_instances(self, batch_pseudo_instances: SampleList, | |
batch_data_samples: SampleList) -> SampleList: | |
"""Project pseudo instances.""" | |
for pseudo_instances, data_samples in zip(batch_pseudo_instances, | |
batch_data_samples): | |
data_samples.gt_instances = copy.deepcopy( | |
pseudo_instances.gt_instances) | |
data_samples.gt_instances.bboxes = bbox_project( | |
data_samples.gt_instances.bboxes, | |
torch.tensor(data_samples.homography_matrix).to( | |
self.data_preprocessor.device), data_samples.img_shape) | |
wh_thr = self.semi_train_cfg.get('min_pseudo_bbox_wh', (1e-2, 1e-2)) | |
return filter_gt_instances(batch_data_samples, wh_thr=wh_thr) | |
def predict(self, batch_inputs: Tensor, | |
batch_data_samples: SampleList) -> SampleList: | |
"""Predict results from a batch of inputs and data samples with post- | |
processing. | |
Args: | |
batch_inputs (Tensor): Inputs with shape (N, C, H, W). | |
batch_data_samples (List[:obj:`DetDataSample`]): The Data | |
Samples. It usually includes information such as | |
`gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. | |
rescale (bool): Whether to rescale the results. | |
Defaults to True. | |
Returns: | |
list[:obj:`DetDataSample`]: Return the detection results of the | |
input images. The returns value is DetDataSample, | |
which usually contain 'pred_instances'. And the | |
``pred_instances`` usually contains following keys. | |
- scores (Tensor): Classification scores, has a shape | |
(num_instance, ) | |
- labels (Tensor): Labels of bboxes, has a shape | |
(num_instances, ). | |
- bboxes (Tensor): Has a shape (num_instances, 4), | |
the last dimension 4 arrange as (x1, y1, x2, y2). | |
- masks (Tensor): Has a shape (num_instances, H, W). | |
""" | |
if self.semi_test_cfg.get('predict_on', 'teacher') == 'teacher': | |
return self.teacher( | |
batch_inputs, batch_data_samples, mode='predict') | |
else: | |
return self.student( | |
batch_inputs, batch_data_samples, mode='predict') | |
def _forward(self, batch_inputs: Tensor, | |
batch_data_samples: SampleList) -> SampleList: | |
"""Network forward process. Usually includes backbone, neck and head | |
forward without any post-processing. | |
Args: | |
batch_inputs (Tensor): Inputs with shape (N, C, H, W). | |
Returns: | |
tuple: A tuple of features from ``rpn_head`` and ``roi_head`` | |
forward. | |
""" | |
if self.semi_test_cfg.get('forward_on', 'teacher') == 'teacher': | |
return self.teacher( | |
batch_inputs, batch_data_samples, mode='tensor') | |
else: | |
return self.student( | |
batch_inputs, batch_data_samples, mode='tensor') | |
def extract_feat(self, batch_inputs: Tensor) -> Tuple[Tensor]: | |
"""Extract features. | |
Args: | |
batch_inputs (Tensor): Image tensor with shape (N, C, H ,W). | |
Returns: | |
tuple[Tensor]: Multi-level features that may have | |
different resolutions. | |
""" | |
if self.semi_test_cfg.get('extract_feat_on', 'teacher') == 'teacher': | |
return self.teacher.extract_feat(batch_inputs) | |
else: | |
return self.student.extract_feat(batch_inputs) | |
def _load_from_state_dict(self, state_dict: dict, prefix: str, | |
local_metadata: dict, strict: bool, | |
missing_keys: Union[List[str], str], | |
unexpected_keys: Union[List[str], str], | |
error_msgs: Union[List[str], str]) -> None: | |
"""Add teacher and student prefixes to model parameter names.""" | |
if not any([ | |
'student' in key or 'teacher' in key | |
for key in state_dict.keys() | |
]): | |
keys = list(state_dict.keys()) | |
state_dict.update({'teacher.' + k: state_dict[k] for k in keys}) | |
state_dict.update({'student.' + k: state_dict[k] for k in keys}) | |
for k in keys: | |
state_dict.pop(k) | |
return super()._load_from_state_dict( | |
state_dict, | |
prefix, | |
local_metadata, | |
strict, | |
missing_keys, | |
unexpected_keys, | |
error_msgs, | |
) | |