Spaces:
Running
on
L40S
Running
on
L40S
# ------------------------------------------------------------------------------ | |
# Adapted from https://github.com/jiawei-ren/BalancedMSE | |
# Original licence: Copyright (c) 2022 Jiawei Ren, under the MIT License. | |
# ------------------------------------------------------------------------------ | |
from typing import Optional, Union | |
import torch | |
import torch.distributed as dist | |
import torch.nn.functional as F | |
from mmcv.runner import get_dist_info | |
from torch.nn.modules.loss import _Loss | |
from .utils import weighted_loss | |
def bmc_loss_md(pred: torch.Tensor, target: torch.Tensor, | |
noise_var: torch.Tensor, all_gather: bool, | |
loss_mse_weight: float, | |
loss_debias_weight: float) -> torch.Tensor: | |
""" | |
Args: | |
pred (torch.Tensor): The prediction. Shape should be (N, L). | |
target (torch.Tensor): The learning target of the prediction. | |
noise_var (torch.Tensor): Noise var of ground truth distribution. | |
all_gather (bool): Whether gather tensors across all sub-processes. | |
Only used in DDP training scheme. | |
loss_mse_weight (float, optional): The weight of the mse term. | |
loss_debias_weight (float, optional): The weight of the debiased term. | |
Returns: | |
torch.Tensor: The calculated loss | |
""" | |
N = pred.shape[0] | |
L = pred.shape[1] | |
device = pred.device | |
loss_mse = F.mse_loss(pred, target, reduction='none').sum(-1) | |
loss_mse = loss_mse / noise_var | |
if all_gather: | |
rank, world_size = get_dist_info() | |
bs, length = target.shape | |
all_bs = [torch.zeros(1).to(device) for _ in range(world_size)] | |
dist.all_gather(all_bs, torch.Tensor([bs]).to(device)) | |
all_bs_int = [int(v.item()) for v in all_bs] | |
max_bs_int = max(all_bs_int) | |
target_padding = torch.zeros(max_bs_int, length).to(device) | |
target_padding[:bs] = target | |
all_tensor = [] | |
for _ in range(world_size): | |
all_tensor.append(torch.zeros(max_bs_int, length).type_as(target)) | |
dist.all_gather(all_tensor, target_padding) | |
# remove padding | |
for i in range(world_size): | |
all_tensor[i] = all_tensor[i][:all_bs_int[i]] | |
target = torch.cat(all_tensor, dim=0) | |
# Debias term | |
target = target.unsqueeze(0).repeat(N, 1, 1) | |
pred = pred.unsqueeze(1).expand_as(target) | |
debias_term = F.mse_loss(pred, target, reduction='none').sum(-1) | |
debias_term = -0.5 * debias_term / noise_var | |
loss_debias = torch.logsumexp(debias_term, dim=1).squeeze(-1) | |
loss = loss_mse * loss_mse_weight + loss_debias * loss_debias_weight | |
# recover loss scale of mse_loss | |
loss = loss / L * noise_var.detach() | |
return loss | |
class BMCLossMD(_Loss): | |
"""Balanced MSE loss, use batch monte-carlo to estimate distribution. | |
https://arxiv.org/abs/2203.16427. | |
Args: | |
init_noise_sigma (float, optional): The initial value of noise sigma. | |
This sigma is used to represent ground truth distribution. | |
Defaults to 1.0. | |
all_gather (bool, optional): Whether gather tensors across all | |
sub-processes. If set True, BMC will have more precise estimation | |
with more time cost. Default: False. | |
reduction (str, optional): The method that reduces the loss to a | |
scalar. Options are "none", "mean" and "sum". | |
loss_mse_weight (float, optional): The weight of the mse term. | |
Defaults to 1.0. | |
loss_debias_weight (float, optional): The weight of the debiased term. | |
Defaults to 1.0. | |
""" | |
def __init__(self, | |
init_noise_sigma: Optional[float] = 1.0, | |
all_gather: Optional[bool] = False, | |
reduction: Optional[str] = 'mean', | |
loss_mse_weight: Optional[float] = 1.0, | |
loss_debias_weight: Optional[float] = 1.0): | |
super(BMCLossMD, self).__init__() | |
self.noise_sigma = torch.nn.Parameter( | |
torch.tensor(init_noise_sigma).float()) | |
self.all_gather = all_gather | |
assert reduction in (None, 'none', 'mean', 'sum') | |
reduction = 'none' if reduction is None else reduction | |
self.reduction = reduction | |
self.loss_mse_weight = loss_mse_weight | |
self.loss_debias_weight = loss_debias_weight | |
def forward( | |
self, | |
pred: torch.Tensor, | |
target: torch.Tensor, | |
weight: Optional[Union[torch.Tensor, None]] = None, | |
avg_factor: Optional[Union[int, None]] = None, | |
reduction_override: Optional[Union[str, | |
None]] = None) -> torch.Tensor: | |
"""Forward function of loss. | |
Args: | |
pred (torch.Tensor): The prediction. | |
target (torch.Tensor): The learning target of the prediction. | |
weight (torch.Tensor, optional): Weight of the loss for each | |
prediction. Defaults to None. | |
avg_factor (int, optional): Average factor that is used to average | |
the loss. Defaults to None. | |
weight (torch.Tensor, optional): Weight of the loss for each | |
prediction. Defaults to None. | |
reduction_override (str, optional): The reduction method used to | |
override the original reduction method of the loss. | |
Defaults to None. | |
Returns: | |
torch.Tensor: The calculated loss | |
""" | |
assert reduction_override in (None, 'none', 'mean', 'sum') | |
reduction = (reduction_override | |
if reduction_override else self.reduction) | |
noise_var = (self.noise_sigma**2).type_as(pred) | |
pred = pred.view(pred.shape[0], -1) | |
target = target.view(target.shape[0], -1) | |
loss = bmc_loss_md(pred, | |
target, | |
noise_var=noise_var, | |
all_gather=self.all_gather, | |
loss_mse_weight=self.loss_mse_weight, | |
loss_debias_weight=self.loss_debias_weight, | |
weight=weight, | |
reduction=reduction, | |
avg_factor=avg_factor) | |
return loss | |