import torch import torch.nn as nn import numpy as np import torch.nn.functional as F # compute loss class compute_loss(nn.Module): def __init__(self, args): """args.loss_fn can be one of following: - L1 - L1 loss (no uncertainty) - L2 - L2 loss (no uncertainty) - AL - Angular loss (no uncertainty) - NLL_vMF - NLL of vonMF distribution - NLL_ours - NLL of Angular vonMF distribution - UG_NLL_vMF - NLL of vonMF distribution (+ pixel-wise MLP + uncertainty-guided sampling) - UG_NLL_ours - NLL of Angular vonMF distribution (+ pixel-wise MLP + uncertainty-guided sampling) """ super(compute_loss, self).__init__() self.loss_type = args.loss_fn if self.loss_type in ['L1', 'L2', 'AL', 'NLL_vMF', 'NLL_ours']: self.loss_fn = self.forward_R elif self.loss_type in ['UG_NLL_vMF', 'UG_NLL_ours']: self.loss_fn = self.forward_UG else: raise Exception('invalid loss type') def forward(self, *args): return self.loss_fn(*args) def forward_R(self, norm_out, gt_norm, gt_norm_mask): pred_norm, pred_kappa = norm_out[:, 0:3, :, :], norm_out[:, 3:, :, :] if self.loss_type == 'L1': l1 = torch.sum(torch.abs(gt_norm - pred_norm), dim=1, keepdim=True) loss = torch.mean(l1[gt_norm_mask]) elif self.loss_type == 'L2': l2 = torch.sum(torch.square(gt_norm - pred_norm), dim=1, keepdim=True) loss = torch.mean(l2[gt_norm_mask]) elif self.loss_type == 'AL': dot = torch.cosine_similarity(pred_norm, gt_norm, dim=1) valid_mask = gt_norm_mask[:, 0, :, :].float() \ * (dot.detach() < 0.999).float() \ * (dot.detach() > -0.999).float() valid_mask = valid_mask > 0.0 al = torch.acos(dot[valid_mask]) loss = torch.mean(al) elif self.loss_type == 'NLL_vMF': dot = torch.cosine_similarity(pred_norm, gt_norm, dim=1) valid_mask = gt_norm_mask[:, 0, :, :].float() \ * (dot.detach() < 0.999).float() \ * (dot.detach() > -0.999).float() valid_mask = valid_mask > 0.0 dot = dot[valid_mask] kappa = pred_kappa[:, 0, :, :][valid_mask] loss_pixelwise = - torch.log(kappa) \ - (kappa * (dot - 1)) \ + torch.log(1 - torch.exp(- 2 * kappa)) loss = torch.mean(loss_pixelwise) elif self.loss_type == 'NLL_ours': dot = torch.cosine_similarity(pred_norm, gt_norm, dim=1) valid_mask = gt_norm_mask[:, 0, :, :].float() \ * (dot.detach() < 0.999).float() \ * (dot.detach() > -0.999).float() valid_mask = valid_mask > 0.0 dot = dot[valid_mask] kappa = pred_kappa[:, 0, :, :][valid_mask] loss_pixelwise = - torch.log(torch.square(kappa) + 1) \ + kappa * torch.acos(dot) \ + torch.log(1 + torch.exp(-kappa * np.pi)) loss = torch.mean(loss_pixelwise) else: raise Exception('invalid loss type') return loss def forward_UG(self, pred_list, coord_list, gt_norm, gt_norm_mask): loss = 0.0 for (pred, coord) in zip(pred_list, coord_list): if coord is None: pred = F.interpolate(pred, size=[gt_norm.size(2), gt_norm.size(3)], mode='bilinear', align_corners=True) pred_norm, pred_kappa = pred[:, 0:3, :, :], pred[:, 3:, :, :] if self.loss_type == 'UG_NLL_vMF': dot = torch.cosine_similarity(pred_norm, gt_norm, dim=1) valid_mask = gt_norm_mask[:, 0, :, :].float() \ * (dot.detach() < 0.999).float() \ * (dot.detach() > -0.999).float() valid_mask = valid_mask > 0.5 # mask dot = dot[valid_mask] kappa = pred_kappa[:, 0, :, :][valid_mask] loss_pixelwise = - torch.log(kappa) \ - (kappa * (dot - 1)) \ + torch.log(1 - torch.exp(- 2 * kappa)) loss = loss + torch.mean(loss_pixelwise) elif self.loss_type == 'UG_NLL_ours': dot = torch.cosine_similarity(pred_norm, gt_norm, dim=1) valid_mask = gt_norm_mask[:, 0, :, :].float() \ * (dot.detach() < 0.999).float() \ * (dot.detach() > -0.999).float() valid_mask = valid_mask > 0.5 dot = dot[valid_mask] kappa = pred_kappa[:, 0, :, :][valid_mask] loss_pixelwise = - torch.log(torch.square(kappa) + 1) \ + kappa * torch.acos(dot) \ + torch.log(1 + torch.exp(-kappa * np.pi)) loss = loss + torch.mean(loss_pixelwise) else: raise Exception else: # coord: B, 1, N, 2 # pred: B, 4, N gt_norm_ = F.grid_sample(gt_norm, coord, mode='nearest', align_corners=True) # (B, 3, 1, N) gt_norm_mask_ = F.grid_sample(gt_norm_mask.float(), coord, mode='nearest', align_corners=True) # (B, 1, 1, N) gt_norm_ = gt_norm_[:, :, 0, :] # (B, 3, N) gt_norm_mask_ = gt_norm_mask_[:, :, 0, :] > 0.5 # (B, 1, N) pred_norm, pred_kappa = pred[:, 0:3, :], pred[:, 3:, :] if self.loss_type == 'UG_NLL_vMF': dot = torch.cosine_similarity(pred_norm, gt_norm_, dim=1) # (B, N) valid_mask = gt_norm_mask_[:, 0, :].float() \ * (dot.detach() < 0.999).float() \ * (dot.detach() > -0.999).float() valid_mask = valid_mask > 0.5 dot = dot[valid_mask] kappa = pred_kappa[:, 0, :][valid_mask] loss_pixelwise = - torch.log(kappa) \ - (kappa * (dot - 1)) \ + torch.log(1 - torch.exp(- 2 * kappa)) loss = loss + torch.mean(loss_pixelwise) elif self.loss_type == 'UG_NLL_ours': dot = torch.cosine_similarity(pred_norm, gt_norm_, dim=1) # (B, N) valid_mask = gt_norm_mask_[:, 0, :].float() \ * (dot.detach() < 0.999).float() \ * (dot.detach() > -0.999).float() valid_mask = valid_mask > 0.5 dot = dot[valid_mask] kappa = pred_kappa[:, 0, :][valid_mask] loss_pixelwise = - torch.log(torch.square(kappa) + 1) \ + kappa * torch.acos(dot) \ + torch.log(1 + torch.exp(-kappa * np.pi)) loss = loss + torch.mean(loss_pixelwise) else: raise Exception return loss