Spaces:
Runtime error
Runtime error
#@title Gradio demo (used in space: ) | |
from matplotlib import pyplot as plt | |
from huggingface_hub import PyTorchModelHubMixin | |
import numpy as np | |
import gradio as gr | |
### A BIG CHUNK OF THIS IS COPIED FROM LIGHTWEIGHTGAN since the original has an assert requiring GPU | |
import os | |
import json | |
import multiprocessing | |
from random import random | |
import math | |
from math import log2, floor | |
from functools import partial | |
from contextlib import contextmanager, ExitStack | |
from pathlib import Path | |
from shutil import rmtree | |
import torch | |
from torch.cuda.amp import autocast, GradScaler | |
from torch.optim import Adam | |
from torch import nn, einsum | |
import torch.nn.functional as F | |
from torch.utils.data import Dataset, DataLoader | |
from torch.autograd import grad as torch_grad | |
from torch.utils.data.distributed import DistributedSampler | |
from torch.nn.parallel import DistributedDataParallel as DDP | |
from PIL import Image | |
import torchvision | |
from torchvision import transforms | |
from kornia.filters import filter2d | |
from tqdm import tqdm | |
from einops import rearrange, reduce, repeat | |
from adabelief_pytorch import AdaBelief | |
# helpers | |
def DiffAugment(x, types=[]): | |
for p in types: | |
for f in AUGMENT_FNS[p]: | |
x = f(x) | |
return x.contiguous() | |
def null_context(): | |
yield | |
def combine_contexts(contexts): | |
def multi_contexts(): | |
with ExitStack() as stack: | |
yield [stack.enter_context(ctx()) for ctx in contexts] | |
return multi_contexts | |
def exists(val): | |
return val is not None | |
def is_power_of_two(val): | |
return log2(val).is_integer() | |
def default(val, d): | |
return val if exists(val) else d | |
def set_requires_grad(model, bool): | |
for p in model.parameters(): | |
p.requires_grad = bool | |
def cycle(iterable): | |
while True: | |
for i in iterable: | |
yield i | |
def raise_if_nan(t): | |
if torch.isnan(t): | |
raise NanException | |
def evaluate_in_chunks(max_batch_size, model, *args): | |
split_args = list(zip(*list(map(lambda x: x.split(max_batch_size, dim=0), args)))) | |
chunked_outputs = [model(*i) for i in split_args] | |
if len(chunked_outputs) == 1: | |
return chunked_outputs[0] | |
return torch.cat(chunked_outputs, dim=0) | |
def slerp(val, low, high): | |
low_norm = low / torch.norm(low, dim=1, keepdim=True) | |
high_norm = high / torch.norm(high, dim=1, keepdim=True) | |
omega = torch.acos((low_norm * high_norm).sum(1)) | |
so = torch.sin(omega) | |
res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + (torch.sin(val * omega) / so).unsqueeze(1) * high | |
return res | |
def safe_div(n, d): | |
try: | |
res = n / d | |
except ZeroDivisionError: | |
prefix = '' if int(n >= 0) else '-' | |
res = float(f'{prefix}inf') | |
return res | |
# loss functions | |
def gen_hinge_loss(fake, real): | |
return fake.mean() | |
def hinge_loss(real, fake): | |
return (F.relu(1 + real) + F.relu(1 - fake)).mean() | |
def dual_contrastive_loss(real_logits, fake_logits): | |
device = real_logits.device | |
real_logits, fake_logits = map(lambda t: rearrange(t, '... -> (...)'), (real_logits, fake_logits)) | |
def loss_half(t1, t2): | |
t1 = rearrange(t1, 'i -> i ()') | |
t2 = repeat(t2, 'j -> i j', i = t1.shape[0]) | |
t = torch.cat((t1, t2), dim = -1) | |
return F.cross_entropy(t, torch.zeros(t1.shape[0], device = device, dtype = torch.long)) | |
return loss_half(real_logits, fake_logits) + loss_half(-fake_logits, -real_logits) | |
# helper classes | |
class NanException(Exception): | |
pass | |
class EMA(): | |
def __init__(self, beta): | |
super().__init__() | |
self.beta = beta | |
def update_average(self, old, new): | |
if not exists(old): | |
return new | |
return old * self.beta + (1 - self.beta) * new | |
class RandomApply(nn.Module): | |
def __init__(self, prob, fn, fn_else = lambda x: x): | |
super().__init__() | |
self.fn = fn | |
self.fn_else = fn_else | |
self.prob = prob | |
def forward(self, x): | |
fn = self.fn if random() < self.prob else self.fn_else | |
return fn(x) | |
class ChanNorm(nn.Module): | |
def __init__(self, dim, eps = 1e-5): | |
super().__init__() | |
self.eps = eps | |
self.g = nn.Parameter(torch.ones(1, dim, 1, 1)) | |
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1)) | |
def forward(self, x): | |
var = torch.var(x, dim = 1, unbiased = False, keepdim = True) | |
mean = torch.mean(x, dim = 1, keepdim = True) | |
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b | |
class PreNorm(nn.Module): | |
def __init__(self, dim, fn): | |
super().__init__() | |
self.fn = fn | |
self.norm = ChanNorm(dim) | |
def forward(self, x): | |
return self.fn(self.norm(x)) | |
class Residual(nn.Module): | |
def __init__(self, fn): | |
super().__init__() | |
self.fn = fn | |
def forward(self, x): | |
return self.fn(x) + x | |
class SumBranches(nn.Module): | |
def __init__(self, branches): | |
super().__init__() | |
self.branches = nn.ModuleList(branches) | |
def forward(self, x): | |
return sum(map(lambda fn: fn(x), self.branches)) | |
class Blur(nn.Module): | |
def __init__(self): | |
super().__init__() | |
f = torch.Tensor([1, 2, 1]) | |
self.register_buffer('f', f) | |
def forward(self, x): | |
f = self.f | |
f = f[None, None, :] * f [None, :, None] | |
return filter2d(x, f, normalized=True) | |
# attention | |
class DepthWiseConv2d(nn.Module): | |
def __init__(self, dim_in, dim_out, kernel_size, padding = 0, stride = 1, bias = True): | |
super().__init__() | |
self.net = nn.Sequential( | |
nn.Conv2d(dim_in, dim_in, kernel_size = kernel_size, padding = padding, groups = dim_in, stride = stride, bias = bias), | |
nn.Conv2d(dim_in, dim_out, kernel_size = 1, bias = bias) | |
) | |
def forward(self, x): | |
return self.net(x) | |
class LinearAttention(nn.Module): | |
def __init__(self, dim, dim_head = 64, heads = 8): | |
super().__init__() | |
self.scale = dim_head ** -0.5 | |
self.heads = heads | |
inner_dim = dim_head * heads | |
self.nonlin = nn.GELU() | |
self.to_q = nn.Conv2d(dim, inner_dim, 1, bias = False) | |
self.to_kv = DepthWiseConv2d(dim, inner_dim * 2, 3, padding = 1, bias = False) | |
self.to_out = nn.Conv2d(inner_dim, dim, 1) | |
def forward(self, fmap): | |
h, x, y = self.heads, *fmap.shape[-2:] | |
q, k, v = (self.to_q(fmap), *self.to_kv(fmap).chunk(2, dim = 1)) | |
q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> (b h) (x y) c', h = h), (q, k, v)) | |
q = q.softmax(dim = -1) | |
k = k.softmax(dim = -2) | |
q = q * self.scale | |
context = einsum('b n d, b n e -> b d e', k, v) | |
out = einsum('b n d, b d e -> b n e', q, context) | |
out = rearrange(out, '(b h) (x y) d -> b (h d) x y', h = h, x = x, y = y) | |
out = self.nonlin(out) | |
return self.to_out(out) | |
# global context network | |
# https://arxiv.org/abs/2012.13375 | |
# similar to squeeze-excite, but with a simplified attention pooling and a subsequent layer norm | |
class GlobalContext(nn.Module): | |
def __init__( | |
self, | |
*, | |
chan_in, | |
chan_out | |
): | |
super().__init__() | |
self.to_k = nn.Conv2d(chan_in, 1, 1) | |
chan_intermediate = max(3, chan_out // 2) | |
self.net = nn.Sequential( | |
nn.Conv2d(chan_in, chan_intermediate, 1), | |
nn.LeakyReLU(0.1), | |
nn.Conv2d(chan_intermediate, chan_out, 1), | |
nn.Sigmoid() | |
) | |
def forward(self, x): | |
context = self.to_k(x) | |
context = context.flatten(2).softmax(dim = -1) | |
out = einsum('b i n, b c n -> b c i', context, x.flatten(2)) | |
out = out.unsqueeze(-1) | |
return self.net(out) | |
# dataset | |
def convert_image_to(img_type, image): | |
if image.mode != img_type: | |
return image.convert(img_type) | |
return image | |
class identity(object): | |
def __call__(self, tensor): | |
return tensor | |
class expand_greyscale(object): | |
def __init__(self, transparent): | |
self.transparent = transparent | |
def __call__(self, tensor): | |
channels = tensor.shape[0] | |
num_target_channels = 4 if self.transparent else 3 | |
if channels == num_target_channels: | |
return tensor | |
alpha = None | |
if channels == 1: | |
color = tensor.expand(3, -1, -1) | |
elif channels == 2: | |
color = tensor[:1].expand(3, -1, -1) | |
alpha = tensor[1:] | |
else: | |
raise Exception(f'image with invalid number of channels given {channels}') | |
if not exists(alpha) and self.transparent: | |
alpha = torch.ones(1, *tensor.shape[1:], device=tensor.device) | |
return color if not self.transparent else torch.cat((color, alpha)) | |
class FCANet(nn.Module): | |
def __init__( | |
self, | |
*, | |
chan_in, | |
chan_out, | |
reduction = 4, | |
width | |
): | |
super().__init__() | |
freq_w, freq_h = ([0] * 8), list(range(8)) # in paper, it seems 16 frequencies was ideal | |
dct_weights = get_dct_weights(width, chan_in, [*freq_w, *freq_h], [*freq_h, *freq_w]) | |
self.register_buffer('dct_weights', dct_weights) | |
chan_intermediate = max(3, chan_out // reduction) | |
self.net = nn.Sequential( | |
nn.Conv2d(chan_in, chan_intermediate, 1), | |
nn.LeakyReLU(0.1), | |
nn.Conv2d(chan_intermediate, chan_out, 1), | |
nn.Sigmoid() | |
) | |
def forward(self, x): | |
x = reduce(x * self.dct_weights, 'b c (h h1) (w w1) -> b c h1 w1', 'sum', h1 = 1, w1 = 1) | |
return self.net(x) | |
# modifiable global variables | |
norm_class = nn.BatchNorm2d | |
def upsample(scale_factor = 2): | |
return nn.Upsample(scale_factor = scale_factor) | |
# generative adversarial network | |
class Generator(nn.Module): | |
def __init__( | |
self, | |
*, | |
image_size, | |
latent_dim = 256, | |
fmap_max = 512, | |
fmap_inverse_coef = 12, | |
transparent = False, | |
greyscale = False, | |
attn_res_layers = [], | |
freq_chan_attn = False | |
): | |
super().__init__() | |
resolution = log2(image_size) | |
assert is_power_of_two(image_size), 'image size must be a power of 2' | |
if transparent: | |
init_channel = 4 | |
elif greyscale: | |
init_channel = 1 | |
else: | |
init_channel = 3 | |
fmap_max = default(fmap_max, latent_dim) | |
self.initial_conv = nn.Sequential( | |
nn.ConvTranspose2d(latent_dim, latent_dim * 2, 4), | |
norm_class(latent_dim * 2), | |
nn.GLU(dim = 1) | |
) | |
num_layers = int(resolution) - 2 | |
features = list(map(lambda n: (n, 2 ** (fmap_inverse_coef - n)), range(2, num_layers + 2))) | |
features = list(map(lambda n: (n[0], min(n[1], fmap_max)), features)) | |
features = list(map(lambda n: 3 if n[0] >= 8 else n[1], features)) | |
features = [latent_dim, *features] | |
in_out_features = list(zip(features[:-1], features[1:])) | |
self.res_layers = range(2, num_layers + 2) | |
self.layers = nn.ModuleList([]) | |
self.res_to_feature_map = dict(zip(self.res_layers, in_out_features)) | |
self.sle_map = ((3, 7), (4, 8), (5, 9), (6, 10)) | |
self.sle_map = list(filter(lambda t: t[0] <= resolution and t[1] <= resolution, self.sle_map)) | |
self.sle_map = dict(self.sle_map) | |
self.num_layers_spatial_res = 1 | |
for (res, (chan_in, chan_out)) in zip(self.res_layers, in_out_features): | |
image_width = 2 ** res | |
attn = None | |
if image_width in attn_res_layers: | |
attn = PreNorm(chan_in, LinearAttention(chan_in)) | |
sle = None | |
if res in self.sle_map: | |
residual_layer = self.sle_map[res] | |
sle_chan_out = self.res_to_feature_map[residual_layer - 1][-1] | |
if freq_chan_attn: | |
sle = FCANet( | |
chan_in = chan_out, | |
chan_out = sle_chan_out, | |
width = 2 ** (res + 1) | |
) | |
else: | |
sle = GlobalContext( | |
chan_in = chan_out, | |
chan_out = sle_chan_out | |
) | |
layer = nn.ModuleList([ | |
nn.Sequential( | |
upsample(), | |
Blur(), | |
nn.Conv2d(chan_in, chan_out * 2, 3, padding = 1), | |
norm_class(chan_out * 2), | |
nn.GLU(dim = 1) | |
), | |
sle, | |
attn | |
]) | |
self.layers.append(layer) | |
self.out_conv = nn.Conv2d(features[-1], init_channel, 3, padding = 1) | |
def forward(self, x): | |
x = rearrange(x, 'b c -> b c () ()') | |
x = self.initial_conv(x) | |
x = F.normalize(x, dim = 1) | |
residuals = dict() | |
for (res, (up, sle, attn)) in zip(self.res_layers, self.layers): | |
if exists(attn): | |
x = attn(x) + x | |
x = up(x) | |
if exists(sle): | |
out_res = self.sle_map[res] | |
residual = sle(x) | |
residuals[out_res] = residual | |
next_res = res + 1 | |
if next_res in residuals: | |
x = x * residuals[next_res] | |
return self.out_conv(x) | |
#### ACTUALLY LOAD THE MODEL AND DEFINE THE INTERFACE | |
# Initialize a generator model | |
gan_new = Generator(latent_dim=256, image_size=256, attn_res_layers = [32]) | |
# Load from local saved state dict | |
# gan_new.load_state_dict(torch.load('/content/orbgan_e3_state_dict.pt')) | |
# Load from model hub: | |
class GeneratorWithPyTorchModelHubMixin(gan_new.__class__, PyTorchModelHubMixin): | |
pass | |
gan_new.__class__ = GeneratorWithPyTorchModelHubMixin | |
gan_new = gan_new.from_pretrained('johnowhitaker/orbgan_e1', latent_dim=256, image_size=256, attn_res_layers = [32]) | |
gan_light = Generator(latent_dim=256, image_size=256, attn_res_layers = [32]) | |
gan_light.__class__ = GeneratorWithPyTorchModelHubMixin | |
gan_light = gan_light.from_pretrained('johnowhitaker/orbgan_light', latent_dim=256, image_size=256, attn_res_layers = [32]) | |
gan_dark = Generator(latent_dim=256, image_size=256, attn_res_layers = [32]) | |
gan_dark.__class__ = GeneratorWithPyTorchModelHubMixin | |
gan_dark = gan_dark.from_pretrained('johnowhitaker/orbgan_dark', latent_dim=256, image_size=256, attn_res_layers = [32]) | |
def gen_ims(n_rows, model='both'): | |
if model == "both": | |
ims = gan_new(torch.randn(int(n_rows)**2, 256)).clamp_(0., 1.) | |
if model == "light": | |
ims = gan_light(torch.randn(int(n_rows)**2, 256)).clamp_(0., 1.) | |
if model == "dark": | |
ims = gan_dark(torch.randn(int(n_rows)**2, 256)).clamp_(0., 1.) | |
grid = torchvision.utils.make_grid(ims, nrow=int(n_rows)).permute(1, 2, 0).detach().cpu().numpy() | |
return (grid*255).astype(np.uint8) | |
iface = gr.Interface(fn=gen_ims, | |
inputs=[gr.inputs.Slider(minimum=1, maximum=6, step=1, default=3,label="N rows"), | |
gr.inputs.Dropdown(["both", "light", "dark"], type="value", default="dark", label="Orb Type (model)", optional=False)], | |
outputs=[gr.outputs.Image(type="numpy", label="Generated Images")], | |
title='Demo for orbgan models', | |
article = 'Models are lightweight-gans trained on johnowhitaker/glid3_orbs. See https://huggingface.co./johnowhitaker/orbgan_e1 for training and inference scripts' | |
) | |
iface.launch() |