Spaces:
Running
Running
harlanhong
commited on
Commit
•
bcec73a
1
Parent(s):
cd0695c
first
Browse files- app.py +67 -0
- config/vox-adv-256.yaml +88 -0
- demo_dagan.py +208 -0
- depth/__init__.py +5 -0
- depth/depth_decoder.py +65 -0
- depth/layers.py +285 -0
- depth/pose_cnn.py +50 -0
- depth/pose_decoder.py +71 -0
- depth/resnet_encoder.py +98 -0
- modules/AdaIN.py +61 -0
- modules/dense_motion.py +112 -0
- modules/discriminator.py +95 -0
- modules/dynamic_conv.py +382 -0
- modules/generator.py +325 -0
- modules/keypoint_detector.py +75 -0
- modules/model.py +362 -0
- modules/util.py +399 -0
- project/cartoon2.jpg +0 -0
- project/cartoon3.jpg +0 -0
- project/cartoon4.jpg +0 -0
- project/cartoon5.jpg +0 -0
- project/cartoon6.jpg +0 -0
- project/cartoon7.jpg +0 -0
- project/celeb1.jpg +0 -0
- project/celeb2.jpg +0 -0
- project/celeb3.jpg +0 -0
- project/celeb4.jpg +0 -0
- project/celeb6.jpg +0 -0
- project/celeb7.jpg +0 -0
- project/celeb8.jpg +0 -0
- project/video1.mp4 +0 -0
- project/video2.mp4 +0 -0
- sync_batchnorm/__init__.py +12 -0
- sync_batchnorm/batchnorm.py +315 -0
- sync_batchnorm/comm.py +137 -0
- sync_batchnorm/replicate.py +94 -0
- sync_batchnorm/unittest.py +29 -0
app.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import shutil
|
3 |
+
import torch
|
4 |
+
import cv2
|
5 |
+
import gradio as gr
|
6 |
+
from PIL import Image
|
7 |
+
|
8 |
+
#os.chdir('Restormer')
|
9 |
+
|
10 |
+
# Download sample images
|
11 |
+
os.system("wget https://github.com/swz30/Restormer/releases/download/v1.0/sample_images.zip")
|
12 |
+
shutil.unpack_archive('sample_images.zip')
|
13 |
+
os.remove('sample_images.zip')
|
14 |
+
|
15 |
+
|
16 |
+
examples = [['project/cartoon2.jpg'],
|
17 |
+
['project/cartoon3.jpg'],
|
18 |
+
['project/celeb1.jpg'],
|
19 |
+
['project/celeb2.jpg']
|
20 |
+
]
|
21 |
+
|
22 |
+
|
23 |
+
inference_on = ['Full Resolution Image', 'Downsampled Image']
|
24 |
+
|
25 |
+
title = "DaGAN"
|
26 |
+
description = """
|
27 |
+
Gradio demo for <b>Depth-Aware Generative Adversarial Network for Talking Head Video Generation</b>, CVPR 2022L. <a href='https://arxiv.org/abs/2203.06605'>[Paper]</a><a href='https://github.com/harlanhong/CVPR2022-DaGAN'>[Github Code]</a>\n
|
28 |
+
"""
|
29 |
+
##With Restormer, you can perform: (1) Image Denoising, (2) Defocus Deblurring, (3) Motion Deblurring, and (4) Image Deraining.
|
30 |
+
##To use it, simply upload your own image, or click one of the examples provided below.
|
31 |
+
|
32 |
+
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2203.06605'>Depth-Aware Generative Adversarial Network for Talking Head Video Generation</a> | <a href='https://github.com/harlanhong/CVPR2022-DaGAN'>Github Repo</a></p>"
|
33 |
+
|
34 |
+
|
35 |
+
def inference(img, video):
|
36 |
+
if not os.path.exists('temp'):
|
37 |
+
os.system('mkdir temp')
|
38 |
+
|
39 |
+
#### Resize the longer edge of the input image
|
40 |
+
max_res = 256
|
41 |
+
width, height = img.size
|
42 |
+
if max(width,height) > max_res:
|
43 |
+
scale = max_res /max(width,height)
|
44 |
+
width = int(scale*width)
|
45 |
+
height = int(scale*height)
|
46 |
+
img = img.resize((width,height), Image.ANTIALIAS)
|
47 |
+
|
48 |
+
img.save("temp/image.jpg", "JPEG")
|
49 |
+
video.save('temp/video.mp4')
|
50 |
+
os.system("python demo_dagan.py --source_image 'temp/image.jpg' --driving_video {} --output 'temp/rst.mp4'".format(video))
|
51 |
+
|
52 |
+
return f'temp/rst.mp4'
|
53 |
+
|
54 |
+
gr.Interface(
|
55 |
+
inference,
|
56 |
+
[
|
57 |
+
gr.inputs.Image(type="pil", label="Input"),
|
58 |
+
gr.inputs.Video(label="Input"),
|
59 |
+
],
|
60 |
+
gr.outputs.Video(type="mp4", label="Output"),
|
61 |
+
title=title,
|
62 |
+
description=description,
|
63 |
+
article=article,
|
64 |
+
theme ="huggingface",
|
65 |
+
examples=examples,
|
66 |
+
allow_flagging=False,
|
67 |
+
).launch(debug=False,enable_queue=True)
|
config/vox-adv-256.yaml
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dataset_params:
|
2 |
+
root_dir: /data/fhongac/origDataset/vox1_frames
|
3 |
+
frame_shape: [256, 256, 3]
|
4 |
+
id_sampling: True
|
5 |
+
pairs_list: data/vox256.csv
|
6 |
+
augmentation_params:
|
7 |
+
flip_param:
|
8 |
+
horizontal_flip: True
|
9 |
+
time_flip: True
|
10 |
+
jitter_param:
|
11 |
+
brightness: 0.1
|
12 |
+
contrast: 0.1
|
13 |
+
saturation: 0.1
|
14 |
+
hue: 0.1
|
15 |
+
|
16 |
+
|
17 |
+
model_params:
|
18 |
+
common_params:
|
19 |
+
num_kp: 15
|
20 |
+
num_channels: 3
|
21 |
+
estimate_jacobian: True
|
22 |
+
kp_detector_params:
|
23 |
+
temperature: 0.1
|
24 |
+
block_expansion: 32
|
25 |
+
max_features: 1024
|
26 |
+
scale_factor: 0.25
|
27 |
+
num_blocks: 5
|
28 |
+
generator_params:
|
29 |
+
block_expansion: 64
|
30 |
+
max_features: 512
|
31 |
+
num_down_blocks: 2
|
32 |
+
num_bottleneck_blocks: 6
|
33 |
+
estimate_occlusion_map: True
|
34 |
+
dense_motion_params:
|
35 |
+
block_expansion: 64
|
36 |
+
max_features: 1024
|
37 |
+
num_blocks: 5
|
38 |
+
scale_factor: 0.25
|
39 |
+
discriminator_params:
|
40 |
+
scales: [1]
|
41 |
+
block_expansion: 32
|
42 |
+
max_features: 512
|
43 |
+
num_blocks: 4
|
44 |
+
use_kp: True
|
45 |
+
|
46 |
+
|
47 |
+
train_params:
|
48 |
+
num_epochs: 150
|
49 |
+
num_repeats: 75
|
50 |
+
epoch_milestones: []
|
51 |
+
lr_generator: 2.0e-4
|
52 |
+
lr_discriminator: 2.0e-4
|
53 |
+
lr_kp_detector: 2.0e-4
|
54 |
+
batch_size: 4
|
55 |
+
scales: [1, 0.5, 0.25, 0.125]
|
56 |
+
checkpoint_freq: 10
|
57 |
+
transform_params:
|
58 |
+
sigma_affine: 0.05
|
59 |
+
sigma_tps: 0.005
|
60 |
+
points_tps: 5
|
61 |
+
loss_weights:
|
62 |
+
generator_gan: 1
|
63 |
+
discriminator_gan: 1
|
64 |
+
feature_matching: [10, 10, 10, 10]
|
65 |
+
perceptual: [10, 10, 10, 10, 10]
|
66 |
+
equivariance_value: 10
|
67 |
+
equivariance_jacobian: 10
|
68 |
+
kp_distance: 10
|
69 |
+
kp_prior: 0
|
70 |
+
kp_scale: 0
|
71 |
+
depth_constraint: 0
|
72 |
+
|
73 |
+
reconstruction_params:
|
74 |
+
num_videos: 1000
|
75 |
+
format: '.mp4'
|
76 |
+
|
77 |
+
animate_params:
|
78 |
+
num_pairs: 50
|
79 |
+
format: '.mp4'
|
80 |
+
normalization_params:
|
81 |
+
adapt_movement_scale: False
|
82 |
+
use_relative_movement: True
|
83 |
+
use_relative_jacobian: True
|
84 |
+
|
85 |
+
visualizer_params:
|
86 |
+
kp_size: 5
|
87 |
+
draw_border: True
|
88 |
+
colormap: 'gist_rainbow'
|
demo_dagan.py
ADDED
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Restormer: Efficient Transformer for High-Resolution Image Restoration
|
2 |
+
## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang
|
3 |
+
## https://arxiv.org/abs/2111.09881
|
4 |
+
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
import os
|
9 |
+
from skimage import img_as_ubyte
|
10 |
+
import cv2
|
11 |
+
import argparse
|
12 |
+
import imageio
|
13 |
+
from skimage.transform import resize
|
14 |
+
from scipy.spatial import ConvexHull
|
15 |
+
from tqdm import tqdm
|
16 |
+
import numpy as np
|
17 |
+
import modules.generator as G
|
18 |
+
import modules.keypoint_detector as KPD
|
19 |
+
import yaml
|
20 |
+
from collections import OrderedDict
|
21 |
+
import depth
|
22 |
+
parser = argparse.ArgumentParser(description='Test DaGAN on your own images')
|
23 |
+
parser.add_argument('--source_image', default='./temp/source.jpg', type=str, help='Directory of input source image')
|
24 |
+
parser.add_argument('--driving_video', default='./temp/driving.mp4', type=str, help='Directory for driving video')
|
25 |
+
parser.add_argument('--output', default='./temp/result.mp4', type=str, help='Directory for driving video')
|
26 |
+
|
27 |
+
|
28 |
+
args = parser.parse_args()
|
29 |
+
def normalize_kp(kp_source, kp_driving, kp_driving_initial, adapt_movement_scale=False,
|
30 |
+
use_relative_movement=False, use_relative_jacobian=False):
|
31 |
+
if adapt_movement_scale:
|
32 |
+
source_area = ConvexHull(kp_source['value'][0].data.cpu().numpy()).volume
|
33 |
+
driving_area = ConvexHull(kp_driving_initial['value'][0].data.cpu().numpy()).volume
|
34 |
+
adapt_movement_scale = np.sqrt(source_area) / np.sqrt(driving_area)
|
35 |
+
else:
|
36 |
+
adapt_movement_scale = 1
|
37 |
+
|
38 |
+
kp_new = {k: v for k, v in kp_driving.items()}
|
39 |
+
|
40 |
+
if use_relative_movement:
|
41 |
+
kp_value_diff = (kp_driving['value'] - kp_driving_initial['value'])
|
42 |
+
kp_value_diff *= adapt_movement_scale
|
43 |
+
kp_new['value'] = kp_value_diff + kp_source['value']
|
44 |
+
|
45 |
+
if use_relative_jacobian:
|
46 |
+
jacobian_diff = torch.matmul(kp_driving['jacobian'], torch.inverse(kp_driving_initial['jacobian']))
|
47 |
+
kp_new['jacobian'] = torch.matmul(jacobian_diff, kp_source['jacobian'])
|
48 |
+
return kp_new
|
49 |
+
def find_best_frame(source, driving, cpu=False):
|
50 |
+
import face_alignment
|
51 |
+
|
52 |
+
def normalize_kp(kp):
|
53 |
+
kp = kp - kp.mean(axis=0, keepdims=True)
|
54 |
+
area = ConvexHull(kp[:, :2]).volume
|
55 |
+
area = np.sqrt(area)
|
56 |
+
kp[:, :2] = kp[:, :2] / area
|
57 |
+
return kp
|
58 |
+
|
59 |
+
fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=True,
|
60 |
+
device='cpu' if cpu else 'cuda')
|
61 |
+
kp_source = fa.get_landmarks(255 * source)[0]
|
62 |
+
kp_source = normalize_kp(kp_source)
|
63 |
+
norm = float('inf')
|
64 |
+
frame_num = 0
|
65 |
+
for i, image in tqdm(enumerate(driving)):
|
66 |
+
kp_driving = fa.get_landmarks(255 * image)[0]
|
67 |
+
kp_driving = normalize_kp(kp_driving)
|
68 |
+
new_norm = (np.abs(kp_source - kp_driving) ** 2).sum()
|
69 |
+
if new_norm < norm:
|
70 |
+
norm = new_norm
|
71 |
+
frame_num = i
|
72 |
+
return frame_num
|
73 |
+
|
74 |
+
|
75 |
+
def make_animation(source_image, driving_video, generator, kp_detector, relative=True, adapt_movement_scale=True, cpu=False):
|
76 |
+
sources = []
|
77 |
+
drivings = []
|
78 |
+
with torch.no_grad():
|
79 |
+
predictions = []
|
80 |
+
depth_gray = []
|
81 |
+
source = torch.tensor(source_image[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2)
|
82 |
+
driving = torch.tensor(np.array(driving_video)[np.newaxis].astype(np.float32)).permute(0, 4, 1, 2, 3)
|
83 |
+
if not cpu:
|
84 |
+
source = source.cuda()
|
85 |
+
driving = driving.cuda()
|
86 |
+
outputs = depth_decoder(depth_encoder(source))
|
87 |
+
depth_source = outputs[("disp", 0)]
|
88 |
+
|
89 |
+
outputs = depth_decoder(depth_encoder(driving[:, :, 0]))
|
90 |
+
depth_driving = outputs[("disp", 0)]
|
91 |
+
source_kp = torch.cat((source,depth_source),1)
|
92 |
+
driving_kp = torch.cat((driving[:, :, 0],depth_driving),1)
|
93 |
+
|
94 |
+
kp_source = kp_detector(source_kp)
|
95 |
+
kp_driving_initial = kp_detector(driving_kp)
|
96 |
+
|
97 |
+
# kp_source = kp_detector(source)
|
98 |
+
# kp_driving_initial = kp_detector(driving[:, :, 0])
|
99 |
+
|
100 |
+
for frame_idx in tqdm(range(driving.shape[2])):
|
101 |
+
driving_frame = driving[:, :, frame_idx]
|
102 |
+
|
103 |
+
if not cpu:
|
104 |
+
driving_frame = driving_frame.cuda()
|
105 |
+
outputs = depth_decoder(depth_encoder(driving_frame))
|
106 |
+
depth_map = outputs[("disp", 0)]
|
107 |
+
|
108 |
+
gray_driving = np.transpose(depth_map.data.cpu().numpy(), [0, 2, 3, 1])[0]
|
109 |
+
gray_driving = 1-gray_driving/np.max(gray_driving)
|
110 |
+
|
111 |
+
frame = torch.cat((driving_frame,depth_map),1)
|
112 |
+
kp_driving = kp_detector(frame)
|
113 |
+
|
114 |
+
kp_norm = normalize_kp(kp_source=kp_source, kp_driving=kp_driving,
|
115 |
+
kp_driving_initial=kp_driving_initial, use_relative_movement=relative,
|
116 |
+
use_relative_jacobian=relative, adapt_movement_scale=adapt_movement_scale)
|
117 |
+
out = generator(source, kp_source=kp_source, kp_driving=kp_norm,source_depth = depth_source, driving_depth = depth_map)
|
118 |
+
|
119 |
+
drivings.append(np.transpose(driving_frame.data.cpu().numpy(), [0, 2, 3, 1])[0])
|
120 |
+
sources.append(np.transpose(source.data.cpu().numpy(), [0, 2, 3, 1])[0])
|
121 |
+
predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0])
|
122 |
+
depth_gray.append(gray_driving)
|
123 |
+
return sources, drivings, predictions,depth_gray
|
124 |
+
with open("config/vox-adv-256.yaml") as f:
|
125 |
+
config = yaml.load(f)
|
126 |
+
generator = G.SPADEDepthAwareGenerator(**config['model_params']['generator_params'],**config['model_params']['common_params'])
|
127 |
+
config['model_params']['common_params']['num_channels'] = 4
|
128 |
+
kp_detector = KPD.KPDetector(**config['model_params']['kp_detector_params'],**config['model_params']['common_params'])
|
129 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
130 |
+
|
131 |
+
|
132 |
+
g_checkpoint = torch.load("generator.pt", map_location=device)
|
133 |
+
kp_checkpoint = torch.load("kp_detector.pt", map_location=device)
|
134 |
+
|
135 |
+
ckp_generator = OrderedDict((k.replace('module.',''),v) for k,v in g_checkpoint.items())
|
136 |
+
generator.load_state_dict(ckp_generator)
|
137 |
+
ckp_kp_detector = OrderedDict((k.replace('module.',''),v) for k,v in kp_checkpoint.items())
|
138 |
+
kp_detector.load_state_dict(ckp_kp_detector)
|
139 |
+
|
140 |
+
depth_encoder = depth.ResnetEncoder(18, False)
|
141 |
+
depth_decoder = depth.DepthDecoder(num_ch_enc=depth_encoder.num_ch_enc, scales=range(4))
|
142 |
+
loaded_dict_enc = torch.load('encoder.pth')
|
143 |
+
loaded_dict_dec = torch.load('depth.pth')
|
144 |
+
filtered_dict_enc = {k: v for k, v in loaded_dict_enc.items() if k in depth_encoder.state_dict()}
|
145 |
+
depth_encoder.load_state_dict(filtered_dict_enc)
|
146 |
+
ckp_depth_decoder= {k: v for k, v in loaded_dict_dec.items() if k in depth_decoder.state_dict()}
|
147 |
+
depth_decoder.load_state_dict(ckp_depth_decoder)
|
148 |
+
depth_encoder.eval()
|
149 |
+
depth_decoder.eval()
|
150 |
+
|
151 |
+
# device = torch.device('cpu')
|
152 |
+
# stx()
|
153 |
+
|
154 |
+
generator = generator.to(device)
|
155 |
+
kp_detector = kp_detector.to(device)
|
156 |
+
depth_encoder = depth_encoder.to(device)
|
157 |
+
depth_decoder = depth_decoder.to(device)
|
158 |
+
|
159 |
+
generator.eval()
|
160 |
+
kp_detector.eval()
|
161 |
+
depth_encoder.eval()
|
162 |
+
depth_decoder.eval()
|
163 |
+
|
164 |
+
img_multiple_of = 8
|
165 |
+
|
166 |
+
with torch.inference_mode():
|
167 |
+
if torch.cuda.is_available():
|
168 |
+
torch.cuda.ipc_collect()
|
169 |
+
torch.cuda.empty_cache()
|
170 |
+
source_image = imageio.imread(args.source_image)
|
171 |
+
reader = imageio.get_reader(args.driving_video)
|
172 |
+
fps = reader.get_meta_data()['fps']
|
173 |
+
driving_video = []
|
174 |
+
try:
|
175 |
+
for im in reader:
|
176 |
+
driving_video.append(im)
|
177 |
+
except RuntimeError:
|
178 |
+
pass
|
179 |
+
reader.close()
|
180 |
+
|
181 |
+
source_image = resize(source_image, (256, 256))[..., :3]
|
182 |
+
driving_video = [resize(frame, (256, 256))[..., :3] for frame in driving_video]
|
183 |
+
|
184 |
+
|
185 |
+
|
186 |
+
i = find_best_frame(source_image, driving_video)
|
187 |
+
print ("Best frame: " + str(i))
|
188 |
+
driving_forward = driving_video[i:]
|
189 |
+
driving_backward = driving_video[:(i+1)][::-1]
|
190 |
+
sources_forward, drivings_forward, predictions_forward,depth_forward = make_animation(source_image, driving_forward, generator, kp_detector, relative=True, adapt_movement_scale=True, cpu=False)
|
191 |
+
sources_backward, drivings_backward, predictions_backward,depth_backward = make_animation(source_image, driving_backward, generator, kp_detector, relative=True, adapt_movement_scale=True, cpu=False)
|
192 |
+
predictions = predictions_backward[::-1] + predictions_forward[1:]
|
193 |
+
sources = sources_backward[::-1] + sources_forward[1:]
|
194 |
+
drivings = drivings_backward[::-1] + drivings_forward[1:]
|
195 |
+
depth_gray = depth_backward[::-1] + depth_forward[1:]
|
196 |
+
|
197 |
+
imageio.mimsave(args.output, [np.concatenate((img_as_ubyte(s),img_as_ubyte(d),img_as_ubyte(p)),1) for (s,d,p) in zip(sources, drivings, predictions)], fps=fps)
|
198 |
+
imageio.mimsave("gray.mp4", depth_gray, fps=fps)
|
199 |
+
# merge the gray video
|
200 |
+
animation = np.array(imageio.mimread(args.output,memtest=False))
|
201 |
+
gray = np.array(imageio.mimread("gray.mp4",memtest=False))
|
202 |
+
|
203 |
+
src_dst = animation[:,:,:512,:]
|
204 |
+
animate = animation[:,:,512:,:]
|
205 |
+
merge = np.concatenate((src_dst,gray,animate),2)
|
206 |
+
imageio.mimsave(args.output, merge, fps=fps)
|
207 |
+
|
208 |
+
# print(f"\nRestored images are saved at {out_dir}")
|
depth/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .resnet_encoder import ResnetEncoder
|
2 |
+
from .depth_decoder import DepthDecoder
|
3 |
+
from .pose_decoder import PoseDecoder
|
4 |
+
from .pose_cnn import PoseCNN
|
5 |
+
|
depth/depth_decoder.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright Niantic 2019. Patent Pending. All rights reserved.
|
2 |
+
#
|
3 |
+
# This software is licensed under the terms of the Monodepth2 licence
|
4 |
+
# which allows for non-commercial use only, the full terms of which are made
|
5 |
+
# available in the LICENSE file.
|
6 |
+
|
7 |
+
from __future__ import absolute_import, division, print_function
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
|
13 |
+
from collections import OrderedDict
|
14 |
+
from depth.layers import *
|
15 |
+
|
16 |
+
|
17 |
+
class DepthDecoder(nn.Module):
|
18 |
+
def __init__(self, num_ch_enc, scales=range(4), num_output_channels=1, use_skips=True):
|
19 |
+
super(DepthDecoder, self).__init__()
|
20 |
+
|
21 |
+
self.num_output_channels = num_output_channels
|
22 |
+
self.use_skips = use_skips
|
23 |
+
self.upsample_mode = 'nearest'
|
24 |
+
self.scales = scales
|
25 |
+
|
26 |
+
self.num_ch_enc = num_ch_enc
|
27 |
+
self.num_ch_dec = np.array([16, 32, 64, 128, 256])
|
28 |
+
|
29 |
+
# decoder
|
30 |
+
self.convs = OrderedDict()
|
31 |
+
for i in range(4, -1, -1):
|
32 |
+
# upconv_0
|
33 |
+
num_ch_in = self.num_ch_enc[-1] if i == 4 else self.num_ch_dec[i + 1]
|
34 |
+
num_ch_out = self.num_ch_dec[i]
|
35 |
+
self.convs[("upconv", i, 0)] = ConvBlock(num_ch_in, num_ch_out)
|
36 |
+
|
37 |
+
# upconv_1
|
38 |
+
num_ch_in = self.num_ch_dec[i]
|
39 |
+
if self.use_skips and i > 0:
|
40 |
+
num_ch_in += self.num_ch_enc[i - 1]
|
41 |
+
num_ch_out = self.num_ch_dec[i]
|
42 |
+
self.convs[("upconv", i, 1)] = ConvBlock(num_ch_in, num_ch_out)
|
43 |
+
|
44 |
+
for s in self.scales:
|
45 |
+
self.convs[("dispconv", s)] = Conv3x3(self.num_ch_dec[s], self.num_output_channels)
|
46 |
+
|
47 |
+
self.decoder = nn.ModuleList(list(self.convs.values()))
|
48 |
+
self.sigmoid = nn.Sigmoid()
|
49 |
+
|
50 |
+
def forward(self, input_features):
|
51 |
+
self.outputs = {}
|
52 |
+
|
53 |
+
# decoder
|
54 |
+
x = input_features[-1]
|
55 |
+
for i in range(4, -1, -1):
|
56 |
+
x = self.convs[("upconv", i, 0)](x)
|
57 |
+
x = [upsample(x)]
|
58 |
+
if self.use_skips and i > 0:
|
59 |
+
x += [input_features[i - 1]]
|
60 |
+
x = torch.cat(x, 1)
|
61 |
+
x = self.convs[("upconv", i, 1)](x)
|
62 |
+
if i in self.scales:
|
63 |
+
self.outputs[("disp", i)] = self.sigmoid(self.convs[("dispconv", i)](x))
|
64 |
+
|
65 |
+
return self.outputs
|
depth/layers.py
ADDED
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright Niantic 2019. Patent Pending. All rights reserved.
|
2 |
+
#
|
3 |
+
# This software is licensed under the terms of the Monodepth2 licence
|
4 |
+
# which allows for non-commercial use only, the full terms of which are made
|
5 |
+
# available in the LICENSE file.
|
6 |
+
|
7 |
+
from __future__ import absolute_import, division, print_function
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import pdb
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
import torch.nn.functional as F
|
14 |
+
|
15 |
+
|
16 |
+
def disp_to_depth(disp, min_depth, max_depth):
|
17 |
+
"""Convert network's sigmoid output into depth prediction
|
18 |
+
The formula for this conversion is given in the 'additional considerations'
|
19 |
+
section of the paper.
|
20 |
+
"""
|
21 |
+
min_disp = 1 / max_depth
|
22 |
+
max_disp = 1 / min_depth
|
23 |
+
scaled_disp = min_disp + (max_disp - min_disp) * disp
|
24 |
+
depth = 1 / scaled_disp
|
25 |
+
return scaled_disp, depth
|
26 |
+
|
27 |
+
|
28 |
+
def transformation_from_parameters(axisangle, translation, invert=False):
|
29 |
+
"""Convert the network's (axisangle, translation) output into a 4x4 matrix
|
30 |
+
"""
|
31 |
+
R = rot_from_axisangle(axisangle)
|
32 |
+
t = translation.clone()
|
33 |
+
|
34 |
+
if invert:
|
35 |
+
R = R.transpose(1, 2)
|
36 |
+
t *= -1
|
37 |
+
|
38 |
+
T = get_translation_matrix(t)
|
39 |
+
|
40 |
+
if invert:
|
41 |
+
M = torch.matmul(R, T)
|
42 |
+
else:
|
43 |
+
M = torch.matmul(T, R)
|
44 |
+
|
45 |
+
return M
|
46 |
+
|
47 |
+
|
48 |
+
def get_translation_matrix(translation_vector):
|
49 |
+
"""Convert a translation vector into a 4x4 transformation matrix
|
50 |
+
"""
|
51 |
+
T = torch.zeros(translation_vector.shape[0], 4, 4).to(device=translation_vector.device)
|
52 |
+
|
53 |
+
t = translation_vector.contiguous().view(-1, 3, 1)
|
54 |
+
|
55 |
+
T[:, 0, 0] = 1
|
56 |
+
T[:, 1, 1] = 1
|
57 |
+
T[:, 2, 2] = 1
|
58 |
+
T[:, 3, 3] = 1
|
59 |
+
T[:, :3, 3, None] = t
|
60 |
+
|
61 |
+
return T
|
62 |
+
|
63 |
+
|
64 |
+
def rot_from_axisangle(vec):
|
65 |
+
"""Convert an axisangle rotation into a 4x4 transformation matrix
|
66 |
+
(adapted from https://github.com/Wallacoloo/printipi)
|
67 |
+
Input 'vec' has to be Bx1x3
|
68 |
+
"""
|
69 |
+
angle = torch.norm(vec, 2, 2, True)
|
70 |
+
axis = vec / (angle + 1e-7)
|
71 |
+
|
72 |
+
ca = torch.cos(angle)
|
73 |
+
sa = torch.sin(angle)
|
74 |
+
C = 1 - ca
|
75 |
+
|
76 |
+
x = axis[..., 0].unsqueeze(1)
|
77 |
+
y = axis[..., 1].unsqueeze(1)
|
78 |
+
z = axis[..., 2].unsqueeze(1)
|
79 |
+
|
80 |
+
xs = x * sa
|
81 |
+
ys = y * sa
|
82 |
+
zs = z * sa
|
83 |
+
xC = x * C
|
84 |
+
yC = y * C
|
85 |
+
zC = z * C
|
86 |
+
xyC = x * yC
|
87 |
+
yzC = y * zC
|
88 |
+
zxC = z * xC
|
89 |
+
|
90 |
+
rot = torch.zeros((vec.shape[0], 4, 4)).to(device=vec.device)
|
91 |
+
|
92 |
+
rot[:, 0, 0] = torch.squeeze(x * xC + ca)
|
93 |
+
rot[:, 0, 1] = torch.squeeze(xyC - zs)
|
94 |
+
rot[:, 0, 2] = torch.squeeze(zxC + ys)
|
95 |
+
rot[:, 1, 0] = torch.squeeze(xyC + zs)
|
96 |
+
rot[:, 1, 1] = torch.squeeze(y * yC + ca)
|
97 |
+
rot[:, 1, 2] = torch.squeeze(yzC - xs)
|
98 |
+
rot[:, 2, 0] = torch.squeeze(zxC - ys)
|
99 |
+
rot[:, 2, 1] = torch.squeeze(yzC + xs)
|
100 |
+
rot[:, 2, 2] = torch.squeeze(z * zC + ca)
|
101 |
+
rot[:, 3, 3] = 1
|
102 |
+
|
103 |
+
return rot
|
104 |
+
|
105 |
+
|
106 |
+
class ConvBlock(nn.Module):
|
107 |
+
"""Layer to perform a convolution followed by ELU
|
108 |
+
"""
|
109 |
+
def __init__(self, in_channels, out_channels):
|
110 |
+
super(ConvBlock, self).__init__()
|
111 |
+
|
112 |
+
self.conv = Conv3x3(in_channels, out_channels)
|
113 |
+
self.nonlin = nn.ELU(inplace=True)
|
114 |
+
|
115 |
+
def forward(self, x):
|
116 |
+
out = self.conv(x)
|
117 |
+
out = self.nonlin(out)
|
118 |
+
return out
|
119 |
+
|
120 |
+
|
121 |
+
class Conv3x3(nn.Module):
|
122 |
+
"""Layer to pad and convolve input
|
123 |
+
"""
|
124 |
+
def __init__(self, in_channels, out_channels, use_refl=True):
|
125 |
+
super(Conv3x3, self).__init__()
|
126 |
+
|
127 |
+
if use_refl:
|
128 |
+
self.pad = nn.ReflectionPad2d(1)
|
129 |
+
else:
|
130 |
+
self.pad = nn.ZeroPad2d(1)
|
131 |
+
self.conv = nn.Conv2d(int(in_channels), int(out_channels), 3)
|
132 |
+
|
133 |
+
def forward(self, x):
|
134 |
+
out = self.pad(x)
|
135 |
+
out = self.conv(out)
|
136 |
+
return out
|
137 |
+
|
138 |
+
|
139 |
+
class BackprojectDepth(nn.Module):
|
140 |
+
"""Layer to transform a depth image into a point cloud
|
141 |
+
"""
|
142 |
+
def __init__(self, batch_size, height, width):
|
143 |
+
super(BackprojectDepth, self).__init__()
|
144 |
+
|
145 |
+
self.batch_size = batch_size
|
146 |
+
self.height = height
|
147 |
+
self.width = width
|
148 |
+
|
149 |
+
meshgrid = np.meshgrid(range(self.width), range(self.height), indexing='xy')
|
150 |
+
self.id_coords = np.stack(meshgrid, axis=0).astype(np.float32)
|
151 |
+
self.id_coords = nn.Parameter(torch.from_numpy(self.id_coords),
|
152 |
+
requires_grad=False)
|
153 |
+
|
154 |
+
self.ones = nn.Parameter(torch.ones(self.batch_size, 1, self.height * self.width),
|
155 |
+
requires_grad=False)
|
156 |
+
|
157 |
+
self.pix_coords = torch.unsqueeze(torch.stack(
|
158 |
+
[self.id_coords[0].view(-1), self.id_coords[1].view(-1)], 0), 0)
|
159 |
+
self.pix_coords = self.pix_coords.repeat(batch_size, 1, 1)
|
160 |
+
self.pix_coords = nn.Parameter(torch.cat([self.pix_coords, self.ones], 1),
|
161 |
+
requires_grad=False)
|
162 |
+
|
163 |
+
def forward(self, depth, K,scale):
|
164 |
+
K[:,:2,:] = (K[:,:2,:]/(2 ** scale)).trunc()
|
165 |
+
b,n,n = K.shape
|
166 |
+
inv_K = torch.linalg.inv(K)
|
167 |
+
#inv_K = torch.cholesky_inverse(K)
|
168 |
+
pad = torch.tensor([0.0,0.0,0.0]).view(1,3,1).expand(b,3,1).cuda()
|
169 |
+
inv_K = torch.cat([inv_K,pad],-1)
|
170 |
+
pad = torch.tensor([0.0,0.0,0.0,1.0]).view(1,1,4).expand(b,1,4).cuda()
|
171 |
+
inv_K = torch.cat([inv_K,pad],1)
|
172 |
+
cam_points = torch.matmul(inv_K[:, :3, :3], self.pix_coords)
|
173 |
+
cam_points = depth.view(self.batch_size, 1, -1) * cam_points
|
174 |
+
cam_points = torch.cat([cam_points, self.ones], 1)
|
175 |
+
|
176 |
+
return cam_points
|
177 |
+
|
178 |
+
|
179 |
+
class Project3D(nn.Module):
|
180 |
+
"""Layer which projects 3D points into a camera with intrinsics K and at position T
|
181 |
+
"""
|
182 |
+
def __init__(self, batch_size, height, width, eps=1e-7):
|
183 |
+
super(Project3D, self).__init__()
|
184 |
+
|
185 |
+
self.batch_size = batch_size
|
186 |
+
self.height = height
|
187 |
+
self.width = width
|
188 |
+
self.eps = eps
|
189 |
+
|
190 |
+
def forward(self, points, K, T,scale=0):
|
191 |
+
# K[0, :] *= self.width // (2 ** scale)
|
192 |
+
# K[1, :] *= self.height // (2 ** scale)
|
193 |
+
K[:,:2,:] = (K[:,:2,:]/(2 ** scale)).trunc()
|
194 |
+
b,n,n = K.shape
|
195 |
+
pad = torch.tensor([0.0,0.0,0.0]).view(1,3,1).expand(b,3,1).cuda()
|
196 |
+
K = torch.cat([K,pad],-1)
|
197 |
+
pad = torch.tensor([0.0,0.0,0.0,1.0]).view(1,1,4).expand(b,1,4).cuda()
|
198 |
+
K = torch.cat([K,pad],1)
|
199 |
+
P = torch.matmul(K, T)[:, :3, :]
|
200 |
+
|
201 |
+
cam_points = torch.matmul(P, points)
|
202 |
+
|
203 |
+
pix_coords = cam_points[:, :2, :] / (cam_points[:, 2, :].unsqueeze(1) + self.eps)
|
204 |
+
pix_coords = pix_coords.view(self.batch_size, 2, self.height, self.width)
|
205 |
+
pix_coords = pix_coords.permute(0, 2, 3, 1)
|
206 |
+
pix_coords[..., 0] /= self.width - 1
|
207 |
+
pix_coords[..., 1] /= self.height - 1
|
208 |
+
pix_coords = (pix_coords - 0.5) * 2
|
209 |
+
return pix_coords
|
210 |
+
|
211 |
+
|
212 |
+
def upsample(x):
|
213 |
+
"""Upsample input tensor by a factor of 2
|
214 |
+
"""
|
215 |
+
return F.interpolate(x, scale_factor=2, mode="nearest")
|
216 |
+
|
217 |
+
|
218 |
+
def get_smooth_loss(disp, img):
|
219 |
+
"""Computes the smoothness loss for a disparity image
|
220 |
+
The color image is used for edge-aware smoothness
|
221 |
+
"""
|
222 |
+
grad_disp_x = torch.abs(disp[:, :, :, :-1] - disp[:, :, :, 1:])
|
223 |
+
grad_disp_y = torch.abs(disp[:, :, :-1, :] - disp[:, :, 1:, :])
|
224 |
+
|
225 |
+
grad_img_x = torch.mean(torch.abs(img[:, :, :, :-1] - img[:, :, :, 1:]), 1, keepdim=True)
|
226 |
+
grad_img_y = torch.mean(torch.abs(img[:, :, :-1, :] - img[:, :, 1:, :]), 1, keepdim=True)
|
227 |
+
|
228 |
+
grad_disp_x *= torch.exp(-grad_img_x)
|
229 |
+
grad_disp_y *= torch.exp(-grad_img_y)
|
230 |
+
|
231 |
+
return grad_disp_x.mean() + grad_disp_y.mean()
|
232 |
+
|
233 |
+
|
234 |
+
class SSIM(nn.Module):
|
235 |
+
"""Layer to compute the SSIM loss between a pair of images
|
236 |
+
"""
|
237 |
+
def __init__(self):
|
238 |
+
super(SSIM, self).__init__()
|
239 |
+
self.mu_x_pool = nn.AvgPool2d(3, 1)
|
240 |
+
self.mu_y_pool = nn.AvgPool2d(3, 1)
|
241 |
+
self.sig_x_pool = nn.AvgPool2d(3, 1)
|
242 |
+
self.sig_y_pool = nn.AvgPool2d(3, 1)
|
243 |
+
self.sig_xy_pool = nn.AvgPool2d(3, 1)
|
244 |
+
|
245 |
+
self.refl = nn.ReflectionPad2d(1)
|
246 |
+
|
247 |
+
self.C1 = 0.01 ** 2
|
248 |
+
self.C2 = 0.03 ** 2
|
249 |
+
|
250 |
+
def forward(self, x, y):
|
251 |
+
x = self.refl(x)
|
252 |
+
y = self.refl(y)
|
253 |
+
|
254 |
+
mu_x = self.mu_x_pool(x)
|
255 |
+
mu_y = self.mu_y_pool(y)
|
256 |
+
|
257 |
+
sigma_x = self.sig_x_pool(x ** 2) - mu_x ** 2
|
258 |
+
sigma_y = self.sig_y_pool(y ** 2) - mu_y ** 2
|
259 |
+
sigma_xy = self.sig_xy_pool(x * y) - mu_x * mu_y
|
260 |
+
|
261 |
+
SSIM_n = (2 * mu_x * mu_y + self.C1) * (2 * sigma_xy + self.C2)
|
262 |
+
SSIM_d = (mu_x ** 2 + mu_y ** 2 + self.C1) * (sigma_x + sigma_y + self.C2)
|
263 |
+
|
264 |
+
return torch.clamp((1 - SSIM_n / SSIM_d) / 2, 0, 1)
|
265 |
+
|
266 |
+
|
267 |
+
def compute_depth_errors(gt, pred):
|
268 |
+
"""Computation of error metrics between predicted and ground truth depths
|
269 |
+
"""
|
270 |
+
thresh = torch.max((gt / pred), (pred / gt))
|
271 |
+
a1 = (thresh < 1.25 ).float().mean()
|
272 |
+
a2 = (thresh < 1.25 ** 2).float().mean()
|
273 |
+
a3 = (thresh < 1.25 ** 3).float().mean()
|
274 |
+
|
275 |
+
rmse = (gt - pred) ** 2
|
276 |
+
rmse = torch.sqrt(rmse.mean())
|
277 |
+
|
278 |
+
rmse_log = (torch.log(gt) - torch.log(pred)) ** 2
|
279 |
+
rmse_log = torch.sqrt(rmse_log.mean())
|
280 |
+
|
281 |
+
abs_rel = torch.mean(torch.abs(gt - pred) / gt)
|
282 |
+
|
283 |
+
sq_rel = torch.mean((gt - pred) ** 2 / gt)
|
284 |
+
|
285 |
+
return abs_rel, sq_rel, rmse, rmse_log, a1, a2, a3
|
depth/pose_cnn.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright Niantic 2019. Patent Pending. All rights reserved.
|
2 |
+
#
|
3 |
+
# This software is licensed under the terms of the Monodepth2 licence
|
4 |
+
# which allows for non-commercial use only, the full terms of which are made
|
5 |
+
# available in the LICENSE file.
|
6 |
+
|
7 |
+
from __future__ import absolute_import, division, print_function
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
|
12 |
+
|
13 |
+
class PoseCNN(nn.Module):
|
14 |
+
def __init__(self, num_input_frames):
|
15 |
+
super(PoseCNN, self).__init__()
|
16 |
+
|
17 |
+
self.num_input_frames = num_input_frames
|
18 |
+
|
19 |
+
self.convs = {}
|
20 |
+
self.convs[0] = nn.Conv2d(3 * num_input_frames, 16, 7, 2, 3)
|
21 |
+
self.convs[1] = nn.Conv2d(16, 32, 5, 2, 2)
|
22 |
+
self.convs[2] = nn.Conv2d(32, 64, 3, 2, 1)
|
23 |
+
self.convs[3] = nn.Conv2d(64, 128, 3, 2, 1)
|
24 |
+
self.convs[4] = nn.Conv2d(128, 256, 3, 2, 1)
|
25 |
+
self.convs[5] = nn.Conv2d(256, 256, 3, 2, 1)
|
26 |
+
self.convs[6] = nn.Conv2d(256, 256, 3, 2, 1)
|
27 |
+
|
28 |
+
self.pose_conv = nn.Conv2d(256, 6 * (num_input_frames - 1), 1)
|
29 |
+
|
30 |
+
self.num_convs = len(self.convs)
|
31 |
+
|
32 |
+
self.relu = nn.ReLU(True)
|
33 |
+
|
34 |
+
self.net = nn.ModuleList(list(self.convs.values()))
|
35 |
+
|
36 |
+
def forward(self, out):
|
37 |
+
|
38 |
+
for i in range(self.num_convs):
|
39 |
+
out = self.convs[i](out)
|
40 |
+
out = self.relu(out)
|
41 |
+
|
42 |
+
out = self.pose_conv(out)
|
43 |
+
out = out.mean(3).mean(2)
|
44 |
+
|
45 |
+
out = 0.01 * out.view(-1, self.num_input_frames - 1, 1, 6)
|
46 |
+
|
47 |
+
axisangle = out[..., :3]
|
48 |
+
translation = out[..., 3:]
|
49 |
+
|
50 |
+
return axisangle, translation
|
depth/pose_decoder.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright Niantic 2019. Patent Pending. All rights reserved.
|
2 |
+
#
|
3 |
+
# This software is licensed under the terms of the Monodepth2 licence
|
4 |
+
# which allows for non-commercial use only, the full terms of which are made
|
5 |
+
# available in the LICENSE file.
|
6 |
+
|
7 |
+
from __future__ import absolute_import, division, print_function
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from collections import OrderedDict
|
12 |
+
import pdb
|
13 |
+
import torch.nn.functional as F
|
14 |
+
# from options import MonodepthOptions
|
15 |
+
# options = MonodepthOptions()
|
16 |
+
# opts = options.parse()
|
17 |
+
class PoseDecoder(nn.Module):
|
18 |
+
def __init__(self, num_ch_enc, num_input_features, num_frames_to_predict_for=None, stride=1):
|
19 |
+
super(PoseDecoder, self).__init__()
|
20 |
+
self.num_ch_enc = num_ch_enc
|
21 |
+
self.num_input_features = num_input_features
|
22 |
+
|
23 |
+
if num_frames_to_predict_for is None:
|
24 |
+
num_frames_to_predict_for = num_input_features - 1
|
25 |
+
self.num_frames_to_predict_for = num_frames_to_predict_for
|
26 |
+
|
27 |
+
self.convs = OrderedDict()
|
28 |
+
self.convs[("squeeze")] = nn.Conv2d(self.num_ch_enc[-1], 256, 1)
|
29 |
+
self.convs[("pose", 0)] = nn.Conv2d(num_input_features * 256, 256, 3, stride, 1)
|
30 |
+
self.convs[("pose", 1)] = nn.Conv2d(256, 256, 3, stride, 1)
|
31 |
+
self.convs[("pose", 2)] = nn.Conv2d(256, 6 * num_frames_to_predict_for, 1)
|
32 |
+
self.convs[("intrinsics", 'focal')] = nn.Conv2d(256, 2, kernel_size = 3,stride = 1,padding = 1)
|
33 |
+
self.convs[("intrinsics", 'offset')] = nn.Conv2d(256, 2, kernel_size = 3,stride = 1,padding = 1)
|
34 |
+
|
35 |
+
self.relu = nn.ReLU()
|
36 |
+
self.net = nn.ModuleList(list(self.convs.values()))
|
37 |
+
|
38 |
+
def forward(self, input_features):
|
39 |
+
last_features = [f[-1] for f in input_features]
|
40 |
+
|
41 |
+
cat_features = [self.relu(self.convs["squeeze"](f)) for f in last_features]
|
42 |
+
cat_features = torch.cat(cat_features, 1)
|
43 |
+
|
44 |
+
feat = cat_features
|
45 |
+
for i in range(2):
|
46 |
+
feat = self.convs[("pose", i)](feat)
|
47 |
+
feat = self.relu(feat)
|
48 |
+
out = self.convs[("pose", 2)](feat)
|
49 |
+
|
50 |
+
out = out.mean(3).mean(2)
|
51 |
+
out = 0.01 * out.view(-1, self.num_frames_to_predict_for, 1, 6)
|
52 |
+
|
53 |
+
axisangle = out[..., :3]
|
54 |
+
translation = out[..., 3:]
|
55 |
+
|
56 |
+
#add_intrinsics_head
|
57 |
+
scales = torch.tensor([256,256]).cuda()
|
58 |
+
focals = F.softplus(self.convs[("intrinsics", 'focal')](feat)).mean(3).mean(2)*scales
|
59 |
+
offset = (F.softplus(self.convs[("intrinsics", 'offset')](feat)).mean(3).mean(2)+0.5)*scales
|
60 |
+
#focals = F.softplus(self.convs[("intrinsics",'focal')](feat).mean(3).mean(2))
|
61 |
+
#offset = F.softplus(self.convs[("intrinsics",'offset')](feat).mean(3).mean(2))
|
62 |
+
eyes = torch.eye(2).cuda()
|
63 |
+
b,xy = focals.shape
|
64 |
+
focals = focals.unsqueeze(-1).expand(b,xy,xy)
|
65 |
+
eyes = eyes.unsqueeze(0).expand(b,xy,xy)
|
66 |
+
intrin = focals*eyes
|
67 |
+
offset = offset.view(b,2,1).contiguous()
|
68 |
+
intrin = torch.cat([intrin,offset],-1)
|
69 |
+
pad = torch.tensor([0.0,0.0,1.0]).view(1,1,3).expand(b,1,3).cuda()
|
70 |
+
intrinsics = torch.cat([intrin,pad],1)
|
71 |
+
return axisangle, translation,intrinsics
|
depth/resnet_encoder.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright Niantic 2019. Patent Pending. All rights reserved.
|
2 |
+
#
|
3 |
+
# This software is licensed under the terms of the Monodepth2 licence
|
4 |
+
# which allows for non-commercial use only, the full terms of which are made
|
5 |
+
# available in the LICENSE file.
|
6 |
+
|
7 |
+
from __future__ import absolute_import, division, print_function
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
import torchvision.models as models
|
14 |
+
import torch.utils.model_zoo as model_zoo
|
15 |
+
|
16 |
+
|
17 |
+
class ResNetMultiImageInput(models.ResNet):
|
18 |
+
"""Constructs a resnet model with varying number of input images.
|
19 |
+
Adapted from https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
|
20 |
+
"""
|
21 |
+
def __init__(self, block, layers, num_classes=1000, num_input_images=1):
|
22 |
+
super(ResNetMultiImageInput, self).__init__(block, layers)
|
23 |
+
self.inplanes = 64
|
24 |
+
self.conv1 = nn.Conv2d(
|
25 |
+
num_input_images * 3, 64, kernel_size=7, stride=2, padding=3, bias=False)
|
26 |
+
self.bn1 = nn.BatchNorm2d(64)
|
27 |
+
self.relu = nn.ReLU(inplace=True)
|
28 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
29 |
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
30 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
31 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
|
32 |
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
|
33 |
+
|
34 |
+
for m in self.modules():
|
35 |
+
if isinstance(m, nn.Conv2d):
|
36 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
37 |
+
elif isinstance(m, nn.BatchNorm2d):
|
38 |
+
nn.init.constant_(m.weight, 1)
|
39 |
+
nn.init.constant_(m.bias, 0)
|
40 |
+
|
41 |
+
|
42 |
+
def resnet_multiimage_input(num_layers, pretrained=False, num_input_images=1):
|
43 |
+
"""Constructs a ResNet model.
|
44 |
+
Args:
|
45 |
+
num_layers (int): Number of resnet layers. Must be 18 or 50
|
46 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
47 |
+
num_input_images (int): Number of frames stacked as input
|
48 |
+
"""
|
49 |
+
assert num_layers in [18, 50], "Can only run with 18 or 50 layer resnet"
|
50 |
+
blocks = {18: [2, 2, 2, 2], 50: [3, 4, 6, 3]}[num_layers]
|
51 |
+
block_type = {18: models.resnet.BasicBlock, 50: models.resnet.Bottleneck}[num_layers]
|
52 |
+
model = ResNetMultiImageInput(block_type, blocks, num_input_images=num_input_images)
|
53 |
+
|
54 |
+
if pretrained:
|
55 |
+
loaded = model_zoo.load_url(models.resnet.model_urls['resnet{}'.format(num_layers)])
|
56 |
+
loaded['conv1.weight'] = torch.cat(
|
57 |
+
[loaded['conv1.weight']] * num_input_images, 1) / num_input_images
|
58 |
+
model.load_state_dict(loaded)
|
59 |
+
return model
|
60 |
+
|
61 |
+
|
62 |
+
class ResnetEncoder(nn.Module):
|
63 |
+
"""Pytorch module for a resnet encoder
|
64 |
+
"""
|
65 |
+
def __init__(self, num_layers, pretrained, num_input_images=1):
|
66 |
+
super(ResnetEncoder, self).__init__()
|
67 |
+
|
68 |
+
self.num_ch_enc = np.array([64, 64, 128, 256, 512])
|
69 |
+
|
70 |
+
resnets = {18: models.resnet18,
|
71 |
+
34: models.resnet34,
|
72 |
+
50: models.resnet50,
|
73 |
+
101: models.resnet101,
|
74 |
+
152: models.resnet152}
|
75 |
+
|
76 |
+
if num_layers not in resnets:
|
77 |
+
raise ValueError("{} is not a valid number of resnet layers".format(num_layers))
|
78 |
+
|
79 |
+
if num_input_images > 1:
|
80 |
+
self.encoder = resnet_multiimage_input(num_layers, pretrained, num_input_images)
|
81 |
+
else:
|
82 |
+
self.encoder = resnets[num_layers](pretrained)
|
83 |
+
|
84 |
+
if num_layers > 34:
|
85 |
+
self.num_ch_enc[1:] *= 4
|
86 |
+
|
87 |
+
def forward(self, input_image):
|
88 |
+
self.features = []
|
89 |
+
x = (input_image - 0.45) / 0.225
|
90 |
+
x = self.encoder.conv1(x)
|
91 |
+
x = self.encoder.bn1(x)
|
92 |
+
self.features.append(self.encoder.relu(x))
|
93 |
+
self.features.append(self.encoder.layer1(self.encoder.maxpool(self.features[-1])))
|
94 |
+
self.features.append(self.encoder.layer2(self.features[-1]))
|
95 |
+
self.features.append(self.encoder.layer3(self.features[-1]))
|
96 |
+
self.features.append(self.encoder.layer4(self.features[-1]))
|
97 |
+
|
98 |
+
return self.features
|
modules/AdaIN.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
def calc_mean_std(feat, eps=1e-5):
|
4 |
+
# eps is a small value added to the variance to avoid divide-by-zero.
|
5 |
+
size = feat.size()
|
6 |
+
assert (len(size) == 4)
|
7 |
+
N, C = size[:2]
|
8 |
+
feat_var = feat.view(N, C, -1).var(dim=2) + eps
|
9 |
+
feat_std = feat_var.sqrt().view(N, C, 1, 1)
|
10 |
+
feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
|
11 |
+
return feat_mean, feat_std
|
12 |
+
|
13 |
+
def adaptive_instance_normalization(content_feat, style_feat):
|
14 |
+
assert (content_feat.size()[:2] == style_feat.size()[:2])
|
15 |
+
size = content_feat.size()
|
16 |
+
style_mean, style_std = calc_mean_std(style_feat)
|
17 |
+
content_mean, content_std = calc_mean_std(content_feat)
|
18 |
+
normalized_feat = (content_feat - content_mean.expand(
|
19 |
+
size)) / content_std.expand(size)
|
20 |
+
|
21 |
+
return normalized_feat * style_std.expand(size) + style_mean.expand(size)
|
22 |
+
|
23 |
+
def _calc_feat_flatten_mean_std(feat):
|
24 |
+
# takes 3D feat (C, H, W), return mean and std of array within channels
|
25 |
+
assert (feat.size()[0] == 3)
|
26 |
+
assert (isinstance(feat, torch.FloatTensor))
|
27 |
+
feat_flatten = feat.view(3, -1)
|
28 |
+
mean = feat_flatten.mean(dim=-1, keepdim=True)
|
29 |
+
std = feat_flatten.std(dim=-1, keepdim=True)
|
30 |
+
return feat_flatten, mean, std
|
31 |
+
|
32 |
+
def _mat_sqrt(x):
|
33 |
+
U, D, V = torch.svd(x)
|
34 |
+
return torch.mm(torch.mm(U, D.pow(0.5).diag()), V.t())
|
35 |
+
|
36 |
+
def coral(source, target):
|
37 |
+
# assume both source and target are 3D array (C, H, W)
|
38 |
+
# Note: flatten -> f
|
39 |
+
source_f, source_f_mean, source_f_std = _calc_feat_flatten_mean_std(source)
|
40 |
+
source_f_norm = (source_f - source_f_mean.expand_as(
|
41 |
+
source_f)) / source_f_std.expand_as(source_f)
|
42 |
+
source_f_cov_eye = \
|
43 |
+
torch.mm(source_f_norm, source_f_norm.t()) + torch.eye(3)
|
44 |
+
|
45 |
+
target_f, target_f_mean, target_f_std = _calc_feat_flatten_mean_std(target)
|
46 |
+
target_f_norm = (target_f - target_f_mean.expand_as(
|
47 |
+
target_f)) / target_f_std.expand_as(target_f)
|
48 |
+
target_f_cov_eye = \
|
49 |
+
torch.mm(target_f_norm, target_f_norm.t()) + torch.eye(3)
|
50 |
+
|
51 |
+
source_f_norm_transfer = torch.mm(
|
52 |
+
_mat_sqrt(target_f_cov_eye),
|
53 |
+
torch.mm(torch.inverse(_mat_sqrt(source_f_cov_eye)),
|
54 |
+
source_f_norm)
|
55 |
+
)
|
56 |
+
|
57 |
+
source_f_transfer = source_f_norm_transfer * \
|
58 |
+
target_f_std.expand_as(source_f_norm) + \
|
59 |
+
target_f_mean.expand_as(source_f_norm)
|
60 |
+
|
61 |
+
return source_f_transfer.view(source.size())
|
modules/dense_motion.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import torch
|
4 |
+
from modules.util import Hourglass, AntiAliasInterpolation2d, make_coordinate_grid, kp2gaussian
|
5 |
+
import pdb
|
6 |
+
from modules.util import ResBlock2d, SameBlock2d, UpBlock2d, DownBlock2d
|
7 |
+
|
8 |
+
|
9 |
+
class DenseMotionNetwork(nn.Module):
|
10 |
+
"""
|
11 |
+
Module that predicting a dense motion from sparse motion representation given by kp_source and kp_driving
|
12 |
+
"""
|
13 |
+
|
14 |
+
def __init__(self, block_expansion, num_blocks, max_features, num_kp, num_channels, estimate_occlusion_map=False,
|
15 |
+
scale_factor=1, kp_variance=0.01):
|
16 |
+
super(DenseMotionNetwork, self).__init__()
|
17 |
+
self.hourglass = Hourglass(block_expansion=block_expansion, in_features=(num_kp + 1) * (num_channels + 1),
|
18 |
+
max_features=max_features, num_blocks=num_blocks)
|
19 |
+
|
20 |
+
self.mask = nn.Conv2d(self.hourglass.out_filters, num_kp + 1, kernel_size=(7, 7), padding=(3, 3))
|
21 |
+
|
22 |
+
if estimate_occlusion_map:
|
23 |
+
self.occlusion = nn.Conv2d(self.hourglass.out_filters, 1, kernel_size=(7, 7), padding=(3, 3))
|
24 |
+
else:
|
25 |
+
self.occlusion = None
|
26 |
+
|
27 |
+
self.num_kp = num_kp
|
28 |
+
self.scale_factor = scale_factor
|
29 |
+
self.kp_variance = kp_variance
|
30 |
+
|
31 |
+
if self.scale_factor != 1:
|
32 |
+
self.down = AntiAliasInterpolation2d(num_channels, self.scale_factor)
|
33 |
+
|
34 |
+
def create_heatmap_representations(self, source_image, kp_driving, kp_source):
|
35 |
+
"""
|
36 |
+
Eq 6. in the paper H_k(z)
|
37 |
+
"""
|
38 |
+
spatial_size = source_image.shape[2:]
|
39 |
+
gaussian_driving = kp2gaussian(kp_driving, spatial_size=spatial_size, kp_variance=self.kp_variance)
|
40 |
+
gaussian_source = kp2gaussian(kp_source, spatial_size=spatial_size, kp_variance=self.kp_variance)
|
41 |
+
heatmap = gaussian_driving - gaussian_source
|
42 |
+
#adding background feature
|
43 |
+
zeros = torch.zeros(heatmap.shape[0], 1, spatial_size[0], spatial_size[1]).type(heatmap.type())
|
44 |
+
heatmap = torch.cat([zeros, heatmap], dim=1)
|
45 |
+
heatmap = heatmap.unsqueeze(2)
|
46 |
+
return heatmap
|
47 |
+
|
48 |
+
def create_sparse_motions(self, source_image, kp_driving, kp_source):
|
49 |
+
"""
|
50 |
+
Eq 4. in the paper T_{s<-d}(z)
|
51 |
+
"""
|
52 |
+
bs, _, h, w = source_image.shape
|
53 |
+
identity_grid = make_coordinate_grid((h, w), type=kp_source['value'].type())
|
54 |
+
identity_grid = identity_grid.view(1, 1, h, w, 2)
|
55 |
+
coordinate_grid = identity_grid - kp_driving['value'].view(bs, self.num_kp, 1, 1, 2)
|
56 |
+
if 'jacobian' in kp_driving:
|
57 |
+
jacobian = torch.matmul(kp_source['jacobian'], torch.inverse(kp_driving['jacobian']))
|
58 |
+
jacobian = jacobian.unsqueeze(-3).unsqueeze(-3)
|
59 |
+
jacobian = jacobian.repeat(1, 1, h, w, 1, 1)
|
60 |
+
coordinate_grid = torch.matmul(jacobian, coordinate_grid.unsqueeze(-1))
|
61 |
+
coordinate_grid = coordinate_grid.squeeze(-1)
|
62 |
+
|
63 |
+
driving_to_source = coordinate_grid + kp_source['value'].view(bs, self.num_kp, 1, 1, 2)
|
64 |
+
|
65 |
+
#adding background feature
|
66 |
+
identity_grid = identity_grid.repeat(bs, 1, 1, 1, 1)
|
67 |
+
sparse_motions = torch.cat([identity_grid, driving_to_source], dim=1) #bs, num_kp+1,w,h,2
|
68 |
+
return sparse_motions
|
69 |
+
|
70 |
+
def create_deformed_source_image(self, source_image, sparse_motions):
|
71 |
+
"""
|
72 |
+
Eq 7. in the paper \hat{T}_{s<-d}(z)
|
73 |
+
"""
|
74 |
+
bs, _, h, w = source_image.shape
|
75 |
+
source_repeat = source_image.unsqueeze(1).unsqueeze(1).repeat(1, self.num_kp + 1, 1, 1, 1, 1)
|
76 |
+
source_repeat = source_repeat.view(bs * (self.num_kp + 1), -1, h, w)
|
77 |
+
sparse_motions = sparse_motions.view((bs * (self.num_kp + 1), h, w, -1))
|
78 |
+
sparse_deformed = F.grid_sample(source_repeat, sparse_motions)
|
79 |
+
sparse_deformed = sparse_deformed.view((bs, self.num_kp + 1, -1, h, w))
|
80 |
+
return sparse_deformed
|
81 |
+
|
82 |
+
def forward(self, source_image, kp_driving, kp_source):
|
83 |
+
if self.scale_factor != 1:
|
84 |
+
source_image = self.down(source_image)
|
85 |
+
bs, _, h, w = source_image.shape
|
86 |
+
out_dict = dict()
|
87 |
+
heatmap_representation = self.create_heatmap_representations(source_image, kp_driving, kp_source)
|
88 |
+
sparse_motion = self.create_sparse_motions(source_image, kp_driving, kp_source)
|
89 |
+
deformed_source = self.create_deformed_source_image(source_image, sparse_motion)
|
90 |
+
out_dict['sparse_deformed'] = deformed_source
|
91 |
+
|
92 |
+
input = torch.cat([heatmap_representation, deformed_source], dim=2)
|
93 |
+
input = input.view(bs, -1, h, w)
|
94 |
+
|
95 |
+
prediction = self.hourglass(input)
|
96 |
+
|
97 |
+
mask = self.mask(prediction)
|
98 |
+
mask = F.softmax(mask, dim=1)
|
99 |
+
out_dict['mask'] = mask
|
100 |
+
mask = mask.unsqueeze(2)
|
101 |
+
sparse_motion = sparse_motion.permute(0, 1, 4, 2, 3)
|
102 |
+
deformation = (sparse_motion * mask).sum(dim=1)
|
103 |
+
deformation = deformation.permute(0, 2, 3, 1)
|
104 |
+
|
105 |
+
out_dict['deformation'] = deformation
|
106 |
+
|
107 |
+
# Sec. 3.2 in the paper
|
108 |
+
if self.occlusion:
|
109 |
+
occlusion_map = torch.sigmoid(self.occlusion(prediction))
|
110 |
+
out_dict['occlusion_map'] = occlusion_map
|
111 |
+
|
112 |
+
return out_dict
|
modules/discriminator.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from modules.util import kp2gaussian
|
4 |
+
import torch
|
5 |
+
import pdb
|
6 |
+
|
7 |
+
class DownBlock2d(nn.Module):
|
8 |
+
"""
|
9 |
+
Simple block for processing video (encoder).
|
10 |
+
"""
|
11 |
+
|
12 |
+
def __init__(self, in_features, out_features, norm=False, kernel_size=4, pool=False, sn=False):
|
13 |
+
super(DownBlock2d, self).__init__()
|
14 |
+
self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size)
|
15 |
+
|
16 |
+
if sn:
|
17 |
+
self.conv = nn.utils.spectral_norm(self.conv)
|
18 |
+
|
19 |
+
if norm:
|
20 |
+
self.norm = nn.InstanceNorm2d(out_features, affine=True)
|
21 |
+
else:
|
22 |
+
self.norm = None
|
23 |
+
self.pool = pool
|
24 |
+
|
25 |
+
def forward(self, x):
|
26 |
+
out = x
|
27 |
+
out = self.conv(out)
|
28 |
+
if self.norm:
|
29 |
+
out = self.norm(out)
|
30 |
+
out = F.leaky_relu(out, 0.2)
|
31 |
+
if self.pool:
|
32 |
+
out = F.avg_pool2d(out, (2, 2))
|
33 |
+
return out
|
34 |
+
|
35 |
+
|
36 |
+
class Discriminator(nn.Module):
|
37 |
+
"""
|
38 |
+
Discriminator similar to Pix2Pix
|
39 |
+
"""
|
40 |
+
|
41 |
+
def __init__(self, num_channels=3, block_expansion=64, num_blocks=4, max_features=512,
|
42 |
+
sn=False, use_kp=False, num_kp=10, kp_variance=0.01, **kwargs):
|
43 |
+
super(Discriminator, self).__init__()
|
44 |
+
|
45 |
+
down_blocks = []
|
46 |
+
for i in range(num_blocks):
|
47 |
+
down_blocks.append(
|
48 |
+
DownBlock2d(num_channels + num_kp * use_kp if i == 0 else min(max_features, block_expansion * (2 ** i)),
|
49 |
+
min(max_features, block_expansion * (2 ** (i + 1))),
|
50 |
+
norm=(i != 0), kernel_size=4, pool=(i != num_blocks - 1), sn=sn))
|
51 |
+
self.down_blocks = nn.ModuleList(down_blocks)
|
52 |
+
self.conv = nn.Conv2d(self.down_blocks[-1].conv.out_channels, out_channels=1, kernel_size=1)
|
53 |
+
if sn:
|
54 |
+
self.conv = nn.utils.spectral_norm(self.conv)
|
55 |
+
self.use_kp = use_kp
|
56 |
+
self.kp_variance = kp_variance
|
57 |
+
|
58 |
+
def forward(self, x, kp=None):
|
59 |
+
feature_maps = []
|
60 |
+
out = x
|
61 |
+
if self.use_kp:
|
62 |
+
heatmap = kp2gaussian(kp, x.shape[2:], self.kp_variance)
|
63 |
+
out = torch.cat([out, heatmap], dim=1)
|
64 |
+
# print(out.shape)
|
65 |
+
for down_block in self.down_blocks:
|
66 |
+
feature_maps.append(down_block(out))
|
67 |
+
out = feature_maps[-1]
|
68 |
+
# print(out.shape)
|
69 |
+
prediction_map = self.conv(out)
|
70 |
+
|
71 |
+
return feature_maps, prediction_map
|
72 |
+
|
73 |
+
|
74 |
+
class MultiScaleDiscriminator(nn.Module):
|
75 |
+
"""
|
76 |
+
Multi-scale (scale) discriminator
|
77 |
+
"""
|
78 |
+
|
79 |
+
def __init__(self, scales=(), **kwargs):
|
80 |
+
super(MultiScaleDiscriminator, self).__init__()
|
81 |
+
self.scales = scales
|
82 |
+
discs = {}
|
83 |
+
for scale in scales:
|
84 |
+
discs[str(scale).replace('.', '-')] = Discriminator(**kwargs)
|
85 |
+
self.discs = nn.ModuleDict(discs)
|
86 |
+
|
87 |
+
def forward(self, x, kp=None):
|
88 |
+
out_dict = {}
|
89 |
+
for scale, disc in self.discs.items():
|
90 |
+
scale = str(scale).replace('-', '.')
|
91 |
+
key = 'prediction_' + scale
|
92 |
+
feature_maps, prediction_map = disc(x[key], kp)
|
93 |
+
out_dict['feature_maps_' + scale] = feature_maps
|
94 |
+
out_dict['prediction_map_' + scale] = prediction_map
|
95 |
+
return out_dict
|
modules/dynamic_conv.py
ADDED
@@ -0,0 +1,382 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import pdb
|
5 |
+
|
6 |
+
class attention1d(nn.Module):
|
7 |
+
def __init__(self, in_planes, ratios, K, temperature, init_weight=True):
|
8 |
+
super(attention1d, self).__init__()
|
9 |
+
assert temperature%3==1
|
10 |
+
self.avgpool = nn.AdaptiveAvgPool1d(1)
|
11 |
+
if in_planes!=3:
|
12 |
+
hidden_planes = int(in_planes*ratios)+1
|
13 |
+
else:
|
14 |
+
hidden_planes = K
|
15 |
+
self.fc1 = nn.Conv1d(in_planes, hidden_planes, 1, bias=False)
|
16 |
+
# self.bn = nn.BatchNorm2d(hidden_planes)
|
17 |
+
self.fc2 = nn.Conv1d(hidden_planes, K, 1, bias=True)
|
18 |
+
self.temperature = temperature
|
19 |
+
if init_weight:
|
20 |
+
self._initialize_weights()
|
21 |
+
|
22 |
+
|
23 |
+
def _initialize_weights(self):
|
24 |
+
for m in self.modules():
|
25 |
+
if isinstance(m, nn.Conv1d):
|
26 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
27 |
+
if m.bias is not None:
|
28 |
+
nn.init.constant_(m.bias, 0)
|
29 |
+
if isinstance(m ,nn.BatchNorm2d):
|
30 |
+
nn.init.constant_(m.weight, 1)
|
31 |
+
nn.init.constant_(m.bias, 0)
|
32 |
+
|
33 |
+
def updata_temperature(self):
|
34 |
+
if self.temperature!=1:
|
35 |
+
self.temperature -=3
|
36 |
+
print('Change temperature to:', str(self.temperature))
|
37 |
+
|
38 |
+
|
39 |
+
def forward(self, x):
|
40 |
+
x = self.avgpool(x)
|
41 |
+
x = self.fc1(x)
|
42 |
+
x = F.relu(x)
|
43 |
+
x = self.fc2(x).view(x.size(0), -1)
|
44 |
+
return F.softmax(x/self.temperature, 1)
|
45 |
+
|
46 |
+
|
47 |
+
class Dynamic_conv1d(nn.Module):
|
48 |
+
def __init__(self, in_planes, out_planes, kernel_size, ratio=0.25, stride=1, padding=0, dilation=1, groups=1, bias=True, K=4,temperature=34, init_weight=True):
|
49 |
+
super(Dynamic_conv1d, self).__init__()
|
50 |
+
assert in_planes%groups==0
|
51 |
+
self.in_planes = in_planes
|
52 |
+
self.out_planes = out_planes
|
53 |
+
self.kernel_size = kernel_size
|
54 |
+
self.stride = stride
|
55 |
+
self.padding = padding
|
56 |
+
self.dilation = dilation
|
57 |
+
self.groups = groups
|
58 |
+
self.bias = bias
|
59 |
+
self.K = K
|
60 |
+
self.attention = attention1d(in_planes, ratio, K, temperature)
|
61 |
+
|
62 |
+
self.weight = nn.Parameter(torch.randn(K, out_planes, in_planes//groups, kernel_size), requires_grad=True)
|
63 |
+
if bias:
|
64 |
+
self.bias = nn.Parameter(torch.Tensor(K, out_planes))
|
65 |
+
else:
|
66 |
+
self.bias = None
|
67 |
+
if init_weight:
|
68 |
+
self._initialize_weights()
|
69 |
+
|
70 |
+
#TODO 初始化
|
71 |
+
def _initialize_weights(self):
|
72 |
+
for i in range(self.K):
|
73 |
+
nn.init.kaiming_uniform_(self.weight[i])
|
74 |
+
|
75 |
+
|
76 |
+
def update_temperature(self):
|
77 |
+
self.attention.updata_temperature()
|
78 |
+
|
79 |
+
def forward(self, x):#将batch视作维度变量,进行组卷积,因为组卷积的权重是不同的,动态卷积的权重也是不同的
|
80 |
+
softmax_attention = self.attention(x)
|
81 |
+
batch_size, in_planes, height = x.size()
|
82 |
+
x = x.view(1, -1, height, )# 变化成一个维度进行组卷积
|
83 |
+
weight = self.weight.view(self.K, -1)
|
84 |
+
|
85 |
+
# 动态卷积的权重的生成, 生成的是batch_size个卷积参数(每个参数不同)
|
86 |
+
aggregate_weight = torch.mm(softmax_attention, weight).view(-1, self.in_planes, self.kernel_size,)
|
87 |
+
if self.bias is not None:
|
88 |
+
aggregate_bias = torch.mm(softmax_attention, self.bias).view(-1)
|
89 |
+
output = F.conv1d(x, weight=aggregate_weight, bias=aggregate_bias, stride=self.stride, padding=self.padding,
|
90 |
+
dilation=self.dilation, groups=self.groups*batch_size)
|
91 |
+
else:
|
92 |
+
output = F.conv1d(x, weight=aggregate_weight, bias=None, stride=self.stride, padding=self.padding,
|
93 |
+
dilation=self.dilation, groups=self.groups * batch_size)
|
94 |
+
|
95 |
+
output = output.view(batch_size, self.out_planes, output.size(-1))
|
96 |
+
return output
|
97 |
+
|
98 |
+
|
99 |
+
|
100 |
+
class attention2d(nn.Module):
|
101 |
+
def __init__(self, in_planes, ratios, K, temperature, init_weight=True):
|
102 |
+
super(attention2d, self).__init__()
|
103 |
+
assert temperature%3==1
|
104 |
+
self.avgpool = nn.AdaptiveAvgPool2d(1)
|
105 |
+
if in_planes!=3:
|
106 |
+
hidden_planes = int(in_planes*ratios)+1
|
107 |
+
else:
|
108 |
+
hidden_planes = K
|
109 |
+
self.fc1 = nn.Conv2d(in_planes, hidden_planes, 1, bias=False)
|
110 |
+
# self.bn = nn.BatchNorm2d(hidden_planes)
|
111 |
+
self.fc2 = nn.Conv2d(hidden_planes, K, 1, bias=True)
|
112 |
+
self.temperature = temperature
|
113 |
+
if init_weight:
|
114 |
+
self._initialize_weights()
|
115 |
+
|
116 |
+
|
117 |
+
def _initialize_weights(self):
|
118 |
+
for m in self.modules():
|
119 |
+
if isinstance(m, nn.Conv2d):
|
120 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
121 |
+
if m.bias is not None:
|
122 |
+
nn.init.constant_(m.bias, 0)
|
123 |
+
if isinstance(m ,nn.BatchNorm2d):
|
124 |
+
nn.init.constant_(m.weight, 1)
|
125 |
+
nn.init.constant_(m.bias, 0)
|
126 |
+
|
127 |
+
def updata_temperature(self):
|
128 |
+
if self.temperature!=1:
|
129 |
+
self.temperature -=3
|
130 |
+
print('Change temperature to:', str(self.temperature))
|
131 |
+
|
132 |
+
|
133 |
+
def forward(self, x):
|
134 |
+
x = self.avgpool(x)
|
135 |
+
x = self.fc1(x)
|
136 |
+
x = F.relu(x)
|
137 |
+
x = self.fc2(x).view(x.size(0), -1)
|
138 |
+
return F.softmax(x/self.temperature, 1)
|
139 |
+
|
140 |
+
|
141 |
+
class Dynamic_deepwise_conv2d(nn.Module):
|
142 |
+
def __init__(self, in_planes, out_planes, kernel_size, ratio=0.25, stride=1, padding=0, dilation=1, groups=1, bias=True, K=4,temperature=34, init_weight=True):
|
143 |
+
super(Dynamic_deepwise_conv2d, self).__init__()
|
144 |
+
assert in_planes%groups==0
|
145 |
+
self.in_planes = in_planes
|
146 |
+
self.out_planes = out_planes
|
147 |
+
self.kernel_size = kernel_size
|
148 |
+
self.stride = stride
|
149 |
+
self.padding = padding
|
150 |
+
self.dilation = dilation
|
151 |
+
self.groups = groups
|
152 |
+
self.bias = bias
|
153 |
+
self.K = K
|
154 |
+
self.attention = attention2d(in_planes, ratio, K, temperature)
|
155 |
+
|
156 |
+
self.weight = nn.Parameter(torch.randn(K, out_planes, in_planes//groups, kernel_size, kernel_size), requires_grad=True)
|
157 |
+
if bias:
|
158 |
+
self.bias = nn.Parameter(torch.Tensor(K, out_planes))
|
159 |
+
else:
|
160 |
+
self.bias = None
|
161 |
+
if init_weight:
|
162 |
+
self._initialize_weights()
|
163 |
+
|
164 |
+
#TODO 初始化
|
165 |
+
def _initialize_weights(self):
|
166 |
+
for i in range(self.K):
|
167 |
+
nn.init.kaiming_uniform_(self.weight[i])
|
168 |
+
|
169 |
+
|
170 |
+
def update_temperature(self):
|
171 |
+
self.attention.updata_temperature()
|
172 |
+
|
173 |
+
def forward(self, x, y):#将batch视作维度变量,进行组卷积,因为组卷积的权重是不同的,动态卷积的权重也是不同的
|
174 |
+
softmax_attention = self.attention(x)
|
175 |
+
batch_size, in_planes, height, width = x.size()
|
176 |
+
y = y.view(1, -1, height, width)# 变化成一个维度进行组卷积
|
177 |
+
weight = self.weight.view(self.K, -1)
|
178 |
+
|
179 |
+
# 动态卷积的权重的生成, 生成的是batch_size个卷积参数(每个参数不同)
|
180 |
+
aggregate_weight = torch.mm(softmax_attention, weight).view(-1, 1, self.kernel_size, self.kernel_size)
|
181 |
+
if self.bias is not None:
|
182 |
+
aggregate_bias = torch.mm(softmax_attention, self.bias).view(-1)
|
183 |
+
output = F.conv2d(y, weight=aggregate_weight, bias=aggregate_bias, stride=self.stride, padding=self.padding,
|
184 |
+
dilation=self.dilation, groups=self.groups*batch_size)
|
185 |
+
else:
|
186 |
+
output = F.conv2d(y, weight=aggregate_weight, bias=None, stride=self.stride, padding=self.padding,
|
187 |
+
dilation=self.dilation, groups=self.groups * batch_size)
|
188 |
+
|
189 |
+
output = output.view(batch_size, self.out_planes, output.size(-2), output.size(-1))
|
190 |
+
return output
|
191 |
+
|
192 |
+
class Dynamic_conv2d(nn.Module):
|
193 |
+
def __init__(self, in_planes, out_planes, kernel_size, ratio=0.25, stride=1, padding=0, dilation=1, groups=1, bias=True, K=4,temperature=34, init_weight=True):
|
194 |
+
super(Dynamic_conv2d, self).__init__()
|
195 |
+
assert in_planes%groups==0
|
196 |
+
self.in_planes = in_planes
|
197 |
+
self.out_planes = out_planes
|
198 |
+
self.kernel_size = kernel_size
|
199 |
+
self.stride = stride
|
200 |
+
self.padding = padding
|
201 |
+
self.dilation = dilation
|
202 |
+
self.groups = groups
|
203 |
+
self.bias = bias
|
204 |
+
self.K = K
|
205 |
+
self.attention = attention2d(in_planes, ratio, K, temperature)
|
206 |
+
|
207 |
+
self.weight = nn.Parameter(torch.randn(K, out_planes, in_planes//groups, kernel_size, kernel_size), requires_grad=True)
|
208 |
+
if bias:
|
209 |
+
self.bias = nn.Parameter(torch.Tensor(K, out_planes))
|
210 |
+
else:
|
211 |
+
self.bias = None
|
212 |
+
if init_weight:
|
213 |
+
self._initialize_weights()
|
214 |
+
|
215 |
+
#TODO 初始化
|
216 |
+
def _initialize_weights(self):
|
217 |
+
for i in range(self.K):
|
218 |
+
nn.init.kaiming_uniform_(self.weight[i])
|
219 |
+
|
220 |
+
|
221 |
+
def update_temperature(self):
|
222 |
+
self.attention.updata_temperature()
|
223 |
+
|
224 |
+
def forward(self, x,y):#将batch视作维度变量,进行组卷积,因为组卷积的权重是不同的,动态卷积的权重也是不同的
|
225 |
+
softmax_attention = self.attention(x)
|
226 |
+
batch_size, in_planes, height, width = x.size()
|
227 |
+
y = y.view(1, -1, height, width)# 变化成一个维度进行组卷积
|
228 |
+
weight = self.weight.view(self.K, -1)
|
229 |
+
|
230 |
+
# 动态卷积的权重的生成, 生成的是batch_size个卷积参数(每个参数不同)
|
231 |
+
aggregate_weight = torch.mm(softmax_attention, weight).view(-1, self.in_planes, self.kernel_size, self.kernel_size)
|
232 |
+
if self.bias is not None:
|
233 |
+
aggregate_bias = torch.mm(softmax_attention, self.bias).view(-1)
|
234 |
+
output = F.conv2d(y, weight=aggregate_weight, bias=aggregate_bias, stride=self.stride, padding=self.padding,
|
235 |
+
dilation=self.dilation, groups=self.groups*batch_size)
|
236 |
+
else:
|
237 |
+
output = F.conv2d(y, weight=aggregate_weight, bias=None, stride=self.stride, padding=self.padding,
|
238 |
+
dilation=self.dilation, groups=self.groups * batch_size)
|
239 |
+
|
240 |
+
output = output.view(batch_size, self.out_planes, output.size(-2), output.size(-1))
|
241 |
+
return output
|
242 |
+
|
243 |
+
|
244 |
+
class attention3d(nn.Module):
|
245 |
+
def __init__(self, in_planes, ratios, K, temperature):
|
246 |
+
super(attention3d, self).__init__()
|
247 |
+
assert temperature%3==1
|
248 |
+
self.avgpool = nn.AdaptiveAvgPool3d(1)
|
249 |
+
if in_planes != 3:
|
250 |
+
hidden_planes = int(in_planes * ratios)+1
|
251 |
+
else:
|
252 |
+
hidden_planes = K
|
253 |
+
self.fc1 = nn.Conv3d(in_planes, hidden_planes, 1, bias=False)
|
254 |
+
self.fc2 = nn.Conv3d(hidden_planes, K, 1, bias=False)
|
255 |
+
self.temperature = temperature
|
256 |
+
|
257 |
+
def updata_temperature(self):
|
258 |
+
if self.temperature!=1:
|
259 |
+
self.temperature -=3
|
260 |
+
print('Change temperature to:', str(self.temperature))
|
261 |
+
|
262 |
+
def forward(self, x):
|
263 |
+
x = self.avgpool(x)
|
264 |
+
x = self.fc1(x)
|
265 |
+
x = F.relu(x)
|
266 |
+
x = self.fc2(x).view(x.size(0), -1)
|
267 |
+
return F.softmax(x / self.temperature, 1)
|
268 |
+
|
269 |
+
class Dynamic_conv3d(nn.Module):
|
270 |
+
def __init__(self, in_planes, out_planes, kernel_size, ratio=0.25, stride=1, padding=0, dilation=1, groups=1, bias=True, K=4, temperature=34):
|
271 |
+
super(Dynamic_conv3d, self).__init__()
|
272 |
+
assert in_planes%groups==0
|
273 |
+
self.in_planes = in_planes
|
274 |
+
self.out_planes = out_planes
|
275 |
+
self.kernel_size = kernel_size
|
276 |
+
self.stride = stride
|
277 |
+
self.padding = padding
|
278 |
+
self.dilation = dilation
|
279 |
+
self.groups = groups
|
280 |
+
self.bias = bias
|
281 |
+
self.K = K
|
282 |
+
self.attention = attention3d(in_planes, ratio, K, temperature)
|
283 |
+
|
284 |
+
self.weight = nn.Parameter(torch.randn(K, out_planes, in_planes//groups, kernel_size, kernel_size, kernel_size), requires_grad=True)
|
285 |
+
if bias:
|
286 |
+
self.bias = nn.Parameter(torch.Tensor(K, out_planes))
|
287 |
+
else:
|
288 |
+
self.bias = None
|
289 |
+
|
290 |
+
|
291 |
+
#TODO 初始化
|
292 |
+
# nn.init.kaiming_uniform_(self.weight, )
|
293 |
+
|
294 |
+
def update_temperature(self):
|
295 |
+
self.attention.updata_temperature()
|
296 |
+
|
297 |
+
def forward(self, x):#将batch视作维度变量,进行组卷积,因为组卷积的权重是不同的,动态卷积的权重也是不同的
|
298 |
+
softmax_attention = self.attention(x)
|
299 |
+
batch_size, in_planes, depth, height, width = x.size()
|
300 |
+
x = x.view(1, -1, depth, height, width)# 变化成一个维度进行组卷积
|
301 |
+
weight = self.weight.view(self.K, -1)
|
302 |
+
|
303 |
+
# 动态卷积的权重的生成, 生成的是batch_size个卷积参数(每个参数不同)
|
304 |
+
aggregate_weight = torch.mm(softmax_attention, weight).view(-1, self.in_planes, self.kernel_size, self.kernel_size, self.kernel_size)
|
305 |
+
if self.bias is not None:
|
306 |
+
aggregate_bias = torch.mm(softmax_attention, self.bias).view(-1)
|
307 |
+
output = F.conv3d(x, weight=aggregate_weight, bias=aggregate_bias, stride=self.stride, padding=self.padding,
|
308 |
+
dilation=self.dilation, groups=self.groups*batch_size)
|
309 |
+
else:
|
310 |
+
output = F.conv3d(x, weight=aggregate_weight, bias=None, stride=self.stride, padding=self.padding,
|
311 |
+
dilation=self.dilation, groups=self.groups * batch_size)
|
312 |
+
|
313 |
+
output = output.view(batch_size, self.out_planes, output.size(-3), output.size(-2), output.size(-1))
|
314 |
+
return output
|
315 |
+
|
316 |
+
|
317 |
+
|
318 |
+
|
319 |
+
if __name__ == '__main__':
|
320 |
+
x = torch.randn(12, 256, 64, 64)
|
321 |
+
y = torch.randn(12, 256, 64, 64)
|
322 |
+
|
323 |
+
model = Dynamic_conv2d(in_planes=256, out_planes=256, kernel_size=3, ratio=0.25, padding=1,groups=1)
|
324 |
+
x = x.to('cuda:0')
|
325 |
+
y = y.to('cuda:0')
|
326 |
+
model.to('cuda')
|
327 |
+
# model.attention.cuda()
|
328 |
+
print(model(x,y).shape)
|
329 |
+
# nn.Conv3d()
|
330 |
+
# print(model(x).shape)
|
331 |
+
# model.update_temperature()
|
332 |
+
# model.update_temperature()
|
333 |
+
# model.update_temperature()
|
334 |
+
# model.update_temperature()
|
335 |
+
# model.update_temperature()
|
336 |
+
# model.update_temperature()
|
337 |
+
# model.update_temperature()
|
338 |
+
# model.update_temperature()
|
339 |
+
# model.update_temperature()
|
340 |
+
# model.update_temperature()
|
341 |
+
# model.update_temperature()
|
342 |
+
# model.update_temperature()
|
343 |
+
# model.update_temperature()
|
344 |
+
# print(model(x).shape)
|
345 |
+
# print(model(x).shape)
|
346 |
+
# print(model(x).shape)
|
347 |
+
# print(model(x).shape)
|
348 |
+
# print(model(x).shape)
|
349 |
+
# print(model(x).shape)
|
350 |
+
# print(model(x).shape)
|
351 |
+
# print(model(x).shape)
|
352 |
+
# print(model(x).shape)
|
353 |
+
# print(model(x).shape)
|
354 |
+
# print(model(x).shape)
|
355 |
+
# print(model(x).shape)
|
356 |
+
# print(model(x).shape)
|
357 |
+
# print(model(x).shape)
|
358 |
+
# print(model(x).shape)
|
359 |
+
# print(model(x).shape)
|
360 |
+
# print(model(x).shape)
|
361 |
+
# print(model(x).shape)
|
362 |
+
# print(model(x).shape)
|
363 |
+
# print(model(x).shape)
|
364 |
+
# print(model(x).shape)
|
365 |
+
# print(model(x).shape)
|
366 |
+
# print(model(x).shape)
|
367 |
+
# print(model(x).shape)
|
368 |
+
# print(model(x).shape)
|
369 |
+
# print(model(x).shape)
|
370 |
+
# print(model(x).shape)
|
371 |
+
# print(model(x).shape)
|
372 |
+
# print(model(x).shape)
|
373 |
+
# print(model(x).shape)
|
374 |
+
# print(model(x).shape)
|
375 |
+
# print(model(x).shape)
|
376 |
+
# print(model(x).shape)
|
377 |
+
# print(model(x).shape)
|
378 |
+
# print(model(x).shape)
|
379 |
+
# print(model(x).shape)
|
380 |
+
# print(model(x).shape)
|
381 |
+
# print(model(x).shape)
|
382 |
+
|
modules/generator.py
ADDED
@@ -0,0 +1,325 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from modules.util import ResBlock2d, SameBlock2d, UpBlock2d, DownBlock2d,SPADEResnetBlock
|
5 |
+
from modules.dense_motion import *
|
6 |
+
import pdb
|
7 |
+
from modules.AdaIN import calc_mean_std,adaptive_instance_normalization
|
8 |
+
from modules.dynamic_conv import Dynamic_conv2d
|
9 |
+
class SPADEGenerator(nn.Module):
|
10 |
+
def __init__(self):
|
11 |
+
super().__init__()
|
12 |
+
ic = 256
|
13 |
+
cc = 4
|
14 |
+
oc = 64
|
15 |
+
norm_G = 'spadespectralinstance'
|
16 |
+
label_nc = 3 + cc
|
17 |
+
|
18 |
+
self.compress = nn.Conv2d(ic, cc, 3, padding=1)
|
19 |
+
self.fc = nn.Conv2d(ic, 2 * ic, 3, padding=1)
|
20 |
+
|
21 |
+
self.G_middle_0 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc)
|
22 |
+
self.G_middle_1 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc)
|
23 |
+
self.G_middle_2 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc)
|
24 |
+
# self.G_middle_3 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc)
|
25 |
+
# self.G_middle_4 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc)
|
26 |
+
# self.G_middle_5 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc)
|
27 |
+
|
28 |
+
self.up_0 = SPADEResnetBlock(2 * ic, ic, norm_G, label_nc)
|
29 |
+
self.up_1 = SPADEResnetBlock(ic, oc, norm_G, label_nc)
|
30 |
+
self.conv_img = nn.Conv2d(oc, 3, 3, padding=1)
|
31 |
+
self.up = nn.Upsample(scale_factor=2)
|
32 |
+
|
33 |
+
def forward(self, feature, image):
|
34 |
+
cp = self.compress(feature)
|
35 |
+
seg = torch.cat((F.interpolate(cp, size=(image.shape[2], image.shape[3])), image), dim=1) # 7, 256, 256
|
36 |
+
|
37 |
+
x = feature # 256, 64, 64
|
38 |
+
x = self.fc(x) # 512, 64, 64
|
39 |
+
x = self.G_middle_0(x, seg)
|
40 |
+
x = self.G_middle_1(x, seg)
|
41 |
+
x = self.G_middle_2(x, seg)
|
42 |
+
# x = self.G_middle_3(x, seg)
|
43 |
+
# x = self.G_middle_4(x, seg)
|
44 |
+
# x = self.G_middle_5(x, seg)
|
45 |
+
x = self.up(x) # 256, 128, 128
|
46 |
+
x = self.up_0(x, seg)
|
47 |
+
x = self.up(x) # 64, 256, 256
|
48 |
+
x = self.up_1(x, seg)
|
49 |
+
|
50 |
+
x = self.conv_img(F.leaky_relu(x, 2e-1))
|
51 |
+
# x = torch.tanh(x)
|
52 |
+
x = F.sigmoid(x)
|
53 |
+
|
54 |
+
return x
|
55 |
+
|
56 |
+
class DepthAwareAttention(nn.Module):
|
57 |
+
""" depth-aware attention Layer"""
|
58 |
+
def __init__(self,in_dim,activation):
|
59 |
+
super(DepthAwareAttention,self).__init__()
|
60 |
+
self.chanel_in = in_dim
|
61 |
+
self.activation = activation
|
62 |
+
|
63 |
+
self.query_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)
|
64 |
+
self.key_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)
|
65 |
+
self.value_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim , kernel_size= 1)
|
66 |
+
self.gamma = nn.Parameter(torch.zeros(1))
|
67 |
+
|
68 |
+
self.softmax = nn.Softmax(dim=-1) #
|
69 |
+
def forward(self,source,feat):
|
70 |
+
"""
|
71 |
+
inputs :
|
72 |
+
source : input feature maps( B X C X W X H) 256,64,64
|
73 |
+
driving : input feature maps( B X C X W X H) 256,64,64
|
74 |
+
returns :
|
75 |
+
out : self attention value + input feature
|
76 |
+
attention: B X N X N (N is Width*Height)
|
77 |
+
"""
|
78 |
+
m_batchsize,C,width ,height = source.size()
|
79 |
+
proj_query = self.activation(self.query_conv(source)).view(m_batchsize,-1,width*height).permute(0,2,1) # B X CX(N) [bz,32,64,64]
|
80 |
+
proj_key = self.activation(self.key_conv(feat)).view(m_batchsize,-1,width*height) # B X C x (*W*H)
|
81 |
+
energy = torch.bmm(proj_query,proj_key) # transpose check
|
82 |
+
attention = self.softmax(energy) # BX (N) X (N)
|
83 |
+
proj_value = self.activation(self.value_conv(feat)).view(m_batchsize,-1,width*height) # B X C X N
|
84 |
+
|
85 |
+
out = torch.bmm(proj_value,attention.permute(0,2,1) )
|
86 |
+
out = out.view(m_batchsize,C,width,height)
|
87 |
+
out = self.gamma*out + feat
|
88 |
+
|
89 |
+
return out,attention
|
90 |
+
|
91 |
+
#### main ####
|
92 |
+
class DepthAwareGenerator(nn.Module):
|
93 |
+
"""
|
94 |
+
Generator that given source image and and keypoints try to transform image according to movement trajectories
|
95 |
+
induced by keypoints. Generator follows Johnson architecture.
|
96 |
+
"""
|
97 |
+
|
98 |
+
def __init__(self, num_channels, num_kp, block_expansion, max_features, num_down_blocks,
|
99 |
+
num_bottleneck_blocks, estimate_occlusion_map=False, dense_motion_params=None, estimate_jacobian=False):
|
100 |
+
super(DepthAwareGenerator, self).__init__()
|
101 |
+
|
102 |
+
if dense_motion_params is not None:
|
103 |
+
self.dense_motion_network = DenseMotionNetwork(num_kp=num_kp, num_channels=num_channels,
|
104 |
+
estimate_occlusion_map=estimate_occlusion_map,
|
105 |
+
**dense_motion_params)
|
106 |
+
else:
|
107 |
+
self.dense_motion_network = None
|
108 |
+
|
109 |
+
self.first = SameBlock2d(num_channels, block_expansion, kernel_size=(7, 7), padding=(3, 3))
|
110 |
+
down_blocks = []
|
111 |
+
for i in range(num_down_blocks):
|
112 |
+
in_features = min(max_features, block_expansion * (2 ** i))
|
113 |
+
out_features = min(max_features, block_expansion * (2 ** (i + 1)))
|
114 |
+
down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1)))
|
115 |
+
self.down_blocks = nn.ModuleList(down_blocks)
|
116 |
+
|
117 |
+
#source depth
|
118 |
+
self.src_first = SameBlock2d(1, block_expansion, kernel_size=(7, 7), padding=(3, 3))
|
119 |
+
src_down_blocks = []
|
120 |
+
for i in range(num_down_blocks):
|
121 |
+
in_features = min(max_features, block_expansion * (2 ** i))
|
122 |
+
out_features = min(max_features, block_expansion * (2 ** (i + 1)))
|
123 |
+
src_down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1)))
|
124 |
+
self.src_down_blocks = nn.ModuleList(src_down_blocks)
|
125 |
+
|
126 |
+
# #driving depth
|
127 |
+
# self.dst_first = SameBlock2d(1, block_expansion, kernel_size=(7, 7), padding=(3, 3))
|
128 |
+
# dst_down_blocks = []
|
129 |
+
# for i in range(num_down_blocks):
|
130 |
+
# in_features = min(max_features, block_expansion * (2 ** i))
|
131 |
+
# out_features = min(max_features, block_expansion * (2 ** (i + 1)))
|
132 |
+
# dst_down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1)))
|
133 |
+
# self.dst_down_blocks = nn.ModuleList(dst_down_blocks)
|
134 |
+
|
135 |
+
self.AttnModule = DepthAwareAttention(out_features,nn.ReLU())
|
136 |
+
|
137 |
+
up_blocks = []
|
138 |
+
for i in range(num_down_blocks):
|
139 |
+
in_features = min(max_features, block_expansion * (2 ** (num_down_blocks - i)))
|
140 |
+
out_features = min(max_features, block_expansion * (2 ** (num_down_blocks - i - 1)))
|
141 |
+
up_blocks.append(UpBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1)))
|
142 |
+
self.up_blocks = nn.ModuleList(up_blocks)
|
143 |
+
|
144 |
+
self.bottleneck = torch.nn.Sequential()
|
145 |
+
in_features = min(max_features, block_expansion * (2 ** num_down_blocks))
|
146 |
+
for i in range(num_bottleneck_blocks):
|
147 |
+
self.bottleneck.add_module('r' + str(i), ResBlock2d(in_features, kernel_size=(3, 3), padding=(1, 1)))
|
148 |
+
|
149 |
+
self.final = nn.Conv2d(block_expansion, num_channels, kernel_size=(7, 7), padding=(3, 3))
|
150 |
+
self.estimate_occlusion_map = estimate_occlusion_map
|
151 |
+
self.num_channels = num_channels
|
152 |
+
|
153 |
+
def deform_input(self, inp, deformation):
|
154 |
+
_, h_old, w_old, _ = deformation.shape
|
155 |
+
_, _, h, w = inp.shape
|
156 |
+
if h_old != h or w_old != w:
|
157 |
+
deformation = deformation.permute(0, 3, 1, 2)
|
158 |
+
deformation = F.interpolate(deformation, size=(h, w), mode='bilinear')
|
159 |
+
deformation = deformation.permute(0, 2, 3, 1)
|
160 |
+
return F.grid_sample(inp, deformation)
|
161 |
+
|
162 |
+
def forward(self, source_image, kp_driving, kp_source, source_depth, driving_depth):
|
163 |
+
# Encoding (downsampling) part
|
164 |
+
out = self.first(source_image)
|
165 |
+
for i in range(len(self.down_blocks)):
|
166 |
+
out = self.down_blocks[i](out)
|
167 |
+
|
168 |
+
src_out = self.src_first(source_depth)
|
169 |
+
for i in range(len(self.src_down_blocks)):
|
170 |
+
src_out = self.src_down_blocks[i](src_out)
|
171 |
+
|
172 |
+
# dst_out = self.dst_first(driving_depth)
|
173 |
+
# for i in range(len(self.down_blocks)):
|
174 |
+
# dst_out = self.dst_down_blocks[i](dst_out)
|
175 |
+
|
176 |
+
# Transforming feature representation according to deformation and occlusion
|
177 |
+
output_dict = {}
|
178 |
+
if self.dense_motion_network is not None:
|
179 |
+
dense_motion = self.dense_motion_network(source_image=source_image, kp_driving=kp_driving,
|
180 |
+
kp_source=kp_source)
|
181 |
+
output_dict['mask'] = dense_motion['mask']
|
182 |
+
output_dict['sparse_deformed'] = dense_motion['sparse_deformed']
|
183 |
+
|
184 |
+
if 'occlusion_map' in dense_motion:
|
185 |
+
occlusion_map = dense_motion['occlusion_map']
|
186 |
+
output_dict['occlusion_map'] = occlusion_map
|
187 |
+
else:
|
188 |
+
occlusion_map = None
|
189 |
+
deformation = dense_motion['deformation']
|
190 |
+
out = self.deform_input(out, deformation)
|
191 |
+
|
192 |
+
if occlusion_map is not None:
|
193 |
+
if out.shape[2] != occlusion_map.shape[2] or out.shape[3] != occlusion_map.shape[3]:
|
194 |
+
occlusion_map = F.interpolate(occlusion_map, size=out.shape[2:], mode='bilinear')
|
195 |
+
out = out * occlusion_map
|
196 |
+
out,attention = self.AttnModule(src_out,out)
|
197 |
+
|
198 |
+
output_dict["deformed"] = self.deform_input(source_image, deformation)
|
199 |
+
output_dict["attention"] = attention
|
200 |
+
|
201 |
+
# Decoding part
|
202 |
+
out = self.bottleneck(out)
|
203 |
+
for i in range(len(self.up_blocks)):
|
204 |
+
out = self.up_blocks[i](out)
|
205 |
+
out = self.final(out)
|
206 |
+
out = F.sigmoid(out)
|
207 |
+
|
208 |
+
output_dict["prediction"] = out
|
209 |
+
|
210 |
+
return output_dict
|
211 |
+
|
212 |
+
class SPADEDepthAwareGenerator(nn.Module):
|
213 |
+
"""
|
214 |
+
Generator that given source image and and keypoints try to transform image according to movement trajectories
|
215 |
+
induced by keypoints. Generator follows Johnson architecture.
|
216 |
+
"""
|
217 |
+
|
218 |
+
def __init__(self, num_channels, num_kp, block_expansion, max_features, num_down_blocks,
|
219 |
+
num_bottleneck_blocks, estimate_occlusion_map=False, dense_motion_params=None, estimate_jacobian=False):
|
220 |
+
super(SPADEDepthAwareGenerator, self).__init__()
|
221 |
+
|
222 |
+
if dense_motion_params is not None:
|
223 |
+
self.dense_motion_network = DenseMotionNetwork(num_kp=num_kp, num_channels=num_channels,
|
224 |
+
estimate_occlusion_map=estimate_occlusion_map,
|
225 |
+
**dense_motion_params)
|
226 |
+
else:
|
227 |
+
self.dense_motion_network = None
|
228 |
+
|
229 |
+
self.first = SameBlock2d(num_channels, block_expansion, kernel_size=(7, 7), padding=(3, 3))
|
230 |
+
down_blocks = []
|
231 |
+
for i in range(num_down_blocks):
|
232 |
+
in_features = min(max_features, block_expansion * (2 ** i))
|
233 |
+
out_features = min(max_features, block_expansion * (2 ** (i + 1)))
|
234 |
+
down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1)))
|
235 |
+
self.down_blocks = nn.ModuleList(down_blocks)
|
236 |
+
|
237 |
+
#source depth
|
238 |
+
self.src_first = SameBlock2d(1, block_expansion, kernel_size=(7, 7), padding=(3, 3))
|
239 |
+
src_down_blocks = []
|
240 |
+
for i in range(num_down_blocks):
|
241 |
+
in_features = min(max_features, block_expansion * (2 ** i))
|
242 |
+
out_features = min(max_features, block_expansion * (2 ** (i + 1)))
|
243 |
+
src_down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1)))
|
244 |
+
self.src_down_blocks = nn.ModuleList(src_down_blocks)
|
245 |
+
|
246 |
+
# #driving depth
|
247 |
+
# self.dst_first = SameBlock2d(1, block_expansion, kernel_size=(7, 7), padding=(3, 3))
|
248 |
+
# dst_down_blocks = []
|
249 |
+
# for i in range(num_down_blocks):
|
250 |
+
# in_features = min(max_features, block_expansion * (2 ** i))
|
251 |
+
# out_features = min(max_features, block_expansion * (2 ** (i + 1)))
|
252 |
+
# dst_down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1)))
|
253 |
+
# self.dst_down_blocks = nn.ModuleList(dst_down_blocks)
|
254 |
+
|
255 |
+
self.AttnModule = DepthAwareAttention(out_features,nn.ReLU())
|
256 |
+
self.decoder = SPADEGenerator()
|
257 |
+
|
258 |
+
self.estimate_occlusion_map = estimate_occlusion_map
|
259 |
+
self.num_channels = num_channels
|
260 |
+
|
261 |
+
def deform_input(self, inp, deformation):
|
262 |
+
_, h_old, w_old, _ = deformation.shape
|
263 |
+
_, _, h, w = inp.shape
|
264 |
+
if h_old != h or w_old != w:
|
265 |
+
deformation = deformation.permute(0, 3, 1, 2)
|
266 |
+
deformation = F.interpolate(deformation, size=(h, w), mode='bilinear')
|
267 |
+
deformation = deformation.permute(0, 2, 3, 1)
|
268 |
+
return F.grid_sample(inp, deformation)
|
269 |
+
|
270 |
+
def forward(self, source_image, kp_driving, kp_source, source_depth, driving_depth):
|
271 |
+
# Encoding (downsampling) part
|
272 |
+
out = self.first(source_image)
|
273 |
+
for i in range(len(self.down_blocks)):
|
274 |
+
out = self.down_blocks[i](out)
|
275 |
+
|
276 |
+
src_out = self.src_first(source_depth)
|
277 |
+
for i in range(len(self.src_down_blocks)):
|
278 |
+
src_out = self.src_down_blocks[i](src_out)
|
279 |
+
|
280 |
+
# dst_out = self.dst_first(driving_depth)
|
281 |
+
# for i in range(len(self.down_blocks)):
|
282 |
+
# dst_out = self.dst_down_blocks[i](dst_out)
|
283 |
+
|
284 |
+
# Transforming feature representation according to deformation and occlusion
|
285 |
+
output_dict = {}
|
286 |
+
if self.dense_motion_network is not None:
|
287 |
+
dense_motion = self.dense_motion_network(source_image=source_image, kp_driving=kp_driving,
|
288 |
+
kp_source=kp_source)
|
289 |
+
output_dict['mask'] = dense_motion['mask']
|
290 |
+
output_dict['sparse_deformed'] = dense_motion['sparse_deformed']
|
291 |
+
|
292 |
+
if 'occlusion_map' in dense_motion:
|
293 |
+
occlusion_map = dense_motion['occlusion_map']
|
294 |
+
output_dict['occlusion_map'] = occlusion_map
|
295 |
+
else:
|
296 |
+
occlusion_map = None
|
297 |
+
deformation = dense_motion['deformation']
|
298 |
+
out = self.deform_input(out, deformation)
|
299 |
+
|
300 |
+
if occlusion_map is not None:
|
301 |
+
if out.shape[2] != occlusion_map.shape[2] or out.shape[3] != occlusion_map.shape[3]:
|
302 |
+
occlusion_map = F.interpolate(occlusion_map, size=out.shape[2:], mode='bilinear')
|
303 |
+
out = out * occlusion_map
|
304 |
+
|
305 |
+
out,attention = self.AttnModule(src_out,out)
|
306 |
+
|
307 |
+
deformed_image = self.deform_input(source_image, deformation)
|
308 |
+
output_dict["deformed"] = deformed_image
|
309 |
+
output_dict["attention"] = attention
|
310 |
+
|
311 |
+
if occlusion_map is not None:
|
312 |
+
if deformed_image.shape[2] != occlusion_map.shape[2] or deformed_image.shape[3] != occlusion_map.shape[3]:
|
313 |
+
occlusion_map = F.interpolate(occlusion_map, size=deformed_image.shape[2:], mode='bilinear')
|
314 |
+
deformed_image = deformed_image * occlusion_map
|
315 |
+
|
316 |
+
out = self.decoder(out, deformed_image)
|
317 |
+
|
318 |
+
# # Decoding part
|
319 |
+
# out = self.bottleneck(out)
|
320 |
+
# for i in range(len(self.up_blocks)):
|
321 |
+
# out = self.up_blocks[i](out)
|
322 |
+
# out = self.final(out)
|
323 |
+
# out = F.sigmoid(out)
|
324 |
+
output_dict["prediction"] = out
|
325 |
+
return output_dict
|
modules/keypoint_detector.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from modules.util import Hourglass, make_coordinate_grid, AntiAliasInterpolation2d,Hourglass_2branch
|
5 |
+
import pdb
|
6 |
+
|
7 |
+
class KPDetector(nn.Module):
|
8 |
+
"""
|
9 |
+
Detecting a keypoints. Return keypoint position and jacobian near each keypoint.
|
10 |
+
"""
|
11 |
+
|
12 |
+
def __init__(self, block_expansion, num_kp, num_channels, max_features,
|
13 |
+
num_blocks, temperature, estimate_jacobian=False, scale_factor=1,
|
14 |
+
single_jacobian_map=False, pad=0):
|
15 |
+
super(KPDetector, self).__init__()
|
16 |
+
self.predictor = Hourglass(block_expansion, in_features=num_channels,
|
17 |
+
max_features=max_features, num_blocks=num_blocks)
|
18 |
+
|
19 |
+
self.kp = nn.Conv2d(in_channels=self.predictor.out_filters, out_channels=num_kp, kernel_size=(7, 7),
|
20 |
+
padding=pad)
|
21 |
+
|
22 |
+
if estimate_jacobian:
|
23 |
+
self.num_jacobian_maps = 1 if single_jacobian_map else num_kp
|
24 |
+
self.jacobian = nn.Conv2d(in_channels=self.predictor.out_filters,
|
25 |
+
out_channels=4 * self.num_jacobian_maps, kernel_size=(7, 7), padding=pad)
|
26 |
+
self.jacobian.weight.data.zero_()
|
27 |
+
self.jacobian.bias.data.copy_(torch.tensor([1, 0, 0, 1] * self.num_jacobian_maps, dtype=torch.float))
|
28 |
+
else:
|
29 |
+
self.jacobian = None
|
30 |
+
|
31 |
+
self.temperature = temperature
|
32 |
+
self.scale_factor = scale_factor
|
33 |
+
if self.scale_factor != 1:
|
34 |
+
self.down = AntiAliasInterpolation2d(num_channels, self.scale_factor)
|
35 |
+
|
36 |
+
def gaussian2kp(self, heatmap):
|
37 |
+
"""
|
38 |
+
Extract the mean and from a heatmap
|
39 |
+
"""
|
40 |
+
shape = heatmap.shape
|
41 |
+
heatmap = heatmap.unsqueeze(-1)
|
42 |
+
grid = make_coordinate_grid(shape[2:], heatmap.type()).unsqueeze_(0).unsqueeze_(0)
|
43 |
+
value = (heatmap * grid).sum(dim=(2, 3))
|
44 |
+
kp = {'value': value}
|
45 |
+
|
46 |
+
return kp
|
47 |
+
|
48 |
+
def forward(self, x):
|
49 |
+
if self.scale_factor != 1:
|
50 |
+
x = self.down(x)
|
51 |
+
feature_map = self.predictor(x) #x bz,4,64,64
|
52 |
+
prediction = self.kp(feature_map)
|
53 |
+
|
54 |
+
final_shape = prediction.shape
|
55 |
+
heatmap = prediction.view(final_shape[0], final_shape[1], -1)
|
56 |
+
heatmap = F.softmax(heatmap / self.temperature, dim=2)
|
57 |
+
heatmap = heatmap.view(*final_shape)
|
58 |
+
|
59 |
+
out = self.gaussian2kp(heatmap)
|
60 |
+
|
61 |
+
if self.jacobian is not None:
|
62 |
+
jacobian_map = self.jacobian(feature_map)
|
63 |
+
# pdb.set_trace()
|
64 |
+
jacobian_map = jacobian_map.reshape(final_shape[0], self.num_jacobian_maps, 4, final_shape[2],
|
65 |
+
final_shape[3])
|
66 |
+
heatmap = heatmap.unsqueeze(2)
|
67 |
+
|
68 |
+
jacobian = heatmap * jacobian_map
|
69 |
+
jacobian = jacobian.view(final_shape[0], final_shape[1], 4, -1)
|
70 |
+
jacobian = jacobian.sum(dim=-1)
|
71 |
+
jacobian = jacobian.view(jacobian.shape[0], jacobian.shape[1], 2, 2)
|
72 |
+
out['jacobian'] = jacobian
|
73 |
+
|
74 |
+
return out
|
75 |
+
|
modules/model.py
ADDED
@@ -0,0 +1,362 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from modules.util import AntiAliasInterpolation2d, make_coordinate_grid
|
5 |
+
from torchvision import models
|
6 |
+
import numpy as np
|
7 |
+
from torch.autograd import grad
|
8 |
+
import pdb
|
9 |
+
import depth
|
10 |
+
|
11 |
+
class Vgg19(torch.nn.Module):
|
12 |
+
"""
|
13 |
+
Vgg19 network for perceptual loss. See Sec 3.3.
|
14 |
+
"""
|
15 |
+
def __init__(self, requires_grad=False):
|
16 |
+
super(Vgg19, self).__init__()
|
17 |
+
vgg_pretrained_features = models.vgg19(pretrained=True).features
|
18 |
+
self.slice1 = torch.nn.Sequential()
|
19 |
+
self.slice2 = torch.nn.Sequential()
|
20 |
+
self.slice3 = torch.nn.Sequential()
|
21 |
+
self.slice4 = torch.nn.Sequential()
|
22 |
+
self.slice5 = torch.nn.Sequential()
|
23 |
+
for x in range(2):
|
24 |
+
self.slice1.add_module(str(x), vgg_pretrained_features[x])
|
25 |
+
for x in range(2, 7):
|
26 |
+
self.slice2.add_module(str(x), vgg_pretrained_features[x])
|
27 |
+
for x in range(7, 12):
|
28 |
+
self.slice3.add_module(str(x), vgg_pretrained_features[x])
|
29 |
+
for x in range(12, 21):
|
30 |
+
self.slice4.add_module(str(x), vgg_pretrained_features[x])
|
31 |
+
for x in range(21, 30):
|
32 |
+
self.slice5.add_module(str(x), vgg_pretrained_features[x])
|
33 |
+
|
34 |
+
self.mean = torch.nn.Parameter(data=torch.Tensor(np.array([0.485, 0.456, 0.406]).reshape((1, 3, 1, 1))),
|
35 |
+
requires_grad=False)
|
36 |
+
self.std = torch.nn.Parameter(data=torch.Tensor(np.array([0.229, 0.224, 0.225]).reshape((1, 3, 1, 1))),
|
37 |
+
requires_grad=False)
|
38 |
+
|
39 |
+
if not requires_grad:
|
40 |
+
for param in self.parameters():
|
41 |
+
param.requires_grad = False
|
42 |
+
|
43 |
+
def forward(self, X):
|
44 |
+
X = (X - self.mean) / self.std
|
45 |
+
h_relu1 = self.slice1(X)
|
46 |
+
h_relu2 = self.slice2(h_relu1)
|
47 |
+
h_relu3 = self.slice3(h_relu2)
|
48 |
+
h_relu4 = self.slice4(h_relu3)
|
49 |
+
h_relu5 = self.slice5(h_relu4)
|
50 |
+
out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
|
51 |
+
return out
|
52 |
+
|
53 |
+
|
54 |
+
class ImagePyramide(torch.nn.Module):
|
55 |
+
"""
|
56 |
+
Create image pyramide for computing pyramide perceptual loss. See Sec 3.3
|
57 |
+
"""
|
58 |
+
def __init__(self, scales, num_channels):
|
59 |
+
super(ImagePyramide, self).__init__()
|
60 |
+
downs = {}
|
61 |
+
for scale in scales:
|
62 |
+
downs[str(scale).replace('.', '-')] = AntiAliasInterpolation2d(num_channels, scale)
|
63 |
+
self.downs = nn.ModuleDict(downs)
|
64 |
+
|
65 |
+
def forward(self, x):
|
66 |
+
out_dict = {}
|
67 |
+
for scale, down_module in self.downs.items():
|
68 |
+
out_dict['prediction_' + str(scale).replace('-', '.')] = down_module(x)
|
69 |
+
return out_dict
|
70 |
+
|
71 |
+
|
72 |
+
class Transform:
|
73 |
+
"""
|
74 |
+
Random tps transformation for equivariance constraints. See Sec 3.3
|
75 |
+
"""
|
76 |
+
def __init__(self, bs, **kwargs):
|
77 |
+
noise = torch.normal(mean=0, std=kwargs['sigma_affine'] * torch.ones([bs, 2, 3]))
|
78 |
+
self.theta = noise + torch.eye(2, 3).view(1, 2, 3)
|
79 |
+
self.bs = bs
|
80 |
+
|
81 |
+
if ('sigma_tps' in kwargs) and ('points_tps' in kwargs):
|
82 |
+
self.tps = True
|
83 |
+
self.control_points = make_coordinate_grid((kwargs['points_tps'], kwargs['points_tps']), type=noise.type())
|
84 |
+
self.control_points = self.control_points.unsqueeze(0)
|
85 |
+
self.control_params = torch.normal(mean=0,
|
86 |
+
std=kwargs['sigma_tps'] * torch.ones([bs, 1, kwargs['points_tps'] ** 2]))
|
87 |
+
else:
|
88 |
+
self.tps = False
|
89 |
+
|
90 |
+
def transform_frame(self, frame):
|
91 |
+
grid = make_coordinate_grid(frame.shape[2:], type=frame.type()).unsqueeze(0)
|
92 |
+
grid = grid.view(1, frame.shape[2] * frame.shape[3], 2)
|
93 |
+
grid = self.warp_coordinates(grid).view(self.bs, frame.shape[2], frame.shape[3], 2)
|
94 |
+
return F.grid_sample(frame, grid, padding_mode="reflection")
|
95 |
+
|
96 |
+
def warp_coordinates(self, coordinates):
|
97 |
+
theta = self.theta.type(coordinates.type())
|
98 |
+
theta = theta.unsqueeze(1)
|
99 |
+
transformed = torch.matmul(theta[:, :, :, :2], coordinates.unsqueeze(-1)) + theta[:, :, :, 2:]
|
100 |
+
transformed = transformed.squeeze(-1)
|
101 |
+
|
102 |
+
if self.tps:
|
103 |
+
control_points = self.control_points.type(coordinates.type())
|
104 |
+
control_params = self.control_params.type(coordinates.type())
|
105 |
+
distances = coordinates.view(coordinates.shape[0], -1, 1, 2) - control_points.view(1, 1, -1, 2)
|
106 |
+
distances = torch.abs(distances).sum(-1)
|
107 |
+
|
108 |
+
result = distances ** 2
|
109 |
+
result = result * torch.log(distances + 1e-6)
|
110 |
+
result = result * control_params
|
111 |
+
result = result.sum(dim=2).view(self.bs, coordinates.shape[1], 1)
|
112 |
+
transformed = transformed + result
|
113 |
+
|
114 |
+
return transformed
|
115 |
+
|
116 |
+
def jacobian(self, coordinates):
|
117 |
+
new_coordinates = self.warp_coordinates(coordinates)
|
118 |
+
grad_x = grad(new_coordinates[..., 0].sum(), coordinates, create_graph=True)
|
119 |
+
grad_y = grad(new_coordinates[..., 1].sum(), coordinates, create_graph=True)
|
120 |
+
jacobian = torch.cat([grad_x[0].unsqueeze(-2), grad_y[0].unsqueeze(-2)], dim=-2)
|
121 |
+
return jacobian
|
122 |
+
|
123 |
+
|
124 |
+
def detach_kp(kp):
|
125 |
+
return {key: value.detach() for key, value in kp.items()}
|
126 |
+
|
127 |
+
|
128 |
+
class GeneratorFullModel(torch.nn.Module):
|
129 |
+
"""
|
130 |
+
Merge all generator related updates into single model for better multi-gpu usage
|
131 |
+
"""
|
132 |
+
|
133 |
+
def __init__(self, kp_extractor, generator, discriminator, train_params,opt):
|
134 |
+
super(GeneratorFullModel, self).__init__()
|
135 |
+
self.kp_extractor = kp_extractor
|
136 |
+
self.generator = generator
|
137 |
+
self.discriminator = discriminator
|
138 |
+
self.train_params = train_params
|
139 |
+
self.scales = train_params['scales']
|
140 |
+
self.disc_scales = self.discriminator.module.scales
|
141 |
+
self.pyramid = ImagePyramide(self.scales, generator.module.num_channels)
|
142 |
+
if torch.cuda.is_available():
|
143 |
+
self.pyramid = self.pyramid.cuda()
|
144 |
+
self.opt = opt
|
145 |
+
self.loss_weights = train_params['loss_weights']
|
146 |
+
|
147 |
+
if sum(self.loss_weights['perceptual']) != 0:
|
148 |
+
self.vgg = Vgg19()
|
149 |
+
if torch.cuda.is_available():
|
150 |
+
self.vgg = self.vgg.cuda()
|
151 |
+
self.depth_encoder = depth.ResnetEncoder(18, False).cuda()
|
152 |
+
self.depth_decoder = depth.DepthDecoder(num_ch_enc=self.depth_encoder.num_ch_enc, scales=range(4)).cuda()
|
153 |
+
loaded_dict_enc = torch.load('depth/models/weights_19/encoder.pth',map_location='cpu')
|
154 |
+
loaded_dict_dec = torch.load('depth/models/weights_19/depth.pth',map_location='cpu')
|
155 |
+
filtered_dict_enc = {k: v for k, v in loaded_dict_enc.items() if k in self.depth_encoder.state_dict()}
|
156 |
+
self.depth_encoder.load_state_dict(filtered_dict_enc)
|
157 |
+
self.depth_decoder.load_state_dict(loaded_dict_dec)
|
158 |
+
self.set_requires_grad(self.depth_encoder, False)
|
159 |
+
self.set_requires_grad(self.depth_decoder, False)
|
160 |
+
self.depth_decoder.eval()
|
161 |
+
self.depth_encoder.eval()
|
162 |
+
def set_requires_grad(self, nets, requires_grad=False):
|
163 |
+
"""Set requies_grad=Fasle for all the networks to avoid unnecessary computations
|
164 |
+
Parameters:
|
165 |
+
nets (network list) -- a list of networks
|
166 |
+
requires_grad (bool) -- whether the networks require gradients or not
|
167 |
+
"""
|
168 |
+
if not isinstance(nets, list):
|
169 |
+
nets = [nets]
|
170 |
+
for net in nets:
|
171 |
+
if net is not None:
|
172 |
+
for param in net.parameters():
|
173 |
+
param.requires_grad = requires_grad
|
174 |
+
def forward(self, x):
|
175 |
+
depth_source = None
|
176 |
+
depth_driving = None
|
177 |
+
outputs = self.depth_decoder(self.depth_encoder(x['source']))
|
178 |
+
depth_source = outputs[("disp", 0)]
|
179 |
+
outputs = self.depth_decoder(self.depth_encoder(x['driving']))
|
180 |
+
depth_driving = outputs[("disp", 0)]
|
181 |
+
|
182 |
+
if self.opt.use_depth:
|
183 |
+
kp_source = self.kp_extractor(depth_source)
|
184 |
+
kp_driving = self.kp_extractor(depth_driving)
|
185 |
+
elif self.opt.rgbd:
|
186 |
+
source = torch.cat((x['source'],depth_source),1)
|
187 |
+
driving = torch.cat((x['driving'],depth_driving),1)
|
188 |
+
kp_source = self.kp_extractor(source)
|
189 |
+
kp_driving = self.kp_extractor(driving)
|
190 |
+
else:
|
191 |
+
kp_source = self.kp_extractor(x['source'])
|
192 |
+
kp_driving = self.kp_extractor(x['driving'])
|
193 |
+
generated = self.generator(x['source'], kp_source=kp_source, kp_driving=kp_driving, source_depth = depth_source, driving_depth = depth_driving)
|
194 |
+
generated.update({'kp_source': kp_source, 'kp_driving': kp_driving})
|
195 |
+
loss_values = {}
|
196 |
+
pyramide_real = self.pyramid(x['driving'])
|
197 |
+
pyramide_generated = self.pyramid(generated['prediction'])
|
198 |
+
if sum(self.loss_weights['perceptual']) != 0:
|
199 |
+
value_total = 0
|
200 |
+
for scale in self.scales:
|
201 |
+
x_vgg = self.vgg(pyramide_generated['prediction_' + str(scale)])
|
202 |
+
y_vgg = self.vgg(pyramide_real['prediction_' + str(scale)])
|
203 |
+
|
204 |
+
for i, weight in enumerate(self.loss_weights['perceptual']):
|
205 |
+
value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean()
|
206 |
+
value_total += self.loss_weights['perceptual'][i] * value
|
207 |
+
loss_values['perceptual'] = value_total
|
208 |
+
|
209 |
+
if self.loss_weights['generator_gan'] != 0:
|
210 |
+
|
211 |
+
discriminator_maps_generated = self.discriminator(pyramide_generated, kp=detach_kp(kp_driving))
|
212 |
+
|
213 |
+
discriminator_maps_real = self.discriminator(pyramide_real, kp=detach_kp(kp_driving))
|
214 |
+
value_total = 0
|
215 |
+
for scale in self.disc_scales:
|
216 |
+
key = 'prediction_map_%s' % scale
|
217 |
+
value = ((1 - discriminator_maps_generated[key]) ** 2).mean()
|
218 |
+
value_total += self.loss_weights['generator_gan'] * value
|
219 |
+
loss_values['gen_gan'] = value_total
|
220 |
+
|
221 |
+
if sum(self.loss_weights['feature_matching']) != 0:
|
222 |
+
value_total = 0
|
223 |
+
for scale in self.disc_scales:
|
224 |
+
key = 'feature_maps_%s' % scale
|
225 |
+
for i, (a, b) in enumerate(zip(discriminator_maps_real[key], discriminator_maps_generated[key])):
|
226 |
+
if self.loss_weights['feature_matching'][i] == 0:
|
227 |
+
continue
|
228 |
+
value = torch.abs(a - b).mean()
|
229 |
+
value_total += self.loss_weights['feature_matching'][i] * value
|
230 |
+
loss_values['feature_matching'] = value_total
|
231 |
+
|
232 |
+
if (self.loss_weights['equivariance_value'] + self.loss_weights['equivariance_jacobian']) != 0:
|
233 |
+
transform = Transform(x['driving'].shape[0], **self.train_params['transform_params'])
|
234 |
+
transformed_frame = transform.transform_frame(x['driving'])
|
235 |
+
if self.opt.use_depth:
|
236 |
+
outputs = self.depth_decoder(self.depth_encoder(transformed_frame))
|
237 |
+
depth_transform = outputs[("disp", 0)]
|
238 |
+
transformed_kp = self.kp_extractor(depth_transform)
|
239 |
+
elif self.opt.rgbd:
|
240 |
+
outputs = self.depth_decoder(self.depth_encoder(transformed_frame))
|
241 |
+
depth_transform = outputs[("disp", 0)]
|
242 |
+
transform_img = torch.cat((transformed_frame,depth_transform),1)
|
243 |
+
transformed_kp = self.kp_extractor(transform_img)
|
244 |
+
else:
|
245 |
+
transformed_kp = self.kp_extractor(transformed_frame)
|
246 |
+
|
247 |
+
generated['transformed_frame'] = transformed_frame
|
248 |
+
generated['transformed_kp'] = transformed_kp
|
249 |
+
|
250 |
+
## Value loss part
|
251 |
+
if self.loss_weights['equivariance_value'] != 0:
|
252 |
+
value = torch.abs(kp_driving['value'] - transform.warp_coordinates(transformed_kp['value'])).mean()
|
253 |
+
loss_values['equivariance_value'] = self.loss_weights['equivariance_value'] * value
|
254 |
+
|
255 |
+
## jacobian loss part
|
256 |
+
if self.loss_weights['equivariance_jacobian'] != 0:
|
257 |
+
jacobian_transformed = torch.matmul(transform.jacobian(transformed_kp['value']),
|
258 |
+
transformed_kp['jacobian'])
|
259 |
+
|
260 |
+
normed_driving = torch.inverse(kp_driving['jacobian'])
|
261 |
+
normed_transformed = jacobian_transformed
|
262 |
+
value = torch.matmul(normed_driving, normed_transformed)
|
263 |
+
|
264 |
+
eye = torch.eye(2).view(1, 1, 2, 2).type(value.type())
|
265 |
+
|
266 |
+
value = torch.abs(eye - value).mean()
|
267 |
+
loss_values['equivariance_jacobian'] = self.loss_weights['equivariance_jacobian'] * value
|
268 |
+
|
269 |
+
|
270 |
+
if self.loss_weights['kp_distance']:
|
271 |
+
bz,num_kp,kp_dim = kp_source['value'].shape
|
272 |
+
sk = kp_source['value'].unsqueeze(2)-kp_source['value'].unsqueeze(1)
|
273 |
+
dk = kp_driving['value'].unsqueeze(2)-kp_driving['value'].unsqueeze(1)
|
274 |
+
source_dist_loss = (-torch.sign((torch.sqrt((sk*sk).sum(-1)+1e-8)+torch.eye(num_kp).cuda()*0.2)-0.2)+1).mean()
|
275 |
+
driving_dist_loss = (-torch.sign((torch.sqrt((dk*dk).sum(-1)+1e-8)+torch.eye(num_kp).cuda()*0.2)-0.2)+1).mean()
|
276 |
+
# driving_dist_loss = (torch.sign(1-(torch.sqrt((dk*dk).sum(-1)+1e-8)+torch.eye(num_kp).cuda()))+1).mean()
|
277 |
+
value_total = self.loss_weights['kp_distance']*(source_dist_loss+driving_dist_loss)
|
278 |
+
loss_values['kp_distance'] = value_total
|
279 |
+
if self.loss_weights['kp_prior']:
|
280 |
+
bz,num_kp,kp_dim = kp_source['value'].shape
|
281 |
+
sk = kp_source['value'].unsqueeze(2)-kp_source['value'].unsqueeze(1)
|
282 |
+
dk = kp_driving['value'].unsqueeze(2)-kp_driving['value'].unsqueeze(1)
|
283 |
+
dis_loss = torch.relu(0.1-torch.sqrt((sk*sk).sum(-1)+1e-8))+torch.relu(0.1-torch.sqrt((dk*dk).sum(-1)+1e-8))
|
284 |
+
bs,nk,_=kp_source['value'].shape
|
285 |
+
scoor_depth = F.grid_sample(depth_source,kp_source['value'].view(bs,1,nk,-1))
|
286 |
+
dcoor_depth = F.grid_sample(depth_driving,kp_driving['value'].view(bs,1,nk,-1))
|
287 |
+
sd_loss = torch.abs(scoor_depth.mean(-1,keepdim=True) - kp_source['value'].view(bs,1,nk,-1)).mean()
|
288 |
+
dd_loss = torch.abs(dcoor_depth.mean(-1,keepdim=True) - kp_driving['value'].view(bs,1,nk,-1)).mean()
|
289 |
+
value_total = self.loss_weights['kp_distance']*(dis_loss+sd_loss+dd_loss)
|
290 |
+
loss_values['kp_distance'] = value_total
|
291 |
+
|
292 |
+
|
293 |
+
if self.loss_weights['kp_scale']:
|
294 |
+
bz,num_kp,kp_dim = kp_source['value'].shape
|
295 |
+
if self.opt.rgbd:
|
296 |
+
outputs = self.depth_decoder(self.depth_encoder(generated['prediction']))
|
297 |
+
depth_pred = outputs[("disp", 0)]
|
298 |
+
pred = torch.cat((generated['prediction'],depth_pred),1)
|
299 |
+
kp_pred = self.kp_extractor(pred)
|
300 |
+
elif self.opt.use_depth:
|
301 |
+
outputs = self.depth_decoder(self.depth_encoder(generated['prediction']))
|
302 |
+
depth_pred = outputs[("disp", 0)]
|
303 |
+
kp_pred = self.kp_extractor(depth_pred)
|
304 |
+
else:
|
305 |
+
kp_pred = self.kp_extractor(generated['prediction'])
|
306 |
+
|
307 |
+
pred_mean = kp_pred['value'].mean(1,keepdim=True)
|
308 |
+
driving_mean = kp_driving['value'].mean(1,keepdim=True)
|
309 |
+
pk = kp_source['value']-pred_mean
|
310 |
+
dk = kp_driving['value']- driving_mean
|
311 |
+
pred_dist_loss = torch.sqrt((pk*pk).sum(-1)+1e-8)
|
312 |
+
driving_dist_loss = torch.sqrt((dk*dk).sum(-1)+1e-8)
|
313 |
+
scale_vec = driving_dist_loss/pred_dist_loss
|
314 |
+
bz,n = scale_vec.shape
|
315 |
+
value = torch.abs(scale_vec[:,:n-1]-scale_vec[:,1:]).mean()
|
316 |
+
value_total = self.loss_weights['kp_scale']*value
|
317 |
+
loss_values['kp_scale'] = value_total
|
318 |
+
if self.loss_weights['depth_constraint']:
|
319 |
+
bz,num_kp,kp_dim = kp_source['value'].shape
|
320 |
+
outputs = self.depth_decoder(self.depth_encoder(generated['prediction']))
|
321 |
+
depth_pred = outputs[("disp", 0)]
|
322 |
+
value_total = self.loss_weights['depth_constraint']*torch.abs(depth_driving-depth_pred).mean()
|
323 |
+
loss_values['depth_constraint'] = value_total
|
324 |
+
return loss_values, generated
|
325 |
+
|
326 |
+
|
327 |
+
|
328 |
+
class DiscriminatorFullModel(torch.nn.Module):
|
329 |
+
"""
|
330 |
+
Merge all discriminator related updates into single model for better multi-gpu usage
|
331 |
+
"""
|
332 |
+
|
333 |
+
def __init__(self, kp_extractor, generator, discriminator, train_params):
|
334 |
+
super(DiscriminatorFullModel, self).__init__()
|
335 |
+
self.kp_extractor = kp_extractor
|
336 |
+
self.generator = generator
|
337 |
+
self.discriminator = discriminator
|
338 |
+
self.train_params = train_params
|
339 |
+
self.scales = self.discriminator.module.scales
|
340 |
+
self.pyramid = ImagePyramide(self.scales, generator.module.num_channels)
|
341 |
+
if torch.cuda.is_available():
|
342 |
+
self.pyramid = self.pyramid.cuda()
|
343 |
+
|
344 |
+
self.loss_weights = train_params['loss_weights']
|
345 |
+
|
346 |
+
def forward(self, x, generated):
|
347 |
+
pyramide_real = self.pyramid(x['driving'])
|
348 |
+
pyramide_generated = self.pyramid(generated['prediction'].detach())
|
349 |
+
|
350 |
+
kp_driving = generated['kp_driving']
|
351 |
+
discriminator_maps_generated = self.discriminator(pyramide_generated, kp=detach_kp(kp_driving))
|
352 |
+
discriminator_maps_real = self.discriminator(pyramide_real, kp=detach_kp(kp_driving))
|
353 |
+
|
354 |
+
loss_values = {}
|
355 |
+
value_total = 0
|
356 |
+
for scale in self.scales:
|
357 |
+
key = 'prediction_map_%s' % scale
|
358 |
+
value = (1 - discriminator_maps_real[key]) ** 2 + discriminator_maps_generated[key] ** 2
|
359 |
+
value_total += self.loss_weights['discriminator_gan'] * value.mean()
|
360 |
+
loss_values['disc_gan'] = value_total
|
361 |
+
|
362 |
+
return loss_values
|
modules/util.py
ADDED
@@ -0,0 +1,399 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from sync_batchnorm import SynchronizedBatchNorm2d as BatchNorm2d
|
7 |
+
import pdb
|
8 |
+
import torch.nn.utils.spectral_norm as spectral_norm
|
9 |
+
def kp2gaussian(kp, spatial_size, kp_variance):
|
10 |
+
"""
|
11 |
+
Transform a keypoint into gaussian like representation
|
12 |
+
"""
|
13 |
+
mean = kp['value']
|
14 |
+
|
15 |
+
coordinate_grid = make_coordinate_grid(spatial_size, mean.type())
|
16 |
+
number_of_leading_dimensions = len(mean.shape) - 1
|
17 |
+
shape = (1,) * number_of_leading_dimensions + coordinate_grid.shape
|
18 |
+
coordinate_grid = coordinate_grid.view(*shape)
|
19 |
+
repeats = mean.shape[:number_of_leading_dimensions] + (1, 1, 1)
|
20 |
+
coordinate_grid = coordinate_grid.repeat(*repeats)
|
21 |
+
|
22 |
+
# Preprocess kp shape
|
23 |
+
shape = mean.shape[:number_of_leading_dimensions] + (1, 1, 2)
|
24 |
+
mean = mean.view(*shape)
|
25 |
+
|
26 |
+
mean_sub = (coordinate_grid - mean)
|
27 |
+
|
28 |
+
out = torch.exp(-0.5 * (mean_sub ** 2).sum(-1) / kp_variance)
|
29 |
+
|
30 |
+
return out
|
31 |
+
|
32 |
+
|
33 |
+
def make_coordinate_grid(spatial_size, type):
|
34 |
+
"""
|
35 |
+
Create a meshgrid [-1,1] x [-1,1] of given spatial_size.
|
36 |
+
"""
|
37 |
+
h, w = spatial_size
|
38 |
+
x = torch.arange(w).type(type)
|
39 |
+
y = torch.arange(h).type(type)
|
40 |
+
|
41 |
+
x = (2 * (x / (w - 1)) - 1)
|
42 |
+
y = (2 * (y / (h - 1)) - 1)
|
43 |
+
|
44 |
+
yy = y.view(-1, 1).repeat(1, w)
|
45 |
+
xx = x.view(1, -1).repeat(h, 1)
|
46 |
+
|
47 |
+
meshed = torch.cat([xx.unsqueeze_(2), yy.unsqueeze_(2)], 2)
|
48 |
+
|
49 |
+
return meshed
|
50 |
+
|
51 |
+
|
52 |
+
class ResBlock2d(nn.Module):
|
53 |
+
"""
|
54 |
+
Res block, preserve spatial resolution.
|
55 |
+
"""
|
56 |
+
|
57 |
+
def __init__(self, in_features, kernel_size, padding):
|
58 |
+
super(ResBlock2d, self).__init__()
|
59 |
+
self.conv1 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size,
|
60 |
+
padding=padding)
|
61 |
+
self.conv2 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size,
|
62 |
+
padding=padding)
|
63 |
+
self.norm1 = BatchNorm2d(in_features, affine=True)
|
64 |
+
self.norm2 = BatchNorm2d(in_features, affine=True)
|
65 |
+
|
66 |
+
def forward(self, x):
|
67 |
+
out = self.norm1(x)
|
68 |
+
out = F.relu(out)
|
69 |
+
out = self.conv1(out)
|
70 |
+
out = self.norm2(out)
|
71 |
+
out = F.relu(out)
|
72 |
+
out = self.conv2(out)
|
73 |
+
out += x
|
74 |
+
return out
|
75 |
+
|
76 |
+
|
77 |
+
class UpBlock2d(nn.Module):
|
78 |
+
"""
|
79 |
+
Upsampling block for use in decoder.
|
80 |
+
"""
|
81 |
+
|
82 |
+
def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):
|
83 |
+
super(UpBlock2d, self).__init__()
|
84 |
+
|
85 |
+
self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,
|
86 |
+
padding=padding, groups=groups)
|
87 |
+
self.norm = BatchNorm2d(out_features, affine=True)
|
88 |
+
|
89 |
+
def forward(self, x):
|
90 |
+
out = F.interpolate(x, scale_factor=2)
|
91 |
+
out = self.conv(out)
|
92 |
+
out = self.norm(out)
|
93 |
+
out = F.relu(out)
|
94 |
+
return out
|
95 |
+
|
96 |
+
|
97 |
+
class DownBlock2d(nn.Module):
|
98 |
+
"""
|
99 |
+
Downsampling block for use in encoder.
|
100 |
+
"""
|
101 |
+
|
102 |
+
def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):
|
103 |
+
super(DownBlock2d, self).__init__()
|
104 |
+
self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,
|
105 |
+
padding=padding, groups=groups)
|
106 |
+
self.norm = BatchNorm2d(out_features, affine=True)
|
107 |
+
self.pool = nn.AvgPool2d(kernel_size=(2, 2))
|
108 |
+
|
109 |
+
def forward(self, x):
|
110 |
+
out = self.conv(x)
|
111 |
+
out = self.norm(out)
|
112 |
+
out = F.relu(out)
|
113 |
+
out = self.pool(out)
|
114 |
+
return out
|
115 |
+
|
116 |
+
|
117 |
+
class SameBlock2d(nn.Module):
|
118 |
+
"""
|
119 |
+
Simple block, preserve spatial resolution.
|
120 |
+
"""
|
121 |
+
|
122 |
+
def __init__(self, in_features, out_features, groups=1, kernel_size=3, padding=1):
|
123 |
+
super(SameBlock2d, self).__init__()
|
124 |
+
self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features,
|
125 |
+
kernel_size=kernel_size, padding=padding, groups=groups)
|
126 |
+
self.norm = BatchNorm2d(out_features, affine=True)
|
127 |
+
|
128 |
+
def forward(self, x):
|
129 |
+
out = self.conv(x)
|
130 |
+
out = self.norm(out)
|
131 |
+
out = F.relu(out)
|
132 |
+
return out
|
133 |
+
|
134 |
+
|
135 |
+
class Encoder(nn.Module):
|
136 |
+
"""
|
137 |
+
Hourglass Encoder
|
138 |
+
"""
|
139 |
+
|
140 |
+
def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
|
141 |
+
super(Encoder, self).__init__()
|
142 |
+
|
143 |
+
down_blocks = []
|
144 |
+
for i in range(num_blocks):
|
145 |
+
down_blocks.append(DownBlock2d(in_features if i == 0 else min(max_features, block_expansion * (2 ** i)),
|
146 |
+
min(max_features, block_expansion * (2 ** (i + 1))),
|
147 |
+
kernel_size=3, padding=1))
|
148 |
+
self.down_blocks = nn.ModuleList(down_blocks)
|
149 |
+
|
150 |
+
def forward(self, x):
|
151 |
+
outs = [x]
|
152 |
+
for down_block in self.down_blocks:
|
153 |
+
outs.append(down_block(outs[-1]))
|
154 |
+
return outs
|
155 |
+
|
156 |
+
|
157 |
+
class Decoder(nn.Module):
|
158 |
+
"""
|
159 |
+
Hourglass Decoder
|
160 |
+
"""
|
161 |
+
|
162 |
+
def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
|
163 |
+
super(Decoder, self).__init__()
|
164 |
+
|
165 |
+
up_blocks = []
|
166 |
+
|
167 |
+
for i in range(num_blocks)[::-1]:
|
168 |
+
in_filters = (1 if i == num_blocks - 1 else 2) * min(max_features, block_expansion * (2 ** (i + 1)))
|
169 |
+
out_filters = min(max_features, block_expansion * (2 ** i))
|
170 |
+
up_blocks.append(UpBlock2d(in_filters, out_filters, kernel_size=3, padding=1))
|
171 |
+
|
172 |
+
self.up_blocks = nn.ModuleList(up_blocks)
|
173 |
+
self.out_filters = block_expansion + in_features
|
174 |
+
|
175 |
+
def forward(self, x):
|
176 |
+
out = x.pop()
|
177 |
+
for up_block in self.up_blocks:
|
178 |
+
out = up_block(out)
|
179 |
+
skip = x.pop()
|
180 |
+
out = torch.cat([out, skip], dim=1)
|
181 |
+
return out
|
182 |
+
|
183 |
+
|
184 |
+
class Decoder_w_emb(nn.Module):
|
185 |
+
"""
|
186 |
+
Hourglass Decoder
|
187 |
+
"""
|
188 |
+
|
189 |
+
def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
|
190 |
+
super(Decoder_w_emb, self).__init__()
|
191 |
+
|
192 |
+
up_blocks = []
|
193 |
+
|
194 |
+
for i in range(num_blocks)[::-1]:
|
195 |
+
in_filters = (1 if i == num_blocks - 1 else 2) * min(max_features, block_expansion * (2 ** (i + 1)))
|
196 |
+
out_filters = min(max_features, block_expansion * (2 ** i))
|
197 |
+
up_blocks.append(UpBlock2d(in_filters, out_filters, kernel_size=3, padding=1))
|
198 |
+
|
199 |
+
self.up_blocks = nn.ModuleList(up_blocks)
|
200 |
+
self.out_filters = block_expansion + in_features
|
201 |
+
|
202 |
+
def forward(self, x):
|
203 |
+
feats = []
|
204 |
+
out = x.pop()
|
205 |
+
feats.append(out)
|
206 |
+
for ind,up_block in enumerate(self.up_blocks):
|
207 |
+
out = up_block(out)
|
208 |
+
skip = x.pop()
|
209 |
+
feats.append(skip)
|
210 |
+
out = torch.cat([out, skip], dim=1)
|
211 |
+
return out,feats
|
212 |
+
|
213 |
+
class Decoder_2branch(nn.Module):
|
214 |
+
"""
|
215 |
+
Hourglass Decoder
|
216 |
+
"""
|
217 |
+
|
218 |
+
def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
|
219 |
+
super(Decoder_2branch, self).__init__()
|
220 |
+
up_blocks = []
|
221 |
+
for i in range(num_blocks)[::-1]:
|
222 |
+
in_filters = (1 if i == num_blocks - 1 else 2) * min(max_features, block_expansion * (2 ** (i + 1)))
|
223 |
+
out_filters = min(max_features, block_expansion * (2 ** i))
|
224 |
+
up_blocks.append(UpBlock2d(in_filters, out_filters, kernel_size=3, padding=1))
|
225 |
+
|
226 |
+
self.up_blocks = nn.ModuleList(up_blocks)
|
227 |
+
self.out_filters = block_expansion + in_features
|
228 |
+
|
229 |
+
def forward(self, x):
|
230 |
+
# out = x.pop()
|
231 |
+
num_feat = len(x)
|
232 |
+
out=x[-1]
|
233 |
+
for i in range(len(self.up_blocks)):
|
234 |
+
out = self.up_blocks[i](out)
|
235 |
+
skip = x[-(i+1+1)]
|
236 |
+
out = torch.cat([out, skip], dim=1)
|
237 |
+
return out
|
238 |
+
|
239 |
+
|
240 |
+
|
241 |
+
class Hourglass(nn.Module):
|
242 |
+
"""
|
243 |
+
Hourglass architecture.
|
244 |
+
"""
|
245 |
+
|
246 |
+
def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
|
247 |
+
super(Hourglass, self).__init__()
|
248 |
+
self.encoder = Encoder(block_expansion, in_features, num_blocks, max_features)
|
249 |
+
self.decoder = Decoder(block_expansion, in_features, num_blocks, max_features)
|
250 |
+
self.out_filters = self.decoder.out_filters
|
251 |
+
def forward(self, x):
|
252 |
+
return self.decoder(self.encoder(x))
|
253 |
+
|
254 |
+
class Hourglass_2branch(nn.Module):
|
255 |
+
"""
|
256 |
+
Hourglass architecture.
|
257 |
+
"""
|
258 |
+
|
259 |
+
def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
|
260 |
+
super(Hourglass_2branch, self).__init__()
|
261 |
+
self.encoder = Encoder(block_expansion, in_features, num_blocks, max_features)
|
262 |
+
self.decoder_kp = Decoder_2branch(block_expansion, in_features, num_blocks, max_features)
|
263 |
+
self.decoder_mask = Decoder_2branch(block_expansion, in_features, num_blocks, max_features)
|
264 |
+
|
265 |
+
self.out_filters = self.decoder_kp.out_filters
|
266 |
+
def forward(self, x):
|
267 |
+
embd= self.encoder(x)
|
268 |
+
kp_feat = self.decoder_kp(embd)
|
269 |
+
mask_feat = self.decoder_mask(embd)
|
270 |
+
return kp_feat,mask_feat
|
271 |
+
|
272 |
+
|
273 |
+
class Hourglass_w_emb(nn.Module):
|
274 |
+
"""
|
275 |
+
Hourglass architecture.
|
276 |
+
"""
|
277 |
+
|
278 |
+
def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
|
279 |
+
super(Hourglass_w_emb, self).__init__()
|
280 |
+
self.encoder = Encoder(block_expansion, in_features, num_blocks, max_features)
|
281 |
+
self.decoder = Decoder_w_emb(block_expansion, in_features, num_blocks, max_features)
|
282 |
+
self.out_filters = self.decoder.out_filters
|
283 |
+
|
284 |
+
def forward(self, x):
|
285 |
+
embs = self.encoder(x)
|
286 |
+
result,feats = self.decoder(embs)
|
287 |
+
return feats,result
|
288 |
+
class AntiAliasInterpolation2d(nn.Module):
|
289 |
+
"""
|
290 |
+
Band-limited downsampling, for better preservation of the input signal.
|
291 |
+
"""
|
292 |
+
def __init__(self, channels, scale):
|
293 |
+
super(AntiAliasInterpolation2d, self).__init__()
|
294 |
+
sigma = (1 / scale - 1) / 2
|
295 |
+
kernel_size = 2 * round(sigma * 4) + 1
|
296 |
+
self.ka = kernel_size // 2
|
297 |
+
self.kb = self.ka - 1 if kernel_size % 2 == 0 else self.ka
|
298 |
+
|
299 |
+
kernel_size = [kernel_size, kernel_size]
|
300 |
+
sigma = [sigma, sigma]
|
301 |
+
# The gaussian kernel is the product of the
|
302 |
+
# gaussian function of each dimension.
|
303 |
+
kernel = 1
|
304 |
+
meshgrids = torch.meshgrid(
|
305 |
+
[
|
306 |
+
torch.arange(size, dtype=torch.float32)
|
307 |
+
for size in kernel_size
|
308 |
+
]
|
309 |
+
)
|
310 |
+
for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
|
311 |
+
mean = (size - 1) / 2
|
312 |
+
kernel *= torch.exp(-(mgrid - mean) ** 2 / (2 * std ** 2))
|
313 |
+
|
314 |
+
# Make sure sum of values in gaussian kernel equals 1.
|
315 |
+
kernel = kernel / torch.sum(kernel)
|
316 |
+
# Reshape to depthwise convolutional weight
|
317 |
+
kernel = kernel.view(1, 1, *kernel.size())
|
318 |
+
kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))
|
319 |
+
|
320 |
+
self.register_buffer('weight', kernel)
|
321 |
+
self.groups = channels
|
322 |
+
self.scale = scale
|
323 |
+
inv_scale = 1 / scale
|
324 |
+
self.int_inv_scale = int(inv_scale)
|
325 |
+
|
326 |
+
def forward(self, input):
|
327 |
+
if self.scale == 1.0:
|
328 |
+
return input
|
329 |
+
|
330 |
+
out = F.pad(input, (self.ka, self.kb, self.ka, self.kb))
|
331 |
+
out = F.conv2d(out, weight=self.weight, groups=self.groups)
|
332 |
+
out = out[:, :, ::self.int_inv_scale, ::self.int_inv_scale]
|
333 |
+
|
334 |
+
return out
|
335 |
+
|
336 |
+
|
337 |
+
class SPADE(nn.Module):
|
338 |
+
def __init__(self, norm_nc, label_nc):
|
339 |
+
super().__init__()
|
340 |
+
|
341 |
+
self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False)
|
342 |
+
nhidden = 128
|
343 |
+
|
344 |
+
self.mlp_shared = nn.Sequential(
|
345 |
+
nn.Conv2d(label_nc, nhidden, kernel_size=3, padding=1),
|
346 |
+
nn.ReLU())
|
347 |
+
self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1)
|
348 |
+
self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1)
|
349 |
+
|
350 |
+
def forward(self, x, segmap):
|
351 |
+
normalized = self.param_free_norm(x)
|
352 |
+
segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest')
|
353 |
+
actv = self.mlp_shared(segmap)
|
354 |
+
gamma = self.mlp_gamma(actv)
|
355 |
+
beta = self.mlp_beta(actv)
|
356 |
+
out = normalized * (1 + gamma) + beta
|
357 |
+
return out
|
358 |
+
|
359 |
+
|
360 |
+
class SPADEResnetBlock(nn.Module):
|
361 |
+
def __init__(self, fin, fout, norm_G, label_nc, use_se=False, dilation=1):
|
362 |
+
super().__init__()
|
363 |
+
# Attributes
|
364 |
+
self.learned_shortcut = (fin != fout)
|
365 |
+
fmiddle = min(fin, fout)
|
366 |
+
self.use_se = use_se
|
367 |
+
# create conv layers
|
368 |
+
self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=dilation, dilation=dilation)
|
369 |
+
self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=dilation, dilation=dilation)
|
370 |
+
if self.learned_shortcut:
|
371 |
+
self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False)
|
372 |
+
# apply spectral norm if specified
|
373 |
+
if 'spectral' in norm_G:
|
374 |
+
self.conv_0 = spectral_norm(self.conv_0)
|
375 |
+
self.conv_1 = spectral_norm(self.conv_1)
|
376 |
+
if self.learned_shortcut:
|
377 |
+
self.conv_s = spectral_norm(self.conv_s)
|
378 |
+
# define normalization layers
|
379 |
+
self.norm_0 = SPADE(fin, label_nc)
|
380 |
+
self.norm_1 = SPADE(fmiddle, label_nc)
|
381 |
+
if self.learned_shortcut:
|
382 |
+
self.norm_s = SPADE(fin, label_nc)
|
383 |
+
|
384 |
+
def forward(self, x, seg1):
|
385 |
+
x_s = self.shortcut(x, seg1)
|
386 |
+
dx = self.conv_0(self.actvn(self.norm_0(x, seg1)))
|
387 |
+
dx = self.conv_1(self.actvn(self.norm_1(dx, seg1)))
|
388 |
+
out = x_s + dx
|
389 |
+
return out
|
390 |
+
|
391 |
+
def shortcut(self, x, seg1):
|
392 |
+
if self.learned_shortcut:
|
393 |
+
x_s = self.conv_s(self.norm_s(x, seg1))
|
394 |
+
else:
|
395 |
+
x_s = x
|
396 |
+
return x_s
|
397 |
+
|
398 |
+
def actvn(self, x):
|
399 |
+
return F.leaky_relu(x, 2e-1)
|
project/cartoon2.jpg
ADDED
project/cartoon3.jpg
ADDED
project/cartoon4.jpg
ADDED
project/cartoon5.jpg
ADDED
project/cartoon6.jpg
ADDED
project/cartoon7.jpg
ADDED
project/celeb1.jpg
ADDED
project/celeb2.jpg
ADDED
project/celeb3.jpg
ADDED
project/celeb4.jpg
ADDED
project/celeb6.jpg
ADDED
project/celeb7.jpg
ADDED
project/celeb8.jpg
ADDED
project/video1.mp4
ADDED
Binary file (152 kB). View file
|
|
project/video2.mp4
ADDED
Binary file (271 kB). View file
|
|
sync_batchnorm/__init__.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# File : __init__.py
|
3 |
+
# Author : Jiayuan Mao
|
4 |
+
# Email : [email protected]
|
5 |
+
# Date : 27/01/2018
|
6 |
+
#
|
7 |
+
# This file is part of Synchronized-BatchNorm-PyTorch.
|
8 |
+
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
9 |
+
# Distributed under MIT License.
|
10 |
+
|
11 |
+
from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d
|
12 |
+
from .replicate import DataParallelWithCallback, patch_replication_callback
|
sync_batchnorm/batchnorm.py
ADDED
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# File : batchnorm.py
|
3 |
+
# Author : Jiayuan Mao
|
4 |
+
# Email : [email protected]
|
5 |
+
# Date : 27/01/2018
|
6 |
+
#
|
7 |
+
# This file is part of Synchronized-BatchNorm-PyTorch.
|
8 |
+
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
9 |
+
# Distributed under MIT License.
|
10 |
+
|
11 |
+
import collections
|
12 |
+
|
13 |
+
import torch
|
14 |
+
import torch.nn.functional as F
|
15 |
+
|
16 |
+
from torch.nn.modules.batchnorm import _BatchNorm
|
17 |
+
from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
|
18 |
+
|
19 |
+
from .comm import SyncMaster
|
20 |
+
|
21 |
+
__all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d']
|
22 |
+
|
23 |
+
|
24 |
+
def _sum_ft(tensor):
|
25 |
+
"""sum over the first and last dimention"""
|
26 |
+
return tensor.sum(dim=0).sum(dim=-1)
|
27 |
+
|
28 |
+
|
29 |
+
def _unsqueeze_ft(tensor):
|
30 |
+
"""add new dementions at the front and the tail"""
|
31 |
+
return tensor.unsqueeze(0).unsqueeze(-1)
|
32 |
+
|
33 |
+
|
34 |
+
_ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size'])
|
35 |
+
_MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std'])
|
36 |
+
|
37 |
+
|
38 |
+
class _SynchronizedBatchNorm(_BatchNorm):
|
39 |
+
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True):
|
40 |
+
super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine)
|
41 |
+
|
42 |
+
self._sync_master = SyncMaster(self._data_parallel_master)
|
43 |
+
|
44 |
+
self._is_parallel = False
|
45 |
+
self._parallel_id = None
|
46 |
+
self._slave_pipe = None
|
47 |
+
|
48 |
+
def forward(self, input):
|
49 |
+
# If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.
|
50 |
+
if not (self._is_parallel and self.training):
|
51 |
+
return F.batch_norm(
|
52 |
+
input, self.running_mean, self.running_var, self.weight, self.bias,
|
53 |
+
self.training, self.momentum, self.eps)
|
54 |
+
|
55 |
+
# Resize the input to (B, C, -1).
|
56 |
+
input_shape = input.size()
|
57 |
+
input = input.view(input.size(0), self.num_features, -1)
|
58 |
+
|
59 |
+
# Compute the sum and square-sum.
|
60 |
+
sum_size = input.size(0) * input.size(2)
|
61 |
+
input_sum = _sum_ft(input)
|
62 |
+
input_ssum = _sum_ft(input ** 2)
|
63 |
+
|
64 |
+
# Reduce-and-broadcast the statistics.
|
65 |
+
if self._parallel_id == 0:
|
66 |
+
mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))
|
67 |
+
else:
|
68 |
+
mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))
|
69 |
+
|
70 |
+
# Compute the output.
|
71 |
+
if self.affine:
|
72 |
+
# MJY:: Fuse the multiplication for speed.
|
73 |
+
output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias)
|
74 |
+
else:
|
75 |
+
output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std)
|
76 |
+
|
77 |
+
# Reshape it.
|
78 |
+
return output.view(input_shape)
|
79 |
+
|
80 |
+
def __data_parallel_replicate__(self, ctx, copy_id):
|
81 |
+
self._is_parallel = True
|
82 |
+
self._parallel_id = copy_id
|
83 |
+
|
84 |
+
# parallel_id == 0 means master device.
|
85 |
+
if self._parallel_id == 0:
|
86 |
+
ctx.sync_master = self._sync_master
|
87 |
+
else:
|
88 |
+
self._slave_pipe = ctx.sync_master.register_slave(copy_id)
|
89 |
+
|
90 |
+
def _data_parallel_master(self, intermediates):
|
91 |
+
"""Reduce the sum and square-sum, compute the statistics, and broadcast it."""
|
92 |
+
|
93 |
+
# Always using same "device order" makes the ReduceAdd operation faster.
|
94 |
+
# Thanks to:: Tete Xiao (http://tetexiao.com/)
|
95 |
+
intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device())
|
96 |
+
|
97 |
+
to_reduce = [i[1][:2] for i in intermediates]
|
98 |
+
to_reduce = [j for i in to_reduce for j in i] # flatten
|
99 |
+
target_gpus = [i[1].sum.get_device() for i in intermediates]
|
100 |
+
|
101 |
+
sum_size = sum([i[1].sum_size for i in intermediates])
|
102 |
+
sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)
|
103 |
+
mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)
|
104 |
+
|
105 |
+
broadcasted = Broadcast.apply(target_gpus, mean, inv_std)
|
106 |
+
|
107 |
+
outputs = []
|
108 |
+
for i, rec in enumerate(intermediates):
|
109 |
+
outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2])))
|
110 |
+
|
111 |
+
return outputs
|
112 |
+
|
113 |
+
def _compute_mean_std(self, sum_, ssum, size):
|
114 |
+
"""Compute the mean and standard-deviation with sum and square-sum. This method
|
115 |
+
also maintains the moving average on the master device."""
|
116 |
+
assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'
|
117 |
+
mean = sum_ / size
|
118 |
+
sumvar = ssum - sum_ * mean
|
119 |
+
unbias_var = sumvar / (size - 1)
|
120 |
+
bias_var = sumvar / size
|
121 |
+
|
122 |
+
self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
|
123 |
+
self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data
|
124 |
+
|
125 |
+
return mean, bias_var.clamp(self.eps) ** -0.5
|
126 |
+
|
127 |
+
|
128 |
+
class SynchronizedBatchNorm1d(_SynchronizedBatchNorm):
|
129 |
+
r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a
|
130 |
+
mini-batch.
|
131 |
+
|
132 |
+
.. math::
|
133 |
+
|
134 |
+
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
|
135 |
+
|
136 |
+
This module differs from the built-in PyTorch BatchNorm1d as the mean and
|
137 |
+
standard-deviation are reduced across all devices during training.
|
138 |
+
|
139 |
+
For example, when one uses `nn.DataParallel` to wrap the network during
|
140 |
+
training, PyTorch's implementation normalize the tensor on each device using
|
141 |
+
the statistics only on that device, which accelerated the computation and
|
142 |
+
is also easy to implement, but the statistics might be inaccurate.
|
143 |
+
Instead, in this synchronized version, the statistics will be computed
|
144 |
+
over all training samples distributed on multiple devices.
|
145 |
+
|
146 |
+
Note that, for one-GPU or CPU-only case, this module behaves exactly same
|
147 |
+
as the built-in PyTorch implementation.
|
148 |
+
|
149 |
+
The mean and standard-deviation are calculated per-dimension over
|
150 |
+
the mini-batches and gamma and beta are learnable parameter vectors
|
151 |
+
of size C (where C is the input size).
|
152 |
+
|
153 |
+
During training, this layer keeps a running estimate of its computed mean
|
154 |
+
and variance. The running sum is kept with a default momentum of 0.1.
|
155 |
+
|
156 |
+
During evaluation, this running mean/variance is used for normalization.
|
157 |
+
|
158 |
+
Because the BatchNorm is done over the `C` dimension, computing statistics
|
159 |
+
on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm
|
160 |
+
|
161 |
+
Args:
|
162 |
+
num_features: num_features from an expected input of size
|
163 |
+
`batch_size x num_features [x width]`
|
164 |
+
eps: a value added to the denominator for numerical stability.
|
165 |
+
Default: 1e-5
|
166 |
+
momentum: the value used for the running_mean and running_var
|
167 |
+
computation. Default: 0.1
|
168 |
+
affine: a boolean value that when set to ``True``, gives the layer learnable
|
169 |
+
affine parameters. Default: ``True``
|
170 |
+
|
171 |
+
Shape:
|
172 |
+
- Input: :math:`(N, C)` or :math:`(N, C, L)`
|
173 |
+
- Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
|
174 |
+
|
175 |
+
Examples:
|
176 |
+
>>> # With Learnable Parameters
|
177 |
+
>>> m = SynchronizedBatchNorm1d(100)
|
178 |
+
>>> # Without Learnable Parameters
|
179 |
+
>>> m = SynchronizedBatchNorm1d(100, affine=False)
|
180 |
+
>>> input = torch.autograd.Variable(torch.randn(20, 100))
|
181 |
+
>>> output = m(input)
|
182 |
+
"""
|
183 |
+
|
184 |
+
def _check_input_dim(self, input):
|
185 |
+
if input.dim() != 2 and input.dim() != 3:
|
186 |
+
raise ValueError('expected 2D or 3D input (got {}D input)'
|
187 |
+
.format(input.dim()))
|
188 |
+
super(SynchronizedBatchNorm1d, self)._check_input_dim(input)
|
189 |
+
|
190 |
+
|
191 |
+
class SynchronizedBatchNorm2d(_SynchronizedBatchNorm):
|
192 |
+
r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch
|
193 |
+
of 3d inputs
|
194 |
+
|
195 |
+
.. math::
|
196 |
+
|
197 |
+
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
|
198 |
+
|
199 |
+
This module differs from the built-in PyTorch BatchNorm2d as the mean and
|
200 |
+
standard-deviation are reduced across all devices during training.
|
201 |
+
|
202 |
+
For example, when one uses `nn.DataParallel` to wrap the network during
|
203 |
+
training, PyTorch's implementation normalize the tensor on each device using
|
204 |
+
the statistics only on that device, which accelerated the computation and
|
205 |
+
is also easy to implement, but the statistics might be inaccurate.
|
206 |
+
Instead, in this synchronized version, the statistics will be computed
|
207 |
+
over all training samples distributed on multiple devices.
|
208 |
+
|
209 |
+
Note that, for one-GPU or CPU-only case, this module behaves exactly same
|
210 |
+
as the built-in PyTorch implementation.
|
211 |
+
|
212 |
+
The mean and standard-deviation are calculated per-dimension over
|
213 |
+
the mini-batches and gamma and beta are learnable parameter vectors
|
214 |
+
of size C (where C is the input size).
|
215 |
+
|
216 |
+
During training, this layer keeps a running estimate of its computed mean
|
217 |
+
and variance. The running sum is kept with a default momentum of 0.1.
|
218 |
+
|
219 |
+
During evaluation, this running mean/variance is used for normalization.
|
220 |
+
|
221 |
+
Because the BatchNorm is done over the `C` dimension, computing statistics
|
222 |
+
on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm
|
223 |
+
|
224 |
+
Args:
|
225 |
+
num_features: num_features from an expected input of
|
226 |
+
size batch_size x num_features x height x width
|
227 |
+
eps: a value added to the denominator for numerical stability.
|
228 |
+
Default: 1e-5
|
229 |
+
momentum: the value used for the running_mean and running_var
|
230 |
+
computation. Default: 0.1
|
231 |
+
affine: a boolean value that when set to ``True``, gives the layer learnable
|
232 |
+
affine parameters. Default: ``True``
|
233 |
+
|
234 |
+
Shape:
|
235 |
+
- Input: :math:`(N, C, H, W)`
|
236 |
+
- Output: :math:`(N, C, H, W)` (same shape as input)
|
237 |
+
|
238 |
+
Examples:
|
239 |
+
>>> # With Learnable Parameters
|
240 |
+
>>> m = SynchronizedBatchNorm2d(100)
|
241 |
+
>>> # Without Learnable Parameters
|
242 |
+
>>> m = SynchronizedBatchNorm2d(100, affine=False)
|
243 |
+
>>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45))
|
244 |
+
>>> output = m(input)
|
245 |
+
"""
|
246 |
+
|
247 |
+
def _check_input_dim(self, input):
|
248 |
+
if input.dim() != 4:
|
249 |
+
raise ValueError('expected 4D input (got {}D input)'
|
250 |
+
.format(input.dim()))
|
251 |
+
super(SynchronizedBatchNorm2d, self)._check_input_dim(input)
|
252 |
+
|
253 |
+
|
254 |
+
class SynchronizedBatchNorm3d(_SynchronizedBatchNorm):
|
255 |
+
r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch
|
256 |
+
of 4d inputs
|
257 |
+
|
258 |
+
.. math::
|
259 |
+
|
260 |
+
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
|
261 |
+
|
262 |
+
This module differs from the built-in PyTorch BatchNorm3d as the mean and
|
263 |
+
standard-deviation are reduced across all devices during training.
|
264 |
+
|
265 |
+
For example, when one uses `nn.DataParallel` to wrap the network during
|
266 |
+
training, PyTorch's implementation normalize the tensor on each device using
|
267 |
+
the statistics only on that device, which accelerated the computation and
|
268 |
+
is also easy to implement, but the statistics might be inaccurate.
|
269 |
+
Instead, in this synchronized version, the statistics will be computed
|
270 |
+
over all training samples distributed on multiple devices.
|
271 |
+
|
272 |
+
Note that, for one-GPU or CPU-only case, this module behaves exactly same
|
273 |
+
as the built-in PyTorch implementation.
|
274 |
+
|
275 |
+
The mean and standard-deviation are calculated per-dimension over
|
276 |
+
the mini-batches and gamma and beta are learnable parameter vectors
|
277 |
+
of size C (where C is the input size).
|
278 |
+
|
279 |
+
During training, this layer keeps a running estimate of its computed mean
|
280 |
+
and variance. The running sum is kept with a default momentum of 0.1.
|
281 |
+
|
282 |
+
During evaluation, this running mean/variance is used for normalization.
|
283 |
+
|
284 |
+
Because the BatchNorm is done over the `C` dimension, computing statistics
|
285 |
+
on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm
|
286 |
+
or Spatio-temporal BatchNorm
|
287 |
+
|
288 |
+
Args:
|
289 |
+
num_features: num_features from an expected input of
|
290 |
+
size batch_size x num_features x depth x height x width
|
291 |
+
eps: a value added to the denominator for numerical stability.
|
292 |
+
Default: 1e-5
|
293 |
+
momentum: the value used for the running_mean and running_var
|
294 |
+
computation. Default: 0.1
|
295 |
+
affine: a boolean value that when set to ``True``, gives the layer learnable
|
296 |
+
affine parameters. Default: ``True``
|
297 |
+
|
298 |
+
Shape:
|
299 |
+
- Input: :math:`(N, C, D, H, W)`
|
300 |
+
- Output: :math:`(N, C, D, H, W)` (same shape as input)
|
301 |
+
|
302 |
+
Examples:
|
303 |
+
>>> # With Learnable Parameters
|
304 |
+
>>> m = SynchronizedBatchNorm3d(100)
|
305 |
+
>>> # Without Learnable Parameters
|
306 |
+
>>> m = SynchronizedBatchNorm3d(100, affine=False)
|
307 |
+
>>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10))
|
308 |
+
>>> output = m(input)
|
309 |
+
"""
|
310 |
+
|
311 |
+
def _check_input_dim(self, input):
|
312 |
+
if input.dim() != 5:
|
313 |
+
raise ValueError('expected 5D input (got {}D input)'
|
314 |
+
.format(input.dim()))
|
315 |
+
super(SynchronizedBatchNorm3d, self)._check_input_dim(input)
|
sync_batchnorm/comm.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# File : comm.py
|
3 |
+
# Author : Jiayuan Mao
|
4 |
+
# Email : [email protected]
|
5 |
+
# Date : 27/01/2018
|
6 |
+
#
|
7 |
+
# This file is part of Synchronized-BatchNorm-PyTorch.
|
8 |
+
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
9 |
+
# Distributed under MIT License.
|
10 |
+
|
11 |
+
import queue
|
12 |
+
import collections
|
13 |
+
import threading
|
14 |
+
|
15 |
+
__all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']
|
16 |
+
|
17 |
+
|
18 |
+
class FutureResult(object):
|
19 |
+
"""A thread-safe future implementation. Used only as one-to-one pipe."""
|
20 |
+
|
21 |
+
def __init__(self):
|
22 |
+
self._result = None
|
23 |
+
self._lock = threading.Lock()
|
24 |
+
self._cond = threading.Condition(self._lock)
|
25 |
+
|
26 |
+
def put(self, result):
|
27 |
+
with self._lock:
|
28 |
+
assert self._result is None, 'Previous result has\'t been fetched.'
|
29 |
+
self._result = result
|
30 |
+
self._cond.notify()
|
31 |
+
|
32 |
+
def get(self):
|
33 |
+
with self._lock:
|
34 |
+
if self._result is None:
|
35 |
+
self._cond.wait()
|
36 |
+
|
37 |
+
res = self._result
|
38 |
+
self._result = None
|
39 |
+
return res
|
40 |
+
|
41 |
+
|
42 |
+
_MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
|
43 |
+
_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
|
44 |
+
|
45 |
+
|
46 |
+
class SlavePipe(_SlavePipeBase):
|
47 |
+
"""Pipe for master-slave communication."""
|
48 |
+
|
49 |
+
def run_slave(self, msg):
|
50 |
+
self.queue.put((self.identifier, msg))
|
51 |
+
ret = self.result.get()
|
52 |
+
self.queue.put(True)
|
53 |
+
return ret
|
54 |
+
|
55 |
+
|
56 |
+
class SyncMaster(object):
|
57 |
+
"""An abstract `SyncMaster` object.
|
58 |
+
|
59 |
+
- During the replication, as the data parallel will trigger an callback of each module, all slave devices should
|
60 |
+
call `register(id)` and obtain an `SlavePipe` to communicate with the master.
|
61 |
+
- During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
|
62 |
+
and passed to a registered callback.
|
63 |
+
- After receiving the messages, the master device should gather the information and determine to message passed
|
64 |
+
back to each slave devices.
|
65 |
+
"""
|
66 |
+
|
67 |
+
def __init__(self, master_callback):
|
68 |
+
"""
|
69 |
+
|
70 |
+
Args:
|
71 |
+
master_callback: a callback to be invoked after having collected messages from slave devices.
|
72 |
+
"""
|
73 |
+
self._master_callback = master_callback
|
74 |
+
self._queue = queue.Queue()
|
75 |
+
self._registry = collections.OrderedDict()
|
76 |
+
self._activated = False
|
77 |
+
|
78 |
+
def __getstate__(self):
|
79 |
+
return {'master_callback': self._master_callback}
|
80 |
+
|
81 |
+
def __setstate__(self, state):
|
82 |
+
self.__init__(state['master_callback'])
|
83 |
+
|
84 |
+
def register_slave(self, identifier):
|
85 |
+
"""
|
86 |
+
Register an slave device.
|
87 |
+
|
88 |
+
Args:
|
89 |
+
identifier: an identifier, usually is the device id.
|
90 |
+
|
91 |
+
Returns: a `SlavePipe` object which can be used to communicate with the master device.
|
92 |
+
|
93 |
+
"""
|
94 |
+
if self._activated:
|
95 |
+
assert self._queue.empty(), 'Queue is not clean before next initialization.'
|
96 |
+
self._activated = False
|
97 |
+
self._registry.clear()
|
98 |
+
future = FutureResult()
|
99 |
+
self._registry[identifier] = _MasterRegistry(future)
|
100 |
+
return SlavePipe(identifier, self._queue, future)
|
101 |
+
|
102 |
+
def run_master(self, master_msg):
|
103 |
+
"""
|
104 |
+
Main entry for the master device in each forward pass.
|
105 |
+
The messages were first collected from each devices (including the master device), and then
|
106 |
+
an callback will be invoked to compute the message to be sent back to each devices
|
107 |
+
(including the master device).
|
108 |
+
|
109 |
+
Args:
|
110 |
+
master_msg: the message that the master want to send to itself. This will be placed as the first
|
111 |
+
message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
|
112 |
+
|
113 |
+
Returns: the message to be sent back to the master device.
|
114 |
+
|
115 |
+
"""
|
116 |
+
self._activated = True
|
117 |
+
|
118 |
+
intermediates = [(0, master_msg)]
|
119 |
+
for i in range(self.nr_slaves):
|
120 |
+
intermediates.append(self._queue.get())
|
121 |
+
|
122 |
+
results = self._master_callback(intermediates)
|
123 |
+
assert results[0][0] == 0, 'The first result should belongs to the master.'
|
124 |
+
|
125 |
+
for i, res in results:
|
126 |
+
if i == 0:
|
127 |
+
continue
|
128 |
+
self._registry[i].result.put(res)
|
129 |
+
|
130 |
+
for i in range(self.nr_slaves):
|
131 |
+
assert self._queue.get() is True
|
132 |
+
|
133 |
+
return results[0][1]
|
134 |
+
|
135 |
+
@property
|
136 |
+
def nr_slaves(self):
|
137 |
+
return len(self._registry)
|
sync_batchnorm/replicate.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# File : replicate.py
|
3 |
+
# Author : Jiayuan Mao
|
4 |
+
# Email : [email protected]
|
5 |
+
# Date : 27/01/2018
|
6 |
+
#
|
7 |
+
# This file is part of Synchronized-BatchNorm-PyTorch.
|
8 |
+
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
9 |
+
# Distributed under MIT License.
|
10 |
+
|
11 |
+
import functools
|
12 |
+
|
13 |
+
from torch.nn.parallel.data_parallel import DataParallel
|
14 |
+
|
15 |
+
__all__ = [
|
16 |
+
'CallbackContext',
|
17 |
+
'execute_replication_callbacks',
|
18 |
+
'DataParallelWithCallback',
|
19 |
+
'patch_replication_callback'
|
20 |
+
]
|
21 |
+
|
22 |
+
|
23 |
+
class CallbackContext(object):
|
24 |
+
pass
|
25 |
+
|
26 |
+
|
27 |
+
def execute_replication_callbacks(modules):
|
28 |
+
"""
|
29 |
+
Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.
|
30 |
+
|
31 |
+
The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
|
32 |
+
|
33 |
+
Note that, as all modules are isomorphism, we assign each sub-module with a context
|
34 |
+
(shared among multiple copies of this module on different devices).
|
35 |
+
Through this context, different copies can share some information.
|
36 |
+
|
37 |
+
We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback
|
38 |
+
of any slave copies.
|
39 |
+
"""
|
40 |
+
master_copy = modules[0]
|
41 |
+
nr_modules = len(list(master_copy.modules()))
|
42 |
+
ctxs = [CallbackContext() for _ in range(nr_modules)]
|
43 |
+
|
44 |
+
for i, module in enumerate(modules):
|
45 |
+
for j, m in enumerate(module.modules()):
|
46 |
+
if hasattr(m, '__data_parallel_replicate__'):
|
47 |
+
m.__data_parallel_replicate__(ctxs[j], i)
|
48 |
+
|
49 |
+
|
50 |
+
class DataParallelWithCallback(DataParallel):
|
51 |
+
"""
|
52 |
+
Data Parallel with a replication callback.
|
53 |
+
|
54 |
+
An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by
|
55 |
+
original `replicate` function.
|
56 |
+
The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
|
57 |
+
|
58 |
+
Examples:
|
59 |
+
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
|
60 |
+
> sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
|
61 |
+
# sync_bn.__data_parallel_replicate__ will be invoked.
|
62 |
+
"""
|
63 |
+
|
64 |
+
def replicate(self, module, device_ids):
|
65 |
+
modules = super(DataParallelWithCallback, self).replicate(module, device_ids)
|
66 |
+
execute_replication_callbacks(modules)
|
67 |
+
return modules
|
68 |
+
|
69 |
+
|
70 |
+
def patch_replication_callback(data_parallel):
|
71 |
+
"""
|
72 |
+
Monkey-patch an existing `DataParallel` object. Add the replication callback.
|
73 |
+
Useful when you have customized `DataParallel` implementation.
|
74 |
+
|
75 |
+
Examples:
|
76 |
+
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
|
77 |
+
> sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
|
78 |
+
> patch_replication_callback(sync_bn)
|
79 |
+
# this is equivalent to
|
80 |
+
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
|
81 |
+
> sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
|
82 |
+
"""
|
83 |
+
|
84 |
+
assert isinstance(data_parallel, DataParallel)
|
85 |
+
|
86 |
+
old_replicate = data_parallel.replicate
|
87 |
+
|
88 |
+
@functools.wraps(old_replicate)
|
89 |
+
def new_replicate(module, device_ids):
|
90 |
+
modules = old_replicate(module, device_ids)
|
91 |
+
execute_replication_callbacks(modules)
|
92 |
+
return modules
|
93 |
+
|
94 |
+
data_parallel.replicate = new_replicate
|
sync_batchnorm/unittest.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# File : unittest.py
|
3 |
+
# Author : Jiayuan Mao
|
4 |
+
# Email : [email protected]
|
5 |
+
# Date : 27/01/2018
|
6 |
+
#
|
7 |
+
# This file is part of Synchronized-BatchNorm-PyTorch.
|
8 |
+
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
9 |
+
# Distributed under MIT License.
|
10 |
+
|
11 |
+
import unittest
|
12 |
+
|
13 |
+
import numpy as np
|
14 |
+
from torch.autograd import Variable
|
15 |
+
|
16 |
+
|
17 |
+
def as_numpy(v):
|
18 |
+
if isinstance(v, Variable):
|
19 |
+
v = v.data
|
20 |
+
return v.cpu().numpy()
|
21 |
+
|
22 |
+
|
23 |
+
class TorchTestCase(unittest.TestCase):
|
24 |
+
def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3):
|
25 |
+
npa, npb = as_numpy(a), as_numpy(b)
|
26 |
+
self.assertTrue(
|
27 |
+
np.allclose(npa, npb, atol=atol),
|
28 |
+
'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max())
|
29 |
+
)
|