KyanChen's picture
init
f549064
raw
history blame
4.23 kB
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmdet.registry import TASK_UTILS
from mmdet.structures.bbox import bbox2roi
from .base_sampler import BaseSampler
@TASK_UTILS.register_module()
class OHEMSampler(BaseSampler):
r"""Online Hard Example Mining Sampler described in `Training Region-based
Object Detectors with Online Hard Example Mining
<https://arxiv.org/abs/1604.03540>`_.
"""
def __init__(self,
num,
pos_fraction,
context,
neg_pos_ub=-1,
add_gt_as_proposals=True,
loss_key='loss_cls',
**kwargs):
super(OHEMSampler, self).__init__(num, pos_fraction, neg_pos_ub,
add_gt_as_proposals)
self.context = context
if not hasattr(self.context, 'num_stages'):
self.bbox_head = self.context.bbox_head
else:
self.bbox_head = self.context.bbox_head[self.context.current_stage]
self.loss_key = loss_key
def hard_mining(self, inds, num_expected, bboxes, labels, feats):
with torch.no_grad():
rois = bbox2roi([bboxes])
if not hasattr(self.context, 'num_stages'):
bbox_results = self.context._bbox_forward(feats, rois)
else:
bbox_results = self.context._bbox_forward(
self.context.current_stage, feats, rois)
cls_score = bbox_results['cls_score']
loss = self.bbox_head.loss(
cls_score=cls_score,
bbox_pred=None,
rois=rois,
labels=labels,
label_weights=cls_score.new_ones(cls_score.size(0)),
bbox_targets=None,
bbox_weights=None,
reduction_override='none')[self.loss_key]
_, topk_loss_inds = loss.topk(num_expected)
return inds[topk_loss_inds]
def _sample_pos(self,
assign_result,
num_expected,
bboxes=None,
feats=None,
**kwargs):
"""Sample positive boxes.
Args:
assign_result (:obj:`AssignResult`): Assigned results
num_expected (int): Number of expected positive samples
bboxes (torch.Tensor, optional): Boxes. Defaults to None.
feats (list[torch.Tensor], optional): Multi-level features.
Defaults to None.
Returns:
torch.Tensor: Indices of positive samples
"""
# Sample some hard positive samples
pos_inds = torch.nonzero(assign_result.gt_inds > 0, as_tuple=False)
if pos_inds.numel() != 0:
pos_inds = pos_inds.squeeze(1)
if pos_inds.numel() <= num_expected:
return pos_inds
else:
return self.hard_mining(pos_inds, num_expected, bboxes[pos_inds],
assign_result.labels[pos_inds], feats)
def _sample_neg(self,
assign_result,
num_expected,
bboxes=None,
feats=None,
**kwargs):
"""Sample negative boxes.
Args:
assign_result (:obj:`AssignResult`): Assigned results
num_expected (int): Number of expected negative samples
bboxes (torch.Tensor, optional): Boxes. Defaults to None.
feats (list[torch.Tensor], optional): Multi-level features.
Defaults to None.
Returns:
torch.Tensor: Indices of negative samples
"""
# Sample some hard negative samples
neg_inds = torch.nonzero(assign_result.gt_inds == 0, as_tuple=False)
if neg_inds.numel() != 0:
neg_inds = neg_inds.squeeze(1)
if len(neg_inds) <= num_expected:
return neg_inds
else:
neg_labels = assign_result.labels.new_empty(
neg_inds.size(0)).fill_(self.bbox_head.num_classes)
return self.hard_mining(neg_inds, num_expected, bboxes[neg_inds],
neg_labels, feats)