# Copyright (c) OpenMMLab. All rights reserved. import torch def mask_matrix_nms(masks, labels, scores, filter_thr=-1, nms_pre=-1, max_num=-1, kernel='gaussian', sigma=2.0, mask_area=None): """Matrix NMS for multi-class masks. Args: masks (Tensor): Has shape (num_instances, h, w) labels (Tensor): Labels of corresponding masks, has shape (num_instances,). scores (Tensor): Mask scores of corresponding masks, has shape (num_instances). filter_thr (float): Score threshold to filter the masks after matrix nms. Default: -1, which means do not use filter_thr. nms_pre (int): The max number of instances to do the matrix nms. Default: -1, which means do not use nms_pre. max_num (int, optional): If there are more than max_num masks after matrix, only top max_num will be kept. Default: -1, which means do not use max_num. kernel (str): 'linear' or 'gaussian'. sigma (float): std in gaussian method. mask_area (Tensor): The sum of seg_masks. Returns: tuple(Tensor): Processed mask results. - scores (Tensor): Updated scores, has shape (n,). - labels (Tensor): Remained labels, has shape (n,). - masks (Tensor): Remained masks, has shape (n, w, h). - keep_inds (Tensor): The indices number of the remaining mask in the input mask, has shape (n,). """ assert len(labels) == len(masks) == len(scores) if len(labels) == 0: return scores.new_zeros(0), labels.new_zeros(0), masks.new_zeros( 0, *masks.shape[-2:]), labels.new_zeros(0) if mask_area is None: mask_area = masks.sum((1, 2)).float() else: assert len(masks) == len(mask_area) # sort and keep top nms_pre scores, sort_inds = torch.sort(scores, descending=True) keep_inds = sort_inds if nms_pre > 0 and len(sort_inds) > nms_pre: sort_inds = sort_inds[:nms_pre] keep_inds = keep_inds[:nms_pre] scores = scores[:nms_pre] masks = masks[sort_inds] mask_area = mask_area[sort_inds] labels = labels[sort_inds] num_masks = len(labels) flatten_masks = masks.reshape(num_masks, -1).float() # inter. inter_matrix = torch.mm(flatten_masks, flatten_masks.transpose(1, 0)) expanded_mask_area = mask_area.expand(num_masks, num_masks) # Upper triangle iou matrix. iou_matrix = (inter_matrix / (expanded_mask_area + expanded_mask_area.transpose(1, 0) - inter_matrix)).triu(diagonal=1) # label_specific matrix. expanded_labels = labels.expand(num_masks, num_masks) # Upper triangle label matrix. label_matrix = (expanded_labels == expanded_labels.transpose( 1, 0)).triu(diagonal=1) # IoU compensation compensate_iou, _ = (iou_matrix * label_matrix).max(0) compensate_iou = compensate_iou.expand(num_masks, num_masks).transpose(1, 0) # IoU decay decay_iou = iou_matrix * label_matrix # Calculate the decay_coefficient if kernel == 'gaussian': decay_matrix = torch.exp(-1 * sigma * (decay_iou**2)) compensate_matrix = torch.exp(-1 * sigma * (compensate_iou**2)) decay_coefficient, _ = (decay_matrix / compensate_matrix).min(0) elif kernel == 'linear': decay_matrix = (1 - decay_iou) / (1 - compensate_iou) decay_coefficient, _ = decay_matrix.min(0) else: raise NotImplementedError( f'{kernel} kernel is not supported in matrix nms!') # update the score. scores = scores * decay_coefficient if filter_thr > 0: keep = scores >= filter_thr keep_inds = keep_inds[keep] if not keep.any(): return scores.new_zeros(0), labels.new_zeros(0), masks.new_zeros( 0, *masks.shape[-2:]), labels.new_zeros(0) masks = masks[keep] scores = scores[keep] labels = labels[keep] # sort and keep top max_num scores, sort_inds = torch.sort(scores, descending=True) keep_inds = keep_inds[sort_inds] if max_num > 0 and len(sort_inds) > max_num: sort_inds = sort_inds[:max_num] keep_inds = keep_inds[:max_num] scores = scores[:max_num] masks = masks[sort_inds] labels = labels[sort_inds] return scores, labels, masks, keep_inds