# Copyright (c) OpenMMLab. All rights reserved. """copy from https://github.com/ZwwWayne/K-Net/blob/main/knet/det/mask_pseudo_sampler.py.""" import torch from torch import Tensor from ..assigners import AssignResult from .sampling_result import SamplingResult class MaskSamplingResult(SamplingResult): """Mask sampling result.""" def __init__(self, pos_inds: Tensor, neg_inds: Tensor, masks: Tensor, gt_masks: Tensor, assign_result: AssignResult, gt_flags: Tensor, avg_factor_with_neg: bool = True) -> None: self.pos_inds = pos_inds self.neg_inds = neg_inds self.num_pos = max(pos_inds.numel(), 1) self.num_neg = max(neg_inds.numel(), 1) self.avg_factor = self.num_pos + self.num_neg \ if avg_factor_with_neg else self.num_pos self.pos_masks = masks[pos_inds] self.neg_masks = masks[neg_inds] self.pos_is_gt = gt_flags[pos_inds] self.num_gts = gt_masks.shape[0] self.pos_assigned_gt_inds = assign_result.gt_inds[pos_inds] - 1 if gt_masks.numel() == 0: # hack for index error case assert self.pos_assigned_gt_inds.numel() == 0 self.pos_gt_masks = torch.empty_like(gt_masks) else: self.pos_gt_masks = gt_masks[self.pos_assigned_gt_inds, :] @property def masks(self) -> Tensor: """torch.Tensor: concatenated positive and negative masks.""" return torch.cat([self.pos_masks, self.neg_masks]) def __nice__(self) -> str: data = self.info.copy() data['pos_masks'] = data.pop('pos_masks').shape data['neg_masks'] = data.pop('neg_masks').shape parts = [f"'{k}': {v!r}" for k, v in sorted(data.items())] body = ' ' + ',\n '.join(parts) return '{\n' + body + '\n}' @property def info(self) -> dict: """Returns a dictionary of info about the object.""" return { 'pos_inds': self.pos_inds, 'neg_inds': self.neg_inds, 'pos_masks': self.pos_masks, 'neg_masks': self.neg_masks, 'pos_is_gt': self.pos_is_gt, 'num_gts': self.num_gts, 'pos_assigned_gt_inds': self.pos_assigned_gt_inds, }