# Copyright (c) OpenMMLab. All rights reserved. import functools from typing import Callable, Optional import torch import torch.nn.functional as F from torch import Tensor def reduce_loss(loss: Tensor, reduction: str) -> Tensor: """Reduce loss as specified. Args: loss (Tensor): Elementwise loss tensor. reduction (str): Options are "none", "mean" and "sum". Return: Tensor: Reduced loss tensor. """ reduction_enum = F._Reduction.get_enum(reduction) # none: 0, elementwise_mean:1, sum: 2 if reduction_enum == 0: return loss elif reduction_enum == 1: return loss.mean() elif reduction_enum == 2: return loss.sum() def weight_reduce_loss(loss: Tensor, weight: Optional[Tensor] = None, reduction: str = 'mean', avg_factor: Optional[float] = None) -> Tensor: """Apply element-wise weight and reduce loss. Args: loss (Tensor): Element-wise loss. weight (Optional[Tensor], optional): Element-wise weights. Defaults to None. reduction (str, optional): Same as built-in losses of PyTorch. Defaults to 'mean'. avg_factor (Optional[float], optional): Average factor when computing the mean of losses. Defaults to None. Returns: Tensor: Processed loss values. """ # if weight is specified, apply element-wise weight if weight is not None: loss = loss * weight # if avg_factor is not specified, just reduce the loss if avg_factor is None: loss = reduce_loss(loss, reduction) else: # if reduction is mean, then average the loss by avg_factor if reduction == 'mean': # Avoid causing ZeroDivisionError when avg_factor is 0.0, # i.e., all labels of an image belong to ignore index. eps = torch.finfo(torch.float32).eps loss = loss.sum() / (avg_factor + eps) # if reduction is 'none', then do nothing, otherwise raise an error elif reduction != 'none': raise ValueError('avg_factor can not be used with reduction="sum"') return loss def weighted_loss(loss_func: Callable) -> Callable: """Create a weighted version of a given loss function. To use this decorator, the loss function must have the signature like `loss_func(pred, target, **kwargs)`. The function only needs to compute element-wise loss without any reduction. This decorator will add weight and reduction arguments to the function. The decorated function will have the signature like `loss_func(pred, target, weight=None, reduction='mean', avg_factor=None, **kwargs)`. :Example: >>> import torch >>> @weighted_loss >>> def l1_loss(pred, target): >>> return (pred - target).abs() >>> pred = torch.Tensor([0, 2, 3]) >>> target = torch.Tensor([1, 1, 1]) >>> weight = torch.Tensor([1, 0, 1]) >>> l1_loss(pred, target) tensor(1.3333) >>> l1_loss(pred, target, weight) tensor(1.) >>> l1_loss(pred, target, reduction='none') tensor([1., 1., 2.]) >>> l1_loss(pred, target, weight, avg_factor=2) tensor(1.5000) """ @functools.wraps(loss_func) def wrapper(pred: Tensor, target: Tensor, weight: Optional[Tensor] = None, reduction: str = 'mean', avg_factor: Optional[int] = None, **kwargs) -> Tensor: """ Args: pred (Tensor): The prediction. target (Tensor): Target bboxes. weight (Optional[Tensor], optional): The weight of loss for each prediction. Defaults to None. reduction (str, optional): Options are "none", "mean" and "sum". Defaults to 'mean'. avg_factor (Optional[int], optional): Average factor that is used to average the loss. Defaults to None. Returns: Tensor: Loss tensor. """ # get element-wise loss loss = loss_func(pred, target, **kwargs) loss = weight_reduce_loss(loss, weight, reduction, avg_factor) return loss return wrapper