aaronb commited on
Commit
ab9cd73
1 Parent(s): 8e1efb4
.gitignore ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ #.idea/
README.md CHANGED
@@ -5,7 +5,7 @@ colorFrom: indigo
5
  colorTo: blue
6
  sdk: gradio
7
  sdk_version: 3.29.0
8
- app_file: app.py
9
  pinned: false
10
  ---
11
 
 
5
  colorTo: blue
6
  sdk: gradio
7
  sdk_version: 3.29.0
8
+ app_file: gradio_app.py
9
  pinned: false
10
  ---
11
 
drag_gan.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import os
3
+ import random
4
+ import urllib.request
5
+
6
+ import torch
7
+ import torch.nn.functional as FF
8
+ import torch.optim
9
+ from torchvision import utils
10
+ from tqdm import tqdm
11
+
12
+ from stylegan2.model import Generator
13
+
14
+
15
+ class DownloadProgressBar(tqdm):
16
+ def update_to(self, b=1, bsize=1, tsize=None):
17
+ if tsize is not None:
18
+ self.total = tsize
19
+ self.update(b * bsize - self.n)
20
+
21
+
22
+ def get_path(base_path):
23
+ BASE_DIR = os.path.join('checkpoints')
24
+
25
+ save_path = os.path.join(BASE_DIR, base_path)
26
+ if not os.path.exists(save_path):
27
+ url = f"https://huggingface.co/aaronb/StyleGAN2/resolve/main/{base_path}"
28
+ print(f'{base_path} not found')
29
+ print('Try to download from huggingface: ', url)
30
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
31
+ download_url(url, save_path)
32
+ print('Downloaded to ', save_path)
33
+ return save_path
34
+
35
+
36
+ def download_url(url, output_path):
37
+ with DownloadProgressBar(unit='B', unit_scale=True,
38
+ miniters=1, desc=url.split('/')[-1]) as t:
39
+ urllib.request.urlretrieve(url, filename=output_path, reporthook=t.update_to)
40
+
41
+
42
+ class CustomGenerator(Generator):
43
+ def prepare(
44
+ self,
45
+ styles,
46
+ inject_index=None,
47
+ truncation=1,
48
+ truncation_latent=None,
49
+ input_is_latent=False,
50
+ noise=None,
51
+ randomize_noise=True,
52
+ ):
53
+ if not input_is_latent:
54
+ styles = [self.style(s) for s in styles]
55
+
56
+ if noise is None:
57
+ if randomize_noise:
58
+ noise = [None] * self.num_layers
59
+ else:
60
+ noise = [
61
+ getattr(self.noises, f"noise_{i}") for i in range(self.num_layers)
62
+ ]
63
+
64
+ if truncation < 1:
65
+ style_t = []
66
+
67
+ for style in styles:
68
+ style_t.append(
69
+ truncation_latent + truncation * (style - truncation_latent)
70
+ )
71
+
72
+ styles = style_t
73
+
74
+ if len(styles) < 2:
75
+ inject_index = self.n_latent
76
+
77
+ if styles[0].ndim < 3:
78
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
79
+
80
+ else:
81
+ latent = styles[0]
82
+
83
+ else:
84
+ if inject_index is None:
85
+ inject_index = random.randint(1, self.n_latent - 1)
86
+
87
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
88
+ latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)
89
+
90
+ latent = torch.cat([latent, latent2], 1)
91
+
92
+ return latent, noise
93
+
94
+ def generate(
95
+ self,
96
+ latent,
97
+ noise,
98
+ ):
99
+ out = self.input(latent)
100
+ out = self.conv1(out, latent[:, 0], noise=noise[0])
101
+
102
+ skip = self.to_rgb1(out, latent[:, 1])
103
+ i = 1
104
+ for conv1, conv2, noise1, noise2, to_rgb in zip(
105
+ self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
106
+ ):
107
+ out = conv1(out, latent[:, i], noise=noise1)
108
+ out = conv2(out, latent[:, i + 1], noise=noise2)
109
+ skip = to_rgb(out, latent[:, i + 2], skip)
110
+ if out.shape[-1] == 256: F = out
111
+ i += 2
112
+
113
+ image = skip
114
+ F = FF.interpolate(F, image.shape[-2:], mode='bilinear')
115
+ return image, F
116
+
117
+
118
+ def stylegan2(
119
+ size=1024,
120
+ channel_multiplier=2,
121
+ latent=512,
122
+ n_mlp=8,
123
+ ckpt='stylegan2-ffhq-config-f.pt'
124
+ ):
125
+ g_ema = CustomGenerator(size, latent, n_mlp, channel_multiplier=channel_multiplier)
126
+ checkpoint = torch.load(get_path(ckpt))
127
+ g_ema.load_state_dict(checkpoint["g_ema"], strict=False)
128
+ g_ema.requires_grad_(False)
129
+ g_ema.eval()
130
+ return g_ema
131
+
132
+
133
+ def bilinear_interpolate_torch(im, y, x):
134
+ """
135
+ im : B,C,H,W
136
+ y : 1,numPoints -- pixel location y float
137
+ x : 1,numPOints -- pixel location y float
138
+ """
139
+ device = im.device
140
+
141
+ x0 = torch.floor(x).long().to(device)
142
+ x1 = x0 + 1
143
+
144
+ y0 = torch.floor(y).long().to(device)
145
+ y1 = y0 + 1
146
+
147
+ wa = ((x1.float() - x) * (y1.float() - y)).to(device)
148
+ wb = ((x1.float() - x) * (y - y0.float())).to(device)
149
+ wc = ((x - x0.float()) * (y1.float() - y)).to(device)
150
+ wd = ((x - x0.float()) * (y - y0.float())).to(device)
151
+ # Instead of clamp
152
+ x1 = x1 - torch.floor(x1 / im.shape[3]).int().to(device)
153
+ y1 = y1 - torch.floor(y1 / im.shape[2]).int().to(device)
154
+ Ia = im[:, :, y0, x0]
155
+ Ib = im[:, :, y1, x0]
156
+ Ic = im[:, :, y0, x1]
157
+ Id = im[:, :, y1, x1]
158
+
159
+ return Ia * wa + Ib * wb + Ic * wc + Id * wd
160
+
161
+
162
+ def drag_gan(g_ema, latent: torch.Tensor, noise, F, handle_points, target_points, mask, max_iters=1000):
163
+ handle_points0 = copy.deepcopy(handle_points)
164
+ n = len(handle_points)
165
+ r1, r2, lam, d = 3, 12, 20, 1
166
+
167
+ def neighbor(x, y, d):
168
+ points = []
169
+ for i in range(x - d, x + d):
170
+ for j in range(y - d, y + d):
171
+ points.append(torch.tensor([i, j]).float().cuda())
172
+ return points
173
+
174
+ F0 = F.detach().clone()
175
+
176
+ latent_trainable = latent[:, :6, :].detach().clone().requires_grad_(True)
177
+ latent_untrainable = latent[:, 6:, :].detach().clone().requires_grad_(False)
178
+ optimizer = torch.optim.Adam([latent_trainable], lr=2e-3)
179
+ for iter in range(max_iters):
180
+ for s in range(1):
181
+ optimizer.zero_grad()
182
+ latent = torch.cat([latent_trainable, latent_untrainable], dim=1)
183
+ sample2, F2 = g_ema.generate(latent, noise)
184
+
185
+ # motion supervision
186
+ loss = 0
187
+ for i in range(n):
188
+ pi, ti = handle_points[i], target_points[i]
189
+ di = (ti - pi) / torch.sum((ti - pi)**2)
190
+
191
+ for qi in neighbor(int(pi[0]), int(pi[1]), r1):
192
+ # f1 = F[..., int(qi[0]), int(qi[1])]
193
+ # f2 = F2[..., int(qi[0] + di[0]), int(qi[1] + di[1])]
194
+ f1 = bilinear_interpolate_torch(F2, qi[0], qi[1]).detach()
195
+ f2 = bilinear_interpolate_torch(F2, qi[0] + di[0], qi[1] + di[1])
196
+ loss += FF.l1_loss(f2, f1)
197
+
198
+ if mask is not None:
199
+ loss += ((F2 - F0) * (1 - mask)).abs().mean() * lam
200
+
201
+ loss.backward()
202
+ optimizer.step()
203
+
204
+ # point tracking
205
+ with torch.no_grad():
206
+ sample2, F2 = g_ema.generate(latent, noise)
207
+ for i in range(n):
208
+ pi = handle_points0[i]
209
+ # f = F0[..., int(pi[0]), int(pi[1])]
210
+ f0 = bilinear_interpolate_torch(F0, pi[0], pi[1])
211
+ minv = 1e9
212
+ minx = 1e9
213
+ miny = 1e9
214
+ for qi in neighbor(int(handle_points[i][0]), int(handle_points[i][1]), r2):
215
+ # f2 = F2[..., int(qi[0]), int(qi[1])]
216
+ try:
217
+ f2 = bilinear_interpolate_torch(F2, qi[0], qi[1])
218
+ except:
219
+ import ipdb
220
+ ipdb.set_trace()
221
+ v = torch.norm(f2 - f0, p=1)
222
+ if v < minv:
223
+ minv = v
224
+ minx = int(qi[0])
225
+ miny = int(qi[1])
226
+ handle_points[i][0] = minx
227
+ handle_points[i][1] = miny
228
+
229
+ F = F2.detach().clone()
230
+ if iter % 1 == 0:
231
+ print(iter, loss.item(), handle_points, target_points)
232
+ # p = handle_points[0].int()
233
+ # sample2[0, :, p[0] - 5:p[0] + 5, p[1] - 5:p[1] + 5] = sample2[0, :, p[0] - 5:p[0] + 5, p[1] - 5:p[1] + 5] * 0
234
+ # t = target_points[0].int()
235
+ # sample2[0, :, t[0] - 5:t[0] + 5, t[1] - 5:t[1] + 5] = sample2[0, :, t[0] - 5:t[0] + 5, t[1] - 5:t[1] + 5] * 255
236
+
237
+ # sample2[0, :, 210, 134] = sample2[0, :, 210, 134] * 0
238
+ # utils.save_image(sample2, "test2.png", normalize=True, range=(-1, 1))
239
+
240
+ yield sample2, latent, F2, handle_points
gradio_app.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import torch
4
+ import numpy as np
5
+ import imageio
6
+ from PIL import Image
7
+ import uuid
8
+
9
+ from drag_gan import drag_gan, stylegan2
10
+ from stylegan2.inversion import inverse_image
11
+
12
+ device = 'cpu'
13
+
14
+
15
+ SIZE_TO_CLICK_SIZE = {
16
+ 1024: 8,
17
+ 512: 5,
18
+ 256: 2
19
+ }
20
+
21
+ CKPT_SIZE = {
22
+ 'stylegan2-ffhq-config-f.pt': 1024,
23
+ 'stylegan2-cat-config-f.pt': 256,
24
+ 'stylegan2-church-config-f.pt': 256,
25
+ 'stylegan2-horse-config-f.pt': 256,
26
+ 'ada/ffhq.pt': 1024,
27
+ 'ada/afhqcat.pt': 512,
28
+ 'ada/afhqdog.pt': 512,
29
+ 'ada/afhqwild.pt': 512,
30
+ 'ada/brecahad.pt': 512,
31
+ 'ada/metfaces.pt': 512,
32
+ }
33
+
34
+ DEFAULT_CKPT = 'stylegan2-ffhq-config-f.pt'
35
+
36
+
37
+ class grImage(gr.components.Image):
38
+ is_template = True
39
+
40
+ def preprocess(self, x):
41
+ if x is None:
42
+ return x
43
+ if self.tool == "sketch" and self.source in ["upload", "webcam"]:
44
+ decode_image = gr.processing_utils.decode_base64_to_image(x)
45
+ width, height = decode_image.size
46
+ mask = np.zeros((height, width, 4), dtype=np.uint8)
47
+ mask[..., -1] = 255
48
+ mask = self.postprocess(mask)
49
+ x = {'image': x, 'mask': mask}
50
+ return super().preprocess(x)
51
+
52
+
53
+ class ImageMask(gr.components.Image):
54
+ """
55
+ Sets: source="canvas", tool="sketch"
56
+ """
57
+
58
+ is_template = True
59
+
60
+ def __init__(self, **kwargs):
61
+ super().__init__(source="upload", tool="sketch", interactive=True, **kwargs)
62
+
63
+ def preprocess(self, x):
64
+ if x is None:
65
+ return x
66
+ if self.tool == "sketch" and self.source in ["upload", "webcam"] and type(x) != dict:
67
+ decode_image = gr.processing_utils.decode_base64_to_image(x)
68
+ width, height = decode_image.size
69
+ mask = np.zeros((height, width, 4), dtype=np.uint8)
70
+ mask[..., -1] = 255
71
+ mask = self.postprocess(mask)
72
+ x = {'image': x, 'mask': mask}
73
+ return super().preprocess(x)
74
+
75
+
76
+ class ModelWrapper:
77
+ def __init__(self, **kwargs):
78
+ self.g_ema = stylegan2(**kwargs).to(device)
79
+
80
+
81
+ def to_image(tensor):
82
+ tensor = tensor.squeeze(0).permute(1, 2, 0)
83
+ arr = tensor.detach().cpu().numpy()
84
+ arr = (arr - arr.min()) / (arr.max() - arr.min())
85
+ arr = arr * 255
86
+ return arr.astype('uint8')
87
+
88
+
89
+ def add_points_to_image(image, points, size=5):
90
+ h, w, = image.shape[:2]
91
+
92
+ for x, y in points['target']:
93
+ image[max(0, x - size):min(x + size, h - 1), max(0, y - size):min(y + size, w), :] = [255, 0, 0]
94
+ for x, y in points['handle']:
95
+ image[max(0, x - size):min(x + size, h - 1), max(0, y - size):min(y + size, w), :] = [0, 0, 255]
96
+
97
+ return image
98
+
99
+
100
+ def on_click(image, target_point, points, size, evt: gr.SelectData):
101
+ if target_point:
102
+ points['target'].append([evt.index[1], evt.index[0]])
103
+ image = add_points_to_image(image, points, size=SIZE_TO_CLICK_SIZE[size])
104
+ return image, str(evt.index), not target_point
105
+ points['handle'].append([evt.index[1], evt.index[0]])
106
+ image = add_points_to_image(image, points, size=SIZE_TO_CLICK_SIZE[size])
107
+ return image, str(evt.index), not target_point
108
+
109
+
110
+ def on_drag(model, points, max_iters, state, size, mask):
111
+ if len(points['handle']) == 0:
112
+ raise gr.Error('You must select at least one handle point and target point.')
113
+ if len(points['handle']) != len(points['target']):
114
+ raise gr.Error('You have uncompleted handle points, try to selct a target point or undo the handle point.')
115
+ max_iters = int(max_iters)
116
+ latent = state['latent']
117
+ noise = state['noise']
118
+ F = state['F']
119
+
120
+ handle_points = [torch.tensor(p).float() for p in points['handle']]
121
+ target_points = [torch.tensor(p).float() for p in points['target']]
122
+
123
+ if mask.get('mask') is not None:
124
+ mask = Image.fromarray(mask['mask']).convert('L')
125
+ mask = np.array(mask) == 255
126
+
127
+ mask = torch.from_numpy(mask).float().to(device)
128
+ mask = mask.unsqueeze(0).unsqueeze(0)
129
+ else:
130
+ mask = None
131
+
132
+ step = 0
133
+ for sample2, latent, F, handle_points in drag_gan(model.g_ema, latent, noise, F,
134
+ handle_points, target_points, mask,
135
+ max_iters=max_iters):
136
+ image = to_image(sample2)
137
+
138
+ state['F'] = F
139
+ state['latent'] = latent
140
+ state['sample'] = sample2
141
+ points['handle'] = [p.cpu().numpy().astype('int') for p in handle_points]
142
+ add_points_to_image(image, points, size=SIZE_TO_CLICK_SIZE[size])
143
+
144
+ state['history'].append(image)
145
+ step += 1
146
+ yield image, state, step
147
+
148
+
149
+ def on_reset(points, image, state):
150
+ return {'target': [], 'handle': []}, to_image(state['sample'])
151
+
152
+
153
+ def on_undo(points, image, state, size):
154
+ image = to_image(state['sample'])
155
+
156
+ if len(points['target']) < len(points['handle']):
157
+ points['handle'] = points['handle'][:-1]
158
+ else:
159
+ points['handle'] = points['handle'][:-1]
160
+ points['target'] = points['target'][:-1]
161
+
162
+ add_points_to_image(image, points, size=SIZE_TO_CLICK_SIZE[size])
163
+ return points, image
164
+
165
+
166
+ def on_change_model(selected, model):
167
+ size = CKPT_SIZE[selected]
168
+ model = ModelWrapper(size=size, ckpt=selected)
169
+ g_ema = model.g_ema
170
+ sample_z = torch.randn([1, 512], device=device)
171
+ latent, noise = g_ema.prepare([sample_z])
172
+ sample, F = g_ema.generate(latent, noise)
173
+
174
+ state = {
175
+ 'latent': latent,
176
+ 'noise': noise,
177
+ 'F': F,
178
+ 'sample': sample,
179
+ 'history': []
180
+ }
181
+ return model, state, to_image(sample), to_image(sample), size
182
+
183
+
184
+ def on_new_image(model):
185
+ g_ema = model.g_ema
186
+ sample_z = torch.randn([1, 512], device=device)
187
+ latent, noise = g_ema.prepare([sample_z])
188
+ sample, F = g_ema.generate(latent, noise)
189
+
190
+ state = {
191
+ 'latent': latent,
192
+ 'noise': noise,
193
+ 'F': F,
194
+ 'sample': sample,
195
+ 'history': []
196
+ }
197
+ points = {'target': [], 'handle': []}
198
+ target_point = False
199
+ return to_image(sample), to_image(sample), state, points, target_point
200
+
201
+
202
+ def on_max_iter_change(max_iters):
203
+ return gr.update(maximum=max_iters)
204
+
205
+
206
+ def on_save_files(image, state):
207
+ os.makedirs('tmp', exist_ok=True)
208
+ image_name = f'tmp/image_{uuid.uuid4()}.png'
209
+ video_name = f'tmp/video_{uuid.uuid4()}.mp4'
210
+ imageio.imsave(image_name, image)
211
+ imageio.mimsave(video_name, state['history'])
212
+ return [image_name, video_name]
213
+
214
+
215
+ def on_show_save():
216
+ return gr.update(visible=True)
217
+
218
+
219
+ def on_image_change(model, image_size, image):
220
+ image = Image.fromarray(image)
221
+ result = inverse_image(
222
+ model.g_ema,
223
+ image,
224
+ image_size=image_size
225
+ )
226
+ result['history'] = []
227
+ image = to_image(result['sample'])
228
+ points = {'target': [], 'handle': []}
229
+ target_point = False
230
+ return image, image, result, points, target_point
231
+
232
+
233
+ def on_mask_change(mask):
234
+ return mask['image']
235
+
236
+
237
+ def main():
238
+ torch.cuda.manual_seed(25)
239
+
240
+ with gr.Blocks() as demo:
241
+ wrapped_model = ModelWrapper(ckpt=DEFAULT_CKPT, size=CKPT_SIZE[DEFAULT_CKPT])
242
+ model = gr.State(wrapped_model)
243
+ sample_z = torch.randn([1, 512], device=device)
244
+ latent, noise = wrapped_model.g_ema.prepare([sample_z])
245
+ sample, F = wrapped_model.g_ema.generate(latent, noise)
246
+
247
+ gr.Markdown(
248
+ """
249
+ # DragGAN
250
+
251
+ Unofficial implementation of [Drag Your GAN: Interactive Point-based Manipulation on the Generative Image Manifold](https://vcai.mpi-inf.mpg.de/projects/DragGAN/)
252
+
253
+ [Our Implementation](https://github.com/Zeqiang-Lai/DragGAN) | [Official Implementation](https://github.com/XingangPan/DragGAN) (Not released yet)
254
+
255
+ ## Tutorial
256
+
257
+ 1. (Optional) Draw a mask indicate the movable region.
258
+ 2. Setup a least one pair of handle point and target point.
259
+ 3. Click "Drag it".
260
+
261
+ ## Hints
262
+
263
+ - Handle points (Blue): the point you want to drag.
264
+ - Target points (Red): the destination you want to drag towards to.
265
+
266
+ ## Primary Support of Custom Image.
267
+
268
+ - We now support dragging user uploaded image by GAN inversion.
269
+ - **Please upload your image at `Setup Handle Points` pannel.** Upload it from `Draw a Mask` would cause errors for now.
270
+ - Due to the limitation of GAN inversion,
271
+ - You might wait roughly 1 minute to see the GAN version of the uploaded image.
272
+ - The shown image might be slightly difference from the uploaded one.
273
+ - It could also fail to invert the uploaded image and generate very poor results.
274
+ - Idealy, you should choose the closest model of the uploaded image. For example, choose `stylegan2-ffhq-config-f.pt` for human face. `stylegan2-cat-config-f.pt` for cat.
275
+
276
+ > Please fire an issue if you have encounted any problem. Also don't forgot to give a star to the [Official Repo](https://github.com/XingangPan/DragGAN), [our project](https://github.com/Zeqiang-Lai/DragGAN) could not exist without it.
277
+ """,
278
+ )
279
+ state = gr.State({
280
+ 'latent': latent,
281
+ 'noise': noise,
282
+ 'F': F,
283
+ 'sample': sample,
284
+ 'history': []
285
+ })
286
+ points = gr.State({'target': [], 'handle': []})
287
+ size = gr.State(CKPT_SIZE[DEFAULT_CKPT])
288
+
289
+ with gr.Row():
290
+ with gr.Column(scale=0.3):
291
+ with gr.Accordion("Model"):
292
+ model_dropdown = gr.Dropdown(choices=list(CKPT_SIZE.keys()), value=DEFAULT_CKPT,
293
+ label='StyleGAN2 model')
294
+ max_iters = gr.Slider(1, 500, 20, step=1, label='Max Iterations')
295
+ new_btn = gr.Button('New Image')
296
+ with gr.Accordion('Drag'):
297
+ with gr.Row():
298
+ with gr.Column(min_width=100):
299
+ text = gr.Textbox(label='Selected Point', interactive=False)
300
+ with gr.Column(min_width=100):
301
+ target_point = gr.Checkbox(label='Target Point', interactive=False)
302
+ with gr.Row():
303
+ with gr.Column(min_width=100):
304
+ reset_btn = gr.Button('Reset All')
305
+ with gr.Column(min_width=100):
306
+ undo_btn = gr.Button('Undo Last')
307
+ with gr.Row():
308
+ btn = gr.Button('Drag it', variant='primary')
309
+
310
+ with gr.Accordion('Save', visible=False) as save_panel:
311
+ files = gr.Files(value=[])
312
+
313
+ progress = gr.Slider(value=0, maximum=20, label='Progress', interactive=False)
314
+
315
+ with gr.Column():
316
+ with gr.Tabs():
317
+ with gr.Tab('Draw a Mask', id='mask'):
318
+ mask = ImageMask(value=to_image(sample), label='Mask').style(height=768, width=768)
319
+ with gr.Tab('Setup Handle Points', id='input'):
320
+ image = grImage(to_image(sample)).style(height=768, width=768)
321
+
322
+ image.select(on_click, [image, target_point, points, size], [image, text, target_point])
323
+ image.upload(on_image_change, [model, size, image], [image, mask, state, points, target_point])
324
+ mask.upload(on_mask_change, [mask], [image])
325
+ btn.click(on_drag, inputs=[model, points, max_iters, state, size, mask], outputs=[image, state, progress]).then(
326
+ on_show_save, outputs=save_panel).then(
327
+ on_save_files, inputs=[image, state], outputs=[files]
328
+ )
329
+ reset_btn.click(on_reset, inputs=[points, image, state], outputs=[points, image])
330
+ undo_btn.click(on_undo, inputs=[points, image, state, size], outputs=[points, image])
331
+ model_dropdown.change(on_change_model, inputs=[model_dropdown, model], outputs=[model, state, image, mask, size])
332
+ new_btn.click(on_new_image, inputs=[model], outputs=[image, mask, state, points, target_point])
333
+ max_iters.change(on_max_iter_change, inputs=max_iters, outputs=progress)
334
+ return demo
335
+
336
+
337
+ if __name__ == '__main__':
338
+ import argparse
339
+ parser = argparse.ArgumentParser()
340
+ parser.add_argument('--device', default='cuda')
341
+ parser.add_argument('--share', action='store_true')
342
+ parser.add_argument('-p', '--port', default=None)
343
+ parser.add_argument('--ip', default=None)
344
+ args = parser.parse_args()
345
+ device = args.device
346
+ demo = main()
347
+ print('Successfully loaded, starting gradio demo')
348
+ demo.queue(concurrency_count=1, max_size=20).launch(share=args.share, server_name=args.ip, server_port=args.port)
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ tqdm
3
+ torch
4
+ numpy
5
+ ninja
6
+ fire
7
+ imageio
8
+ torchvision
9
+ IPython
stylegan2/__init__.py ADDED
File without changes
stylegan2/inversion.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+
4
+ import torch
5
+ from torch import optim
6
+ from torch.nn import functional as FF
7
+ from torchvision import transforms
8
+ from PIL import Image
9
+ from tqdm import tqdm
10
+ import dataclasses
11
+
12
+ from .lpips import util
13
+
14
+
15
+ def noise_regularize(noises):
16
+ loss = 0
17
+
18
+ for noise in noises:
19
+ size = noise.shape[2]
20
+
21
+ while True:
22
+ loss = (
23
+ loss
24
+ + (noise * torch.roll(noise, shifts=1, dims=3)).mean().pow(2)
25
+ + (noise * torch.roll(noise, shifts=1, dims=2)).mean().pow(2)
26
+ )
27
+
28
+ if size <= 8:
29
+ break
30
+
31
+ noise = noise.reshape([-1, 1, size // 2, 2, size // 2, 2])
32
+ noise = noise.mean([3, 5])
33
+ size //= 2
34
+
35
+ return loss
36
+
37
+
38
+ def noise_normalize_(noises):
39
+ for noise in noises:
40
+ mean = noise.mean()
41
+ std = noise.std()
42
+
43
+ noise.data.add_(-mean).div_(std)
44
+
45
+
46
+ def get_lr(t, initial_lr, rampdown=0.25, rampup=0.05):
47
+ lr_ramp = min(1, (1 - t) / rampdown)
48
+ lr_ramp = 0.5 - 0.5 * math.cos(lr_ramp * math.pi)
49
+ lr_ramp = lr_ramp * min(1, t / rampup)
50
+
51
+ return initial_lr * lr_ramp
52
+
53
+
54
+ def latent_noise(latent, strength):
55
+ noise = torch.randn_like(latent) * strength
56
+
57
+ return latent + noise
58
+
59
+
60
+ def make_image(tensor):
61
+ return (
62
+ tensor.detach()
63
+ .clamp_(min=-1, max=1)
64
+ .add(1)
65
+ .div_(2)
66
+ .mul(255)
67
+ .type(torch.uint8)
68
+ .permute(0, 2, 3, 1)
69
+ .to("cpu")
70
+ .numpy()
71
+ )
72
+
73
+
74
+ @dataclasses.dataclass
75
+ class InverseConfig:
76
+ lr_warmup = 0.05
77
+ lr_decay = 0.25
78
+ lr = 0.1
79
+ noise = 0.05
80
+ noise_decay = 0.75
81
+ step = 1000
82
+ noise_regularize = 1e5
83
+ mse = 0
84
+ w_plus = False,
85
+
86
+
87
+ def inverse_image(
88
+ g_ema,
89
+ image,
90
+ image_size=256,
91
+ config=InverseConfig()
92
+ ):
93
+ device = "cuda"
94
+ args = config
95
+
96
+ n_mean_latent = 10000
97
+
98
+ resize = min(image_size, 256)
99
+
100
+ transform = transforms.Compose(
101
+ [
102
+ transforms.Resize(resize),
103
+ transforms.CenterCrop(resize),
104
+ transforms.ToTensor(),
105
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
106
+ ]
107
+ )
108
+
109
+ imgs = []
110
+ img = transform(image)
111
+ imgs.append(img)
112
+
113
+ imgs = torch.stack(imgs, 0).to(device)
114
+
115
+ with torch.no_grad():
116
+ noise_sample = torch.randn(n_mean_latent, 512, device=device)
117
+ latent_out = g_ema.style(noise_sample)
118
+
119
+ latent_mean = latent_out.mean(0)
120
+ latent_std = ((latent_out - latent_mean).pow(2).sum() / n_mean_latent) ** 0.5
121
+
122
+ percept = util.PerceptualLoss(
123
+ model="net-lin", net="vgg", use_gpu=device.startswith("cuda")
124
+ )
125
+
126
+ noises_single = g_ema.make_noise()
127
+ noises = []
128
+ for noise in noises_single:
129
+ noises.append(noise.repeat(imgs.shape[0], 1, 1, 1).normal_())
130
+
131
+ latent_in = latent_mean.detach().clone().unsqueeze(0).repeat(imgs.shape[0], 1)
132
+
133
+ if args.w_plus:
134
+ latent_in = latent_in.unsqueeze(1).repeat(1, g_ema.n_latent, 1)
135
+
136
+ latent_in.requires_grad = True
137
+
138
+ for noise in noises:
139
+ noise.requires_grad = True
140
+
141
+ optimizer = optim.Adam([latent_in] + noises, lr=args.lr)
142
+
143
+ pbar = tqdm(range(args.step))
144
+ latent_path = []
145
+
146
+ for i in pbar:
147
+ t = i / args.step
148
+ lr = get_lr(t, args.lr)
149
+ optimizer.param_groups[0]["lr"] = lr
150
+ noise_strength = latent_std * args.noise * max(0, 1 - t / args.noise_decay) ** 2
151
+ latent_n = latent_noise(latent_in, noise_strength.item())
152
+
153
+ latent, noise = g_ema.prepare([latent_n], input_is_latent=True, noise=noises)
154
+ img_gen, F = g_ema.generate(latent, noise)
155
+
156
+ batch, channel, height, width = img_gen.shape
157
+
158
+ if height > 256:
159
+ factor = height // 256
160
+
161
+ img_gen = img_gen.reshape(
162
+ batch, channel, height // factor, factor, width // factor, factor
163
+ )
164
+ img_gen = img_gen.mean([3, 5])
165
+
166
+ p_loss = percept(img_gen, imgs).sum()
167
+ n_loss = noise_regularize(noises)
168
+ mse_loss = FF.mse_loss(img_gen, imgs)
169
+
170
+ loss = p_loss + args.noise_regularize * n_loss + args.mse * mse_loss
171
+
172
+ optimizer.zero_grad()
173
+ loss.backward()
174
+ optimizer.step()
175
+
176
+ noise_normalize_(noises)
177
+
178
+ if (i + 1) % 100 == 0:
179
+ latent_path.append(latent_in.detach().clone())
180
+
181
+ pbar.set_description(
182
+ (
183
+ f"perceptual: {p_loss.item():.4f}; noise regularize: {n_loss.item():.4f};"
184
+ f" mse: {mse_loss.item():.4f}; lr: {lr:.4f}"
185
+ )
186
+ )
187
+
188
+ latent, noise = g_ema.prepare([latent_path[-1]], input_is_latent=True, noise=noises)
189
+ img_gen, F = g_ema.generate(latent, noise)
190
+
191
+ img_ar = make_image(img_gen)
192
+
193
+ i = 0
194
+
195
+ noise_single = []
196
+ for noise in noises:
197
+ noise_single.append(noise[i: i + 1])
198
+
199
+ result = {
200
+ "latent": latent,
201
+ "noise": noise_single,
202
+ 'F': F,
203
+ "sample": img_gen,
204
+ }
205
+
206
+ pil_img = Image.fromarray(img_ar[i])
207
+ pil_img.save('project.png')
208
+
209
+ return result
stylegan2/lpips/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+
2
+ from __future__ import absolute_import
3
+ from __future__ import division
4
+ from __future__ import print_function
5
+
stylegan2/lpips/base_model.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ from torch.autograd import Variable
5
+ from pdb import set_trace as st
6
+ from IPython import embed
7
+
8
+ class BaseModel():
9
+ def __init__(self):
10
+ pass;
11
+
12
+ def name(self):
13
+ return 'BaseModel'
14
+
15
+ def initialize(self, use_gpu=True, gpu_ids=[0]):
16
+ self.use_gpu = use_gpu
17
+ self.gpu_ids = gpu_ids
18
+
19
+ def forward(self):
20
+ pass
21
+
22
+ def get_image_paths(self):
23
+ pass
24
+
25
+ def optimize_parameters(self):
26
+ pass
27
+
28
+ def get_current_visuals(self):
29
+ return self.input
30
+
31
+ def get_current_errors(self):
32
+ return {}
33
+
34
+ def save(self, label):
35
+ pass
36
+
37
+ # helper saving function that can be used by subclasses
38
+ def save_network(self, network, path, network_label, epoch_label):
39
+ save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
40
+ save_path = os.path.join(path, save_filename)
41
+ torch.save(network.state_dict(), save_path)
42
+
43
+ # helper loading function that can be used by subclasses
44
+ def load_network(self, network, network_label, epoch_label):
45
+ save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
46
+ save_path = os.path.join(self.save_dir, save_filename)
47
+ print('Loading network from %s'%save_path)
48
+ network.load_state_dict(torch.load(save_path))
49
+
50
+ def update_learning_rate():
51
+ pass
52
+
53
+ def get_image_paths(self):
54
+ return self.image_paths
55
+
56
+ def save_done(self, flag=False):
57
+ np.save(os.path.join(self.save_dir, 'done_flag'),flag)
58
+ np.savetxt(os.path.join(self.save_dir, 'done_flag'),[flag,],fmt='%i')
stylegan2/lpips/dist_model.py ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from __future__ import absolute_import
3
+
4
+ import sys
5
+ import numpy as np
6
+ import torch
7
+ from torch import nn
8
+ import os
9
+ from collections import OrderedDict
10
+ from torch.autograd import Variable
11
+ import itertools
12
+ from .base_model import BaseModel
13
+ from scipy.ndimage import zoom
14
+ import fractions
15
+ import functools
16
+ import skimage.transform
17
+ from tqdm import tqdm
18
+ import urllib
19
+
20
+ from IPython import embed
21
+
22
+ from . import networks_basic as networks
23
+ from . import util
24
+
25
+
26
+ class DownloadProgressBar(tqdm):
27
+ def update_to(self, b=1, bsize=1, tsize=None):
28
+ if tsize is not None:
29
+ self.total = tsize
30
+ self.update(b * bsize - self.n)
31
+
32
+
33
+ def get_path(base_path):
34
+ BASE_DIR = os.path.join('checkpoints')
35
+
36
+ save_path = os.path.join(BASE_DIR, base_path)
37
+ if not os.path.exists(save_path):
38
+ url = f"https://huggingface.co/aaronb/StyleGAN2/resolve/main/{base_path}"
39
+ print(f'{base_path} not found')
40
+ print('Try to download from huggingface: ', url)
41
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
42
+ download_url(url, save_path)
43
+ print('Downloaded to ', save_path)
44
+ return save_path
45
+
46
+
47
+ def download_url(url, output_path):
48
+ with DownloadProgressBar(unit='B', unit_scale=True,
49
+ miniters=1, desc=url.split('/')[-1]) as t:
50
+ urllib.request.urlretrieve(url, filename=output_path, reporthook=t.update_to)
51
+
52
+
53
+ class DistModel(BaseModel):
54
+ def name(self):
55
+ return self.model_name
56
+
57
+ def initialize(self, model='net-lin', net='alex', colorspace='Lab', pnet_rand=False, pnet_tune=False, model_path=None,
58
+ use_gpu=True, printNet=False, spatial=False,
59
+ is_train=False, lr=.0001, beta1=0.5, version='0.1', gpu_ids=[0]):
60
+ '''
61
+ INPUTS
62
+ model - ['net-lin'] for linearly calibrated network
63
+ ['net'] for off-the-shelf network
64
+ ['L2'] for L2 distance in Lab colorspace
65
+ ['SSIM'] for ssim in RGB colorspace
66
+ net - ['squeeze','alex','vgg']
67
+ model_path - if None, will look in weights/[NET_NAME].pth
68
+ colorspace - ['Lab','RGB'] colorspace to use for L2 and SSIM
69
+ use_gpu - bool - whether or not to use a GPU
70
+ printNet - bool - whether or not to print network architecture out
71
+ spatial - bool - whether to output an array containing varying distances across spatial dimensions
72
+ spatial_shape - if given, output spatial shape. if None then spatial shape is determined automatically via spatial_factor (see below).
73
+ spatial_factor - if given, specifies upsampling factor relative to the largest spatial extent of a convolutional layer. if None then resized to size of input images.
74
+ spatial_order - spline order of filter for upsampling in spatial mode, by default 1 (bilinear).
75
+ is_train - bool - [True] for training mode
76
+ lr - float - initial learning rate
77
+ beta1 - float - initial momentum term for adam
78
+ version - 0.1 for latest, 0.0 was original (with a bug)
79
+ gpu_ids - int array - [0] by default, gpus to use
80
+ '''
81
+ BaseModel.initialize(self, use_gpu=use_gpu, gpu_ids=gpu_ids)
82
+
83
+ self.model = model
84
+ self.net = net
85
+ self.is_train = is_train
86
+ self.spatial = spatial
87
+ self.gpu_ids = gpu_ids
88
+ self.model_name = '%s [%s]' % (model, net)
89
+
90
+ if(self.model == 'net-lin'): # pretrained net + linear layer
91
+ self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_tune=pnet_tune, pnet_type=net,
92
+ use_dropout=True, spatial=spatial, version=version, lpips=True)
93
+ kw = {}
94
+ if not use_gpu:
95
+ kw['map_location'] = 'cpu'
96
+ if(model_path is None):
97
+ model_path = get_path('weights/v%s/%s.pth' % (version, net))
98
+
99
+ if(not is_train):
100
+ print('Loading model from: %s' % model_path)
101
+ self.net.load_state_dict(torch.load(model_path, **kw), strict=False)
102
+
103
+ elif(self.model == 'net'): # pretrained network
104
+ self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_type=net, lpips=False)
105
+ elif(self.model in ['L2', 'l2']):
106
+ self.net = networks.L2(use_gpu=use_gpu, colorspace=colorspace) # not really a network, only for testing
107
+ self.model_name = 'L2'
108
+ elif(self.model in ['DSSIM', 'dssim', 'SSIM', 'ssim']):
109
+ self.net = networks.DSSIM(use_gpu=use_gpu, colorspace=colorspace)
110
+ self.model_name = 'SSIM'
111
+ else:
112
+ raise ValueError("Model [%s] not recognized." % self.model)
113
+
114
+ self.parameters = list(self.net.parameters())
115
+
116
+ if self.is_train: # training mode
117
+ # extra network on top to go from distances (d0,d1) => predicted human judgment (h*)
118
+ self.rankLoss = networks.BCERankingLoss()
119
+ self.parameters += list(self.rankLoss.net.parameters())
120
+ self.lr = lr
121
+ self.old_lr = lr
122
+ self.optimizer_net = torch.optim.Adam(self.parameters, lr=lr, betas=(beta1, 0.999))
123
+ else: # test mode
124
+ self.net.eval()
125
+
126
+ if(use_gpu):
127
+ self.net.to(gpu_ids[0])
128
+ self.net = torch.nn.DataParallel(self.net, device_ids=gpu_ids)
129
+ if(self.is_train):
130
+ self.rankLoss = self.rankLoss.to(device=gpu_ids[0]) # just put this on GPU0
131
+
132
+ if(printNet):
133
+ print('---------- Networks initialized -------------')
134
+ networks.print_network(self.net)
135
+ print('-----------------------------------------------')
136
+
137
+ def forward(self, in0, in1, retPerLayer=False):
138
+ ''' Function computes the distance between image patches in0 and in1
139
+ INPUTS
140
+ in0, in1 - torch.Tensor object of shape Nx3xXxY - image patch scaled to [-1,1]
141
+ OUTPUT
142
+ computed distances between in0 and in1
143
+ '''
144
+
145
+ return self.net.forward(in0, in1, retPerLayer=retPerLayer)
146
+
147
+ # ***** TRAINING FUNCTIONS *****
148
+ def optimize_parameters(self):
149
+ self.forward_train()
150
+ self.optimizer_net.zero_grad()
151
+ self.backward_train()
152
+ self.optimizer_net.step()
153
+ self.clamp_weights()
154
+
155
+ def clamp_weights(self):
156
+ for module in self.net.modules():
157
+ if(hasattr(module, 'weight') and module.kernel_size == (1, 1)):
158
+ module.weight.data = torch.clamp(module.weight.data, min=0)
159
+
160
+ def set_input(self, data):
161
+ self.input_ref = data['ref']
162
+ self.input_p0 = data['p0']
163
+ self.input_p1 = data['p1']
164
+ self.input_judge = data['judge']
165
+
166
+ if(self.use_gpu):
167
+ self.input_ref = self.input_ref.to(device=self.gpu_ids[0])
168
+ self.input_p0 = self.input_p0.to(device=self.gpu_ids[0])
169
+ self.input_p1 = self.input_p1.to(device=self.gpu_ids[0])
170
+ self.input_judge = self.input_judge.to(device=self.gpu_ids[0])
171
+
172
+ self.var_ref = Variable(self.input_ref, requires_grad=True)
173
+ self.var_p0 = Variable(self.input_p0, requires_grad=True)
174
+ self.var_p1 = Variable(self.input_p1, requires_grad=True)
175
+
176
+ def forward_train(self): # run forward pass
177
+ # print(self.net.module.scaling_layer.shift)
178
+ # print(torch.norm(self.net.module.net.slice1[0].weight).item(), torch.norm(self.net.module.lin0.model[1].weight).item())
179
+
180
+ self.d0 = self.forward(self.var_ref, self.var_p0)
181
+ self.d1 = self.forward(self.var_ref, self.var_p1)
182
+ self.acc_r = self.compute_accuracy(self.d0, self.d1, self.input_judge)
183
+
184
+ self.var_judge = Variable(1. * self.input_judge).view(self.d0.size())
185
+
186
+ self.loss_total = self.rankLoss.forward(self.d0, self.d1, self.var_judge * 2. - 1.)
187
+
188
+ return self.loss_total
189
+
190
+ def backward_train(self):
191
+ torch.mean(self.loss_total).backward()
192
+
193
+ def compute_accuracy(self, d0, d1, judge):
194
+ ''' d0, d1 are Variables, judge is a Tensor '''
195
+ d1_lt_d0 = (d1 < d0).cpu().data.numpy().flatten()
196
+ judge_per = judge.cpu().numpy().flatten()
197
+ return d1_lt_d0 * judge_per + (1 - d1_lt_d0) * (1 - judge_per)
198
+
199
+ def get_current_errors(self):
200
+ retDict = OrderedDict([('loss_total', self.loss_total.data.cpu().numpy()),
201
+ ('acc_r', self.acc_r)])
202
+
203
+ for key in retDict.keys():
204
+ retDict[key] = np.mean(retDict[key])
205
+
206
+ return retDict
207
+
208
+ def get_current_visuals(self):
209
+ zoom_factor = 256 / self.var_ref.data.size()[2]
210
+
211
+ ref_img = util.tensor2im(self.var_ref.data)
212
+ p0_img = util.tensor2im(self.var_p0.data)
213
+ p1_img = util.tensor2im(self.var_p1.data)
214
+
215
+ ref_img_vis = zoom(ref_img, [zoom_factor, zoom_factor, 1], order=0)
216
+ p0_img_vis = zoom(p0_img, [zoom_factor, zoom_factor, 1], order=0)
217
+ p1_img_vis = zoom(p1_img, [zoom_factor, zoom_factor, 1], order=0)
218
+
219
+ return OrderedDict([('ref', ref_img_vis),
220
+ ('p0', p0_img_vis),
221
+ ('p1', p1_img_vis)])
222
+
223
+ def save(self, path, label):
224
+ if(self.use_gpu):
225
+ self.save_network(self.net.module, path, '', label)
226
+ else:
227
+ self.save_network(self.net, path, '', label)
228
+ self.save_network(self.rankLoss.net, path, 'rank', label)
229
+
230
+ def update_learning_rate(self, nepoch_decay):
231
+ lrd = self.lr / nepoch_decay
232
+ lr = self.old_lr - lrd
233
+
234
+ for param_group in self.optimizer_net.param_groups:
235
+ param_group['lr'] = lr
236
+
237
+ print('update lr [%s] decay: %f -> %f' % (type, self.old_lr, lr))
238
+ self.old_lr = lr
239
+
240
+
241
+ def score_2afc_dataset(data_loader, func, name=''):
242
+ ''' Function computes Two Alternative Forced Choice (2AFC) score using
243
+ distance function 'func' in dataset 'data_loader'
244
+ INPUTS
245
+ data_loader - CustomDatasetDataLoader object - contains a TwoAFCDataset inside
246
+ func - callable distance function - calling d=func(in0,in1) should take 2
247
+ pytorch tensors with shape Nx3xXxY, and return numpy array of length N
248
+ OUTPUTS
249
+ [0] - 2AFC score in [0,1], fraction of time func agrees with human evaluators
250
+ [1] - dictionary with following elements
251
+ d0s,d1s - N arrays containing distances between reference patch to perturbed patches
252
+ gts - N array in [0,1], preferred patch selected by human evaluators
253
+ (closer to "0" for left patch p0, "1" for right patch p1,
254
+ "0.6" means 60pct people preferred right patch, 40pct preferred left)
255
+ scores - N array in [0,1], corresponding to what percentage function agreed with humans
256
+ CONSTS
257
+ N - number of test triplets in data_loader
258
+ '''
259
+
260
+ d0s = []
261
+ d1s = []
262
+ gts = []
263
+
264
+ for data in tqdm(data_loader.load_data(), desc=name):
265
+ d0s += func(data['ref'], data['p0']).data.cpu().numpy().flatten().tolist()
266
+ d1s += func(data['ref'], data['p1']).data.cpu().numpy().flatten().tolist()
267
+ gts += data['judge'].cpu().numpy().flatten().tolist()
268
+
269
+ d0s = np.array(d0s)
270
+ d1s = np.array(d1s)
271
+ gts = np.array(gts)
272
+ scores = (d0s < d1s) * (1. - gts) + (d1s < d0s) * gts + (d1s == d0s) * .5
273
+
274
+ return(np.mean(scores), dict(d0s=d0s, d1s=d1s, gts=gts, scores=scores))
275
+
276
+
277
+ def score_jnd_dataset(data_loader, func, name=''):
278
+ ''' Function computes JND score using distance function 'func' in dataset 'data_loader'
279
+ INPUTS
280
+ data_loader - CustomDatasetDataLoader object - contains a JNDDataset inside
281
+ func - callable distance function - calling d=func(in0,in1) should take 2
282
+ pytorch tensors with shape Nx3xXxY, and return pytorch array of length N
283
+ OUTPUTS
284
+ [0] - JND score in [0,1], mAP score (area under precision-recall curve)
285
+ [1] - dictionary with following elements
286
+ ds - N array containing distances between two patches shown to human evaluator
287
+ sames - N array containing fraction of people who thought the two patches were identical
288
+ CONSTS
289
+ N - number of test triplets in data_loader
290
+ '''
291
+
292
+ ds = []
293
+ gts = []
294
+
295
+ for data in tqdm(data_loader.load_data(), desc=name):
296
+ ds += func(data['p0'], data['p1']).data.cpu().numpy().tolist()
297
+ gts += data['same'].cpu().numpy().flatten().tolist()
298
+
299
+ sames = np.array(gts)
300
+ ds = np.array(ds)
301
+
302
+ sorted_inds = np.argsort(ds)
303
+ ds_sorted = ds[sorted_inds]
304
+ sames_sorted = sames[sorted_inds]
305
+
306
+ TPs = np.cumsum(sames_sorted)
307
+ FPs = np.cumsum(1 - sames_sorted)
308
+ FNs = np.sum(sames_sorted) - TPs
309
+
310
+ precs = TPs / (TPs + FPs)
311
+ recs = TPs / (TPs + FNs)
312
+ score = util.voc_ap(recs, precs)
313
+
314
+ return(score, dict(ds=ds, sames=sames))
stylegan2/lpips/networks_basic.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from __future__ import absolute_import
3
+
4
+ import sys
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.init as init
8
+ from torch.autograd import Variable
9
+ import numpy as np
10
+ from pdb import set_trace as st
11
+ from skimage import color
12
+ from IPython import embed
13
+ from . import pretrained_networks as pn
14
+
15
+ from . import util
16
+
17
+
18
+ def spatial_average(in_tens, keepdim=True):
19
+ return in_tens.mean([2,3],keepdim=keepdim)
20
+
21
+ def upsample(in_tens, out_H=64): # assumes scale factor is same for H and W
22
+ in_H = in_tens.shape[2]
23
+ scale_factor = 1.*out_H/in_H
24
+
25
+ return nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=False)(in_tens)
26
+
27
+ # Learned perceptual metric
28
+ class PNetLin(nn.Module):
29
+ def __init__(self, pnet_type='vgg', pnet_rand=False, pnet_tune=False, use_dropout=True, spatial=False, version='0.1', lpips=True):
30
+ super(PNetLin, self).__init__()
31
+
32
+ self.pnet_type = pnet_type
33
+ self.pnet_tune = pnet_tune
34
+ self.pnet_rand = pnet_rand
35
+ self.spatial = spatial
36
+ self.lpips = lpips
37
+ self.version = version
38
+ self.scaling_layer = ScalingLayer()
39
+
40
+ if(self.pnet_type in ['vgg','vgg16']):
41
+ net_type = pn.vgg16
42
+ self.chns = [64,128,256,512,512]
43
+ elif(self.pnet_type=='alex'):
44
+ net_type = pn.alexnet
45
+ self.chns = [64,192,384,256,256]
46
+ elif(self.pnet_type=='squeeze'):
47
+ net_type = pn.squeezenet
48
+ self.chns = [64,128,256,384,384,512,512]
49
+ self.L = len(self.chns)
50
+
51
+ self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune)
52
+
53
+ if(lpips):
54
+ self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
55
+ self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
56
+ self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
57
+ self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
58
+ self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
59
+ self.lins = [self.lin0,self.lin1,self.lin2,self.lin3,self.lin4]
60
+ if(self.pnet_type=='squeeze'): # 7 layers for squeezenet
61
+ self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout)
62
+ self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout)
63
+ self.lins+=[self.lin5,self.lin6]
64
+
65
+ def forward(self, in0, in1, retPerLayer=False):
66
+ # v0.0 - original release had a bug, where input was not scaled
67
+ in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) if self.version=='0.1' else (in0, in1)
68
+ outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input)
69
+ feats0, feats1, diffs = {}, {}, {}
70
+
71
+ for kk in range(self.L):
72
+ feats0[kk], feats1[kk] = util.normalize_tensor(outs0[kk]), util.normalize_tensor(outs1[kk])
73
+ diffs[kk] = (feats0[kk]-feats1[kk])**2
74
+
75
+ if(self.lpips):
76
+ if(self.spatial):
77
+ res = [upsample(self.lins[kk].model(diffs[kk]), out_H=in0.shape[2]) for kk in range(self.L)]
78
+ else:
79
+ res = [spatial_average(self.lins[kk].model(diffs[kk]), keepdim=True) for kk in range(self.L)]
80
+ else:
81
+ if(self.spatial):
82
+ res = [upsample(diffs[kk].sum(dim=1,keepdim=True), out_H=in0.shape[2]) for kk in range(self.L)]
83
+ else:
84
+ res = [spatial_average(diffs[kk].sum(dim=1,keepdim=True), keepdim=True) for kk in range(self.L)]
85
+
86
+ val = res[0]
87
+ for l in range(1,self.L):
88
+ val += res[l]
89
+
90
+ if(retPerLayer):
91
+ return (val, res)
92
+ else:
93
+ return val
94
+
95
+ class ScalingLayer(nn.Module):
96
+ def __init__(self):
97
+ super(ScalingLayer, self).__init__()
98
+ self.register_buffer('shift', torch.Tensor([-.030,-.088,-.188])[None,:,None,None])
99
+ self.register_buffer('scale', torch.Tensor([.458,.448,.450])[None,:,None,None])
100
+
101
+ def forward(self, inp):
102
+ return (inp - self.shift) / self.scale
103
+
104
+
105
+ class NetLinLayer(nn.Module):
106
+ ''' A single linear layer which does a 1x1 conv '''
107
+ def __init__(self, chn_in, chn_out=1, use_dropout=False):
108
+ super(NetLinLayer, self).__init__()
109
+
110
+ layers = [nn.Dropout(),] if(use_dropout) else []
111
+ layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),]
112
+ self.model = nn.Sequential(*layers)
113
+
114
+
115
+ class Dist2LogitLayer(nn.Module):
116
+ ''' takes 2 distances, puts through fc layers, spits out value between [0,1] (if use_sigmoid is True) '''
117
+ def __init__(self, chn_mid=32, use_sigmoid=True):
118
+ super(Dist2LogitLayer, self).__init__()
119
+
120
+ layers = [nn.Conv2d(5, chn_mid, 1, stride=1, padding=0, bias=True),]
121
+ layers += [nn.LeakyReLU(0.2,True),]
122
+ layers += [nn.Conv2d(chn_mid, chn_mid, 1, stride=1, padding=0, bias=True),]
123
+ layers += [nn.LeakyReLU(0.2,True),]
124
+ layers += [nn.Conv2d(chn_mid, 1, 1, stride=1, padding=0, bias=True),]
125
+ if(use_sigmoid):
126
+ layers += [nn.Sigmoid(),]
127
+ self.model = nn.Sequential(*layers)
128
+
129
+ def forward(self,d0,d1,eps=0.1):
130
+ return self.model.forward(torch.cat((d0,d1,d0-d1,d0/(d1+eps),d1/(d0+eps)),dim=1))
131
+
132
+ class BCERankingLoss(nn.Module):
133
+ def __init__(self, chn_mid=32):
134
+ super(BCERankingLoss, self).__init__()
135
+ self.net = Dist2LogitLayer(chn_mid=chn_mid)
136
+ # self.parameters = list(self.net.parameters())
137
+ self.loss = torch.nn.BCELoss()
138
+
139
+ def forward(self, d0, d1, judge):
140
+ per = (judge+1.)/2.
141
+ self.logit = self.net.forward(d0,d1)
142
+ return self.loss(self.logit, per)
143
+
144
+ # L2, DSSIM metrics
145
+ class FakeNet(nn.Module):
146
+ def __init__(self, use_gpu=True, colorspace='Lab'):
147
+ super(FakeNet, self).__init__()
148
+ self.use_gpu = use_gpu
149
+ self.colorspace=colorspace
150
+
151
+ class L2(FakeNet):
152
+
153
+ def forward(self, in0, in1, retPerLayer=None):
154
+ assert(in0.size()[0]==1) # currently only supports batchSize 1
155
+
156
+ if(self.colorspace=='RGB'):
157
+ (N,C,X,Y) = in0.size()
158
+ value = torch.mean(torch.mean(torch.mean((in0-in1)**2,dim=1).view(N,1,X,Y),dim=2).view(N,1,1,Y),dim=3).view(N)
159
+ return value
160
+ elif(self.colorspace=='Lab'):
161
+ value = util.l2(util.tensor2np(util.tensor2tensorlab(in0.data,to_norm=False)),
162
+ util.tensor2np(util.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float')
163
+ ret_var = Variable( torch.Tensor((value,) ) )
164
+ if(self.use_gpu):
165
+ ret_var = ret_var.cuda()
166
+ return ret_var
167
+
168
+ class DSSIM(FakeNet):
169
+
170
+ def forward(self, in0, in1, retPerLayer=None):
171
+ assert(in0.size()[0]==1) # currently only supports batchSize 1
172
+
173
+ if(self.colorspace=='RGB'):
174
+ value = util.dssim(1.*util.tensor2im(in0.data), 1.*util.tensor2im(in1.data), range=255.).astype('float')
175
+ elif(self.colorspace=='Lab'):
176
+ value = util.dssim(util.tensor2np(util.tensor2tensorlab(in0.data,to_norm=False)),
177
+ util.tensor2np(util.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float')
178
+ ret_var = Variable( torch.Tensor((value,) ) )
179
+ if(self.use_gpu):
180
+ ret_var = ret_var.cuda()
181
+ return ret_var
182
+
183
+ def print_network(net):
184
+ num_params = 0
185
+ for param in net.parameters():
186
+ num_params += param.numel()
187
+ print('Network',net)
188
+ print('Total number of parameters: %d' % num_params)
stylegan2/lpips/pretrained_networks.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import namedtuple
2
+ import torch
3
+ from torchvision import models as tv
4
+ from IPython import embed
5
+
6
+ class squeezenet(torch.nn.Module):
7
+ def __init__(self, requires_grad=False, pretrained=True):
8
+ super(squeezenet, self).__init__()
9
+ pretrained_features = tv.squeezenet1_1(pretrained=pretrained).features
10
+ self.slice1 = torch.nn.Sequential()
11
+ self.slice2 = torch.nn.Sequential()
12
+ self.slice3 = torch.nn.Sequential()
13
+ self.slice4 = torch.nn.Sequential()
14
+ self.slice5 = torch.nn.Sequential()
15
+ self.slice6 = torch.nn.Sequential()
16
+ self.slice7 = torch.nn.Sequential()
17
+ self.N_slices = 7
18
+ for x in range(2):
19
+ self.slice1.add_module(str(x), pretrained_features[x])
20
+ for x in range(2,5):
21
+ self.slice2.add_module(str(x), pretrained_features[x])
22
+ for x in range(5, 8):
23
+ self.slice3.add_module(str(x), pretrained_features[x])
24
+ for x in range(8, 10):
25
+ self.slice4.add_module(str(x), pretrained_features[x])
26
+ for x in range(10, 11):
27
+ self.slice5.add_module(str(x), pretrained_features[x])
28
+ for x in range(11, 12):
29
+ self.slice6.add_module(str(x), pretrained_features[x])
30
+ for x in range(12, 13):
31
+ self.slice7.add_module(str(x), pretrained_features[x])
32
+ if not requires_grad:
33
+ for param in self.parameters():
34
+ param.requires_grad = False
35
+
36
+ def forward(self, X):
37
+ h = self.slice1(X)
38
+ h_relu1 = h
39
+ h = self.slice2(h)
40
+ h_relu2 = h
41
+ h = self.slice3(h)
42
+ h_relu3 = h
43
+ h = self.slice4(h)
44
+ h_relu4 = h
45
+ h = self.slice5(h)
46
+ h_relu5 = h
47
+ h = self.slice6(h)
48
+ h_relu6 = h
49
+ h = self.slice7(h)
50
+ h_relu7 = h
51
+ vgg_outputs = namedtuple("SqueezeOutputs", ['relu1','relu2','relu3','relu4','relu5','relu6','relu7'])
52
+ out = vgg_outputs(h_relu1,h_relu2,h_relu3,h_relu4,h_relu5,h_relu6,h_relu7)
53
+
54
+ return out
55
+
56
+
57
+ class alexnet(torch.nn.Module):
58
+ def __init__(self, requires_grad=False, pretrained=True):
59
+ super(alexnet, self).__init__()
60
+ alexnet_pretrained_features = tv.alexnet(pretrained=pretrained).features
61
+ self.slice1 = torch.nn.Sequential()
62
+ self.slice2 = torch.nn.Sequential()
63
+ self.slice3 = torch.nn.Sequential()
64
+ self.slice4 = torch.nn.Sequential()
65
+ self.slice5 = torch.nn.Sequential()
66
+ self.N_slices = 5
67
+ for x in range(2):
68
+ self.slice1.add_module(str(x), alexnet_pretrained_features[x])
69
+ for x in range(2, 5):
70
+ self.slice2.add_module(str(x), alexnet_pretrained_features[x])
71
+ for x in range(5, 8):
72
+ self.slice3.add_module(str(x), alexnet_pretrained_features[x])
73
+ for x in range(8, 10):
74
+ self.slice4.add_module(str(x), alexnet_pretrained_features[x])
75
+ for x in range(10, 12):
76
+ self.slice5.add_module(str(x), alexnet_pretrained_features[x])
77
+ if not requires_grad:
78
+ for param in self.parameters():
79
+ param.requires_grad = False
80
+
81
+ def forward(self, X):
82
+ h = self.slice1(X)
83
+ h_relu1 = h
84
+ h = self.slice2(h)
85
+ h_relu2 = h
86
+ h = self.slice3(h)
87
+ h_relu3 = h
88
+ h = self.slice4(h)
89
+ h_relu4 = h
90
+ h = self.slice5(h)
91
+ h_relu5 = h
92
+ alexnet_outputs = namedtuple("AlexnetOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5'])
93
+ out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5)
94
+
95
+ return out
96
+
97
+ class vgg16(torch.nn.Module):
98
+ def __init__(self, requires_grad=False, pretrained=True):
99
+ super(vgg16, self).__init__()
100
+ vgg_pretrained_features = tv.vgg16(pretrained=pretrained).features
101
+ self.slice1 = torch.nn.Sequential()
102
+ self.slice2 = torch.nn.Sequential()
103
+ self.slice3 = torch.nn.Sequential()
104
+ self.slice4 = torch.nn.Sequential()
105
+ self.slice5 = torch.nn.Sequential()
106
+ self.N_slices = 5
107
+ for x in range(4):
108
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
109
+ for x in range(4, 9):
110
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
111
+ for x in range(9, 16):
112
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
113
+ for x in range(16, 23):
114
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
115
+ for x in range(23, 30):
116
+ self.slice5.add_module(str(x), vgg_pretrained_features[x])
117
+ if not requires_grad:
118
+ for param in self.parameters():
119
+ param.requires_grad = False
120
+
121
+ def forward(self, X):
122
+ h = self.slice1(X)
123
+ h_relu1_2 = h
124
+ h = self.slice2(h)
125
+ h_relu2_2 = h
126
+ h = self.slice3(h)
127
+ h_relu3_3 = h
128
+ h = self.slice4(h)
129
+ h_relu4_3 = h
130
+ h = self.slice5(h)
131
+ h_relu5_3 = h
132
+ vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
133
+ out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
134
+
135
+ return out
136
+
137
+
138
+
139
+ class resnet(torch.nn.Module):
140
+ def __init__(self, requires_grad=False, pretrained=True, num=18):
141
+ super(resnet, self).__init__()
142
+ if(num==18):
143
+ self.net = tv.resnet18(pretrained=pretrained)
144
+ elif(num==34):
145
+ self.net = tv.resnet34(pretrained=pretrained)
146
+ elif(num==50):
147
+ self.net = tv.resnet50(pretrained=pretrained)
148
+ elif(num==101):
149
+ self.net = tv.resnet101(pretrained=pretrained)
150
+ elif(num==152):
151
+ self.net = tv.resnet152(pretrained=pretrained)
152
+ self.N_slices = 5
153
+
154
+ self.conv1 = self.net.conv1
155
+ self.bn1 = self.net.bn1
156
+ self.relu = self.net.relu
157
+ self.maxpool = self.net.maxpool
158
+ self.layer1 = self.net.layer1
159
+ self.layer2 = self.net.layer2
160
+ self.layer3 = self.net.layer3
161
+ self.layer4 = self.net.layer4
162
+
163
+ def forward(self, X):
164
+ h = self.conv1(X)
165
+ h = self.bn1(h)
166
+ h = self.relu(h)
167
+ h_relu1 = h
168
+ h = self.maxpool(h)
169
+ h = self.layer1(h)
170
+ h_conv2 = h
171
+ h = self.layer2(h)
172
+ h_conv3 = h
173
+ h = self.layer3(h)
174
+ h_conv4 = h
175
+ h = self.layer4(h)
176
+ h_conv5 = h
177
+
178
+ outputs = namedtuple("Outputs", ['relu1','conv2','conv3','conv4','conv5'])
179
+ out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5)
180
+
181
+ return out
stylegan2/lpips/util.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from __future__ import absolute_import
3
+ from __future__ import division
4
+ from __future__ import print_function
5
+
6
+ import numpy as np
7
+ from skimage.metrics import structural_similarity
8
+ import torch
9
+
10
+
11
+ from . import dist_model
12
+
13
+ class PerceptualLoss(torch.nn.Module):
14
+ def __init__(self, model='net-lin', net='alex', colorspace='rgb', spatial=False, use_gpu=True, gpu_ids=[0]): # VGG using our perceptually-learned weights (LPIPS metric)
15
+ # def __init__(self, model='net', net='vgg', use_gpu=True): # "default" way of using VGG as a perceptual loss
16
+ super(PerceptualLoss, self).__init__()
17
+ print('Setting up Perceptual loss...')
18
+ self.use_gpu = use_gpu
19
+ self.spatial = spatial
20
+ self.gpu_ids = gpu_ids
21
+ self.model = dist_model.DistModel()
22
+ self.model.initialize(model=model, net=net, use_gpu=use_gpu, colorspace=colorspace, spatial=self.spatial, gpu_ids=gpu_ids)
23
+ print('...[%s] initialized'%self.model.name())
24
+ print('...Done')
25
+
26
+ def forward(self, pred, target, normalize=False):
27
+ """
28
+ Pred and target are Variables.
29
+ If normalize is True, assumes the images are between [0,1] and then scales them between [-1,+1]
30
+ If normalize is False, assumes the images are already between [-1,+1]
31
+
32
+ Inputs pred and target are Nx3xHxW
33
+ Output pytorch Variable N long
34
+ """
35
+
36
+ if normalize:
37
+ target = 2 * target - 1
38
+ pred = 2 * pred - 1
39
+
40
+ return self.model.forward(target, pred)
41
+
42
+ def normalize_tensor(in_feat,eps=1e-10):
43
+ norm_factor = torch.sqrt(torch.sum(in_feat**2,dim=1,keepdim=True))
44
+ return in_feat/(norm_factor+eps)
45
+
46
+ def l2(p0, p1, range=255.):
47
+ return .5*np.mean((p0 / range - p1 / range)**2)
48
+
49
+ def psnr(p0, p1, peak=255.):
50
+ return 10*np.log10(peak**2/np.mean((1.*p0-1.*p1)**2))
51
+
52
+ def dssim(p0, p1, range=255.):
53
+ return (1 - structural_similarity(p0, p1, data_range=range, multichannel=True)) / 2.
54
+
55
+ def rgb2lab(in_img,mean_cent=False):
56
+ from skimage import color
57
+ img_lab = color.rgb2lab(in_img)
58
+ if(mean_cent):
59
+ img_lab[:,:,0] = img_lab[:,:,0]-50
60
+ return img_lab
61
+
62
+ def tensor2np(tensor_obj):
63
+ # change dimension of a tensor object into a numpy array
64
+ return tensor_obj[0].cpu().float().numpy().transpose((1,2,0))
65
+
66
+ def np2tensor(np_obj):
67
+ # change dimenion of np array into tensor array
68
+ return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
69
+
70
+ def tensor2tensorlab(image_tensor,to_norm=True,mc_only=False):
71
+ # image tensor to lab tensor
72
+ from skimage import color
73
+
74
+ img = tensor2im(image_tensor)
75
+ img_lab = color.rgb2lab(img)
76
+ if(mc_only):
77
+ img_lab[:,:,0] = img_lab[:,:,0]-50
78
+ if(to_norm and not mc_only):
79
+ img_lab[:,:,0] = img_lab[:,:,0]-50
80
+ img_lab = img_lab/100.
81
+
82
+ return np2tensor(img_lab)
83
+
84
+ def tensorlab2tensor(lab_tensor,return_inbnd=False):
85
+ from skimage import color
86
+ import warnings
87
+ warnings.filterwarnings("ignore")
88
+
89
+ lab = tensor2np(lab_tensor)*100.
90
+ lab[:,:,0] = lab[:,:,0]+50
91
+
92
+ rgb_back = 255.*np.clip(color.lab2rgb(lab.astype('float')),0,1)
93
+ if(return_inbnd):
94
+ # convert back to lab, see if we match
95
+ lab_back = color.rgb2lab(rgb_back.astype('uint8'))
96
+ mask = 1.*np.isclose(lab_back,lab,atol=2.)
97
+ mask = np2tensor(np.prod(mask,axis=2)[:,:,np.newaxis])
98
+ return (im2tensor(rgb_back),mask)
99
+ else:
100
+ return im2tensor(rgb_back)
101
+
102
+ def rgb2lab(input):
103
+ from skimage import color
104
+ return color.rgb2lab(input / 255.)
105
+
106
+ def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.):
107
+ image_numpy = image_tensor[0].cpu().float().numpy()
108
+ image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor
109
+ return image_numpy.astype(imtype)
110
+
111
+ def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.):
112
+ return torch.Tensor((image / factor - cent)
113
+ [:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
114
+
115
+ def tensor2vec(vector_tensor):
116
+ return vector_tensor.data.cpu().numpy()[:, :, 0, 0]
117
+
118
+ def voc_ap(rec, prec, use_07_metric=False):
119
+ """ ap = voc_ap(rec, prec, [use_07_metric])
120
+ Compute VOC AP given precision and recall.
121
+ If use_07_metric is true, uses the
122
+ VOC 07 11 point method (default:False).
123
+ """
124
+ if use_07_metric:
125
+ # 11 point metric
126
+ ap = 0.
127
+ for t in np.arange(0., 1.1, 0.1):
128
+ if np.sum(rec >= t) == 0:
129
+ p = 0
130
+ else:
131
+ p = np.max(prec[rec >= t])
132
+ ap = ap + p / 11.
133
+ else:
134
+ # correct AP calculation
135
+ # first append sentinel values at the end
136
+ mrec = np.concatenate(([0.], rec, [1.]))
137
+ mpre = np.concatenate(([0.], prec, [0.]))
138
+
139
+ # compute the precision envelope
140
+ for i in range(mpre.size - 1, 0, -1):
141
+ mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])
142
+
143
+ # to calculate area under PR curve, look for points
144
+ # where X axis (recall) changes value
145
+ i = np.where(mrec[1:] != mrec[:-1])[0]
146
+
147
+ # and sum (\Delta recall) * prec
148
+ ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
149
+ return ap
150
+
151
+ def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.):
152
+ # def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=1.):
153
+ image_numpy = image_tensor[0].cpu().float().numpy()
154
+ image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor
155
+ return image_numpy.astype(imtype)
156
+
157
+ def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.):
158
+ # def im2tensor(image, imtype=np.uint8, cent=1., factor=1.):
159
+ return torch.Tensor((image / factor - cent)
160
+ [:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
stylegan2/model.py ADDED
@@ -0,0 +1,714 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+
4
+ import torch
5
+ from torch import nn
6
+ from torch.nn import functional as F
7
+
8
+ from .op.fused_act import fused
9
+
10
+ if fused is not None:
11
+ from .op.fused_act import FusedLeakyReLU, fused_leaky_relu
12
+ else:
13
+ from .op import FusedLeakyReLU_Native as FusedLeakyReLU
14
+ from .op import fused_leaky_relu_native as fused_leaky_relu
15
+
16
+ from .op.upfirdn2d import upfirdn2d_op
17
+
18
+ if upfirdn2d_op is not None:
19
+ from .op.upfirdn2d import upfirdn2d
20
+ else:
21
+ from .op import upfirdn2d_native as upfirdn2d
22
+
23
+ from .op import conv2d_gradfix
24
+
25
+ # https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.py#L152
26
+ # https://github.com/rosinality/stylegan2-pytorch/issues/70
27
+
28
+
29
+ class PixelNorm(nn.Module):
30
+ def __init__(self):
31
+ super().__init__()
32
+
33
+ def forward(self, input):
34
+ return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
35
+
36
+
37
+ def make_kernel(k):
38
+ k = torch.tensor(k, dtype=torch.float32)
39
+
40
+ if k.ndim == 1:
41
+ k = k[None, :] * k[:, None]
42
+
43
+ k /= k.sum()
44
+
45
+ return k
46
+
47
+
48
+ class Upsample(nn.Module):
49
+ def __init__(self, kernel, factor=2):
50
+ super().__init__()
51
+
52
+ self.factor = factor
53
+ kernel = make_kernel(kernel) * (factor ** 2)
54
+ self.register_buffer("kernel", kernel)
55
+
56
+ p = kernel.shape[0] - factor
57
+
58
+ pad0 = (p + 1) // 2 + factor - 1
59
+ pad1 = p // 2
60
+
61
+ self.pad = (pad0, pad1)
62
+
63
+ def forward(self, input):
64
+ out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
65
+
66
+ return out
67
+
68
+
69
+ class Downsample(nn.Module):
70
+ def __init__(self, kernel, factor=2):
71
+ super().__init__()
72
+
73
+ self.factor = factor
74
+ kernel = make_kernel(kernel)
75
+ self.register_buffer("kernel", kernel)
76
+
77
+ p = kernel.shape[0] - factor
78
+
79
+ pad0 = (p + 1) // 2
80
+ pad1 = p // 2
81
+
82
+ self.pad = (pad0, pad1)
83
+
84
+ def forward(self, input):
85
+ out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)
86
+
87
+ return out
88
+
89
+
90
+ class Blur(nn.Module):
91
+ def __init__(self, kernel, pad, upsample_factor=1):
92
+ super().__init__()
93
+
94
+ kernel = make_kernel(kernel)
95
+
96
+ if upsample_factor > 1:
97
+ kernel = kernel * (upsample_factor ** 2)
98
+
99
+ self.register_buffer("kernel", kernel)
100
+
101
+ self.pad = pad
102
+
103
+ def forward(self, input):
104
+ out = upfirdn2d(input, self.kernel, pad=self.pad)
105
+
106
+ return out
107
+
108
+
109
+ class EqualConv2d(nn.Module):
110
+ def __init__(
111
+ self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True
112
+ ):
113
+ super().__init__()
114
+
115
+ self.weight = nn.Parameter(
116
+ torch.randn(out_channel, in_channel, kernel_size, kernel_size)
117
+ )
118
+ self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
119
+
120
+ self.stride = stride
121
+ self.padding = padding
122
+
123
+ if bias:
124
+ self.bias = nn.Parameter(torch.zeros(out_channel))
125
+
126
+ else:
127
+ self.bias = None
128
+
129
+ def forward(self, input):
130
+ out = conv2d_gradfix.conv2d(
131
+ input,
132
+ self.weight * self.scale,
133
+ bias=self.bias,
134
+ stride=self.stride,
135
+ padding=self.padding,
136
+ )
137
+
138
+ return out
139
+
140
+ def __repr__(self):
141
+ return (
142
+ f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},"
143
+ f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})"
144
+ )
145
+
146
+
147
+ class EqualLinear(nn.Module):
148
+ def __init__(
149
+ self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
150
+ ):
151
+ super().__init__()
152
+
153
+ self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
154
+
155
+ if bias:
156
+ self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
157
+
158
+ else:
159
+ self.bias = None
160
+
161
+ self.activation = activation
162
+
163
+ self.scale = (1 / math.sqrt(in_dim)) * lr_mul
164
+ self.lr_mul = lr_mul
165
+
166
+ def forward(self, input):
167
+ if self.activation:
168
+ out = F.linear(input, self.weight * self.scale)
169
+ out = fused_leaky_relu(out, self.bias * self.lr_mul)
170
+
171
+ else:
172
+ out = F.linear(
173
+ input, self.weight * self.scale, bias=self.bias * self.lr_mul
174
+ )
175
+
176
+ return out
177
+
178
+ def __repr__(self):
179
+ return (
180
+ f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})"
181
+ )
182
+
183
+
184
+ class ModulatedConv2d(nn.Module):
185
+ def __init__(
186
+ self,
187
+ in_channel,
188
+ out_channel,
189
+ kernel_size,
190
+ style_dim,
191
+ demodulate=True,
192
+ upsample=False,
193
+ downsample=False,
194
+ blur_kernel=[1, 3, 3, 1],
195
+ fused=True,
196
+ ):
197
+ super().__init__()
198
+
199
+ self.eps = 1e-8
200
+ self.kernel_size = kernel_size
201
+ self.in_channel = in_channel
202
+ self.out_channel = out_channel
203
+ self.upsample = upsample
204
+ self.downsample = downsample
205
+
206
+ if upsample:
207
+ factor = 2
208
+ p = (len(blur_kernel) - factor) - (kernel_size - 1)
209
+ pad0 = (p + 1) // 2 + factor - 1
210
+ pad1 = p // 2 + 1
211
+
212
+ self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)
213
+
214
+ if downsample:
215
+ factor = 2
216
+ p = (len(blur_kernel) - factor) + (kernel_size - 1)
217
+ pad0 = (p + 1) // 2
218
+ pad1 = p // 2
219
+
220
+ self.blur = Blur(blur_kernel, pad=(pad0, pad1))
221
+
222
+ fan_in = in_channel * kernel_size ** 2
223
+ self.scale = 1 / math.sqrt(fan_in)
224
+ self.padding = kernel_size // 2
225
+
226
+ self.weight = nn.Parameter(
227
+ torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
228
+ )
229
+
230
+ self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
231
+
232
+ self.demodulate = demodulate
233
+ self.fused = fused
234
+
235
+ def __repr__(self):
236
+ return (
237
+ f"{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, "
238
+ f"upsample={self.upsample}, downsample={self.downsample})"
239
+ )
240
+
241
+ def forward(self, input, style):
242
+ batch, in_channel, height, width = input.shape
243
+
244
+ if not self.fused:
245
+ weight = self.scale * self.weight.squeeze(0)
246
+ style = self.modulation(style)
247
+
248
+ if self.demodulate:
249
+ w = weight.unsqueeze(0) * style.view(batch, 1, in_channel, 1, 1)
250
+ dcoefs = (w.square().sum((2, 3, 4)) + 1e-8).rsqrt()
251
+
252
+ input = input * style.reshape(batch, in_channel, 1, 1)
253
+
254
+ if self.upsample:
255
+ weight = weight.transpose(0, 1)
256
+ out = conv2d_gradfix.conv_transpose2d(
257
+ input, weight, padding=0, stride=2
258
+ )
259
+ out = self.blur(out)
260
+
261
+ elif self.downsample:
262
+ input = self.blur(input)
263
+ out = conv2d_gradfix.conv2d(input, weight, padding=0, stride=2)
264
+
265
+ else:
266
+ out = conv2d_gradfix.conv2d(input, weight, padding=self.padding)
267
+
268
+ if self.demodulate:
269
+ out = out * dcoefs.view(batch, -1, 1, 1)
270
+
271
+ return out
272
+
273
+ style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
274
+ weight = self.scale * self.weight * style
275
+
276
+ if self.demodulate:
277
+ demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
278
+ weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)
279
+
280
+ weight = weight.view(
281
+ batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size
282
+ )
283
+
284
+ if self.upsample:
285
+ input = input.view(1, batch * in_channel, height, width)
286
+ weight = weight.view(
287
+ batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size
288
+ )
289
+ weight = weight.transpose(1, 2).reshape(
290
+ batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size
291
+ )
292
+ out = conv2d_gradfix.conv_transpose2d(
293
+ input, weight, padding=0, stride=2, groups=batch
294
+ )
295
+ _, _, height, width = out.shape
296
+ out = out.view(batch, self.out_channel, height, width)
297
+ out = self.blur(out)
298
+
299
+ elif self.downsample:
300
+ input = self.blur(input)
301
+ _, _, height, width = input.shape
302
+ input = input.view(1, batch * in_channel, height, width)
303
+ out = conv2d_gradfix.conv2d(
304
+ input, weight, padding=0, stride=2, groups=batch
305
+ )
306
+ _, _, height, width = out.shape
307
+ out = out.view(batch, self.out_channel, height, width)
308
+
309
+ else:
310
+ input = input.view(1, batch * in_channel, height, width)
311
+ out = conv2d_gradfix.conv2d(
312
+ input, weight, padding=self.padding, groups=batch
313
+ )
314
+ _, _, height, width = out.shape
315
+ out = out.view(batch, self.out_channel, height, width)
316
+
317
+ return out
318
+
319
+
320
+ class NoiseInjection(nn.Module):
321
+ def __init__(self):
322
+ super().__init__()
323
+
324
+ self.weight = nn.Parameter(torch.zeros(1))
325
+
326
+ def forward(self, image, noise=None):
327
+ if noise is None:
328
+ batch, _, height, width = image.shape
329
+ noise = image.new_empty(batch, 1, height, width).normal_()
330
+
331
+ return image + self.weight * noise
332
+
333
+
334
+ class ConstantInput(nn.Module):
335
+ def __init__(self, channel, size=4):
336
+ super().__init__()
337
+
338
+ self.input = nn.Parameter(torch.randn(1, channel, size, size))
339
+
340
+ def forward(self, input):
341
+ batch = input.shape[0]
342
+ out = self.input.repeat(batch, 1, 1, 1)
343
+
344
+ return out
345
+
346
+
347
+ class StyledConv(nn.Module):
348
+ def __init__(
349
+ self,
350
+ in_channel,
351
+ out_channel,
352
+ kernel_size,
353
+ style_dim,
354
+ upsample=False,
355
+ blur_kernel=[1, 3, 3, 1],
356
+ demodulate=True,
357
+ ):
358
+ super().__init__()
359
+
360
+ self.conv = ModulatedConv2d(
361
+ in_channel,
362
+ out_channel,
363
+ kernel_size,
364
+ style_dim,
365
+ upsample=upsample,
366
+ blur_kernel=blur_kernel,
367
+ demodulate=demodulate,
368
+ )
369
+
370
+ self.noise = NoiseInjection()
371
+ # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1))
372
+ # self.activate = ScaledLeakyReLU(0.2)
373
+ self.activate = FusedLeakyReLU(out_channel)
374
+
375
+ def forward(self, input, style, noise=None):
376
+ out = self.conv(input, style)
377
+ out = self.noise(out, noise=noise)
378
+ # out = out + self.bias
379
+ out = self.activate(out)
380
+
381
+ return out
382
+
383
+
384
+ class ToRGB(nn.Module):
385
+ def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):
386
+ super().__init__()
387
+
388
+ if upsample:
389
+ self.upsample = Upsample(blur_kernel)
390
+
391
+ self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False)
392
+ self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
393
+
394
+ def forward(self, input, style, skip=None):
395
+ out = self.conv(input, style)
396
+ out = out + self.bias
397
+
398
+ if skip is not None:
399
+ skip = self.upsample(skip)
400
+
401
+ out = out + skip
402
+
403
+ return out
404
+
405
+
406
+ class Generator(nn.Module):
407
+ def __init__(
408
+ self,
409
+ size,
410
+ style_dim,
411
+ n_mlp,
412
+ channel_multiplier=2,
413
+ blur_kernel=[1, 3, 3, 1],
414
+ lr_mlp=0.01,
415
+ ):
416
+ super().__init__()
417
+
418
+ self.size = size
419
+
420
+ self.style_dim = style_dim
421
+
422
+ layers = [PixelNorm()]
423
+
424
+ for i in range(n_mlp):
425
+ layers.append(
426
+ EqualLinear(
427
+ style_dim, style_dim, lr_mul=lr_mlp, activation="fused_lrelu"
428
+ )
429
+ )
430
+
431
+ self.style = nn.Sequential(*layers)
432
+
433
+ self.channels = {
434
+ 4: 512,
435
+ 8: 512,
436
+ 16: 512,
437
+ 32: 512,
438
+ 64: 256 * channel_multiplier,
439
+ 128: 128 * channel_multiplier,
440
+ 256: 64 * channel_multiplier,
441
+ 512: 32 * channel_multiplier,
442
+ 1024: 16 * channel_multiplier,
443
+ }
444
+
445
+ self.input = ConstantInput(self.channels[4])
446
+ self.conv1 = StyledConv(
447
+ self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel
448
+ )
449
+ self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)
450
+
451
+ self.log_size = int(math.log(size, 2))
452
+ self.num_layers = (self.log_size - 2) * 2 + 1
453
+
454
+ self.convs = nn.ModuleList()
455
+ self.upsamples = nn.ModuleList()
456
+ self.to_rgbs = nn.ModuleList()
457
+ self.noises = nn.Module()
458
+
459
+ in_channel = self.channels[4]
460
+
461
+ for layer_idx in range(self.num_layers):
462
+ res = (layer_idx + 5) // 2
463
+ shape = [1, 1, 2 ** res, 2 ** res]
464
+ self.noises.register_buffer(f"noise_{layer_idx}", torch.randn(*shape))
465
+
466
+ for i in range(3, self.log_size + 1):
467
+ out_channel = self.channels[2 ** i]
468
+
469
+ self.convs.append(
470
+ StyledConv(
471
+ in_channel,
472
+ out_channel,
473
+ 3,
474
+ style_dim,
475
+ upsample=True,
476
+ blur_kernel=blur_kernel,
477
+ )
478
+ )
479
+
480
+ self.convs.append(
481
+ StyledConv(
482
+ out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel
483
+ )
484
+ )
485
+
486
+ self.to_rgbs.append(ToRGB(out_channel, style_dim))
487
+
488
+ in_channel = out_channel
489
+
490
+ self.n_latent = self.log_size * 2 - 2
491
+
492
+ def make_noise(self):
493
+ device = self.input.input.device
494
+
495
+ noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)]
496
+
497
+ for i in range(3, self.log_size + 1):
498
+ for _ in range(2):
499
+ noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device))
500
+
501
+ return noises
502
+
503
+ def mean_latent(self, n_latent):
504
+ latent_in = torch.randn(
505
+ n_latent, self.style_dim, device=self.input.input.device
506
+ )
507
+ latent = self.style(latent_in).mean(0, keepdim=True)
508
+
509
+ return latent
510
+
511
+ def get_latent(self, input):
512
+ return self.style(input)
513
+
514
+ def forward(
515
+ self,
516
+ styles,
517
+ return_latents=False,
518
+ inject_index=None,
519
+ truncation=1,
520
+ truncation_latent=None,
521
+ input_is_latent=False,
522
+ noise=None,
523
+ randomize_noise=True,
524
+ ):
525
+ if not input_is_latent:
526
+ styles = [self.style(s) for s in styles]
527
+
528
+ if noise is None:
529
+ if randomize_noise:
530
+ noise = [None] * self.num_layers
531
+ else:
532
+ noise = [
533
+ getattr(self.noises, f"noise_{i}") for i in range(self.num_layers)
534
+ ]
535
+
536
+ if truncation < 1:
537
+ style_t = []
538
+
539
+ for style in styles:
540
+ style_t.append(
541
+ truncation_latent + truncation * (style - truncation_latent)
542
+ )
543
+
544
+ styles = style_t
545
+
546
+ if len(styles) < 2:
547
+ inject_index = self.n_latent
548
+
549
+ if styles[0].ndim < 3:
550
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
551
+
552
+ else:
553
+ latent = styles[0]
554
+
555
+ else:
556
+ if inject_index is None:
557
+ inject_index = random.randint(1, self.n_latent - 1)
558
+
559
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
560
+ latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)
561
+
562
+ latent = torch.cat([latent, latent2], 1)
563
+
564
+ out = self.input(latent)
565
+ out = self.conv1(out, latent[:, 0], noise=noise[0])
566
+
567
+ skip = self.to_rgb1(out, latent[:, 1])
568
+
569
+ i = 1
570
+ for conv1, conv2, noise1, noise2, to_rgb in zip(
571
+ self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
572
+ ):
573
+ out = conv1(out, latent[:, i], noise=noise1)
574
+ out = conv2(out, latent[:, i + 1], noise=noise2)
575
+ skip = to_rgb(out, latent[:, i + 2], skip)
576
+
577
+ i += 2
578
+
579
+
580
+ image = skip
581
+
582
+ if return_latents:
583
+ return image, latent
584
+
585
+ else:
586
+ return image, None
587
+
588
+
589
+ class ConvLayer(nn.Sequential):
590
+ def __init__(
591
+ self,
592
+ in_channel,
593
+ out_channel,
594
+ kernel_size,
595
+ downsample=False,
596
+ blur_kernel=[1, 3, 3, 1],
597
+ bias=True,
598
+ activate=True,
599
+ ):
600
+ layers = []
601
+
602
+ if downsample:
603
+ factor = 2
604
+ p = (len(blur_kernel) - factor) + (kernel_size - 1)
605
+ pad0 = (p + 1) // 2
606
+ pad1 = p // 2
607
+
608
+ layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
609
+
610
+ stride = 2
611
+ self.padding = 0
612
+
613
+ else:
614
+ stride = 1
615
+ self.padding = kernel_size // 2
616
+
617
+ layers.append(
618
+ EqualConv2d(
619
+ in_channel,
620
+ out_channel,
621
+ kernel_size,
622
+ padding=self.padding,
623
+ stride=stride,
624
+ bias=bias and not activate,
625
+ )
626
+ )
627
+
628
+ if activate:
629
+ layers.append(FusedLeakyReLU(out_channel, bias=bias))
630
+
631
+ super().__init__(*layers)
632
+
633
+
634
+ class ResBlock(nn.Module):
635
+ def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
636
+ super().__init__()
637
+
638
+ self.conv1 = ConvLayer(in_channel, in_channel, 3)
639
+ self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
640
+
641
+ self.skip = ConvLayer(
642
+ in_channel, out_channel, 1, downsample=True, activate=False, bias=False
643
+ )
644
+
645
+ def forward(self, input):
646
+ out = self.conv1(input)
647
+ out = self.conv2(out)
648
+
649
+ skip = self.skip(input)
650
+ out = (out + skip) / math.sqrt(2)
651
+
652
+ return out
653
+
654
+
655
+ class Discriminator(nn.Module):
656
+ def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]):
657
+ super().__init__()
658
+
659
+ channels = {
660
+ 4: 512,
661
+ 8: 512,
662
+ 16: 512,
663
+ 32: 512,
664
+ 64: 256 * channel_multiplier,
665
+ 128: 128 * channel_multiplier,
666
+ 256: 64 * channel_multiplier,
667
+ 512: 32 * channel_multiplier,
668
+ 1024: 16 * channel_multiplier,
669
+ }
670
+
671
+ convs = [ConvLayer(3, channels[size], 1)]
672
+
673
+ log_size = int(math.log(size, 2))
674
+
675
+ in_channel = channels[size]
676
+
677
+ for i in range(log_size, 2, -1):
678
+ out_channel = channels[2 ** (i - 1)]
679
+
680
+ convs.append(ResBlock(in_channel, out_channel, blur_kernel))
681
+
682
+ in_channel = out_channel
683
+
684
+ self.convs = nn.Sequential(*convs)
685
+
686
+ self.stddev_group = 4
687
+ self.stddev_feat = 1
688
+
689
+ self.final_conv = ConvLayer(in_channel + 1, channels[4], 3)
690
+ self.final_linear = nn.Sequential(
691
+ EqualLinear(channels[4] * 4 * 4, channels[4], activation="fused_lrelu"),
692
+ EqualLinear(channels[4], 1),
693
+ )
694
+
695
+ def forward(self, input):
696
+ out = self.convs(input)
697
+
698
+ batch, channel, height, width = out.shape
699
+ group = min(batch, self.stddev_group)
700
+ stddev = out.view(
701
+ group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
702
+ )
703
+ stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
704
+ stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
705
+ stddev = stddev.repeat(group, 1, height, width)
706
+ out = torch.cat([out, stddev], 1)
707
+
708
+ out = self.final_conv(out)
709
+
710
+ out = out.view(batch, -1)
711
+ out = self.final_linear(out)
712
+
713
+ return out
714
+
stylegan2/op/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .fused_act import FusedLeakyReLU, fused_leaky_relu, fused_leaky_relu_native, FusedLeakyReLU_Native
2
+ from .upfirdn2d import upfirdn2d, upfirdn2d_native
stylegan2/op/conv2d_gradfix.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+ import warnings
3
+
4
+ import torch
5
+ from torch import autograd
6
+ from torch.nn import functional as F
7
+
8
+ enabled = True
9
+ weight_gradients_disabled = False
10
+
11
+
12
+ @contextlib.contextmanager
13
+ def no_weight_gradients():
14
+ global weight_gradients_disabled
15
+
16
+ old = weight_gradients_disabled
17
+ weight_gradients_disabled = True
18
+ yield
19
+ weight_gradients_disabled = old
20
+
21
+
22
+ def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
23
+ if could_use_op(input):
24
+ return conv2d_gradfix(
25
+ transpose=False,
26
+ weight_shape=weight.shape,
27
+ stride=stride,
28
+ padding=padding,
29
+ output_padding=0,
30
+ dilation=dilation,
31
+ groups=groups,
32
+ ).apply(input, weight, bias)
33
+
34
+ return F.conv2d(
35
+ input=input,
36
+ weight=weight,
37
+ bias=bias,
38
+ stride=stride,
39
+ padding=padding,
40
+ dilation=dilation,
41
+ groups=groups,
42
+ )
43
+
44
+
45
+ def conv_transpose2d(
46
+ input,
47
+ weight,
48
+ bias=None,
49
+ stride=1,
50
+ padding=0,
51
+ output_padding=0,
52
+ groups=1,
53
+ dilation=1,
54
+ ):
55
+ if could_use_op(input):
56
+ return conv2d_gradfix(
57
+ transpose=True,
58
+ weight_shape=weight.shape,
59
+ stride=stride,
60
+ padding=padding,
61
+ output_padding=output_padding,
62
+ groups=groups,
63
+ dilation=dilation,
64
+ ).apply(input, weight, bias)
65
+
66
+ return F.conv_transpose2d(
67
+ input=input,
68
+ weight=weight,
69
+ bias=bias,
70
+ stride=stride,
71
+ padding=padding,
72
+ output_padding=output_padding,
73
+ dilation=dilation,
74
+ groups=groups,
75
+ )
76
+
77
+
78
+ def could_use_op(input):
79
+ return False
80
+
81
+ if (not enabled) or (not torch.backends.cudnn.enabled):
82
+ return False
83
+
84
+ if input.device.type != "cuda":
85
+ return False
86
+
87
+ if any(torch.__version__.startswith(x) for x in ["1.7.", "1.8."]):
88
+ return True
89
+
90
+ warnings.warn(
91
+ f"conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d()."
92
+ )
93
+
94
+ return False
95
+
96
+
97
+ def ensure_tuple(xs, ndim):
98
+ xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim
99
+
100
+ return xs
101
+
102
+
103
+ conv2d_gradfix_cache = dict()
104
+
105
+
106
+ def conv2d_gradfix(
107
+ transpose, weight_shape, stride, padding, output_padding, dilation, groups
108
+ ):
109
+ ndim = 2
110
+ weight_shape = tuple(weight_shape)
111
+ stride = ensure_tuple(stride, ndim)
112
+ padding = ensure_tuple(padding, ndim)
113
+ output_padding = ensure_tuple(output_padding, ndim)
114
+ dilation = ensure_tuple(dilation, ndim)
115
+
116
+ key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups)
117
+ if key in conv2d_gradfix_cache:
118
+ return conv2d_gradfix_cache[key]
119
+
120
+ common_kwargs = dict(
121
+ stride=stride, padding=padding, dilation=dilation, groups=groups
122
+ )
123
+
124
+ def calc_output_padding(input_shape, output_shape):
125
+ if transpose:
126
+ return [0, 0]
127
+
128
+ return [
129
+ input_shape[i + 2]
130
+ - (output_shape[i + 2] - 1) * stride[i]
131
+ - (1 - 2 * padding[i])
132
+ - dilation[i] * (weight_shape[i + 2] - 1)
133
+ for i in range(ndim)
134
+ ]
135
+
136
+ class Conv2d(autograd.Function):
137
+ @staticmethod
138
+ def forward(ctx, input, weight, bias):
139
+ if not transpose:
140
+ out = F.conv2d(input=input, weight=weight, bias=bias, **common_kwargs)
141
+
142
+ else:
143
+ out = F.conv_transpose2d(
144
+ input=input,
145
+ weight=weight,
146
+ bias=bias,
147
+ output_padding=output_padding,
148
+ **common_kwargs,
149
+ )
150
+
151
+ ctx.save_for_backward(input, weight)
152
+
153
+ return out
154
+
155
+ @staticmethod
156
+ def backward(ctx, grad_output):
157
+ input, weight = ctx.saved_tensors
158
+ grad_input, grad_weight, grad_bias = None, None, None
159
+
160
+ if ctx.needs_input_grad[0]:
161
+ p = calc_output_padding(
162
+ input_shape=input.shape, output_shape=grad_output.shape
163
+ )
164
+ grad_input = conv2d_gradfix(
165
+ transpose=(not transpose),
166
+ weight_shape=weight_shape,
167
+ output_padding=p,
168
+ **common_kwargs,
169
+ ).apply(grad_output, weight, None)
170
+
171
+ if ctx.needs_input_grad[1] and not weight_gradients_disabled:
172
+ grad_weight = Conv2dGradWeight.apply(grad_output, input)
173
+
174
+ if ctx.needs_input_grad[2]:
175
+ grad_bias = grad_output.sum((0, 2, 3))
176
+
177
+ return grad_input, grad_weight, grad_bias
178
+
179
+ class Conv2dGradWeight(autograd.Function):
180
+ @staticmethod
181
+ def forward(ctx, grad_output, input):
182
+ op = torch._C._jit_get_operation(
183
+ "aten::cudnn_convolution_backward_weight"
184
+ if not transpose
185
+ else "aten::cudnn_convolution_transpose_backward_weight"
186
+ )
187
+ flags = [
188
+ torch.backends.cudnn.benchmark,
189
+ torch.backends.cudnn.deterministic,
190
+ torch.backends.cudnn.allow_tf32,
191
+ ]
192
+ grad_weight = op(
193
+ weight_shape,
194
+ grad_output,
195
+ input,
196
+ padding,
197
+ stride,
198
+ dilation,
199
+ groups,
200
+ *flags,
201
+ )
202
+ ctx.save_for_backward(grad_output, input)
203
+
204
+ return grad_weight
205
+
206
+ @staticmethod
207
+ def backward(ctx, grad_grad_weight):
208
+ grad_output, input = ctx.saved_tensors
209
+ grad_grad_output, grad_grad_input = None, None
210
+
211
+ if ctx.needs_input_grad[0]:
212
+ grad_grad_output = Conv2d.apply(input, grad_grad_weight, None)
213
+
214
+ if ctx.needs_input_grad[1]:
215
+ p = calc_output_padding(
216
+ input_shape=input.shape, output_shape=grad_output.shape
217
+ )
218
+ grad_grad_input = conv2d_gradfix(
219
+ transpose=(not transpose),
220
+ weight_shape=weight_shape,
221
+ output_padding=p,
222
+ **common_kwargs,
223
+ ).apply(grad_output, grad_grad_weight, None)
224
+
225
+ return grad_grad_output, grad_grad_input
226
+
227
+ conv2d_gradfix_cache[key] = Conv2d
228
+
229
+ return Conv2d
stylegan2/op/fused_act.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+ from torch.autograd import Function
7
+ from torch.utils.cpp_extension import load
8
+
9
+ import warnings
10
+
11
+ module_path = os.path.dirname(os.path.abspath(__file__))
12
+
13
+ try:
14
+ fused = load(
15
+ "fused",
16
+ sources=[
17
+ os.path.join(module_path, "fused_bias_act.cpp"),
18
+ os.path.join(module_path, "fused_bias_act_kernel.cu"),
19
+ ],
20
+ )
21
+ except:
22
+ warnings.warn(
23
+ f"(This is not error) Switch to native implementation"
24
+ )
25
+
26
+ fused = None
27
+
28
+
29
+ class FusedLeakyReLUFunctionBackward(Function):
30
+ @staticmethod
31
+ def forward(ctx, grad_output, out, bias, negative_slope, scale):
32
+ ctx.save_for_backward(out)
33
+ ctx.negative_slope = negative_slope
34
+ ctx.scale = scale
35
+
36
+ empty = grad_output.new_empty(0)
37
+
38
+ grad_input = fused.fused_bias_act(
39
+ grad_output.contiguous(), empty, out, 3, 1, negative_slope, scale
40
+ )
41
+
42
+ dim = [0]
43
+
44
+ if grad_input.ndim > 2:
45
+ dim += list(range(2, grad_input.ndim))
46
+
47
+ if bias:
48
+ grad_bias = grad_input.sum(dim).detach()
49
+
50
+ else:
51
+ grad_bias = empty
52
+
53
+ return grad_input, grad_bias
54
+
55
+ @staticmethod
56
+ def backward(ctx, gradgrad_input, gradgrad_bias):
57
+ out, = ctx.saved_tensors
58
+ gradgrad_out = fused.fused_bias_act(
59
+ gradgrad_input.contiguous(),
60
+ gradgrad_bias,
61
+ out,
62
+ 3,
63
+ 1,
64
+ ctx.negative_slope,
65
+ ctx.scale,
66
+ )
67
+
68
+ return gradgrad_out, None, None, None, None
69
+
70
+
71
+ class FusedLeakyReLUFunction(Function):
72
+ @staticmethod
73
+ def forward(ctx, input, bias, negative_slope, scale):
74
+ empty = input.new_empty(0)
75
+
76
+ ctx.bias = bias is not None
77
+
78
+ if bias is None:
79
+ bias = empty
80
+
81
+ out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)
82
+ ctx.save_for_backward(out)
83
+ ctx.negative_slope = negative_slope
84
+ ctx.scale = scale
85
+
86
+ return out
87
+
88
+ @staticmethod
89
+ def backward(ctx, grad_output):
90
+ out, = ctx.saved_tensors
91
+
92
+ grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(
93
+ grad_output, out, ctx.bias, ctx.negative_slope, ctx.scale
94
+ )
95
+
96
+ if not ctx.bias:
97
+ grad_bias = None
98
+
99
+ return grad_input, grad_bias, None, None
100
+
101
+
102
+ class FusedLeakyReLU(nn.Module):
103
+ def __init__(self, channel, bias=True, negative_slope=0.2, scale=2 ** 0.5):
104
+ super().__init__()
105
+
106
+ if bias:
107
+ self.bias = nn.Parameter(torch.zeros(channel))
108
+
109
+ else:
110
+ self.bias = None
111
+
112
+ self.negative_slope = negative_slope
113
+ self.scale = scale
114
+
115
+ def forward(self, input):
116
+ return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
117
+
118
+
119
+ def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5):
120
+ if input.device.type == "cpu":
121
+ if bias is not None:
122
+ rest_dim = [1] * (input.ndim - bias.ndim - 1)
123
+ return (
124
+ F.leaky_relu(
125
+ input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2
126
+ )
127
+ * scale
128
+ )
129
+
130
+ else:
131
+ return F.leaky_relu(input, negative_slope=0.2) * scale
132
+
133
+ else:
134
+ return FusedLeakyReLUFunction.apply(
135
+ input.contiguous(), bias, negative_slope, scale
136
+ )
137
+
138
+
139
+ class FusedLeakyReLU_Native(nn.Module):
140
+ def __init__(self, channel, bias=True, negative_slope=0.2, scale=2 ** 0.5):
141
+ super().__init__()
142
+
143
+ if bias:
144
+ self.bias = nn.Parameter(torch.zeros(channel))
145
+
146
+ else:
147
+ self.bias = None
148
+
149
+ self.negative_slope = negative_slope
150
+ self.scale = scale
151
+
152
+ def forward(self, input):
153
+ return fused_leaky_relu_native(input, self.bias, self.negative_slope, self.scale)
154
+
155
+
156
+ def fused_leaky_relu_native(input, bias, negative_slope=0.2, scale=2 ** 0.5):
157
+ return scale * F.leaky_relu(input + bias.view((1, -1) + (1,) * (len(input.shape) - 2)), negative_slope=negative_slope)
stylegan2/op/fused_bias_act.cpp ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ #include <ATen/ATen.h>
3
+ #include <torch/extension.h>
4
+
5
+ torch::Tensor fused_bias_act_op(const torch::Tensor &input,
6
+ const torch::Tensor &bias,
7
+ const torch::Tensor &refer, int act, int grad,
8
+ float alpha, float scale);
9
+
10
+ #define CHECK_CUDA(x) \
11
+ TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
12
+ #define CHECK_CONTIGUOUS(x) \
13
+ TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
14
+ #define CHECK_INPUT(x) \
15
+ CHECK_CUDA(x); \
16
+ CHECK_CONTIGUOUS(x)
17
+
18
+ torch::Tensor fused_bias_act(const torch::Tensor &input,
19
+ const torch::Tensor &bias,
20
+ const torch::Tensor &refer, int act, int grad,
21
+ float alpha, float scale) {
22
+ CHECK_INPUT(input);
23
+ CHECK_INPUT(bias);
24
+
25
+ at::DeviceGuard guard(input.device());
26
+
27
+ return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);
28
+ }
29
+
30
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
31
+ m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)");
32
+ }
stylegan2/op/fused_bias_act_kernel.cu ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2
+ //
3
+ // This work is made available under the Nvidia Source Code License-NC.
4
+ // To view a copy of this license, visit
5
+ // https://nvlabs.github.io/stylegan2/license.html
6
+
7
+ #include <torch/types.h>
8
+
9
+ #include <ATen/ATen.h>
10
+ #include <ATen/AccumulateType.h>
11
+ #include <ATen/cuda/CUDAApplyUtils.cuh>
12
+ #include <ATen/cuda/CUDAContext.h>
13
+
14
+
15
+ #include <cuda.h>
16
+ #include <cuda_runtime.h>
17
+
18
+ template <typename scalar_t>
19
+ static __global__ void
20
+ fused_bias_act_kernel(scalar_t *out, const scalar_t *p_x, const scalar_t *p_b,
21
+ const scalar_t *p_ref, int act, int grad, scalar_t alpha,
22
+ scalar_t scale, int loop_x, int size_x, int step_b,
23
+ int size_b, int use_bias, int use_ref) {
24
+ int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;
25
+
26
+ scalar_t zero = 0.0;
27
+
28
+ for (int loop_idx = 0; loop_idx < loop_x && xi < size_x;
29
+ loop_idx++, xi += blockDim.x) {
30
+ scalar_t x = p_x[xi];
31
+
32
+ if (use_bias) {
33
+ x += p_b[(xi / step_b) % size_b];
34
+ }
35
+
36
+ scalar_t ref = use_ref ? p_ref[xi] : zero;
37
+
38
+ scalar_t y;
39
+
40
+ switch (act * 10 + grad) {
41
+ default:
42
+ case 10:
43
+ y = x;
44
+ break;
45
+ case 11:
46
+ y = x;
47
+ break;
48
+ case 12:
49
+ y = 0.0;
50
+ break;
51
+
52
+ case 30:
53
+ y = (x > 0.0) ? x : x * alpha;
54
+ break;
55
+ case 31:
56
+ y = (ref > 0.0) ? x : x * alpha;
57
+ break;
58
+ case 32:
59
+ y = 0.0;
60
+ break;
61
+ }
62
+
63
+ out[xi] = y * scale;
64
+ }
65
+ }
66
+
67
+ torch::Tensor fused_bias_act_op(const torch::Tensor &input,
68
+ const torch::Tensor &bias,
69
+ const torch::Tensor &refer, int act, int grad,
70
+ float alpha, float scale) {
71
+ int curDevice = -1;
72
+ cudaGetDevice(&curDevice);
73
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
74
+
75
+ auto x = input.contiguous();
76
+ auto b = bias.contiguous();
77
+ auto ref = refer.contiguous();
78
+
79
+ int use_bias = b.numel() ? 1 : 0;
80
+ int use_ref = ref.numel() ? 1 : 0;
81
+
82
+ int size_x = x.numel();
83
+ int size_b = b.numel();
84
+ int step_b = 1;
85
+
86
+ for (int i = 1 + 1; i < x.dim(); i++) {
87
+ step_b *= x.size(i);
88
+ }
89
+
90
+ int loop_x = 4;
91
+ int block_size = 4 * 32;
92
+ int grid_size = (size_x - 1) / (loop_x * block_size) + 1;
93
+
94
+ auto y = torch::empty_like(x);
95
+
96
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
97
+ x.scalar_type(), "fused_bias_act_kernel", [&] {
98
+ fused_bias_act_kernel<scalar_t><<<grid_size, block_size, 0, stream>>>(
99
+ y.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(),
100
+ b.data_ptr<scalar_t>(), ref.data_ptr<scalar_t>(), act, grad, alpha,
101
+ scale, loop_x, size_x, step_b, size_b, use_bias, use_ref);
102
+ });
103
+
104
+ return y;
105
+ }
stylegan2/op/upfirdn2d.cpp ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <ATen/ATen.h>
2
+ #include <torch/extension.h>
3
+
4
+ torch::Tensor upfirdn2d_op(const torch::Tensor &input,
5
+ const torch::Tensor &kernel, int up_x, int up_y,
6
+ int down_x, int down_y, int pad_x0, int pad_x1,
7
+ int pad_y0, int pad_y1);
8
+
9
+ #define CHECK_CUDA(x) \
10
+ TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
11
+ #define CHECK_CONTIGUOUS(x) \
12
+ TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
13
+ #define CHECK_INPUT(x) \
14
+ CHECK_CUDA(x); \
15
+ CHECK_CONTIGUOUS(x)
16
+
17
+ torch::Tensor upfirdn2d(const torch::Tensor &input, const torch::Tensor &kernel,
18
+ int up_x, int up_y, int down_x, int down_y, int pad_x0,
19
+ int pad_x1, int pad_y0, int pad_y1) {
20
+ CHECK_INPUT(input);
21
+ CHECK_INPUT(kernel);
22
+
23
+ at::DeviceGuard guard(input.device());
24
+
25
+ return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1,
26
+ pad_y0, pad_y1);
27
+ }
28
+
29
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
30
+ m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
31
+ }
stylegan2/op/upfirdn2d.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import abc
2
+ import os
3
+
4
+ import torch
5
+ from torch.nn import functional as F
6
+ from torch.autograd import Function
7
+ from torch.utils.cpp_extension import load
8
+ import warnings
9
+
10
+ module_path = os.path.dirname(os.path.abspath(__file__))
11
+
12
+ try:
13
+ upfirdn2d_op = load(
14
+ "upfirdn2d",
15
+ sources=[
16
+ os.path.join(module_path, "upfirdn2d.cpp"),
17
+ os.path.join(module_path, "upfirdn2d_kernel.cu"),
18
+ ],
19
+ )
20
+ except:
21
+ warnings.warn(
22
+ f"(This is not error) Switch to native implementation"
23
+ )
24
+
25
+ upfirdn2d_op = None
26
+
27
+
28
+ class UpFirDn2dBackward(Function):
29
+ @staticmethod
30
+ def forward(
31
+ ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size
32
+ ):
33
+
34
+ up_x, up_y = up
35
+ down_x, down_y = down
36
+ g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
37
+
38
+ grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
39
+
40
+ grad_input = upfirdn2d_op.upfirdn2d(
41
+ grad_output,
42
+ grad_kernel,
43
+ down_x,
44
+ down_y,
45
+ up_x,
46
+ up_y,
47
+ g_pad_x0,
48
+ g_pad_x1,
49
+ g_pad_y0,
50
+ g_pad_y1,
51
+ )
52
+ grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])
53
+
54
+ ctx.save_for_backward(kernel)
55
+
56
+ pad_x0, pad_x1, pad_y0, pad_y1 = pad
57
+
58
+ ctx.up_x = up_x
59
+ ctx.up_y = up_y
60
+ ctx.down_x = down_x
61
+ ctx.down_y = down_y
62
+ ctx.pad_x0 = pad_x0
63
+ ctx.pad_x1 = pad_x1
64
+ ctx.pad_y0 = pad_y0
65
+ ctx.pad_y1 = pad_y1
66
+ ctx.in_size = in_size
67
+ ctx.out_size = out_size
68
+
69
+ return grad_input
70
+
71
+ @staticmethod
72
+ def backward(ctx, gradgrad_input):
73
+ kernel, = ctx.saved_tensors
74
+
75
+ gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
76
+
77
+ gradgrad_out = upfirdn2d_op.upfirdn2d(
78
+ gradgrad_input,
79
+ kernel,
80
+ ctx.up_x,
81
+ ctx.up_y,
82
+ ctx.down_x,
83
+ ctx.down_y,
84
+ ctx.pad_x0,
85
+ ctx.pad_x1,
86
+ ctx.pad_y0,
87
+ ctx.pad_y1,
88
+ )
89
+ # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3])
90
+ gradgrad_out = gradgrad_out.view(
91
+ ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]
92
+ )
93
+
94
+ return gradgrad_out, None, None, None, None, None, None, None, None
95
+
96
+
97
+ class UpFirDn2d(Function):
98
+ @staticmethod
99
+ def forward(ctx, input, kernel, up, down, pad):
100
+ up_x, up_y = up
101
+ down_x, down_y = down
102
+ pad_x0, pad_x1, pad_y0, pad_y1 = pad
103
+
104
+ kernel_h, kernel_w = kernel.shape
105
+ batch, channel, in_h, in_w = input.shape
106
+ ctx.in_size = input.shape
107
+
108
+ input = input.reshape(-1, in_h, in_w, 1)
109
+
110
+ ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
111
+
112
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y
113
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x
114
+ ctx.out_size = (out_h, out_w)
115
+
116
+ ctx.up = (up_x, up_y)
117
+ ctx.down = (down_x, down_y)
118
+ ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
119
+
120
+ g_pad_x0 = kernel_w - pad_x0 - 1
121
+ g_pad_y0 = kernel_h - pad_y0 - 1
122
+ g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
123
+ g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
124
+
125
+ ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
126
+
127
+ out = upfirdn2d_op.upfirdn2d(
128
+ input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
129
+ )
130
+ # out = out.view(major, out_h, out_w, minor)
131
+ out = out.view(-1, channel, out_h, out_w)
132
+
133
+ return out
134
+
135
+ @staticmethod
136
+ def backward(ctx, grad_output):
137
+ kernel, grad_kernel = ctx.saved_tensors
138
+
139
+ grad_input = None
140
+
141
+ if ctx.needs_input_grad[0]:
142
+ grad_input = UpFirDn2dBackward.apply(
143
+ grad_output,
144
+ kernel,
145
+ grad_kernel,
146
+ ctx.up,
147
+ ctx.down,
148
+ ctx.pad,
149
+ ctx.g_pad,
150
+ ctx.in_size,
151
+ ctx.out_size,
152
+ )
153
+
154
+ return grad_input, None, None, None, None
155
+
156
+
157
+ def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
158
+ if not isinstance(up, abc.Iterable):
159
+ up = (up, up)
160
+
161
+ if not isinstance(down, abc.Iterable):
162
+ down = (down, down)
163
+
164
+ if len(pad) == 2:
165
+ pad = (pad[0], pad[1], pad[0], pad[1])
166
+
167
+ if input.device.type == "cpu":
168
+ out = _upfirdn2d_native(input, kernel, *up, *down, *pad)
169
+
170
+ else:
171
+ out = UpFirDn2d.apply(input, kernel, up, down, pad)
172
+
173
+ return out
174
+
175
+
176
+ def upfirdn2d_native(input, kernel, up=1, down=1, pad=(0, 0)):
177
+ if not isinstance(up, abc.Iterable):
178
+ up = (up, up)
179
+
180
+ if not isinstance(down, abc.Iterable):
181
+ down = (down, down)
182
+
183
+ if len(pad) == 2:
184
+ pad = (pad[0], pad[1], pad[0], pad[1])
185
+
186
+ out = _upfirdn2d_native(input, kernel, *up, *down, *pad)
187
+
188
+ return out
189
+
190
+
191
+ def _upfirdn2d_native(
192
+ input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
193
+ ):
194
+ _, channel, in_h, in_w = input.shape
195
+ input = input.reshape(-1, in_h, in_w, 1)
196
+
197
+ _, in_h, in_w, minor = input.shape
198
+ kernel_h, kernel_w = kernel.shape
199
+
200
+ out = input.view(-1, in_h, 1, in_w, 1, minor)
201
+ out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
202
+ out = out.view(-1, in_h * up_y, in_w * up_x, minor)
203
+
204
+ out = F.pad(
205
+ out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
206
+ )
207
+ out = out[
208
+ :,
209
+ max(-pad_y0, 0): out.shape[1] - max(-pad_y1, 0),
210
+ max(-pad_x0, 0): out.shape[2] - max(-pad_x1, 0),
211
+ :,
212
+ ]
213
+
214
+ out = out.permute(0, 3, 1, 2)
215
+ out = out.reshape(
216
+ [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
217
+ )
218
+ w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
219
+ out = F.conv2d(out, w)
220
+ out = out.reshape(
221
+ -1,
222
+ minor,
223
+ in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
224
+ in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
225
+ )
226
+ out = out.permute(0, 2, 3, 1)
227
+ out = out[:, ::down_y, ::down_x, :]
228
+
229
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y
230
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x
231
+
232
+ return out.view(-1, channel, out_h, out_w)
stylegan2/op/upfirdn2d_kernel.cu ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2
+ //
3
+ // This work is made available under the Nvidia Source Code License-NC.
4
+ // To view a copy of this license, visit
5
+ // https://nvlabs.github.io/stylegan2/license.html
6
+
7
+ #include <torch/types.h>
8
+
9
+ #include <ATen/ATen.h>
10
+ #include <ATen/AccumulateType.h>
11
+ #include <ATen/cuda/CUDAApplyUtils.cuh>
12
+ #include <ATen/cuda/CUDAContext.h>
13
+
14
+ #include <cuda.h>
15
+ #include <cuda_runtime.h>
16
+
17
+ static __host__ __device__ __forceinline__ int floor_div(int a, int b) {
18
+ int c = a / b;
19
+
20
+ if (c * b > a) {
21
+ c--;
22
+ }
23
+
24
+ return c;
25
+ }
26
+
27
+ struct UpFirDn2DKernelParams {
28
+ int up_x;
29
+ int up_y;
30
+ int down_x;
31
+ int down_y;
32
+ int pad_x0;
33
+ int pad_x1;
34
+ int pad_y0;
35
+ int pad_y1;
36
+
37
+ int major_dim;
38
+ int in_h;
39
+ int in_w;
40
+ int minor_dim;
41
+ int kernel_h;
42
+ int kernel_w;
43
+ int out_h;
44
+ int out_w;
45
+ int loop_major;
46
+ int loop_x;
47
+ };
48
+
49
+ template <typename scalar_t>
50
+ __global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input,
51
+ const scalar_t *kernel,
52
+ const UpFirDn2DKernelParams p) {
53
+ int minor_idx = blockIdx.x * blockDim.x + threadIdx.x;
54
+ int out_y = minor_idx / p.minor_dim;
55
+ minor_idx -= out_y * p.minor_dim;
56
+ int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y;
57
+ int major_idx_base = blockIdx.z * p.loop_major;
58
+
59
+ if (out_x_base >= p.out_w || out_y >= p.out_h ||
60
+ major_idx_base >= p.major_dim) {
61
+ return;
62
+ }
63
+
64
+ int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0;
65
+ int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h);
66
+ int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y;
67
+ int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y;
68
+
69
+ for (int loop_major = 0, major_idx = major_idx_base;
70
+ loop_major < p.loop_major && major_idx < p.major_dim;
71
+ loop_major++, major_idx++) {
72
+ for (int loop_x = 0, out_x = out_x_base;
73
+ loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) {
74
+ int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0;
75
+ int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w);
76
+ int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x;
77
+ int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x;
78
+
79
+ const scalar_t *x_p =
80
+ &input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim +
81
+ minor_idx];
82
+ const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x];
83
+ int x_px = p.minor_dim;
84
+ int k_px = -p.up_x;
85
+ int x_py = p.in_w * p.minor_dim;
86
+ int k_py = -p.up_y * p.kernel_w;
87
+
88
+ scalar_t v = 0.0f;
89
+
90
+ for (int y = 0; y < h; y++) {
91
+ for (int x = 0; x < w; x++) {
92
+ v += static_cast<scalar_t>(*x_p) * static_cast<scalar_t>(*k_p);
93
+ x_p += x_px;
94
+ k_p += k_px;
95
+ }
96
+
97
+ x_p += x_py - w * x_px;
98
+ k_p += k_py - w * k_px;
99
+ }
100
+
101
+ out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
102
+ minor_idx] = v;
103
+ }
104
+ }
105
+ }
106
+
107
+ template <typename scalar_t, int up_x, int up_y, int down_x, int down_y,
108
+ int kernel_h, int kernel_w, int tile_out_h, int tile_out_w>
109
+ __global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input,
110
+ const scalar_t *kernel,
111
+ const UpFirDn2DKernelParams p) {
112
+ const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1;
113
+ const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1;
114
+
115
+ __shared__ volatile float sk[kernel_h][kernel_w];
116
+ __shared__ volatile float sx[tile_in_h][tile_in_w];
117
+
118
+ int minor_idx = blockIdx.x;
119
+ int tile_out_y = minor_idx / p.minor_dim;
120
+ minor_idx -= tile_out_y * p.minor_dim;
121
+ tile_out_y *= tile_out_h;
122
+ int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w;
123
+ int major_idx_base = blockIdx.z * p.loop_major;
124
+
125
+ if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h |
126
+ major_idx_base >= p.major_dim) {
127
+ return;
128
+ }
129
+
130
+ for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w;
131
+ tap_idx += blockDim.x) {
132
+ int ky = tap_idx / kernel_w;
133
+ int kx = tap_idx - ky * kernel_w;
134
+ scalar_t v = 0.0;
135
+
136
+ if (kx < p.kernel_w & ky < p.kernel_h) {
137
+ v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)];
138
+ }
139
+
140
+ sk[ky][kx] = v;
141
+ }
142
+
143
+ for (int loop_major = 0, major_idx = major_idx_base;
144
+ loop_major < p.loop_major & major_idx < p.major_dim;
145
+ loop_major++, major_idx++) {
146
+ for (int loop_x = 0, tile_out_x = tile_out_x_base;
147
+ loop_x < p.loop_x & tile_out_x < p.out_w;
148
+ loop_x++, tile_out_x += tile_out_w) {
149
+ int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0;
150
+ int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0;
151
+ int tile_in_x = floor_div(tile_mid_x, up_x);
152
+ int tile_in_y = floor_div(tile_mid_y, up_y);
153
+
154
+ __syncthreads();
155
+
156
+ for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w;
157
+ in_idx += blockDim.x) {
158
+ int rel_in_y = in_idx / tile_in_w;
159
+ int rel_in_x = in_idx - rel_in_y * tile_in_w;
160
+ int in_x = rel_in_x + tile_in_x;
161
+ int in_y = rel_in_y + tile_in_y;
162
+
163
+ scalar_t v = 0.0;
164
+
165
+ if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) {
166
+ v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) *
167
+ p.minor_dim +
168
+ minor_idx];
169
+ }
170
+
171
+ sx[rel_in_y][rel_in_x] = v;
172
+ }
173
+
174
+ __syncthreads();
175
+ for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w;
176
+ out_idx += blockDim.x) {
177
+ int rel_out_y = out_idx / tile_out_w;
178
+ int rel_out_x = out_idx - rel_out_y * tile_out_w;
179
+ int out_x = rel_out_x + tile_out_x;
180
+ int out_y = rel_out_y + tile_out_y;
181
+
182
+ int mid_x = tile_mid_x + rel_out_x * down_x;
183
+ int mid_y = tile_mid_y + rel_out_y * down_y;
184
+ int in_x = floor_div(mid_x, up_x);
185
+ int in_y = floor_div(mid_y, up_y);
186
+ int rel_in_x = in_x - tile_in_x;
187
+ int rel_in_y = in_y - tile_in_y;
188
+ int kernel_x = (in_x + 1) * up_x - mid_x - 1;
189
+ int kernel_y = (in_y + 1) * up_y - mid_y - 1;
190
+
191
+ scalar_t v = 0.0;
192
+
193
+ #pragma unroll
194
+ for (int y = 0; y < kernel_h / up_y; y++)
195
+ #pragma unroll
196
+ for (int x = 0; x < kernel_w / up_x; x++)
197
+ v += sx[rel_in_y + y][rel_in_x + x] *
198
+ sk[kernel_y + y * up_y][kernel_x + x * up_x];
199
+
200
+ if (out_x < p.out_w & out_y < p.out_h) {
201
+ out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
202
+ minor_idx] = v;
203
+ }
204
+ }
205
+ }
206
+ }
207
+ }
208
+
209
+ torch::Tensor upfirdn2d_op(const torch::Tensor &input,
210
+ const torch::Tensor &kernel, int up_x, int up_y,
211
+ int down_x, int down_y, int pad_x0, int pad_x1,
212
+ int pad_y0, int pad_y1) {
213
+ int curDevice = -1;
214
+ cudaGetDevice(&curDevice);
215
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
216
+
217
+ UpFirDn2DKernelParams p;
218
+
219
+ auto x = input.contiguous();
220
+ auto k = kernel.contiguous();
221
+
222
+ p.major_dim = x.size(0);
223
+ p.in_h = x.size(1);
224
+ p.in_w = x.size(2);
225
+ p.minor_dim = x.size(3);
226
+ p.kernel_h = k.size(0);
227
+ p.kernel_w = k.size(1);
228
+ p.up_x = up_x;
229
+ p.up_y = up_y;
230
+ p.down_x = down_x;
231
+ p.down_y = down_y;
232
+ p.pad_x0 = pad_x0;
233
+ p.pad_x1 = pad_x1;
234
+ p.pad_y0 = pad_y0;
235
+ p.pad_y1 = pad_y1;
236
+
237
+ p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) /
238
+ p.down_y;
239
+ p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) /
240
+ p.down_x;
241
+
242
+ auto out =
243
+ at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options());
244
+
245
+ int mode = -1;
246
+
247
+ int tile_out_h = -1;
248
+ int tile_out_w = -1;
249
+
250
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
251
+ p.kernel_h <= 4 && p.kernel_w <= 4) {
252
+ mode = 1;
253
+ tile_out_h = 16;
254
+ tile_out_w = 64;
255
+ }
256
+
257
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
258
+ p.kernel_h <= 3 && p.kernel_w <= 3) {
259
+ mode = 2;
260
+ tile_out_h = 16;
261
+ tile_out_w = 64;
262
+ }
263
+
264
+ if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
265
+ p.kernel_h <= 4 && p.kernel_w <= 4) {
266
+ mode = 3;
267
+ tile_out_h = 16;
268
+ tile_out_w = 64;
269
+ }
270
+
271
+ if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
272
+ p.kernel_h <= 2 && p.kernel_w <= 2) {
273
+ mode = 4;
274
+ tile_out_h = 16;
275
+ tile_out_w = 64;
276
+ }
277
+
278
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
279
+ p.kernel_h <= 4 && p.kernel_w <= 4) {
280
+ mode = 5;
281
+ tile_out_h = 8;
282
+ tile_out_w = 32;
283
+ }
284
+
285
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
286
+ p.kernel_h <= 2 && p.kernel_w <= 2) {
287
+ mode = 6;
288
+ tile_out_h = 8;
289
+ tile_out_w = 32;
290
+ }
291
+
292
+ dim3 block_size;
293
+ dim3 grid_size;
294
+
295
+ if (tile_out_h > 0 && tile_out_w > 0) {
296
+ p.loop_major = (p.major_dim - 1) / 16384 + 1;
297
+ p.loop_x = 1;
298
+ block_size = dim3(32 * 8, 1, 1);
299
+ grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim,
300
+ (p.out_w - 1) / (p.loop_x * tile_out_w) + 1,
301
+ (p.major_dim - 1) / p.loop_major + 1);
302
+ } else {
303
+ p.loop_major = (p.major_dim - 1) / 16384 + 1;
304
+ p.loop_x = 4;
305
+ block_size = dim3(4, 32, 1);
306
+ grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1,
307
+ (p.out_w - 1) / (p.loop_x * block_size.y) + 1,
308
+ (p.major_dim - 1) / p.loop_major + 1);
309
+ }
310
+
311
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] {
312
+ switch (mode) {
313
+ case 1:
314
+ upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 4, 4, 16, 64>
315
+ <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
316
+ x.data_ptr<scalar_t>(),
317
+ k.data_ptr<scalar_t>(), p);
318
+
319
+ break;
320
+
321
+ case 2:
322
+ upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 3, 3, 16, 64>
323
+ <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
324
+ x.data_ptr<scalar_t>(),
325
+ k.data_ptr<scalar_t>(), p);
326
+
327
+ break;
328
+
329
+ case 3:
330
+ upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 4, 4, 16, 64>
331
+ <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
332
+ x.data_ptr<scalar_t>(),
333
+ k.data_ptr<scalar_t>(), p);
334
+
335
+ break;
336
+
337
+ case 4:
338
+ upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 2, 2, 16, 64>
339
+ <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
340
+ x.data_ptr<scalar_t>(),
341
+ k.data_ptr<scalar_t>(), p);
342
+
343
+ break;
344
+
345
+ case 5:
346
+ upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32>
347
+ <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
348
+ x.data_ptr<scalar_t>(),
349
+ k.data_ptr<scalar_t>(), p);
350
+
351
+ break;
352
+
353
+ case 6:
354
+ upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32>
355
+ <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
356
+ x.data_ptr<scalar_t>(),
357
+ k.data_ptr<scalar_t>(), p);
358
+
359
+ break;
360
+
361
+ default:
362
+ upfirdn2d_kernel_large<scalar_t><<<grid_size, block_size, 0, stream>>>(
363
+ out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(),
364
+ k.data_ptr<scalar_t>(), p);
365
+ }
366
+ });
367
+
368
+ return out;
369
+ }