KyanChen's picture
init
f549064
# Copyright (c) OpenMMLab. All rights reserved.
import copy
from typing import Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from torch import Tensor
from mmdet.models.utils import (filter_gt_instances, rename_loss_dict,
reweight_loss_dict)
from mmdet.registry import MODELS
from mmdet.structures import SampleList
from mmdet.structures.bbox import bbox_project
from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig
from .base import BaseDetector
@MODELS.register_module()
class SemiBaseDetector(BaseDetector):
"""Base class for semi-supervised detectors.
Semi-supervised detectors typically consisting of a teacher model
updated by exponential moving average and a student model updated
by gradient descent.
Args:
detector (:obj:`ConfigDict` or dict): The detector config.
semi_train_cfg (:obj:`ConfigDict` or dict, optional):
The semi-supervised training config.
semi_test_cfg (:obj:`ConfigDict` or dict, optional):
The semi-supervised testing config.
data_preprocessor (:obj:`ConfigDict` or dict, optional): Config of
:class:`DetDataPreprocessor` to process the input data.
Defaults to None.
init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or
list[dict], optional): Initialization config dict.
Defaults to None.
"""
def __init__(self,
detector: ConfigType,
semi_train_cfg: OptConfigType = None,
semi_test_cfg: OptConfigType = None,
data_preprocessor: OptConfigType = None,
init_cfg: OptMultiConfig = None) -> None:
super().__init__(
data_preprocessor=data_preprocessor, init_cfg=init_cfg)
self.student = MODELS.build(detector)
self.teacher = MODELS.build(detector)
self.semi_train_cfg = semi_train_cfg
self.semi_test_cfg = semi_test_cfg
if self.semi_train_cfg.get('freeze_teacher', True) is True:
self.freeze(self.teacher)
@staticmethod
def freeze(model: nn.Module):
"""Freeze the model."""
model.eval()
for param in model.parameters():
param.requires_grad = False
def loss(self, multi_batch_inputs: Dict[str, Tensor],
multi_batch_data_samples: Dict[str, SampleList]) -> dict:
"""Calculate losses from multi-branch inputs and data samples.
Args:
multi_batch_inputs (Dict[str, Tensor]): The dict of multi-branch
input images, each value with shape (N, C, H, W).
Each value should usually be mean centered and std scaled.
multi_batch_data_samples (Dict[str, List[:obj:`DetDataSample`]]):
The dict of multi-branch data samples.
Returns:
dict: A dictionary of loss components
"""
losses = dict()
losses.update(**self.loss_by_gt_instances(
multi_batch_inputs['sup'], multi_batch_data_samples['sup']))
origin_pseudo_data_samples, batch_info = self.get_pseudo_instances(
multi_batch_inputs['unsup_teacher'],
multi_batch_data_samples['unsup_teacher'])
multi_batch_data_samples[
'unsup_student'] = self.project_pseudo_instances(
origin_pseudo_data_samples,
multi_batch_data_samples['unsup_student'])
losses.update(**self.loss_by_pseudo_instances(
multi_batch_inputs['unsup_student'],
multi_batch_data_samples['unsup_student'], batch_info))
return losses
def loss_by_gt_instances(self, batch_inputs: Tensor,
batch_data_samples: SampleList) -> dict:
"""Calculate losses from a batch of inputs and ground-truth data
samples.
Args:
batch_inputs (Tensor): Input images of shape (N, C, H, W).
These should usually be mean centered and std scaled.
batch_data_samples (List[:obj:`DetDataSample`]): The batch
data samples. It usually includes information such
as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
Returns:
dict: A dictionary of loss components
"""
losses = self.student.loss(batch_inputs, batch_data_samples)
sup_weight = self.semi_train_cfg.get('sup_weight', 1.)
return rename_loss_dict('sup_', reweight_loss_dict(losses, sup_weight))
def loss_by_pseudo_instances(self,
batch_inputs: Tensor,
batch_data_samples: SampleList,
batch_info: Optional[dict] = None) -> dict:
"""Calculate losses from a batch of inputs and pseudo data samples.
Args:
batch_inputs (Tensor): Input images of shape (N, C, H, W).
These should usually be mean centered and std scaled.
batch_data_samples (List[:obj:`DetDataSample`]): The batch
data samples. It usually includes information such
as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`,
which are `pseudo_instance` or `pseudo_panoptic_seg`
or `pseudo_sem_seg` in fact.
batch_info (dict): Batch information of teacher model
forward propagation process. Defaults to None.
Returns:
dict: A dictionary of loss components
"""
batch_data_samples = filter_gt_instances(
batch_data_samples, score_thr=self.semi_train_cfg.cls_pseudo_thr)
losses = self.student.loss(batch_inputs, batch_data_samples)
pseudo_instances_num = sum([
len(data_samples.gt_instances)
for data_samples in batch_data_samples
])
unsup_weight = self.semi_train_cfg.get(
'unsup_weight', 1.) if pseudo_instances_num > 0 else 0.
return rename_loss_dict('unsup_',
reweight_loss_dict(losses, unsup_weight))
@torch.no_grad()
def get_pseudo_instances(
self, batch_inputs: Tensor, batch_data_samples: SampleList
) -> Tuple[SampleList, Optional[dict]]:
"""Get pseudo instances from teacher model."""
self.teacher.eval()
results_list = self.teacher.predict(
batch_inputs, batch_data_samples, rescale=False)
batch_info = {}
for data_samples, results in zip(batch_data_samples, results_list):
data_samples.gt_instances = results.pred_instances
data_samples.gt_instances.bboxes = bbox_project(
data_samples.gt_instances.bboxes,
torch.from_numpy(data_samples.homography_matrix).inverse().to(
self.data_preprocessor.device), data_samples.ori_shape)
return batch_data_samples, batch_info
def project_pseudo_instances(self, batch_pseudo_instances: SampleList,
batch_data_samples: SampleList) -> SampleList:
"""Project pseudo instances."""
for pseudo_instances, data_samples in zip(batch_pseudo_instances,
batch_data_samples):
data_samples.gt_instances = copy.deepcopy(
pseudo_instances.gt_instances)
data_samples.gt_instances.bboxes = bbox_project(
data_samples.gt_instances.bboxes,
torch.tensor(data_samples.homography_matrix).to(
self.data_preprocessor.device), data_samples.img_shape)
wh_thr = self.semi_train_cfg.get('min_pseudo_bbox_wh', (1e-2, 1e-2))
return filter_gt_instances(batch_data_samples, wh_thr=wh_thr)
def predict(self, batch_inputs: Tensor,
batch_data_samples: SampleList) -> SampleList:
"""Predict results from a batch of inputs and data samples with post-
processing.
Args:
batch_inputs (Tensor): Inputs with shape (N, C, H, W).
batch_data_samples (List[:obj:`DetDataSample`]): The Data
Samples. It usually includes information such as
`gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
rescale (bool): Whether to rescale the results.
Defaults to True.
Returns:
list[:obj:`DetDataSample`]: Return the detection results of the
input images. The returns value is DetDataSample,
which usually contain 'pred_instances'. And the
``pred_instances`` usually contains following keys.
- scores (Tensor): Classification scores, has a shape
(num_instance, )
- labels (Tensor): Labels of bboxes, has a shape
(num_instances, ).
- bboxes (Tensor): Has a shape (num_instances, 4),
the last dimension 4 arrange as (x1, y1, x2, y2).
- masks (Tensor): Has a shape (num_instances, H, W).
"""
if self.semi_test_cfg.get('predict_on', 'teacher') == 'teacher':
return self.teacher(
batch_inputs, batch_data_samples, mode='predict')
else:
return self.student(
batch_inputs, batch_data_samples, mode='predict')
def _forward(self, batch_inputs: Tensor,
batch_data_samples: SampleList) -> SampleList:
"""Network forward process. Usually includes backbone, neck and head
forward without any post-processing.
Args:
batch_inputs (Tensor): Inputs with shape (N, C, H, W).
Returns:
tuple: A tuple of features from ``rpn_head`` and ``roi_head``
forward.
"""
if self.semi_test_cfg.get('forward_on', 'teacher') == 'teacher':
return self.teacher(
batch_inputs, batch_data_samples, mode='tensor')
else:
return self.student(
batch_inputs, batch_data_samples, mode='tensor')
def extract_feat(self, batch_inputs: Tensor) -> Tuple[Tensor]:
"""Extract features.
Args:
batch_inputs (Tensor): Image tensor with shape (N, C, H ,W).
Returns:
tuple[Tensor]: Multi-level features that may have
different resolutions.
"""
if self.semi_test_cfg.get('extract_feat_on', 'teacher') == 'teacher':
return self.teacher.extract_feat(batch_inputs)
else:
return self.student.extract_feat(batch_inputs)
def _load_from_state_dict(self, state_dict: dict, prefix: str,
local_metadata: dict, strict: bool,
missing_keys: Union[List[str], str],
unexpected_keys: Union[List[str], str],
error_msgs: Union[List[str], str]) -> None:
"""Add teacher and student prefixes to model parameter names."""
if not any([
'student' in key or 'teacher' in key
for key in state_dict.keys()
]):
keys = list(state_dict.keys())
state_dict.update({'teacher.' + k: state_dict[k] for k in keys})
state_dict.update({'student.' + k: state_dict[k] for k in keys})
for k in keys:
state_dict.pop(k)
return super()._load_from_state_dict(
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
)