# Copyright (c) OpenMMLab. All rights reserved. from abc import ABCMeta, abstractmethod import torch from mmengine.structures import InstanceData from mmdet.structures.bbox import BaseBoxes, cat_boxes from ..assigners import AssignResult from .sampling_result import SamplingResult class BaseSampler(metaclass=ABCMeta): """Base class of samplers. Args: num (int): Number of samples pos_fraction (float): Fraction of positive samples neg_pos_up (int): Upper bound number of negative and positive samples. Defaults to -1. add_gt_as_proposals (bool): Whether to add ground truth boxes as proposals. Defaults to True. """ def __init__(self, num: int, pos_fraction: float, neg_pos_ub: int = -1, add_gt_as_proposals: bool = True, **kwargs) -> None: self.num = num self.pos_fraction = pos_fraction self.neg_pos_ub = neg_pos_ub self.add_gt_as_proposals = add_gt_as_proposals self.pos_sampler = self self.neg_sampler = self @abstractmethod def _sample_pos(self, assign_result: AssignResult, num_expected: int, **kwargs): """Sample positive samples.""" pass @abstractmethod def _sample_neg(self, assign_result: AssignResult, num_expected: int, **kwargs): """Sample negative samples.""" pass def sample(self, assign_result: AssignResult, pred_instances: InstanceData, gt_instances: InstanceData, **kwargs) -> SamplingResult: """Sample positive and negative bboxes. This is a simple implementation of bbox sampling given candidates, assigning results and ground truth bboxes. Args: assign_result (:obj:`AssignResult`): Assigning results. pred_instances (:obj:`InstanceData`): Instances of model predictions. It includes ``priors``, and the priors can be anchors or points, or the bboxes predicted by the previous stage, has shape (n, 4). The bboxes predicted by the current model or stage will be named ``bboxes``, ``labels``, and ``scores``, the same as the ``InstanceData`` in other places. gt_instances (:obj:`InstanceData`): Ground truth of instance annotations. It usually includes ``bboxes``, with shape (k, 4), and ``labels``, with shape (k, ). Returns: :obj:`SamplingResult`: Sampling result. Example: >>> from mmengine.structures import InstanceData >>> from mmdet.models.task_modules.samplers import RandomSampler, >>> from mmdet.models.task_modules.assigners import AssignResult >>> from mmdet.models.task_modules.samplers. ... sampling_result import ensure_rng, random_boxes >>> rng = ensure_rng(None) >>> assign_result = AssignResult.random(rng=rng) >>> pred_instances = InstanceData() >>> pred_instances.priors = random_boxes(assign_result.num_preds, ... rng=rng) >>> gt_instances = InstanceData() >>> gt_instances.bboxes = random_boxes(assign_result.num_gts, ... rng=rng) >>> gt_instances.labels = torch.randint( ... 0, 5, (assign_result.num_gts,), dtype=torch.long) >>> self = RandomSampler(num=32, pos_fraction=0.5, neg_pos_ub=-1, >>> add_gt_as_proposals=False) >>> self = self.sample(assign_result, pred_instances, gt_instances) """ gt_bboxes = gt_instances.bboxes priors = pred_instances.priors gt_labels = gt_instances.labels if len(priors.shape) < 2: priors = priors[None, :] gt_flags = priors.new_zeros((priors.shape[0], ), dtype=torch.uint8) if self.add_gt_as_proposals and len(gt_bboxes) > 0: # When `gt_bboxes` and `priors` are all box type, convert # `gt_bboxes` type to `priors` type. if (isinstance(gt_bboxes, BaseBoxes) and isinstance(priors, BaseBoxes)): gt_bboxes_ = gt_bboxes.convert_to(type(priors)) else: gt_bboxes_ = gt_bboxes priors = cat_boxes([gt_bboxes_, priors], dim=0) assign_result.add_gt_(gt_labels) gt_ones = priors.new_ones(gt_bboxes_.shape[0], dtype=torch.uint8) gt_flags = torch.cat([gt_ones, gt_flags]) num_expected_pos = int(self.num * self.pos_fraction) pos_inds = self.pos_sampler._sample_pos( assign_result, num_expected_pos, bboxes=priors, **kwargs) # We found that sampled indices have duplicated items occasionally. # (may be a bug of PyTorch) pos_inds = pos_inds.unique() num_sampled_pos = pos_inds.numel() num_expected_neg = self.num - num_sampled_pos if self.neg_pos_ub >= 0: _pos = max(1, num_sampled_pos) neg_upper_bound = int(self.neg_pos_ub * _pos) if num_expected_neg > neg_upper_bound: num_expected_neg = neg_upper_bound neg_inds = self.neg_sampler._sample_neg( assign_result, num_expected_neg, bboxes=priors, **kwargs) neg_inds = neg_inds.unique() sampling_result = SamplingResult( pos_inds=pos_inds, neg_inds=neg_inds, priors=priors, gt_bboxes=gt_bboxes, assign_result=assign_result, gt_flags=gt_flags) return sampling_result