Spaces:
Running
on
T4
Running
on
T4
import os | |
import sys | |
import torch | |
import torch.nn as nn | |
import numpy as np | |
class AlignLoss(nn.Module): | |
def __init__(self, reduction='mean'): | |
super().__init__() | |
self.loss_fn = nn.L1Loss(reduction=reduction) | |
def forward(self, frames, masks, aligned_vs, aligned_rs): | |
""" | |
:param frames: The original frames(GT) | |
:param masks: Original masks | |
:param aligned_vs: aligned visibility map from reference frame(List: B, C, T, H, W) | |
:param aligned_rs: aligned reference frames(List: B, C, T, H, W) | |
:return: | |
""" | |
try: | |
B, C, T, H, W = frames.shape | |
except ValueError: | |
frames = frames.unsqueeze(2) | |
masks = masks.unsqueeze(2) | |
B, C, T, H, W = frames.shape | |
loss = 0 | |
for i in range(T): | |
frame = frames[:, :, i] | |
mask = masks[:, :, i] | |
aligned_v = aligned_vs[i] | |
aligned_r = aligned_rs[i] | |
loss += self._singleFrameAlignLoss(frame, mask, aligned_v, aligned_r) | |
return loss | |
def _singleFrameAlignLoss(self, targetFrame, targetMask, aligned_v, aligned_r): | |
""" | |
:param targetFrame: targetFrame to be aligned-> B, C, H, W | |
:param targetMask: the mask of target frames | |
:param aligned_v: aligned visibility map from reference frame | |
:param aligned_r: aligned reference frame-> B, C, T, H, W | |
:return: | |
""" | |
targetVisibility = 1. - targetMask | |
targetVisibility = targetVisibility.unsqueeze(2) | |
targetFrame = targetFrame.unsqueeze(2) | |
visibility_map = targetVisibility * aligned_v | |
target_visibility = visibility_map * targetFrame | |
reference_visibility = visibility_map * aligned_r | |
loss = 0 | |
for i in range(aligned_r.shape[2]): | |
loss += self.loss_fn(target_visibility[:, :, i], reference_visibility[:, :, i]) | |
return loss | |
class HoleVisibleLoss(nn.Module): | |
def __init__(self, reduction='mean'): | |
super().__init__() | |
self.loss_fn = nn.L1Loss(reduction=reduction) | |
def forward(self, outputs, masks, GTs, c_masks): | |
try: | |
B, C, T, H, W = outputs.shape | |
except ValueError: | |
outputs = outputs.unsqueeze(2) | |
masks = masks.unsqueeze(2) | |
GTs = GTs.unsqueeze(2) | |
c_masks = c_masks.unsqueeze(2) | |
B, C, T, H, W = outputs.shape | |
loss = 0 | |
for i in range(T): | |
loss += self._singleFrameHoleVisibleLoss(outputs[:, :, i], masks[:, :, i], c_masks[:, :, i], GTs[:, :, i]) | |
return loss | |
def _singleFrameHoleVisibleLoss(self, targetFrame, targetMask, c_mask, GT): | |
return self.loss_fn(targetMask * c_mask * targetFrame, targetMask * c_mask * GT) | |
class HoleInvisibleLoss(nn.Module): | |
def __init__(self, reduction='mean'): | |
super().__init__() | |
self.loss_fn = nn.L1Loss(reduction=reduction) | |
def forward(self, outputs, masks, GTs, c_masks): | |
try: | |
B, C, T, H, W = outputs.shape | |
except ValueError: | |
outputs = outputs.unsqueeze(2) | |
masks = masks.unsqueeze(2) | |
GTs = GTs.unsqueeze(2) | |
c_masks = c_masks.unsqueeze(2) | |
B, C, T, H, W = outputs.shape | |
loss = 0 | |
for i in range(T): | |
loss += self._singleFrameHoleInvisibleLoss(outputs[:, :, i], masks[:, :, i], c_masks[:, :, i], GTs[:, :, i]) | |
return loss | |
def _singleFrameHoleInvisibleLoss(self, targetFrame, targetMask, c_mask, GT): | |
return self.loss_fn(targetMask * (1. - c_mask) * targetFrame, targetMask * (1. - c_mask) * GT) | |
class NonHoleLoss(nn.Module): | |
def __init__(self, reduction='mean'): | |
super().__init__() | |
self.loss_fn = nn.L1Loss(reduction=reduction) | |
def forward(self, outputs, masks, GTs): | |
try: | |
B, C, T, H, W = outputs.shape | |
except ValueError: | |
outputs = outputs.unsqueeze(2) | |
masks = masks.unsqueeze(2) | |
GTs = GTs.unsqueeze(2) | |
B, C, T, H, W = outputs.shape | |
loss = 0 | |
for i in range(T): | |
loss += self._singleNonHoleLoss(outputs[:, :, i], masks[:, :, i], GTs[:, :, i]) | |
return loss | |
def _singleNonHoleLoss(self, targetFrame, targetMask, GT): | |
return self.loss_fn((1. - targetMask) * targetFrame, (1. - targetMask) * GT) | |
class ReconLoss(nn.Module): | |
def __init__(self, reduction='mean', masked=False): | |
super().__init__() | |
self.loss_fn = nn.L1Loss(reduction=reduction) | |
self.masked = masked | |
def forward(self, model_output, target, mask): | |
outputs = model_output | |
targets = target | |
if self.masked: | |
masks = mask | |
return self.loss_fn(outputs * masks, targets * masks) # L1 loss in masked region | |
else: | |
return self.loss_fn(outputs, targets) # L1 loss in the whole region | |
class VGGLoss(nn.Module): | |
def __init__(self, vgg): | |
super().__init__() | |
self.l1_loss = nn.L1Loss() | |
self.vgg = vgg | |
def vgg_loss(self, output, target): | |
output_feature = self.vgg(output) | |
target_feature = self.vgg(target) | |
loss = ( | |
self.l1_loss(output_feature.relu2_2, target_feature.relu2_2) | |
+ self.l1_loss(output_feature.relu3_3, target_feature.relu3_3) | |
+ self.l1_loss(output_feature.relu4_3, target_feature.relu4_3) | |
) | |
return loss | |
def forward(self, data_input, model_output): | |
targets = data_input | |
outputs = model_output | |
mean_image_loss = self.vgg_loss(outputs, targets) | |
return mean_image_loss | |
class StyleLoss(nn.Module): | |
def __init__(self, vgg, original_channel_norm=True): | |
super().__init__() | |
self.l1_loss = nn.L1Loss() | |
self.vgg = vgg | |
self.original_channel_norm = original_channel_norm | |
# From https://github.com/pytorch/tutorials/blob/master/advanced_source/neural_style_tutorial.py | |
def gram_matrix(self, input): | |
a, b, c, d = input.size() # a=batch size(=1) | |
# b=number of feature maps | |
# (c,d)=dimensions of a f. map (N=c*d) | |
features = input.view(a * b, c * d) # resise F_XL into \hat F_XL | |
G = torch.mm(features, features.t()) # compute the gram product | |
# we 'normalize' the values of the gram matrix | |
# by dividing by the number of element in each feature maps. | |
return G.div(a * b * c * d) | |
# Implement "Image Inpainting for Irregular Holes Using Partial Convolutions", Liu et al., 2018 | |
def style_loss(self, output, target): | |
output_features = self.vgg(output) | |
target_features = self.vgg(target) | |
layers = ['relu2_2', 'relu3_3', 'relu4_3'] # n_channel: 128 (=2 ** 7), 256 (=2 ** 8), 512 (=2 ** 9) | |
loss = 0 | |
for i, layer in enumerate(layers): | |
output_feature = getattr(output_features, layer) | |
target_feature = getattr(target_features, layer) | |
B, C_P, H, W = output_feature.shape | |
output_gram_matrix = self.gram_matrix(output_feature) | |
target_gram_matrix = self.gram_matrix(target_feature) | |
if self.original_channel_norm: | |
C_P_square_divider = 2 ** (i + 1) # original design (avoid too small loss) | |
else: | |
C_P_square_divider = C_P ** 2 | |
assert C_P == 128 * 2 ** i | |
loss += self.l1_loss(output_gram_matrix, target_gram_matrix) / C_P_square_divider | |
return loss | |
def forward(self, data_input, model_output): | |
targets = data_input | |
outputs = model_output | |
mean_image_loss = self.style_loss(outputs, targets) | |
return mean_image_loss | |
class L1LossMaskedMean(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.l1 = nn.L1Loss(reduction='sum') | |
def forward(self, x, y, mask): | |
masked = 1 - mask # 默认missing region的mask值为0,原有区域为1 | |
l1_sum = self.l1(x * masked, y * masked) | |
return l1_sum / torch.sum(masked) | |
class L2LossMaskedMean(nn.Module): | |
def __init__(self, reduction='sum'): | |
super().__init__() | |
self.l2 = nn.MSELoss(reduction=reduction) | |
def forward(self, x, y, mask): | |
masked = 1 - mask | |
l2_sum = self.l2(x * masked, y * masked) | |
return l2_sum / torch.sum(masked) | |
class ImcompleteVideoReconLoss(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.loss_fn = L1LossMaskedMean() | |
def forward(self, data_input, model_output): | |
imcomplete_video = model_output['imcomplete_video'] | |
targets = data_input['targets'] | |
down_sampled_targets = nn.functional.interpolate( | |
targets.transpose(1, 2), scale_factor=[1, 0.5, 0.5]) | |
masks = data_input['masks'] | |
down_sampled_masks = nn.functional.interpolate( | |
masks.transpose(1, 2), scale_factor=[1, 0.5, 0.5]) | |
return self.loss_fn( | |
imcomplete_video, down_sampled_targets, | |
down_sampled_masks | |
) | |
class CompleteFramesReconLoss(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.loss_fn = L1LossMaskedMean() | |
def forward(self, data_input, model_output): | |
outputs = model_output['outputs'] | |
targets = data_input['targets'] | |
masks = data_input['masks'] | |
return self.loss_fn(outputs, targets, masks) | |
class AdversarialLoss(nn.Module): | |
r""" | |
Adversarial loss | |
https://arxiv.org/abs/1711.10337 | |
""" | |
def __init__(self, type='nsgan', target_real_label=1.0, target_fake_label=0.0): | |
r""" | |
type = nsgan | lsgan | hinge | |
""" | |
super(AdversarialLoss, self).__init__() | |
self.type = type | |
self.register_buffer('real_label', torch.tensor(target_real_label)) | |
self.register_buffer('fake_label', torch.tensor(target_fake_label)) | |
if type == 'nsgan': | |
self.criterion = nn.BCELoss() | |
elif type == 'lsgan': | |
self.criterion = nn.MSELoss() | |
elif type == 'hinge': | |
self.criterion = nn.ReLU() | |
def __call__(self, outputs, is_real, is_disc=None): | |
if self.type == 'hinge': | |
if is_disc: | |
if is_real: | |
outputs = -outputs | |
return self.criterion(1 + outputs).mean() | |
else: | |
return (-outputs).mean() | |
else: | |
labels = (self.real_label if is_real else self.fake_label).expand_as( | |
outputs) | |
loss = self.criterion(outputs, labels) | |
return loss | |
# # From https://github.com/phoenix104104/fast_blind_video_consistency | |
# class TemporalWarpingLoss(nn.Module): | |
# def __init__(self, opts, flownet_checkpoint_path=None, alpha=50): | |
# super().__init__() | |
# self.loss_fn = L1LossMaskedMean() | |
# self.alpha = alpha | |
# self.opts = opts | |
# | |
# assert flownet_checkpoint_path is not None, "Flownet2 pretrained models must be provided" | |
# | |
# self.flownet_checkpoint_path = flownet_checkpoint_path | |
# raise NotImplementedError | |
# | |
# def get_flownet_checkpoint_path(self): | |
# return self.flownet_checkpoint_path | |
# | |
# def _flownetwrapper(self): | |
# Flownet = FlowNet2(self.opts, requires_grad=False) | |
# Flownet2_ckpt = torch.load(self.flownet_checkpoint_path) | |
# Flownet.load_state_dict(Flownet2_ckpt['state_dict']) | |
# Flownet.to(device) | |
# Flownet.exal() | |
# return Flownet | |
# | |
# def _setup(self): | |
# self.flownet = self._flownetwrapper() | |
# | |
# def _get_non_occlusuib_mask(self, targets, warped_targets): | |
# non_occlusion_masks = torch.exp( | |
# -self.alpha * torch.sum(targets[:, 1:] - warped_targets, dim=2).pow(2) | |
# ).unsqueeze(2) | |
# return non_occlusion_masks | |
# | |
# def _get_loss(self, outputs, warped_outputs, non_occlusion_masks, masks): | |
# return self.loss_fn( | |
# outputs[:, 1:] * non_occlusion_masks, | |
# warped_outputs * non_occlusion_masks, | |
# masks[:, 1:] | |
# ) | |
# | |
# def forward(self, data_input, model_output): | |
# if self.flownet is None: | |
# self._setup() | |
# | |
# targets = data_input['targets'].to(device) | |
# outputs = model_output['outputs'].to(device) | |
# flows = self.flownet.infer_video(targets).to(device) | |
# | |
# from utils.flow_utils import warp_optical_flow | |
# warped_targets = warp_optical_flow(targets[:, :-1], -flows).detach() | |
# warped_outputs = warp_optical_flow(outputs[:, :-1], -flows).detach() | |
# non_occlusion_masks = self._get_non_occlusion_mask(targets, warped_targets) | |
# | |
# # model_output is passed by name and dictionary is mutable | |
# # These values are sent to trainer for visualization | |
# model_output['warped_outputs'] = warped_outputs[0] | |
# model_output['warped_targets'] = warped_targets[0] | |
# model_output['non_occlusion_masks'] = non_occlusion_masks[0] | |
# from utils.flow_utils import flow_to_image | |
# flow_imgs = [] | |
# for flow in flows[0]: | |
# flow_img = flow_to_image(flow.cpu().permute(1, 2, 0).detach().numpy()).transpose(2, 0, 1) | |
# flow_imgs.append(torch.Tensor(flow_img)) | |
# model_output['flow_imgs'] = flow_imgs | |
# | |
# masks = data_input['masks'].to(device) | |
# return self._get_loss(outputs, warped_outputs, non_occlusion_masks, masks) | |
# | |
# | |
# class TemporalWarpingError(TemporalWarpingLoss): | |
# def __init__(self, flownet_checkpoint_path, alpha=50): | |
# super().__init__(flownet_checkpoint_path, alpha) | |
# self.loss_fn = L2LossMaskedMean(reduction='none') | |
# | |
# def _get_loss(self, outputs, warped_outputs, non_occlusion_masks, masks): | |
# # See https://arxiv.org/pdf/1808.00449.pdf 4.3 | |
# # The sum of non_occlusion_masks is different for each video, | |
# # So the batch dim is kept | |
# loss = self.loss_fn( | |
# outputs[:, 1:] * non_occlusion_masks, | |
# warped_outputs * non_occlusion_masks, | |
# masks[:, 1:] | |
# ).sum(1).sum(1).sum(1).sum(1) | |
# | |
# loss = loss / non_occlusion_masks.sum(1).sum(1).sum(1).sum(1) | |
# return loss.sum() | |
class ValidLoss(nn.Module): | |
def __init__(self): | |
super(ValidLoss, self).__init__() | |
self.loss_fn = nn.L1Loss(reduction='mean') | |
def forward(self, model_output, target, mk): | |
outputs = model_output | |
targets = target | |
return self.loss_fn(outputs * (1 - mk), targets * (1 - mk)) # L1 loss in masked region | |
class TVLoss(nn.Module): | |
def __init__(self): | |
super(TVLoss, self).__init__() | |
def forward(self, mask_input, model_output): | |
# View 3D data as 2D | |
outputs = model_output | |
if len(mask_input.shape) == 4: | |
mask_input = mask_input.unsqueeze(2) | |
if len(outputs.shape) == 4: | |
outputs = outputs.unsqueeze(2) | |
outputs = outputs.permute((0, 2, 1, 3, 4)).contiguous() | |
masks = mask_input.permute((0, 2, 1, 3, 4)).contiguous() | |
B, L, C, H, W = outputs.shape | |
x = outputs.view([B * L, C, H, W]) | |
masks = masks.view([B * L, -1]) | |
mask_areas = masks.sum(dim=1) | |
h_x = x.size()[2] | |
w_x = x.size()[3] | |
h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :h_x - 1, :]), 2).sum(1).sum(1).sum(1) # 差分是为了求梯度,本质上还是梯度平方和 | |
w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, :w_x - 1]), 2).sum(1).sum(1).sum(1) | |
return ((h_tv + w_tv) / mask_areas).mean() | |
# for debug | |
def show_images(image, name): | |
import cv2 | |
import numpy as np | |
image = np.array(image) | |
image[image > 0.5] = 255. | |
image = image.transpose((1, 2, 0)) | |
cv2.imwrite(name, image) | |
if __name__ == '__main__': | |
# test align loss, | |
targetFrame = torch.ones(1, 3, 32, 32) | |
GT = torch.ones(1, 3, 32, 32) | |
GT += 1 | |
mask = torch.zeros(1, 1, 32, 32) | |
mask[:, :, 8:24, 8:24] = 1. | |
# referenceFrames = torch.ones(1, 3, 4, 32, 32) | |
# referenceMasks = torch.zeros(1, 1, 4, 32, 32) | |
# referenceMasks[:, :, 0, 4:12, 4:12] = 1. | |
# referenceFrames[:, :, 0, 4:12, 4:12] = 2. | |
# referenceMasks[:, :, 1, 4:12, 20:28] = 1. | |
# referenceFrames[:, :, 1, 4:12, 20:28] = 2. | |
# referenceMasks[:, :, 2, 20:28, 4:12] = 1. | |
# referenceFrames[:, :, 2, 20:28, 4:12] = 2. | |
# referenceMasks[:, :, 3, 20:28, 20:28] = 1. | |
# referenceFrames[:, :, 3, 20:28, 20:28] = 2. | |
# | |
# aligned_v = referenceMasks | |
# aligned_v, referenceFrames = [aligned_v], [referenceFrames] | |
# | |
# result = AlignLoss()(targetFrame, mask, aligned_v, referenceFrames) | |
# print(result) | |
c_mask = torch.zeros(1, 1, 32, 32) | |
c_mask[:, :, 8:16, 16:24] = 1. | |
result1 = HoleVisibleLoss()(targetFrame, mask, GT, c_mask) | |
result2 = HoleInvisibleLoss()(targetFrame, mask, GT, c_mask) | |
result3 = NonHoleLoss()(targetFrame, mask, GT) | |
print('vis: {}, invis: {}, gt: {}'.format(result1, result2, result3)) | |