harlanhong commited on
Commit
bcec73a
1 Parent(s): cd0695c
app.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ import torch
4
+ import cv2
5
+ import gradio as gr
6
+ from PIL import Image
7
+
8
+ #os.chdir('Restormer')
9
+
10
+ # Download sample images
11
+ os.system("wget https://github.com/swz30/Restormer/releases/download/v1.0/sample_images.zip")
12
+ shutil.unpack_archive('sample_images.zip')
13
+ os.remove('sample_images.zip')
14
+
15
+
16
+ examples = [['project/cartoon2.jpg'],
17
+ ['project/cartoon3.jpg'],
18
+ ['project/celeb1.jpg'],
19
+ ['project/celeb2.jpg']
20
+ ]
21
+
22
+
23
+ inference_on = ['Full Resolution Image', 'Downsampled Image']
24
+
25
+ title = "DaGAN"
26
+ description = """
27
+ Gradio demo for <b>Depth-Aware Generative Adversarial Network for Talking Head Video Generation</b>, CVPR 2022L. <a href='https://arxiv.org/abs/2203.06605'>[Paper]</a><a href='https://github.com/harlanhong/CVPR2022-DaGAN'>[Github Code]</a>\n
28
+ """
29
+ ##With Restormer, you can perform: (1) Image Denoising, (2) Defocus Deblurring, (3) Motion Deblurring, and (4) Image Deraining.
30
+ ##To use it, simply upload your own image, or click one of the examples provided below.
31
+
32
+ article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2203.06605'>Depth-Aware Generative Adversarial Network for Talking Head Video Generation</a> | <a href='https://github.com/harlanhong/CVPR2022-DaGAN'>Github Repo</a></p>"
33
+
34
+
35
+ def inference(img, video):
36
+ if not os.path.exists('temp'):
37
+ os.system('mkdir temp')
38
+
39
+ #### Resize the longer edge of the input image
40
+ max_res = 256
41
+ width, height = img.size
42
+ if max(width,height) > max_res:
43
+ scale = max_res /max(width,height)
44
+ width = int(scale*width)
45
+ height = int(scale*height)
46
+ img = img.resize((width,height), Image.ANTIALIAS)
47
+
48
+ img.save("temp/image.jpg", "JPEG")
49
+ video.save('temp/video.mp4')
50
+ os.system("python demo_dagan.py --source_image 'temp/image.jpg' --driving_video {} --output 'temp/rst.mp4'".format(video))
51
+
52
+ return f'temp/rst.mp4'
53
+
54
+ gr.Interface(
55
+ inference,
56
+ [
57
+ gr.inputs.Image(type="pil", label="Input"),
58
+ gr.inputs.Video(label="Input"),
59
+ ],
60
+ gr.outputs.Video(type="mp4", label="Output"),
61
+ title=title,
62
+ description=description,
63
+ article=article,
64
+ theme ="huggingface",
65
+ examples=examples,
66
+ allow_flagging=False,
67
+ ).launch(debug=False,enable_queue=True)
config/vox-adv-256.yaml ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset_params:
2
+ root_dir: /data/fhongac/origDataset/vox1_frames
3
+ frame_shape: [256, 256, 3]
4
+ id_sampling: True
5
+ pairs_list: data/vox256.csv
6
+ augmentation_params:
7
+ flip_param:
8
+ horizontal_flip: True
9
+ time_flip: True
10
+ jitter_param:
11
+ brightness: 0.1
12
+ contrast: 0.1
13
+ saturation: 0.1
14
+ hue: 0.1
15
+
16
+
17
+ model_params:
18
+ common_params:
19
+ num_kp: 15
20
+ num_channels: 3
21
+ estimate_jacobian: True
22
+ kp_detector_params:
23
+ temperature: 0.1
24
+ block_expansion: 32
25
+ max_features: 1024
26
+ scale_factor: 0.25
27
+ num_blocks: 5
28
+ generator_params:
29
+ block_expansion: 64
30
+ max_features: 512
31
+ num_down_blocks: 2
32
+ num_bottleneck_blocks: 6
33
+ estimate_occlusion_map: True
34
+ dense_motion_params:
35
+ block_expansion: 64
36
+ max_features: 1024
37
+ num_blocks: 5
38
+ scale_factor: 0.25
39
+ discriminator_params:
40
+ scales: [1]
41
+ block_expansion: 32
42
+ max_features: 512
43
+ num_blocks: 4
44
+ use_kp: True
45
+
46
+
47
+ train_params:
48
+ num_epochs: 150
49
+ num_repeats: 75
50
+ epoch_milestones: []
51
+ lr_generator: 2.0e-4
52
+ lr_discriminator: 2.0e-4
53
+ lr_kp_detector: 2.0e-4
54
+ batch_size: 4
55
+ scales: [1, 0.5, 0.25, 0.125]
56
+ checkpoint_freq: 10
57
+ transform_params:
58
+ sigma_affine: 0.05
59
+ sigma_tps: 0.005
60
+ points_tps: 5
61
+ loss_weights:
62
+ generator_gan: 1
63
+ discriminator_gan: 1
64
+ feature_matching: [10, 10, 10, 10]
65
+ perceptual: [10, 10, 10, 10, 10]
66
+ equivariance_value: 10
67
+ equivariance_jacobian: 10
68
+ kp_distance: 10
69
+ kp_prior: 0
70
+ kp_scale: 0
71
+ depth_constraint: 0
72
+
73
+ reconstruction_params:
74
+ num_videos: 1000
75
+ format: '.mp4'
76
+
77
+ animate_params:
78
+ num_pairs: 50
79
+ format: '.mp4'
80
+ normalization_params:
81
+ adapt_movement_scale: False
82
+ use_relative_movement: True
83
+ use_relative_jacobian: True
84
+
85
+ visualizer_params:
86
+ kp_size: 5
87
+ draw_border: True
88
+ colormap: 'gist_rainbow'
demo_dagan.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Restormer: Efficient Transformer for High-Resolution Image Restoration
2
+ ## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang
3
+ ## https://arxiv.org/abs/2111.09881
4
+
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import os
9
+ from skimage import img_as_ubyte
10
+ import cv2
11
+ import argparse
12
+ import imageio
13
+ from skimage.transform import resize
14
+ from scipy.spatial import ConvexHull
15
+ from tqdm import tqdm
16
+ import numpy as np
17
+ import modules.generator as G
18
+ import modules.keypoint_detector as KPD
19
+ import yaml
20
+ from collections import OrderedDict
21
+ import depth
22
+ parser = argparse.ArgumentParser(description='Test DaGAN on your own images')
23
+ parser.add_argument('--source_image', default='./temp/source.jpg', type=str, help='Directory of input source image')
24
+ parser.add_argument('--driving_video', default='./temp/driving.mp4', type=str, help='Directory for driving video')
25
+ parser.add_argument('--output', default='./temp/result.mp4', type=str, help='Directory for driving video')
26
+
27
+
28
+ args = parser.parse_args()
29
+ def normalize_kp(kp_source, kp_driving, kp_driving_initial, adapt_movement_scale=False,
30
+ use_relative_movement=False, use_relative_jacobian=False):
31
+ if adapt_movement_scale:
32
+ source_area = ConvexHull(kp_source['value'][0].data.cpu().numpy()).volume
33
+ driving_area = ConvexHull(kp_driving_initial['value'][0].data.cpu().numpy()).volume
34
+ adapt_movement_scale = np.sqrt(source_area) / np.sqrt(driving_area)
35
+ else:
36
+ adapt_movement_scale = 1
37
+
38
+ kp_new = {k: v for k, v in kp_driving.items()}
39
+
40
+ if use_relative_movement:
41
+ kp_value_diff = (kp_driving['value'] - kp_driving_initial['value'])
42
+ kp_value_diff *= adapt_movement_scale
43
+ kp_new['value'] = kp_value_diff + kp_source['value']
44
+
45
+ if use_relative_jacobian:
46
+ jacobian_diff = torch.matmul(kp_driving['jacobian'], torch.inverse(kp_driving_initial['jacobian']))
47
+ kp_new['jacobian'] = torch.matmul(jacobian_diff, kp_source['jacobian'])
48
+ return kp_new
49
+ def find_best_frame(source, driving, cpu=False):
50
+ import face_alignment
51
+
52
+ def normalize_kp(kp):
53
+ kp = kp - kp.mean(axis=0, keepdims=True)
54
+ area = ConvexHull(kp[:, :2]).volume
55
+ area = np.sqrt(area)
56
+ kp[:, :2] = kp[:, :2] / area
57
+ return kp
58
+
59
+ fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=True,
60
+ device='cpu' if cpu else 'cuda')
61
+ kp_source = fa.get_landmarks(255 * source)[0]
62
+ kp_source = normalize_kp(kp_source)
63
+ norm = float('inf')
64
+ frame_num = 0
65
+ for i, image in tqdm(enumerate(driving)):
66
+ kp_driving = fa.get_landmarks(255 * image)[0]
67
+ kp_driving = normalize_kp(kp_driving)
68
+ new_norm = (np.abs(kp_source - kp_driving) ** 2).sum()
69
+ if new_norm < norm:
70
+ norm = new_norm
71
+ frame_num = i
72
+ return frame_num
73
+
74
+
75
+ def make_animation(source_image, driving_video, generator, kp_detector, relative=True, adapt_movement_scale=True, cpu=False):
76
+ sources = []
77
+ drivings = []
78
+ with torch.no_grad():
79
+ predictions = []
80
+ depth_gray = []
81
+ source = torch.tensor(source_image[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2)
82
+ driving = torch.tensor(np.array(driving_video)[np.newaxis].astype(np.float32)).permute(0, 4, 1, 2, 3)
83
+ if not cpu:
84
+ source = source.cuda()
85
+ driving = driving.cuda()
86
+ outputs = depth_decoder(depth_encoder(source))
87
+ depth_source = outputs[("disp", 0)]
88
+
89
+ outputs = depth_decoder(depth_encoder(driving[:, :, 0]))
90
+ depth_driving = outputs[("disp", 0)]
91
+ source_kp = torch.cat((source,depth_source),1)
92
+ driving_kp = torch.cat((driving[:, :, 0],depth_driving),1)
93
+
94
+ kp_source = kp_detector(source_kp)
95
+ kp_driving_initial = kp_detector(driving_kp)
96
+
97
+ # kp_source = kp_detector(source)
98
+ # kp_driving_initial = kp_detector(driving[:, :, 0])
99
+
100
+ for frame_idx in tqdm(range(driving.shape[2])):
101
+ driving_frame = driving[:, :, frame_idx]
102
+
103
+ if not cpu:
104
+ driving_frame = driving_frame.cuda()
105
+ outputs = depth_decoder(depth_encoder(driving_frame))
106
+ depth_map = outputs[("disp", 0)]
107
+
108
+ gray_driving = np.transpose(depth_map.data.cpu().numpy(), [0, 2, 3, 1])[0]
109
+ gray_driving = 1-gray_driving/np.max(gray_driving)
110
+
111
+ frame = torch.cat((driving_frame,depth_map),1)
112
+ kp_driving = kp_detector(frame)
113
+
114
+ kp_norm = normalize_kp(kp_source=kp_source, kp_driving=kp_driving,
115
+ kp_driving_initial=kp_driving_initial, use_relative_movement=relative,
116
+ use_relative_jacobian=relative, adapt_movement_scale=adapt_movement_scale)
117
+ out = generator(source, kp_source=kp_source, kp_driving=kp_norm,source_depth = depth_source, driving_depth = depth_map)
118
+
119
+ drivings.append(np.transpose(driving_frame.data.cpu().numpy(), [0, 2, 3, 1])[0])
120
+ sources.append(np.transpose(source.data.cpu().numpy(), [0, 2, 3, 1])[0])
121
+ predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0])
122
+ depth_gray.append(gray_driving)
123
+ return sources, drivings, predictions,depth_gray
124
+ with open("config/vox-adv-256.yaml") as f:
125
+ config = yaml.load(f)
126
+ generator = G.SPADEDepthAwareGenerator(**config['model_params']['generator_params'],**config['model_params']['common_params'])
127
+ config['model_params']['common_params']['num_channels'] = 4
128
+ kp_detector = KPD.KPDetector(**config['model_params']['kp_detector_params'],**config['model_params']['common_params'])
129
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
130
+
131
+
132
+ g_checkpoint = torch.load("generator.pt", map_location=device)
133
+ kp_checkpoint = torch.load("kp_detector.pt", map_location=device)
134
+
135
+ ckp_generator = OrderedDict((k.replace('module.',''),v) for k,v in g_checkpoint.items())
136
+ generator.load_state_dict(ckp_generator)
137
+ ckp_kp_detector = OrderedDict((k.replace('module.',''),v) for k,v in kp_checkpoint.items())
138
+ kp_detector.load_state_dict(ckp_kp_detector)
139
+
140
+ depth_encoder = depth.ResnetEncoder(18, False)
141
+ depth_decoder = depth.DepthDecoder(num_ch_enc=depth_encoder.num_ch_enc, scales=range(4))
142
+ loaded_dict_enc = torch.load('encoder.pth')
143
+ loaded_dict_dec = torch.load('depth.pth')
144
+ filtered_dict_enc = {k: v for k, v in loaded_dict_enc.items() if k in depth_encoder.state_dict()}
145
+ depth_encoder.load_state_dict(filtered_dict_enc)
146
+ ckp_depth_decoder= {k: v for k, v in loaded_dict_dec.items() if k in depth_decoder.state_dict()}
147
+ depth_decoder.load_state_dict(ckp_depth_decoder)
148
+ depth_encoder.eval()
149
+ depth_decoder.eval()
150
+
151
+ # device = torch.device('cpu')
152
+ # stx()
153
+
154
+ generator = generator.to(device)
155
+ kp_detector = kp_detector.to(device)
156
+ depth_encoder = depth_encoder.to(device)
157
+ depth_decoder = depth_decoder.to(device)
158
+
159
+ generator.eval()
160
+ kp_detector.eval()
161
+ depth_encoder.eval()
162
+ depth_decoder.eval()
163
+
164
+ img_multiple_of = 8
165
+
166
+ with torch.inference_mode():
167
+ if torch.cuda.is_available():
168
+ torch.cuda.ipc_collect()
169
+ torch.cuda.empty_cache()
170
+ source_image = imageio.imread(args.source_image)
171
+ reader = imageio.get_reader(args.driving_video)
172
+ fps = reader.get_meta_data()['fps']
173
+ driving_video = []
174
+ try:
175
+ for im in reader:
176
+ driving_video.append(im)
177
+ except RuntimeError:
178
+ pass
179
+ reader.close()
180
+
181
+ source_image = resize(source_image, (256, 256))[..., :3]
182
+ driving_video = [resize(frame, (256, 256))[..., :3] for frame in driving_video]
183
+
184
+
185
+
186
+ i = find_best_frame(source_image, driving_video)
187
+ print ("Best frame: " + str(i))
188
+ driving_forward = driving_video[i:]
189
+ driving_backward = driving_video[:(i+1)][::-1]
190
+ sources_forward, drivings_forward, predictions_forward,depth_forward = make_animation(source_image, driving_forward, generator, kp_detector, relative=True, adapt_movement_scale=True, cpu=False)
191
+ sources_backward, drivings_backward, predictions_backward,depth_backward = make_animation(source_image, driving_backward, generator, kp_detector, relative=True, adapt_movement_scale=True, cpu=False)
192
+ predictions = predictions_backward[::-1] + predictions_forward[1:]
193
+ sources = sources_backward[::-1] + sources_forward[1:]
194
+ drivings = drivings_backward[::-1] + drivings_forward[1:]
195
+ depth_gray = depth_backward[::-1] + depth_forward[1:]
196
+
197
+ imageio.mimsave(args.output, [np.concatenate((img_as_ubyte(s),img_as_ubyte(d),img_as_ubyte(p)),1) for (s,d,p) in zip(sources, drivings, predictions)], fps=fps)
198
+ imageio.mimsave("gray.mp4", depth_gray, fps=fps)
199
+ # merge the gray video
200
+ animation = np.array(imageio.mimread(args.output,memtest=False))
201
+ gray = np.array(imageio.mimread("gray.mp4",memtest=False))
202
+
203
+ src_dst = animation[:,:,:512,:]
204
+ animate = animation[:,:,512:,:]
205
+ merge = np.concatenate((src_dst,gray,animate),2)
206
+ imageio.mimsave(args.output, merge, fps=fps)
207
+
208
+ # print(f"\nRestored images are saved at {out_dir}")
depth/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .resnet_encoder import ResnetEncoder
2
+ from .depth_decoder import DepthDecoder
3
+ from .pose_decoder import PoseDecoder
4
+ from .pose_cnn import PoseCNN
5
+
depth/depth_decoder.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright Niantic 2019. Patent Pending. All rights reserved.
2
+ #
3
+ # This software is licensed under the terms of the Monodepth2 licence
4
+ # which allows for non-commercial use only, the full terms of which are made
5
+ # available in the LICENSE file.
6
+
7
+ from __future__ import absolute_import, division, print_function
8
+
9
+ import numpy as np
10
+ import torch
11
+ import torch.nn as nn
12
+
13
+ from collections import OrderedDict
14
+ from depth.layers import *
15
+
16
+
17
+ class DepthDecoder(nn.Module):
18
+ def __init__(self, num_ch_enc, scales=range(4), num_output_channels=1, use_skips=True):
19
+ super(DepthDecoder, self).__init__()
20
+
21
+ self.num_output_channels = num_output_channels
22
+ self.use_skips = use_skips
23
+ self.upsample_mode = 'nearest'
24
+ self.scales = scales
25
+
26
+ self.num_ch_enc = num_ch_enc
27
+ self.num_ch_dec = np.array([16, 32, 64, 128, 256])
28
+
29
+ # decoder
30
+ self.convs = OrderedDict()
31
+ for i in range(4, -1, -1):
32
+ # upconv_0
33
+ num_ch_in = self.num_ch_enc[-1] if i == 4 else self.num_ch_dec[i + 1]
34
+ num_ch_out = self.num_ch_dec[i]
35
+ self.convs[("upconv", i, 0)] = ConvBlock(num_ch_in, num_ch_out)
36
+
37
+ # upconv_1
38
+ num_ch_in = self.num_ch_dec[i]
39
+ if self.use_skips and i > 0:
40
+ num_ch_in += self.num_ch_enc[i - 1]
41
+ num_ch_out = self.num_ch_dec[i]
42
+ self.convs[("upconv", i, 1)] = ConvBlock(num_ch_in, num_ch_out)
43
+
44
+ for s in self.scales:
45
+ self.convs[("dispconv", s)] = Conv3x3(self.num_ch_dec[s], self.num_output_channels)
46
+
47
+ self.decoder = nn.ModuleList(list(self.convs.values()))
48
+ self.sigmoid = nn.Sigmoid()
49
+
50
+ def forward(self, input_features):
51
+ self.outputs = {}
52
+
53
+ # decoder
54
+ x = input_features[-1]
55
+ for i in range(4, -1, -1):
56
+ x = self.convs[("upconv", i, 0)](x)
57
+ x = [upsample(x)]
58
+ if self.use_skips and i > 0:
59
+ x += [input_features[i - 1]]
60
+ x = torch.cat(x, 1)
61
+ x = self.convs[("upconv", i, 1)](x)
62
+ if i in self.scales:
63
+ self.outputs[("disp", i)] = self.sigmoid(self.convs[("dispconv", i)](x))
64
+
65
+ return self.outputs
depth/layers.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright Niantic 2019. Patent Pending. All rights reserved.
2
+ #
3
+ # This software is licensed under the terms of the Monodepth2 licence
4
+ # which allows for non-commercial use only, the full terms of which are made
5
+ # available in the LICENSE file.
6
+
7
+ from __future__ import absolute_import, division, print_function
8
+
9
+ import numpy as np
10
+ import pdb
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+
15
+
16
+ def disp_to_depth(disp, min_depth, max_depth):
17
+ """Convert network's sigmoid output into depth prediction
18
+ The formula for this conversion is given in the 'additional considerations'
19
+ section of the paper.
20
+ """
21
+ min_disp = 1 / max_depth
22
+ max_disp = 1 / min_depth
23
+ scaled_disp = min_disp + (max_disp - min_disp) * disp
24
+ depth = 1 / scaled_disp
25
+ return scaled_disp, depth
26
+
27
+
28
+ def transformation_from_parameters(axisangle, translation, invert=False):
29
+ """Convert the network's (axisangle, translation) output into a 4x4 matrix
30
+ """
31
+ R = rot_from_axisangle(axisangle)
32
+ t = translation.clone()
33
+
34
+ if invert:
35
+ R = R.transpose(1, 2)
36
+ t *= -1
37
+
38
+ T = get_translation_matrix(t)
39
+
40
+ if invert:
41
+ M = torch.matmul(R, T)
42
+ else:
43
+ M = torch.matmul(T, R)
44
+
45
+ return M
46
+
47
+
48
+ def get_translation_matrix(translation_vector):
49
+ """Convert a translation vector into a 4x4 transformation matrix
50
+ """
51
+ T = torch.zeros(translation_vector.shape[0], 4, 4).to(device=translation_vector.device)
52
+
53
+ t = translation_vector.contiguous().view(-1, 3, 1)
54
+
55
+ T[:, 0, 0] = 1
56
+ T[:, 1, 1] = 1
57
+ T[:, 2, 2] = 1
58
+ T[:, 3, 3] = 1
59
+ T[:, :3, 3, None] = t
60
+
61
+ return T
62
+
63
+
64
+ def rot_from_axisangle(vec):
65
+ """Convert an axisangle rotation into a 4x4 transformation matrix
66
+ (adapted from https://github.com/Wallacoloo/printipi)
67
+ Input 'vec' has to be Bx1x3
68
+ """
69
+ angle = torch.norm(vec, 2, 2, True)
70
+ axis = vec / (angle + 1e-7)
71
+
72
+ ca = torch.cos(angle)
73
+ sa = torch.sin(angle)
74
+ C = 1 - ca
75
+
76
+ x = axis[..., 0].unsqueeze(1)
77
+ y = axis[..., 1].unsqueeze(1)
78
+ z = axis[..., 2].unsqueeze(1)
79
+
80
+ xs = x * sa
81
+ ys = y * sa
82
+ zs = z * sa
83
+ xC = x * C
84
+ yC = y * C
85
+ zC = z * C
86
+ xyC = x * yC
87
+ yzC = y * zC
88
+ zxC = z * xC
89
+
90
+ rot = torch.zeros((vec.shape[0], 4, 4)).to(device=vec.device)
91
+
92
+ rot[:, 0, 0] = torch.squeeze(x * xC + ca)
93
+ rot[:, 0, 1] = torch.squeeze(xyC - zs)
94
+ rot[:, 0, 2] = torch.squeeze(zxC + ys)
95
+ rot[:, 1, 0] = torch.squeeze(xyC + zs)
96
+ rot[:, 1, 1] = torch.squeeze(y * yC + ca)
97
+ rot[:, 1, 2] = torch.squeeze(yzC - xs)
98
+ rot[:, 2, 0] = torch.squeeze(zxC - ys)
99
+ rot[:, 2, 1] = torch.squeeze(yzC + xs)
100
+ rot[:, 2, 2] = torch.squeeze(z * zC + ca)
101
+ rot[:, 3, 3] = 1
102
+
103
+ return rot
104
+
105
+
106
+ class ConvBlock(nn.Module):
107
+ """Layer to perform a convolution followed by ELU
108
+ """
109
+ def __init__(self, in_channels, out_channels):
110
+ super(ConvBlock, self).__init__()
111
+
112
+ self.conv = Conv3x3(in_channels, out_channels)
113
+ self.nonlin = nn.ELU(inplace=True)
114
+
115
+ def forward(self, x):
116
+ out = self.conv(x)
117
+ out = self.nonlin(out)
118
+ return out
119
+
120
+
121
+ class Conv3x3(nn.Module):
122
+ """Layer to pad and convolve input
123
+ """
124
+ def __init__(self, in_channels, out_channels, use_refl=True):
125
+ super(Conv3x3, self).__init__()
126
+
127
+ if use_refl:
128
+ self.pad = nn.ReflectionPad2d(1)
129
+ else:
130
+ self.pad = nn.ZeroPad2d(1)
131
+ self.conv = nn.Conv2d(int(in_channels), int(out_channels), 3)
132
+
133
+ def forward(self, x):
134
+ out = self.pad(x)
135
+ out = self.conv(out)
136
+ return out
137
+
138
+
139
+ class BackprojectDepth(nn.Module):
140
+ """Layer to transform a depth image into a point cloud
141
+ """
142
+ def __init__(self, batch_size, height, width):
143
+ super(BackprojectDepth, self).__init__()
144
+
145
+ self.batch_size = batch_size
146
+ self.height = height
147
+ self.width = width
148
+
149
+ meshgrid = np.meshgrid(range(self.width), range(self.height), indexing='xy')
150
+ self.id_coords = np.stack(meshgrid, axis=0).astype(np.float32)
151
+ self.id_coords = nn.Parameter(torch.from_numpy(self.id_coords),
152
+ requires_grad=False)
153
+
154
+ self.ones = nn.Parameter(torch.ones(self.batch_size, 1, self.height * self.width),
155
+ requires_grad=False)
156
+
157
+ self.pix_coords = torch.unsqueeze(torch.stack(
158
+ [self.id_coords[0].view(-1), self.id_coords[1].view(-1)], 0), 0)
159
+ self.pix_coords = self.pix_coords.repeat(batch_size, 1, 1)
160
+ self.pix_coords = nn.Parameter(torch.cat([self.pix_coords, self.ones], 1),
161
+ requires_grad=False)
162
+
163
+ def forward(self, depth, K,scale):
164
+ K[:,:2,:] = (K[:,:2,:]/(2 ** scale)).trunc()
165
+ b,n,n = K.shape
166
+ inv_K = torch.linalg.inv(K)
167
+ #inv_K = torch.cholesky_inverse(K)
168
+ pad = torch.tensor([0.0,0.0,0.0]).view(1,3,1).expand(b,3,1).cuda()
169
+ inv_K = torch.cat([inv_K,pad],-1)
170
+ pad = torch.tensor([0.0,0.0,0.0,1.0]).view(1,1,4).expand(b,1,4).cuda()
171
+ inv_K = torch.cat([inv_K,pad],1)
172
+ cam_points = torch.matmul(inv_K[:, :3, :3], self.pix_coords)
173
+ cam_points = depth.view(self.batch_size, 1, -1) * cam_points
174
+ cam_points = torch.cat([cam_points, self.ones], 1)
175
+
176
+ return cam_points
177
+
178
+
179
+ class Project3D(nn.Module):
180
+ """Layer which projects 3D points into a camera with intrinsics K and at position T
181
+ """
182
+ def __init__(self, batch_size, height, width, eps=1e-7):
183
+ super(Project3D, self).__init__()
184
+
185
+ self.batch_size = batch_size
186
+ self.height = height
187
+ self.width = width
188
+ self.eps = eps
189
+
190
+ def forward(self, points, K, T,scale=0):
191
+ # K[0, :] *= self.width // (2 ** scale)
192
+ # K[1, :] *= self.height // (2 ** scale)
193
+ K[:,:2,:] = (K[:,:2,:]/(2 ** scale)).trunc()
194
+ b,n,n = K.shape
195
+ pad = torch.tensor([0.0,0.0,0.0]).view(1,3,1).expand(b,3,1).cuda()
196
+ K = torch.cat([K,pad],-1)
197
+ pad = torch.tensor([0.0,0.0,0.0,1.0]).view(1,1,4).expand(b,1,4).cuda()
198
+ K = torch.cat([K,pad],1)
199
+ P = torch.matmul(K, T)[:, :3, :]
200
+
201
+ cam_points = torch.matmul(P, points)
202
+
203
+ pix_coords = cam_points[:, :2, :] / (cam_points[:, 2, :].unsqueeze(1) + self.eps)
204
+ pix_coords = pix_coords.view(self.batch_size, 2, self.height, self.width)
205
+ pix_coords = pix_coords.permute(0, 2, 3, 1)
206
+ pix_coords[..., 0] /= self.width - 1
207
+ pix_coords[..., 1] /= self.height - 1
208
+ pix_coords = (pix_coords - 0.5) * 2
209
+ return pix_coords
210
+
211
+
212
+ def upsample(x):
213
+ """Upsample input tensor by a factor of 2
214
+ """
215
+ return F.interpolate(x, scale_factor=2, mode="nearest")
216
+
217
+
218
+ def get_smooth_loss(disp, img):
219
+ """Computes the smoothness loss for a disparity image
220
+ The color image is used for edge-aware smoothness
221
+ """
222
+ grad_disp_x = torch.abs(disp[:, :, :, :-1] - disp[:, :, :, 1:])
223
+ grad_disp_y = torch.abs(disp[:, :, :-1, :] - disp[:, :, 1:, :])
224
+
225
+ grad_img_x = torch.mean(torch.abs(img[:, :, :, :-1] - img[:, :, :, 1:]), 1, keepdim=True)
226
+ grad_img_y = torch.mean(torch.abs(img[:, :, :-1, :] - img[:, :, 1:, :]), 1, keepdim=True)
227
+
228
+ grad_disp_x *= torch.exp(-grad_img_x)
229
+ grad_disp_y *= torch.exp(-grad_img_y)
230
+
231
+ return grad_disp_x.mean() + grad_disp_y.mean()
232
+
233
+
234
+ class SSIM(nn.Module):
235
+ """Layer to compute the SSIM loss between a pair of images
236
+ """
237
+ def __init__(self):
238
+ super(SSIM, self).__init__()
239
+ self.mu_x_pool = nn.AvgPool2d(3, 1)
240
+ self.mu_y_pool = nn.AvgPool2d(3, 1)
241
+ self.sig_x_pool = nn.AvgPool2d(3, 1)
242
+ self.sig_y_pool = nn.AvgPool2d(3, 1)
243
+ self.sig_xy_pool = nn.AvgPool2d(3, 1)
244
+
245
+ self.refl = nn.ReflectionPad2d(1)
246
+
247
+ self.C1 = 0.01 ** 2
248
+ self.C2 = 0.03 ** 2
249
+
250
+ def forward(self, x, y):
251
+ x = self.refl(x)
252
+ y = self.refl(y)
253
+
254
+ mu_x = self.mu_x_pool(x)
255
+ mu_y = self.mu_y_pool(y)
256
+
257
+ sigma_x = self.sig_x_pool(x ** 2) - mu_x ** 2
258
+ sigma_y = self.sig_y_pool(y ** 2) - mu_y ** 2
259
+ sigma_xy = self.sig_xy_pool(x * y) - mu_x * mu_y
260
+
261
+ SSIM_n = (2 * mu_x * mu_y + self.C1) * (2 * sigma_xy + self.C2)
262
+ SSIM_d = (mu_x ** 2 + mu_y ** 2 + self.C1) * (sigma_x + sigma_y + self.C2)
263
+
264
+ return torch.clamp((1 - SSIM_n / SSIM_d) / 2, 0, 1)
265
+
266
+
267
+ def compute_depth_errors(gt, pred):
268
+ """Computation of error metrics between predicted and ground truth depths
269
+ """
270
+ thresh = torch.max((gt / pred), (pred / gt))
271
+ a1 = (thresh < 1.25 ).float().mean()
272
+ a2 = (thresh < 1.25 ** 2).float().mean()
273
+ a3 = (thresh < 1.25 ** 3).float().mean()
274
+
275
+ rmse = (gt - pred) ** 2
276
+ rmse = torch.sqrt(rmse.mean())
277
+
278
+ rmse_log = (torch.log(gt) - torch.log(pred)) ** 2
279
+ rmse_log = torch.sqrt(rmse_log.mean())
280
+
281
+ abs_rel = torch.mean(torch.abs(gt - pred) / gt)
282
+
283
+ sq_rel = torch.mean((gt - pred) ** 2 / gt)
284
+
285
+ return abs_rel, sq_rel, rmse, rmse_log, a1, a2, a3
depth/pose_cnn.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright Niantic 2019. Patent Pending. All rights reserved.
2
+ #
3
+ # This software is licensed under the terms of the Monodepth2 licence
4
+ # which allows for non-commercial use only, the full terms of which are made
5
+ # available in the LICENSE file.
6
+
7
+ from __future__ import absolute_import, division, print_function
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+
12
+
13
+ class PoseCNN(nn.Module):
14
+ def __init__(self, num_input_frames):
15
+ super(PoseCNN, self).__init__()
16
+
17
+ self.num_input_frames = num_input_frames
18
+
19
+ self.convs = {}
20
+ self.convs[0] = nn.Conv2d(3 * num_input_frames, 16, 7, 2, 3)
21
+ self.convs[1] = nn.Conv2d(16, 32, 5, 2, 2)
22
+ self.convs[2] = nn.Conv2d(32, 64, 3, 2, 1)
23
+ self.convs[3] = nn.Conv2d(64, 128, 3, 2, 1)
24
+ self.convs[4] = nn.Conv2d(128, 256, 3, 2, 1)
25
+ self.convs[5] = nn.Conv2d(256, 256, 3, 2, 1)
26
+ self.convs[6] = nn.Conv2d(256, 256, 3, 2, 1)
27
+
28
+ self.pose_conv = nn.Conv2d(256, 6 * (num_input_frames - 1), 1)
29
+
30
+ self.num_convs = len(self.convs)
31
+
32
+ self.relu = nn.ReLU(True)
33
+
34
+ self.net = nn.ModuleList(list(self.convs.values()))
35
+
36
+ def forward(self, out):
37
+
38
+ for i in range(self.num_convs):
39
+ out = self.convs[i](out)
40
+ out = self.relu(out)
41
+
42
+ out = self.pose_conv(out)
43
+ out = out.mean(3).mean(2)
44
+
45
+ out = 0.01 * out.view(-1, self.num_input_frames - 1, 1, 6)
46
+
47
+ axisangle = out[..., :3]
48
+ translation = out[..., 3:]
49
+
50
+ return axisangle, translation
depth/pose_decoder.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright Niantic 2019. Patent Pending. All rights reserved.
2
+ #
3
+ # This software is licensed under the terms of the Monodepth2 licence
4
+ # which allows for non-commercial use only, the full terms of which are made
5
+ # available in the LICENSE file.
6
+
7
+ from __future__ import absolute_import, division, print_function
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from collections import OrderedDict
12
+ import pdb
13
+ import torch.nn.functional as F
14
+ # from options import MonodepthOptions
15
+ # options = MonodepthOptions()
16
+ # opts = options.parse()
17
+ class PoseDecoder(nn.Module):
18
+ def __init__(self, num_ch_enc, num_input_features, num_frames_to_predict_for=None, stride=1):
19
+ super(PoseDecoder, self).__init__()
20
+ self.num_ch_enc = num_ch_enc
21
+ self.num_input_features = num_input_features
22
+
23
+ if num_frames_to_predict_for is None:
24
+ num_frames_to_predict_for = num_input_features - 1
25
+ self.num_frames_to_predict_for = num_frames_to_predict_for
26
+
27
+ self.convs = OrderedDict()
28
+ self.convs[("squeeze")] = nn.Conv2d(self.num_ch_enc[-1], 256, 1)
29
+ self.convs[("pose", 0)] = nn.Conv2d(num_input_features * 256, 256, 3, stride, 1)
30
+ self.convs[("pose", 1)] = nn.Conv2d(256, 256, 3, stride, 1)
31
+ self.convs[("pose", 2)] = nn.Conv2d(256, 6 * num_frames_to_predict_for, 1)
32
+ self.convs[("intrinsics", 'focal')] = nn.Conv2d(256, 2, kernel_size = 3,stride = 1,padding = 1)
33
+ self.convs[("intrinsics", 'offset')] = nn.Conv2d(256, 2, kernel_size = 3,stride = 1,padding = 1)
34
+
35
+ self.relu = nn.ReLU()
36
+ self.net = nn.ModuleList(list(self.convs.values()))
37
+
38
+ def forward(self, input_features):
39
+ last_features = [f[-1] for f in input_features]
40
+
41
+ cat_features = [self.relu(self.convs["squeeze"](f)) for f in last_features]
42
+ cat_features = torch.cat(cat_features, 1)
43
+
44
+ feat = cat_features
45
+ for i in range(2):
46
+ feat = self.convs[("pose", i)](feat)
47
+ feat = self.relu(feat)
48
+ out = self.convs[("pose", 2)](feat)
49
+
50
+ out = out.mean(3).mean(2)
51
+ out = 0.01 * out.view(-1, self.num_frames_to_predict_for, 1, 6)
52
+
53
+ axisangle = out[..., :3]
54
+ translation = out[..., 3:]
55
+
56
+ #add_intrinsics_head
57
+ scales = torch.tensor([256,256]).cuda()
58
+ focals = F.softplus(self.convs[("intrinsics", 'focal')](feat)).mean(3).mean(2)*scales
59
+ offset = (F.softplus(self.convs[("intrinsics", 'offset')](feat)).mean(3).mean(2)+0.5)*scales
60
+ #focals = F.softplus(self.convs[("intrinsics",'focal')](feat).mean(3).mean(2))
61
+ #offset = F.softplus(self.convs[("intrinsics",'offset')](feat).mean(3).mean(2))
62
+ eyes = torch.eye(2).cuda()
63
+ b,xy = focals.shape
64
+ focals = focals.unsqueeze(-1).expand(b,xy,xy)
65
+ eyes = eyes.unsqueeze(0).expand(b,xy,xy)
66
+ intrin = focals*eyes
67
+ offset = offset.view(b,2,1).contiguous()
68
+ intrin = torch.cat([intrin,offset],-1)
69
+ pad = torch.tensor([0.0,0.0,1.0]).view(1,1,3).expand(b,1,3).cuda()
70
+ intrinsics = torch.cat([intrin,pad],1)
71
+ return axisangle, translation,intrinsics
depth/resnet_encoder.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright Niantic 2019. Patent Pending. All rights reserved.
2
+ #
3
+ # This software is licensed under the terms of the Monodepth2 licence
4
+ # which allows for non-commercial use only, the full terms of which are made
5
+ # available in the LICENSE file.
6
+
7
+ from __future__ import absolute_import, division, print_function
8
+
9
+ import numpy as np
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torchvision.models as models
14
+ import torch.utils.model_zoo as model_zoo
15
+
16
+
17
+ class ResNetMultiImageInput(models.ResNet):
18
+ """Constructs a resnet model with varying number of input images.
19
+ Adapted from https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
20
+ """
21
+ def __init__(self, block, layers, num_classes=1000, num_input_images=1):
22
+ super(ResNetMultiImageInput, self).__init__(block, layers)
23
+ self.inplanes = 64
24
+ self.conv1 = nn.Conv2d(
25
+ num_input_images * 3, 64, kernel_size=7, stride=2, padding=3, bias=False)
26
+ self.bn1 = nn.BatchNorm2d(64)
27
+ self.relu = nn.ReLU(inplace=True)
28
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
29
+ self.layer1 = self._make_layer(block, 64, layers[0])
30
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
31
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
32
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
33
+
34
+ for m in self.modules():
35
+ if isinstance(m, nn.Conv2d):
36
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
37
+ elif isinstance(m, nn.BatchNorm2d):
38
+ nn.init.constant_(m.weight, 1)
39
+ nn.init.constant_(m.bias, 0)
40
+
41
+
42
+ def resnet_multiimage_input(num_layers, pretrained=False, num_input_images=1):
43
+ """Constructs a ResNet model.
44
+ Args:
45
+ num_layers (int): Number of resnet layers. Must be 18 or 50
46
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
47
+ num_input_images (int): Number of frames stacked as input
48
+ """
49
+ assert num_layers in [18, 50], "Can only run with 18 or 50 layer resnet"
50
+ blocks = {18: [2, 2, 2, 2], 50: [3, 4, 6, 3]}[num_layers]
51
+ block_type = {18: models.resnet.BasicBlock, 50: models.resnet.Bottleneck}[num_layers]
52
+ model = ResNetMultiImageInput(block_type, blocks, num_input_images=num_input_images)
53
+
54
+ if pretrained:
55
+ loaded = model_zoo.load_url(models.resnet.model_urls['resnet{}'.format(num_layers)])
56
+ loaded['conv1.weight'] = torch.cat(
57
+ [loaded['conv1.weight']] * num_input_images, 1) / num_input_images
58
+ model.load_state_dict(loaded)
59
+ return model
60
+
61
+
62
+ class ResnetEncoder(nn.Module):
63
+ """Pytorch module for a resnet encoder
64
+ """
65
+ def __init__(self, num_layers, pretrained, num_input_images=1):
66
+ super(ResnetEncoder, self).__init__()
67
+
68
+ self.num_ch_enc = np.array([64, 64, 128, 256, 512])
69
+
70
+ resnets = {18: models.resnet18,
71
+ 34: models.resnet34,
72
+ 50: models.resnet50,
73
+ 101: models.resnet101,
74
+ 152: models.resnet152}
75
+
76
+ if num_layers not in resnets:
77
+ raise ValueError("{} is not a valid number of resnet layers".format(num_layers))
78
+
79
+ if num_input_images > 1:
80
+ self.encoder = resnet_multiimage_input(num_layers, pretrained, num_input_images)
81
+ else:
82
+ self.encoder = resnets[num_layers](pretrained)
83
+
84
+ if num_layers > 34:
85
+ self.num_ch_enc[1:] *= 4
86
+
87
+ def forward(self, input_image):
88
+ self.features = []
89
+ x = (input_image - 0.45) / 0.225
90
+ x = self.encoder.conv1(x)
91
+ x = self.encoder.bn1(x)
92
+ self.features.append(self.encoder.relu(x))
93
+ self.features.append(self.encoder.layer1(self.encoder.maxpool(self.features[-1])))
94
+ self.features.append(self.encoder.layer2(self.features[-1]))
95
+ self.features.append(self.encoder.layer3(self.features[-1]))
96
+ self.features.append(self.encoder.layer4(self.features[-1]))
97
+
98
+ return self.features
modules/AdaIN.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ def calc_mean_std(feat, eps=1e-5):
4
+ # eps is a small value added to the variance to avoid divide-by-zero.
5
+ size = feat.size()
6
+ assert (len(size) == 4)
7
+ N, C = size[:2]
8
+ feat_var = feat.view(N, C, -1).var(dim=2) + eps
9
+ feat_std = feat_var.sqrt().view(N, C, 1, 1)
10
+ feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
11
+ return feat_mean, feat_std
12
+
13
+ def adaptive_instance_normalization(content_feat, style_feat):
14
+ assert (content_feat.size()[:2] == style_feat.size()[:2])
15
+ size = content_feat.size()
16
+ style_mean, style_std = calc_mean_std(style_feat)
17
+ content_mean, content_std = calc_mean_std(content_feat)
18
+ normalized_feat = (content_feat - content_mean.expand(
19
+ size)) / content_std.expand(size)
20
+
21
+ return normalized_feat * style_std.expand(size) + style_mean.expand(size)
22
+
23
+ def _calc_feat_flatten_mean_std(feat):
24
+ # takes 3D feat (C, H, W), return mean and std of array within channels
25
+ assert (feat.size()[0] == 3)
26
+ assert (isinstance(feat, torch.FloatTensor))
27
+ feat_flatten = feat.view(3, -1)
28
+ mean = feat_flatten.mean(dim=-1, keepdim=True)
29
+ std = feat_flatten.std(dim=-1, keepdim=True)
30
+ return feat_flatten, mean, std
31
+
32
+ def _mat_sqrt(x):
33
+ U, D, V = torch.svd(x)
34
+ return torch.mm(torch.mm(U, D.pow(0.5).diag()), V.t())
35
+
36
+ def coral(source, target):
37
+ # assume both source and target are 3D array (C, H, W)
38
+ # Note: flatten -> f
39
+ source_f, source_f_mean, source_f_std = _calc_feat_flatten_mean_std(source)
40
+ source_f_norm = (source_f - source_f_mean.expand_as(
41
+ source_f)) / source_f_std.expand_as(source_f)
42
+ source_f_cov_eye = \
43
+ torch.mm(source_f_norm, source_f_norm.t()) + torch.eye(3)
44
+
45
+ target_f, target_f_mean, target_f_std = _calc_feat_flatten_mean_std(target)
46
+ target_f_norm = (target_f - target_f_mean.expand_as(
47
+ target_f)) / target_f_std.expand_as(target_f)
48
+ target_f_cov_eye = \
49
+ torch.mm(target_f_norm, target_f_norm.t()) + torch.eye(3)
50
+
51
+ source_f_norm_transfer = torch.mm(
52
+ _mat_sqrt(target_f_cov_eye),
53
+ torch.mm(torch.inverse(_mat_sqrt(source_f_cov_eye)),
54
+ source_f_norm)
55
+ )
56
+
57
+ source_f_transfer = source_f_norm_transfer * \
58
+ target_f_std.expand_as(source_f_norm) + \
59
+ target_f_mean.expand_as(source_f_norm)
60
+
61
+ return source_f_transfer.view(source.size())
modules/dense_motion.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ import torch.nn.functional as F
3
+ import torch
4
+ from modules.util import Hourglass, AntiAliasInterpolation2d, make_coordinate_grid, kp2gaussian
5
+ import pdb
6
+ from modules.util import ResBlock2d, SameBlock2d, UpBlock2d, DownBlock2d
7
+
8
+
9
+ class DenseMotionNetwork(nn.Module):
10
+ """
11
+ Module that predicting a dense motion from sparse motion representation given by kp_source and kp_driving
12
+ """
13
+
14
+ def __init__(self, block_expansion, num_blocks, max_features, num_kp, num_channels, estimate_occlusion_map=False,
15
+ scale_factor=1, kp_variance=0.01):
16
+ super(DenseMotionNetwork, self).__init__()
17
+ self.hourglass = Hourglass(block_expansion=block_expansion, in_features=(num_kp + 1) * (num_channels + 1),
18
+ max_features=max_features, num_blocks=num_blocks)
19
+
20
+ self.mask = nn.Conv2d(self.hourglass.out_filters, num_kp + 1, kernel_size=(7, 7), padding=(3, 3))
21
+
22
+ if estimate_occlusion_map:
23
+ self.occlusion = nn.Conv2d(self.hourglass.out_filters, 1, kernel_size=(7, 7), padding=(3, 3))
24
+ else:
25
+ self.occlusion = None
26
+
27
+ self.num_kp = num_kp
28
+ self.scale_factor = scale_factor
29
+ self.kp_variance = kp_variance
30
+
31
+ if self.scale_factor != 1:
32
+ self.down = AntiAliasInterpolation2d(num_channels, self.scale_factor)
33
+
34
+ def create_heatmap_representations(self, source_image, kp_driving, kp_source):
35
+ """
36
+ Eq 6. in the paper H_k(z)
37
+ """
38
+ spatial_size = source_image.shape[2:]
39
+ gaussian_driving = kp2gaussian(kp_driving, spatial_size=spatial_size, kp_variance=self.kp_variance)
40
+ gaussian_source = kp2gaussian(kp_source, spatial_size=spatial_size, kp_variance=self.kp_variance)
41
+ heatmap = gaussian_driving - gaussian_source
42
+ #adding background feature
43
+ zeros = torch.zeros(heatmap.shape[0], 1, spatial_size[0], spatial_size[1]).type(heatmap.type())
44
+ heatmap = torch.cat([zeros, heatmap], dim=1)
45
+ heatmap = heatmap.unsqueeze(2)
46
+ return heatmap
47
+
48
+ def create_sparse_motions(self, source_image, kp_driving, kp_source):
49
+ """
50
+ Eq 4. in the paper T_{s<-d}(z)
51
+ """
52
+ bs, _, h, w = source_image.shape
53
+ identity_grid = make_coordinate_grid((h, w), type=kp_source['value'].type())
54
+ identity_grid = identity_grid.view(1, 1, h, w, 2)
55
+ coordinate_grid = identity_grid - kp_driving['value'].view(bs, self.num_kp, 1, 1, 2)
56
+ if 'jacobian' in kp_driving:
57
+ jacobian = torch.matmul(kp_source['jacobian'], torch.inverse(kp_driving['jacobian']))
58
+ jacobian = jacobian.unsqueeze(-3).unsqueeze(-3)
59
+ jacobian = jacobian.repeat(1, 1, h, w, 1, 1)
60
+ coordinate_grid = torch.matmul(jacobian, coordinate_grid.unsqueeze(-1))
61
+ coordinate_grid = coordinate_grid.squeeze(-1)
62
+
63
+ driving_to_source = coordinate_grid + kp_source['value'].view(bs, self.num_kp, 1, 1, 2)
64
+
65
+ #adding background feature
66
+ identity_grid = identity_grid.repeat(bs, 1, 1, 1, 1)
67
+ sparse_motions = torch.cat([identity_grid, driving_to_source], dim=1) #bs, num_kp+1,w,h,2
68
+ return sparse_motions
69
+
70
+ def create_deformed_source_image(self, source_image, sparse_motions):
71
+ """
72
+ Eq 7. in the paper \hat{T}_{s<-d}(z)
73
+ """
74
+ bs, _, h, w = source_image.shape
75
+ source_repeat = source_image.unsqueeze(1).unsqueeze(1).repeat(1, self.num_kp + 1, 1, 1, 1, 1)
76
+ source_repeat = source_repeat.view(bs * (self.num_kp + 1), -1, h, w)
77
+ sparse_motions = sparse_motions.view((bs * (self.num_kp + 1), h, w, -1))
78
+ sparse_deformed = F.grid_sample(source_repeat, sparse_motions)
79
+ sparse_deformed = sparse_deformed.view((bs, self.num_kp + 1, -1, h, w))
80
+ return sparse_deformed
81
+
82
+ def forward(self, source_image, kp_driving, kp_source):
83
+ if self.scale_factor != 1:
84
+ source_image = self.down(source_image)
85
+ bs, _, h, w = source_image.shape
86
+ out_dict = dict()
87
+ heatmap_representation = self.create_heatmap_representations(source_image, kp_driving, kp_source)
88
+ sparse_motion = self.create_sparse_motions(source_image, kp_driving, kp_source)
89
+ deformed_source = self.create_deformed_source_image(source_image, sparse_motion)
90
+ out_dict['sparse_deformed'] = deformed_source
91
+
92
+ input = torch.cat([heatmap_representation, deformed_source], dim=2)
93
+ input = input.view(bs, -1, h, w)
94
+
95
+ prediction = self.hourglass(input)
96
+
97
+ mask = self.mask(prediction)
98
+ mask = F.softmax(mask, dim=1)
99
+ out_dict['mask'] = mask
100
+ mask = mask.unsqueeze(2)
101
+ sparse_motion = sparse_motion.permute(0, 1, 4, 2, 3)
102
+ deformation = (sparse_motion * mask).sum(dim=1)
103
+ deformation = deformation.permute(0, 2, 3, 1)
104
+
105
+ out_dict['deformation'] = deformation
106
+
107
+ # Sec. 3.2 in the paper
108
+ if self.occlusion:
109
+ occlusion_map = torch.sigmoid(self.occlusion(prediction))
110
+ out_dict['occlusion_map'] = occlusion_map
111
+
112
+ return out_dict
modules/discriminator.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ import torch.nn.functional as F
3
+ from modules.util import kp2gaussian
4
+ import torch
5
+ import pdb
6
+
7
+ class DownBlock2d(nn.Module):
8
+ """
9
+ Simple block for processing video (encoder).
10
+ """
11
+
12
+ def __init__(self, in_features, out_features, norm=False, kernel_size=4, pool=False, sn=False):
13
+ super(DownBlock2d, self).__init__()
14
+ self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size)
15
+
16
+ if sn:
17
+ self.conv = nn.utils.spectral_norm(self.conv)
18
+
19
+ if norm:
20
+ self.norm = nn.InstanceNorm2d(out_features, affine=True)
21
+ else:
22
+ self.norm = None
23
+ self.pool = pool
24
+
25
+ def forward(self, x):
26
+ out = x
27
+ out = self.conv(out)
28
+ if self.norm:
29
+ out = self.norm(out)
30
+ out = F.leaky_relu(out, 0.2)
31
+ if self.pool:
32
+ out = F.avg_pool2d(out, (2, 2))
33
+ return out
34
+
35
+
36
+ class Discriminator(nn.Module):
37
+ """
38
+ Discriminator similar to Pix2Pix
39
+ """
40
+
41
+ def __init__(self, num_channels=3, block_expansion=64, num_blocks=4, max_features=512,
42
+ sn=False, use_kp=False, num_kp=10, kp_variance=0.01, **kwargs):
43
+ super(Discriminator, self).__init__()
44
+
45
+ down_blocks = []
46
+ for i in range(num_blocks):
47
+ down_blocks.append(
48
+ DownBlock2d(num_channels + num_kp * use_kp if i == 0 else min(max_features, block_expansion * (2 ** i)),
49
+ min(max_features, block_expansion * (2 ** (i + 1))),
50
+ norm=(i != 0), kernel_size=4, pool=(i != num_blocks - 1), sn=sn))
51
+ self.down_blocks = nn.ModuleList(down_blocks)
52
+ self.conv = nn.Conv2d(self.down_blocks[-1].conv.out_channels, out_channels=1, kernel_size=1)
53
+ if sn:
54
+ self.conv = nn.utils.spectral_norm(self.conv)
55
+ self.use_kp = use_kp
56
+ self.kp_variance = kp_variance
57
+
58
+ def forward(self, x, kp=None):
59
+ feature_maps = []
60
+ out = x
61
+ if self.use_kp:
62
+ heatmap = kp2gaussian(kp, x.shape[2:], self.kp_variance)
63
+ out = torch.cat([out, heatmap], dim=1)
64
+ # print(out.shape)
65
+ for down_block in self.down_blocks:
66
+ feature_maps.append(down_block(out))
67
+ out = feature_maps[-1]
68
+ # print(out.shape)
69
+ prediction_map = self.conv(out)
70
+
71
+ return feature_maps, prediction_map
72
+
73
+
74
+ class MultiScaleDiscriminator(nn.Module):
75
+ """
76
+ Multi-scale (scale) discriminator
77
+ """
78
+
79
+ def __init__(self, scales=(), **kwargs):
80
+ super(MultiScaleDiscriminator, self).__init__()
81
+ self.scales = scales
82
+ discs = {}
83
+ for scale in scales:
84
+ discs[str(scale).replace('.', '-')] = Discriminator(**kwargs)
85
+ self.discs = nn.ModuleDict(discs)
86
+
87
+ def forward(self, x, kp=None):
88
+ out_dict = {}
89
+ for scale, disc in self.discs.items():
90
+ scale = str(scale).replace('-', '.')
91
+ key = 'prediction_' + scale
92
+ feature_maps, prediction_map = disc(x[key], kp)
93
+ out_dict['feature_maps_' + scale] = feature_maps
94
+ out_dict['prediction_map_' + scale] = prediction_map
95
+ return out_dict
modules/dynamic_conv.py ADDED
@@ -0,0 +1,382 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import pdb
5
+
6
+ class attention1d(nn.Module):
7
+ def __init__(self, in_planes, ratios, K, temperature, init_weight=True):
8
+ super(attention1d, self).__init__()
9
+ assert temperature%3==1
10
+ self.avgpool = nn.AdaptiveAvgPool1d(1)
11
+ if in_planes!=3:
12
+ hidden_planes = int(in_planes*ratios)+1
13
+ else:
14
+ hidden_planes = K
15
+ self.fc1 = nn.Conv1d(in_planes, hidden_planes, 1, bias=False)
16
+ # self.bn = nn.BatchNorm2d(hidden_planes)
17
+ self.fc2 = nn.Conv1d(hidden_planes, K, 1, bias=True)
18
+ self.temperature = temperature
19
+ if init_weight:
20
+ self._initialize_weights()
21
+
22
+
23
+ def _initialize_weights(self):
24
+ for m in self.modules():
25
+ if isinstance(m, nn.Conv1d):
26
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
27
+ if m.bias is not None:
28
+ nn.init.constant_(m.bias, 0)
29
+ if isinstance(m ,nn.BatchNorm2d):
30
+ nn.init.constant_(m.weight, 1)
31
+ nn.init.constant_(m.bias, 0)
32
+
33
+ def updata_temperature(self):
34
+ if self.temperature!=1:
35
+ self.temperature -=3
36
+ print('Change temperature to:', str(self.temperature))
37
+
38
+
39
+ def forward(self, x):
40
+ x = self.avgpool(x)
41
+ x = self.fc1(x)
42
+ x = F.relu(x)
43
+ x = self.fc2(x).view(x.size(0), -1)
44
+ return F.softmax(x/self.temperature, 1)
45
+
46
+
47
+ class Dynamic_conv1d(nn.Module):
48
+ def __init__(self, in_planes, out_planes, kernel_size, ratio=0.25, stride=1, padding=0, dilation=1, groups=1, bias=True, K=4,temperature=34, init_weight=True):
49
+ super(Dynamic_conv1d, self).__init__()
50
+ assert in_planes%groups==0
51
+ self.in_planes = in_planes
52
+ self.out_planes = out_planes
53
+ self.kernel_size = kernel_size
54
+ self.stride = stride
55
+ self.padding = padding
56
+ self.dilation = dilation
57
+ self.groups = groups
58
+ self.bias = bias
59
+ self.K = K
60
+ self.attention = attention1d(in_planes, ratio, K, temperature)
61
+
62
+ self.weight = nn.Parameter(torch.randn(K, out_planes, in_planes//groups, kernel_size), requires_grad=True)
63
+ if bias:
64
+ self.bias = nn.Parameter(torch.Tensor(K, out_planes))
65
+ else:
66
+ self.bias = None
67
+ if init_weight:
68
+ self._initialize_weights()
69
+
70
+ #TODO 初始化
71
+ def _initialize_weights(self):
72
+ for i in range(self.K):
73
+ nn.init.kaiming_uniform_(self.weight[i])
74
+
75
+
76
+ def update_temperature(self):
77
+ self.attention.updata_temperature()
78
+
79
+ def forward(self, x):#将batch视作维度变量,进行组卷积,因为组卷积的权重是不同的,动态卷积的权重也是不同的
80
+ softmax_attention = self.attention(x)
81
+ batch_size, in_planes, height = x.size()
82
+ x = x.view(1, -1, height, )# 变化成一个维度进行组卷积
83
+ weight = self.weight.view(self.K, -1)
84
+
85
+ # 动态卷积的权重的生成, 生成的是batch_size个卷积参数(每个参数不同)
86
+ aggregate_weight = torch.mm(softmax_attention, weight).view(-1, self.in_planes, self.kernel_size,)
87
+ if self.bias is not None:
88
+ aggregate_bias = torch.mm(softmax_attention, self.bias).view(-1)
89
+ output = F.conv1d(x, weight=aggregate_weight, bias=aggregate_bias, stride=self.stride, padding=self.padding,
90
+ dilation=self.dilation, groups=self.groups*batch_size)
91
+ else:
92
+ output = F.conv1d(x, weight=aggregate_weight, bias=None, stride=self.stride, padding=self.padding,
93
+ dilation=self.dilation, groups=self.groups * batch_size)
94
+
95
+ output = output.view(batch_size, self.out_planes, output.size(-1))
96
+ return output
97
+
98
+
99
+
100
+ class attention2d(nn.Module):
101
+ def __init__(self, in_planes, ratios, K, temperature, init_weight=True):
102
+ super(attention2d, self).__init__()
103
+ assert temperature%3==1
104
+ self.avgpool = nn.AdaptiveAvgPool2d(1)
105
+ if in_planes!=3:
106
+ hidden_planes = int(in_planes*ratios)+1
107
+ else:
108
+ hidden_planes = K
109
+ self.fc1 = nn.Conv2d(in_planes, hidden_planes, 1, bias=False)
110
+ # self.bn = nn.BatchNorm2d(hidden_planes)
111
+ self.fc2 = nn.Conv2d(hidden_planes, K, 1, bias=True)
112
+ self.temperature = temperature
113
+ if init_weight:
114
+ self._initialize_weights()
115
+
116
+
117
+ def _initialize_weights(self):
118
+ for m in self.modules():
119
+ if isinstance(m, nn.Conv2d):
120
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
121
+ if m.bias is not None:
122
+ nn.init.constant_(m.bias, 0)
123
+ if isinstance(m ,nn.BatchNorm2d):
124
+ nn.init.constant_(m.weight, 1)
125
+ nn.init.constant_(m.bias, 0)
126
+
127
+ def updata_temperature(self):
128
+ if self.temperature!=1:
129
+ self.temperature -=3
130
+ print('Change temperature to:', str(self.temperature))
131
+
132
+
133
+ def forward(self, x):
134
+ x = self.avgpool(x)
135
+ x = self.fc1(x)
136
+ x = F.relu(x)
137
+ x = self.fc2(x).view(x.size(0), -1)
138
+ return F.softmax(x/self.temperature, 1)
139
+
140
+
141
+ class Dynamic_deepwise_conv2d(nn.Module):
142
+ def __init__(self, in_planes, out_planes, kernel_size, ratio=0.25, stride=1, padding=0, dilation=1, groups=1, bias=True, K=4,temperature=34, init_weight=True):
143
+ super(Dynamic_deepwise_conv2d, self).__init__()
144
+ assert in_planes%groups==0
145
+ self.in_planes = in_planes
146
+ self.out_planes = out_planes
147
+ self.kernel_size = kernel_size
148
+ self.stride = stride
149
+ self.padding = padding
150
+ self.dilation = dilation
151
+ self.groups = groups
152
+ self.bias = bias
153
+ self.K = K
154
+ self.attention = attention2d(in_planes, ratio, K, temperature)
155
+
156
+ self.weight = nn.Parameter(torch.randn(K, out_planes, in_planes//groups, kernel_size, kernel_size), requires_grad=True)
157
+ if bias:
158
+ self.bias = nn.Parameter(torch.Tensor(K, out_planes))
159
+ else:
160
+ self.bias = None
161
+ if init_weight:
162
+ self._initialize_weights()
163
+
164
+ #TODO 初始化
165
+ def _initialize_weights(self):
166
+ for i in range(self.K):
167
+ nn.init.kaiming_uniform_(self.weight[i])
168
+
169
+
170
+ def update_temperature(self):
171
+ self.attention.updata_temperature()
172
+
173
+ def forward(self, x, y):#将batch视作维度变量,进行组卷积,因为组卷积的权重是不同的,动态卷积的权重也是不同的
174
+ softmax_attention = self.attention(x)
175
+ batch_size, in_planes, height, width = x.size()
176
+ y = y.view(1, -1, height, width)# 变化成一个维度进行组卷积
177
+ weight = self.weight.view(self.K, -1)
178
+
179
+ # 动态卷积的权重的生成, 生成的是batch_size个卷积参数(每个参数不同)
180
+ aggregate_weight = torch.mm(softmax_attention, weight).view(-1, 1, self.kernel_size, self.kernel_size)
181
+ if self.bias is not None:
182
+ aggregate_bias = torch.mm(softmax_attention, self.bias).view(-1)
183
+ output = F.conv2d(y, weight=aggregate_weight, bias=aggregate_bias, stride=self.stride, padding=self.padding,
184
+ dilation=self.dilation, groups=self.groups*batch_size)
185
+ else:
186
+ output = F.conv2d(y, weight=aggregate_weight, bias=None, stride=self.stride, padding=self.padding,
187
+ dilation=self.dilation, groups=self.groups * batch_size)
188
+
189
+ output = output.view(batch_size, self.out_planes, output.size(-2), output.size(-1))
190
+ return output
191
+
192
+ class Dynamic_conv2d(nn.Module):
193
+ def __init__(self, in_planes, out_planes, kernel_size, ratio=0.25, stride=1, padding=0, dilation=1, groups=1, bias=True, K=4,temperature=34, init_weight=True):
194
+ super(Dynamic_conv2d, self).__init__()
195
+ assert in_planes%groups==0
196
+ self.in_planes = in_planes
197
+ self.out_planes = out_planes
198
+ self.kernel_size = kernel_size
199
+ self.stride = stride
200
+ self.padding = padding
201
+ self.dilation = dilation
202
+ self.groups = groups
203
+ self.bias = bias
204
+ self.K = K
205
+ self.attention = attention2d(in_planes, ratio, K, temperature)
206
+
207
+ self.weight = nn.Parameter(torch.randn(K, out_planes, in_planes//groups, kernel_size, kernel_size), requires_grad=True)
208
+ if bias:
209
+ self.bias = nn.Parameter(torch.Tensor(K, out_planes))
210
+ else:
211
+ self.bias = None
212
+ if init_weight:
213
+ self._initialize_weights()
214
+
215
+ #TODO 初始化
216
+ def _initialize_weights(self):
217
+ for i in range(self.K):
218
+ nn.init.kaiming_uniform_(self.weight[i])
219
+
220
+
221
+ def update_temperature(self):
222
+ self.attention.updata_temperature()
223
+
224
+ def forward(self, x,y):#将batch视作维度变量,进行组卷积,因为组卷积的权重是不同的,动态卷积的权重也是不同的
225
+ softmax_attention = self.attention(x)
226
+ batch_size, in_planes, height, width = x.size()
227
+ y = y.view(1, -1, height, width)# 变化成一个维度进行组卷积
228
+ weight = self.weight.view(self.K, -1)
229
+
230
+ # 动态卷积的权重的生成, 生成的是batch_size个卷积参数(每个参数不同)
231
+ aggregate_weight = torch.mm(softmax_attention, weight).view(-1, self.in_planes, self.kernel_size, self.kernel_size)
232
+ if self.bias is not None:
233
+ aggregate_bias = torch.mm(softmax_attention, self.bias).view(-1)
234
+ output = F.conv2d(y, weight=aggregate_weight, bias=aggregate_bias, stride=self.stride, padding=self.padding,
235
+ dilation=self.dilation, groups=self.groups*batch_size)
236
+ else:
237
+ output = F.conv2d(y, weight=aggregate_weight, bias=None, stride=self.stride, padding=self.padding,
238
+ dilation=self.dilation, groups=self.groups * batch_size)
239
+
240
+ output = output.view(batch_size, self.out_planes, output.size(-2), output.size(-1))
241
+ return output
242
+
243
+
244
+ class attention3d(nn.Module):
245
+ def __init__(self, in_planes, ratios, K, temperature):
246
+ super(attention3d, self).__init__()
247
+ assert temperature%3==1
248
+ self.avgpool = nn.AdaptiveAvgPool3d(1)
249
+ if in_planes != 3:
250
+ hidden_planes = int(in_planes * ratios)+1
251
+ else:
252
+ hidden_planes = K
253
+ self.fc1 = nn.Conv3d(in_planes, hidden_planes, 1, bias=False)
254
+ self.fc2 = nn.Conv3d(hidden_planes, K, 1, bias=False)
255
+ self.temperature = temperature
256
+
257
+ def updata_temperature(self):
258
+ if self.temperature!=1:
259
+ self.temperature -=3
260
+ print('Change temperature to:', str(self.temperature))
261
+
262
+ def forward(self, x):
263
+ x = self.avgpool(x)
264
+ x = self.fc1(x)
265
+ x = F.relu(x)
266
+ x = self.fc2(x).view(x.size(0), -1)
267
+ return F.softmax(x / self.temperature, 1)
268
+
269
+ class Dynamic_conv3d(nn.Module):
270
+ def __init__(self, in_planes, out_planes, kernel_size, ratio=0.25, stride=1, padding=0, dilation=1, groups=1, bias=True, K=4, temperature=34):
271
+ super(Dynamic_conv3d, self).__init__()
272
+ assert in_planes%groups==0
273
+ self.in_planes = in_planes
274
+ self.out_planes = out_planes
275
+ self.kernel_size = kernel_size
276
+ self.stride = stride
277
+ self.padding = padding
278
+ self.dilation = dilation
279
+ self.groups = groups
280
+ self.bias = bias
281
+ self.K = K
282
+ self.attention = attention3d(in_planes, ratio, K, temperature)
283
+
284
+ self.weight = nn.Parameter(torch.randn(K, out_planes, in_planes//groups, kernel_size, kernel_size, kernel_size), requires_grad=True)
285
+ if bias:
286
+ self.bias = nn.Parameter(torch.Tensor(K, out_planes))
287
+ else:
288
+ self.bias = None
289
+
290
+
291
+ #TODO 初始化
292
+ # nn.init.kaiming_uniform_(self.weight, )
293
+
294
+ def update_temperature(self):
295
+ self.attention.updata_temperature()
296
+
297
+ def forward(self, x):#将batch视作维度变量,进行组卷积,因为组卷积的权重是不同的,动态卷积的权重也是不同的
298
+ softmax_attention = self.attention(x)
299
+ batch_size, in_planes, depth, height, width = x.size()
300
+ x = x.view(1, -1, depth, height, width)# 变化成一个维度进行组卷积
301
+ weight = self.weight.view(self.K, -1)
302
+
303
+ # 动态卷积的权重的生成, 生成的是batch_size个卷积参数(每个参数不同)
304
+ aggregate_weight = torch.mm(softmax_attention, weight).view(-1, self.in_planes, self.kernel_size, self.kernel_size, self.kernel_size)
305
+ if self.bias is not None:
306
+ aggregate_bias = torch.mm(softmax_attention, self.bias).view(-1)
307
+ output = F.conv3d(x, weight=aggregate_weight, bias=aggregate_bias, stride=self.stride, padding=self.padding,
308
+ dilation=self.dilation, groups=self.groups*batch_size)
309
+ else:
310
+ output = F.conv3d(x, weight=aggregate_weight, bias=None, stride=self.stride, padding=self.padding,
311
+ dilation=self.dilation, groups=self.groups * batch_size)
312
+
313
+ output = output.view(batch_size, self.out_planes, output.size(-3), output.size(-2), output.size(-1))
314
+ return output
315
+
316
+
317
+
318
+
319
+ if __name__ == '__main__':
320
+ x = torch.randn(12, 256, 64, 64)
321
+ y = torch.randn(12, 256, 64, 64)
322
+
323
+ model = Dynamic_conv2d(in_planes=256, out_planes=256, kernel_size=3, ratio=0.25, padding=1,groups=1)
324
+ x = x.to('cuda:0')
325
+ y = y.to('cuda:0')
326
+ model.to('cuda')
327
+ # model.attention.cuda()
328
+ print(model(x,y).shape)
329
+ # nn.Conv3d()
330
+ # print(model(x).shape)
331
+ # model.update_temperature()
332
+ # model.update_temperature()
333
+ # model.update_temperature()
334
+ # model.update_temperature()
335
+ # model.update_temperature()
336
+ # model.update_temperature()
337
+ # model.update_temperature()
338
+ # model.update_temperature()
339
+ # model.update_temperature()
340
+ # model.update_temperature()
341
+ # model.update_temperature()
342
+ # model.update_temperature()
343
+ # model.update_temperature()
344
+ # print(model(x).shape)
345
+ # print(model(x).shape)
346
+ # print(model(x).shape)
347
+ # print(model(x).shape)
348
+ # print(model(x).shape)
349
+ # print(model(x).shape)
350
+ # print(model(x).shape)
351
+ # print(model(x).shape)
352
+ # print(model(x).shape)
353
+ # print(model(x).shape)
354
+ # print(model(x).shape)
355
+ # print(model(x).shape)
356
+ # print(model(x).shape)
357
+ # print(model(x).shape)
358
+ # print(model(x).shape)
359
+ # print(model(x).shape)
360
+ # print(model(x).shape)
361
+ # print(model(x).shape)
362
+ # print(model(x).shape)
363
+ # print(model(x).shape)
364
+ # print(model(x).shape)
365
+ # print(model(x).shape)
366
+ # print(model(x).shape)
367
+ # print(model(x).shape)
368
+ # print(model(x).shape)
369
+ # print(model(x).shape)
370
+ # print(model(x).shape)
371
+ # print(model(x).shape)
372
+ # print(model(x).shape)
373
+ # print(model(x).shape)
374
+ # print(model(x).shape)
375
+ # print(model(x).shape)
376
+ # print(model(x).shape)
377
+ # print(model(x).shape)
378
+ # print(model(x).shape)
379
+ # print(model(x).shape)
380
+ # print(model(x).shape)
381
+ # print(model(x).shape)
382
+
modules/generator.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+ from modules.util import ResBlock2d, SameBlock2d, UpBlock2d, DownBlock2d,SPADEResnetBlock
5
+ from modules.dense_motion import *
6
+ import pdb
7
+ from modules.AdaIN import calc_mean_std,adaptive_instance_normalization
8
+ from modules.dynamic_conv import Dynamic_conv2d
9
+ class SPADEGenerator(nn.Module):
10
+ def __init__(self):
11
+ super().__init__()
12
+ ic = 256
13
+ cc = 4
14
+ oc = 64
15
+ norm_G = 'spadespectralinstance'
16
+ label_nc = 3 + cc
17
+
18
+ self.compress = nn.Conv2d(ic, cc, 3, padding=1)
19
+ self.fc = nn.Conv2d(ic, 2 * ic, 3, padding=1)
20
+
21
+ self.G_middle_0 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc)
22
+ self.G_middle_1 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc)
23
+ self.G_middle_2 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc)
24
+ # self.G_middle_3 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc)
25
+ # self.G_middle_4 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc)
26
+ # self.G_middle_5 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc)
27
+
28
+ self.up_0 = SPADEResnetBlock(2 * ic, ic, norm_G, label_nc)
29
+ self.up_1 = SPADEResnetBlock(ic, oc, norm_G, label_nc)
30
+ self.conv_img = nn.Conv2d(oc, 3, 3, padding=1)
31
+ self.up = nn.Upsample(scale_factor=2)
32
+
33
+ def forward(self, feature, image):
34
+ cp = self.compress(feature)
35
+ seg = torch.cat((F.interpolate(cp, size=(image.shape[2], image.shape[3])), image), dim=1) # 7, 256, 256
36
+
37
+ x = feature # 256, 64, 64
38
+ x = self.fc(x) # 512, 64, 64
39
+ x = self.G_middle_0(x, seg)
40
+ x = self.G_middle_1(x, seg)
41
+ x = self.G_middle_2(x, seg)
42
+ # x = self.G_middle_3(x, seg)
43
+ # x = self.G_middle_4(x, seg)
44
+ # x = self.G_middle_5(x, seg)
45
+ x = self.up(x) # 256, 128, 128
46
+ x = self.up_0(x, seg)
47
+ x = self.up(x) # 64, 256, 256
48
+ x = self.up_1(x, seg)
49
+
50
+ x = self.conv_img(F.leaky_relu(x, 2e-1))
51
+ # x = torch.tanh(x)
52
+ x = F.sigmoid(x)
53
+
54
+ return x
55
+
56
+ class DepthAwareAttention(nn.Module):
57
+ """ depth-aware attention Layer"""
58
+ def __init__(self,in_dim,activation):
59
+ super(DepthAwareAttention,self).__init__()
60
+ self.chanel_in = in_dim
61
+ self.activation = activation
62
+
63
+ self.query_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)
64
+ self.key_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)
65
+ self.value_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim , kernel_size= 1)
66
+ self.gamma = nn.Parameter(torch.zeros(1))
67
+
68
+ self.softmax = nn.Softmax(dim=-1) #
69
+ def forward(self,source,feat):
70
+ """
71
+ inputs :
72
+ source : input feature maps( B X C X W X H) 256,64,64
73
+ driving : input feature maps( B X C X W X H) 256,64,64
74
+ returns :
75
+ out : self attention value + input feature
76
+ attention: B X N X N (N is Width*Height)
77
+ """
78
+ m_batchsize,C,width ,height = source.size()
79
+ proj_query = self.activation(self.query_conv(source)).view(m_batchsize,-1,width*height).permute(0,2,1) # B X CX(N) [bz,32,64,64]
80
+ proj_key = self.activation(self.key_conv(feat)).view(m_batchsize,-1,width*height) # B X C x (*W*H)
81
+ energy = torch.bmm(proj_query,proj_key) # transpose check
82
+ attention = self.softmax(energy) # BX (N) X (N)
83
+ proj_value = self.activation(self.value_conv(feat)).view(m_batchsize,-1,width*height) # B X C X N
84
+
85
+ out = torch.bmm(proj_value,attention.permute(0,2,1) )
86
+ out = out.view(m_batchsize,C,width,height)
87
+ out = self.gamma*out + feat
88
+
89
+ return out,attention
90
+
91
+ #### main ####
92
+ class DepthAwareGenerator(nn.Module):
93
+ """
94
+ Generator that given source image and and keypoints try to transform image according to movement trajectories
95
+ induced by keypoints. Generator follows Johnson architecture.
96
+ """
97
+
98
+ def __init__(self, num_channels, num_kp, block_expansion, max_features, num_down_blocks,
99
+ num_bottleneck_blocks, estimate_occlusion_map=False, dense_motion_params=None, estimate_jacobian=False):
100
+ super(DepthAwareGenerator, self).__init__()
101
+
102
+ if dense_motion_params is not None:
103
+ self.dense_motion_network = DenseMotionNetwork(num_kp=num_kp, num_channels=num_channels,
104
+ estimate_occlusion_map=estimate_occlusion_map,
105
+ **dense_motion_params)
106
+ else:
107
+ self.dense_motion_network = None
108
+
109
+ self.first = SameBlock2d(num_channels, block_expansion, kernel_size=(7, 7), padding=(3, 3))
110
+ down_blocks = []
111
+ for i in range(num_down_blocks):
112
+ in_features = min(max_features, block_expansion * (2 ** i))
113
+ out_features = min(max_features, block_expansion * (2 ** (i + 1)))
114
+ down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1)))
115
+ self.down_blocks = nn.ModuleList(down_blocks)
116
+
117
+ #source depth
118
+ self.src_first = SameBlock2d(1, block_expansion, kernel_size=(7, 7), padding=(3, 3))
119
+ src_down_blocks = []
120
+ for i in range(num_down_blocks):
121
+ in_features = min(max_features, block_expansion * (2 ** i))
122
+ out_features = min(max_features, block_expansion * (2 ** (i + 1)))
123
+ src_down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1)))
124
+ self.src_down_blocks = nn.ModuleList(src_down_blocks)
125
+
126
+ # #driving depth
127
+ # self.dst_first = SameBlock2d(1, block_expansion, kernel_size=(7, 7), padding=(3, 3))
128
+ # dst_down_blocks = []
129
+ # for i in range(num_down_blocks):
130
+ # in_features = min(max_features, block_expansion * (2 ** i))
131
+ # out_features = min(max_features, block_expansion * (2 ** (i + 1)))
132
+ # dst_down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1)))
133
+ # self.dst_down_blocks = nn.ModuleList(dst_down_blocks)
134
+
135
+ self.AttnModule = DepthAwareAttention(out_features,nn.ReLU())
136
+
137
+ up_blocks = []
138
+ for i in range(num_down_blocks):
139
+ in_features = min(max_features, block_expansion * (2 ** (num_down_blocks - i)))
140
+ out_features = min(max_features, block_expansion * (2 ** (num_down_blocks - i - 1)))
141
+ up_blocks.append(UpBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1)))
142
+ self.up_blocks = nn.ModuleList(up_blocks)
143
+
144
+ self.bottleneck = torch.nn.Sequential()
145
+ in_features = min(max_features, block_expansion * (2 ** num_down_blocks))
146
+ for i in range(num_bottleneck_blocks):
147
+ self.bottleneck.add_module('r' + str(i), ResBlock2d(in_features, kernel_size=(3, 3), padding=(1, 1)))
148
+
149
+ self.final = nn.Conv2d(block_expansion, num_channels, kernel_size=(7, 7), padding=(3, 3))
150
+ self.estimate_occlusion_map = estimate_occlusion_map
151
+ self.num_channels = num_channels
152
+
153
+ def deform_input(self, inp, deformation):
154
+ _, h_old, w_old, _ = deformation.shape
155
+ _, _, h, w = inp.shape
156
+ if h_old != h or w_old != w:
157
+ deformation = deformation.permute(0, 3, 1, 2)
158
+ deformation = F.interpolate(deformation, size=(h, w), mode='bilinear')
159
+ deformation = deformation.permute(0, 2, 3, 1)
160
+ return F.grid_sample(inp, deformation)
161
+
162
+ def forward(self, source_image, kp_driving, kp_source, source_depth, driving_depth):
163
+ # Encoding (downsampling) part
164
+ out = self.first(source_image)
165
+ for i in range(len(self.down_blocks)):
166
+ out = self.down_blocks[i](out)
167
+
168
+ src_out = self.src_first(source_depth)
169
+ for i in range(len(self.src_down_blocks)):
170
+ src_out = self.src_down_blocks[i](src_out)
171
+
172
+ # dst_out = self.dst_first(driving_depth)
173
+ # for i in range(len(self.down_blocks)):
174
+ # dst_out = self.dst_down_blocks[i](dst_out)
175
+
176
+ # Transforming feature representation according to deformation and occlusion
177
+ output_dict = {}
178
+ if self.dense_motion_network is not None:
179
+ dense_motion = self.dense_motion_network(source_image=source_image, kp_driving=kp_driving,
180
+ kp_source=kp_source)
181
+ output_dict['mask'] = dense_motion['mask']
182
+ output_dict['sparse_deformed'] = dense_motion['sparse_deformed']
183
+
184
+ if 'occlusion_map' in dense_motion:
185
+ occlusion_map = dense_motion['occlusion_map']
186
+ output_dict['occlusion_map'] = occlusion_map
187
+ else:
188
+ occlusion_map = None
189
+ deformation = dense_motion['deformation']
190
+ out = self.deform_input(out, deformation)
191
+
192
+ if occlusion_map is not None:
193
+ if out.shape[2] != occlusion_map.shape[2] or out.shape[3] != occlusion_map.shape[3]:
194
+ occlusion_map = F.interpolate(occlusion_map, size=out.shape[2:], mode='bilinear')
195
+ out = out * occlusion_map
196
+ out,attention = self.AttnModule(src_out,out)
197
+
198
+ output_dict["deformed"] = self.deform_input(source_image, deformation)
199
+ output_dict["attention"] = attention
200
+
201
+ # Decoding part
202
+ out = self.bottleneck(out)
203
+ for i in range(len(self.up_blocks)):
204
+ out = self.up_blocks[i](out)
205
+ out = self.final(out)
206
+ out = F.sigmoid(out)
207
+
208
+ output_dict["prediction"] = out
209
+
210
+ return output_dict
211
+
212
+ class SPADEDepthAwareGenerator(nn.Module):
213
+ """
214
+ Generator that given source image and and keypoints try to transform image according to movement trajectories
215
+ induced by keypoints. Generator follows Johnson architecture.
216
+ """
217
+
218
+ def __init__(self, num_channels, num_kp, block_expansion, max_features, num_down_blocks,
219
+ num_bottleneck_blocks, estimate_occlusion_map=False, dense_motion_params=None, estimate_jacobian=False):
220
+ super(SPADEDepthAwareGenerator, self).__init__()
221
+
222
+ if dense_motion_params is not None:
223
+ self.dense_motion_network = DenseMotionNetwork(num_kp=num_kp, num_channels=num_channels,
224
+ estimate_occlusion_map=estimate_occlusion_map,
225
+ **dense_motion_params)
226
+ else:
227
+ self.dense_motion_network = None
228
+
229
+ self.first = SameBlock2d(num_channels, block_expansion, kernel_size=(7, 7), padding=(3, 3))
230
+ down_blocks = []
231
+ for i in range(num_down_blocks):
232
+ in_features = min(max_features, block_expansion * (2 ** i))
233
+ out_features = min(max_features, block_expansion * (2 ** (i + 1)))
234
+ down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1)))
235
+ self.down_blocks = nn.ModuleList(down_blocks)
236
+
237
+ #source depth
238
+ self.src_first = SameBlock2d(1, block_expansion, kernel_size=(7, 7), padding=(3, 3))
239
+ src_down_blocks = []
240
+ for i in range(num_down_blocks):
241
+ in_features = min(max_features, block_expansion * (2 ** i))
242
+ out_features = min(max_features, block_expansion * (2 ** (i + 1)))
243
+ src_down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1)))
244
+ self.src_down_blocks = nn.ModuleList(src_down_blocks)
245
+
246
+ # #driving depth
247
+ # self.dst_first = SameBlock2d(1, block_expansion, kernel_size=(7, 7), padding=(3, 3))
248
+ # dst_down_blocks = []
249
+ # for i in range(num_down_blocks):
250
+ # in_features = min(max_features, block_expansion * (2 ** i))
251
+ # out_features = min(max_features, block_expansion * (2 ** (i + 1)))
252
+ # dst_down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1)))
253
+ # self.dst_down_blocks = nn.ModuleList(dst_down_blocks)
254
+
255
+ self.AttnModule = DepthAwareAttention(out_features,nn.ReLU())
256
+ self.decoder = SPADEGenerator()
257
+
258
+ self.estimate_occlusion_map = estimate_occlusion_map
259
+ self.num_channels = num_channels
260
+
261
+ def deform_input(self, inp, deformation):
262
+ _, h_old, w_old, _ = deformation.shape
263
+ _, _, h, w = inp.shape
264
+ if h_old != h or w_old != w:
265
+ deformation = deformation.permute(0, 3, 1, 2)
266
+ deformation = F.interpolate(deformation, size=(h, w), mode='bilinear')
267
+ deformation = deformation.permute(0, 2, 3, 1)
268
+ return F.grid_sample(inp, deformation)
269
+
270
+ def forward(self, source_image, kp_driving, kp_source, source_depth, driving_depth):
271
+ # Encoding (downsampling) part
272
+ out = self.first(source_image)
273
+ for i in range(len(self.down_blocks)):
274
+ out = self.down_blocks[i](out)
275
+
276
+ src_out = self.src_first(source_depth)
277
+ for i in range(len(self.src_down_blocks)):
278
+ src_out = self.src_down_blocks[i](src_out)
279
+
280
+ # dst_out = self.dst_first(driving_depth)
281
+ # for i in range(len(self.down_blocks)):
282
+ # dst_out = self.dst_down_blocks[i](dst_out)
283
+
284
+ # Transforming feature representation according to deformation and occlusion
285
+ output_dict = {}
286
+ if self.dense_motion_network is not None:
287
+ dense_motion = self.dense_motion_network(source_image=source_image, kp_driving=kp_driving,
288
+ kp_source=kp_source)
289
+ output_dict['mask'] = dense_motion['mask']
290
+ output_dict['sparse_deformed'] = dense_motion['sparse_deformed']
291
+
292
+ if 'occlusion_map' in dense_motion:
293
+ occlusion_map = dense_motion['occlusion_map']
294
+ output_dict['occlusion_map'] = occlusion_map
295
+ else:
296
+ occlusion_map = None
297
+ deformation = dense_motion['deformation']
298
+ out = self.deform_input(out, deformation)
299
+
300
+ if occlusion_map is not None:
301
+ if out.shape[2] != occlusion_map.shape[2] or out.shape[3] != occlusion_map.shape[3]:
302
+ occlusion_map = F.interpolate(occlusion_map, size=out.shape[2:], mode='bilinear')
303
+ out = out * occlusion_map
304
+
305
+ out,attention = self.AttnModule(src_out,out)
306
+
307
+ deformed_image = self.deform_input(source_image, deformation)
308
+ output_dict["deformed"] = deformed_image
309
+ output_dict["attention"] = attention
310
+
311
+ if occlusion_map is not None:
312
+ if deformed_image.shape[2] != occlusion_map.shape[2] or deformed_image.shape[3] != occlusion_map.shape[3]:
313
+ occlusion_map = F.interpolate(occlusion_map, size=deformed_image.shape[2:], mode='bilinear')
314
+ deformed_image = deformed_image * occlusion_map
315
+
316
+ out = self.decoder(out, deformed_image)
317
+
318
+ # # Decoding part
319
+ # out = self.bottleneck(out)
320
+ # for i in range(len(self.up_blocks)):
321
+ # out = self.up_blocks[i](out)
322
+ # out = self.final(out)
323
+ # out = F.sigmoid(out)
324
+ output_dict["prediction"] = out
325
+ return output_dict
modules/keypoint_detector.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from modules.util import Hourglass, make_coordinate_grid, AntiAliasInterpolation2d,Hourglass_2branch
5
+ import pdb
6
+
7
+ class KPDetector(nn.Module):
8
+ """
9
+ Detecting a keypoints. Return keypoint position and jacobian near each keypoint.
10
+ """
11
+
12
+ def __init__(self, block_expansion, num_kp, num_channels, max_features,
13
+ num_blocks, temperature, estimate_jacobian=False, scale_factor=1,
14
+ single_jacobian_map=False, pad=0):
15
+ super(KPDetector, self).__init__()
16
+ self.predictor = Hourglass(block_expansion, in_features=num_channels,
17
+ max_features=max_features, num_blocks=num_blocks)
18
+
19
+ self.kp = nn.Conv2d(in_channels=self.predictor.out_filters, out_channels=num_kp, kernel_size=(7, 7),
20
+ padding=pad)
21
+
22
+ if estimate_jacobian:
23
+ self.num_jacobian_maps = 1 if single_jacobian_map else num_kp
24
+ self.jacobian = nn.Conv2d(in_channels=self.predictor.out_filters,
25
+ out_channels=4 * self.num_jacobian_maps, kernel_size=(7, 7), padding=pad)
26
+ self.jacobian.weight.data.zero_()
27
+ self.jacobian.bias.data.copy_(torch.tensor([1, 0, 0, 1] * self.num_jacobian_maps, dtype=torch.float))
28
+ else:
29
+ self.jacobian = None
30
+
31
+ self.temperature = temperature
32
+ self.scale_factor = scale_factor
33
+ if self.scale_factor != 1:
34
+ self.down = AntiAliasInterpolation2d(num_channels, self.scale_factor)
35
+
36
+ def gaussian2kp(self, heatmap):
37
+ """
38
+ Extract the mean and from a heatmap
39
+ """
40
+ shape = heatmap.shape
41
+ heatmap = heatmap.unsqueeze(-1)
42
+ grid = make_coordinate_grid(shape[2:], heatmap.type()).unsqueeze_(0).unsqueeze_(0)
43
+ value = (heatmap * grid).sum(dim=(2, 3))
44
+ kp = {'value': value}
45
+
46
+ return kp
47
+
48
+ def forward(self, x):
49
+ if self.scale_factor != 1:
50
+ x = self.down(x)
51
+ feature_map = self.predictor(x) #x bz,4,64,64
52
+ prediction = self.kp(feature_map)
53
+
54
+ final_shape = prediction.shape
55
+ heatmap = prediction.view(final_shape[0], final_shape[1], -1)
56
+ heatmap = F.softmax(heatmap / self.temperature, dim=2)
57
+ heatmap = heatmap.view(*final_shape)
58
+
59
+ out = self.gaussian2kp(heatmap)
60
+
61
+ if self.jacobian is not None:
62
+ jacobian_map = self.jacobian(feature_map)
63
+ # pdb.set_trace()
64
+ jacobian_map = jacobian_map.reshape(final_shape[0], self.num_jacobian_maps, 4, final_shape[2],
65
+ final_shape[3])
66
+ heatmap = heatmap.unsqueeze(2)
67
+
68
+ jacobian = heatmap * jacobian_map
69
+ jacobian = jacobian.view(final_shape[0], final_shape[1], 4, -1)
70
+ jacobian = jacobian.sum(dim=-1)
71
+ jacobian = jacobian.view(jacobian.shape[0], jacobian.shape[1], 2, 2)
72
+ out['jacobian'] = jacobian
73
+
74
+ return out
75
+
modules/model.py ADDED
@@ -0,0 +1,362 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from modules.util import AntiAliasInterpolation2d, make_coordinate_grid
5
+ from torchvision import models
6
+ import numpy as np
7
+ from torch.autograd import grad
8
+ import pdb
9
+ import depth
10
+
11
+ class Vgg19(torch.nn.Module):
12
+ """
13
+ Vgg19 network for perceptual loss. See Sec 3.3.
14
+ """
15
+ def __init__(self, requires_grad=False):
16
+ super(Vgg19, self).__init__()
17
+ vgg_pretrained_features = models.vgg19(pretrained=True).features
18
+ self.slice1 = torch.nn.Sequential()
19
+ self.slice2 = torch.nn.Sequential()
20
+ self.slice3 = torch.nn.Sequential()
21
+ self.slice4 = torch.nn.Sequential()
22
+ self.slice5 = torch.nn.Sequential()
23
+ for x in range(2):
24
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
25
+ for x in range(2, 7):
26
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
27
+ for x in range(7, 12):
28
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
29
+ for x in range(12, 21):
30
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
31
+ for x in range(21, 30):
32
+ self.slice5.add_module(str(x), vgg_pretrained_features[x])
33
+
34
+ self.mean = torch.nn.Parameter(data=torch.Tensor(np.array([0.485, 0.456, 0.406]).reshape((1, 3, 1, 1))),
35
+ requires_grad=False)
36
+ self.std = torch.nn.Parameter(data=torch.Tensor(np.array([0.229, 0.224, 0.225]).reshape((1, 3, 1, 1))),
37
+ requires_grad=False)
38
+
39
+ if not requires_grad:
40
+ for param in self.parameters():
41
+ param.requires_grad = False
42
+
43
+ def forward(self, X):
44
+ X = (X - self.mean) / self.std
45
+ h_relu1 = self.slice1(X)
46
+ h_relu2 = self.slice2(h_relu1)
47
+ h_relu3 = self.slice3(h_relu2)
48
+ h_relu4 = self.slice4(h_relu3)
49
+ h_relu5 = self.slice5(h_relu4)
50
+ out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
51
+ return out
52
+
53
+
54
+ class ImagePyramide(torch.nn.Module):
55
+ """
56
+ Create image pyramide for computing pyramide perceptual loss. See Sec 3.3
57
+ """
58
+ def __init__(self, scales, num_channels):
59
+ super(ImagePyramide, self).__init__()
60
+ downs = {}
61
+ for scale in scales:
62
+ downs[str(scale).replace('.', '-')] = AntiAliasInterpolation2d(num_channels, scale)
63
+ self.downs = nn.ModuleDict(downs)
64
+
65
+ def forward(self, x):
66
+ out_dict = {}
67
+ for scale, down_module in self.downs.items():
68
+ out_dict['prediction_' + str(scale).replace('-', '.')] = down_module(x)
69
+ return out_dict
70
+
71
+
72
+ class Transform:
73
+ """
74
+ Random tps transformation for equivariance constraints. See Sec 3.3
75
+ """
76
+ def __init__(self, bs, **kwargs):
77
+ noise = torch.normal(mean=0, std=kwargs['sigma_affine'] * torch.ones([bs, 2, 3]))
78
+ self.theta = noise + torch.eye(2, 3).view(1, 2, 3)
79
+ self.bs = bs
80
+
81
+ if ('sigma_tps' in kwargs) and ('points_tps' in kwargs):
82
+ self.tps = True
83
+ self.control_points = make_coordinate_grid((kwargs['points_tps'], kwargs['points_tps']), type=noise.type())
84
+ self.control_points = self.control_points.unsqueeze(0)
85
+ self.control_params = torch.normal(mean=0,
86
+ std=kwargs['sigma_tps'] * torch.ones([bs, 1, kwargs['points_tps'] ** 2]))
87
+ else:
88
+ self.tps = False
89
+
90
+ def transform_frame(self, frame):
91
+ grid = make_coordinate_grid(frame.shape[2:], type=frame.type()).unsqueeze(0)
92
+ grid = grid.view(1, frame.shape[2] * frame.shape[3], 2)
93
+ grid = self.warp_coordinates(grid).view(self.bs, frame.shape[2], frame.shape[3], 2)
94
+ return F.grid_sample(frame, grid, padding_mode="reflection")
95
+
96
+ def warp_coordinates(self, coordinates):
97
+ theta = self.theta.type(coordinates.type())
98
+ theta = theta.unsqueeze(1)
99
+ transformed = torch.matmul(theta[:, :, :, :2], coordinates.unsqueeze(-1)) + theta[:, :, :, 2:]
100
+ transformed = transformed.squeeze(-1)
101
+
102
+ if self.tps:
103
+ control_points = self.control_points.type(coordinates.type())
104
+ control_params = self.control_params.type(coordinates.type())
105
+ distances = coordinates.view(coordinates.shape[0], -1, 1, 2) - control_points.view(1, 1, -1, 2)
106
+ distances = torch.abs(distances).sum(-1)
107
+
108
+ result = distances ** 2
109
+ result = result * torch.log(distances + 1e-6)
110
+ result = result * control_params
111
+ result = result.sum(dim=2).view(self.bs, coordinates.shape[1], 1)
112
+ transformed = transformed + result
113
+
114
+ return transformed
115
+
116
+ def jacobian(self, coordinates):
117
+ new_coordinates = self.warp_coordinates(coordinates)
118
+ grad_x = grad(new_coordinates[..., 0].sum(), coordinates, create_graph=True)
119
+ grad_y = grad(new_coordinates[..., 1].sum(), coordinates, create_graph=True)
120
+ jacobian = torch.cat([grad_x[0].unsqueeze(-2), grad_y[0].unsqueeze(-2)], dim=-2)
121
+ return jacobian
122
+
123
+
124
+ def detach_kp(kp):
125
+ return {key: value.detach() for key, value in kp.items()}
126
+
127
+
128
+ class GeneratorFullModel(torch.nn.Module):
129
+ """
130
+ Merge all generator related updates into single model for better multi-gpu usage
131
+ """
132
+
133
+ def __init__(self, kp_extractor, generator, discriminator, train_params,opt):
134
+ super(GeneratorFullModel, self).__init__()
135
+ self.kp_extractor = kp_extractor
136
+ self.generator = generator
137
+ self.discriminator = discriminator
138
+ self.train_params = train_params
139
+ self.scales = train_params['scales']
140
+ self.disc_scales = self.discriminator.module.scales
141
+ self.pyramid = ImagePyramide(self.scales, generator.module.num_channels)
142
+ if torch.cuda.is_available():
143
+ self.pyramid = self.pyramid.cuda()
144
+ self.opt = opt
145
+ self.loss_weights = train_params['loss_weights']
146
+
147
+ if sum(self.loss_weights['perceptual']) != 0:
148
+ self.vgg = Vgg19()
149
+ if torch.cuda.is_available():
150
+ self.vgg = self.vgg.cuda()
151
+ self.depth_encoder = depth.ResnetEncoder(18, False).cuda()
152
+ self.depth_decoder = depth.DepthDecoder(num_ch_enc=self.depth_encoder.num_ch_enc, scales=range(4)).cuda()
153
+ loaded_dict_enc = torch.load('depth/models/weights_19/encoder.pth',map_location='cpu')
154
+ loaded_dict_dec = torch.load('depth/models/weights_19/depth.pth',map_location='cpu')
155
+ filtered_dict_enc = {k: v for k, v in loaded_dict_enc.items() if k in self.depth_encoder.state_dict()}
156
+ self.depth_encoder.load_state_dict(filtered_dict_enc)
157
+ self.depth_decoder.load_state_dict(loaded_dict_dec)
158
+ self.set_requires_grad(self.depth_encoder, False)
159
+ self.set_requires_grad(self.depth_decoder, False)
160
+ self.depth_decoder.eval()
161
+ self.depth_encoder.eval()
162
+ def set_requires_grad(self, nets, requires_grad=False):
163
+ """Set requies_grad=Fasle for all the networks to avoid unnecessary computations
164
+ Parameters:
165
+ nets (network list) -- a list of networks
166
+ requires_grad (bool) -- whether the networks require gradients or not
167
+ """
168
+ if not isinstance(nets, list):
169
+ nets = [nets]
170
+ for net in nets:
171
+ if net is not None:
172
+ for param in net.parameters():
173
+ param.requires_grad = requires_grad
174
+ def forward(self, x):
175
+ depth_source = None
176
+ depth_driving = None
177
+ outputs = self.depth_decoder(self.depth_encoder(x['source']))
178
+ depth_source = outputs[("disp", 0)]
179
+ outputs = self.depth_decoder(self.depth_encoder(x['driving']))
180
+ depth_driving = outputs[("disp", 0)]
181
+
182
+ if self.opt.use_depth:
183
+ kp_source = self.kp_extractor(depth_source)
184
+ kp_driving = self.kp_extractor(depth_driving)
185
+ elif self.opt.rgbd:
186
+ source = torch.cat((x['source'],depth_source),1)
187
+ driving = torch.cat((x['driving'],depth_driving),1)
188
+ kp_source = self.kp_extractor(source)
189
+ kp_driving = self.kp_extractor(driving)
190
+ else:
191
+ kp_source = self.kp_extractor(x['source'])
192
+ kp_driving = self.kp_extractor(x['driving'])
193
+ generated = self.generator(x['source'], kp_source=kp_source, kp_driving=kp_driving, source_depth = depth_source, driving_depth = depth_driving)
194
+ generated.update({'kp_source': kp_source, 'kp_driving': kp_driving})
195
+ loss_values = {}
196
+ pyramide_real = self.pyramid(x['driving'])
197
+ pyramide_generated = self.pyramid(generated['prediction'])
198
+ if sum(self.loss_weights['perceptual']) != 0:
199
+ value_total = 0
200
+ for scale in self.scales:
201
+ x_vgg = self.vgg(pyramide_generated['prediction_' + str(scale)])
202
+ y_vgg = self.vgg(pyramide_real['prediction_' + str(scale)])
203
+
204
+ for i, weight in enumerate(self.loss_weights['perceptual']):
205
+ value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean()
206
+ value_total += self.loss_weights['perceptual'][i] * value
207
+ loss_values['perceptual'] = value_total
208
+
209
+ if self.loss_weights['generator_gan'] != 0:
210
+
211
+ discriminator_maps_generated = self.discriminator(pyramide_generated, kp=detach_kp(kp_driving))
212
+
213
+ discriminator_maps_real = self.discriminator(pyramide_real, kp=detach_kp(kp_driving))
214
+ value_total = 0
215
+ for scale in self.disc_scales:
216
+ key = 'prediction_map_%s' % scale
217
+ value = ((1 - discriminator_maps_generated[key]) ** 2).mean()
218
+ value_total += self.loss_weights['generator_gan'] * value
219
+ loss_values['gen_gan'] = value_total
220
+
221
+ if sum(self.loss_weights['feature_matching']) != 0:
222
+ value_total = 0
223
+ for scale in self.disc_scales:
224
+ key = 'feature_maps_%s' % scale
225
+ for i, (a, b) in enumerate(zip(discriminator_maps_real[key], discriminator_maps_generated[key])):
226
+ if self.loss_weights['feature_matching'][i] == 0:
227
+ continue
228
+ value = torch.abs(a - b).mean()
229
+ value_total += self.loss_weights['feature_matching'][i] * value
230
+ loss_values['feature_matching'] = value_total
231
+
232
+ if (self.loss_weights['equivariance_value'] + self.loss_weights['equivariance_jacobian']) != 0:
233
+ transform = Transform(x['driving'].shape[0], **self.train_params['transform_params'])
234
+ transformed_frame = transform.transform_frame(x['driving'])
235
+ if self.opt.use_depth:
236
+ outputs = self.depth_decoder(self.depth_encoder(transformed_frame))
237
+ depth_transform = outputs[("disp", 0)]
238
+ transformed_kp = self.kp_extractor(depth_transform)
239
+ elif self.opt.rgbd:
240
+ outputs = self.depth_decoder(self.depth_encoder(transformed_frame))
241
+ depth_transform = outputs[("disp", 0)]
242
+ transform_img = torch.cat((transformed_frame,depth_transform),1)
243
+ transformed_kp = self.kp_extractor(transform_img)
244
+ else:
245
+ transformed_kp = self.kp_extractor(transformed_frame)
246
+
247
+ generated['transformed_frame'] = transformed_frame
248
+ generated['transformed_kp'] = transformed_kp
249
+
250
+ ## Value loss part
251
+ if self.loss_weights['equivariance_value'] != 0:
252
+ value = torch.abs(kp_driving['value'] - transform.warp_coordinates(transformed_kp['value'])).mean()
253
+ loss_values['equivariance_value'] = self.loss_weights['equivariance_value'] * value
254
+
255
+ ## jacobian loss part
256
+ if self.loss_weights['equivariance_jacobian'] != 0:
257
+ jacobian_transformed = torch.matmul(transform.jacobian(transformed_kp['value']),
258
+ transformed_kp['jacobian'])
259
+
260
+ normed_driving = torch.inverse(kp_driving['jacobian'])
261
+ normed_transformed = jacobian_transformed
262
+ value = torch.matmul(normed_driving, normed_transformed)
263
+
264
+ eye = torch.eye(2).view(1, 1, 2, 2).type(value.type())
265
+
266
+ value = torch.abs(eye - value).mean()
267
+ loss_values['equivariance_jacobian'] = self.loss_weights['equivariance_jacobian'] * value
268
+
269
+
270
+ if self.loss_weights['kp_distance']:
271
+ bz,num_kp,kp_dim = kp_source['value'].shape
272
+ sk = kp_source['value'].unsqueeze(2)-kp_source['value'].unsqueeze(1)
273
+ dk = kp_driving['value'].unsqueeze(2)-kp_driving['value'].unsqueeze(1)
274
+ source_dist_loss = (-torch.sign((torch.sqrt((sk*sk).sum(-1)+1e-8)+torch.eye(num_kp).cuda()*0.2)-0.2)+1).mean()
275
+ driving_dist_loss = (-torch.sign((torch.sqrt((dk*dk).sum(-1)+1e-8)+torch.eye(num_kp).cuda()*0.2)-0.2)+1).mean()
276
+ # driving_dist_loss = (torch.sign(1-(torch.sqrt((dk*dk).sum(-1)+1e-8)+torch.eye(num_kp).cuda()))+1).mean()
277
+ value_total = self.loss_weights['kp_distance']*(source_dist_loss+driving_dist_loss)
278
+ loss_values['kp_distance'] = value_total
279
+ if self.loss_weights['kp_prior']:
280
+ bz,num_kp,kp_dim = kp_source['value'].shape
281
+ sk = kp_source['value'].unsqueeze(2)-kp_source['value'].unsqueeze(1)
282
+ dk = kp_driving['value'].unsqueeze(2)-kp_driving['value'].unsqueeze(1)
283
+ dis_loss = torch.relu(0.1-torch.sqrt((sk*sk).sum(-1)+1e-8))+torch.relu(0.1-torch.sqrt((dk*dk).sum(-1)+1e-8))
284
+ bs,nk,_=kp_source['value'].shape
285
+ scoor_depth = F.grid_sample(depth_source,kp_source['value'].view(bs,1,nk,-1))
286
+ dcoor_depth = F.grid_sample(depth_driving,kp_driving['value'].view(bs,1,nk,-1))
287
+ sd_loss = torch.abs(scoor_depth.mean(-1,keepdim=True) - kp_source['value'].view(bs,1,nk,-1)).mean()
288
+ dd_loss = torch.abs(dcoor_depth.mean(-1,keepdim=True) - kp_driving['value'].view(bs,1,nk,-1)).mean()
289
+ value_total = self.loss_weights['kp_distance']*(dis_loss+sd_loss+dd_loss)
290
+ loss_values['kp_distance'] = value_total
291
+
292
+
293
+ if self.loss_weights['kp_scale']:
294
+ bz,num_kp,kp_dim = kp_source['value'].shape
295
+ if self.opt.rgbd:
296
+ outputs = self.depth_decoder(self.depth_encoder(generated['prediction']))
297
+ depth_pred = outputs[("disp", 0)]
298
+ pred = torch.cat((generated['prediction'],depth_pred),1)
299
+ kp_pred = self.kp_extractor(pred)
300
+ elif self.opt.use_depth:
301
+ outputs = self.depth_decoder(self.depth_encoder(generated['prediction']))
302
+ depth_pred = outputs[("disp", 0)]
303
+ kp_pred = self.kp_extractor(depth_pred)
304
+ else:
305
+ kp_pred = self.kp_extractor(generated['prediction'])
306
+
307
+ pred_mean = kp_pred['value'].mean(1,keepdim=True)
308
+ driving_mean = kp_driving['value'].mean(1,keepdim=True)
309
+ pk = kp_source['value']-pred_mean
310
+ dk = kp_driving['value']- driving_mean
311
+ pred_dist_loss = torch.sqrt((pk*pk).sum(-1)+1e-8)
312
+ driving_dist_loss = torch.sqrt((dk*dk).sum(-1)+1e-8)
313
+ scale_vec = driving_dist_loss/pred_dist_loss
314
+ bz,n = scale_vec.shape
315
+ value = torch.abs(scale_vec[:,:n-1]-scale_vec[:,1:]).mean()
316
+ value_total = self.loss_weights['kp_scale']*value
317
+ loss_values['kp_scale'] = value_total
318
+ if self.loss_weights['depth_constraint']:
319
+ bz,num_kp,kp_dim = kp_source['value'].shape
320
+ outputs = self.depth_decoder(self.depth_encoder(generated['prediction']))
321
+ depth_pred = outputs[("disp", 0)]
322
+ value_total = self.loss_weights['depth_constraint']*torch.abs(depth_driving-depth_pred).mean()
323
+ loss_values['depth_constraint'] = value_total
324
+ return loss_values, generated
325
+
326
+
327
+
328
+ class DiscriminatorFullModel(torch.nn.Module):
329
+ """
330
+ Merge all discriminator related updates into single model for better multi-gpu usage
331
+ """
332
+
333
+ def __init__(self, kp_extractor, generator, discriminator, train_params):
334
+ super(DiscriminatorFullModel, self).__init__()
335
+ self.kp_extractor = kp_extractor
336
+ self.generator = generator
337
+ self.discriminator = discriminator
338
+ self.train_params = train_params
339
+ self.scales = self.discriminator.module.scales
340
+ self.pyramid = ImagePyramide(self.scales, generator.module.num_channels)
341
+ if torch.cuda.is_available():
342
+ self.pyramid = self.pyramid.cuda()
343
+
344
+ self.loss_weights = train_params['loss_weights']
345
+
346
+ def forward(self, x, generated):
347
+ pyramide_real = self.pyramid(x['driving'])
348
+ pyramide_generated = self.pyramid(generated['prediction'].detach())
349
+
350
+ kp_driving = generated['kp_driving']
351
+ discriminator_maps_generated = self.discriminator(pyramide_generated, kp=detach_kp(kp_driving))
352
+ discriminator_maps_real = self.discriminator(pyramide_real, kp=detach_kp(kp_driving))
353
+
354
+ loss_values = {}
355
+ value_total = 0
356
+ for scale in self.scales:
357
+ key = 'prediction_map_%s' % scale
358
+ value = (1 - discriminator_maps_real[key]) ** 2 + discriminator_maps_generated[key] ** 2
359
+ value_total += self.loss_weights['discriminator_gan'] * value.mean()
360
+ loss_values['disc_gan'] = value_total
361
+
362
+ return loss_values
modules/util.py ADDED
@@ -0,0 +1,399 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+
3
+ import torch.nn.functional as F
4
+ import torch
5
+
6
+ from sync_batchnorm import SynchronizedBatchNorm2d as BatchNorm2d
7
+ import pdb
8
+ import torch.nn.utils.spectral_norm as spectral_norm
9
+ def kp2gaussian(kp, spatial_size, kp_variance):
10
+ """
11
+ Transform a keypoint into gaussian like representation
12
+ """
13
+ mean = kp['value']
14
+
15
+ coordinate_grid = make_coordinate_grid(spatial_size, mean.type())
16
+ number_of_leading_dimensions = len(mean.shape) - 1
17
+ shape = (1,) * number_of_leading_dimensions + coordinate_grid.shape
18
+ coordinate_grid = coordinate_grid.view(*shape)
19
+ repeats = mean.shape[:number_of_leading_dimensions] + (1, 1, 1)
20
+ coordinate_grid = coordinate_grid.repeat(*repeats)
21
+
22
+ # Preprocess kp shape
23
+ shape = mean.shape[:number_of_leading_dimensions] + (1, 1, 2)
24
+ mean = mean.view(*shape)
25
+
26
+ mean_sub = (coordinate_grid - mean)
27
+
28
+ out = torch.exp(-0.5 * (mean_sub ** 2).sum(-1) / kp_variance)
29
+
30
+ return out
31
+
32
+
33
+ def make_coordinate_grid(spatial_size, type):
34
+ """
35
+ Create a meshgrid [-1,1] x [-1,1] of given spatial_size.
36
+ """
37
+ h, w = spatial_size
38
+ x = torch.arange(w).type(type)
39
+ y = torch.arange(h).type(type)
40
+
41
+ x = (2 * (x / (w - 1)) - 1)
42
+ y = (2 * (y / (h - 1)) - 1)
43
+
44
+ yy = y.view(-1, 1).repeat(1, w)
45
+ xx = x.view(1, -1).repeat(h, 1)
46
+
47
+ meshed = torch.cat([xx.unsqueeze_(2), yy.unsqueeze_(2)], 2)
48
+
49
+ return meshed
50
+
51
+
52
+ class ResBlock2d(nn.Module):
53
+ """
54
+ Res block, preserve spatial resolution.
55
+ """
56
+
57
+ def __init__(self, in_features, kernel_size, padding):
58
+ super(ResBlock2d, self).__init__()
59
+ self.conv1 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size,
60
+ padding=padding)
61
+ self.conv2 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size,
62
+ padding=padding)
63
+ self.norm1 = BatchNorm2d(in_features, affine=True)
64
+ self.norm2 = BatchNorm2d(in_features, affine=True)
65
+
66
+ def forward(self, x):
67
+ out = self.norm1(x)
68
+ out = F.relu(out)
69
+ out = self.conv1(out)
70
+ out = self.norm2(out)
71
+ out = F.relu(out)
72
+ out = self.conv2(out)
73
+ out += x
74
+ return out
75
+
76
+
77
+ class UpBlock2d(nn.Module):
78
+ """
79
+ Upsampling block for use in decoder.
80
+ """
81
+
82
+ def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):
83
+ super(UpBlock2d, self).__init__()
84
+
85
+ self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,
86
+ padding=padding, groups=groups)
87
+ self.norm = BatchNorm2d(out_features, affine=True)
88
+
89
+ def forward(self, x):
90
+ out = F.interpolate(x, scale_factor=2)
91
+ out = self.conv(out)
92
+ out = self.norm(out)
93
+ out = F.relu(out)
94
+ return out
95
+
96
+
97
+ class DownBlock2d(nn.Module):
98
+ """
99
+ Downsampling block for use in encoder.
100
+ """
101
+
102
+ def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):
103
+ super(DownBlock2d, self).__init__()
104
+ self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,
105
+ padding=padding, groups=groups)
106
+ self.norm = BatchNorm2d(out_features, affine=True)
107
+ self.pool = nn.AvgPool2d(kernel_size=(2, 2))
108
+
109
+ def forward(self, x):
110
+ out = self.conv(x)
111
+ out = self.norm(out)
112
+ out = F.relu(out)
113
+ out = self.pool(out)
114
+ return out
115
+
116
+
117
+ class SameBlock2d(nn.Module):
118
+ """
119
+ Simple block, preserve spatial resolution.
120
+ """
121
+
122
+ def __init__(self, in_features, out_features, groups=1, kernel_size=3, padding=1):
123
+ super(SameBlock2d, self).__init__()
124
+ self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features,
125
+ kernel_size=kernel_size, padding=padding, groups=groups)
126
+ self.norm = BatchNorm2d(out_features, affine=True)
127
+
128
+ def forward(self, x):
129
+ out = self.conv(x)
130
+ out = self.norm(out)
131
+ out = F.relu(out)
132
+ return out
133
+
134
+
135
+ class Encoder(nn.Module):
136
+ """
137
+ Hourglass Encoder
138
+ """
139
+
140
+ def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
141
+ super(Encoder, self).__init__()
142
+
143
+ down_blocks = []
144
+ for i in range(num_blocks):
145
+ down_blocks.append(DownBlock2d(in_features if i == 0 else min(max_features, block_expansion * (2 ** i)),
146
+ min(max_features, block_expansion * (2 ** (i + 1))),
147
+ kernel_size=3, padding=1))
148
+ self.down_blocks = nn.ModuleList(down_blocks)
149
+
150
+ def forward(self, x):
151
+ outs = [x]
152
+ for down_block in self.down_blocks:
153
+ outs.append(down_block(outs[-1]))
154
+ return outs
155
+
156
+
157
+ class Decoder(nn.Module):
158
+ """
159
+ Hourglass Decoder
160
+ """
161
+
162
+ def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
163
+ super(Decoder, self).__init__()
164
+
165
+ up_blocks = []
166
+
167
+ for i in range(num_blocks)[::-1]:
168
+ in_filters = (1 if i == num_blocks - 1 else 2) * min(max_features, block_expansion * (2 ** (i + 1)))
169
+ out_filters = min(max_features, block_expansion * (2 ** i))
170
+ up_blocks.append(UpBlock2d(in_filters, out_filters, kernel_size=3, padding=1))
171
+
172
+ self.up_blocks = nn.ModuleList(up_blocks)
173
+ self.out_filters = block_expansion + in_features
174
+
175
+ def forward(self, x):
176
+ out = x.pop()
177
+ for up_block in self.up_blocks:
178
+ out = up_block(out)
179
+ skip = x.pop()
180
+ out = torch.cat([out, skip], dim=1)
181
+ return out
182
+
183
+
184
+ class Decoder_w_emb(nn.Module):
185
+ """
186
+ Hourglass Decoder
187
+ """
188
+
189
+ def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
190
+ super(Decoder_w_emb, self).__init__()
191
+
192
+ up_blocks = []
193
+
194
+ for i in range(num_blocks)[::-1]:
195
+ in_filters = (1 if i == num_blocks - 1 else 2) * min(max_features, block_expansion * (2 ** (i + 1)))
196
+ out_filters = min(max_features, block_expansion * (2 ** i))
197
+ up_blocks.append(UpBlock2d(in_filters, out_filters, kernel_size=3, padding=1))
198
+
199
+ self.up_blocks = nn.ModuleList(up_blocks)
200
+ self.out_filters = block_expansion + in_features
201
+
202
+ def forward(self, x):
203
+ feats = []
204
+ out = x.pop()
205
+ feats.append(out)
206
+ for ind,up_block in enumerate(self.up_blocks):
207
+ out = up_block(out)
208
+ skip = x.pop()
209
+ feats.append(skip)
210
+ out = torch.cat([out, skip], dim=1)
211
+ return out,feats
212
+
213
+ class Decoder_2branch(nn.Module):
214
+ """
215
+ Hourglass Decoder
216
+ """
217
+
218
+ def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
219
+ super(Decoder_2branch, self).__init__()
220
+ up_blocks = []
221
+ for i in range(num_blocks)[::-1]:
222
+ in_filters = (1 if i == num_blocks - 1 else 2) * min(max_features, block_expansion * (2 ** (i + 1)))
223
+ out_filters = min(max_features, block_expansion * (2 ** i))
224
+ up_blocks.append(UpBlock2d(in_filters, out_filters, kernel_size=3, padding=1))
225
+
226
+ self.up_blocks = nn.ModuleList(up_blocks)
227
+ self.out_filters = block_expansion + in_features
228
+
229
+ def forward(self, x):
230
+ # out = x.pop()
231
+ num_feat = len(x)
232
+ out=x[-1]
233
+ for i in range(len(self.up_blocks)):
234
+ out = self.up_blocks[i](out)
235
+ skip = x[-(i+1+1)]
236
+ out = torch.cat([out, skip], dim=1)
237
+ return out
238
+
239
+
240
+
241
+ class Hourglass(nn.Module):
242
+ """
243
+ Hourglass architecture.
244
+ """
245
+
246
+ def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
247
+ super(Hourglass, self).__init__()
248
+ self.encoder = Encoder(block_expansion, in_features, num_blocks, max_features)
249
+ self.decoder = Decoder(block_expansion, in_features, num_blocks, max_features)
250
+ self.out_filters = self.decoder.out_filters
251
+ def forward(self, x):
252
+ return self.decoder(self.encoder(x))
253
+
254
+ class Hourglass_2branch(nn.Module):
255
+ """
256
+ Hourglass architecture.
257
+ """
258
+
259
+ def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
260
+ super(Hourglass_2branch, self).__init__()
261
+ self.encoder = Encoder(block_expansion, in_features, num_blocks, max_features)
262
+ self.decoder_kp = Decoder_2branch(block_expansion, in_features, num_blocks, max_features)
263
+ self.decoder_mask = Decoder_2branch(block_expansion, in_features, num_blocks, max_features)
264
+
265
+ self.out_filters = self.decoder_kp.out_filters
266
+ def forward(self, x):
267
+ embd= self.encoder(x)
268
+ kp_feat = self.decoder_kp(embd)
269
+ mask_feat = self.decoder_mask(embd)
270
+ return kp_feat,mask_feat
271
+
272
+
273
+ class Hourglass_w_emb(nn.Module):
274
+ """
275
+ Hourglass architecture.
276
+ """
277
+
278
+ def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
279
+ super(Hourglass_w_emb, self).__init__()
280
+ self.encoder = Encoder(block_expansion, in_features, num_blocks, max_features)
281
+ self.decoder = Decoder_w_emb(block_expansion, in_features, num_blocks, max_features)
282
+ self.out_filters = self.decoder.out_filters
283
+
284
+ def forward(self, x):
285
+ embs = self.encoder(x)
286
+ result,feats = self.decoder(embs)
287
+ return feats,result
288
+ class AntiAliasInterpolation2d(nn.Module):
289
+ """
290
+ Band-limited downsampling, for better preservation of the input signal.
291
+ """
292
+ def __init__(self, channels, scale):
293
+ super(AntiAliasInterpolation2d, self).__init__()
294
+ sigma = (1 / scale - 1) / 2
295
+ kernel_size = 2 * round(sigma * 4) + 1
296
+ self.ka = kernel_size // 2
297
+ self.kb = self.ka - 1 if kernel_size % 2 == 0 else self.ka
298
+
299
+ kernel_size = [kernel_size, kernel_size]
300
+ sigma = [sigma, sigma]
301
+ # The gaussian kernel is the product of the
302
+ # gaussian function of each dimension.
303
+ kernel = 1
304
+ meshgrids = torch.meshgrid(
305
+ [
306
+ torch.arange(size, dtype=torch.float32)
307
+ for size in kernel_size
308
+ ]
309
+ )
310
+ for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
311
+ mean = (size - 1) / 2
312
+ kernel *= torch.exp(-(mgrid - mean) ** 2 / (2 * std ** 2))
313
+
314
+ # Make sure sum of values in gaussian kernel equals 1.
315
+ kernel = kernel / torch.sum(kernel)
316
+ # Reshape to depthwise convolutional weight
317
+ kernel = kernel.view(1, 1, *kernel.size())
318
+ kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))
319
+
320
+ self.register_buffer('weight', kernel)
321
+ self.groups = channels
322
+ self.scale = scale
323
+ inv_scale = 1 / scale
324
+ self.int_inv_scale = int(inv_scale)
325
+
326
+ def forward(self, input):
327
+ if self.scale == 1.0:
328
+ return input
329
+
330
+ out = F.pad(input, (self.ka, self.kb, self.ka, self.kb))
331
+ out = F.conv2d(out, weight=self.weight, groups=self.groups)
332
+ out = out[:, :, ::self.int_inv_scale, ::self.int_inv_scale]
333
+
334
+ return out
335
+
336
+
337
+ class SPADE(nn.Module):
338
+ def __init__(self, norm_nc, label_nc):
339
+ super().__init__()
340
+
341
+ self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False)
342
+ nhidden = 128
343
+
344
+ self.mlp_shared = nn.Sequential(
345
+ nn.Conv2d(label_nc, nhidden, kernel_size=3, padding=1),
346
+ nn.ReLU())
347
+ self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1)
348
+ self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1)
349
+
350
+ def forward(self, x, segmap):
351
+ normalized = self.param_free_norm(x)
352
+ segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest')
353
+ actv = self.mlp_shared(segmap)
354
+ gamma = self.mlp_gamma(actv)
355
+ beta = self.mlp_beta(actv)
356
+ out = normalized * (1 + gamma) + beta
357
+ return out
358
+
359
+
360
+ class SPADEResnetBlock(nn.Module):
361
+ def __init__(self, fin, fout, norm_G, label_nc, use_se=False, dilation=1):
362
+ super().__init__()
363
+ # Attributes
364
+ self.learned_shortcut = (fin != fout)
365
+ fmiddle = min(fin, fout)
366
+ self.use_se = use_se
367
+ # create conv layers
368
+ self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=dilation, dilation=dilation)
369
+ self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=dilation, dilation=dilation)
370
+ if self.learned_shortcut:
371
+ self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False)
372
+ # apply spectral norm if specified
373
+ if 'spectral' in norm_G:
374
+ self.conv_0 = spectral_norm(self.conv_0)
375
+ self.conv_1 = spectral_norm(self.conv_1)
376
+ if self.learned_shortcut:
377
+ self.conv_s = spectral_norm(self.conv_s)
378
+ # define normalization layers
379
+ self.norm_0 = SPADE(fin, label_nc)
380
+ self.norm_1 = SPADE(fmiddle, label_nc)
381
+ if self.learned_shortcut:
382
+ self.norm_s = SPADE(fin, label_nc)
383
+
384
+ def forward(self, x, seg1):
385
+ x_s = self.shortcut(x, seg1)
386
+ dx = self.conv_0(self.actvn(self.norm_0(x, seg1)))
387
+ dx = self.conv_1(self.actvn(self.norm_1(dx, seg1)))
388
+ out = x_s + dx
389
+ return out
390
+
391
+ def shortcut(self, x, seg1):
392
+ if self.learned_shortcut:
393
+ x_s = self.conv_s(self.norm_s(x, seg1))
394
+ else:
395
+ x_s = x
396
+ return x_s
397
+
398
+ def actvn(self, x):
399
+ return F.leaky_relu(x, 2e-1)
project/cartoon2.jpg ADDED
project/cartoon3.jpg ADDED
project/cartoon4.jpg ADDED
project/cartoon5.jpg ADDED
project/cartoon6.jpg ADDED
project/cartoon7.jpg ADDED
project/celeb1.jpg ADDED
project/celeb2.jpg ADDED
project/celeb3.jpg ADDED
project/celeb4.jpg ADDED
project/celeb6.jpg ADDED
project/celeb7.jpg ADDED
project/celeb8.jpg ADDED
project/video1.mp4 ADDED
Binary file (152 kB). View file
 
project/video2.mp4 ADDED
Binary file (271 kB). View file
 
sync_batchnorm/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # File : __init__.py
3
+ # Author : Jiayuan Mao
4
+ # Email : [email protected]
5
+ # Date : 27/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+ # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9
+ # Distributed under MIT License.
10
+
11
+ from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d
12
+ from .replicate import DataParallelWithCallback, patch_replication_callback
sync_batchnorm/batchnorm.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # File : batchnorm.py
3
+ # Author : Jiayuan Mao
4
+ # Email : [email protected]
5
+ # Date : 27/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+ # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9
+ # Distributed under MIT License.
10
+
11
+ import collections
12
+
13
+ import torch
14
+ import torch.nn.functional as F
15
+
16
+ from torch.nn.modules.batchnorm import _BatchNorm
17
+ from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
18
+
19
+ from .comm import SyncMaster
20
+
21
+ __all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d']
22
+
23
+
24
+ def _sum_ft(tensor):
25
+ """sum over the first and last dimention"""
26
+ return tensor.sum(dim=0).sum(dim=-1)
27
+
28
+
29
+ def _unsqueeze_ft(tensor):
30
+ """add new dementions at the front and the tail"""
31
+ return tensor.unsqueeze(0).unsqueeze(-1)
32
+
33
+
34
+ _ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size'])
35
+ _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std'])
36
+
37
+
38
+ class _SynchronizedBatchNorm(_BatchNorm):
39
+ def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True):
40
+ super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine)
41
+
42
+ self._sync_master = SyncMaster(self._data_parallel_master)
43
+
44
+ self._is_parallel = False
45
+ self._parallel_id = None
46
+ self._slave_pipe = None
47
+
48
+ def forward(self, input):
49
+ # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.
50
+ if not (self._is_parallel and self.training):
51
+ return F.batch_norm(
52
+ input, self.running_mean, self.running_var, self.weight, self.bias,
53
+ self.training, self.momentum, self.eps)
54
+
55
+ # Resize the input to (B, C, -1).
56
+ input_shape = input.size()
57
+ input = input.view(input.size(0), self.num_features, -1)
58
+
59
+ # Compute the sum and square-sum.
60
+ sum_size = input.size(0) * input.size(2)
61
+ input_sum = _sum_ft(input)
62
+ input_ssum = _sum_ft(input ** 2)
63
+
64
+ # Reduce-and-broadcast the statistics.
65
+ if self._parallel_id == 0:
66
+ mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))
67
+ else:
68
+ mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))
69
+
70
+ # Compute the output.
71
+ if self.affine:
72
+ # MJY:: Fuse the multiplication for speed.
73
+ output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias)
74
+ else:
75
+ output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std)
76
+
77
+ # Reshape it.
78
+ return output.view(input_shape)
79
+
80
+ def __data_parallel_replicate__(self, ctx, copy_id):
81
+ self._is_parallel = True
82
+ self._parallel_id = copy_id
83
+
84
+ # parallel_id == 0 means master device.
85
+ if self._parallel_id == 0:
86
+ ctx.sync_master = self._sync_master
87
+ else:
88
+ self._slave_pipe = ctx.sync_master.register_slave(copy_id)
89
+
90
+ def _data_parallel_master(self, intermediates):
91
+ """Reduce the sum and square-sum, compute the statistics, and broadcast it."""
92
+
93
+ # Always using same "device order" makes the ReduceAdd operation faster.
94
+ # Thanks to:: Tete Xiao (http://tetexiao.com/)
95
+ intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device())
96
+
97
+ to_reduce = [i[1][:2] for i in intermediates]
98
+ to_reduce = [j for i in to_reduce for j in i] # flatten
99
+ target_gpus = [i[1].sum.get_device() for i in intermediates]
100
+
101
+ sum_size = sum([i[1].sum_size for i in intermediates])
102
+ sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)
103
+ mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)
104
+
105
+ broadcasted = Broadcast.apply(target_gpus, mean, inv_std)
106
+
107
+ outputs = []
108
+ for i, rec in enumerate(intermediates):
109
+ outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2])))
110
+
111
+ return outputs
112
+
113
+ def _compute_mean_std(self, sum_, ssum, size):
114
+ """Compute the mean and standard-deviation with sum and square-sum. This method
115
+ also maintains the moving average on the master device."""
116
+ assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'
117
+ mean = sum_ / size
118
+ sumvar = ssum - sum_ * mean
119
+ unbias_var = sumvar / (size - 1)
120
+ bias_var = sumvar / size
121
+
122
+ self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
123
+ self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data
124
+
125
+ return mean, bias_var.clamp(self.eps) ** -0.5
126
+
127
+
128
+ class SynchronizedBatchNorm1d(_SynchronizedBatchNorm):
129
+ r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a
130
+ mini-batch.
131
+
132
+ .. math::
133
+
134
+ y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
135
+
136
+ This module differs from the built-in PyTorch BatchNorm1d as the mean and
137
+ standard-deviation are reduced across all devices during training.
138
+
139
+ For example, when one uses `nn.DataParallel` to wrap the network during
140
+ training, PyTorch's implementation normalize the tensor on each device using
141
+ the statistics only on that device, which accelerated the computation and
142
+ is also easy to implement, but the statistics might be inaccurate.
143
+ Instead, in this synchronized version, the statistics will be computed
144
+ over all training samples distributed on multiple devices.
145
+
146
+ Note that, for one-GPU or CPU-only case, this module behaves exactly same
147
+ as the built-in PyTorch implementation.
148
+
149
+ The mean and standard-deviation are calculated per-dimension over
150
+ the mini-batches and gamma and beta are learnable parameter vectors
151
+ of size C (where C is the input size).
152
+
153
+ During training, this layer keeps a running estimate of its computed mean
154
+ and variance. The running sum is kept with a default momentum of 0.1.
155
+
156
+ During evaluation, this running mean/variance is used for normalization.
157
+
158
+ Because the BatchNorm is done over the `C` dimension, computing statistics
159
+ on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm
160
+
161
+ Args:
162
+ num_features: num_features from an expected input of size
163
+ `batch_size x num_features [x width]`
164
+ eps: a value added to the denominator for numerical stability.
165
+ Default: 1e-5
166
+ momentum: the value used for the running_mean and running_var
167
+ computation. Default: 0.1
168
+ affine: a boolean value that when set to ``True``, gives the layer learnable
169
+ affine parameters. Default: ``True``
170
+
171
+ Shape:
172
+ - Input: :math:`(N, C)` or :math:`(N, C, L)`
173
+ - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
174
+
175
+ Examples:
176
+ >>> # With Learnable Parameters
177
+ >>> m = SynchronizedBatchNorm1d(100)
178
+ >>> # Without Learnable Parameters
179
+ >>> m = SynchronizedBatchNorm1d(100, affine=False)
180
+ >>> input = torch.autograd.Variable(torch.randn(20, 100))
181
+ >>> output = m(input)
182
+ """
183
+
184
+ def _check_input_dim(self, input):
185
+ if input.dim() != 2 and input.dim() != 3:
186
+ raise ValueError('expected 2D or 3D input (got {}D input)'
187
+ .format(input.dim()))
188
+ super(SynchronizedBatchNorm1d, self)._check_input_dim(input)
189
+
190
+
191
+ class SynchronizedBatchNorm2d(_SynchronizedBatchNorm):
192
+ r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch
193
+ of 3d inputs
194
+
195
+ .. math::
196
+
197
+ y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
198
+
199
+ This module differs from the built-in PyTorch BatchNorm2d as the mean and
200
+ standard-deviation are reduced across all devices during training.
201
+
202
+ For example, when one uses `nn.DataParallel` to wrap the network during
203
+ training, PyTorch's implementation normalize the tensor on each device using
204
+ the statistics only on that device, which accelerated the computation and
205
+ is also easy to implement, but the statistics might be inaccurate.
206
+ Instead, in this synchronized version, the statistics will be computed
207
+ over all training samples distributed on multiple devices.
208
+
209
+ Note that, for one-GPU or CPU-only case, this module behaves exactly same
210
+ as the built-in PyTorch implementation.
211
+
212
+ The mean and standard-deviation are calculated per-dimension over
213
+ the mini-batches and gamma and beta are learnable parameter vectors
214
+ of size C (where C is the input size).
215
+
216
+ During training, this layer keeps a running estimate of its computed mean
217
+ and variance. The running sum is kept with a default momentum of 0.1.
218
+
219
+ During evaluation, this running mean/variance is used for normalization.
220
+
221
+ Because the BatchNorm is done over the `C` dimension, computing statistics
222
+ on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm
223
+
224
+ Args:
225
+ num_features: num_features from an expected input of
226
+ size batch_size x num_features x height x width
227
+ eps: a value added to the denominator for numerical stability.
228
+ Default: 1e-5
229
+ momentum: the value used for the running_mean and running_var
230
+ computation. Default: 0.1
231
+ affine: a boolean value that when set to ``True``, gives the layer learnable
232
+ affine parameters. Default: ``True``
233
+
234
+ Shape:
235
+ - Input: :math:`(N, C, H, W)`
236
+ - Output: :math:`(N, C, H, W)` (same shape as input)
237
+
238
+ Examples:
239
+ >>> # With Learnable Parameters
240
+ >>> m = SynchronizedBatchNorm2d(100)
241
+ >>> # Without Learnable Parameters
242
+ >>> m = SynchronizedBatchNorm2d(100, affine=False)
243
+ >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45))
244
+ >>> output = m(input)
245
+ """
246
+
247
+ def _check_input_dim(self, input):
248
+ if input.dim() != 4:
249
+ raise ValueError('expected 4D input (got {}D input)'
250
+ .format(input.dim()))
251
+ super(SynchronizedBatchNorm2d, self)._check_input_dim(input)
252
+
253
+
254
+ class SynchronizedBatchNorm3d(_SynchronizedBatchNorm):
255
+ r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch
256
+ of 4d inputs
257
+
258
+ .. math::
259
+
260
+ y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
261
+
262
+ This module differs from the built-in PyTorch BatchNorm3d as the mean and
263
+ standard-deviation are reduced across all devices during training.
264
+
265
+ For example, when one uses `nn.DataParallel` to wrap the network during
266
+ training, PyTorch's implementation normalize the tensor on each device using
267
+ the statistics only on that device, which accelerated the computation and
268
+ is also easy to implement, but the statistics might be inaccurate.
269
+ Instead, in this synchronized version, the statistics will be computed
270
+ over all training samples distributed on multiple devices.
271
+
272
+ Note that, for one-GPU or CPU-only case, this module behaves exactly same
273
+ as the built-in PyTorch implementation.
274
+
275
+ The mean and standard-deviation are calculated per-dimension over
276
+ the mini-batches and gamma and beta are learnable parameter vectors
277
+ of size C (where C is the input size).
278
+
279
+ During training, this layer keeps a running estimate of its computed mean
280
+ and variance. The running sum is kept with a default momentum of 0.1.
281
+
282
+ During evaluation, this running mean/variance is used for normalization.
283
+
284
+ Because the BatchNorm is done over the `C` dimension, computing statistics
285
+ on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm
286
+ or Spatio-temporal BatchNorm
287
+
288
+ Args:
289
+ num_features: num_features from an expected input of
290
+ size batch_size x num_features x depth x height x width
291
+ eps: a value added to the denominator for numerical stability.
292
+ Default: 1e-5
293
+ momentum: the value used for the running_mean and running_var
294
+ computation. Default: 0.1
295
+ affine: a boolean value that when set to ``True``, gives the layer learnable
296
+ affine parameters. Default: ``True``
297
+
298
+ Shape:
299
+ - Input: :math:`(N, C, D, H, W)`
300
+ - Output: :math:`(N, C, D, H, W)` (same shape as input)
301
+
302
+ Examples:
303
+ >>> # With Learnable Parameters
304
+ >>> m = SynchronizedBatchNorm3d(100)
305
+ >>> # Without Learnable Parameters
306
+ >>> m = SynchronizedBatchNorm3d(100, affine=False)
307
+ >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10))
308
+ >>> output = m(input)
309
+ """
310
+
311
+ def _check_input_dim(self, input):
312
+ if input.dim() != 5:
313
+ raise ValueError('expected 5D input (got {}D input)'
314
+ .format(input.dim()))
315
+ super(SynchronizedBatchNorm3d, self)._check_input_dim(input)
sync_batchnorm/comm.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # File : comm.py
3
+ # Author : Jiayuan Mao
4
+ # Email : [email protected]
5
+ # Date : 27/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+ # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9
+ # Distributed under MIT License.
10
+
11
+ import queue
12
+ import collections
13
+ import threading
14
+
15
+ __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']
16
+
17
+
18
+ class FutureResult(object):
19
+ """A thread-safe future implementation. Used only as one-to-one pipe."""
20
+
21
+ def __init__(self):
22
+ self._result = None
23
+ self._lock = threading.Lock()
24
+ self._cond = threading.Condition(self._lock)
25
+
26
+ def put(self, result):
27
+ with self._lock:
28
+ assert self._result is None, 'Previous result has\'t been fetched.'
29
+ self._result = result
30
+ self._cond.notify()
31
+
32
+ def get(self):
33
+ with self._lock:
34
+ if self._result is None:
35
+ self._cond.wait()
36
+
37
+ res = self._result
38
+ self._result = None
39
+ return res
40
+
41
+
42
+ _MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
43
+ _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
44
+
45
+
46
+ class SlavePipe(_SlavePipeBase):
47
+ """Pipe for master-slave communication."""
48
+
49
+ def run_slave(self, msg):
50
+ self.queue.put((self.identifier, msg))
51
+ ret = self.result.get()
52
+ self.queue.put(True)
53
+ return ret
54
+
55
+
56
+ class SyncMaster(object):
57
+ """An abstract `SyncMaster` object.
58
+
59
+ - During the replication, as the data parallel will trigger an callback of each module, all slave devices should
60
+ call `register(id)` and obtain an `SlavePipe` to communicate with the master.
61
+ - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
62
+ and passed to a registered callback.
63
+ - After receiving the messages, the master device should gather the information and determine to message passed
64
+ back to each slave devices.
65
+ """
66
+
67
+ def __init__(self, master_callback):
68
+ """
69
+
70
+ Args:
71
+ master_callback: a callback to be invoked after having collected messages from slave devices.
72
+ """
73
+ self._master_callback = master_callback
74
+ self._queue = queue.Queue()
75
+ self._registry = collections.OrderedDict()
76
+ self._activated = False
77
+
78
+ def __getstate__(self):
79
+ return {'master_callback': self._master_callback}
80
+
81
+ def __setstate__(self, state):
82
+ self.__init__(state['master_callback'])
83
+
84
+ def register_slave(self, identifier):
85
+ """
86
+ Register an slave device.
87
+
88
+ Args:
89
+ identifier: an identifier, usually is the device id.
90
+
91
+ Returns: a `SlavePipe` object which can be used to communicate with the master device.
92
+
93
+ """
94
+ if self._activated:
95
+ assert self._queue.empty(), 'Queue is not clean before next initialization.'
96
+ self._activated = False
97
+ self._registry.clear()
98
+ future = FutureResult()
99
+ self._registry[identifier] = _MasterRegistry(future)
100
+ return SlavePipe(identifier, self._queue, future)
101
+
102
+ def run_master(self, master_msg):
103
+ """
104
+ Main entry for the master device in each forward pass.
105
+ The messages were first collected from each devices (including the master device), and then
106
+ an callback will be invoked to compute the message to be sent back to each devices
107
+ (including the master device).
108
+
109
+ Args:
110
+ master_msg: the message that the master want to send to itself. This will be placed as the first
111
+ message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
112
+
113
+ Returns: the message to be sent back to the master device.
114
+
115
+ """
116
+ self._activated = True
117
+
118
+ intermediates = [(0, master_msg)]
119
+ for i in range(self.nr_slaves):
120
+ intermediates.append(self._queue.get())
121
+
122
+ results = self._master_callback(intermediates)
123
+ assert results[0][0] == 0, 'The first result should belongs to the master.'
124
+
125
+ for i, res in results:
126
+ if i == 0:
127
+ continue
128
+ self._registry[i].result.put(res)
129
+
130
+ for i in range(self.nr_slaves):
131
+ assert self._queue.get() is True
132
+
133
+ return results[0][1]
134
+
135
+ @property
136
+ def nr_slaves(self):
137
+ return len(self._registry)
sync_batchnorm/replicate.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # File : replicate.py
3
+ # Author : Jiayuan Mao
4
+ # Email : [email protected]
5
+ # Date : 27/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+ # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9
+ # Distributed under MIT License.
10
+
11
+ import functools
12
+
13
+ from torch.nn.parallel.data_parallel import DataParallel
14
+
15
+ __all__ = [
16
+ 'CallbackContext',
17
+ 'execute_replication_callbacks',
18
+ 'DataParallelWithCallback',
19
+ 'patch_replication_callback'
20
+ ]
21
+
22
+
23
+ class CallbackContext(object):
24
+ pass
25
+
26
+
27
+ def execute_replication_callbacks(modules):
28
+ """
29
+ Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.
30
+
31
+ The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
32
+
33
+ Note that, as all modules are isomorphism, we assign each sub-module with a context
34
+ (shared among multiple copies of this module on different devices).
35
+ Through this context, different copies can share some information.
36
+
37
+ We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback
38
+ of any slave copies.
39
+ """
40
+ master_copy = modules[0]
41
+ nr_modules = len(list(master_copy.modules()))
42
+ ctxs = [CallbackContext() for _ in range(nr_modules)]
43
+
44
+ for i, module in enumerate(modules):
45
+ for j, m in enumerate(module.modules()):
46
+ if hasattr(m, '__data_parallel_replicate__'):
47
+ m.__data_parallel_replicate__(ctxs[j], i)
48
+
49
+
50
+ class DataParallelWithCallback(DataParallel):
51
+ """
52
+ Data Parallel with a replication callback.
53
+
54
+ An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by
55
+ original `replicate` function.
56
+ The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
57
+
58
+ Examples:
59
+ > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
60
+ > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
61
+ # sync_bn.__data_parallel_replicate__ will be invoked.
62
+ """
63
+
64
+ def replicate(self, module, device_ids):
65
+ modules = super(DataParallelWithCallback, self).replicate(module, device_ids)
66
+ execute_replication_callbacks(modules)
67
+ return modules
68
+
69
+
70
+ def patch_replication_callback(data_parallel):
71
+ """
72
+ Monkey-patch an existing `DataParallel` object. Add the replication callback.
73
+ Useful when you have customized `DataParallel` implementation.
74
+
75
+ Examples:
76
+ > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
77
+ > sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
78
+ > patch_replication_callback(sync_bn)
79
+ # this is equivalent to
80
+ > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
81
+ > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
82
+ """
83
+
84
+ assert isinstance(data_parallel, DataParallel)
85
+
86
+ old_replicate = data_parallel.replicate
87
+
88
+ @functools.wraps(old_replicate)
89
+ def new_replicate(module, device_ids):
90
+ modules = old_replicate(module, device_ids)
91
+ execute_replication_callbacks(modules)
92
+ return modules
93
+
94
+ data_parallel.replicate = new_replicate
sync_batchnorm/unittest.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # File : unittest.py
3
+ # Author : Jiayuan Mao
4
+ # Email : [email protected]
5
+ # Date : 27/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+ # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9
+ # Distributed under MIT License.
10
+
11
+ import unittest
12
+
13
+ import numpy as np
14
+ from torch.autograd import Variable
15
+
16
+
17
+ def as_numpy(v):
18
+ if isinstance(v, Variable):
19
+ v = v.data
20
+ return v.cpu().numpy()
21
+
22
+
23
+ class TorchTestCase(unittest.TestCase):
24
+ def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3):
25
+ npa, npb = as_numpy(a), as_numpy(b)
26
+ self.assertTrue(
27
+ np.allclose(npa, npb, atol=atol),
28
+ 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max())
29
+ )