# Copyright (c) OpenMMLab. All rights reserved. from functools import partial import torch import torch.nn as nn import torch.nn.functional as F from mmdet.models.losses.utils import weighted_loss from mmdet.registry import MODELS @weighted_loss def quality_focal_loss(pred, target, beta=2.0): r"""Quality Focal Loss (QFL) is from `Generalized Focal Loss: Learning Qualified and Distributed Bounding Boxes for Dense Object Detection `_. Args: pred (torch.Tensor): Predicted joint representation of classification and quality (IoU) estimation with shape (N, C), C is the number of classes. target (tuple([torch.Tensor])): Target category label with shape (N,) and target quality label with shape (N,). beta (float): The beta parameter for calculating the modulating factor. Defaults to 2.0. Returns: torch.Tensor: Loss tensor with shape (N,). """ assert len(target) == 2, """target for QFL must be a tuple of two elements, including category label and quality label, respectively""" # label denotes the category id, score denotes the quality score label, score = target # negatives are supervised by 0 quality score pred_sigmoid = pred.sigmoid() scale_factor = pred_sigmoid zerolabel = scale_factor.new_zeros(pred.shape) loss = F.binary_cross_entropy_with_logits( pred, zerolabel, reduction='none') * scale_factor.pow(beta) # FG cat_id: [0, num_classes -1], BG cat_id: num_classes bg_class_ind = pred.size(1) pos = ((label >= 0) & (label < bg_class_ind)).nonzero().squeeze(1) pos_label = label[pos].long() # positives are supervised by bbox quality (IoU) score scale_factor = score[pos] - pred_sigmoid[pos, pos_label] loss[pos, pos_label] = F.binary_cross_entropy_with_logits( pred[pos, pos_label], score[pos], reduction='none') * scale_factor.abs().pow(beta) loss = loss.sum(dim=1, keepdim=False) return loss @weighted_loss def quality_focal_loss_tensor_target(pred, target, beta=2.0, activated=False): """`QualityFocal Loss `_ Args: pred (torch.Tensor): The prediction with shape (N, C), C is the number of classes target (torch.Tensor): The learning target of the iou-aware classification score with shape (N, C), C is the number of classes. beta (float): The beta parameter for calculating the modulating factor. Defaults to 2.0. activated (bool): Whether the input is activated. If True, it means the input has been activated and can be treated as probabilities. Else, it should be treated as logits. Defaults to False. """ # pred and target should be of the same size assert pred.size() == target.size() if activated: pred_sigmoid = pred loss_function = F.binary_cross_entropy else: pred_sigmoid = pred.sigmoid() loss_function = F.binary_cross_entropy_with_logits scale_factor = pred_sigmoid target = target.type_as(pred) zerolabel = scale_factor.new_zeros(pred.shape) loss = loss_function( pred, zerolabel, reduction='none') * scale_factor.pow(beta) pos = (target != 0) scale_factor = target[pos] - pred_sigmoid[pos] loss[pos] = loss_function( pred[pos], target[pos], reduction='none') * scale_factor.abs().pow(beta) loss = loss.sum(dim=1, keepdim=False) return loss @weighted_loss def quality_focal_loss_with_prob(pred, target, beta=2.0): r"""Quality Focal Loss (QFL) is from `Generalized Focal Loss: Learning Qualified and Distributed Bounding Boxes for Dense Object Detection `_. Different from `quality_focal_loss`, this function accepts probability as input. Args: pred (torch.Tensor): Predicted joint representation of classification and quality (IoU) estimation with shape (N, C), C is the number of classes. target (tuple([torch.Tensor])): Target category label with shape (N,) and target quality label with shape (N,). beta (float): The beta parameter for calculating the modulating factor. Defaults to 2.0. Returns: torch.Tensor: Loss tensor with shape (N,). """ assert len(target) == 2, """target for QFL must be a tuple of two elements, including category label and quality label, respectively""" # label denotes the category id, score denotes the quality score label, score = target # negatives are supervised by 0 quality score pred_sigmoid = pred scale_factor = pred_sigmoid zerolabel = scale_factor.new_zeros(pred.shape) loss = F.binary_cross_entropy( pred, zerolabel, reduction='none') * scale_factor.pow(beta) # FG cat_id: [0, num_classes -1], BG cat_id: num_classes bg_class_ind = pred.size(1) pos = ((label >= 0) & (label < bg_class_ind)).nonzero().squeeze(1) pos_label = label[pos].long() # positives are supervised by bbox quality (IoU) score scale_factor = score[pos] - pred_sigmoid[pos, pos_label] loss[pos, pos_label] = F.binary_cross_entropy( pred[pos, pos_label], score[pos], reduction='none') * scale_factor.abs().pow(beta) loss = loss.sum(dim=1, keepdim=False) return loss @weighted_loss def distribution_focal_loss(pred, label): r"""Distribution Focal Loss (DFL) is from `Generalized Focal Loss: Learning Qualified and Distributed Bounding Boxes for Dense Object Detection `_. Args: pred (torch.Tensor): Predicted general distribution of bounding boxes (before softmax) with shape (N, n+1), n is the max value of the integral set `{0, ..., n}` in paper. label (torch.Tensor): Target distance label for bounding boxes with shape (N,). Returns: torch.Tensor: Loss tensor with shape (N,). """ dis_left = label.long() dis_right = dis_left + 1 weight_left = dis_right.float() - label weight_right = label - dis_left.float() loss = F.cross_entropy(pred, dis_left, reduction='none') * weight_left \ + F.cross_entropy(pred, dis_right, reduction='none') * weight_right return loss @MODELS.register_module() class QualityFocalLoss(nn.Module): r"""Quality Focal Loss (QFL) is a variant of `Generalized Focal Loss: Learning Qualified and Distributed Bounding Boxes for Dense Object Detection `_. Args: use_sigmoid (bool): Whether sigmoid operation is conducted in QFL. Defaults to True. beta (float): The beta parameter for calculating the modulating factor. Defaults to 2.0. reduction (str): Options are "none", "mean" and "sum". loss_weight (float): Loss weight of current loss. activated (bool, optional): Whether the input is activated. If True, it means the input has been activated and can be treated as probabilities. Else, it should be treated as logits. Defaults to False. """ def __init__(self, use_sigmoid=True, beta=2.0, reduction='mean', loss_weight=1.0, activated=False): super(QualityFocalLoss, self).__init__() assert use_sigmoid is True, 'Only sigmoid in QFL supported now.' self.use_sigmoid = use_sigmoid self.beta = beta self.reduction = reduction self.loss_weight = loss_weight self.activated = activated def forward(self, pred, target, weight=None, avg_factor=None, reduction_override=None): """Forward function. Args: pred (torch.Tensor): Predicted joint representation of classification and quality (IoU) estimation with shape (N, C), C is the number of classes. target (Union(tuple([torch.Tensor]),Torch.Tensor)): The type is tuple, it should be included Target category label with shape (N,) and target quality label with shape (N,).The type is torch.Tensor, the target should be one-hot form with soft weights. weight (torch.Tensor, optional): The weight of loss for each prediction. Defaults to None. avg_factor (int, optional): Average factor that is used to average the loss. Defaults to None. reduction_override (str, optional): The reduction method used to override the original reduction method of the loss. Defaults to None. """ assert reduction_override in (None, 'none', 'mean', 'sum') reduction = ( reduction_override if reduction_override else self.reduction) if self.use_sigmoid: if self.activated: calculate_loss_func = quality_focal_loss_with_prob else: calculate_loss_func = quality_focal_loss if isinstance(target, torch.Tensor): # the target shape with (N,C) or (N,C,...), which means # the target is one-hot form with soft weights. calculate_loss_func = partial( quality_focal_loss_tensor_target, activated=self.activated) loss_cls = self.loss_weight * calculate_loss_func( pred, target, weight, beta=self.beta, reduction=reduction, avg_factor=avg_factor) else: raise NotImplementedError return loss_cls @MODELS.register_module() class DistributionFocalLoss(nn.Module): r"""Distribution Focal Loss (DFL) is a variant of `Generalized Focal Loss: Learning Qualified and Distributed Bounding Boxes for Dense Object Detection `_. Args: reduction (str): Options are `'none'`, `'mean'` and `'sum'`. loss_weight (float): Loss weight of current loss. """ def __init__(self, reduction='mean', loss_weight=1.0): super(DistributionFocalLoss, self).__init__() self.reduction = reduction self.loss_weight = loss_weight def forward(self, pred, target, weight=None, avg_factor=None, reduction_override=None): """Forward function. Args: pred (torch.Tensor): Predicted general distribution of bounding boxes (before softmax) with shape (N, n+1), n is the max value of the integral set `{0, ..., n}` in paper. target (torch.Tensor): Target distance label for bounding boxes with shape (N,). weight (torch.Tensor, optional): The weight of loss for each prediction. Defaults to None. avg_factor (int, optional): Average factor that is used to average the loss. Defaults to None. reduction_override (str, optional): The reduction method used to override the original reduction method of the loss. Defaults to None. """ assert reduction_override in (None, 'none', 'mean', 'sum') reduction = ( reduction_override if reduction_override else self.reduction) loss_cls = self.loss_weight * distribution_focal_loss( pred, target, weight, reduction=reduction, avg_factor=avg_factor) return loss_cls