AiOS / detrsmpl /models /losses /balanced_mse_loss.py
ttxskk
update
d7e58f0
raw
history blame
6.29 kB
# ------------------------------------------------------------------------------
# Adapted from https://github.com/jiawei-ren/BalancedMSE
# Original licence: Copyright (c) 2022 Jiawei Ren, under the MIT License.
# ------------------------------------------------------------------------------
from typing import Optional, Union
import torch
import torch.distributed as dist
import torch.nn.functional as F
from mmcv.runner import get_dist_info
from torch.nn.modules.loss import _Loss
from .utils import weighted_loss
@weighted_loss
def bmc_loss_md(pred: torch.Tensor, target: torch.Tensor,
noise_var: torch.Tensor, all_gather: bool,
loss_mse_weight: float,
loss_debias_weight: float) -> torch.Tensor:
"""
Args:
pred (torch.Tensor): The prediction. Shape should be (N, L).
target (torch.Tensor): The learning target of the prediction.
noise_var (torch.Tensor): Noise var of ground truth distribution.
all_gather (bool): Whether gather tensors across all sub-processes.
Only used in DDP training scheme.
loss_mse_weight (float, optional): The weight of the mse term.
loss_debias_weight (float, optional): The weight of the debiased term.
Returns:
torch.Tensor: The calculated loss
"""
N = pred.shape[0]
L = pred.shape[1]
device = pred.device
loss_mse = F.mse_loss(pred, target, reduction='none').sum(-1)
loss_mse = loss_mse / noise_var
if all_gather:
rank, world_size = get_dist_info()
bs, length = target.shape
all_bs = [torch.zeros(1).to(device) for _ in range(world_size)]
dist.all_gather(all_bs, torch.Tensor([bs]).to(device))
all_bs_int = [int(v.item()) for v in all_bs]
max_bs_int = max(all_bs_int)
target_padding = torch.zeros(max_bs_int, length).to(device)
target_padding[:bs] = target
all_tensor = []
for _ in range(world_size):
all_tensor.append(torch.zeros(max_bs_int, length).type_as(target))
dist.all_gather(all_tensor, target_padding)
# remove padding
for i in range(world_size):
all_tensor[i] = all_tensor[i][:all_bs_int[i]]
target = torch.cat(all_tensor, dim=0)
# Debias term
target = target.unsqueeze(0).repeat(N, 1, 1)
pred = pred.unsqueeze(1).expand_as(target)
debias_term = F.mse_loss(pred, target, reduction='none').sum(-1)
debias_term = -0.5 * debias_term / noise_var
loss_debias = torch.logsumexp(debias_term, dim=1).squeeze(-1)
loss = loss_mse * loss_mse_weight + loss_debias * loss_debias_weight
# recover loss scale of mse_loss
loss = loss / L * noise_var.detach()
return loss
class BMCLossMD(_Loss):
"""Balanced MSE loss, use batch monte-carlo to estimate distribution.
https://arxiv.org/abs/2203.16427.
Args:
init_noise_sigma (float, optional): The initial value of noise sigma.
This sigma is used to represent ground truth distribution.
Defaults to 1.0.
all_gather (bool, optional): Whether gather tensors across all
sub-processes. If set True, BMC will have more precise estimation
with more time cost. Default: False.
reduction (str, optional): The method that reduces the loss to a
scalar. Options are "none", "mean" and "sum".
loss_mse_weight (float, optional): The weight of the mse term.
Defaults to 1.0.
loss_debias_weight (float, optional): The weight of the debiased term.
Defaults to 1.0.
"""
def __init__(self,
init_noise_sigma: Optional[float] = 1.0,
all_gather: Optional[bool] = False,
reduction: Optional[str] = 'mean',
loss_mse_weight: Optional[float] = 1.0,
loss_debias_weight: Optional[float] = 1.0):
super(BMCLossMD, self).__init__()
self.noise_sigma = torch.nn.Parameter(
torch.tensor(init_noise_sigma).float())
self.all_gather = all_gather
assert reduction in (None, 'none', 'mean', 'sum')
reduction = 'none' if reduction is None else reduction
self.reduction = reduction
self.loss_mse_weight = loss_mse_weight
self.loss_debias_weight = loss_debias_weight
def forward(
self,
pred: torch.Tensor,
target: torch.Tensor,
weight: Optional[Union[torch.Tensor, None]] = None,
avg_factor: Optional[Union[int, None]] = None,
reduction_override: Optional[Union[str,
None]] = None) -> torch.Tensor:
"""Forward function of loss.
Args:
pred (torch.Tensor): The prediction.
target (torch.Tensor): The learning target of the prediction.
weight (torch.Tensor, optional): Weight of the loss for each
prediction. Defaults to None.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
weight (torch.Tensor, optional): Weight of the loss for each
prediction. Defaults to None.
reduction_override (str, optional): The reduction method used to
override the original reduction method of the loss.
Defaults to None.
Returns:
torch.Tensor: The calculated loss
"""
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (reduction_override
if reduction_override else self.reduction)
noise_var = (self.noise_sigma**2).type_as(pred)
pred = pred.view(pred.shape[0], -1)
target = target.view(target.shape[0], -1)
loss = bmc_loss_md(pred,
target,
noise_var=noise_var,
all_gather=self.all_gather,
loss_mse_weight=self.loss_mse_weight,
loss_debias_weight=self.loss_debias_weight,
weight=weight,
reduction=reduction,
avg_factor=avg_factor)
return loss