# Copyright (c) OpenMMLab. All rights reserved. from typing import Union import torch from numpy import ndarray from torch import Tensor from mmdet.registry import TASK_UTILS from ..assigners import AssignResult from .base_sampler import BaseSampler @TASK_UTILS.register_module() class RandomSampler(BaseSampler): """Random sampler. Args: num (int): Number of samples pos_fraction (float): Fraction of positive samples neg_pos_up (int): Upper bound number of negative and positive samples. Defaults to -1. add_gt_as_proposals (bool): Whether to add ground truth boxes as proposals. Defaults to True. """ def __init__(self, num: int, pos_fraction: float, neg_pos_ub: int = -1, add_gt_as_proposals: bool = True, **kwargs): from .sampling_result import ensure_rng super().__init__( num=num, pos_fraction=pos_fraction, neg_pos_ub=neg_pos_ub, add_gt_as_proposals=add_gt_as_proposals) self.rng = ensure_rng(kwargs.get('rng', None)) def random_choice(self, gallery: Union[Tensor, ndarray, list], num: int) -> Union[Tensor, ndarray]: """Random select some elements from the gallery. If `gallery` is a Tensor, the returned indices will be a Tensor; If `gallery` is a ndarray or list, the returned indices will be a ndarray. Args: gallery (Tensor | ndarray | list): indices pool. num (int): expected sample num. Returns: Tensor or ndarray: sampled indices. """ assert len(gallery) >= num is_tensor = isinstance(gallery, torch.Tensor) if not is_tensor: if torch.cuda.is_available(): device = torch.cuda.current_device() else: device = 'cpu' gallery = torch.tensor(gallery, dtype=torch.long, device=device) # This is a temporary fix. We can revert the following code # when PyTorch fixes the abnormal return of torch.randperm. # See: https://github.com/open-mmlab/mmdetection/pull/5014 perm = torch.randperm(gallery.numel())[:num].to(device=gallery.device) rand_inds = gallery[perm] if not is_tensor: rand_inds = rand_inds.cpu().numpy() return rand_inds def _sample_pos(self, assign_result: AssignResult, num_expected: int, **kwargs) -> Union[Tensor, ndarray]: """Randomly sample some positive samples. Args: assign_result (:obj:`AssignResult`): Bbox assigning results. num_expected (int): The number of expected positive samples Returns: Tensor or ndarray: sampled indices. """ pos_inds = torch.nonzero(assign_result.gt_inds > 0, as_tuple=False) if pos_inds.numel() != 0: pos_inds = pos_inds.squeeze(1) if pos_inds.numel() <= num_expected: return pos_inds else: return self.random_choice(pos_inds, num_expected) def _sample_neg(self, assign_result: AssignResult, num_expected: int, **kwargs) -> Union[Tensor, ndarray]: """Randomly sample some negative samples. Args: assign_result (:obj:`AssignResult`): Bbox assigning results. num_expected (int): The number of expected positive samples Returns: Tensor or ndarray: sampled indices. """ neg_inds = torch.nonzero(assign_result.gt_inds == 0, as_tuple=False) if neg_inds.numel() != 0: neg_inds = neg_inds.squeeze(1) if len(neg_inds) <= num_expected: return neg_inds else: return self.random_choice(neg_inds, num_expected)