stylegan3_clip / viz /renderer.py
akhaliq3
stylegan3
a8c8bc6
raw
history blame
15.3 kB
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
import sys
import copy
import traceback
import numpy as np
import torch
import torch.fft
import torch.nn
import matplotlib.cm
import dnnlib
from torch_utils.ops import upfirdn2d
import legacy # pylint: disable=import-error
#----------------------------------------------------------------------------
class CapturedException(Exception):
def __init__(self, msg=None):
if msg is None:
_type, value, _traceback = sys.exc_info()
assert value is not None
if isinstance(value, CapturedException):
msg = str(value)
else:
msg = traceback.format_exc()
assert isinstance(msg, str)
super().__init__(msg)
#----------------------------------------------------------------------------
class CaptureSuccess(Exception):
def __init__(self, out):
super().__init__()
self.out = out
#----------------------------------------------------------------------------
def _sinc(x):
y = (x * np.pi).abs()
z = torch.sin(y) / y.clamp(1e-30, float('inf'))
return torch.where(y < 1e-30, torch.ones_like(x), z)
def _lanczos_window(x, a):
x = x.abs() / a
return torch.where(x < 1, _sinc(x), torch.zeros_like(x))
#----------------------------------------------------------------------------
def _construct_affine_bandlimit_filter(mat, a=3, amax=16, aflt=64, up=4, cutoff_in=1, cutoff_out=1):
assert a <= amax < aflt
mat = torch.as_tensor(mat).to(torch.float32)
# Construct 2D filter taps in input & output coordinate spaces.
taps = ((torch.arange(aflt * up * 2 - 1, device=mat.device) + 1) / up - aflt).roll(1 - aflt * up)
yi, xi = torch.meshgrid(taps, taps)
xo, yo = (torch.stack([xi, yi], dim=2) @ mat[:2, :2].t()).unbind(2)
# Convolution of two oriented 2D sinc filters.
fi = _sinc(xi * cutoff_in) * _sinc(yi * cutoff_in)
fo = _sinc(xo * cutoff_out) * _sinc(yo * cutoff_out)
f = torch.fft.ifftn(torch.fft.fftn(fi) * torch.fft.fftn(fo)).real
# Convolution of two oriented 2D Lanczos windows.
wi = _lanczos_window(xi, a) * _lanczos_window(yi, a)
wo = _lanczos_window(xo, a) * _lanczos_window(yo, a)
w = torch.fft.ifftn(torch.fft.fftn(wi) * torch.fft.fftn(wo)).real
# Construct windowed FIR filter.
f = f * w
# Finalize.
c = (aflt - amax) * up
f = f.roll([aflt * up - 1] * 2, dims=[0,1])[c:-c, c:-c]
f = torch.nn.functional.pad(f, [0, 1, 0, 1]).reshape(amax * 2, up, amax * 2, up)
f = f / f.sum([0,2], keepdim=True) / (up ** 2)
f = f.reshape(amax * 2 * up, amax * 2 * up)[:-1, :-1]
return f
#----------------------------------------------------------------------------
def _apply_affine_transformation(x, mat, up=4, **filter_kwargs):
_N, _C, H, W = x.shape
mat = torch.as_tensor(mat).to(dtype=torch.float32, device=x.device)
# Construct filter.
f = _construct_affine_bandlimit_filter(mat, up=up, **filter_kwargs)
assert f.ndim == 2 and f.shape[0] == f.shape[1] and f.shape[0] % 2 == 1
p = f.shape[0] // 2
# Construct sampling grid.
theta = mat.inverse()
theta[:2, 2] *= 2
theta[0, 2] += 1 / up / W
theta[1, 2] += 1 / up / H
theta[0, :] *= W / (W + p / up * 2)
theta[1, :] *= H / (H + p / up * 2)
theta = theta[:2, :3].unsqueeze(0).repeat([x.shape[0], 1, 1])
g = torch.nn.functional.affine_grid(theta, x.shape, align_corners=False)
# Resample image.
y = upfirdn2d.upsample2d(x=x, f=f, up=up, padding=p)
z = torch.nn.functional.grid_sample(y, g, mode='bilinear', padding_mode='zeros', align_corners=False)
# Form mask.
m = torch.zeros_like(y)
c = p * 2 + 1
m[:, :, c:-c, c:-c] = 1
m = torch.nn.functional.grid_sample(m, g, mode='nearest', padding_mode='zeros', align_corners=False)
return z, m
#----------------------------------------------------------------------------
class Renderer:
def __init__(self):
self._device = torch.device('cuda')
self._pkl_data = dict() # {pkl: dict | CapturedException, ...}
self._networks = dict() # {cache_key: torch.nn.Module, ...}
self._pinned_bufs = dict() # {(shape, dtype): torch.Tensor, ...}
self._cmaps = dict() # {name: torch.Tensor, ...}
self._is_timing = False
self._start_event = torch.cuda.Event(enable_timing=True)
self._end_event = torch.cuda.Event(enable_timing=True)
self._net_layers = dict() # {cache_key: [dnnlib.EasyDict, ...], ...}
def render(self, **args):
self._is_timing = True
self._start_event.record(torch.cuda.current_stream(self._device))
res = dnnlib.EasyDict()
try:
self._render_impl(res, **args)
except:
res.error = CapturedException()
self._end_event.record(torch.cuda.current_stream(self._device))
if 'image' in res:
res.image = self.to_cpu(res.image).numpy()
if 'stats' in res:
res.stats = self.to_cpu(res.stats).numpy()
if 'error' in res:
res.error = str(res.error)
if self._is_timing:
self._end_event.synchronize()
res.render_time = self._start_event.elapsed_time(self._end_event) * 1e-3
self._is_timing = False
return res
def get_network(self, pkl, key, **tweak_kwargs):
data = self._pkl_data.get(pkl, None)
if data is None:
print(f'Loading "{pkl}"... ', end='', flush=True)
try:
with dnnlib.util.open_url(pkl, verbose=False) as f:
data = legacy.load_network_pkl(f)
print('Done.')
except:
data = CapturedException()
print('Failed!')
self._pkl_data[pkl] = data
self._ignore_timing()
if isinstance(data, CapturedException):
raise data
orig_net = data[key]
cache_key = (orig_net, self._device, tuple(sorted(tweak_kwargs.items())))
net = self._networks.get(cache_key, None)
if net is None:
try:
net = copy.deepcopy(orig_net)
net = self._tweak_network(net, **tweak_kwargs)
net.to(self._device)
except:
net = CapturedException()
self._networks[cache_key] = net
self._ignore_timing()
if isinstance(net, CapturedException):
raise net
return net
def _tweak_network(self, net):
# Print diagnostics.
#for name, value in misc.named_params_and_buffers(net):
# if name.endswith('.magnitude_ema'):
# value = value.rsqrt().numpy()
# print(f'{name:<50s}{np.min(value):<16g}{np.max(value):g}')
# if name.endswith('.weight') and value.ndim == 4:
# value = value.square().mean([1,2,3]).sqrt().numpy()
# print(f'{name:<50s}{np.min(value):<16g}{np.max(value):g}')
return net
def _get_pinned_buf(self, ref):
key = (tuple(ref.shape), ref.dtype)
buf = self._pinned_bufs.get(key, None)
if buf is None:
buf = torch.empty(ref.shape, dtype=ref.dtype).pin_memory()
self._pinned_bufs[key] = buf
return buf
def to_device(self, buf):
return self._get_pinned_buf(buf).copy_(buf).to(self._device)
def to_cpu(self, buf):
return self._get_pinned_buf(buf).copy_(buf).clone()
def _ignore_timing(self):
self._is_timing = False
def _apply_cmap(self, x, name='viridis'):
cmap = self._cmaps.get(name, None)
if cmap is None:
cmap = matplotlib.cm.get_cmap(name)
cmap = cmap(np.linspace(0, 1, num=1024), bytes=True)[:, :3]
cmap = self.to_device(torch.from_numpy(cmap))
self._cmaps[name] = cmap
hi = cmap.shape[0] - 1
x = (x * hi + 0.5).clamp(0, hi).to(torch.int64)
x = torch.nn.functional.embedding(x, cmap)
return x
def _render_impl(self, res,
pkl = None,
w0_seeds = [[0, 1]],
stylemix_idx = [],
stylemix_seed = 0,
trunc_psi = 1,
trunc_cutoff = 0,
random_seed = 0,
noise_mode = 'const',
force_fp32 = False,
layer_name = None,
sel_channels = 3,
base_channel = 0,
img_scale_db = 0,
img_normalize = False,
fft_show = False,
fft_all = True,
fft_range_db = 50,
fft_beta = 8,
input_transform = None,
untransform = False,
):
# Dig up network details.
G = self.get_network(pkl, 'G_ema')
res.img_resolution = G.img_resolution
res.num_ws = G.num_ws
res.has_noise = any('noise_const' in name for name, _buf in G.synthesis.named_buffers())
res.has_input_transform = (hasattr(G.synthesis, 'input') and hasattr(G.synthesis.input, 'transform'))
# Set input transform.
if res.has_input_transform:
m = np.eye(3)
try:
if input_transform is not None:
m = np.linalg.inv(np.asarray(input_transform))
except np.linalg.LinAlgError:
res.error = CapturedException()
G.synthesis.input.transform.copy_(torch.from_numpy(m))
# Generate random latents.
all_seeds = [seed for seed, _weight in w0_seeds] + [stylemix_seed]
all_seeds = list(set(all_seeds))
all_zs = np.zeros([len(all_seeds), G.z_dim], dtype=np.float32)
all_cs = np.zeros([len(all_seeds), G.c_dim], dtype=np.float32)
for idx, seed in enumerate(all_seeds):
rnd = np.random.RandomState(seed)
all_zs[idx] = rnd.randn(G.z_dim)
if G.c_dim > 0:
all_cs[idx, rnd.randint(G.c_dim)] = 1
# Run mapping network.
w_avg = G.mapping.w_avg
all_zs = self.to_device(torch.from_numpy(all_zs))
all_cs = self.to_device(torch.from_numpy(all_cs))
all_ws = G.mapping(z=all_zs, c=all_cs, truncation_psi=trunc_psi, truncation_cutoff=trunc_cutoff) - w_avg
all_ws = dict(zip(all_seeds, all_ws))
# Calculate final W.
w = torch.stack([all_ws[seed] * weight for seed, weight in w0_seeds]).sum(dim=0, keepdim=True)
stylemix_idx = [idx for idx in stylemix_idx if 0 <= idx < G.num_ws]
if len(stylemix_idx) > 0:
w[:, stylemix_idx] = all_ws[stylemix_seed][np.newaxis, stylemix_idx]
w += w_avg
# Run synthesis network.
synthesis_kwargs = dnnlib.EasyDict(noise_mode=noise_mode, force_fp32=force_fp32)
torch.manual_seed(random_seed)
out, layers = self.run_synthesis_net(G.synthesis, w, capture_layer=layer_name, **synthesis_kwargs)
# Update layer list.
cache_key = (G.synthesis, tuple(sorted(synthesis_kwargs.items())))
if cache_key not in self._net_layers:
if layer_name is not None:
torch.manual_seed(random_seed)
_out, layers = self.run_synthesis_net(G.synthesis, w, **synthesis_kwargs)
self._net_layers[cache_key] = layers
res.layers = self._net_layers[cache_key]
# Untransform.
if untransform and res.has_input_transform:
out, _mask = _apply_affine_transformation(out.to(torch.float32), G.synthesis.input.transform, amax=6) # Override amax to hit the fast path in upfirdn2d.
# Select channels and compute statistics.
out = out[0].to(torch.float32)
if sel_channels > out.shape[0]:
sel_channels = 1
base_channel = max(min(base_channel, out.shape[0] - sel_channels), 0)
sel = out[base_channel : base_channel + sel_channels]
res.stats = torch.stack([
out.mean(), sel.mean(),
out.std(), sel.std(),
out.norm(float('inf')), sel.norm(float('inf')),
])
# Scale and convert to uint8.
img = sel
if img_normalize:
img = img / img.norm(float('inf'), dim=[1,2], keepdim=True).clip(1e-8, 1e8)
img = img * (10 ** (img_scale_db / 20))
img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8).permute(1, 2, 0)
res.image = img
# FFT.
if fft_show:
sig = out if fft_all else sel
sig = sig.to(torch.float32)
sig = sig - sig.mean(dim=[1,2], keepdim=True)
sig = sig * torch.kaiser_window(sig.shape[1], periodic=False, beta=fft_beta, device=self._device)[None, :, None]
sig = sig * torch.kaiser_window(sig.shape[2], periodic=False, beta=fft_beta, device=self._device)[None, None, :]
fft = torch.fft.fftn(sig, dim=[1,2]).abs().square().sum(dim=0)
fft = fft.roll(shifts=[fft.shape[0] // 2, fft.shape[1] // 2], dims=[0,1])
fft = (fft / fft.mean()).log10() * 10 # dB
fft = self._apply_cmap((fft / fft_range_db + 1) / 2)
res.image = torch.cat([img.expand_as(fft), fft], dim=1)
@staticmethod
def run_synthesis_net(net, *args, capture_layer=None, **kwargs): # => out, layers
submodule_names = {mod: name for name, mod in net.named_modules()}
unique_names = set()
layers = []
def module_hook(module, _inputs, outputs):
outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs]
outputs = [out for out in outputs if isinstance(out, torch.Tensor) and out.ndim in [4, 5]]
for idx, out in enumerate(outputs):
if out.ndim == 5: # G-CNN => remove group dimension.
out = out.mean(2)
name = submodule_names[module]
if name == '':
name = 'output'
if len(outputs) > 1:
name += f':{idx}'
if name in unique_names:
suffix = 2
while f'{name}_{suffix}' in unique_names:
suffix += 1
name += f'_{suffix}'
unique_names.add(name)
shape = [int(x) for x in out.shape]
dtype = str(out.dtype).split('.')[-1]
layers.append(dnnlib.EasyDict(name=name, shape=shape, dtype=dtype))
if name == capture_layer:
raise CaptureSuccess(out)
hooks = [module.register_forward_hook(module_hook) for module in net.modules()]
try:
out = net(*args, **kwargs)
except CaptureSuccess as e:
out = e.out
for hook in hooks:
hook.remove()
return out, layers
#----------------------------------------------------------------------------