Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
"""copy from | |
https://github.com/ZwwWayne/K-Net/blob/main/knet/det/mask_pseudo_sampler.py.""" | |
import torch | |
from torch import Tensor | |
from ..assigners import AssignResult | |
from .sampling_result import SamplingResult | |
class MaskSamplingResult(SamplingResult): | |
"""Mask sampling result.""" | |
def __init__(self, | |
pos_inds: Tensor, | |
neg_inds: Tensor, | |
masks: Tensor, | |
gt_masks: Tensor, | |
assign_result: AssignResult, | |
gt_flags: Tensor, | |
avg_factor_with_neg: bool = True) -> None: | |
self.pos_inds = pos_inds | |
self.neg_inds = neg_inds | |
self.num_pos = max(pos_inds.numel(), 1) | |
self.num_neg = max(neg_inds.numel(), 1) | |
self.avg_factor = self.num_pos + self.num_neg \ | |
if avg_factor_with_neg else self.num_pos | |
self.pos_masks = masks[pos_inds] | |
self.neg_masks = masks[neg_inds] | |
self.pos_is_gt = gt_flags[pos_inds] | |
self.num_gts = gt_masks.shape[0] | |
self.pos_assigned_gt_inds = assign_result.gt_inds[pos_inds] - 1 | |
if gt_masks.numel() == 0: | |
# hack for index error case | |
assert self.pos_assigned_gt_inds.numel() == 0 | |
self.pos_gt_masks = torch.empty_like(gt_masks) | |
else: | |
self.pos_gt_masks = gt_masks[self.pos_assigned_gt_inds, :] | |
def masks(self) -> Tensor: | |
"""torch.Tensor: concatenated positive and negative masks.""" | |
return torch.cat([self.pos_masks, self.neg_masks]) | |
def __nice__(self) -> str: | |
data = self.info.copy() | |
data['pos_masks'] = data.pop('pos_masks').shape | |
data['neg_masks'] = data.pop('neg_masks').shape | |
parts = [f"'{k}': {v!r}" for k, v in sorted(data.items())] | |
body = ' ' + ',\n '.join(parts) | |
return '{\n' + body + '\n}' | |
def info(self) -> dict: | |
"""Returns a dictionary of info about the object.""" | |
return { | |
'pos_inds': self.pos_inds, | |
'neg_inds': self.neg_inds, | |
'pos_masks': self.pos_masks, | |
'neg_masks': self.neg_masks, | |
'pos_is_gt': self.pos_is_gt, | |
'num_gts': self.num_gts, | |
'pos_assigned_gt_inds': self.pos_assigned_gt_inds, | |
} | |