AiOS / detrsmpl /models /losses /mse_loss.py
ttxskk
update
d7e58f0
raw
history blame
6.68 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from .utils import weighted_loss
def gmof(x, sigma):
"""Geman-McClure error function."""
x_squared = x**2
sigma_squared = sigma**2
return (sigma_squared * x_squared) / (sigma_squared + x_squared)
@weighted_loss
def mse_loss(pred, target):
"""Warpper of mse loss."""
return F.mse_loss(pred, target, reduction='none')
@weighted_loss
def mse_loss_with_gmof(pred, target, sigma):
"""Extended MSE Loss with GMOF."""
loss = F.mse_loss(pred, target, reduction='none')
loss = gmof(loss, sigma)
return loss
class MSELoss(nn.Module):
"""MSELoss.
Args:
reduction (str, optional): The method that reduces the loss to a
scalar. Options are "none", "mean" and "sum".
loss_weight (float, optional): The weight of the loss. Defaults to 1.0
"""
def __init__(self, reduction='mean', loss_weight=1.0):
super().__init__()
assert reduction in (None, 'none', 'mean', 'sum')
reduction = 'none' if reduction is None else reduction
self.reduction = reduction
self.loss_weight = loss_weight
def forward(self,
pred,
target,
weight=None,
avg_factor=None,
reduction_override=None):
"""Forward function of loss.
Args:
pred (torch.Tensor): The prediction.
target (torch.Tensor): The learning target of the prediction.
weight (torch.Tensor, optional): Weight of the loss for each
prediction. Defaults to None.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
reduction_override (str, optional): The reduction method used to
override the original reduction method of the loss.
Defaults to None.
Returns:
torch.Tensor: The calculated loss
"""
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (reduction_override
if reduction_override else self.reduction)
loss = self.loss_weight * mse_loss(
pred, target, weight, reduction=reduction, avg_factor=avg_factor)
return loss
class KeypointMSELoss(nn.Module):
"""MSELoss for 2D and 3D keypoints.
Args:
reduction (str, optional): The method that reduces the loss to a
scalar. Options are "none", "mean" and "sum".
loss_weight (float, optional): The weight of the loss. Defaults to 1.0
sigma (float, optional): Weighing parameter of Geman-McClure
error function. Defaults to 1.0 (no effect).
keypoint_weight (List[float], optional): Weighing parameter for each
keypoint. Shape should be (K). K: number of keypoints. Defaults to
None (no effect).
"""
def __init__(self,
reduction='mean',
loss_weight=1.0,
sigma=1.0,
keypoint_weight=None):
super().__init__()
assert reduction in (None, 'none', 'mean', 'sum')
reduction = 'none' if reduction is None else reduction
self.reduction = reduction
self.loss_weight = loss_weight
self.sigma = sigma
if keypoint_weight is None:
self.keypoint_weight = None
else:
self.keypoint_weight = torch.Tensor(keypoint_weight)
def forward(self,
pred,
target,
pred_conf=None,
target_conf=None,
keypoint_weight=None,
avg_factor=None,
loss_weight_override=None,
reduction_override=None):
"""Forward function of loss.
Args:
pred (torch.Tensor): The prediction. Shape should be (N, K, 2/3)
B: batch size. K: number of keypoints.
target (torch.Tensor): The learning target of the prediction.
Shape should be the same as pred.
pred_conf (optional, torch.Tensor): Confidence of
predicted keypoints. Shape should be (N, K).
target_conf (optional, torch.Tensor): Confidence of
target keypoints. Shape should be the same as pred_conf.
keypoint_weight (optional, torch.Tensor): keypoint-wise weight.
shape should be (K,). This weight allow different weights
to be assigned at different body parts.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
loss_weight_override (float, optional): The overall weight of loss
used to override the original weight of loss.
reduction_override (str, optional): The reduction method used to
override the original reduction method of the loss.
Defaults to None.
Returns:
torch.Tensor: The calculated loss
"""
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (reduction_override
if reduction_override else self.reduction)
loss_weight = (loss_weight_override if loss_weight_override is not None
else self.loss_weight)
B, K, D = pred.shape
pred_conf = pred_conf.view((B, K, 1)) \
if pred_conf is not None else 1.0
target_conf = target_conf.view((B, K, 1)) \
if target_conf is not None else 1.0
keypoint_weight = keypoint_weight.view((1, K, 1)) \
if keypoint_weight is not None else \
self.keypoint_weight.view((1, K, 1)).type_as(pred) \
if self.keypoint_weight is not None else 1.0
weight = keypoint_weight * pred_conf * target_conf
assert isinstance(
weight,
float) or weight.shape == (B, K, 1) or weight.shape == (1, K, 1)
# B, J, D = pred.shape[:2]
# if len(weight.shape) == 1:
# # for simplify tools
# weight = weight.view(1, -1, 1)
# else:
# # for body model estimator
# weight = weight.view(B, J, 1)
loss = loss_weight * mse_loss_with_gmof(pred,
target,
weight,
reduction=reduction,
avg_factor=avg_factor,
sigma=self.sigma)
return loss