Spaces:
Sleeping
Sleeping
File size: 6,676 Bytes
d7e58f0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from .utils import weighted_loss
def gmof(x, sigma):
"""Geman-McClure error function."""
x_squared = x**2
sigma_squared = sigma**2
return (sigma_squared * x_squared) / (sigma_squared + x_squared)
@weighted_loss
def mse_loss(pred, target):
"""Warpper of mse loss."""
return F.mse_loss(pred, target, reduction='none')
@weighted_loss
def mse_loss_with_gmof(pred, target, sigma):
"""Extended MSE Loss with GMOF."""
loss = F.mse_loss(pred, target, reduction='none')
loss = gmof(loss, sigma)
return loss
class MSELoss(nn.Module):
"""MSELoss.
Args:
reduction (str, optional): The method that reduces the loss to a
scalar. Options are "none", "mean" and "sum".
loss_weight (float, optional): The weight of the loss. Defaults to 1.0
"""
def __init__(self, reduction='mean', loss_weight=1.0):
super().__init__()
assert reduction in (None, 'none', 'mean', 'sum')
reduction = 'none' if reduction is None else reduction
self.reduction = reduction
self.loss_weight = loss_weight
def forward(self,
pred,
target,
weight=None,
avg_factor=None,
reduction_override=None):
"""Forward function of loss.
Args:
pred (torch.Tensor): The prediction.
target (torch.Tensor): The learning target of the prediction.
weight (torch.Tensor, optional): Weight of the loss for each
prediction. Defaults to None.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
reduction_override (str, optional): The reduction method used to
override the original reduction method of the loss.
Defaults to None.
Returns:
torch.Tensor: The calculated loss
"""
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (reduction_override
if reduction_override else self.reduction)
loss = self.loss_weight * mse_loss(
pred, target, weight, reduction=reduction, avg_factor=avg_factor)
return loss
class KeypointMSELoss(nn.Module):
"""MSELoss for 2D and 3D keypoints.
Args:
reduction (str, optional): The method that reduces the loss to a
scalar. Options are "none", "mean" and "sum".
loss_weight (float, optional): The weight of the loss. Defaults to 1.0
sigma (float, optional): Weighing parameter of Geman-McClure
error function. Defaults to 1.0 (no effect).
keypoint_weight (List[float], optional): Weighing parameter for each
keypoint. Shape should be (K). K: number of keypoints. Defaults to
None (no effect).
"""
def __init__(self,
reduction='mean',
loss_weight=1.0,
sigma=1.0,
keypoint_weight=None):
super().__init__()
assert reduction in (None, 'none', 'mean', 'sum')
reduction = 'none' if reduction is None else reduction
self.reduction = reduction
self.loss_weight = loss_weight
self.sigma = sigma
if keypoint_weight is None:
self.keypoint_weight = None
else:
self.keypoint_weight = torch.Tensor(keypoint_weight)
def forward(self,
pred,
target,
pred_conf=None,
target_conf=None,
keypoint_weight=None,
avg_factor=None,
loss_weight_override=None,
reduction_override=None):
"""Forward function of loss.
Args:
pred (torch.Tensor): The prediction. Shape should be (N, K, 2/3)
B: batch size. K: number of keypoints.
target (torch.Tensor): The learning target of the prediction.
Shape should be the same as pred.
pred_conf (optional, torch.Tensor): Confidence of
predicted keypoints. Shape should be (N, K).
target_conf (optional, torch.Tensor): Confidence of
target keypoints. Shape should be the same as pred_conf.
keypoint_weight (optional, torch.Tensor): keypoint-wise weight.
shape should be (K,). This weight allow different weights
to be assigned at different body parts.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
loss_weight_override (float, optional): The overall weight of loss
used to override the original weight of loss.
reduction_override (str, optional): The reduction method used to
override the original reduction method of the loss.
Defaults to None.
Returns:
torch.Tensor: The calculated loss
"""
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (reduction_override
if reduction_override else self.reduction)
loss_weight = (loss_weight_override if loss_weight_override is not None
else self.loss_weight)
B, K, D = pred.shape
pred_conf = pred_conf.view((B, K, 1)) \
if pred_conf is not None else 1.0
target_conf = target_conf.view((B, K, 1)) \
if target_conf is not None else 1.0
keypoint_weight = keypoint_weight.view((1, K, 1)) \
if keypoint_weight is not None else \
self.keypoint_weight.view((1, K, 1)).type_as(pred) \
if self.keypoint_weight is not None else 1.0
weight = keypoint_weight * pred_conf * target_conf
assert isinstance(
weight,
float) or weight.shape == (B, K, 1) or weight.shape == (1, K, 1)
# B, J, D = pred.shape[:2]
# if len(weight.shape) == 1:
# # for simplify tools
# weight = weight.view(1, -1, 1)
# else:
# # for body model estimator
# weight = weight.view(B, J, 1)
loss = loss_weight * mse_loss_with_gmof(pred,
target,
weight,
reduction=reduction,
avg_factor=avg_factor,
sigma=self.sigma)
return loss
|