atatakun's picture
Duplicate from atatakun/testapp2
18dd6ad
raw
history blame
No virus
7.57 kB
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
# compute loss
class compute_loss(nn.Module):
def __init__(self, args):
"""args.loss_fn can be one of following:
- L1 - L1 loss (no uncertainty)
- L2 - L2 loss (no uncertainty)
- AL - Angular loss (no uncertainty)
- NLL_vMF - NLL of vonMF distribution
- NLL_ours - NLL of Angular vonMF distribution
- UG_NLL_vMF - NLL of vonMF distribution (+ pixel-wise MLP + uncertainty-guided sampling)
- UG_NLL_ours - NLL of Angular vonMF distribution (+ pixel-wise MLP + uncertainty-guided sampling)
"""
super(compute_loss, self).__init__()
self.loss_type = args.loss_fn
if self.loss_type in ['L1', 'L2', 'AL', 'NLL_vMF', 'NLL_ours']:
self.loss_fn = self.forward_R
elif self.loss_type in ['UG_NLL_vMF', 'UG_NLL_ours']:
self.loss_fn = self.forward_UG
else:
raise Exception('invalid loss type')
def forward(self, *args):
return self.loss_fn(*args)
def forward_R(self, norm_out, gt_norm, gt_norm_mask):
pred_norm, pred_kappa = norm_out[:, 0:3, :, :], norm_out[:, 3:, :, :]
if self.loss_type == 'L1':
l1 = torch.sum(torch.abs(gt_norm - pred_norm), dim=1, keepdim=True)
loss = torch.mean(l1[gt_norm_mask])
elif self.loss_type == 'L2':
l2 = torch.sum(torch.square(gt_norm - pred_norm), dim=1, keepdim=True)
loss = torch.mean(l2[gt_norm_mask])
elif self.loss_type == 'AL':
dot = torch.cosine_similarity(pred_norm, gt_norm, dim=1)
valid_mask = gt_norm_mask[:, 0, :, :].float() \
* (dot.detach() < 0.999).float() \
* (dot.detach() > -0.999).float()
valid_mask = valid_mask > 0.0
al = torch.acos(dot[valid_mask])
loss = torch.mean(al)
elif self.loss_type == 'NLL_vMF':
dot = torch.cosine_similarity(pred_norm, gt_norm, dim=1)
valid_mask = gt_norm_mask[:, 0, :, :].float() \
* (dot.detach() < 0.999).float() \
* (dot.detach() > -0.999).float()
valid_mask = valid_mask > 0.0
dot = dot[valid_mask]
kappa = pred_kappa[:, 0, :, :][valid_mask]
loss_pixelwise = - torch.log(kappa) \
- (kappa * (dot - 1)) \
+ torch.log(1 - torch.exp(- 2 * kappa))
loss = torch.mean(loss_pixelwise)
elif self.loss_type == 'NLL_ours':
dot = torch.cosine_similarity(pred_norm, gt_norm, dim=1)
valid_mask = gt_norm_mask[:, 0, :, :].float() \
* (dot.detach() < 0.999).float() \
* (dot.detach() > -0.999).float()
valid_mask = valid_mask > 0.0
dot = dot[valid_mask]
kappa = pred_kappa[:, 0, :, :][valid_mask]
loss_pixelwise = - torch.log(torch.square(kappa) + 1) \
+ kappa * torch.acos(dot) \
+ torch.log(1 + torch.exp(-kappa * np.pi))
loss = torch.mean(loss_pixelwise)
else:
raise Exception('invalid loss type')
return loss
def forward_UG(self, pred_list, coord_list, gt_norm, gt_norm_mask):
loss = 0.0
for (pred, coord) in zip(pred_list, coord_list):
if coord is None:
pred = F.interpolate(pred, size=[gt_norm.size(2), gt_norm.size(3)], mode='bilinear', align_corners=True)
pred_norm, pred_kappa = pred[:, 0:3, :, :], pred[:, 3:, :, :]
if self.loss_type == 'UG_NLL_vMF':
dot = torch.cosine_similarity(pred_norm, gt_norm, dim=1)
valid_mask = gt_norm_mask[:, 0, :, :].float() \
* (dot.detach() < 0.999).float() \
* (dot.detach() > -0.999).float()
valid_mask = valid_mask > 0.5
# mask
dot = dot[valid_mask]
kappa = pred_kappa[:, 0, :, :][valid_mask]
loss_pixelwise = - torch.log(kappa) \
- (kappa * (dot - 1)) \
+ torch.log(1 - torch.exp(- 2 * kappa))
loss = loss + torch.mean(loss_pixelwise)
elif self.loss_type == 'UG_NLL_ours':
dot = torch.cosine_similarity(pred_norm, gt_norm, dim=1)
valid_mask = gt_norm_mask[:, 0, :, :].float() \
* (dot.detach() < 0.999).float() \
* (dot.detach() > -0.999).float()
valid_mask = valid_mask > 0.5
dot = dot[valid_mask]
kappa = pred_kappa[:, 0, :, :][valid_mask]
loss_pixelwise = - torch.log(torch.square(kappa) + 1) \
+ kappa * torch.acos(dot) \
+ torch.log(1 + torch.exp(-kappa * np.pi))
loss = loss + torch.mean(loss_pixelwise)
else:
raise Exception
else:
# coord: B, 1, N, 2
# pred: B, 4, N
gt_norm_ = F.grid_sample(gt_norm, coord, mode='nearest', align_corners=True) # (B, 3, 1, N)
gt_norm_mask_ = F.grid_sample(gt_norm_mask.float(), coord, mode='nearest', align_corners=True) # (B, 1, 1, N)
gt_norm_ = gt_norm_[:, :, 0, :] # (B, 3, N)
gt_norm_mask_ = gt_norm_mask_[:, :, 0, :] > 0.5 # (B, 1, N)
pred_norm, pred_kappa = pred[:, 0:3, :], pred[:, 3:, :]
if self.loss_type == 'UG_NLL_vMF':
dot = torch.cosine_similarity(pred_norm, gt_norm_, dim=1) # (B, N)
valid_mask = gt_norm_mask_[:, 0, :].float() \
* (dot.detach() < 0.999).float() \
* (dot.detach() > -0.999).float()
valid_mask = valid_mask > 0.5
dot = dot[valid_mask]
kappa = pred_kappa[:, 0, :][valid_mask]
loss_pixelwise = - torch.log(kappa) \
- (kappa * (dot - 1)) \
+ torch.log(1 - torch.exp(- 2 * kappa))
loss = loss + torch.mean(loss_pixelwise)
elif self.loss_type == 'UG_NLL_ours':
dot = torch.cosine_similarity(pred_norm, gt_norm_, dim=1) # (B, N)
valid_mask = gt_norm_mask_[:, 0, :].float() \
* (dot.detach() < 0.999).float() \
* (dot.detach() > -0.999).float()
valid_mask = valid_mask > 0.5
dot = dot[valid_mask]
kappa = pred_kappa[:, 0, :][valid_mask]
loss_pixelwise = - torch.log(torch.square(kappa) + 1) \
+ kappa * torch.acos(dot) \
+ torch.log(1 + torch.exp(-kappa * np.pi))
loss = loss + torch.mean(loss_pixelwise)
else:
raise Exception
return loss