Spaces:
Runtime error
Runtime error
update
Browse files- .gitignore +160 -0
- README.md +1 -1
- drag_gan.py +240 -0
- gradio_app.py +348 -0
- requirements.txt +9 -0
- stylegan2/__init__.py +0 -0
- stylegan2/inversion.py +209 -0
- stylegan2/lpips/__init__.py +5 -0
- stylegan2/lpips/base_model.py +58 -0
- stylegan2/lpips/dist_model.py +314 -0
- stylegan2/lpips/networks_basic.py +188 -0
- stylegan2/lpips/pretrained_networks.py +181 -0
- stylegan2/lpips/util.py +160 -0
- stylegan2/model.py +714 -0
- stylegan2/op/__init__.py +2 -0
- stylegan2/op/conv2d_gradfix.py +229 -0
- stylegan2/op/fused_act.py +157 -0
- stylegan2/op/fused_bias_act.cpp +32 -0
- stylegan2/op/fused_bias_act_kernel.cu +105 -0
- stylegan2/op/upfirdn2d.cpp +31 -0
- stylegan2/op/upfirdn2d.py +232 -0
- stylegan2/op/upfirdn2d_kernel.cu +369 -0
.gitignore
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
share/python-wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
# PyInstaller
|
30 |
+
# Usually these files are written by a python script from a template
|
31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
+
*.manifest
|
33 |
+
*.spec
|
34 |
+
|
35 |
+
# Installer logs
|
36 |
+
pip-log.txt
|
37 |
+
pip-delete-this-directory.txt
|
38 |
+
|
39 |
+
# Unit test / coverage reports
|
40 |
+
htmlcov/
|
41 |
+
.tox/
|
42 |
+
.nox/
|
43 |
+
.coverage
|
44 |
+
.coverage.*
|
45 |
+
.cache
|
46 |
+
nosetests.xml
|
47 |
+
coverage.xml
|
48 |
+
*.cover
|
49 |
+
*.py,cover
|
50 |
+
.hypothesis/
|
51 |
+
.pytest_cache/
|
52 |
+
cover/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
.pybuilder/
|
76 |
+
target/
|
77 |
+
|
78 |
+
# Jupyter Notebook
|
79 |
+
.ipynb_checkpoints
|
80 |
+
|
81 |
+
# IPython
|
82 |
+
profile_default/
|
83 |
+
ipython_config.py
|
84 |
+
|
85 |
+
# pyenv
|
86 |
+
# For a library or package, you might want to ignore these files since the code is
|
87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
88 |
+
# .python-version
|
89 |
+
|
90 |
+
# pipenv
|
91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
+
# install all needed dependencies.
|
95 |
+
#Pipfile.lock
|
96 |
+
|
97 |
+
# poetry
|
98 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
100 |
+
# commonly ignored for libraries.
|
101 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
102 |
+
#poetry.lock
|
103 |
+
|
104 |
+
# pdm
|
105 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
106 |
+
#pdm.lock
|
107 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
108 |
+
# in version control.
|
109 |
+
# https://pdm.fming.dev/#use-with-ide
|
110 |
+
.pdm.toml
|
111 |
+
|
112 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
113 |
+
__pypackages__/
|
114 |
+
|
115 |
+
# Celery stuff
|
116 |
+
celerybeat-schedule
|
117 |
+
celerybeat.pid
|
118 |
+
|
119 |
+
# SageMath parsed files
|
120 |
+
*.sage.py
|
121 |
+
|
122 |
+
# Environments
|
123 |
+
.env
|
124 |
+
.venv
|
125 |
+
env/
|
126 |
+
venv/
|
127 |
+
ENV/
|
128 |
+
env.bak/
|
129 |
+
venv.bak/
|
130 |
+
|
131 |
+
# Spyder project settings
|
132 |
+
.spyderproject
|
133 |
+
.spyproject
|
134 |
+
|
135 |
+
# Rope project settings
|
136 |
+
.ropeproject
|
137 |
+
|
138 |
+
# mkdocs documentation
|
139 |
+
/site
|
140 |
+
|
141 |
+
# mypy
|
142 |
+
.mypy_cache/
|
143 |
+
.dmypy.json
|
144 |
+
dmypy.json
|
145 |
+
|
146 |
+
# Pyre type checker
|
147 |
+
.pyre/
|
148 |
+
|
149 |
+
# pytype static type analyzer
|
150 |
+
.pytype/
|
151 |
+
|
152 |
+
# Cython debug symbols
|
153 |
+
cython_debug/
|
154 |
+
|
155 |
+
# PyCharm
|
156 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
157 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
158 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
159 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
160 |
+
#.idea/
|
README.md
CHANGED
@@ -5,7 +5,7 @@ colorFrom: indigo
|
|
5 |
colorTo: blue
|
6 |
sdk: gradio
|
7 |
sdk_version: 3.29.0
|
8 |
-
app_file:
|
9 |
pinned: false
|
10 |
---
|
11 |
|
|
|
5 |
colorTo: blue
|
6 |
sdk: gradio
|
7 |
sdk_version: 3.29.0
|
8 |
+
app_file: gradio_app.py
|
9 |
pinned: false
|
10 |
---
|
11 |
|
drag_gan.py
ADDED
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import os
|
3 |
+
import random
|
4 |
+
import urllib.request
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as FF
|
8 |
+
import torch.optim
|
9 |
+
from torchvision import utils
|
10 |
+
from tqdm import tqdm
|
11 |
+
|
12 |
+
from stylegan2.model import Generator
|
13 |
+
|
14 |
+
|
15 |
+
class DownloadProgressBar(tqdm):
|
16 |
+
def update_to(self, b=1, bsize=1, tsize=None):
|
17 |
+
if tsize is not None:
|
18 |
+
self.total = tsize
|
19 |
+
self.update(b * bsize - self.n)
|
20 |
+
|
21 |
+
|
22 |
+
def get_path(base_path):
|
23 |
+
BASE_DIR = os.path.join('checkpoints')
|
24 |
+
|
25 |
+
save_path = os.path.join(BASE_DIR, base_path)
|
26 |
+
if not os.path.exists(save_path):
|
27 |
+
url = f"https://huggingface.co/aaronb/StyleGAN2/resolve/main/{base_path}"
|
28 |
+
print(f'{base_path} not found')
|
29 |
+
print('Try to download from huggingface: ', url)
|
30 |
+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
31 |
+
download_url(url, save_path)
|
32 |
+
print('Downloaded to ', save_path)
|
33 |
+
return save_path
|
34 |
+
|
35 |
+
|
36 |
+
def download_url(url, output_path):
|
37 |
+
with DownloadProgressBar(unit='B', unit_scale=True,
|
38 |
+
miniters=1, desc=url.split('/')[-1]) as t:
|
39 |
+
urllib.request.urlretrieve(url, filename=output_path, reporthook=t.update_to)
|
40 |
+
|
41 |
+
|
42 |
+
class CustomGenerator(Generator):
|
43 |
+
def prepare(
|
44 |
+
self,
|
45 |
+
styles,
|
46 |
+
inject_index=None,
|
47 |
+
truncation=1,
|
48 |
+
truncation_latent=None,
|
49 |
+
input_is_latent=False,
|
50 |
+
noise=None,
|
51 |
+
randomize_noise=True,
|
52 |
+
):
|
53 |
+
if not input_is_latent:
|
54 |
+
styles = [self.style(s) for s in styles]
|
55 |
+
|
56 |
+
if noise is None:
|
57 |
+
if randomize_noise:
|
58 |
+
noise = [None] * self.num_layers
|
59 |
+
else:
|
60 |
+
noise = [
|
61 |
+
getattr(self.noises, f"noise_{i}") for i in range(self.num_layers)
|
62 |
+
]
|
63 |
+
|
64 |
+
if truncation < 1:
|
65 |
+
style_t = []
|
66 |
+
|
67 |
+
for style in styles:
|
68 |
+
style_t.append(
|
69 |
+
truncation_latent + truncation * (style - truncation_latent)
|
70 |
+
)
|
71 |
+
|
72 |
+
styles = style_t
|
73 |
+
|
74 |
+
if len(styles) < 2:
|
75 |
+
inject_index = self.n_latent
|
76 |
+
|
77 |
+
if styles[0].ndim < 3:
|
78 |
+
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
79 |
+
|
80 |
+
else:
|
81 |
+
latent = styles[0]
|
82 |
+
|
83 |
+
else:
|
84 |
+
if inject_index is None:
|
85 |
+
inject_index = random.randint(1, self.n_latent - 1)
|
86 |
+
|
87 |
+
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
88 |
+
latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)
|
89 |
+
|
90 |
+
latent = torch.cat([latent, latent2], 1)
|
91 |
+
|
92 |
+
return latent, noise
|
93 |
+
|
94 |
+
def generate(
|
95 |
+
self,
|
96 |
+
latent,
|
97 |
+
noise,
|
98 |
+
):
|
99 |
+
out = self.input(latent)
|
100 |
+
out = self.conv1(out, latent[:, 0], noise=noise[0])
|
101 |
+
|
102 |
+
skip = self.to_rgb1(out, latent[:, 1])
|
103 |
+
i = 1
|
104 |
+
for conv1, conv2, noise1, noise2, to_rgb in zip(
|
105 |
+
self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
|
106 |
+
):
|
107 |
+
out = conv1(out, latent[:, i], noise=noise1)
|
108 |
+
out = conv2(out, latent[:, i + 1], noise=noise2)
|
109 |
+
skip = to_rgb(out, latent[:, i + 2], skip)
|
110 |
+
if out.shape[-1] == 256: F = out
|
111 |
+
i += 2
|
112 |
+
|
113 |
+
image = skip
|
114 |
+
F = FF.interpolate(F, image.shape[-2:], mode='bilinear')
|
115 |
+
return image, F
|
116 |
+
|
117 |
+
|
118 |
+
def stylegan2(
|
119 |
+
size=1024,
|
120 |
+
channel_multiplier=2,
|
121 |
+
latent=512,
|
122 |
+
n_mlp=8,
|
123 |
+
ckpt='stylegan2-ffhq-config-f.pt'
|
124 |
+
):
|
125 |
+
g_ema = CustomGenerator(size, latent, n_mlp, channel_multiplier=channel_multiplier)
|
126 |
+
checkpoint = torch.load(get_path(ckpt))
|
127 |
+
g_ema.load_state_dict(checkpoint["g_ema"], strict=False)
|
128 |
+
g_ema.requires_grad_(False)
|
129 |
+
g_ema.eval()
|
130 |
+
return g_ema
|
131 |
+
|
132 |
+
|
133 |
+
def bilinear_interpolate_torch(im, y, x):
|
134 |
+
"""
|
135 |
+
im : B,C,H,W
|
136 |
+
y : 1,numPoints -- pixel location y float
|
137 |
+
x : 1,numPOints -- pixel location y float
|
138 |
+
"""
|
139 |
+
device = im.device
|
140 |
+
|
141 |
+
x0 = torch.floor(x).long().to(device)
|
142 |
+
x1 = x0 + 1
|
143 |
+
|
144 |
+
y0 = torch.floor(y).long().to(device)
|
145 |
+
y1 = y0 + 1
|
146 |
+
|
147 |
+
wa = ((x1.float() - x) * (y1.float() - y)).to(device)
|
148 |
+
wb = ((x1.float() - x) * (y - y0.float())).to(device)
|
149 |
+
wc = ((x - x0.float()) * (y1.float() - y)).to(device)
|
150 |
+
wd = ((x - x0.float()) * (y - y0.float())).to(device)
|
151 |
+
# Instead of clamp
|
152 |
+
x1 = x1 - torch.floor(x1 / im.shape[3]).int().to(device)
|
153 |
+
y1 = y1 - torch.floor(y1 / im.shape[2]).int().to(device)
|
154 |
+
Ia = im[:, :, y0, x0]
|
155 |
+
Ib = im[:, :, y1, x0]
|
156 |
+
Ic = im[:, :, y0, x1]
|
157 |
+
Id = im[:, :, y1, x1]
|
158 |
+
|
159 |
+
return Ia * wa + Ib * wb + Ic * wc + Id * wd
|
160 |
+
|
161 |
+
|
162 |
+
def drag_gan(g_ema, latent: torch.Tensor, noise, F, handle_points, target_points, mask, max_iters=1000):
|
163 |
+
handle_points0 = copy.deepcopy(handle_points)
|
164 |
+
n = len(handle_points)
|
165 |
+
r1, r2, lam, d = 3, 12, 20, 1
|
166 |
+
|
167 |
+
def neighbor(x, y, d):
|
168 |
+
points = []
|
169 |
+
for i in range(x - d, x + d):
|
170 |
+
for j in range(y - d, y + d):
|
171 |
+
points.append(torch.tensor([i, j]).float().cuda())
|
172 |
+
return points
|
173 |
+
|
174 |
+
F0 = F.detach().clone()
|
175 |
+
|
176 |
+
latent_trainable = latent[:, :6, :].detach().clone().requires_grad_(True)
|
177 |
+
latent_untrainable = latent[:, 6:, :].detach().clone().requires_grad_(False)
|
178 |
+
optimizer = torch.optim.Adam([latent_trainable], lr=2e-3)
|
179 |
+
for iter in range(max_iters):
|
180 |
+
for s in range(1):
|
181 |
+
optimizer.zero_grad()
|
182 |
+
latent = torch.cat([latent_trainable, latent_untrainable], dim=1)
|
183 |
+
sample2, F2 = g_ema.generate(latent, noise)
|
184 |
+
|
185 |
+
# motion supervision
|
186 |
+
loss = 0
|
187 |
+
for i in range(n):
|
188 |
+
pi, ti = handle_points[i], target_points[i]
|
189 |
+
di = (ti - pi) / torch.sum((ti - pi)**2)
|
190 |
+
|
191 |
+
for qi in neighbor(int(pi[0]), int(pi[1]), r1):
|
192 |
+
# f1 = F[..., int(qi[0]), int(qi[1])]
|
193 |
+
# f2 = F2[..., int(qi[0] + di[0]), int(qi[1] + di[1])]
|
194 |
+
f1 = bilinear_interpolate_torch(F2, qi[0], qi[1]).detach()
|
195 |
+
f2 = bilinear_interpolate_torch(F2, qi[0] + di[0], qi[1] + di[1])
|
196 |
+
loss += FF.l1_loss(f2, f1)
|
197 |
+
|
198 |
+
if mask is not None:
|
199 |
+
loss += ((F2 - F0) * (1 - mask)).abs().mean() * lam
|
200 |
+
|
201 |
+
loss.backward()
|
202 |
+
optimizer.step()
|
203 |
+
|
204 |
+
# point tracking
|
205 |
+
with torch.no_grad():
|
206 |
+
sample2, F2 = g_ema.generate(latent, noise)
|
207 |
+
for i in range(n):
|
208 |
+
pi = handle_points0[i]
|
209 |
+
# f = F0[..., int(pi[0]), int(pi[1])]
|
210 |
+
f0 = bilinear_interpolate_torch(F0, pi[0], pi[1])
|
211 |
+
minv = 1e9
|
212 |
+
minx = 1e9
|
213 |
+
miny = 1e9
|
214 |
+
for qi in neighbor(int(handle_points[i][0]), int(handle_points[i][1]), r2):
|
215 |
+
# f2 = F2[..., int(qi[0]), int(qi[1])]
|
216 |
+
try:
|
217 |
+
f2 = bilinear_interpolate_torch(F2, qi[0], qi[1])
|
218 |
+
except:
|
219 |
+
import ipdb
|
220 |
+
ipdb.set_trace()
|
221 |
+
v = torch.norm(f2 - f0, p=1)
|
222 |
+
if v < minv:
|
223 |
+
minv = v
|
224 |
+
minx = int(qi[0])
|
225 |
+
miny = int(qi[1])
|
226 |
+
handle_points[i][0] = minx
|
227 |
+
handle_points[i][1] = miny
|
228 |
+
|
229 |
+
F = F2.detach().clone()
|
230 |
+
if iter % 1 == 0:
|
231 |
+
print(iter, loss.item(), handle_points, target_points)
|
232 |
+
# p = handle_points[0].int()
|
233 |
+
# sample2[0, :, p[0] - 5:p[0] + 5, p[1] - 5:p[1] + 5] = sample2[0, :, p[0] - 5:p[0] + 5, p[1] - 5:p[1] + 5] * 0
|
234 |
+
# t = target_points[0].int()
|
235 |
+
# sample2[0, :, t[0] - 5:t[0] + 5, t[1] - 5:t[1] + 5] = sample2[0, :, t[0] - 5:t[0] + 5, t[1] - 5:t[1] + 5] * 255
|
236 |
+
|
237 |
+
# sample2[0, :, 210, 134] = sample2[0, :, 210, 134] * 0
|
238 |
+
# utils.save_image(sample2, "test2.png", normalize=True, range=(-1, 1))
|
239 |
+
|
240 |
+
yield sample2, latent, F2, handle_points
|
gradio_app.py
ADDED
@@ -0,0 +1,348 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import gradio as gr
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
import imageio
|
6 |
+
from PIL import Image
|
7 |
+
import uuid
|
8 |
+
|
9 |
+
from drag_gan import drag_gan, stylegan2
|
10 |
+
from stylegan2.inversion import inverse_image
|
11 |
+
|
12 |
+
device = 'cpu'
|
13 |
+
|
14 |
+
|
15 |
+
SIZE_TO_CLICK_SIZE = {
|
16 |
+
1024: 8,
|
17 |
+
512: 5,
|
18 |
+
256: 2
|
19 |
+
}
|
20 |
+
|
21 |
+
CKPT_SIZE = {
|
22 |
+
'stylegan2-ffhq-config-f.pt': 1024,
|
23 |
+
'stylegan2-cat-config-f.pt': 256,
|
24 |
+
'stylegan2-church-config-f.pt': 256,
|
25 |
+
'stylegan2-horse-config-f.pt': 256,
|
26 |
+
'ada/ffhq.pt': 1024,
|
27 |
+
'ada/afhqcat.pt': 512,
|
28 |
+
'ada/afhqdog.pt': 512,
|
29 |
+
'ada/afhqwild.pt': 512,
|
30 |
+
'ada/brecahad.pt': 512,
|
31 |
+
'ada/metfaces.pt': 512,
|
32 |
+
}
|
33 |
+
|
34 |
+
DEFAULT_CKPT = 'stylegan2-ffhq-config-f.pt'
|
35 |
+
|
36 |
+
|
37 |
+
class grImage(gr.components.Image):
|
38 |
+
is_template = True
|
39 |
+
|
40 |
+
def preprocess(self, x):
|
41 |
+
if x is None:
|
42 |
+
return x
|
43 |
+
if self.tool == "sketch" and self.source in ["upload", "webcam"]:
|
44 |
+
decode_image = gr.processing_utils.decode_base64_to_image(x)
|
45 |
+
width, height = decode_image.size
|
46 |
+
mask = np.zeros((height, width, 4), dtype=np.uint8)
|
47 |
+
mask[..., -1] = 255
|
48 |
+
mask = self.postprocess(mask)
|
49 |
+
x = {'image': x, 'mask': mask}
|
50 |
+
return super().preprocess(x)
|
51 |
+
|
52 |
+
|
53 |
+
class ImageMask(gr.components.Image):
|
54 |
+
"""
|
55 |
+
Sets: source="canvas", tool="sketch"
|
56 |
+
"""
|
57 |
+
|
58 |
+
is_template = True
|
59 |
+
|
60 |
+
def __init__(self, **kwargs):
|
61 |
+
super().__init__(source="upload", tool="sketch", interactive=True, **kwargs)
|
62 |
+
|
63 |
+
def preprocess(self, x):
|
64 |
+
if x is None:
|
65 |
+
return x
|
66 |
+
if self.tool == "sketch" and self.source in ["upload", "webcam"] and type(x) != dict:
|
67 |
+
decode_image = gr.processing_utils.decode_base64_to_image(x)
|
68 |
+
width, height = decode_image.size
|
69 |
+
mask = np.zeros((height, width, 4), dtype=np.uint8)
|
70 |
+
mask[..., -1] = 255
|
71 |
+
mask = self.postprocess(mask)
|
72 |
+
x = {'image': x, 'mask': mask}
|
73 |
+
return super().preprocess(x)
|
74 |
+
|
75 |
+
|
76 |
+
class ModelWrapper:
|
77 |
+
def __init__(self, **kwargs):
|
78 |
+
self.g_ema = stylegan2(**kwargs).to(device)
|
79 |
+
|
80 |
+
|
81 |
+
def to_image(tensor):
|
82 |
+
tensor = tensor.squeeze(0).permute(1, 2, 0)
|
83 |
+
arr = tensor.detach().cpu().numpy()
|
84 |
+
arr = (arr - arr.min()) / (arr.max() - arr.min())
|
85 |
+
arr = arr * 255
|
86 |
+
return arr.astype('uint8')
|
87 |
+
|
88 |
+
|
89 |
+
def add_points_to_image(image, points, size=5):
|
90 |
+
h, w, = image.shape[:2]
|
91 |
+
|
92 |
+
for x, y in points['target']:
|
93 |
+
image[max(0, x - size):min(x + size, h - 1), max(0, y - size):min(y + size, w), :] = [255, 0, 0]
|
94 |
+
for x, y in points['handle']:
|
95 |
+
image[max(0, x - size):min(x + size, h - 1), max(0, y - size):min(y + size, w), :] = [0, 0, 255]
|
96 |
+
|
97 |
+
return image
|
98 |
+
|
99 |
+
|
100 |
+
def on_click(image, target_point, points, size, evt: gr.SelectData):
|
101 |
+
if target_point:
|
102 |
+
points['target'].append([evt.index[1], evt.index[0]])
|
103 |
+
image = add_points_to_image(image, points, size=SIZE_TO_CLICK_SIZE[size])
|
104 |
+
return image, str(evt.index), not target_point
|
105 |
+
points['handle'].append([evt.index[1], evt.index[0]])
|
106 |
+
image = add_points_to_image(image, points, size=SIZE_TO_CLICK_SIZE[size])
|
107 |
+
return image, str(evt.index), not target_point
|
108 |
+
|
109 |
+
|
110 |
+
def on_drag(model, points, max_iters, state, size, mask):
|
111 |
+
if len(points['handle']) == 0:
|
112 |
+
raise gr.Error('You must select at least one handle point and target point.')
|
113 |
+
if len(points['handle']) != len(points['target']):
|
114 |
+
raise gr.Error('You have uncompleted handle points, try to selct a target point or undo the handle point.')
|
115 |
+
max_iters = int(max_iters)
|
116 |
+
latent = state['latent']
|
117 |
+
noise = state['noise']
|
118 |
+
F = state['F']
|
119 |
+
|
120 |
+
handle_points = [torch.tensor(p).float() for p in points['handle']]
|
121 |
+
target_points = [torch.tensor(p).float() for p in points['target']]
|
122 |
+
|
123 |
+
if mask.get('mask') is not None:
|
124 |
+
mask = Image.fromarray(mask['mask']).convert('L')
|
125 |
+
mask = np.array(mask) == 255
|
126 |
+
|
127 |
+
mask = torch.from_numpy(mask).float().to(device)
|
128 |
+
mask = mask.unsqueeze(0).unsqueeze(0)
|
129 |
+
else:
|
130 |
+
mask = None
|
131 |
+
|
132 |
+
step = 0
|
133 |
+
for sample2, latent, F, handle_points in drag_gan(model.g_ema, latent, noise, F,
|
134 |
+
handle_points, target_points, mask,
|
135 |
+
max_iters=max_iters):
|
136 |
+
image = to_image(sample2)
|
137 |
+
|
138 |
+
state['F'] = F
|
139 |
+
state['latent'] = latent
|
140 |
+
state['sample'] = sample2
|
141 |
+
points['handle'] = [p.cpu().numpy().astype('int') for p in handle_points]
|
142 |
+
add_points_to_image(image, points, size=SIZE_TO_CLICK_SIZE[size])
|
143 |
+
|
144 |
+
state['history'].append(image)
|
145 |
+
step += 1
|
146 |
+
yield image, state, step
|
147 |
+
|
148 |
+
|
149 |
+
def on_reset(points, image, state):
|
150 |
+
return {'target': [], 'handle': []}, to_image(state['sample'])
|
151 |
+
|
152 |
+
|
153 |
+
def on_undo(points, image, state, size):
|
154 |
+
image = to_image(state['sample'])
|
155 |
+
|
156 |
+
if len(points['target']) < len(points['handle']):
|
157 |
+
points['handle'] = points['handle'][:-1]
|
158 |
+
else:
|
159 |
+
points['handle'] = points['handle'][:-1]
|
160 |
+
points['target'] = points['target'][:-1]
|
161 |
+
|
162 |
+
add_points_to_image(image, points, size=SIZE_TO_CLICK_SIZE[size])
|
163 |
+
return points, image
|
164 |
+
|
165 |
+
|
166 |
+
def on_change_model(selected, model):
|
167 |
+
size = CKPT_SIZE[selected]
|
168 |
+
model = ModelWrapper(size=size, ckpt=selected)
|
169 |
+
g_ema = model.g_ema
|
170 |
+
sample_z = torch.randn([1, 512], device=device)
|
171 |
+
latent, noise = g_ema.prepare([sample_z])
|
172 |
+
sample, F = g_ema.generate(latent, noise)
|
173 |
+
|
174 |
+
state = {
|
175 |
+
'latent': latent,
|
176 |
+
'noise': noise,
|
177 |
+
'F': F,
|
178 |
+
'sample': sample,
|
179 |
+
'history': []
|
180 |
+
}
|
181 |
+
return model, state, to_image(sample), to_image(sample), size
|
182 |
+
|
183 |
+
|
184 |
+
def on_new_image(model):
|
185 |
+
g_ema = model.g_ema
|
186 |
+
sample_z = torch.randn([1, 512], device=device)
|
187 |
+
latent, noise = g_ema.prepare([sample_z])
|
188 |
+
sample, F = g_ema.generate(latent, noise)
|
189 |
+
|
190 |
+
state = {
|
191 |
+
'latent': latent,
|
192 |
+
'noise': noise,
|
193 |
+
'F': F,
|
194 |
+
'sample': sample,
|
195 |
+
'history': []
|
196 |
+
}
|
197 |
+
points = {'target': [], 'handle': []}
|
198 |
+
target_point = False
|
199 |
+
return to_image(sample), to_image(sample), state, points, target_point
|
200 |
+
|
201 |
+
|
202 |
+
def on_max_iter_change(max_iters):
|
203 |
+
return gr.update(maximum=max_iters)
|
204 |
+
|
205 |
+
|
206 |
+
def on_save_files(image, state):
|
207 |
+
os.makedirs('tmp', exist_ok=True)
|
208 |
+
image_name = f'tmp/image_{uuid.uuid4()}.png'
|
209 |
+
video_name = f'tmp/video_{uuid.uuid4()}.mp4'
|
210 |
+
imageio.imsave(image_name, image)
|
211 |
+
imageio.mimsave(video_name, state['history'])
|
212 |
+
return [image_name, video_name]
|
213 |
+
|
214 |
+
|
215 |
+
def on_show_save():
|
216 |
+
return gr.update(visible=True)
|
217 |
+
|
218 |
+
|
219 |
+
def on_image_change(model, image_size, image):
|
220 |
+
image = Image.fromarray(image)
|
221 |
+
result = inverse_image(
|
222 |
+
model.g_ema,
|
223 |
+
image,
|
224 |
+
image_size=image_size
|
225 |
+
)
|
226 |
+
result['history'] = []
|
227 |
+
image = to_image(result['sample'])
|
228 |
+
points = {'target': [], 'handle': []}
|
229 |
+
target_point = False
|
230 |
+
return image, image, result, points, target_point
|
231 |
+
|
232 |
+
|
233 |
+
def on_mask_change(mask):
|
234 |
+
return mask['image']
|
235 |
+
|
236 |
+
|
237 |
+
def main():
|
238 |
+
torch.cuda.manual_seed(25)
|
239 |
+
|
240 |
+
with gr.Blocks() as demo:
|
241 |
+
wrapped_model = ModelWrapper(ckpt=DEFAULT_CKPT, size=CKPT_SIZE[DEFAULT_CKPT])
|
242 |
+
model = gr.State(wrapped_model)
|
243 |
+
sample_z = torch.randn([1, 512], device=device)
|
244 |
+
latent, noise = wrapped_model.g_ema.prepare([sample_z])
|
245 |
+
sample, F = wrapped_model.g_ema.generate(latent, noise)
|
246 |
+
|
247 |
+
gr.Markdown(
|
248 |
+
"""
|
249 |
+
# DragGAN
|
250 |
+
|
251 |
+
Unofficial implementation of [Drag Your GAN: Interactive Point-based Manipulation on the Generative Image Manifold](https://vcai.mpi-inf.mpg.de/projects/DragGAN/)
|
252 |
+
|
253 |
+
[Our Implementation](https://github.com/Zeqiang-Lai/DragGAN) | [Official Implementation](https://github.com/XingangPan/DragGAN) (Not released yet)
|
254 |
+
|
255 |
+
## Tutorial
|
256 |
+
|
257 |
+
1. (Optional) Draw a mask indicate the movable region.
|
258 |
+
2. Setup a least one pair of handle point and target point.
|
259 |
+
3. Click "Drag it".
|
260 |
+
|
261 |
+
## Hints
|
262 |
+
|
263 |
+
- Handle points (Blue): the point you want to drag.
|
264 |
+
- Target points (Red): the destination you want to drag towards to.
|
265 |
+
|
266 |
+
## Primary Support of Custom Image.
|
267 |
+
|
268 |
+
- We now support dragging user uploaded image by GAN inversion.
|
269 |
+
- **Please upload your image at `Setup Handle Points` pannel.** Upload it from `Draw a Mask` would cause errors for now.
|
270 |
+
- Due to the limitation of GAN inversion,
|
271 |
+
- You might wait roughly 1 minute to see the GAN version of the uploaded image.
|
272 |
+
- The shown image might be slightly difference from the uploaded one.
|
273 |
+
- It could also fail to invert the uploaded image and generate very poor results.
|
274 |
+
- Idealy, you should choose the closest model of the uploaded image. For example, choose `stylegan2-ffhq-config-f.pt` for human face. `stylegan2-cat-config-f.pt` for cat.
|
275 |
+
|
276 |
+
> Please fire an issue if you have encounted any problem. Also don't forgot to give a star to the [Official Repo](https://github.com/XingangPan/DragGAN), [our project](https://github.com/Zeqiang-Lai/DragGAN) could not exist without it.
|
277 |
+
""",
|
278 |
+
)
|
279 |
+
state = gr.State({
|
280 |
+
'latent': latent,
|
281 |
+
'noise': noise,
|
282 |
+
'F': F,
|
283 |
+
'sample': sample,
|
284 |
+
'history': []
|
285 |
+
})
|
286 |
+
points = gr.State({'target': [], 'handle': []})
|
287 |
+
size = gr.State(CKPT_SIZE[DEFAULT_CKPT])
|
288 |
+
|
289 |
+
with gr.Row():
|
290 |
+
with gr.Column(scale=0.3):
|
291 |
+
with gr.Accordion("Model"):
|
292 |
+
model_dropdown = gr.Dropdown(choices=list(CKPT_SIZE.keys()), value=DEFAULT_CKPT,
|
293 |
+
label='StyleGAN2 model')
|
294 |
+
max_iters = gr.Slider(1, 500, 20, step=1, label='Max Iterations')
|
295 |
+
new_btn = gr.Button('New Image')
|
296 |
+
with gr.Accordion('Drag'):
|
297 |
+
with gr.Row():
|
298 |
+
with gr.Column(min_width=100):
|
299 |
+
text = gr.Textbox(label='Selected Point', interactive=False)
|
300 |
+
with gr.Column(min_width=100):
|
301 |
+
target_point = gr.Checkbox(label='Target Point', interactive=False)
|
302 |
+
with gr.Row():
|
303 |
+
with gr.Column(min_width=100):
|
304 |
+
reset_btn = gr.Button('Reset All')
|
305 |
+
with gr.Column(min_width=100):
|
306 |
+
undo_btn = gr.Button('Undo Last')
|
307 |
+
with gr.Row():
|
308 |
+
btn = gr.Button('Drag it', variant='primary')
|
309 |
+
|
310 |
+
with gr.Accordion('Save', visible=False) as save_panel:
|
311 |
+
files = gr.Files(value=[])
|
312 |
+
|
313 |
+
progress = gr.Slider(value=0, maximum=20, label='Progress', interactive=False)
|
314 |
+
|
315 |
+
with gr.Column():
|
316 |
+
with gr.Tabs():
|
317 |
+
with gr.Tab('Draw a Mask', id='mask'):
|
318 |
+
mask = ImageMask(value=to_image(sample), label='Mask').style(height=768, width=768)
|
319 |
+
with gr.Tab('Setup Handle Points', id='input'):
|
320 |
+
image = grImage(to_image(sample)).style(height=768, width=768)
|
321 |
+
|
322 |
+
image.select(on_click, [image, target_point, points, size], [image, text, target_point])
|
323 |
+
image.upload(on_image_change, [model, size, image], [image, mask, state, points, target_point])
|
324 |
+
mask.upload(on_mask_change, [mask], [image])
|
325 |
+
btn.click(on_drag, inputs=[model, points, max_iters, state, size, mask], outputs=[image, state, progress]).then(
|
326 |
+
on_show_save, outputs=save_panel).then(
|
327 |
+
on_save_files, inputs=[image, state], outputs=[files]
|
328 |
+
)
|
329 |
+
reset_btn.click(on_reset, inputs=[points, image, state], outputs=[points, image])
|
330 |
+
undo_btn.click(on_undo, inputs=[points, image, state, size], outputs=[points, image])
|
331 |
+
model_dropdown.change(on_change_model, inputs=[model_dropdown, model], outputs=[model, state, image, mask, size])
|
332 |
+
new_btn.click(on_new_image, inputs=[model], outputs=[image, mask, state, points, target_point])
|
333 |
+
max_iters.change(on_max_iter_change, inputs=max_iters, outputs=progress)
|
334 |
+
return demo
|
335 |
+
|
336 |
+
|
337 |
+
if __name__ == '__main__':
|
338 |
+
import argparse
|
339 |
+
parser = argparse.ArgumentParser()
|
340 |
+
parser.add_argument('--device', default='cuda')
|
341 |
+
parser.add_argument('--share', action='store_true')
|
342 |
+
parser.add_argument('-p', '--port', default=None)
|
343 |
+
parser.add_argument('--ip', default=None)
|
344 |
+
args = parser.parse_args()
|
345 |
+
device = args.device
|
346 |
+
demo = main()
|
347 |
+
print('Successfully loaded, starting gradio demo')
|
348 |
+
demo.queue(concurrency_count=1, max_size=20).launch(share=args.share, server_name=args.ip, server_port=args.port)
|
requirements.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio
|
2 |
+
tqdm
|
3 |
+
torch
|
4 |
+
numpy
|
5 |
+
ninja
|
6 |
+
fire
|
7 |
+
imageio
|
8 |
+
torchvision
|
9 |
+
IPython
|
stylegan2/__init__.py
ADDED
File without changes
|
stylegan2/inversion.py
ADDED
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import os
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch import optim
|
6 |
+
from torch.nn import functional as FF
|
7 |
+
from torchvision import transforms
|
8 |
+
from PIL import Image
|
9 |
+
from tqdm import tqdm
|
10 |
+
import dataclasses
|
11 |
+
|
12 |
+
from .lpips import util
|
13 |
+
|
14 |
+
|
15 |
+
def noise_regularize(noises):
|
16 |
+
loss = 0
|
17 |
+
|
18 |
+
for noise in noises:
|
19 |
+
size = noise.shape[2]
|
20 |
+
|
21 |
+
while True:
|
22 |
+
loss = (
|
23 |
+
loss
|
24 |
+
+ (noise * torch.roll(noise, shifts=1, dims=3)).mean().pow(2)
|
25 |
+
+ (noise * torch.roll(noise, shifts=1, dims=2)).mean().pow(2)
|
26 |
+
)
|
27 |
+
|
28 |
+
if size <= 8:
|
29 |
+
break
|
30 |
+
|
31 |
+
noise = noise.reshape([-1, 1, size // 2, 2, size // 2, 2])
|
32 |
+
noise = noise.mean([3, 5])
|
33 |
+
size //= 2
|
34 |
+
|
35 |
+
return loss
|
36 |
+
|
37 |
+
|
38 |
+
def noise_normalize_(noises):
|
39 |
+
for noise in noises:
|
40 |
+
mean = noise.mean()
|
41 |
+
std = noise.std()
|
42 |
+
|
43 |
+
noise.data.add_(-mean).div_(std)
|
44 |
+
|
45 |
+
|
46 |
+
def get_lr(t, initial_lr, rampdown=0.25, rampup=0.05):
|
47 |
+
lr_ramp = min(1, (1 - t) / rampdown)
|
48 |
+
lr_ramp = 0.5 - 0.5 * math.cos(lr_ramp * math.pi)
|
49 |
+
lr_ramp = lr_ramp * min(1, t / rampup)
|
50 |
+
|
51 |
+
return initial_lr * lr_ramp
|
52 |
+
|
53 |
+
|
54 |
+
def latent_noise(latent, strength):
|
55 |
+
noise = torch.randn_like(latent) * strength
|
56 |
+
|
57 |
+
return latent + noise
|
58 |
+
|
59 |
+
|
60 |
+
def make_image(tensor):
|
61 |
+
return (
|
62 |
+
tensor.detach()
|
63 |
+
.clamp_(min=-1, max=1)
|
64 |
+
.add(1)
|
65 |
+
.div_(2)
|
66 |
+
.mul(255)
|
67 |
+
.type(torch.uint8)
|
68 |
+
.permute(0, 2, 3, 1)
|
69 |
+
.to("cpu")
|
70 |
+
.numpy()
|
71 |
+
)
|
72 |
+
|
73 |
+
|
74 |
+
@dataclasses.dataclass
|
75 |
+
class InverseConfig:
|
76 |
+
lr_warmup = 0.05
|
77 |
+
lr_decay = 0.25
|
78 |
+
lr = 0.1
|
79 |
+
noise = 0.05
|
80 |
+
noise_decay = 0.75
|
81 |
+
step = 1000
|
82 |
+
noise_regularize = 1e5
|
83 |
+
mse = 0
|
84 |
+
w_plus = False,
|
85 |
+
|
86 |
+
|
87 |
+
def inverse_image(
|
88 |
+
g_ema,
|
89 |
+
image,
|
90 |
+
image_size=256,
|
91 |
+
config=InverseConfig()
|
92 |
+
):
|
93 |
+
device = "cuda"
|
94 |
+
args = config
|
95 |
+
|
96 |
+
n_mean_latent = 10000
|
97 |
+
|
98 |
+
resize = min(image_size, 256)
|
99 |
+
|
100 |
+
transform = transforms.Compose(
|
101 |
+
[
|
102 |
+
transforms.Resize(resize),
|
103 |
+
transforms.CenterCrop(resize),
|
104 |
+
transforms.ToTensor(),
|
105 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
|
106 |
+
]
|
107 |
+
)
|
108 |
+
|
109 |
+
imgs = []
|
110 |
+
img = transform(image)
|
111 |
+
imgs.append(img)
|
112 |
+
|
113 |
+
imgs = torch.stack(imgs, 0).to(device)
|
114 |
+
|
115 |
+
with torch.no_grad():
|
116 |
+
noise_sample = torch.randn(n_mean_latent, 512, device=device)
|
117 |
+
latent_out = g_ema.style(noise_sample)
|
118 |
+
|
119 |
+
latent_mean = latent_out.mean(0)
|
120 |
+
latent_std = ((latent_out - latent_mean).pow(2).sum() / n_mean_latent) ** 0.5
|
121 |
+
|
122 |
+
percept = util.PerceptualLoss(
|
123 |
+
model="net-lin", net="vgg", use_gpu=device.startswith("cuda")
|
124 |
+
)
|
125 |
+
|
126 |
+
noises_single = g_ema.make_noise()
|
127 |
+
noises = []
|
128 |
+
for noise in noises_single:
|
129 |
+
noises.append(noise.repeat(imgs.shape[0], 1, 1, 1).normal_())
|
130 |
+
|
131 |
+
latent_in = latent_mean.detach().clone().unsqueeze(0).repeat(imgs.shape[0], 1)
|
132 |
+
|
133 |
+
if args.w_plus:
|
134 |
+
latent_in = latent_in.unsqueeze(1).repeat(1, g_ema.n_latent, 1)
|
135 |
+
|
136 |
+
latent_in.requires_grad = True
|
137 |
+
|
138 |
+
for noise in noises:
|
139 |
+
noise.requires_grad = True
|
140 |
+
|
141 |
+
optimizer = optim.Adam([latent_in] + noises, lr=args.lr)
|
142 |
+
|
143 |
+
pbar = tqdm(range(args.step))
|
144 |
+
latent_path = []
|
145 |
+
|
146 |
+
for i in pbar:
|
147 |
+
t = i / args.step
|
148 |
+
lr = get_lr(t, args.lr)
|
149 |
+
optimizer.param_groups[0]["lr"] = lr
|
150 |
+
noise_strength = latent_std * args.noise * max(0, 1 - t / args.noise_decay) ** 2
|
151 |
+
latent_n = latent_noise(latent_in, noise_strength.item())
|
152 |
+
|
153 |
+
latent, noise = g_ema.prepare([latent_n], input_is_latent=True, noise=noises)
|
154 |
+
img_gen, F = g_ema.generate(latent, noise)
|
155 |
+
|
156 |
+
batch, channel, height, width = img_gen.shape
|
157 |
+
|
158 |
+
if height > 256:
|
159 |
+
factor = height // 256
|
160 |
+
|
161 |
+
img_gen = img_gen.reshape(
|
162 |
+
batch, channel, height // factor, factor, width // factor, factor
|
163 |
+
)
|
164 |
+
img_gen = img_gen.mean([3, 5])
|
165 |
+
|
166 |
+
p_loss = percept(img_gen, imgs).sum()
|
167 |
+
n_loss = noise_regularize(noises)
|
168 |
+
mse_loss = FF.mse_loss(img_gen, imgs)
|
169 |
+
|
170 |
+
loss = p_loss + args.noise_regularize * n_loss + args.mse * mse_loss
|
171 |
+
|
172 |
+
optimizer.zero_grad()
|
173 |
+
loss.backward()
|
174 |
+
optimizer.step()
|
175 |
+
|
176 |
+
noise_normalize_(noises)
|
177 |
+
|
178 |
+
if (i + 1) % 100 == 0:
|
179 |
+
latent_path.append(latent_in.detach().clone())
|
180 |
+
|
181 |
+
pbar.set_description(
|
182 |
+
(
|
183 |
+
f"perceptual: {p_loss.item():.4f}; noise regularize: {n_loss.item():.4f};"
|
184 |
+
f" mse: {mse_loss.item():.4f}; lr: {lr:.4f}"
|
185 |
+
)
|
186 |
+
)
|
187 |
+
|
188 |
+
latent, noise = g_ema.prepare([latent_path[-1]], input_is_latent=True, noise=noises)
|
189 |
+
img_gen, F = g_ema.generate(latent, noise)
|
190 |
+
|
191 |
+
img_ar = make_image(img_gen)
|
192 |
+
|
193 |
+
i = 0
|
194 |
+
|
195 |
+
noise_single = []
|
196 |
+
for noise in noises:
|
197 |
+
noise_single.append(noise[i: i + 1])
|
198 |
+
|
199 |
+
result = {
|
200 |
+
"latent": latent,
|
201 |
+
"noise": noise_single,
|
202 |
+
'F': F,
|
203 |
+
"sample": img_gen,
|
204 |
+
}
|
205 |
+
|
206 |
+
pil_img = Image.fromarray(img_ar[i])
|
207 |
+
pil_img.save('project.png')
|
208 |
+
|
209 |
+
return result
|
stylegan2/lpips/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from __future__ import absolute_import
|
3 |
+
from __future__ import division
|
4 |
+
from __future__ import print_function
|
5 |
+
|
stylegan2/lpips/base_model.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from torch.autograd import Variable
|
5 |
+
from pdb import set_trace as st
|
6 |
+
from IPython import embed
|
7 |
+
|
8 |
+
class BaseModel():
|
9 |
+
def __init__(self):
|
10 |
+
pass;
|
11 |
+
|
12 |
+
def name(self):
|
13 |
+
return 'BaseModel'
|
14 |
+
|
15 |
+
def initialize(self, use_gpu=True, gpu_ids=[0]):
|
16 |
+
self.use_gpu = use_gpu
|
17 |
+
self.gpu_ids = gpu_ids
|
18 |
+
|
19 |
+
def forward(self):
|
20 |
+
pass
|
21 |
+
|
22 |
+
def get_image_paths(self):
|
23 |
+
pass
|
24 |
+
|
25 |
+
def optimize_parameters(self):
|
26 |
+
pass
|
27 |
+
|
28 |
+
def get_current_visuals(self):
|
29 |
+
return self.input
|
30 |
+
|
31 |
+
def get_current_errors(self):
|
32 |
+
return {}
|
33 |
+
|
34 |
+
def save(self, label):
|
35 |
+
pass
|
36 |
+
|
37 |
+
# helper saving function that can be used by subclasses
|
38 |
+
def save_network(self, network, path, network_label, epoch_label):
|
39 |
+
save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
|
40 |
+
save_path = os.path.join(path, save_filename)
|
41 |
+
torch.save(network.state_dict(), save_path)
|
42 |
+
|
43 |
+
# helper loading function that can be used by subclasses
|
44 |
+
def load_network(self, network, network_label, epoch_label):
|
45 |
+
save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
|
46 |
+
save_path = os.path.join(self.save_dir, save_filename)
|
47 |
+
print('Loading network from %s'%save_path)
|
48 |
+
network.load_state_dict(torch.load(save_path))
|
49 |
+
|
50 |
+
def update_learning_rate():
|
51 |
+
pass
|
52 |
+
|
53 |
+
def get_image_paths(self):
|
54 |
+
return self.image_paths
|
55 |
+
|
56 |
+
def save_done(self, flag=False):
|
57 |
+
np.save(os.path.join(self.save_dir, 'done_flag'),flag)
|
58 |
+
np.savetxt(os.path.join(self.save_dir, 'done_flag'),[flag,],fmt='%i')
|
stylegan2/lpips/dist_model.py
ADDED
@@ -0,0 +1,314 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from __future__ import absolute_import
|
3 |
+
|
4 |
+
import sys
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from torch import nn
|
8 |
+
import os
|
9 |
+
from collections import OrderedDict
|
10 |
+
from torch.autograd import Variable
|
11 |
+
import itertools
|
12 |
+
from .base_model import BaseModel
|
13 |
+
from scipy.ndimage import zoom
|
14 |
+
import fractions
|
15 |
+
import functools
|
16 |
+
import skimage.transform
|
17 |
+
from tqdm import tqdm
|
18 |
+
import urllib
|
19 |
+
|
20 |
+
from IPython import embed
|
21 |
+
|
22 |
+
from . import networks_basic as networks
|
23 |
+
from . import util
|
24 |
+
|
25 |
+
|
26 |
+
class DownloadProgressBar(tqdm):
|
27 |
+
def update_to(self, b=1, bsize=1, tsize=None):
|
28 |
+
if tsize is not None:
|
29 |
+
self.total = tsize
|
30 |
+
self.update(b * bsize - self.n)
|
31 |
+
|
32 |
+
|
33 |
+
def get_path(base_path):
|
34 |
+
BASE_DIR = os.path.join('checkpoints')
|
35 |
+
|
36 |
+
save_path = os.path.join(BASE_DIR, base_path)
|
37 |
+
if not os.path.exists(save_path):
|
38 |
+
url = f"https://huggingface.co/aaronb/StyleGAN2/resolve/main/{base_path}"
|
39 |
+
print(f'{base_path} not found')
|
40 |
+
print('Try to download from huggingface: ', url)
|
41 |
+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
42 |
+
download_url(url, save_path)
|
43 |
+
print('Downloaded to ', save_path)
|
44 |
+
return save_path
|
45 |
+
|
46 |
+
|
47 |
+
def download_url(url, output_path):
|
48 |
+
with DownloadProgressBar(unit='B', unit_scale=True,
|
49 |
+
miniters=1, desc=url.split('/')[-1]) as t:
|
50 |
+
urllib.request.urlretrieve(url, filename=output_path, reporthook=t.update_to)
|
51 |
+
|
52 |
+
|
53 |
+
class DistModel(BaseModel):
|
54 |
+
def name(self):
|
55 |
+
return self.model_name
|
56 |
+
|
57 |
+
def initialize(self, model='net-lin', net='alex', colorspace='Lab', pnet_rand=False, pnet_tune=False, model_path=None,
|
58 |
+
use_gpu=True, printNet=False, spatial=False,
|
59 |
+
is_train=False, lr=.0001, beta1=0.5, version='0.1', gpu_ids=[0]):
|
60 |
+
'''
|
61 |
+
INPUTS
|
62 |
+
model - ['net-lin'] for linearly calibrated network
|
63 |
+
['net'] for off-the-shelf network
|
64 |
+
['L2'] for L2 distance in Lab colorspace
|
65 |
+
['SSIM'] for ssim in RGB colorspace
|
66 |
+
net - ['squeeze','alex','vgg']
|
67 |
+
model_path - if None, will look in weights/[NET_NAME].pth
|
68 |
+
colorspace - ['Lab','RGB'] colorspace to use for L2 and SSIM
|
69 |
+
use_gpu - bool - whether or not to use a GPU
|
70 |
+
printNet - bool - whether or not to print network architecture out
|
71 |
+
spatial - bool - whether to output an array containing varying distances across spatial dimensions
|
72 |
+
spatial_shape - if given, output spatial shape. if None then spatial shape is determined automatically via spatial_factor (see below).
|
73 |
+
spatial_factor - if given, specifies upsampling factor relative to the largest spatial extent of a convolutional layer. if None then resized to size of input images.
|
74 |
+
spatial_order - spline order of filter for upsampling in spatial mode, by default 1 (bilinear).
|
75 |
+
is_train - bool - [True] for training mode
|
76 |
+
lr - float - initial learning rate
|
77 |
+
beta1 - float - initial momentum term for adam
|
78 |
+
version - 0.1 for latest, 0.0 was original (with a bug)
|
79 |
+
gpu_ids - int array - [0] by default, gpus to use
|
80 |
+
'''
|
81 |
+
BaseModel.initialize(self, use_gpu=use_gpu, gpu_ids=gpu_ids)
|
82 |
+
|
83 |
+
self.model = model
|
84 |
+
self.net = net
|
85 |
+
self.is_train = is_train
|
86 |
+
self.spatial = spatial
|
87 |
+
self.gpu_ids = gpu_ids
|
88 |
+
self.model_name = '%s [%s]' % (model, net)
|
89 |
+
|
90 |
+
if(self.model == 'net-lin'): # pretrained net + linear layer
|
91 |
+
self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_tune=pnet_tune, pnet_type=net,
|
92 |
+
use_dropout=True, spatial=spatial, version=version, lpips=True)
|
93 |
+
kw = {}
|
94 |
+
if not use_gpu:
|
95 |
+
kw['map_location'] = 'cpu'
|
96 |
+
if(model_path is None):
|
97 |
+
model_path = get_path('weights/v%s/%s.pth' % (version, net))
|
98 |
+
|
99 |
+
if(not is_train):
|
100 |
+
print('Loading model from: %s' % model_path)
|
101 |
+
self.net.load_state_dict(torch.load(model_path, **kw), strict=False)
|
102 |
+
|
103 |
+
elif(self.model == 'net'): # pretrained network
|
104 |
+
self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_type=net, lpips=False)
|
105 |
+
elif(self.model in ['L2', 'l2']):
|
106 |
+
self.net = networks.L2(use_gpu=use_gpu, colorspace=colorspace) # not really a network, only for testing
|
107 |
+
self.model_name = 'L2'
|
108 |
+
elif(self.model in ['DSSIM', 'dssim', 'SSIM', 'ssim']):
|
109 |
+
self.net = networks.DSSIM(use_gpu=use_gpu, colorspace=colorspace)
|
110 |
+
self.model_name = 'SSIM'
|
111 |
+
else:
|
112 |
+
raise ValueError("Model [%s] not recognized." % self.model)
|
113 |
+
|
114 |
+
self.parameters = list(self.net.parameters())
|
115 |
+
|
116 |
+
if self.is_train: # training mode
|
117 |
+
# extra network on top to go from distances (d0,d1) => predicted human judgment (h*)
|
118 |
+
self.rankLoss = networks.BCERankingLoss()
|
119 |
+
self.parameters += list(self.rankLoss.net.parameters())
|
120 |
+
self.lr = lr
|
121 |
+
self.old_lr = lr
|
122 |
+
self.optimizer_net = torch.optim.Adam(self.parameters, lr=lr, betas=(beta1, 0.999))
|
123 |
+
else: # test mode
|
124 |
+
self.net.eval()
|
125 |
+
|
126 |
+
if(use_gpu):
|
127 |
+
self.net.to(gpu_ids[0])
|
128 |
+
self.net = torch.nn.DataParallel(self.net, device_ids=gpu_ids)
|
129 |
+
if(self.is_train):
|
130 |
+
self.rankLoss = self.rankLoss.to(device=gpu_ids[0]) # just put this on GPU0
|
131 |
+
|
132 |
+
if(printNet):
|
133 |
+
print('---------- Networks initialized -------------')
|
134 |
+
networks.print_network(self.net)
|
135 |
+
print('-----------------------------------------------')
|
136 |
+
|
137 |
+
def forward(self, in0, in1, retPerLayer=False):
|
138 |
+
''' Function computes the distance between image patches in0 and in1
|
139 |
+
INPUTS
|
140 |
+
in0, in1 - torch.Tensor object of shape Nx3xXxY - image patch scaled to [-1,1]
|
141 |
+
OUTPUT
|
142 |
+
computed distances between in0 and in1
|
143 |
+
'''
|
144 |
+
|
145 |
+
return self.net.forward(in0, in1, retPerLayer=retPerLayer)
|
146 |
+
|
147 |
+
# ***** TRAINING FUNCTIONS *****
|
148 |
+
def optimize_parameters(self):
|
149 |
+
self.forward_train()
|
150 |
+
self.optimizer_net.zero_grad()
|
151 |
+
self.backward_train()
|
152 |
+
self.optimizer_net.step()
|
153 |
+
self.clamp_weights()
|
154 |
+
|
155 |
+
def clamp_weights(self):
|
156 |
+
for module in self.net.modules():
|
157 |
+
if(hasattr(module, 'weight') and module.kernel_size == (1, 1)):
|
158 |
+
module.weight.data = torch.clamp(module.weight.data, min=0)
|
159 |
+
|
160 |
+
def set_input(self, data):
|
161 |
+
self.input_ref = data['ref']
|
162 |
+
self.input_p0 = data['p0']
|
163 |
+
self.input_p1 = data['p1']
|
164 |
+
self.input_judge = data['judge']
|
165 |
+
|
166 |
+
if(self.use_gpu):
|
167 |
+
self.input_ref = self.input_ref.to(device=self.gpu_ids[0])
|
168 |
+
self.input_p0 = self.input_p0.to(device=self.gpu_ids[0])
|
169 |
+
self.input_p1 = self.input_p1.to(device=self.gpu_ids[0])
|
170 |
+
self.input_judge = self.input_judge.to(device=self.gpu_ids[0])
|
171 |
+
|
172 |
+
self.var_ref = Variable(self.input_ref, requires_grad=True)
|
173 |
+
self.var_p0 = Variable(self.input_p0, requires_grad=True)
|
174 |
+
self.var_p1 = Variable(self.input_p1, requires_grad=True)
|
175 |
+
|
176 |
+
def forward_train(self): # run forward pass
|
177 |
+
# print(self.net.module.scaling_layer.shift)
|
178 |
+
# print(torch.norm(self.net.module.net.slice1[0].weight).item(), torch.norm(self.net.module.lin0.model[1].weight).item())
|
179 |
+
|
180 |
+
self.d0 = self.forward(self.var_ref, self.var_p0)
|
181 |
+
self.d1 = self.forward(self.var_ref, self.var_p1)
|
182 |
+
self.acc_r = self.compute_accuracy(self.d0, self.d1, self.input_judge)
|
183 |
+
|
184 |
+
self.var_judge = Variable(1. * self.input_judge).view(self.d0.size())
|
185 |
+
|
186 |
+
self.loss_total = self.rankLoss.forward(self.d0, self.d1, self.var_judge * 2. - 1.)
|
187 |
+
|
188 |
+
return self.loss_total
|
189 |
+
|
190 |
+
def backward_train(self):
|
191 |
+
torch.mean(self.loss_total).backward()
|
192 |
+
|
193 |
+
def compute_accuracy(self, d0, d1, judge):
|
194 |
+
''' d0, d1 are Variables, judge is a Tensor '''
|
195 |
+
d1_lt_d0 = (d1 < d0).cpu().data.numpy().flatten()
|
196 |
+
judge_per = judge.cpu().numpy().flatten()
|
197 |
+
return d1_lt_d0 * judge_per + (1 - d1_lt_d0) * (1 - judge_per)
|
198 |
+
|
199 |
+
def get_current_errors(self):
|
200 |
+
retDict = OrderedDict([('loss_total', self.loss_total.data.cpu().numpy()),
|
201 |
+
('acc_r', self.acc_r)])
|
202 |
+
|
203 |
+
for key in retDict.keys():
|
204 |
+
retDict[key] = np.mean(retDict[key])
|
205 |
+
|
206 |
+
return retDict
|
207 |
+
|
208 |
+
def get_current_visuals(self):
|
209 |
+
zoom_factor = 256 / self.var_ref.data.size()[2]
|
210 |
+
|
211 |
+
ref_img = util.tensor2im(self.var_ref.data)
|
212 |
+
p0_img = util.tensor2im(self.var_p0.data)
|
213 |
+
p1_img = util.tensor2im(self.var_p1.data)
|
214 |
+
|
215 |
+
ref_img_vis = zoom(ref_img, [zoom_factor, zoom_factor, 1], order=0)
|
216 |
+
p0_img_vis = zoom(p0_img, [zoom_factor, zoom_factor, 1], order=0)
|
217 |
+
p1_img_vis = zoom(p1_img, [zoom_factor, zoom_factor, 1], order=0)
|
218 |
+
|
219 |
+
return OrderedDict([('ref', ref_img_vis),
|
220 |
+
('p0', p0_img_vis),
|
221 |
+
('p1', p1_img_vis)])
|
222 |
+
|
223 |
+
def save(self, path, label):
|
224 |
+
if(self.use_gpu):
|
225 |
+
self.save_network(self.net.module, path, '', label)
|
226 |
+
else:
|
227 |
+
self.save_network(self.net, path, '', label)
|
228 |
+
self.save_network(self.rankLoss.net, path, 'rank', label)
|
229 |
+
|
230 |
+
def update_learning_rate(self, nepoch_decay):
|
231 |
+
lrd = self.lr / nepoch_decay
|
232 |
+
lr = self.old_lr - lrd
|
233 |
+
|
234 |
+
for param_group in self.optimizer_net.param_groups:
|
235 |
+
param_group['lr'] = lr
|
236 |
+
|
237 |
+
print('update lr [%s] decay: %f -> %f' % (type, self.old_lr, lr))
|
238 |
+
self.old_lr = lr
|
239 |
+
|
240 |
+
|
241 |
+
def score_2afc_dataset(data_loader, func, name=''):
|
242 |
+
''' Function computes Two Alternative Forced Choice (2AFC) score using
|
243 |
+
distance function 'func' in dataset 'data_loader'
|
244 |
+
INPUTS
|
245 |
+
data_loader - CustomDatasetDataLoader object - contains a TwoAFCDataset inside
|
246 |
+
func - callable distance function - calling d=func(in0,in1) should take 2
|
247 |
+
pytorch tensors with shape Nx3xXxY, and return numpy array of length N
|
248 |
+
OUTPUTS
|
249 |
+
[0] - 2AFC score in [0,1], fraction of time func agrees with human evaluators
|
250 |
+
[1] - dictionary with following elements
|
251 |
+
d0s,d1s - N arrays containing distances between reference patch to perturbed patches
|
252 |
+
gts - N array in [0,1], preferred patch selected by human evaluators
|
253 |
+
(closer to "0" for left patch p0, "1" for right patch p1,
|
254 |
+
"0.6" means 60pct people preferred right patch, 40pct preferred left)
|
255 |
+
scores - N array in [0,1], corresponding to what percentage function agreed with humans
|
256 |
+
CONSTS
|
257 |
+
N - number of test triplets in data_loader
|
258 |
+
'''
|
259 |
+
|
260 |
+
d0s = []
|
261 |
+
d1s = []
|
262 |
+
gts = []
|
263 |
+
|
264 |
+
for data in tqdm(data_loader.load_data(), desc=name):
|
265 |
+
d0s += func(data['ref'], data['p0']).data.cpu().numpy().flatten().tolist()
|
266 |
+
d1s += func(data['ref'], data['p1']).data.cpu().numpy().flatten().tolist()
|
267 |
+
gts += data['judge'].cpu().numpy().flatten().tolist()
|
268 |
+
|
269 |
+
d0s = np.array(d0s)
|
270 |
+
d1s = np.array(d1s)
|
271 |
+
gts = np.array(gts)
|
272 |
+
scores = (d0s < d1s) * (1. - gts) + (d1s < d0s) * gts + (d1s == d0s) * .5
|
273 |
+
|
274 |
+
return(np.mean(scores), dict(d0s=d0s, d1s=d1s, gts=gts, scores=scores))
|
275 |
+
|
276 |
+
|
277 |
+
def score_jnd_dataset(data_loader, func, name=''):
|
278 |
+
''' Function computes JND score using distance function 'func' in dataset 'data_loader'
|
279 |
+
INPUTS
|
280 |
+
data_loader - CustomDatasetDataLoader object - contains a JNDDataset inside
|
281 |
+
func - callable distance function - calling d=func(in0,in1) should take 2
|
282 |
+
pytorch tensors with shape Nx3xXxY, and return pytorch array of length N
|
283 |
+
OUTPUTS
|
284 |
+
[0] - JND score in [0,1], mAP score (area under precision-recall curve)
|
285 |
+
[1] - dictionary with following elements
|
286 |
+
ds - N array containing distances between two patches shown to human evaluator
|
287 |
+
sames - N array containing fraction of people who thought the two patches were identical
|
288 |
+
CONSTS
|
289 |
+
N - number of test triplets in data_loader
|
290 |
+
'''
|
291 |
+
|
292 |
+
ds = []
|
293 |
+
gts = []
|
294 |
+
|
295 |
+
for data in tqdm(data_loader.load_data(), desc=name):
|
296 |
+
ds += func(data['p0'], data['p1']).data.cpu().numpy().tolist()
|
297 |
+
gts += data['same'].cpu().numpy().flatten().tolist()
|
298 |
+
|
299 |
+
sames = np.array(gts)
|
300 |
+
ds = np.array(ds)
|
301 |
+
|
302 |
+
sorted_inds = np.argsort(ds)
|
303 |
+
ds_sorted = ds[sorted_inds]
|
304 |
+
sames_sorted = sames[sorted_inds]
|
305 |
+
|
306 |
+
TPs = np.cumsum(sames_sorted)
|
307 |
+
FPs = np.cumsum(1 - sames_sorted)
|
308 |
+
FNs = np.sum(sames_sorted) - TPs
|
309 |
+
|
310 |
+
precs = TPs / (TPs + FPs)
|
311 |
+
recs = TPs / (TPs + FNs)
|
312 |
+
score = util.voc_ap(recs, precs)
|
313 |
+
|
314 |
+
return(score, dict(ds=ds, sames=sames))
|
stylegan2/lpips/networks_basic.py
ADDED
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from __future__ import absolute_import
|
3 |
+
|
4 |
+
import sys
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.init as init
|
8 |
+
from torch.autograd import Variable
|
9 |
+
import numpy as np
|
10 |
+
from pdb import set_trace as st
|
11 |
+
from skimage import color
|
12 |
+
from IPython import embed
|
13 |
+
from . import pretrained_networks as pn
|
14 |
+
|
15 |
+
from . import util
|
16 |
+
|
17 |
+
|
18 |
+
def spatial_average(in_tens, keepdim=True):
|
19 |
+
return in_tens.mean([2,3],keepdim=keepdim)
|
20 |
+
|
21 |
+
def upsample(in_tens, out_H=64): # assumes scale factor is same for H and W
|
22 |
+
in_H = in_tens.shape[2]
|
23 |
+
scale_factor = 1.*out_H/in_H
|
24 |
+
|
25 |
+
return nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=False)(in_tens)
|
26 |
+
|
27 |
+
# Learned perceptual metric
|
28 |
+
class PNetLin(nn.Module):
|
29 |
+
def __init__(self, pnet_type='vgg', pnet_rand=False, pnet_tune=False, use_dropout=True, spatial=False, version='0.1', lpips=True):
|
30 |
+
super(PNetLin, self).__init__()
|
31 |
+
|
32 |
+
self.pnet_type = pnet_type
|
33 |
+
self.pnet_tune = pnet_tune
|
34 |
+
self.pnet_rand = pnet_rand
|
35 |
+
self.spatial = spatial
|
36 |
+
self.lpips = lpips
|
37 |
+
self.version = version
|
38 |
+
self.scaling_layer = ScalingLayer()
|
39 |
+
|
40 |
+
if(self.pnet_type in ['vgg','vgg16']):
|
41 |
+
net_type = pn.vgg16
|
42 |
+
self.chns = [64,128,256,512,512]
|
43 |
+
elif(self.pnet_type=='alex'):
|
44 |
+
net_type = pn.alexnet
|
45 |
+
self.chns = [64,192,384,256,256]
|
46 |
+
elif(self.pnet_type=='squeeze'):
|
47 |
+
net_type = pn.squeezenet
|
48 |
+
self.chns = [64,128,256,384,384,512,512]
|
49 |
+
self.L = len(self.chns)
|
50 |
+
|
51 |
+
self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune)
|
52 |
+
|
53 |
+
if(lpips):
|
54 |
+
self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
|
55 |
+
self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
|
56 |
+
self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
|
57 |
+
self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
|
58 |
+
self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
|
59 |
+
self.lins = [self.lin0,self.lin1,self.lin2,self.lin3,self.lin4]
|
60 |
+
if(self.pnet_type=='squeeze'): # 7 layers for squeezenet
|
61 |
+
self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout)
|
62 |
+
self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout)
|
63 |
+
self.lins+=[self.lin5,self.lin6]
|
64 |
+
|
65 |
+
def forward(self, in0, in1, retPerLayer=False):
|
66 |
+
# v0.0 - original release had a bug, where input was not scaled
|
67 |
+
in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) if self.version=='0.1' else (in0, in1)
|
68 |
+
outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input)
|
69 |
+
feats0, feats1, diffs = {}, {}, {}
|
70 |
+
|
71 |
+
for kk in range(self.L):
|
72 |
+
feats0[kk], feats1[kk] = util.normalize_tensor(outs0[kk]), util.normalize_tensor(outs1[kk])
|
73 |
+
diffs[kk] = (feats0[kk]-feats1[kk])**2
|
74 |
+
|
75 |
+
if(self.lpips):
|
76 |
+
if(self.spatial):
|
77 |
+
res = [upsample(self.lins[kk].model(diffs[kk]), out_H=in0.shape[2]) for kk in range(self.L)]
|
78 |
+
else:
|
79 |
+
res = [spatial_average(self.lins[kk].model(diffs[kk]), keepdim=True) for kk in range(self.L)]
|
80 |
+
else:
|
81 |
+
if(self.spatial):
|
82 |
+
res = [upsample(diffs[kk].sum(dim=1,keepdim=True), out_H=in0.shape[2]) for kk in range(self.L)]
|
83 |
+
else:
|
84 |
+
res = [spatial_average(diffs[kk].sum(dim=1,keepdim=True), keepdim=True) for kk in range(self.L)]
|
85 |
+
|
86 |
+
val = res[0]
|
87 |
+
for l in range(1,self.L):
|
88 |
+
val += res[l]
|
89 |
+
|
90 |
+
if(retPerLayer):
|
91 |
+
return (val, res)
|
92 |
+
else:
|
93 |
+
return val
|
94 |
+
|
95 |
+
class ScalingLayer(nn.Module):
|
96 |
+
def __init__(self):
|
97 |
+
super(ScalingLayer, self).__init__()
|
98 |
+
self.register_buffer('shift', torch.Tensor([-.030,-.088,-.188])[None,:,None,None])
|
99 |
+
self.register_buffer('scale', torch.Tensor([.458,.448,.450])[None,:,None,None])
|
100 |
+
|
101 |
+
def forward(self, inp):
|
102 |
+
return (inp - self.shift) / self.scale
|
103 |
+
|
104 |
+
|
105 |
+
class NetLinLayer(nn.Module):
|
106 |
+
''' A single linear layer which does a 1x1 conv '''
|
107 |
+
def __init__(self, chn_in, chn_out=1, use_dropout=False):
|
108 |
+
super(NetLinLayer, self).__init__()
|
109 |
+
|
110 |
+
layers = [nn.Dropout(),] if(use_dropout) else []
|
111 |
+
layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),]
|
112 |
+
self.model = nn.Sequential(*layers)
|
113 |
+
|
114 |
+
|
115 |
+
class Dist2LogitLayer(nn.Module):
|
116 |
+
''' takes 2 distances, puts through fc layers, spits out value between [0,1] (if use_sigmoid is True) '''
|
117 |
+
def __init__(self, chn_mid=32, use_sigmoid=True):
|
118 |
+
super(Dist2LogitLayer, self).__init__()
|
119 |
+
|
120 |
+
layers = [nn.Conv2d(5, chn_mid, 1, stride=1, padding=0, bias=True),]
|
121 |
+
layers += [nn.LeakyReLU(0.2,True),]
|
122 |
+
layers += [nn.Conv2d(chn_mid, chn_mid, 1, stride=1, padding=0, bias=True),]
|
123 |
+
layers += [nn.LeakyReLU(0.2,True),]
|
124 |
+
layers += [nn.Conv2d(chn_mid, 1, 1, stride=1, padding=0, bias=True),]
|
125 |
+
if(use_sigmoid):
|
126 |
+
layers += [nn.Sigmoid(),]
|
127 |
+
self.model = nn.Sequential(*layers)
|
128 |
+
|
129 |
+
def forward(self,d0,d1,eps=0.1):
|
130 |
+
return self.model.forward(torch.cat((d0,d1,d0-d1,d0/(d1+eps),d1/(d0+eps)),dim=1))
|
131 |
+
|
132 |
+
class BCERankingLoss(nn.Module):
|
133 |
+
def __init__(self, chn_mid=32):
|
134 |
+
super(BCERankingLoss, self).__init__()
|
135 |
+
self.net = Dist2LogitLayer(chn_mid=chn_mid)
|
136 |
+
# self.parameters = list(self.net.parameters())
|
137 |
+
self.loss = torch.nn.BCELoss()
|
138 |
+
|
139 |
+
def forward(self, d0, d1, judge):
|
140 |
+
per = (judge+1.)/2.
|
141 |
+
self.logit = self.net.forward(d0,d1)
|
142 |
+
return self.loss(self.logit, per)
|
143 |
+
|
144 |
+
# L2, DSSIM metrics
|
145 |
+
class FakeNet(nn.Module):
|
146 |
+
def __init__(self, use_gpu=True, colorspace='Lab'):
|
147 |
+
super(FakeNet, self).__init__()
|
148 |
+
self.use_gpu = use_gpu
|
149 |
+
self.colorspace=colorspace
|
150 |
+
|
151 |
+
class L2(FakeNet):
|
152 |
+
|
153 |
+
def forward(self, in0, in1, retPerLayer=None):
|
154 |
+
assert(in0.size()[0]==1) # currently only supports batchSize 1
|
155 |
+
|
156 |
+
if(self.colorspace=='RGB'):
|
157 |
+
(N,C,X,Y) = in0.size()
|
158 |
+
value = torch.mean(torch.mean(torch.mean((in0-in1)**2,dim=1).view(N,1,X,Y),dim=2).view(N,1,1,Y),dim=3).view(N)
|
159 |
+
return value
|
160 |
+
elif(self.colorspace=='Lab'):
|
161 |
+
value = util.l2(util.tensor2np(util.tensor2tensorlab(in0.data,to_norm=False)),
|
162 |
+
util.tensor2np(util.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float')
|
163 |
+
ret_var = Variable( torch.Tensor((value,) ) )
|
164 |
+
if(self.use_gpu):
|
165 |
+
ret_var = ret_var.cuda()
|
166 |
+
return ret_var
|
167 |
+
|
168 |
+
class DSSIM(FakeNet):
|
169 |
+
|
170 |
+
def forward(self, in0, in1, retPerLayer=None):
|
171 |
+
assert(in0.size()[0]==1) # currently only supports batchSize 1
|
172 |
+
|
173 |
+
if(self.colorspace=='RGB'):
|
174 |
+
value = util.dssim(1.*util.tensor2im(in0.data), 1.*util.tensor2im(in1.data), range=255.).astype('float')
|
175 |
+
elif(self.colorspace=='Lab'):
|
176 |
+
value = util.dssim(util.tensor2np(util.tensor2tensorlab(in0.data,to_norm=False)),
|
177 |
+
util.tensor2np(util.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float')
|
178 |
+
ret_var = Variable( torch.Tensor((value,) ) )
|
179 |
+
if(self.use_gpu):
|
180 |
+
ret_var = ret_var.cuda()
|
181 |
+
return ret_var
|
182 |
+
|
183 |
+
def print_network(net):
|
184 |
+
num_params = 0
|
185 |
+
for param in net.parameters():
|
186 |
+
num_params += param.numel()
|
187 |
+
print('Network',net)
|
188 |
+
print('Total number of parameters: %d' % num_params)
|
stylegan2/lpips/pretrained_networks.py
ADDED
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import namedtuple
|
2 |
+
import torch
|
3 |
+
from torchvision import models as tv
|
4 |
+
from IPython import embed
|
5 |
+
|
6 |
+
class squeezenet(torch.nn.Module):
|
7 |
+
def __init__(self, requires_grad=False, pretrained=True):
|
8 |
+
super(squeezenet, self).__init__()
|
9 |
+
pretrained_features = tv.squeezenet1_1(pretrained=pretrained).features
|
10 |
+
self.slice1 = torch.nn.Sequential()
|
11 |
+
self.slice2 = torch.nn.Sequential()
|
12 |
+
self.slice3 = torch.nn.Sequential()
|
13 |
+
self.slice4 = torch.nn.Sequential()
|
14 |
+
self.slice5 = torch.nn.Sequential()
|
15 |
+
self.slice6 = torch.nn.Sequential()
|
16 |
+
self.slice7 = torch.nn.Sequential()
|
17 |
+
self.N_slices = 7
|
18 |
+
for x in range(2):
|
19 |
+
self.slice1.add_module(str(x), pretrained_features[x])
|
20 |
+
for x in range(2,5):
|
21 |
+
self.slice2.add_module(str(x), pretrained_features[x])
|
22 |
+
for x in range(5, 8):
|
23 |
+
self.slice3.add_module(str(x), pretrained_features[x])
|
24 |
+
for x in range(8, 10):
|
25 |
+
self.slice4.add_module(str(x), pretrained_features[x])
|
26 |
+
for x in range(10, 11):
|
27 |
+
self.slice5.add_module(str(x), pretrained_features[x])
|
28 |
+
for x in range(11, 12):
|
29 |
+
self.slice6.add_module(str(x), pretrained_features[x])
|
30 |
+
for x in range(12, 13):
|
31 |
+
self.slice7.add_module(str(x), pretrained_features[x])
|
32 |
+
if not requires_grad:
|
33 |
+
for param in self.parameters():
|
34 |
+
param.requires_grad = False
|
35 |
+
|
36 |
+
def forward(self, X):
|
37 |
+
h = self.slice1(X)
|
38 |
+
h_relu1 = h
|
39 |
+
h = self.slice2(h)
|
40 |
+
h_relu2 = h
|
41 |
+
h = self.slice3(h)
|
42 |
+
h_relu3 = h
|
43 |
+
h = self.slice4(h)
|
44 |
+
h_relu4 = h
|
45 |
+
h = self.slice5(h)
|
46 |
+
h_relu5 = h
|
47 |
+
h = self.slice6(h)
|
48 |
+
h_relu6 = h
|
49 |
+
h = self.slice7(h)
|
50 |
+
h_relu7 = h
|
51 |
+
vgg_outputs = namedtuple("SqueezeOutputs", ['relu1','relu2','relu3','relu4','relu5','relu6','relu7'])
|
52 |
+
out = vgg_outputs(h_relu1,h_relu2,h_relu3,h_relu4,h_relu5,h_relu6,h_relu7)
|
53 |
+
|
54 |
+
return out
|
55 |
+
|
56 |
+
|
57 |
+
class alexnet(torch.nn.Module):
|
58 |
+
def __init__(self, requires_grad=False, pretrained=True):
|
59 |
+
super(alexnet, self).__init__()
|
60 |
+
alexnet_pretrained_features = tv.alexnet(pretrained=pretrained).features
|
61 |
+
self.slice1 = torch.nn.Sequential()
|
62 |
+
self.slice2 = torch.nn.Sequential()
|
63 |
+
self.slice3 = torch.nn.Sequential()
|
64 |
+
self.slice4 = torch.nn.Sequential()
|
65 |
+
self.slice5 = torch.nn.Sequential()
|
66 |
+
self.N_slices = 5
|
67 |
+
for x in range(2):
|
68 |
+
self.slice1.add_module(str(x), alexnet_pretrained_features[x])
|
69 |
+
for x in range(2, 5):
|
70 |
+
self.slice2.add_module(str(x), alexnet_pretrained_features[x])
|
71 |
+
for x in range(5, 8):
|
72 |
+
self.slice3.add_module(str(x), alexnet_pretrained_features[x])
|
73 |
+
for x in range(8, 10):
|
74 |
+
self.slice4.add_module(str(x), alexnet_pretrained_features[x])
|
75 |
+
for x in range(10, 12):
|
76 |
+
self.slice5.add_module(str(x), alexnet_pretrained_features[x])
|
77 |
+
if not requires_grad:
|
78 |
+
for param in self.parameters():
|
79 |
+
param.requires_grad = False
|
80 |
+
|
81 |
+
def forward(self, X):
|
82 |
+
h = self.slice1(X)
|
83 |
+
h_relu1 = h
|
84 |
+
h = self.slice2(h)
|
85 |
+
h_relu2 = h
|
86 |
+
h = self.slice3(h)
|
87 |
+
h_relu3 = h
|
88 |
+
h = self.slice4(h)
|
89 |
+
h_relu4 = h
|
90 |
+
h = self.slice5(h)
|
91 |
+
h_relu5 = h
|
92 |
+
alexnet_outputs = namedtuple("AlexnetOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5'])
|
93 |
+
out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5)
|
94 |
+
|
95 |
+
return out
|
96 |
+
|
97 |
+
class vgg16(torch.nn.Module):
|
98 |
+
def __init__(self, requires_grad=False, pretrained=True):
|
99 |
+
super(vgg16, self).__init__()
|
100 |
+
vgg_pretrained_features = tv.vgg16(pretrained=pretrained).features
|
101 |
+
self.slice1 = torch.nn.Sequential()
|
102 |
+
self.slice2 = torch.nn.Sequential()
|
103 |
+
self.slice3 = torch.nn.Sequential()
|
104 |
+
self.slice4 = torch.nn.Sequential()
|
105 |
+
self.slice5 = torch.nn.Sequential()
|
106 |
+
self.N_slices = 5
|
107 |
+
for x in range(4):
|
108 |
+
self.slice1.add_module(str(x), vgg_pretrained_features[x])
|
109 |
+
for x in range(4, 9):
|
110 |
+
self.slice2.add_module(str(x), vgg_pretrained_features[x])
|
111 |
+
for x in range(9, 16):
|
112 |
+
self.slice3.add_module(str(x), vgg_pretrained_features[x])
|
113 |
+
for x in range(16, 23):
|
114 |
+
self.slice4.add_module(str(x), vgg_pretrained_features[x])
|
115 |
+
for x in range(23, 30):
|
116 |
+
self.slice5.add_module(str(x), vgg_pretrained_features[x])
|
117 |
+
if not requires_grad:
|
118 |
+
for param in self.parameters():
|
119 |
+
param.requires_grad = False
|
120 |
+
|
121 |
+
def forward(self, X):
|
122 |
+
h = self.slice1(X)
|
123 |
+
h_relu1_2 = h
|
124 |
+
h = self.slice2(h)
|
125 |
+
h_relu2_2 = h
|
126 |
+
h = self.slice3(h)
|
127 |
+
h_relu3_3 = h
|
128 |
+
h = self.slice4(h)
|
129 |
+
h_relu4_3 = h
|
130 |
+
h = self.slice5(h)
|
131 |
+
h_relu5_3 = h
|
132 |
+
vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
|
133 |
+
out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
|
134 |
+
|
135 |
+
return out
|
136 |
+
|
137 |
+
|
138 |
+
|
139 |
+
class resnet(torch.nn.Module):
|
140 |
+
def __init__(self, requires_grad=False, pretrained=True, num=18):
|
141 |
+
super(resnet, self).__init__()
|
142 |
+
if(num==18):
|
143 |
+
self.net = tv.resnet18(pretrained=pretrained)
|
144 |
+
elif(num==34):
|
145 |
+
self.net = tv.resnet34(pretrained=pretrained)
|
146 |
+
elif(num==50):
|
147 |
+
self.net = tv.resnet50(pretrained=pretrained)
|
148 |
+
elif(num==101):
|
149 |
+
self.net = tv.resnet101(pretrained=pretrained)
|
150 |
+
elif(num==152):
|
151 |
+
self.net = tv.resnet152(pretrained=pretrained)
|
152 |
+
self.N_slices = 5
|
153 |
+
|
154 |
+
self.conv1 = self.net.conv1
|
155 |
+
self.bn1 = self.net.bn1
|
156 |
+
self.relu = self.net.relu
|
157 |
+
self.maxpool = self.net.maxpool
|
158 |
+
self.layer1 = self.net.layer1
|
159 |
+
self.layer2 = self.net.layer2
|
160 |
+
self.layer3 = self.net.layer3
|
161 |
+
self.layer4 = self.net.layer4
|
162 |
+
|
163 |
+
def forward(self, X):
|
164 |
+
h = self.conv1(X)
|
165 |
+
h = self.bn1(h)
|
166 |
+
h = self.relu(h)
|
167 |
+
h_relu1 = h
|
168 |
+
h = self.maxpool(h)
|
169 |
+
h = self.layer1(h)
|
170 |
+
h_conv2 = h
|
171 |
+
h = self.layer2(h)
|
172 |
+
h_conv3 = h
|
173 |
+
h = self.layer3(h)
|
174 |
+
h_conv4 = h
|
175 |
+
h = self.layer4(h)
|
176 |
+
h_conv5 = h
|
177 |
+
|
178 |
+
outputs = namedtuple("Outputs", ['relu1','conv2','conv3','conv4','conv5'])
|
179 |
+
out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5)
|
180 |
+
|
181 |
+
return out
|
stylegan2/lpips/util.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from __future__ import absolute_import
|
3 |
+
from __future__ import division
|
4 |
+
from __future__ import print_function
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
from skimage.metrics import structural_similarity
|
8 |
+
import torch
|
9 |
+
|
10 |
+
|
11 |
+
from . import dist_model
|
12 |
+
|
13 |
+
class PerceptualLoss(torch.nn.Module):
|
14 |
+
def __init__(self, model='net-lin', net='alex', colorspace='rgb', spatial=False, use_gpu=True, gpu_ids=[0]): # VGG using our perceptually-learned weights (LPIPS metric)
|
15 |
+
# def __init__(self, model='net', net='vgg', use_gpu=True): # "default" way of using VGG as a perceptual loss
|
16 |
+
super(PerceptualLoss, self).__init__()
|
17 |
+
print('Setting up Perceptual loss...')
|
18 |
+
self.use_gpu = use_gpu
|
19 |
+
self.spatial = spatial
|
20 |
+
self.gpu_ids = gpu_ids
|
21 |
+
self.model = dist_model.DistModel()
|
22 |
+
self.model.initialize(model=model, net=net, use_gpu=use_gpu, colorspace=colorspace, spatial=self.spatial, gpu_ids=gpu_ids)
|
23 |
+
print('...[%s] initialized'%self.model.name())
|
24 |
+
print('...Done')
|
25 |
+
|
26 |
+
def forward(self, pred, target, normalize=False):
|
27 |
+
"""
|
28 |
+
Pred and target are Variables.
|
29 |
+
If normalize is True, assumes the images are between [0,1] and then scales them between [-1,+1]
|
30 |
+
If normalize is False, assumes the images are already between [-1,+1]
|
31 |
+
|
32 |
+
Inputs pred and target are Nx3xHxW
|
33 |
+
Output pytorch Variable N long
|
34 |
+
"""
|
35 |
+
|
36 |
+
if normalize:
|
37 |
+
target = 2 * target - 1
|
38 |
+
pred = 2 * pred - 1
|
39 |
+
|
40 |
+
return self.model.forward(target, pred)
|
41 |
+
|
42 |
+
def normalize_tensor(in_feat,eps=1e-10):
|
43 |
+
norm_factor = torch.sqrt(torch.sum(in_feat**2,dim=1,keepdim=True))
|
44 |
+
return in_feat/(norm_factor+eps)
|
45 |
+
|
46 |
+
def l2(p0, p1, range=255.):
|
47 |
+
return .5*np.mean((p0 / range - p1 / range)**2)
|
48 |
+
|
49 |
+
def psnr(p0, p1, peak=255.):
|
50 |
+
return 10*np.log10(peak**2/np.mean((1.*p0-1.*p1)**2))
|
51 |
+
|
52 |
+
def dssim(p0, p1, range=255.):
|
53 |
+
return (1 - structural_similarity(p0, p1, data_range=range, multichannel=True)) / 2.
|
54 |
+
|
55 |
+
def rgb2lab(in_img,mean_cent=False):
|
56 |
+
from skimage import color
|
57 |
+
img_lab = color.rgb2lab(in_img)
|
58 |
+
if(mean_cent):
|
59 |
+
img_lab[:,:,0] = img_lab[:,:,0]-50
|
60 |
+
return img_lab
|
61 |
+
|
62 |
+
def tensor2np(tensor_obj):
|
63 |
+
# change dimension of a tensor object into a numpy array
|
64 |
+
return tensor_obj[0].cpu().float().numpy().transpose((1,2,0))
|
65 |
+
|
66 |
+
def np2tensor(np_obj):
|
67 |
+
# change dimenion of np array into tensor array
|
68 |
+
return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
|
69 |
+
|
70 |
+
def tensor2tensorlab(image_tensor,to_norm=True,mc_only=False):
|
71 |
+
# image tensor to lab tensor
|
72 |
+
from skimage import color
|
73 |
+
|
74 |
+
img = tensor2im(image_tensor)
|
75 |
+
img_lab = color.rgb2lab(img)
|
76 |
+
if(mc_only):
|
77 |
+
img_lab[:,:,0] = img_lab[:,:,0]-50
|
78 |
+
if(to_norm and not mc_only):
|
79 |
+
img_lab[:,:,0] = img_lab[:,:,0]-50
|
80 |
+
img_lab = img_lab/100.
|
81 |
+
|
82 |
+
return np2tensor(img_lab)
|
83 |
+
|
84 |
+
def tensorlab2tensor(lab_tensor,return_inbnd=False):
|
85 |
+
from skimage import color
|
86 |
+
import warnings
|
87 |
+
warnings.filterwarnings("ignore")
|
88 |
+
|
89 |
+
lab = tensor2np(lab_tensor)*100.
|
90 |
+
lab[:,:,0] = lab[:,:,0]+50
|
91 |
+
|
92 |
+
rgb_back = 255.*np.clip(color.lab2rgb(lab.astype('float')),0,1)
|
93 |
+
if(return_inbnd):
|
94 |
+
# convert back to lab, see if we match
|
95 |
+
lab_back = color.rgb2lab(rgb_back.astype('uint8'))
|
96 |
+
mask = 1.*np.isclose(lab_back,lab,atol=2.)
|
97 |
+
mask = np2tensor(np.prod(mask,axis=2)[:,:,np.newaxis])
|
98 |
+
return (im2tensor(rgb_back),mask)
|
99 |
+
else:
|
100 |
+
return im2tensor(rgb_back)
|
101 |
+
|
102 |
+
def rgb2lab(input):
|
103 |
+
from skimage import color
|
104 |
+
return color.rgb2lab(input / 255.)
|
105 |
+
|
106 |
+
def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.):
|
107 |
+
image_numpy = image_tensor[0].cpu().float().numpy()
|
108 |
+
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor
|
109 |
+
return image_numpy.astype(imtype)
|
110 |
+
|
111 |
+
def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.):
|
112 |
+
return torch.Tensor((image / factor - cent)
|
113 |
+
[:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
|
114 |
+
|
115 |
+
def tensor2vec(vector_tensor):
|
116 |
+
return vector_tensor.data.cpu().numpy()[:, :, 0, 0]
|
117 |
+
|
118 |
+
def voc_ap(rec, prec, use_07_metric=False):
|
119 |
+
""" ap = voc_ap(rec, prec, [use_07_metric])
|
120 |
+
Compute VOC AP given precision and recall.
|
121 |
+
If use_07_metric is true, uses the
|
122 |
+
VOC 07 11 point method (default:False).
|
123 |
+
"""
|
124 |
+
if use_07_metric:
|
125 |
+
# 11 point metric
|
126 |
+
ap = 0.
|
127 |
+
for t in np.arange(0., 1.1, 0.1):
|
128 |
+
if np.sum(rec >= t) == 0:
|
129 |
+
p = 0
|
130 |
+
else:
|
131 |
+
p = np.max(prec[rec >= t])
|
132 |
+
ap = ap + p / 11.
|
133 |
+
else:
|
134 |
+
# correct AP calculation
|
135 |
+
# first append sentinel values at the end
|
136 |
+
mrec = np.concatenate(([0.], rec, [1.]))
|
137 |
+
mpre = np.concatenate(([0.], prec, [0.]))
|
138 |
+
|
139 |
+
# compute the precision envelope
|
140 |
+
for i in range(mpre.size - 1, 0, -1):
|
141 |
+
mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])
|
142 |
+
|
143 |
+
# to calculate area under PR curve, look for points
|
144 |
+
# where X axis (recall) changes value
|
145 |
+
i = np.where(mrec[1:] != mrec[:-1])[0]
|
146 |
+
|
147 |
+
# and sum (\Delta recall) * prec
|
148 |
+
ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
|
149 |
+
return ap
|
150 |
+
|
151 |
+
def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.):
|
152 |
+
# def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=1.):
|
153 |
+
image_numpy = image_tensor[0].cpu().float().numpy()
|
154 |
+
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor
|
155 |
+
return image_numpy.astype(imtype)
|
156 |
+
|
157 |
+
def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.):
|
158 |
+
# def im2tensor(image, imtype=np.uint8, cent=1., factor=1.):
|
159 |
+
return torch.Tensor((image / factor - cent)
|
160 |
+
[:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
|
stylegan2/model.py
ADDED
@@ -0,0 +1,714 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import random
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch import nn
|
6 |
+
from torch.nn import functional as F
|
7 |
+
|
8 |
+
from .op.fused_act import fused
|
9 |
+
|
10 |
+
if fused is not None:
|
11 |
+
from .op.fused_act import FusedLeakyReLU, fused_leaky_relu
|
12 |
+
else:
|
13 |
+
from .op import FusedLeakyReLU_Native as FusedLeakyReLU
|
14 |
+
from .op import fused_leaky_relu_native as fused_leaky_relu
|
15 |
+
|
16 |
+
from .op.upfirdn2d import upfirdn2d_op
|
17 |
+
|
18 |
+
if upfirdn2d_op is not None:
|
19 |
+
from .op.upfirdn2d import upfirdn2d
|
20 |
+
else:
|
21 |
+
from .op import upfirdn2d_native as upfirdn2d
|
22 |
+
|
23 |
+
from .op import conv2d_gradfix
|
24 |
+
|
25 |
+
# https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.py#L152
|
26 |
+
# https://github.com/rosinality/stylegan2-pytorch/issues/70
|
27 |
+
|
28 |
+
|
29 |
+
class PixelNorm(nn.Module):
|
30 |
+
def __init__(self):
|
31 |
+
super().__init__()
|
32 |
+
|
33 |
+
def forward(self, input):
|
34 |
+
return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
|
35 |
+
|
36 |
+
|
37 |
+
def make_kernel(k):
|
38 |
+
k = torch.tensor(k, dtype=torch.float32)
|
39 |
+
|
40 |
+
if k.ndim == 1:
|
41 |
+
k = k[None, :] * k[:, None]
|
42 |
+
|
43 |
+
k /= k.sum()
|
44 |
+
|
45 |
+
return k
|
46 |
+
|
47 |
+
|
48 |
+
class Upsample(nn.Module):
|
49 |
+
def __init__(self, kernel, factor=2):
|
50 |
+
super().__init__()
|
51 |
+
|
52 |
+
self.factor = factor
|
53 |
+
kernel = make_kernel(kernel) * (factor ** 2)
|
54 |
+
self.register_buffer("kernel", kernel)
|
55 |
+
|
56 |
+
p = kernel.shape[0] - factor
|
57 |
+
|
58 |
+
pad0 = (p + 1) // 2 + factor - 1
|
59 |
+
pad1 = p // 2
|
60 |
+
|
61 |
+
self.pad = (pad0, pad1)
|
62 |
+
|
63 |
+
def forward(self, input):
|
64 |
+
out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
|
65 |
+
|
66 |
+
return out
|
67 |
+
|
68 |
+
|
69 |
+
class Downsample(nn.Module):
|
70 |
+
def __init__(self, kernel, factor=2):
|
71 |
+
super().__init__()
|
72 |
+
|
73 |
+
self.factor = factor
|
74 |
+
kernel = make_kernel(kernel)
|
75 |
+
self.register_buffer("kernel", kernel)
|
76 |
+
|
77 |
+
p = kernel.shape[0] - factor
|
78 |
+
|
79 |
+
pad0 = (p + 1) // 2
|
80 |
+
pad1 = p // 2
|
81 |
+
|
82 |
+
self.pad = (pad0, pad1)
|
83 |
+
|
84 |
+
def forward(self, input):
|
85 |
+
out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)
|
86 |
+
|
87 |
+
return out
|
88 |
+
|
89 |
+
|
90 |
+
class Blur(nn.Module):
|
91 |
+
def __init__(self, kernel, pad, upsample_factor=1):
|
92 |
+
super().__init__()
|
93 |
+
|
94 |
+
kernel = make_kernel(kernel)
|
95 |
+
|
96 |
+
if upsample_factor > 1:
|
97 |
+
kernel = kernel * (upsample_factor ** 2)
|
98 |
+
|
99 |
+
self.register_buffer("kernel", kernel)
|
100 |
+
|
101 |
+
self.pad = pad
|
102 |
+
|
103 |
+
def forward(self, input):
|
104 |
+
out = upfirdn2d(input, self.kernel, pad=self.pad)
|
105 |
+
|
106 |
+
return out
|
107 |
+
|
108 |
+
|
109 |
+
class EqualConv2d(nn.Module):
|
110 |
+
def __init__(
|
111 |
+
self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True
|
112 |
+
):
|
113 |
+
super().__init__()
|
114 |
+
|
115 |
+
self.weight = nn.Parameter(
|
116 |
+
torch.randn(out_channel, in_channel, kernel_size, kernel_size)
|
117 |
+
)
|
118 |
+
self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
|
119 |
+
|
120 |
+
self.stride = stride
|
121 |
+
self.padding = padding
|
122 |
+
|
123 |
+
if bias:
|
124 |
+
self.bias = nn.Parameter(torch.zeros(out_channel))
|
125 |
+
|
126 |
+
else:
|
127 |
+
self.bias = None
|
128 |
+
|
129 |
+
def forward(self, input):
|
130 |
+
out = conv2d_gradfix.conv2d(
|
131 |
+
input,
|
132 |
+
self.weight * self.scale,
|
133 |
+
bias=self.bias,
|
134 |
+
stride=self.stride,
|
135 |
+
padding=self.padding,
|
136 |
+
)
|
137 |
+
|
138 |
+
return out
|
139 |
+
|
140 |
+
def __repr__(self):
|
141 |
+
return (
|
142 |
+
f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},"
|
143 |
+
f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})"
|
144 |
+
)
|
145 |
+
|
146 |
+
|
147 |
+
class EqualLinear(nn.Module):
|
148 |
+
def __init__(
|
149 |
+
self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
|
150 |
+
):
|
151 |
+
super().__init__()
|
152 |
+
|
153 |
+
self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
|
154 |
+
|
155 |
+
if bias:
|
156 |
+
self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
|
157 |
+
|
158 |
+
else:
|
159 |
+
self.bias = None
|
160 |
+
|
161 |
+
self.activation = activation
|
162 |
+
|
163 |
+
self.scale = (1 / math.sqrt(in_dim)) * lr_mul
|
164 |
+
self.lr_mul = lr_mul
|
165 |
+
|
166 |
+
def forward(self, input):
|
167 |
+
if self.activation:
|
168 |
+
out = F.linear(input, self.weight * self.scale)
|
169 |
+
out = fused_leaky_relu(out, self.bias * self.lr_mul)
|
170 |
+
|
171 |
+
else:
|
172 |
+
out = F.linear(
|
173 |
+
input, self.weight * self.scale, bias=self.bias * self.lr_mul
|
174 |
+
)
|
175 |
+
|
176 |
+
return out
|
177 |
+
|
178 |
+
def __repr__(self):
|
179 |
+
return (
|
180 |
+
f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})"
|
181 |
+
)
|
182 |
+
|
183 |
+
|
184 |
+
class ModulatedConv2d(nn.Module):
|
185 |
+
def __init__(
|
186 |
+
self,
|
187 |
+
in_channel,
|
188 |
+
out_channel,
|
189 |
+
kernel_size,
|
190 |
+
style_dim,
|
191 |
+
demodulate=True,
|
192 |
+
upsample=False,
|
193 |
+
downsample=False,
|
194 |
+
blur_kernel=[1, 3, 3, 1],
|
195 |
+
fused=True,
|
196 |
+
):
|
197 |
+
super().__init__()
|
198 |
+
|
199 |
+
self.eps = 1e-8
|
200 |
+
self.kernel_size = kernel_size
|
201 |
+
self.in_channel = in_channel
|
202 |
+
self.out_channel = out_channel
|
203 |
+
self.upsample = upsample
|
204 |
+
self.downsample = downsample
|
205 |
+
|
206 |
+
if upsample:
|
207 |
+
factor = 2
|
208 |
+
p = (len(blur_kernel) - factor) - (kernel_size - 1)
|
209 |
+
pad0 = (p + 1) // 2 + factor - 1
|
210 |
+
pad1 = p // 2 + 1
|
211 |
+
|
212 |
+
self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)
|
213 |
+
|
214 |
+
if downsample:
|
215 |
+
factor = 2
|
216 |
+
p = (len(blur_kernel) - factor) + (kernel_size - 1)
|
217 |
+
pad0 = (p + 1) // 2
|
218 |
+
pad1 = p // 2
|
219 |
+
|
220 |
+
self.blur = Blur(blur_kernel, pad=(pad0, pad1))
|
221 |
+
|
222 |
+
fan_in = in_channel * kernel_size ** 2
|
223 |
+
self.scale = 1 / math.sqrt(fan_in)
|
224 |
+
self.padding = kernel_size // 2
|
225 |
+
|
226 |
+
self.weight = nn.Parameter(
|
227 |
+
torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
|
228 |
+
)
|
229 |
+
|
230 |
+
self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
|
231 |
+
|
232 |
+
self.demodulate = demodulate
|
233 |
+
self.fused = fused
|
234 |
+
|
235 |
+
def __repr__(self):
|
236 |
+
return (
|
237 |
+
f"{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, "
|
238 |
+
f"upsample={self.upsample}, downsample={self.downsample})"
|
239 |
+
)
|
240 |
+
|
241 |
+
def forward(self, input, style):
|
242 |
+
batch, in_channel, height, width = input.shape
|
243 |
+
|
244 |
+
if not self.fused:
|
245 |
+
weight = self.scale * self.weight.squeeze(0)
|
246 |
+
style = self.modulation(style)
|
247 |
+
|
248 |
+
if self.demodulate:
|
249 |
+
w = weight.unsqueeze(0) * style.view(batch, 1, in_channel, 1, 1)
|
250 |
+
dcoefs = (w.square().sum((2, 3, 4)) + 1e-8).rsqrt()
|
251 |
+
|
252 |
+
input = input * style.reshape(batch, in_channel, 1, 1)
|
253 |
+
|
254 |
+
if self.upsample:
|
255 |
+
weight = weight.transpose(0, 1)
|
256 |
+
out = conv2d_gradfix.conv_transpose2d(
|
257 |
+
input, weight, padding=0, stride=2
|
258 |
+
)
|
259 |
+
out = self.blur(out)
|
260 |
+
|
261 |
+
elif self.downsample:
|
262 |
+
input = self.blur(input)
|
263 |
+
out = conv2d_gradfix.conv2d(input, weight, padding=0, stride=2)
|
264 |
+
|
265 |
+
else:
|
266 |
+
out = conv2d_gradfix.conv2d(input, weight, padding=self.padding)
|
267 |
+
|
268 |
+
if self.demodulate:
|
269 |
+
out = out * dcoefs.view(batch, -1, 1, 1)
|
270 |
+
|
271 |
+
return out
|
272 |
+
|
273 |
+
style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
|
274 |
+
weight = self.scale * self.weight * style
|
275 |
+
|
276 |
+
if self.demodulate:
|
277 |
+
demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
|
278 |
+
weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)
|
279 |
+
|
280 |
+
weight = weight.view(
|
281 |
+
batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size
|
282 |
+
)
|
283 |
+
|
284 |
+
if self.upsample:
|
285 |
+
input = input.view(1, batch * in_channel, height, width)
|
286 |
+
weight = weight.view(
|
287 |
+
batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size
|
288 |
+
)
|
289 |
+
weight = weight.transpose(1, 2).reshape(
|
290 |
+
batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size
|
291 |
+
)
|
292 |
+
out = conv2d_gradfix.conv_transpose2d(
|
293 |
+
input, weight, padding=0, stride=2, groups=batch
|
294 |
+
)
|
295 |
+
_, _, height, width = out.shape
|
296 |
+
out = out.view(batch, self.out_channel, height, width)
|
297 |
+
out = self.blur(out)
|
298 |
+
|
299 |
+
elif self.downsample:
|
300 |
+
input = self.blur(input)
|
301 |
+
_, _, height, width = input.shape
|
302 |
+
input = input.view(1, batch * in_channel, height, width)
|
303 |
+
out = conv2d_gradfix.conv2d(
|
304 |
+
input, weight, padding=0, stride=2, groups=batch
|
305 |
+
)
|
306 |
+
_, _, height, width = out.shape
|
307 |
+
out = out.view(batch, self.out_channel, height, width)
|
308 |
+
|
309 |
+
else:
|
310 |
+
input = input.view(1, batch * in_channel, height, width)
|
311 |
+
out = conv2d_gradfix.conv2d(
|
312 |
+
input, weight, padding=self.padding, groups=batch
|
313 |
+
)
|
314 |
+
_, _, height, width = out.shape
|
315 |
+
out = out.view(batch, self.out_channel, height, width)
|
316 |
+
|
317 |
+
return out
|
318 |
+
|
319 |
+
|
320 |
+
class NoiseInjection(nn.Module):
|
321 |
+
def __init__(self):
|
322 |
+
super().__init__()
|
323 |
+
|
324 |
+
self.weight = nn.Parameter(torch.zeros(1))
|
325 |
+
|
326 |
+
def forward(self, image, noise=None):
|
327 |
+
if noise is None:
|
328 |
+
batch, _, height, width = image.shape
|
329 |
+
noise = image.new_empty(batch, 1, height, width).normal_()
|
330 |
+
|
331 |
+
return image + self.weight * noise
|
332 |
+
|
333 |
+
|
334 |
+
class ConstantInput(nn.Module):
|
335 |
+
def __init__(self, channel, size=4):
|
336 |
+
super().__init__()
|
337 |
+
|
338 |
+
self.input = nn.Parameter(torch.randn(1, channel, size, size))
|
339 |
+
|
340 |
+
def forward(self, input):
|
341 |
+
batch = input.shape[0]
|
342 |
+
out = self.input.repeat(batch, 1, 1, 1)
|
343 |
+
|
344 |
+
return out
|
345 |
+
|
346 |
+
|
347 |
+
class StyledConv(nn.Module):
|
348 |
+
def __init__(
|
349 |
+
self,
|
350 |
+
in_channel,
|
351 |
+
out_channel,
|
352 |
+
kernel_size,
|
353 |
+
style_dim,
|
354 |
+
upsample=False,
|
355 |
+
blur_kernel=[1, 3, 3, 1],
|
356 |
+
demodulate=True,
|
357 |
+
):
|
358 |
+
super().__init__()
|
359 |
+
|
360 |
+
self.conv = ModulatedConv2d(
|
361 |
+
in_channel,
|
362 |
+
out_channel,
|
363 |
+
kernel_size,
|
364 |
+
style_dim,
|
365 |
+
upsample=upsample,
|
366 |
+
blur_kernel=blur_kernel,
|
367 |
+
demodulate=demodulate,
|
368 |
+
)
|
369 |
+
|
370 |
+
self.noise = NoiseInjection()
|
371 |
+
# self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1))
|
372 |
+
# self.activate = ScaledLeakyReLU(0.2)
|
373 |
+
self.activate = FusedLeakyReLU(out_channel)
|
374 |
+
|
375 |
+
def forward(self, input, style, noise=None):
|
376 |
+
out = self.conv(input, style)
|
377 |
+
out = self.noise(out, noise=noise)
|
378 |
+
# out = out + self.bias
|
379 |
+
out = self.activate(out)
|
380 |
+
|
381 |
+
return out
|
382 |
+
|
383 |
+
|
384 |
+
class ToRGB(nn.Module):
|
385 |
+
def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):
|
386 |
+
super().__init__()
|
387 |
+
|
388 |
+
if upsample:
|
389 |
+
self.upsample = Upsample(blur_kernel)
|
390 |
+
|
391 |
+
self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False)
|
392 |
+
self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
|
393 |
+
|
394 |
+
def forward(self, input, style, skip=None):
|
395 |
+
out = self.conv(input, style)
|
396 |
+
out = out + self.bias
|
397 |
+
|
398 |
+
if skip is not None:
|
399 |
+
skip = self.upsample(skip)
|
400 |
+
|
401 |
+
out = out + skip
|
402 |
+
|
403 |
+
return out
|
404 |
+
|
405 |
+
|
406 |
+
class Generator(nn.Module):
|
407 |
+
def __init__(
|
408 |
+
self,
|
409 |
+
size,
|
410 |
+
style_dim,
|
411 |
+
n_mlp,
|
412 |
+
channel_multiplier=2,
|
413 |
+
blur_kernel=[1, 3, 3, 1],
|
414 |
+
lr_mlp=0.01,
|
415 |
+
):
|
416 |
+
super().__init__()
|
417 |
+
|
418 |
+
self.size = size
|
419 |
+
|
420 |
+
self.style_dim = style_dim
|
421 |
+
|
422 |
+
layers = [PixelNorm()]
|
423 |
+
|
424 |
+
for i in range(n_mlp):
|
425 |
+
layers.append(
|
426 |
+
EqualLinear(
|
427 |
+
style_dim, style_dim, lr_mul=lr_mlp, activation="fused_lrelu"
|
428 |
+
)
|
429 |
+
)
|
430 |
+
|
431 |
+
self.style = nn.Sequential(*layers)
|
432 |
+
|
433 |
+
self.channels = {
|
434 |
+
4: 512,
|
435 |
+
8: 512,
|
436 |
+
16: 512,
|
437 |
+
32: 512,
|
438 |
+
64: 256 * channel_multiplier,
|
439 |
+
128: 128 * channel_multiplier,
|
440 |
+
256: 64 * channel_multiplier,
|
441 |
+
512: 32 * channel_multiplier,
|
442 |
+
1024: 16 * channel_multiplier,
|
443 |
+
}
|
444 |
+
|
445 |
+
self.input = ConstantInput(self.channels[4])
|
446 |
+
self.conv1 = StyledConv(
|
447 |
+
self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel
|
448 |
+
)
|
449 |
+
self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)
|
450 |
+
|
451 |
+
self.log_size = int(math.log(size, 2))
|
452 |
+
self.num_layers = (self.log_size - 2) * 2 + 1
|
453 |
+
|
454 |
+
self.convs = nn.ModuleList()
|
455 |
+
self.upsamples = nn.ModuleList()
|
456 |
+
self.to_rgbs = nn.ModuleList()
|
457 |
+
self.noises = nn.Module()
|
458 |
+
|
459 |
+
in_channel = self.channels[4]
|
460 |
+
|
461 |
+
for layer_idx in range(self.num_layers):
|
462 |
+
res = (layer_idx + 5) // 2
|
463 |
+
shape = [1, 1, 2 ** res, 2 ** res]
|
464 |
+
self.noises.register_buffer(f"noise_{layer_idx}", torch.randn(*shape))
|
465 |
+
|
466 |
+
for i in range(3, self.log_size + 1):
|
467 |
+
out_channel = self.channels[2 ** i]
|
468 |
+
|
469 |
+
self.convs.append(
|
470 |
+
StyledConv(
|
471 |
+
in_channel,
|
472 |
+
out_channel,
|
473 |
+
3,
|
474 |
+
style_dim,
|
475 |
+
upsample=True,
|
476 |
+
blur_kernel=blur_kernel,
|
477 |
+
)
|
478 |
+
)
|
479 |
+
|
480 |
+
self.convs.append(
|
481 |
+
StyledConv(
|
482 |
+
out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel
|
483 |
+
)
|
484 |
+
)
|
485 |
+
|
486 |
+
self.to_rgbs.append(ToRGB(out_channel, style_dim))
|
487 |
+
|
488 |
+
in_channel = out_channel
|
489 |
+
|
490 |
+
self.n_latent = self.log_size * 2 - 2
|
491 |
+
|
492 |
+
def make_noise(self):
|
493 |
+
device = self.input.input.device
|
494 |
+
|
495 |
+
noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)]
|
496 |
+
|
497 |
+
for i in range(3, self.log_size + 1):
|
498 |
+
for _ in range(2):
|
499 |
+
noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device))
|
500 |
+
|
501 |
+
return noises
|
502 |
+
|
503 |
+
def mean_latent(self, n_latent):
|
504 |
+
latent_in = torch.randn(
|
505 |
+
n_latent, self.style_dim, device=self.input.input.device
|
506 |
+
)
|
507 |
+
latent = self.style(latent_in).mean(0, keepdim=True)
|
508 |
+
|
509 |
+
return latent
|
510 |
+
|
511 |
+
def get_latent(self, input):
|
512 |
+
return self.style(input)
|
513 |
+
|
514 |
+
def forward(
|
515 |
+
self,
|
516 |
+
styles,
|
517 |
+
return_latents=False,
|
518 |
+
inject_index=None,
|
519 |
+
truncation=1,
|
520 |
+
truncation_latent=None,
|
521 |
+
input_is_latent=False,
|
522 |
+
noise=None,
|
523 |
+
randomize_noise=True,
|
524 |
+
):
|
525 |
+
if not input_is_latent:
|
526 |
+
styles = [self.style(s) for s in styles]
|
527 |
+
|
528 |
+
if noise is None:
|
529 |
+
if randomize_noise:
|
530 |
+
noise = [None] * self.num_layers
|
531 |
+
else:
|
532 |
+
noise = [
|
533 |
+
getattr(self.noises, f"noise_{i}") for i in range(self.num_layers)
|
534 |
+
]
|
535 |
+
|
536 |
+
if truncation < 1:
|
537 |
+
style_t = []
|
538 |
+
|
539 |
+
for style in styles:
|
540 |
+
style_t.append(
|
541 |
+
truncation_latent + truncation * (style - truncation_latent)
|
542 |
+
)
|
543 |
+
|
544 |
+
styles = style_t
|
545 |
+
|
546 |
+
if len(styles) < 2:
|
547 |
+
inject_index = self.n_latent
|
548 |
+
|
549 |
+
if styles[0].ndim < 3:
|
550 |
+
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
551 |
+
|
552 |
+
else:
|
553 |
+
latent = styles[0]
|
554 |
+
|
555 |
+
else:
|
556 |
+
if inject_index is None:
|
557 |
+
inject_index = random.randint(1, self.n_latent - 1)
|
558 |
+
|
559 |
+
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
560 |
+
latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)
|
561 |
+
|
562 |
+
latent = torch.cat([latent, latent2], 1)
|
563 |
+
|
564 |
+
out = self.input(latent)
|
565 |
+
out = self.conv1(out, latent[:, 0], noise=noise[0])
|
566 |
+
|
567 |
+
skip = self.to_rgb1(out, latent[:, 1])
|
568 |
+
|
569 |
+
i = 1
|
570 |
+
for conv1, conv2, noise1, noise2, to_rgb in zip(
|
571 |
+
self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
|
572 |
+
):
|
573 |
+
out = conv1(out, latent[:, i], noise=noise1)
|
574 |
+
out = conv2(out, latent[:, i + 1], noise=noise2)
|
575 |
+
skip = to_rgb(out, latent[:, i + 2], skip)
|
576 |
+
|
577 |
+
i += 2
|
578 |
+
|
579 |
+
|
580 |
+
image = skip
|
581 |
+
|
582 |
+
if return_latents:
|
583 |
+
return image, latent
|
584 |
+
|
585 |
+
else:
|
586 |
+
return image, None
|
587 |
+
|
588 |
+
|
589 |
+
class ConvLayer(nn.Sequential):
|
590 |
+
def __init__(
|
591 |
+
self,
|
592 |
+
in_channel,
|
593 |
+
out_channel,
|
594 |
+
kernel_size,
|
595 |
+
downsample=False,
|
596 |
+
blur_kernel=[1, 3, 3, 1],
|
597 |
+
bias=True,
|
598 |
+
activate=True,
|
599 |
+
):
|
600 |
+
layers = []
|
601 |
+
|
602 |
+
if downsample:
|
603 |
+
factor = 2
|
604 |
+
p = (len(blur_kernel) - factor) + (kernel_size - 1)
|
605 |
+
pad0 = (p + 1) // 2
|
606 |
+
pad1 = p // 2
|
607 |
+
|
608 |
+
layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
|
609 |
+
|
610 |
+
stride = 2
|
611 |
+
self.padding = 0
|
612 |
+
|
613 |
+
else:
|
614 |
+
stride = 1
|
615 |
+
self.padding = kernel_size // 2
|
616 |
+
|
617 |
+
layers.append(
|
618 |
+
EqualConv2d(
|
619 |
+
in_channel,
|
620 |
+
out_channel,
|
621 |
+
kernel_size,
|
622 |
+
padding=self.padding,
|
623 |
+
stride=stride,
|
624 |
+
bias=bias and not activate,
|
625 |
+
)
|
626 |
+
)
|
627 |
+
|
628 |
+
if activate:
|
629 |
+
layers.append(FusedLeakyReLU(out_channel, bias=bias))
|
630 |
+
|
631 |
+
super().__init__(*layers)
|
632 |
+
|
633 |
+
|
634 |
+
class ResBlock(nn.Module):
|
635 |
+
def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
|
636 |
+
super().__init__()
|
637 |
+
|
638 |
+
self.conv1 = ConvLayer(in_channel, in_channel, 3)
|
639 |
+
self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
|
640 |
+
|
641 |
+
self.skip = ConvLayer(
|
642 |
+
in_channel, out_channel, 1, downsample=True, activate=False, bias=False
|
643 |
+
)
|
644 |
+
|
645 |
+
def forward(self, input):
|
646 |
+
out = self.conv1(input)
|
647 |
+
out = self.conv2(out)
|
648 |
+
|
649 |
+
skip = self.skip(input)
|
650 |
+
out = (out + skip) / math.sqrt(2)
|
651 |
+
|
652 |
+
return out
|
653 |
+
|
654 |
+
|
655 |
+
class Discriminator(nn.Module):
|
656 |
+
def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]):
|
657 |
+
super().__init__()
|
658 |
+
|
659 |
+
channels = {
|
660 |
+
4: 512,
|
661 |
+
8: 512,
|
662 |
+
16: 512,
|
663 |
+
32: 512,
|
664 |
+
64: 256 * channel_multiplier,
|
665 |
+
128: 128 * channel_multiplier,
|
666 |
+
256: 64 * channel_multiplier,
|
667 |
+
512: 32 * channel_multiplier,
|
668 |
+
1024: 16 * channel_multiplier,
|
669 |
+
}
|
670 |
+
|
671 |
+
convs = [ConvLayer(3, channels[size], 1)]
|
672 |
+
|
673 |
+
log_size = int(math.log(size, 2))
|
674 |
+
|
675 |
+
in_channel = channels[size]
|
676 |
+
|
677 |
+
for i in range(log_size, 2, -1):
|
678 |
+
out_channel = channels[2 ** (i - 1)]
|
679 |
+
|
680 |
+
convs.append(ResBlock(in_channel, out_channel, blur_kernel))
|
681 |
+
|
682 |
+
in_channel = out_channel
|
683 |
+
|
684 |
+
self.convs = nn.Sequential(*convs)
|
685 |
+
|
686 |
+
self.stddev_group = 4
|
687 |
+
self.stddev_feat = 1
|
688 |
+
|
689 |
+
self.final_conv = ConvLayer(in_channel + 1, channels[4], 3)
|
690 |
+
self.final_linear = nn.Sequential(
|
691 |
+
EqualLinear(channels[4] * 4 * 4, channels[4], activation="fused_lrelu"),
|
692 |
+
EqualLinear(channels[4], 1),
|
693 |
+
)
|
694 |
+
|
695 |
+
def forward(self, input):
|
696 |
+
out = self.convs(input)
|
697 |
+
|
698 |
+
batch, channel, height, width = out.shape
|
699 |
+
group = min(batch, self.stddev_group)
|
700 |
+
stddev = out.view(
|
701 |
+
group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
|
702 |
+
)
|
703 |
+
stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
|
704 |
+
stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
|
705 |
+
stddev = stddev.repeat(group, 1, height, width)
|
706 |
+
out = torch.cat([out, stddev], 1)
|
707 |
+
|
708 |
+
out = self.final_conv(out)
|
709 |
+
|
710 |
+
out = out.view(batch, -1)
|
711 |
+
out = self.final_linear(out)
|
712 |
+
|
713 |
+
return out
|
714 |
+
|
stylegan2/op/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .fused_act import FusedLeakyReLU, fused_leaky_relu, fused_leaky_relu_native, FusedLeakyReLU_Native
|
2 |
+
from .upfirdn2d import upfirdn2d, upfirdn2d_native
|
stylegan2/op/conv2d_gradfix.py
ADDED
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import contextlib
|
2 |
+
import warnings
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch import autograd
|
6 |
+
from torch.nn import functional as F
|
7 |
+
|
8 |
+
enabled = True
|
9 |
+
weight_gradients_disabled = False
|
10 |
+
|
11 |
+
|
12 |
+
@contextlib.contextmanager
|
13 |
+
def no_weight_gradients():
|
14 |
+
global weight_gradients_disabled
|
15 |
+
|
16 |
+
old = weight_gradients_disabled
|
17 |
+
weight_gradients_disabled = True
|
18 |
+
yield
|
19 |
+
weight_gradients_disabled = old
|
20 |
+
|
21 |
+
|
22 |
+
def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
|
23 |
+
if could_use_op(input):
|
24 |
+
return conv2d_gradfix(
|
25 |
+
transpose=False,
|
26 |
+
weight_shape=weight.shape,
|
27 |
+
stride=stride,
|
28 |
+
padding=padding,
|
29 |
+
output_padding=0,
|
30 |
+
dilation=dilation,
|
31 |
+
groups=groups,
|
32 |
+
).apply(input, weight, bias)
|
33 |
+
|
34 |
+
return F.conv2d(
|
35 |
+
input=input,
|
36 |
+
weight=weight,
|
37 |
+
bias=bias,
|
38 |
+
stride=stride,
|
39 |
+
padding=padding,
|
40 |
+
dilation=dilation,
|
41 |
+
groups=groups,
|
42 |
+
)
|
43 |
+
|
44 |
+
|
45 |
+
def conv_transpose2d(
|
46 |
+
input,
|
47 |
+
weight,
|
48 |
+
bias=None,
|
49 |
+
stride=1,
|
50 |
+
padding=0,
|
51 |
+
output_padding=0,
|
52 |
+
groups=1,
|
53 |
+
dilation=1,
|
54 |
+
):
|
55 |
+
if could_use_op(input):
|
56 |
+
return conv2d_gradfix(
|
57 |
+
transpose=True,
|
58 |
+
weight_shape=weight.shape,
|
59 |
+
stride=stride,
|
60 |
+
padding=padding,
|
61 |
+
output_padding=output_padding,
|
62 |
+
groups=groups,
|
63 |
+
dilation=dilation,
|
64 |
+
).apply(input, weight, bias)
|
65 |
+
|
66 |
+
return F.conv_transpose2d(
|
67 |
+
input=input,
|
68 |
+
weight=weight,
|
69 |
+
bias=bias,
|
70 |
+
stride=stride,
|
71 |
+
padding=padding,
|
72 |
+
output_padding=output_padding,
|
73 |
+
dilation=dilation,
|
74 |
+
groups=groups,
|
75 |
+
)
|
76 |
+
|
77 |
+
|
78 |
+
def could_use_op(input):
|
79 |
+
return False
|
80 |
+
|
81 |
+
if (not enabled) or (not torch.backends.cudnn.enabled):
|
82 |
+
return False
|
83 |
+
|
84 |
+
if input.device.type != "cuda":
|
85 |
+
return False
|
86 |
+
|
87 |
+
if any(torch.__version__.startswith(x) for x in ["1.7.", "1.8."]):
|
88 |
+
return True
|
89 |
+
|
90 |
+
warnings.warn(
|
91 |
+
f"conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d()."
|
92 |
+
)
|
93 |
+
|
94 |
+
return False
|
95 |
+
|
96 |
+
|
97 |
+
def ensure_tuple(xs, ndim):
|
98 |
+
xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim
|
99 |
+
|
100 |
+
return xs
|
101 |
+
|
102 |
+
|
103 |
+
conv2d_gradfix_cache = dict()
|
104 |
+
|
105 |
+
|
106 |
+
def conv2d_gradfix(
|
107 |
+
transpose, weight_shape, stride, padding, output_padding, dilation, groups
|
108 |
+
):
|
109 |
+
ndim = 2
|
110 |
+
weight_shape = tuple(weight_shape)
|
111 |
+
stride = ensure_tuple(stride, ndim)
|
112 |
+
padding = ensure_tuple(padding, ndim)
|
113 |
+
output_padding = ensure_tuple(output_padding, ndim)
|
114 |
+
dilation = ensure_tuple(dilation, ndim)
|
115 |
+
|
116 |
+
key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups)
|
117 |
+
if key in conv2d_gradfix_cache:
|
118 |
+
return conv2d_gradfix_cache[key]
|
119 |
+
|
120 |
+
common_kwargs = dict(
|
121 |
+
stride=stride, padding=padding, dilation=dilation, groups=groups
|
122 |
+
)
|
123 |
+
|
124 |
+
def calc_output_padding(input_shape, output_shape):
|
125 |
+
if transpose:
|
126 |
+
return [0, 0]
|
127 |
+
|
128 |
+
return [
|
129 |
+
input_shape[i + 2]
|
130 |
+
- (output_shape[i + 2] - 1) * stride[i]
|
131 |
+
- (1 - 2 * padding[i])
|
132 |
+
- dilation[i] * (weight_shape[i + 2] - 1)
|
133 |
+
for i in range(ndim)
|
134 |
+
]
|
135 |
+
|
136 |
+
class Conv2d(autograd.Function):
|
137 |
+
@staticmethod
|
138 |
+
def forward(ctx, input, weight, bias):
|
139 |
+
if not transpose:
|
140 |
+
out = F.conv2d(input=input, weight=weight, bias=bias, **common_kwargs)
|
141 |
+
|
142 |
+
else:
|
143 |
+
out = F.conv_transpose2d(
|
144 |
+
input=input,
|
145 |
+
weight=weight,
|
146 |
+
bias=bias,
|
147 |
+
output_padding=output_padding,
|
148 |
+
**common_kwargs,
|
149 |
+
)
|
150 |
+
|
151 |
+
ctx.save_for_backward(input, weight)
|
152 |
+
|
153 |
+
return out
|
154 |
+
|
155 |
+
@staticmethod
|
156 |
+
def backward(ctx, grad_output):
|
157 |
+
input, weight = ctx.saved_tensors
|
158 |
+
grad_input, grad_weight, grad_bias = None, None, None
|
159 |
+
|
160 |
+
if ctx.needs_input_grad[0]:
|
161 |
+
p = calc_output_padding(
|
162 |
+
input_shape=input.shape, output_shape=grad_output.shape
|
163 |
+
)
|
164 |
+
grad_input = conv2d_gradfix(
|
165 |
+
transpose=(not transpose),
|
166 |
+
weight_shape=weight_shape,
|
167 |
+
output_padding=p,
|
168 |
+
**common_kwargs,
|
169 |
+
).apply(grad_output, weight, None)
|
170 |
+
|
171 |
+
if ctx.needs_input_grad[1] and not weight_gradients_disabled:
|
172 |
+
grad_weight = Conv2dGradWeight.apply(grad_output, input)
|
173 |
+
|
174 |
+
if ctx.needs_input_grad[2]:
|
175 |
+
grad_bias = grad_output.sum((0, 2, 3))
|
176 |
+
|
177 |
+
return grad_input, grad_weight, grad_bias
|
178 |
+
|
179 |
+
class Conv2dGradWeight(autograd.Function):
|
180 |
+
@staticmethod
|
181 |
+
def forward(ctx, grad_output, input):
|
182 |
+
op = torch._C._jit_get_operation(
|
183 |
+
"aten::cudnn_convolution_backward_weight"
|
184 |
+
if not transpose
|
185 |
+
else "aten::cudnn_convolution_transpose_backward_weight"
|
186 |
+
)
|
187 |
+
flags = [
|
188 |
+
torch.backends.cudnn.benchmark,
|
189 |
+
torch.backends.cudnn.deterministic,
|
190 |
+
torch.backends.cudnn.allow_tf32,
|
191 |
+
]
|
192 |
+
grad_weight = op(
|
193 |
+
weight_shape,
|
194 |
+
grad_output,
|
195 |
+
input,
|
196 |
+
padding,
|
197 |
+
stride,
|
198 |
+
dilation,
|
199 |
+
groups,
|
200 |
+
*flags,
|
201 |
+
)
|
202 |
+
ctx.save_for_backward(grad_output, input)
|
203 |
+
|
204 |
+
return grad_weight
|
205 |
+
|
206 |
+
@staticmethod
|
207 |
+
def backward(ctx, grad_grad_weight):
|
208 |
+
grad_output, input = ctx.saved_tensors
|
209 |
+
grad_grad_output, grad_grad_input = None, None
|
210 |
+
|
211 |
+
if ctx.needs_input_grad[0]:
|
212 |
+
grad_grad_output = Conv2d.apply(input, grad_grad_weight, None)
|
213 |
+
|
214 |
+
if ctx.needs_input_grad[1]:
|
215 |
+
p = calc_output_padding(
|
216 |
+
input_shape=input.shape, output_shape=grad_output.shape
|
217 |
+
)
|
218 |
+
grad_grad_input = conv2d_gradfix(
|
219 |
+
transpose=(not transpose),
|
220 |
+
weight_shape=weight_shape,
|
221 |
+
output_padding=p,
|
222 |
+
**common_kwargs,
|
223 |
+
).apply(grad_output, grad_grad_weight, None)
|
224 |
+
|
225 |
+
return grad_grad_output, grad_grad_input
|
226 |
+
|
227 |
+
conv2d_gradfix_cache[key] = Conv2d
|
228 |
+
|
229 |
+
return Conv2d
|
stylegan2/op/fused_act.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
from torch.nn import functional as F
|
6 |
+
from torch.autograd import Function
|
7 |
+
from torch.utils.cpp_extension import load
|
8 |
+
|
9 |
+
import warnings
|
10 |
+
|
11 |
+
module_path = os.path.dirname(os.path.abspath(__file__))
|
12 |
+
|
13 |
+
try:
|
14 |
+
fused = load(
|
15 |
+
"fused",
|
16 |
+
sources=[
|
17 |
+
os.path.join(module_path, "fused_bias_act.cpp"),
|
18 |
+
os.path.join(module_path, "fused_bias_act_kernel.cu"),
|
19 |
+
],
|
20 |
+
)
|
21 |
+
except:
|
22 |
+
warnings.warn(
|
23 |
+
f"(This is not error) Switch to native implementation"
|
24 |
+
)
|
25 |
+
|
26 |
+
fused = None
|
27 |
+
|
28 |
+
|
29 |
+
class FusedLeakyReLUFunctionBackward(Function):
|
30 |
+
@staticmethod
|
31 |
+
def forward(ctx, grad_output, out, bias, negative_slope, scale):
|
32 |
+
ctx.save_for_backward(out)
|
33 |
+
ctx.negative_slope = negative_slope
|
34 |
+
ctx.scale = scale
|
35 |
+
|
36 |
+
empty = grad_output.new_empty(0)
|
37 |
+
|
38 |
+
grad_input = fused.fused_bias_act(
|
39 |
+
grad_output.contiguous(), empty, out, 3, 1, negative_slope, scale
|
40 |
+
)
|
41 |
+
|
42 |
+
dim = [0]
|
43 |
+
|
44 |
+
if grad_input.ndim > 2:
|
45 |
+
dim += list(range(2, grad_input.ndim))
|
46 |
+
|
47 |
+
if bias:
|
48 |
+
grad_bias = grad_input.sum(dim).detach()
|
49 |
+
|
50 |
+
else:
|
51 |
+
grad_bias = empty
|
52 |
+
|
53 |
+
return grad_input, grad_bias
|
54 |
+
|
55 |
+
@staticmethod
|
56 |
+
def backward(ctx, gradgrad_input, gradgrad_bias):
|
57 |
+
out, = ctx.saved_tensors
|
58 |
+
gradgrad_out = fused.fused_bias_act(
|
59 |
+
gradgrad_input.contiguous(),
|
60 |
+
gradgrad_bias,
|
61 |
+
out,
|
62 |
+
3,
|
63 |
+
1,
|
64 |
+
ctx.negative_slope,
|
65 |
+
ctx.scale,
|
66 |
+
)
|
67 |
+
|
68 |
+
return gradgrad_out, None, None, None, None
|
69 |
+
|
70 |
+
|
71 |
+
class FusedLeakyReLUFunction(Function):
|
72 |
+
@staticmethod
|
73 |
+
def forward(ctx, input, bias, negative_slope, scale):
|
74 |
+
empty = input.new_empty(0)
|
75 |
+
|
76 |
+
ctx.bias = bias is not None
|
77 |
+
|
78 |
+
if bias is None:
|
79 |
+
bias = empty
|
80 |
+
|
81 |
+
out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)
|
82 |
+
ctx.save_for_backward(out)
|
83 |
+
ctx.negative_slope = negative_slope
|
84 |
+
ctx.scale = scale
|
85 |
+
|
86 |
+
return out
|
87 |
+
|
88 |
+
@staticmethod
|
89 |
+
def backward(ctx, grad_output):
|
90 |
+
out, = ctx.saved_tensors
|
91 |
+
|
92 |
+
grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(
|
93 |
+
grad_output, out, ctx.bias, ctx.negative_slope, ctx.scale
|
94 |
+
)
|
95 |
+
|
96 |
+
if not ctx.bias:
|
97 |
+
grad_bias = None
|
98 |
+
|
99 |
+
return grad_input, grad_bias, None, None
|
100 |
+
|
101 |
+
|
102 |
+
class FusedLeakyReLU(nn.Module):
|
103 |
+
def __init__(self, channel, bias=True, negative_slope=0.2, scale=2 ** 0.5):
|
104 |
+
super().__init__()
|
105 |
+
|
106 |
+
if bias:
|
107 |
+
self.bias = nn.Parameter(torch.zeros(channel))
|
108 |
+
|
109 |
+
else:
|
110 |
+
self.bias = None
|
111 |
+
|
112 |
+
self.negative_slope = negative_slope
|
113 |
+
self.scale = scale
|
114 |
+
|
115 |
+
def forward(self, input):
|
116 |
+
return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
|
117 |
+
|
118 |
+
|
119 |
+
def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5):
|
120 |
+
if input.device.type == "cpu":
|
121 |
+
if bias is not None:
|
122 |
+
rest_dim = [1] * (input.ndim - bias.ndim - 1)
|
123 |
+
return (
|
124 |
+
F.leaky_relu(
|
125 |
+
input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2
|
126 |
+
)
|
127 |
+
* scale
|
128 |
+
)
|
129 |
+
|
130 |
+
else:
|
131 |
+
return F.leaky_relu(input, negative_slope=0.2) * scale
|
132 |
+
|
133 |
+
else:
|
134 |
+
return FusedLeakyReLUFunction.apply(
|
135 |
+
input.contiguous(), bias, negative_slope, scale
|
136 |
+
)
|
137 |
+
|
138 |
+
|
139 |
+
class FusedLeakyReLU_Native(nn.Module):
|
140 |
+
def __init__(self, channel, bias=True, negative_slope=0.2, scale=2 ** 0.5):
|
141 |
+
super().__init__()
|
142 |
+
|
143 |
+
if bias:
|
144 |
+
self.bias = nn.Parameter(torch.zeros(channel))
|
145 |
+
|
146 |
+
else:
|
147 |
+
self.bias = None
|
148 |
+
|
149 |
+
self.negative_slope = negative_slope
|
150 |
+
self.scale = scale
|
151 |
+
|
152 |
+
def forward(self, input):
|
153 |
+
return fused_leaky_relu_native(input, self.bias, self.negative_slope, self.scale)
|
154 |
+
|
155 |
+
|
156 |
+
def fused_leaky_relu_native(input, bias, negative_slope=0.2, scale=2 ** 0.5):
|
157 |
+
return scale * F.leaky_relu(input + bias.view((1, -1) + (1,) * (len(input.shape) - 2)), negative_slope=negative_slope)
|
stylegan2/op/fused_bias_act.cpp
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
#include <ATen/ATen.h>
|
3 |
+
#include <torch/extension.h>
|
4 |
+
|
5 |
+
torch::Tensor fused_bias_act_op(const torch::Tensor &input,
|
6 |
+
const torch::Tensor &bias,
|
7 |
+
const torch::Tensor &refer, int act, int grad,
|
8 |
+
float alpha, float scale);
|
9 |
+
|
10 |
+
#define CHECK_CUDA(x) \
|
11 |
+
TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
|
12 |
+
#define CHECK_CONTIGUOUS(x) \
|
13 |
+
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
14 |
+
#define CHECK_INPUT(x) \
|
15 |
+
CHECK_CUDA(x); \
|
16 |
+
CHECK_CONTIGUOUS(x)
|
17 |
+
|
18 |
+
torch::Tensor fused_bias_act(const torch::Tensor &input,
|
19 |
+
const torch::Tensor &bias,
|
20 |
+
const torch::Tensor &refer, int act, int grad,
|
21 |
+
float alpha, float scale) {
|
22 |
+
CHECK_INPUT(input);
|
23 |
+
CHECK_INPUT(bias);
|
24 |
+
|
25 |
+
at::DeviceGuard guard(input.device());
|
26 |
+
|
27 |
+
return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);
|
28 |
+
}
|
29 |
+
|
30 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
31 |
+
m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)");
|
32 |
+
}
|
stylegan2/op/fused_bias_act_kernel.cu
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
|
2 |
+
//
|
3 |
+
// This work is made available under the Nvidia Source Code License-NC.
|
4 |
+
// To view a copy of this license, visit
|
5 |
+
// https://nvlabs.github.io/stylegan2/license.html
|
6 |
+
|
7 |
+
#include <torch/types.h>
|
8 |
+
|
9 |
+
#include <ATen/ATen.h>
|
10 |
+
#include <ATen/AccumulateType.h>
|
11 |
+
#include <ATen/cuda/CUDAApplyUtils.cuh>
|
12 |
+
#include <ATen/cuda/CUDAContext.h>
|
13 |
+
|
14 |
+
|
15 |
+
#include <cuda.h>
|
16 |
+
#include <cuda_runtime.h>
|
17 |
+
|
18 |
+
template <typename scalar_t>
|
19 |
+
static __global__ void
|
20 |
+
fused_bias_act_kernel(scalar_t *out, const scalar_t *p_x, const scalar_t *p_b,
|
21 |
+
const scalar_t *p_ref, int act, int grad, scalar_t alpha,
|
22 |
+
scalar_t scale, int loop_x, int size_x, int step_b,
|
23 |
+
int size_b, int use_bias, int use_ref) {
|
24 |
+
int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;
|
25 |
+
|
26 |
+
scalar_t zero = 0.0;
|
27 |
+
|
28 |
+
for (int loop_idx = 0; loop_idx < loop_x && xi < size_x;
|
29 |
+
loop_idx++, xi += blockDim.x) {
|
30 |
+
scalar_t x = p_x[xi];
|
31 |
+
|
32 |
+
if (use_bias) {
|
33 |
+
x += p_b[(xi / step_b) % size_b];
|
34 |
+
}
|
35 |
+
|
36 |
+
scalar_t ref = use_ref ? p_ref[xi] : zero;
|
37 |
+
|
38 |
+
scalar_t y;
|
39 |
+
|
40 |
+
switch (act * 10 + grad) {
|
41 |
+
default:
|
42 |
+
case 10:
|
43 |
+
y = x;
|
44 |
+
break;
|
45 |
+
case 11:
|
46 |
+
y = x;
|
47 |
+
break;
|
48 |
+
case 12:
|
49 |
+
y = 0.0;
|
50 |
+
break;
|
51 |
+
|
52 |
+
case 30:
|
53 |
+
y = (x > 0.0) ? x : x * alpha;
|
54 |
+
break;
|
55 |
+
case 31:
|
56 |
+
y = (ref > 0.0) ? x : x * alpha;
|
57 |
+
break;
|
58 |
+
case 32:
|
59 |
+
y = 0.0;
|
60 |
+
break;
|
61 |
+
}
|
62 |
+
|
63 |
+
out[xi] = y * scale;
|
64 |
+
}
|
65 |
+
}
|
66 |
+
|
67 |
+
torch::Tensor fused_bias_act_op(const torch::Tensor &input,
|
68 |
+
const torch::Tensor &bias,
|
69 |
+
const torch::Tensor &refer, int act, int grad,
|
70 |
+
float alpha, float scale) {
|
71 |
+
int curDevice = -1;
|
72 |
+
cudaGetDevice(&curDevice);
|
73 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
74 |
+
|
75 |
+
auto x = input.contiguous();
|
76 |
+
auto b = bias.contiguous();
|
77 |
+
auto ref = refer.contiguous();
|
78 |
+
|
79 |
+
int use_bias = b.numel() ? 1 : 0;
|
80 |
+
int use_ref = ref.numel() ? 1 : 0;
|
81 |
+
|
82 |
+
int size_x = x.numel();
|
83 |
+
int size_b = b.numel();
|
84 |
+
int step_b = 1;
|
85 |
+
|
86 |
+
for (int i = 1 + 1; i < x.dim(); i++) {
|
87 |
+
step_b *= x.size(i);
|
88 |
+
}
|
89 |
+
|
90 |
+
int loop_x = 4;
|
91 |
+
int block_size = 4 * 32;
|
92 |
+
int grid_size = (size_x - 1) / (loop_x * block_size) + 1;
|
93 |
+
|
94 |
+
auto y = torch::empty_like(x);
|
95 |
+
|
96 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
97 |
+
x.scalar_type(), "fused_bias_act_kernel", [&] {
|
98 |
+
fused_bias_act_kernel<scalar_t><<<grid_size, block_size, 0, stream>>>(
|
99 |
+
y.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(),
|
100 |
+
b.data_ptr<scalar_t>(), ref.data_ptr<scalar_t>(), act, grad, alpha,
|
101 |
+
scale, loop_x, size_x, step_b, size_b, use_bias, use_ref);
|
102 |
+
});
|
103 |
+
|
104 |
+
return y;
|
105 |
+
}
|
stylegan2/op/upfirdn2d.cpp
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <ATen/ATen.h>
|
2 |
+
#include <torch/extension.h>
|
3 |
+
|
4 |
+
torch::Tensor upfirdn2d_op(const torch::Tensor &input,
|
5 |
+
const torch::Tensor &kernel, int up_x, int up_y,
|
6 |
+
int down_x, int down_y, int pad_x0, int pad_x1,
|
7 |
+
int pad_y0, int pad_y1);
|
8 |
+
|
9 |
+
#define CHECK_CUDA(x) \
|
10 |
+
TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
|
11 |
+
#define CHECK_CONTIGUOUS(x) \
|
12 |
+
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
13 |
+
#define CHECK_INPUT(x) \
|
14 |
+
CHECK_CUDA(x); \
|
15 |
+
CHECK_CONTIGUOUS(x)
|
16 |
+
|
17 |
+
torch::Tensor upfirdn2d(const torch::Tensor &input, const torch::Tensor &kernel,
|
18 |
+
int up_x, int up_y, int down_x, int down_y, int pad_x0,
|
19 |
+
int pad_x1, int pad_y0, int pad_y1) {
|
20 |
+
CHECK_INPUT(input);
|
21 |
+
CHECK_INPUT(kernel);
|
22 |
+
|
23 |
+
at::DeviceGuard guard(input.device());
|
24 |
+
|
25 |
+
return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1,
|
26 |
+
pad_y0, pad_y1);
|
27 |
+
}
|
28 |
+
|
29 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
30 |
+
m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
|
31 |
+
}
|
stylegan2/op/upfirdn2d.py
ADDED
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import abc
|
2 |
+
import os
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch.nn import functional as F
|
6 |
+
from torch.autograd import Function
|
7 |
+
from torch.utils.cpp_extension import load
|
8 |
+
import warnings
|
9 |
+
|
10 |
+
module_path = os.path.dirname(os.path.abspath(__file__))
|
11 |
+
|
12 |
+
try:
|
13 |
+
upfirdn2d_op = load(
|
14 |
+
"upfirdn2d",
|
15 |
+
sources=[
|
16 |
+
os.path.join(module_path, "upfirdn2d.cpp"),
|
17 |
+
os.path.join(module_path, "upfirdn2d_kernel.cu"),
|
18 |
+
],
|
19 |
+
)
|
20 |
+
except:
|
21 |
+
warnings.warn(
|
22 |
+
f"(This is not error) Switch to native implementation"
|
23 |
+
)
|
24 |
+
|
25 |
+
upfirdn2d_op = None
|
26 |
+
|
27 |
+
|
28 |
+
class UpFirDn2dBackward(Function):
|
29 |
+
@staticmethod
|
30 |
+
def forward(
|
31 |
+
ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size
|
32 |
+
):
|
33 |
+
|
34 |
+
up_x, up_y = up
|
35 |
+
down_x, down_y = down
|
36 |
+
g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
|
37 |
+
|
38 |
+
grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
|
39 |
+
|
40 |
+
grad_input = upfirdn2d_op.upfirdn2d(
|
41 |
+
grad_output,
|
42 |
+
grad_kernel,
|
43 |
+
down_x,
|
44 |
+
down_y,
|
45 |
+
up_x,
|
46 |
+
up_y,
|
47 |
+
g_pad_x0,
|
48 |
+
g_pad_x1,
|
49 |
+
g_pad_y0,
|
50 |
+
g_pad_y1,
|
51 |
+
)
|
52 |
+
grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])
|
53 |
+
|
54 |
+
ctx.save_for_backward(kernel)
|
55 |
+
|
56 |
+
pad_x0, pad_x1, pad_y0, pad_y1 = pad
|
57 |
+
|
58 |
+
ctx.up_x = up_x
|
59 |
+
ctx.up_y = up_y
|
60 |
+
ctx.down_x = down_x
|
61 |
+
ctx.down_y = down_y
|
62 |
+
ctx.pad_x0 = pad_x0
|
63 |
+
ctx.pad_x1 = pad_x1
|
64 |
+
ctx.pad_y0 = pad_y0
|
65 |
+
ctx.pad_y1 = pad_y1
|
66 |
+
ctx.in_size = in_size
|
67 |
+
ctx.out_size = out_size
|
68 |
+
|
69 |
+
return grad_input
|
70 |
+
|
71 |
+
@staticmethod
|
72 |
+
def backward(ctx, gradgrad_input):
|
73 |
+
kernel, = ctx.saved_tensors
|
74 |
+
|
75 |
+
gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
|
76 |
+
|
77 |
+
gradgrad_out = upfirdn2d_op.upfirdn2d(
|
78 |
+
gradgrad_input,
|
79 |
+
kernel,
|
80 |
+
ctx.up_x,
|
81 |
+
ctx.up_y,
|
82 |
+
ctx.down_x,
|
83 |
+
ctx.down_y,
|
84 |
+
ctx.pad_x0,
|
85 |
+
ctx.pad_x1,
|
86 |
+
ctx.pad_y0,
|
87 |
+
ctx.pad_y1,
|
88 |
+
)
|
89 |
+
# gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3])
|
90 |
+
gradgrad_out = gradgrad_out.view(
|
91 |
+
ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]
|
92 |
+
)
|
93 |
+
|
94 |
+
return gradgrad_out, None, None, None, None, None, None, None, None
|
95 |
+
|
96 |
+
|
97 |
+
class UpFirDn2d(Function):
|
98 |
+
@staticmethod
|
99 |
+
def forward(ctx, input, kernel, up, down, pad):
|
100 |
+
up_x, up_y = up
|
101 |
+
down_x, down_y = down
|
102 |
+
pad_x0, pad_x1, pad_y0, pad_y1 = pad
|
103 |
+
|
104 |
+
kernel_h, kernel_w = kernel.shape
|
105 |
+
batch, channel, in_h, in_w = input.shape
|
106 |
+
ctx.in_size = input.shape
|
107 |
+
|
108 |
+
input = input.reshape(-1, in_h, in_w, 1)
|
109 |
+
|
110 |
+
ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
|
111 |
+
|
112 |
+
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y
|
113 |
+
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x
|
114 |
+
ctx.out_size = (out_h, out_w)
|
115 |
+
|
116 |
+
ctx.up = (up_x, up_y)
|
117 |
+
ctx.down = (down_x, down_y)
|
118 |
+
ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
|
119 |
+
|
120 |
+
g_pad_x0 = kernel_w - pad_x0 - 1
|
121 |
+
g_pad_y0 = kernel_h - pad_y0 - 1
|
122 |
+
g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
|
123 |
+
g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
|
124 |
+
|
125 |
+
ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
|
126 |
+
|
127 |
+
out = upfirdn2d_op.upfirdn2d(
|
128 |
+
input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
|
129 |
+
)
|
130 |
+
# out = out.view(major, out_h, out_w, minor)
|
131 |
+
out = out.view(-1, channel, out_h, out_w)
|
132 |
+
|
133 |
+
return out
|
134 |
+
|
135 |
+
@staticmethod
|
136 |
+
def backward(ctx, grad_output):
|
137 |
+
kernel, grad_kernel = ctx.saved_tensors
|
138 |
+
|
139 |
+
grad_input = None
|
140 |
+
|
141 |
+
if ctx.needs_input_grad[0]:
|
142 |
+
grad_input = UpFirDn2dBackward.apply(
|
143 |
+
grad_output,
|
144 |
+
kernel,
|
145 |
+
grad_kernel,
|
146 |
+
ctx.up,
|
147 |
+
ctx.down,
|
148 |
+
ctx.pad,
|
149 |
+
ctx.g_pad,
|
150 |
+
ctx.in_size,
|
151 |
+
ctx.out_size,
|
152 |
+
)
|
153 |
+
|
154 |
+
return grad_input, None, None, None, None
|
155 |
+
|
156 |
+
|
157 |
+
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
|
158 |
+
if not isinstance(up, abc.Iterable):
|
159 |
+
up = (up, up)
|
160 |
+
|
161 |
+
if not isinstance(down, abc.Iterable):
|
162 |
+
down = (down, down)
|
163 |
+
|
164 |
+
if len(pad) == 2:
|
165 |
+
pad = (pad[0], pad[1], pad[0], pad[1])
|
166 |
+
|
167 |
+
if input.device.type == "cpu":
|
168 |
+
out = _upfirdn2d_native(input, kernel, *up, *down, *pad)
|
169 |
+
|
170 |
+
else:
|
171 |
+
out = UpFirDn2d.apply(input, kernel, up, down, pad)
|
172 |
+
|
173 |
+
return out
|
174 |
+
|
175 |
+
|
176 |
+
def upfirdn2d_native(input, kernel, up=1, down=1, pad=(0, 0)):
|
177 |
+
if not isinstance(up, abc.Iterable):
|
178 |
+
up = (up, up)
|
179 |
+
|
180 |
+
if not isinstance(down, abc.Iterable):
|
181 |
+
down = (down, down)
|
182 |
+
|
183 |
+
if len(pad) == 2:
|
184 |
+
pad = (pad[0], pad[1], pad[0], pad[1])
|
185 |
+
|
186 |
+
out = _upfirdn2d_native(input, kernel, *up, *down, *pad)
|
187 |
+
|
188 |
+
return out
|
189 |
+
|
190 |
+
|
191 |
+
def _upfirdn2d_native(
|
192 |
+
input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
|
193 |
+
):
|
194 |
+
_, channel, in_h, in_w = input.shape
|
195 |
+
input = input.reshape(-1, in_h, in_w, 1)
|
196 |
+
|
197 |
+
_, in_h, in_w, minor = input.shape
|
198 |
+
kernel_h, kernel_w = kernel.shape
|
199 |
+
|
200 |
+
out = input.view(-1, in_h, 1, in_w, 1, minor)
|
201 |
+
out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
|
202 |
+
out = out.view(-1, in_h * up_y, in_w * up_x, minor)
|
203 |
+
|
204 |
+
out = F.pad(
|
205 |
+
out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
|
206 |
+
)
|
207 |
+
out = out[
|
208 |
+
:,
|
209 |
+
max(-pad_y0, 0): out.shape[1] - max(-pad_y1, 0),
|
210 |
+
max(-pad_x0, 0): out.shape[2] - max(-pad_x1, 0),
|
211 |
+
:,
|
212 |
+
]
|
213 |
+
|
214 |
+
out = out.permute(0, 3, 1, 2)
|
215 |
+
out = out.reshape(
|
216 |
+
[-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
|
217 |
+
)
|
218 |
+
w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
|
219 |
+
out = F.conv2d(out, w)
|
220 |
+
out = out.reshape(
|
221 |
+
-1,
|
222 |
+
minor,
|
223 |
+
in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
|
224 |
+
in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
|
225 |
+
)
|
226 |
+
out = out.permute(0, 2, 3, 1)
|
227 |
+
out = out[:, ::down_y, ::down_x, :]
|
228 |
+
|
229 |
+
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y
|
230 |
+
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x
|
231 |
+
|
232 |
+
return out.view(-1, channel, out_h, out_w)
|
stylegan2/op/upfirdn2d_kernel.cu
ADDED
@@ -0,0 +1,369 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
|
2 |
+
//
|
3 |
+
// This work is made available under the Nvidia Source Code License-NC.
|
4 |
+
// To view a copy of this license, visit
|
5 |
+
// https://nvlabs.github.io/stylegan2/license.html
|
6 |
+
|
7 |
+
#include <torch/types.h>
|
8 |
+
|
9 |
+
#include <ATen/ATen.h>
|
10 |
+
#include <ATen/AccumulateType.h>
|
11 |
+
#include <ATen/cuda/CUDAApplyUtils.cuh>
|
12 |
+
#include <ATen/cuda/CUDAContext.h>
|
13 |
+
|
14 |
+
#include <cuda.h>
|
15 |
+
#include <cuda_runtime.h>
|
16 |
+
|
17 |
+
static __host__ __device__ __forceinline__ int floor_div(int a, int b) {
|
18 |
+
int c = a / b;
|
19 |
+
|
20 |
+
if (c * b > a) {
|
21 |
+
c--;
|
22 |
+
}
|
23 |
+
|
24 |
+
return c;
|
25 |
+
}
|
26 |
+
|
27 |
+
struct UpFirDn2DKernelParams {
|
28 |
+
int up_x;
|
29 |
+
int up_y;
|
30 |
+
int down_x;
|
31 |
+
int down_y;
|
32 |
+
int pad_x0;
|
33 |
+
int pad_x1;
|
34 |
+
int pad_y0;
|
35 |
+
int pad_y1;
|
36 |
+
|
37 |
+
int major_dim;
|
38 |
+
int in_h;
|
39 |
+
int in_w;
|
40 |
+
int minor_dim;
|
41 |
+
int kernel_h;
|
42 |
+
int kernel_w;
|
43 |
+
int out_h;
|
44 |
+
int out_w;
|
45 |
+
int loop_major;
|
46 |
+
int loop_x;
|
47 |
+
};
|
48 |
+
|
49 |
+
template <typename scalar_t>
|
50 |
+
__global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input,
|
51 |
+
const scalar_t *kernel,
|
52 |
+
const UpFirDn2DKernelParams p) {
|
53 |
+
int minor_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
54 |
+
int out_y = minor_idx / p.minor_dim;
|
55 |
+
minor_idx -= out_y * p.minor_dim;
|
56 |
+
int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y;
|
57 |
+
int major_idx_base = blockIdx.z * p.loop_major;
|
58 |
+
|
59 |
+
if (out_x_base >= p.out_w || out_y >= p.out_h ||
|
60 |
+
major_idx_base >= p.major_dim) {
|
61 |
+
return;
|
62 |
+
}
|
63 |
+
|
64 |
+
int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0;
|
65 |
+
int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h);
|
66 |
+
int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y;
|
67 |
+
int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y;
|
68 |
+
|
69 |
+
for (int loop_major = 0, major_idx = major_idx_base;
|
70 |
+
loop_major < p.loop_major && major_idx < p.major_dim;
|
71 |
+
loop_major++, major_idx++) {
|
72 |
+
for (int loop_x = 0, out_x = out_x_base;
|
73 |
+
loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) {
|
74 |
+
int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0;
|
75 |
+
int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w);
|
76 |
+
int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x;
|
77 |
+
int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x;
|
78 |
+
|
79 |
+
const scalar_t *x_p =
|
80 |
+
&input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim +
|
81 |
+
minor_idx];
|
82 |
+
const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x];
|
83 |
+
int x_px = p.minor_dim;
|
84 |
+
int k_px = -p.up_x;
|
85 |
+
int x_py = p.in_w * p.minor_dim;
|
86 |
+
int k_py = -p.up_y * p.kernel_w;
|
87 |
+
|
88 |
+
scalar_t v = 0.0f;
|
89 |
+
|
90 |
+
for (int y = 0; y < h; y++) {
|
91 |
+
for (int x = 0; x < w; x++) {
|
92 |
+
v += static_cast<scalar_t>(*x_p) * static_cast<scalar_t>(*k_p);
|
93 |
+
x_p += x_px;
|
94 |
+
k_p += k_px;
|
95 |
+
}
|
96 |
+
|
97 |
+
x_p += x_py - w * x_px;
|
98 |
+
k_p += k_py - w * k_px;
|
99 |
+
}
|
100 |
+
|
101 |
+
out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
|
102 |
+
minor_idx] = v;
|
103 |
+
}
|
104 |
+
}
|
105 |
+
}
|
106 |
+
|
107 |
+
template <typename scalar_t, int up_x, int up_y, int down_x, int down_y,
|
108 |
+
int kernel_h, int kernel_w, int tile_out_h, int tile_out_w>
|
109 |
+
__global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input,
|
110 |
+
const scalar_t *kernel,
|
111 |
+
const UpFirDn2DKernelParams p) {
|
112 |
+
const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1;
|
113 |
+
const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1;
|
114 |
+
|
115 |
+
__shared__ volatile float sk[kernel_h][kernel_w];
|
116 |
+
__shared__ volatile float sx[tile_in_h][tile_in_w];
|
117 |
+
|
118 |
+
int minor_idx = blockIdx.x;
|
119 |
+
int tile_out_y = minor_idx / p.minor_dim;
|
120 |
+
minor_idx -= tile_out_y * p.minor_dim;
|
121 |
+
tile_out_y *= tile_out_h;
|
122 |
+
int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w;
|
123 |
+
int major_idx_base = blockIdx.z * p.loop_major;
|
124 |
+
|
125 |
+
if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h |
|
126 |
+
major_idx_base >= p.major_dim) {
|
127 |
+
return;
|
128 |
+
}
|
129 |
+
|
130 |
+
for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w;
|
131 |
+
tap_idx += blockDim.x) {
|
132 |
+
int ky = tap_idx / kernel_w;
|
133 |
+
int kx = tap_idx - ky * kernel_w;
|
134 |
+
scalar_t v = 0.0;
|
135 |
+
|
136 |
+
if (kx < p.kernel_w & ky < p.kernel_h) {
|
137 |
+
v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)];
|
138 |
+
}
|
139 |
+
|
140 |
+
sk[ky][kx] = v;
|
141 |
+
}
|
142 |
+
|
143 |
+
for (int loop_major = 0, major_idx = major_idx_base;
|
144 |
+
loop_major < p.loop_major & major_idx < p.major_dim;
|
145 |
+
loop_major++, major_idx++) {
|
146 |
+
for (int loop_x = 0, tile_out_x = tile_out_x_base;
|
147 |
+
loop_x < p.loop_x & tile_out_x < p.out_w;
|
148 |
+
loop_x++, tile_out_x += tile_out_w) {
|
149 |
+
int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0;
|
150 |
+
int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0;
|
151 |
+
int tile_in_x = floor_div(tile_mid_x, up_x);
|
152 |
+
int tile_in_y = floor_div(tile_mid_y, up_y);
|
153 |
+
|
154 |
+
__syncthreads();
|
155 |
+
|
156 |
+
for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w;
|
157 |
+
in_idx += blockDim.x) {
|
158 |
+
int rel_in_y = in_idx / tile_in_w;
|
159 |
+
int rel_in_x = in_idx - rel_in_y * tile_in_w;
|
160 |
+
int in_x = rel_in_x + tile_in_x;
|
161 |
+
int in_y = rel_in_y + tile_in_y;
|
162 |
+
|
163 |
+
scalar_t v = 0.0;
|
164 |
+
|
165 |
+
if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) {
|
166 |
+
v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) *
|
167 |
+
p.minor_dim +
|
168 |
+
minor_idx];
|
169 |
+
}
|
170 |
+
|
171 |
+
sx[rel_in_y][rel_in_x] = v;
|
172 |
+
}
|
173 |
+
|
174 |
+
__syncthreads();
|
175 |
+
for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w;
|
176 |
+
out_idx += blockDim.x) {
|
177 |
+
int rel_out_y = out_idx / tile_out_w;
|
178 |
+
int rel_out_x = out_idx - rel_out_y * tile_out_w;
|
179 |
+
int out_x = rel_out_x + tile_out_x;
|
180 |
+
int out_y = rel_out_y + tile_out_y;
|
181 |
+
|
182 |
+
int mid_x = tile_mid_x + rel_out_x * down_x;
|
183 |
+
int mid_y = tile_mid_y + rel_out_y * down_y;
|
184 |
+
int in_x = floor_div(mid_x, up_x);
|
185 |
+
int in_y = floor_div(mid_y, up_y);
|
186 |
+
int rel_in_x = in_x - tile_in_x;
|
187 |
+
int rel_in_y = in_y - tile_in_y;
|
188 |
+
int kernel_x = (in_x + 1) * up_x - mid_x - 1;
|
189 |
+
int kernel_y = (in_y + 1) * up_y - mid_y - 1;
|
190 |
+
|
191 |
+
scalar_t v = 0.0;
|
192 |
+
|
193 |
+
#pragma unroll
|
194 |
+
for (int y = 0; y < kernel_h / up_y; y++)
|
195 |
+
#pragma unroll
|
196 |
+
for (int x = 0; x < kernel_w / up_x; x++)
|
197 |
+
v += sx[rel_in_y + y][rel_in_x + x] *
|
198 |
+
sk[kernel_y + y * up_y][kernel_x + x * up_x];
|
199 |
+
|
200 |
+
if (out_x < p.out_w & out_y < p.out_h) {
|
201 |
+
out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
|
202 |
+
minor_idx] = v;
|
203 |
+
}
|
204 |
+
}
|
205 |
+
}
|
206 |
+
}
|
207 |
+
}
|
208 |
+
|
209 |
+
torch::Tensor upfirdn2d_op(const torch::Tensor &input,
|
210 |
+
const torch::Tensor &kernel, int up_x, int up_y,
|
211 |
+
int down_x, int down_y, int pad_x0, int pad_x1,
|
212 |
+
int pad_y0, int pad_y1) {
|
213 |
+
int curDevice = -1;
|
214 |
+
cudaGetDevice(&curDevice);
|
215 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
216 |
+
|
217 |
+
UpFirDn2DKernelParams p;
|
218 |
+
|
219 |
+
auto x = input.contiguous();
|
220 |
+
auto k = kernel.contiguous();
|
221 |
+
|
222 |
+
p.major_dim = x.size(0);
|
223 |
+
p.in_h = x.size(1);
|
224 |
+
p.in_w = x.size(2);
|
225 |
+
p.minor_dim = x.size(3);
|
226 |
+
p.kernel_h = k.size(0);
|
227 |
+
p.kernel_w = k.size(1);
|
228 |
+
p.up_x = up_x;
|
229 |
+
p.up_y = up_y;
|
230 |
+
p.down_x = down_x;
|
231 |
+
p.down_y = down_y;
|
232 |
+
p.pad_x0 = pad_x0;
|
233 |
+
p.pad_x1 = pad_x1;
|
234 |
+
p.pad_y0 = pad_y0;
|
235 |
+
p.pad_y1 = pad_y1;
|
236 |
+
|
237 |
+
p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) /
|
238 |
+
p.down_y;
|
239 |
+
p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) /
|
240 |
+
p.down_x;
|
241 |
+
|
242 |
+
auto out =
|
243 |
+
at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options());
|
244 |
+
|
245 |
+
int mode = -1;
|
246 |
+
|
247 |
+
int tile_out_h = -1;
|
248 |
+
int tile_out_w = -1;
|
249 |
+
|
250 |
+
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
|
251 |
+
p.kernel_h <= 4 && p.kernel_w <= 4) {
|
252 |
+
mode = 1;
|
253 |
+
tile_out_h = 16;
|
254 |
+
tile_out_w = 64;
|
255 |
+
}
|
256 |
+
|
257 |
+
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
|
258 |
+
p.kernel_h <= 3 && p.kernel_w <= 3) {
|
259 |
+
mode = 2;
|
260 |
+
tile_out_h = 16;
|
261 |
+
tile_out_w = 64;
|
262 |
+
}
|
263 |
+
|
264 |
+
if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
|
265 |
+
p.kernel_h <= 4 && p.kernel_w <= 4) {
|
266 |
+
mode = 3;
|
267 |
+
tile_out_h = 16;
|
268 |
+
tile_out_w = 64;
|
269 |
+
}
|
270 |
+
|
271 |
+
if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
|
272 |
+
p.kernel_h <= 2 && p.kernel_w <= 2) {
|
273 |
+
mode = 4;
|
274 |
+
tile_out_h = 16;
|
275 |
+
tile_out_w = 64;
|
276 |
+
}
|
277 |
+
|
278 |
+
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
|
279 |
+
p.kernel_h <= 4 && p.kernel_w <= 4) {
|
280 |
+
mode = 5;
|
281 |
+
tile_out_h = 8;
|
282 |
+
tile_out_w = 32;
|
283 |
+
}
|
284 |
+
|
285 |
+
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
|
286 |
+
p.kernel_h <= 2 && p.kernel_w <= 2) {
|
287 |
+
mode = 6;
|
288 |
+
tile_out_h = 8;
|
289 |
+
tile_out_w = 32;
|
290 |
+
}
|
291 |
+
|
292 |
+
dim3 block_size;
|
293 |
+
dim3 grid_size;
|
294 |
+
|
295 |
+
if (tile_out_h > 0 && tile_out_w > 0) {
|
296 |
+
p.loop_major = (p.major_dim - 1) / 16384 + 1;
|
297 |
+
p.loop_x = 1;
|
298 |
+
block_size = dim3(32 * 8, 1, 1);
|
299 |
+
grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim,
|
300 |
+
(p.out_w - 1) / (p.loop_x * tile_out_w) + 1,
|
301 |
+
(p.major_dim - 1) / p.loop_major + 1);
|
302 |
+
} else {
|
303 |
+
p.loop_major = (p.major_dim - 1) / 16384 + 1;
|
304 |
+
p.loop_x = 4;
|
305 |
+
block_size = dim3(4, 32, 1);
|
306 |
+
grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1,
|
307 |
+
(p.out_w - 1) / (p.loop_x * block_size.y) + 1,
|
308 |
+
(p.major_dim - 1) / p.loop_major + 1);
|
309 |
+
}
|
310 |
+
|
311 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] {
|
312 |
+
switch (mode) {
|
313 |
+
case 1:
|
314 |
+
upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 4, 4, 16, 64>
|
315 |
+
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
|
316 |
+
x.data_ptr<scalar_t>(),
|
317 |
+
k.data_ptr<scalar_t>(), p);
|
318 |
+
|
319 |
+
break;
|
320 |
+
|
321 |
+
case 2:
|
322 |
+
upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 3, 3, 16, 64>
|
323 |
+
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
|
324 |
+
x.data_ptr<scalar_t>(),
|
325 |
+
k.data_ptr<scalar_t>(), p);
|
326 |
+
|
327 |
+
break;
|
328 |
+
|
329 |
+
case 3:
|
330 |
+
upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 4, 4, 16, 64>
|
331 |
+
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
|
332 |
+
x.data_ptr<scalar_t>(),
|
333 |
+
k.data_ptr<scalar_t>(), p);
|
334 |
+
|
335 |
+
break;
|
336 |
+
|
337 |
+
case 4:
|
338 |
+
upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 2, 2, 16, 64>
|
339 |
+
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
|
340 |
+
x.data_ptr<scalar_t>(),
|
341 |
+
k.data_ptr<scalar_t>(), p);
|
342 |
+
|
343 |
+
break;
|
344 |
+
|
345 |
+
case 5:
|
346 |
+
upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32>
|
347 |
+
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
|
348 |
+
x.data_ptr<scalar_t>(),
|
349 |
+
k.data_ptr<scalar_t>(), p);
|
350 |
+
|
351 |
+
break;
|
352 |
+
|
353 |
+
case 6:
|
354 |
+
upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32>
|
355 |
+
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
|
356 |
+
x.data_ptr<scalar_t>(),
|
357 |
+
k.data_ptr<scalar_t>(), p);
|
358 |
+
|
359 |
+
break;
|
360 |
+
|
361 |
+
default:
|
362 |
+
upfirdn2d_kernel_large<scalar_t><<<grid_size, block_size, 0, stream>>>(
|
363 |
+
out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(),
|
364 |
+
k.data_ptr<scalar_t>(), p);
|
365 |
+
}
|
366 |
+
});
|
367 |
+
|
368 |
+
return out;
|
369 |
+
}
|