Spaces:
Running
on
L40S
Running
on
L40S
import itertools | |
import os | |
import pickle | |
import sys | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from detrsmpl.core.conventions.joints_mapping.standard_joint_angles import ( | |
STANDARD_JOINT_ANGLE_LIMITS, | |
TRANSFORMATION_AA_TO_SJA, | |
TRANSFORMATION_SJA_TO_AA, | |
) | |
from detrsmpl.utils.keypoint_utils import search_limbs | |
from detrsmpl.utils.transforms import aa_to_rot6d, aa_to_sja | |
class ShapePriorLoss(nn.Module): | |
"""Prior loss for body shape parameters. | |
Args: | |
reduction (str, optional): The method that reduces the loss to a | |
scalar. Options are "none", "mean" and "sum". | |
loss_weight (float, optional): The weight of the loss. Defaults to 1.0 | |
""" | |
def __init__(self, reduction='mean', loss_weight=1.0): | |
super().__init__() | |
assert reduction in (None, 'none', 'mean', 'sum') | |
self.reduction = reduction | |
self.loss_weight = loss_weight | |
def forward(self, | |
betas, | |
loss_weight_override=None, | |
reduction_override=None): | |
"""Forward function of loss. | |
Args: | |
betas (torch.Tensor): The body shape parameters | |
loss_weight_override (float, optional): The weight of loss used to | |
override the original weight of loss | |
reduction_override (str, optional): The reduction method used to | |
override the original reduction method of the loss. | |
Defaults to None | |
Returns: | |
torch.Tensor: The calculated loss | |
""" | |
assert reduction_override in (None, 'none', 'mean', 'sum') | |
reduction = (reduction_override | |
if reduction_override else self.reduction) | |
loss_weight = (loss_weight_override if loss_weight_override is not None | |
else self.loss_weight) | |
shape_prior_loss = loss_weight * betas**2 | |
if reduction == 'mean': | |
shape_prior_loss = shape_prior_loss.mean() | |
elif reduction == 'sum': | |
shape_prior_loss = shape_prior_loss.sum() | |
return shape_prior_loss | |
class ShapeThresholdPriorLoss(nn.Module): | |
"""Threshold loss for betas. Soft constraint to prevent parameters for | |
leaving feasible set. Implements a penalty constraint that encourages the | |
parameters to stay in the feasible set of solutions. | |
Args: | |
margin (int, optional): The threshold value | |
norm (str, optional): The loss method. Options are 'l1', l2' | |
loss_weight (float, optional): The weight of the loss. Defaults to 1.0 | |
""" | |
def __init__(self, margin=1, norm='l2', epsilon=1e-7, loss_weight=1.0): | |
super().__init__() | |
self.margin = margin | |
assert norm in ['l1', 'l2'], 'Norm variable must me l1 or l2' | |
self.norm = norm | |
self.epsilon = epsilon | |
self.loss_weight = loss_weight | |
def forward(self, betas): | |
"""Forward function of loss. | |
Args: | |
betas (torch.Tensor): The body shape parameters | |
Returns: | |
torch.Tensor: The calculated loss | |
""" | |
abs_values = betas.abs() | |
mask = abs_values.gt(self.margin) | |
invalid_values = torch.masked_select(betas, mask) | |
if self.norm == 'l1': | |
return self.loss_weight * invalid_values.abs().sum() / ( | |
mask.to(dtype=betas.dtype).sum() + self.epsilon) | |
elif self.norm == 'l2': | |
return self.loss_weight * invalid_values.pow(2).sum() / ( | |
mask.to(dtype=betas.dtype).sum() + self.epsilon) | |
class PoseRegLoss(nn.Module): | |
"""Regulizer loss for body pose parameters. | |
Args: | |
reduction (str, optional): The method that reduces the loss to a | |
scalar. Options are "none", "mean" and "sum". | |
loss_weight (float, optional): The weight of the loss. Defaults to 1.0 | |
""" | |
def __init__(self, reduction='mean', loss_weight=1.0): | |
super().__init__() | |
assert reduction in (None, 'none', 'mean', 'sum') | |
self.reduction = reduction | |
self.loss_weight = loss_weight | |
def forward(self, | |
body_pose, | |
weight=None, | |
avg_factor=None, | |
loss_weight_override=None, | |
reduction_override=None): | |
reduction = (reduction_override | |
if reduction_override else self.reduction) | |
loss_weight = (loss_weight_override if loss_weight_override is not None | |
else self.loss_weight) | |
pose_prior_loss = loss_weight * (body_pose**2) | |
if reduction == 'mean': | |
pose_prior_loss = pose_prior_loss.mean() | |
elif reduction == 'sum': | |
pose_prior_loss = pose_prior_loss.sum() | |
return pose_prior_loss | |
class LimbLengthLoss(nn.Module): | |
"""Limb length loss for body shape parameters. As betas are associated with | |
the height of a person, fitting on limb length help determine body shape | |
parameters. It penalizes the L2 distance between target limb length and | |
pred limb length. Note that it should take keypoints3d as input, as limb | |
length computed from keypoints2d varies with camera. | |
Args: | |
convention (str): Limb convention to search for keypoint connections. | |
reduction (str, optional): The method that reduces the loss to a | |
scalar. Options are "none", "mean" and "sum". | |
loss_weight (float, optional): The weight of the loss. Defaults to 1.0 | |
eps (float, optional): epsilon for computing normalized limb vector. | |
Defaults to 1e-4. | |
""" | |
def __init__(self, | |
convention, | |
reduction='mean', | |
loss_weight=1.0, | |
eps=1e-4): | |
super().__init__() | |
assert reduction in (None, 'none', 'mean', 'sum') | |
self.reduction = reduction | |
self.loss_weight = loss_weight | |
self.eps = eps | |
limb_idxs, _ = search_limbs(data_source=convention) | |
limb_idxs = sorted(limb_idxs['body']) | |
self.limb_idxs = np.array( | |
list(x for x, _ in itertools.groupby(limb_idxs))) | |
def _compute_limb_length(self, keypoints3d): | |
kp_src = keypoints3d[:, self.limb_idxs[:, 0], :3] | |
kp_dst = keypoints3d[:, self.limb_idxs[:, 1], :3] | |
limb_vec = kp_dst - kp_src | |
limb_length = torch.norm(limb_vec, dim=2) | |
return limb_length | |
def _keypoint_conf_to_limb_conf(self, keypoint_conf): | |
limb_conf = torch.min(keypoint_conf[:, self.limb_idxs[:, 1]], | |
keypoint_conf[:, self.limb_idxs[:, 0]]) | |
return limb_conf | |
def forward(self, | |
pred, | |
target, | |
pred_conf=None, | |
target_conf=None, | |
loss_weight_override=None, | |
reduction_override=None): | |
"""Forward function of LimbLengthLoss. | |
Args: | |
pred (torch.Tensor): The predicted smpl keypoints3d. | |
Shape should be (N, K, 3). | |
B: batch size. K: number of keypoints. | |
target (torch.Tensor): The ground-truth keypoints3d. | |
Shape should be (N, K, 3). | |
pred_conf (torch.Tensor, optional): Confidence of | |
predicted keypoints. Shape should be (N, K). | |
target_conf (torch.Tensor, optional): Confidence of | |
target keypoints. Shape should be (N, K). | |
loss_weight_override (float, optional): The weight of loss used to | |
override the original weight of loss. Defaults to None. | |
reduction_override (str, optional): The reduction method used to | |
override the original reduction method of the loss. | |
Defaults to None | |
Returns: | |
torch.Tensor: The calculated loss | |
""" | |
assert pred.dim() == 3 and pred.shape[-1] == 3 | |
assert pred.shape == target.shape | |
if pred_conf is not None: | |
assert pred_conf.dim() == 2 | |
assert pred_conf.shape == pred.shape[:2] | |
if target_conf is not None: | |
assert target_conf.dim() == 2 | |
assert target_conf.shape == target.shape[:2] | |
assert reduction_override in (None, 'none', 'mean', 'sum') | |
reduction = (reduction_override | |
if reduction_override else self.reduction) | |
loss_weight = (loss_weight_override if loss_weight_override is not None | |
else self.loss_weight) | |
limb_len_target = self._compute_limb_length(target) | |
limb_len_pred = self._compute_limb_length(pred) | |
if target_conf is None: | |
target_conf = torch.ones_like(target[..., 0]) | |
if pred_conf is None: | |
pred_conf = torch.ones_like(pred[..., 0]) | |
limb_conf_target = self._keypoint_conf_to_limb_conf(target_conf) | |
limb_conf_pred = self._keypoint_conf_to_limb_conf(pred_conf) | |
limb_conf = limb_conf_target * limb_conf_pred | |
diff_len = limb_len_target - limb_len_pred | |
loss = diff_len**2 * limb_conf | |
if reduction == 'mean': | |
loss = loss.mean() | |
elif reduction == 'sum': | |
loss = loss.sum() | |
loss *= loss_weight | |
return loss | |
class JointPriorLoss(nn.Module): | |
"""Prior loss for joint angles. | |
Args: | |
reduction (str, optional): The method that reduces the loss to a | |
scalar. Options are "none", "mean" and "sum". | |
loss_weight (float, optional): The weight of the loss. Defaults to 1.0 | |
use_full_body (bool, optional): Use full set of joint constraints | |
(in standard joint angles). | |
smooth_spine (bool, optional): Ensuring smooth spine rotations | |
smooth_spine_loss_weight (float, optional): An additional weight | |
factor multiplied on smooth spine loss | |
""" | |
def __init__(self, | |
reduction='mean', | |
loss_weight=1.0, | |
use_full_body=False, | |
smooth_spine=False, | |
smooth_spine_loss_weight=1.0): | |
super().__init__() | |
assert reduction in (None, 'none', 'mean', 'sum') | |
self.reduction = reduction | |
self.loss_weight = loss_weight | |
self.use_full_body = use_full_body | |
self.smooth_spine = smooth_spine | |
self.smooth_spine_loss_weight = smooth_spine_loss_weight | |
if self.use_full_body: | |
self.register_buffer('R_t', TRANSFORMATION_AA_TO_SJA) | |
self.register_buffer('R_t_inv', TRANSFORMATION_SJA_TO_AA) | |
self.register_buffer('sja_limits', STANDARD_JOINT_ANGLE_LIMITS) | |
def forward(self, | |
body_pose, | |
loss_weight_override=None, | |
reduction_override=None): | |
"""Forward function of loss. | |
Args: | |
body_pose (torch.Tensor): The body pose parameters | |
loss_weight_override (float, optional): The weight of loss used to | |
override the original weight of loss | |
reduction_override (str, optional): The reduction method used to | |
override the original reduction method of the loss. | |
Defaults to None | |
Returns: | |
torch.Tensor: The calculated loss | |
""" | |
assert reduction_override in (None, 'none', 'mean', 'sum') | |
reduction = (reduction_override | |
if reduction_override else self.reduction) | |
loss_weight = (loss_weight_override if loss_weight_override is not None | |
else self.loss_weight) | |
if self.use_full_body: | |
batch_size = body_pose.shape[0] | |
body_pose_reshape = body_pose.reshape(batch_size, -1, 3) | |
assert body_pose_reshape.shape[1] in (21, 23) # smpl-x, smpl | |
body_pose_reshape = body_pose_reshape[:, :21, :] | |
body_pose_sja = aa_to_sja(body_pose_reshape, self.R_t, | |
self.R_t_inv) | |
lower_limits = self.sja_limits[:, :, 0] # shape: (21, 3) | |
upper_limits = self.sja_limits[:, :, 1] # shape: (21, 3) | |
lower_loss = (torch.exp(F.relu(lower_limits - body_pose_sja)) - | |
1).pow(2) | |
upper_loss = (torch.exp(F.relu(body_pose_sja - upper_limits)) - | |
1).pow(2) | |
standard_joint_angle_prior_loss = (lower_loss + upper_loss).view( | |
body_pose.shape[0], -1) # shape: (n, 3) | |
joint_prior_loss = standard_joint_angle_prior_loss | |
else: | |
# default joint prior loss applied on elbows and knees | |
joint_prior_loss = (torch.exp( | |
body_pose[:, [55, 58, 12, 15]] * | |
torch.tensor([1., -1., -1, -1.], device=body_pose.device)) - | |
1)**2 | |
if self.smooth_spine: | |
spine1 = body_pose[:, [9, 10, 11]] | |
spine2 = body_pose[:, [18, 19, 20]] | |
spine3 = body_pose[:, [27, 28, 29]] | |
smooth_spine_loss_12 = (torch.exp(F.relu(-spine1 * spine2)) - | |
1).pow(2) * self.smooth_spine_loss_weight | |
smooth_spine_loss_23 = (torch.exp(F.relu(-spine2 * spine3)) - | |
1).pow(2) * self.smooth_spine_loss_weight | |
joint_prior_loss = torch.cat( | |
[joint_prior_loss, smooth_spine_loss_12, smooth_spine_loss_23], | |
axis=1) | |
joint_prior_loss = loss_weight * joint_prior_loss | |
if reduction == 'mean': | |
joint_prior_loss = joint_prior_loss.mean() | |
elif reduction == 'sum': | |
joint_prior_loss = joint_prior_loss.sum() | |
return joint_prior_loss | |
class SmoothJointLoss(nn.Module): | |
"""Smooth loss for joint angles. | |
Args: | |
reduction (str, optional): The method that reduces the loss to a | |
scalar. Options are "none", "mean" and "sum". | |
loss_weight (float, optional): The weight of the loss. Defaults to 1.0 | |
degree (bool, optional): The flag which represents whether the input | |
tensor is in degree or radian. | |
""" | |
def __init__(self, | |
reduction='mean', | |
loss_weight=1.0, | |
degree=False, | |
loss_func='L1'): | |
super().__init__() | |
assert reduction in (None, 'none', 'mean', 'sum') | |
assert loss_func in ('L1', 'L2') | |
self.reduction = reduction | |
self.loss_weight = loss_weight | |
self.degree = degree | |
self.loss_func = loss_func | |
def forward(self, | |
body_pose, | |
loss_weight_override=None, | |
reduction_override=None): | |
"""Forward function of SmoothJointLoss. | |
Args: | |
body_pose (torch.Tensor): The body pose parameters | |
loss_weight_override (float, optional): The weight of loss used to | |
override the original weight of loss | |
reduction_override (str, optional): The reduction method used to | |
override the original reduction method of the loss. | |
Defaults to None | |
Returns: | |
torch.Tensor: The calculated loss | |
""" | |
assert reduction_override in (None, 'none', 'mean', 'sum') | |
reduction = (reduction_override | |
if reduction_override else self.reduction) | |
loss_weight = (loss_weight_override if loss_weight_override is not None | |
else self.loss_weight) | |
theta = body_pose.reshape(body_pose.shape[0], -1, 3) | |
if self.degree: | |
theta = torch.deg2rad(theta) | |
rot_6d = aa_to_rot6d(theta) | |
rot_6d_diff = rot_6d[1:] - rot_6d[:-1] | |
if self.loss_func == 'L2': | |
smooth_joint_loss = (rot_6d_diff**2).sum(dim=[1, 2]) | |
elif self.loss_func == 'L1': | |
smooth_joint_loss = rot_6d_diff.abs().sum(dim=[1, 2]) | |
else: | |
raise TypeError(f'{self.func} is not defined') | |
# add zero padding to retain original batch_size | |
smooth_joint_loss = torch.cat( | |
[torch.zeros_like(smooth_joint_loss)[:1], smooth_joint_loss]) | |
if reduction == 'mean': | |
smooth_joint_loss = smooth_joint_loss.mean() | |
elif reduction == 'sum': | |
smooth_joint_loss = smooth_joint_loss.sum() | |
smooth_joint_loss *= loss_weight | |
return smooth_joint_loss | |
class SmoothPelvisLoss(nn.Module): | |
"""Smooth loss for pelvis angles. | |
Args: | |
reduction (str, optional): The method that reduces the loss to a | |
scalar. Options are "none", "mean" and "sum". | |
loss_weight (float, optional): The weight of the loss. Defaults to 1.0 | |
degree (bool, optional): The flag which represents whether the input | |
tensor is in degree or radian. | |
""" | |
def __init__(self, reduction='mean', loss_weight=1.0, degree=False): | |
super().__init__() | |
assert reduction in (None, 'none', 'mean', 'sum') | |
self.reduction = reduction | |
self.loss_weight = loss_weight | |
self.degree = degree | |
def forward(self, | |
global_orient, | |
loss_weight_override=None, | |
reduction_override=None): | |
"""Forward function of SmoothPelvisLoss. | |
Args: | |
global_orient (torch.Tensor): The global orientation parameters | |
loss_weight_override (float, optional): The weight of loss used to | |
override the original weight of loss | |
reduction_override (str, optional): The reduction method used to | |
override the original reduction method of the loss. | |
Defaults to None | |
Returns: | |
torch.Tensor: The calculated loss | |
""" | |
assert reduction_override in (None, 'none', 'mean', 'sum') | |
reduction = (reduction_override | |
if reduction_override else self.reduction) | |
loss_weight = (loss_weight_override if loss_weight_override is not None | |
else self.loss_weight) | |
if self.degree: | |
global_orient = torch.deg2rad(global_orient) | |
pelvis = global_orient.unsqueeze(1) | |
rot_6d = aa_to_rot6d(pelvis) | |
rot_6d_diff = rot_6d[1:] - rot_6d[:-1] | |
smooth_pelvis_loss = rot_6d_diff.abs().sum(dim=-1) | |
# add zero padding to retain original batch_size | |
smooth_pelvis_loss = torch.cat( | |
[torch.zeros_like(smooth_pelvis_loss)[:1], | |
smooth_pelvis_loss]).sum(dim=-1) | |
smooth_pelvis_loss = loss_weight * smooth_pelvis_loss | |
if reduction == 'mean': | |
smooth_pelvis_loss = smooth_pelvis_loss.mean() | |
elif reduction == 'sum': | |
smooth_pelvis_loss = smooth_pelvis_loss.sum() | |
return smooth_pelvis_loss | |
class SmoothTranslationLoss(nn.Module): | |
"""Smooth loss for translations. | |
Args: | |
reduction (str, optional): The method that reduces the loss to a | |
scalar. Options are "none", "mean" and "sum". | |
loss_weight (float, optional): The weight of the loss. Defaults to 1.0 | |
""" | |
def __init__(self, reduction='mean', loss_weight=1.0): | |
super().__init__() | |
assert reduction in (None, 'none', 'mean', 'sum') | |
self.reduction = reduction | |
self.loss_weight = loss_weight | |
def forward(self, | |
translation, | |
loss_weight_override=None, | |
reduction_override=None): | |
"""Forward function of loss. | |
Args: | |
translation (torch.Tensor): The body translation parameters | |
loss_weight_override (float, optional): The weight of loss used to | |
override the original weight of loss | |
reduction_override (str, optional): The reduction method used to | |
override the original reduction method of the loss. | |
Defaults to None | |
Returns: | |
torch.Tensor: The calculated loss | |
""" | |
assert reduction_override in (None, 'none', 'mean', 'sum') | |
reduction = (reduction_override | |
if reduction_override else self.reduction) | |
loss_weight = (loss_weight_override if loss_weight_override is not None | |
else self.loss_weight) | |
translation_diff = translation[1:] - translation[:-1] | |
smooth_translation_loss = translation_diff.abs().sum(dim=-1, | |
keepdim=True) | |
# add zero padding to retain original batch_size | |
smooth_translation_loss = torch.cat([ | |
torch.zeros_like(smooth_translation_loss)[:1], | |
smooth_translation_loss | |
]).sum(dim=-1) | |
smooth_translation_loss *= 1e3 | |
smooth_translation_loss = loss_weight * \ | |
smooth_translation_loss | |
if reduction == 'mean': | |
smooth_translation_loss = smooth_translation_loss.mean() | |
elif reduction == 'sum': | |
smooth_translation_loss = smooth_translation_loss.sum() | |
return smooth_translation_loss | |
class CameraPriorLoss(nn.Module): | |
"""Prior loss for predicted camera. | |
Args: | |
reduction (str, optional): The method that reduces the loss to a | |
scalar. Options are "none", "mean" and "sum". | |
scale (float, optional): The scale coefficient for regularizing camera | |
parameters. Defaults to 10 | |
loss_weight (float, optional): The weight of the loss. Defaults to 1.0 | |
""" | |
def __init__(self, scale=10, reduction='mean', loss_weight=1.0): | |
super().__init__() | |
self.scale = scale | |
self.reduction = reduction | |
self.loss_weight = loss_weight | |
def forward(self, | |
cameras, | |
loss_weight_override=None, | |
reduction_override=None): | |
"""Forward function of loss. | |
Args: | |
cameras (torch.Tensor): The predicted camera parameters | |
loss_weight_override (float, optional): The weight of loss used to | |
override the original weight of loss | |
reduction_override (str, optional): The reduction method used to | |
override the original reduction method of the loss. | |
Defaults to None | |
Returns: | |
torch.Tensor: The calculated loss | |
""" | |
assert reduction_override in (None, 'none', 'mean', 'sum') | |
reduction = (reduction_override | |
if reduction_override else self.reduction) | |
loss_weight = (loss_weight_override if loss_weight_override is not None | |
else self.loss_weight) | |
camera_prior_loss = torch.exp(-cameras[:, 0] * self.scale) | |
camera_prior_loss = torch.pow(camera_prior_loss, 2) * loss_weight | |
if reduction == 'mean': | |
camera_prior_loss = camera_prior_loss.mean() | |
elif reduction == 'sum': | |
camera_prior_loss = camera_prior_loss.sum() | |
return camera_prior_loss | |
class MaxMixturePrior(nn.Module): | |
"""Ref: SMPLify-X | |
https://github.com/vchoutas/smplify-x/blob/master/smplifyx/prior.py | |
""" | |
def __init__(self, | |
prior_folder='data', | |
num_gaussians=8, | |
dtype=torch.float32, | |
epsilon=1e-16, | |
use_merged=True, | |
reduction=None, | |
loss_weight=1.0): | |
super(MaxMixturePrior, self).__init__() | |
assert reduction in (None, 'none', 'mean', 'sum') | |
self.reduction = reduction | |
self.loss_weight = loss_weight | |
if dtype == torch.float32: | |
np_dtype = np.float32 | |
elif dtype == torch.float64: | |
np_dtype = np.float64 | |
else: | |
print('Unknown float type {}, exiting!'.format(dtype)) | |
sys.exit(-1) | |
self.num_gaussians = num_gaussians | |
self.epsilon = epsilon | |
self.use_merged = use_merged | |
gmm_fn = 'gmm_{:02d}.pkl'.format(num_gaussians) | |
full_gmm_fn = os.path.join(prior_folder, gmm_fn) | |
if not os.path.exists(full_gmm_fn): | |
print('The path to the mixture prior "{}"'.format(full_gmm_fn) + | |
' does not exist, exiting!') | |
sys.exit(-1) | |
with open(full_gmm_fn, 'rb') as f: | |
gmm = pickle.load(f, encoding='latin1') | |
if type(gmm) == dict: | |
means = gmm['means'].astype(np_dtype) | |
covs = gmm['covars'].astype(np_dtype) | |
weights = gmm['weights'].astype(np_dtype) | |
elif 'sklearn.mixture.gmm.GMM' in str(type(gmm)): | |
means = gmm.means_.astype(np_dtype) | |
covs = gmm.covars_.astype(np_dtype) | |
weights = gmm.weights_.astype(np_dtype) | |
else: | |
print('Unknown type for the prior: {}, exiting!'.format(type(gmm))) | |
sys.exit(-1) | |
self.register_buffer('means', torch.tensor(means, dtype=dtype)) | |
self.register_buffer('covs', torch.tensor(covs, dtype=dtype)) | |
precisions = [np.linalg.inv(cov) for cov in covs] | |
precisions = np.stack(precisions).astype(np_dtype) | |
self.register_buffer('precisions', torch.tensor(precisions, | |
dtype=dtype)) | |
# The constant term: | |
sqrdets = np.array([(np.sqrt(np.linalg.det(c))) | |
for c in gmm['covars']]) | |
const = (2 * np.pi)**(69 / 2.) | |
nll_weights = np.asarray(gmm['weights'] / (const * | |
(sqrdets / sqrdets.min()))) | |
nll_weights = torch.tensor(nll_weights, dtype=dtype).unsqueeze(dim=0) | |
self.register_buffer('nll_weights', nll_weights) | |
weights = torch.tensor(gmm['weights'], dtype=dtype).unsqueeze(dim=0) | |
self.register_buffer('weights', weights) | |
self.register_buffer('pi_term', | |
torch.log(torch.tensor(2 * np.pi, dtype=dtype))) | |
cov_dets = [ | |
np.log(np.linalg.det(cov.astype(np_dtype)) + epsilon) | |
for cov in covs | |
] | |
self.register_buffer('cov_dets', torch.tensor(cov_dets, dtype=dtype)) | |
# The dimensionality of the random variable | |
self.random_var_dim = self.means.shape[1] | |
def get_mean(self): | |
"""Returns the mean of the mixture.""" | |
mean_pose = torch.matmul(self.weights, self.means) | |
return mean_pose | |
def merged_log_likelihood(self, pose): | |
diff_from_mean = pose.unsqueeze(dim=1) - self.means | |
prec_diff_prod = torch.einsum('mij,bmj->bmi', | |
[self.precisions, diff_from_mean]) | |
diff_prec_quadratic = (prec_diff_prod * diff_from_mean).sum(dim=-1) | |
curr_loglikelihood = 0.5 * diff_prec_quadratic - \ | |
torch.log(self.nll_weights) | |
# curr_loglikelihood = 0.5 * (self.cov_dets.unsqueeze(dim=0) + | |
# self.random_var_dim * self.pi_term + | |
# diff_prec_quadratic | |
# ) - torch.log(self.weights) | |
min_likelihood, _ = torch.min(curr_loglikelihood, dim=1) | |
return min_likelihood | |
def log_likelihood(self, pose): | |
"""Create graph operation for negative log-likelihood calculation.""" | |
likelihoods = [] | |
for idx in range(self.num_gaussians): | |
mean = self.means[idx] | |
prec = self.precisions[idx] | |
cov = self.covs[idx] | |
diff_from_mean = pose - mean | |
curr_loglikelihood = torch.einsum('bj,ji->bi', | |
[diff_from_mean, prec]) | |
curr_loglikelihood = torch.einsum( | |
'bi,bi->b', [curr_loglikelihood, diff_from_mean]) | |
cov_term = torch.log(torch.det(cov) + self.epsilon) | |
curr_loglikelihood += 0.5 * (cov_term + | |
self.random_var_dim * self.pi_term) | |
likelihoods.append(curr_loglikelihood) | |
log_likelihoods = torch.stack(likelihoods, dim=1) | |
min_idx = torch.argmin(log_likelihoods, dim=1) | |
weight_component = self.nll_weights[:, min_idx] | |
weight_component = -torch.log(weight_component) | |
return weight_component + log_likelihoods[:, min_idx] | |
def forward(self, | |
body_pose, | |
loss_weight_override=None, | |
reduction_override=None): | |
assert reduction_override in (None, 'none', 'mean', 'sum') | |
reduction = (reduction_override | |
if reduction_override else self.reduction) | |
loss_weight = (loss_weight_override if loss_weight_override is not None | |
else self.loss_weight) | |
if self.use_merged: | |
pose_prior_loss = self.merged_log_likelihood(body_pose) | |
else: | |
pose_prior_loss = self.log_likelihood(body_pose) | |
pose_prior_loss = loss_weight * pose_prior_loss | |
if reduction == 'mean': | |
pose_prior_loss = pose_prior_loss.mean() | |
elif reduction == 'sum': | |
pose_prior_loss = pose_prior_loss.sum() | |
return pose_prior_loss | |