# Copyright (c) OpenMMLab. All rights reserved. import torch.nn as nn import torch.nn.functional as F from mmcls.registry import MODELS from .utils import weight_reduce_loss def cross_entropy(pred, label, weight=None, reduction='mean', avg_factor=None, class_weight=None): """Calculate the CrossEntropy loss. Args: pred (torch.Tensor): The prediction with shape (N, C), C is the number of classes. label (torch.Tensor): The gt label of the prediction. weight (torch.Tensor, optional): Sample-wise loss weight. reduction (str): The method used to reduce the loss. avg_factor (int, optional): Average factor that is used to average the loss. Defaults to None. class_weight (torch.Tensor, optional): The weight for each class with shape (C), C is the number of classes. Default None. Returns: torch.Tensor: The calculated loss """ # element-wise losses loss = F.cross_entropy(pred, label, weight=class_weight, reduction='none') # apply weights and do the reduction if weight is not None: weight = weight.float() loss = weight_reduce_loss( loss, weight=weight, reduction=reduction, avg_factor=avg_factor) return loss def soft_cross_entropy(pred, label, weight=None, reduction='mean', class_weight=None, avg_factor=None): """Calculate the Soft CrossEntropy loss. The label can be float. Args: pred (torch.Tensor): The prediction with shape (N, C), C is the number of classes. label (torch.Tensor): The gt label of the prediction with shape (N, C). When using "mixup", the label can be float. weight (torch.Tensor, optional): Sample-wise loss weight. reduction (str): The method used to reduce the loss. avg_factor (int, optional): Average factor that is used to average the loss. Defaults to None. class_weight (torch.Tensor, optional): The weight for each class with shape (C), C is the number of classes. Default None. Returns: torch.Tensor: The calculated loss """ # element-wise losses loss = -label * F.log_softmax(pred, dim=-1) if class_weight is not None: loss *= class_weight loss = loss.sum(dim=-1) # apply weights and do the reduction if weight is not None: weight = weight.float() loss = weight_reduce_loss( loss, weight=weight, reduction=reduction, avg_factor=avg_factor) return loss def binary_cross_entropy(pred, label, weight=None, reduction='mean', avg_factor=None, class_weight=None, pos_weight=None): r"""Calculate the binary CrossEntropy loss with logits. Args: pred (torch.Tensor): The prediction with shape (N, \*). label (torch.Tensor): The gt label with shape (N, \*). weight (torch.Tensor, optional): Element-wise weight of loss with shape (N, ). Defaults to None. reduction (str): The method used to reduce the loss. Options are "none", "mean" and "sum". If reduction is 'none' , loss is same shape as pred and label. Defaults to 'mean'. avg_factor (int, optional): Average factor that is used to average the loss. Defaults to None. class_weight (torch.Tensor, optional): The weight for each class with shape (C), C is the number of classes. Default None. pos_weight (torch.Tensor, optional): The positive weight for each class with shape (C), C is the number of classes. Default None. Returns: torch.Tensor: The calculated loss """ # Ensure that the size of class_weight is consistent with pred and label to # avoid automatic boracast, assert pred.dim() == label.dim() if class_weight is not None: N = pred.size()[0] class_weight = class_weight.repeat(N, 1) loss = F.binary_cross_entropy_with_logits( pred, label, weight=class_weight, pos_weight=pos_weight, reduction='none') # apply weights and do the reduction if weight is not None: assert weight.dim() == 1 weight = weight.float() if pred.dim() > 1: weight = weight.reshape(-1, 1) loss = weight_reduce_loss( loss, weight=weight, reduction=reduction, avg_factor=avg_factor) return loss @MODELS.register_module() class CrossEntropyLoss(nn.Module): """Cross entropy loss. Args: use_sigmoid (bool): Whether the prediction uses sigmoid of softmax. Defaults to False. use_soft (bool): Whether to use the soft version of CrossEntropyLoss. Defaults to False. reduction (str): The method used to reduce the loss. Options are "none", "mean" and "sum". Defaults to 'mean'. loss_weight (float): Weight of the loss. Defaults to 1.0. class_weight (List[float], optional): The weight for each class with shape (C), C is the number of classes. Default None. pos_weight (List[float], optional): The positive weight for each class with shape (C), C is the number of classes. Only enabled in BCE loss when ``use_sigmoid`` is True. Default None. """ def __init__(self, use_sigmoid=False, use_soft=False, reduction='mean', loss_weight=1.0, class_weight=None, pos_weight=None): super(CrossEntropyLoss, self).__init__() self.use_sigmoid = use_sigmoid self.use_soft = use_soft assert not ( self.use_soft and self.use_sigmoid ), 'use_sigmoid and use_soft could not be set simultaneously' self.reduction = reduction self.loss_weight = loss_weight self.class_weight = class_weight self.pos_weight = pos_weight if self.use_sigmoid: self.cls_criterion = binary_cross_entropy elif self.use_soft: self.cls_criterion = soft_cross_entropy else: self.cls_criterion = cross_entropy def forward(self, cls_score, label, weight=None, avg_factor=None, reduction_override=None, **kwargs): assert reduction_override in (None, 'none', 'mean', 'sum') reduction = ( reduction_override if reduction_override else self.reduction) if self.class_weight is not None: class_weight = cls_score.new_tensor(self.class_weight) else: class_weight = None # only BCE loss has pos_weight if self.pos_weight is not None and self.use_sigmoid: pos_weight = cls_score.new_tensor(self.pos_weight) kwargs.update({'pos_weight': pos_weight}) else: pos_weight = None loss_cls = self.loss_weight * self.cls_criterion( cls_score, label, weight, class_weight=class_weight, reduction=reduction, avg_factor=avg_factor, **kwargs) return loss_cls