Spaces:
Runtime error
Runtime error
johnowhitaker
commited on
Commit
·
c9bf0bf
1
Parent(s):
5e36def
Update app.py
Browse files
app.py
CHANGED
@@ -19,8 +19,194 @@ from fastprogress.fastprogress import master_bar, progress_bar
|
|
19 |
from IPython.display import HTML
|
20 |
from base64 import b64encode
|
21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
def generate(text, n_steps):
|
|
|
|
|
24 |
#todo
|
25 |
return np.random.random((128, 128, 3)).astype(np.uint8)
|
26 |
|
|
|
19 |
from IPython.display import HTML
|
20 |
from base64 import b64encode
|
21 |
|
22 |
+
# Definitions
|
23 |
+
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
24 |
+
|
25 |
+
def sinc(x):
|
26 |
+
return torch.where(x != 0, torch.sin(math.pi * x) / (math.pi * x), x.new_ones([]))
|
27 |
+
|
28 |
+
|
29 |
+
def lanczos(x, a):
|
30 |
+
cond = torch.logical_and(-a < x, x < a)
|
31 |
+
out = torch.where(cond, sinc(x) * sinc(x/a), x.new_zeros([]))
|
32 |
+
return out / out.sum()
|
33 |
+
|
34 |
+
|
35 |
+
def ramp(ratio, width):
|
36 |
+
n = math.ceil(width / ratio + 1)
|
37 |
+
out = torch.empty([n])
|
38 |
+
cur = 0
|
39 |
+
for i in range(out.shape[0]):
|
40 |
+
out[i] = cur
|
41 |
+
cur += ratio
|
42 |
+
return torch.cat([-out[1:].flip([0]), out])[1:-1]
|
43 |
+
|
44 |
+
class Prompt(nn.Module):
|
45 |
+
def __init__(self, embed, weight=1., stop=float('-inf')):
|
46 |
+
super().__init__()
|
47 |
+
self.register_buffer('embed', embed)
|
48 |
+
self.register_buffer('weight', torch.as_tensor(weight))
|
49 |
+
self.register_buffer('stop', torch.as_tensor(stop))
|
50 |
+
|
51 |
+
def forward(self, input):
|
52 |
+
input_normed = F.normalize(input.unsqueeze(1), dim=2)
|
53 |
+
embed_normed = F.normalize(self.embed.unsqueeze(0), dim=2)
|
54 |
+
dists = input_normed.sub(embed_normed).norm(dim=2).div(2).arcsin().pow(2).mul(2)
|
55 |
+
dists = dists * self.weight.sign()
|
56 |
+
return self.weight.abs() * replace_grad(dists, torch.maximum(dists, self.stop)).mean()
|
57 |
+
|
58 |
+
class MakeCutouts(nn.Module):
|
59 |
+
def __init__(self, cut_size, cutn, cut_pow=1.):
|
60 |
+
super().__init__()
|
61 |
+
self.cut_size = cut_size
|
62 |
+
self.cutn = cutn
|
63 |
+
self.cut_pow = cut_pow
|
64 |
+
self.augs = nn.Sequential(
|
65 |
+
K.RandomHorizontalFlip(p=0.5),
|
66 |
+
K.RandomSharpness(0.3,p=0.4),
|
67 |
+
K.RandomAffine(degrees=30, translate=0.1, p=0.8, padding_mode='border'),
|
68 |
+
K.RandomPerspective(0.2,p=0.4),
|
69 |
+
K.ColorJitter(hue=0.01, saturation=0.01, p=0.7))
|
70 |
+
self.noise_fac = 0.1
|
71 |
+
|
72 |
+
def forward(self, input):
|
73 |
+
sideY, sideX = input.shape[2:4]
|
74 |
+
max_size = min(sideX, sideY)
|
75 |
+
min_size = min(sideX, sideY, self.cut_size)
|
76 |
+
cutouts = []
|
77 |
+
for _ in range(self.cutn):
|
78 |
+
size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)
|
79 |
+
offsetx = torch.randint(0, sideX - size + 1, ())
|
80 |
+
offsety = torch.randint(0, sideY - size + 1, ())
|
81 |
+
cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]
|
82 |
+
cutouts.append(resample(cutout, (self.cut_size, self.cut_size)))
|
83 |
+
batch = self.augs(torch.cat(cutouts, dim=0))
|
84 |
+
if self.noise_fac:
|
85 |
+
facs = batch.new_empty([self.cutn, 1, 1, 1]).uniform_(0, self.noise_fac)
|
86 |
+
batch = batch + facs * torch.randn_like(batch)
|
87 |
+
return batch
|
88 |
+
|
89 |
+
def resample(input, size, align_corners=True):
|
90 |
+
n, c, h, w = input.shape
|
91 |
+
dh, dw = size
|
92 |
+
|
93 |
+
input = input.view([n * c, 1, h, w])
|
94 |
+
|
95 |
+
if dh < h:
|
96 |
+
kernel_h = lanczos(ramp(dh / h, 2), 2).to(input.device, input.dtype)
|
97 |
+
pad_h = (kernel_h.shape[0] - 1) // 2
|
98 |
+
input = F.pad(input, (0, 0, pad_h, pad_h), 'reflect')
|
99 |
+
input = F.conv2d(input, kernel_h[None, None, :, None])
|
100 |
+
|
101 |
+
if dw < w:
|
102 |
+
kernel_w = lanczos(ramp(dw / w, 2), 2).to(input.device, input.dtype)
|
103 |
+
pad_w = (kernel_w.shape[0] - 1) // 2
|
104 |
+
input = F.pad(input, (pad_w, pad_w, 0, 0), 'reflect')
|
105 |
+
input = F.conv2d(input, kernel_w[None, None, None, :])
|
106 |
+
|
107 |
+
input = input.view([n, c, h, w])
|
108 |
+
return F.interpolate(input, size, mode='bicubic', align_corners=align_corners)
|
109 |
+
|
110 |
+
class ReplaceGrad(torch.autograd.Function):
|
111 |
+
@staticmethod
|
112 |
+
def forward(ctx, x_forward, x_backward):
|
113 |
+
ctx.shape = x_backward.shape
|
114 |
+
return x_forward
|
115 |
+
|
116 |
+
@staticmethod
|
117 |
+
def backward(ctx, grad_in):
|
118 |
+
return None, grad_in.sum_to_size(ctx.shape)
|
119 |
+
|
120 |
+
|
121 |
+
replace_grad = ReplaceGrad.apply
|
122 |
+
|
123 |
+
# Set up CLIP
|
124 |
+
perceptor = clip.load('ViT-B/32', jit=False)[0].eval().requires_grad_(False).to(device)
|
125 |
+
normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
|
126 |
+
std=[0.26862954, 0.26130258, 0.27577711])
|
127 |
+
cut_size = perceptor.visual.input_resolution
|
128 |
+
cutn=64
|
129 |
+
cut_pow=1
|
130 |
+
make_cutouts = MakeCutouts(cut_size, cutn, cut_pow=cut_pow)
|
131 |
+
|
132 |
+
# ImStack
|
133 |
+
class ImStack(nn.Module):
|
134 |
+
""" This class represents an image as a series of stacked arrays, where each is 1/2
|
135 |
+
the resolution of the next. This is useful eg when trying to create an image to minimise
|
136 |
+
some loss - parameters in the early (small) layers can have an affect on the overall
|
137 |
+
structure and shapes while those in later layers act as residuals and fill in fine detail.
|
138 |
+
"""
|
139 |
+
|
140 |
+
def __init__(self, n_layers=3, base_size=32, scale=2,
|
141 |
+
init_image=None, out_size=256, decay=0.7):
|
142 |
+
"""Constructs the Image Stack
|
143 |
+
|
144 |
+
Args:
|
145 |
+
TODO
|
146 |
+
"""
|
147 |
+
super().__init__()
|
148 |
+
self.n_layers = n_layers
|
149 |
+
self.base_size = base_size
|
150 |
+
self.sig = nn.Sigmoid()
|
151 |
+
self.layers = []
|
152 |
+
|
153 |
+
for i in range(n_layers):
|
154 |
+
side = base_size * (scale**i)
|
155 |
+
tim = torch.randn((3, side, side)).to(device)*(decay**i)
|
156 |
+
self.layers.append(tim)
|
157 |
+
|
158 |
+
self.scalers = [nn.Upsample(scale_factor=out_size/(l.shape[1]), mode='bilinear', align_corners=False) for l in self.layers]
|
159 |
+
|
160 |
+
self.preview_scalers = [nn.Upsample(scale_factor=224/(l.shape[1]), mode='bilinear', align_corners=False) for l in self.layers]
|
161 |
+
|
162 |
+
if init_image != None: # Given a PIL image, decompose it into a stack
|
163 |
+
downscalers = [nn.Upsample(scale_factor=(l.shape[1]/out_size), mode='bilinear', align_corners=False) for l in self.layers]
|
164 |
+
final_side = base_size * (scale ** n_layers)
|
165 |
+
im = torch.tensor(np.array(init_image.resize((out_size, out_size)))/255).clip(1e-03, 1-1e-3) # Between 0 and 1 (non-inclusive)
|
166 |
+
im = im.permute(2, 0, 1).unsqueeze(0).to(device) # torch.log(im/(1-im))
|
167 |
+
for i in range(n_layers):self.layers[i] *= 0 # Sero out the layers
|
168 |
+
for i in range(n_layers):
|
169 |
+
side = base_size * (scale**i)
|
170 |
+
out = self.forward()
|
171 |
+
residual = (torch.logit(im) - torch.logit(out))
|
172 |
+
Image.fromarray((torch.logit(residual).detach().cpu().squeeze().permute([1, 2, 0]) * 255).numpy().astype(np.uint8)).save(f'residual{i}.png')
|
173 |
+
self.layers[i] = downscalers[i](residual).squeeze()
|
174 |
+
|
175 |
+
for l in self.layers: l.requires_grad = True
|
176 |
+
|
177 |
+
def forward(self):
|
178 |
+
im = self.scalers[0](self.layers[0].unsqueeze(0))
|
179 |
+
for i in range(1, self.n_layers):
|
180 |
+
im += self.scalers[i](self.layers[i].unsqueeze(0))
|
181 |
+
return self.sig(im)
|
182 |
+
|
183 |
+
def preview(self, n_preview=2):
|
184 |
+
im = self.preview_scalers[0](self.layers[0].unsqueeze(0))
|
185 |
+
for i in range(1, n_preview):
|
186 |
+
im += self.preview_scalers[i](self.layers[i].unsqueeze(0))
|
187 |
+
return self.sig(im)
|
188 |
+
|
189 |
+
def to_pil(self):
|
190 |
+
return Image.fromarray((self.forward().detach().cpu().squeeze().permute([1, 2, 0]) * 255).numpy().astype(np.uint8))
|
191 |
+
|
192 |
+
def preview_pil(self):
|
193 |
+
return Image.fromarray((self.preview().detach().cpu().squeeze().permute([1, 2, 0]) * 255).numpy().astype(np.uint8))
|
194 |
+
|
195 |
+
def save(self, fn):
|
196 |
+
self.to_pil().save(fn)
|
197 |
+
|
198 |
+
def plot_layers(self):
|
199 |
+
fig, axs = plt.subplots(1, self.n_layers, figsize=(15, 5))
|
200 |
+
for i in range(self.n_layers):
|
201 |
+
im = (self.sig(self.layers[i].unsqueeze(0)).detach().cpu().squeeze().permute([1, 2, 0]) * 255).numpy().astype(np.uint8)
|
202 |
+
axs[i].imshow(im)
|
203 |
+
|
204 |
+
|
205 |
+
|
206 |
|
207 |
def generate(text, n_steps):
|
208 |
+
# Encode prompt
|
209 |
+
embed = perceptor.encode_text(clip.tokenize(text).to(device)).float()
|
210 |
#todo
|
211 |
return np.random.random((128, 128, 3)).astype(np.uint8)
|
212 |
|