Runtime error
Runtime error
commited on
Browse files
@@ -19,8 +19,194 @@ from fastprogress.fastprogress import master_bar, progress_bar
19 |
from IPython.display import HTML
20 |
from base64 import b64encode
21 |
22 |
23 |
def generate(text, n_steps):
24 |
25 |
return np.random.random((128, 128, 3)).astype(np.uint8)
26 |
19 |
from IPython.display import HTML
20 |
from base64 import b64encode
21 |
22 |
# Definitions
23 |
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
24 |
25 |
def sinc(x):
26 |
return torch.where(x != 0, torch.sin(math.pi * x) / (math.pi * x), x.new_ones([]))
27 |
28 |
29 |
def lanczos(x, a):
30 |
cond = torch.logical_and(-a < x, x < a)
31 |
out = torch.where(cond, sinc(x) * sinc(x/a), x.new_zeros([]))
32 |
return out / out.sum()
33 |
34 |
35 |
def ramp(ratio, width):
36 |
n = math.ceil(width / ratio + 1)
37 |
out = torch.empty([n])
38 |
cur = 0
39 |
for i in range(out.shape[0]):
40 |
out[i] = cur
41 |
cur += ratio
42 |
return[-out[1:].flip([0]), out])[1:-1]
43 |
44 |
class Prompt(nn.Module):
45 |
def __init__(self, embed, weight=1., stop=float('-inf')):
46 |
47 |
self.register_buffer('embed', embed)
48 |
self.register_buffer('weight', torch.as_tensor(weight))
49 |
self.register_buffer('stop', torch.as_tensor(stop))
50 |
51 |
def forward(self, input):
52 |
input_normed = F.normalize(input.unsqueeze(1), dim=2)
53 |
embed_normed = F.normalize(self.embed.unsqueeze(0), dim=2)
54 |
dists = input_normed.sub(embed_normed).norm(dim=2).div(2).arcsin().pow(2).mul(2)
55 |
dists = dists * self.weight.sign()
56 |
return self.weight.abs() * replace_grad(dists, torch.maximum(dists, self.stop)).mean()
57 |
58 |
class MakeCutouts(nn.Module):
59 |
def __init__(self, cut_size, cutn, cut_pow=1.):
60 |
61 |
self.cut_size = cut_size
62 |
self.cutn = cutn
63 |
self.cut_pow = cut_pow
64 |
self.augs = nn.Sequential(
65 |
66 |
67 |
K.RandomAffine(degrees=30, translate=0.1, p=0.8, padding_mode='border'),
68 |
69 |
K.ColorJitter(hue=0.01, saturation=0.01, p=0.7))
70 |
self.noise_fac = 0.1
71 |
72 |
def forward(self, input):
73 |
sideY, sideX = input.shape[2:4]
74 |
max_size = min(sideX, sideY)
75 |
min_size = min(sideX, sideY, self.cut_size)
76 |
cutouts = []
77 |
for _ in range(self.cutn):
78 |
size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)
79 |
offsetx = torch.randint(0, sideX - size + 1, ())
80 |
offsety = torch.randint(0, sideY - size + 1, ())
81 |
cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]
82 |
cutouts.append(resample(cutout, (self.cut_size, self.cut_size)))
83 |
batch = self.augs(, dim=0))
84 |
if self.noise_fac:
85 |
facs = batch.new_empty([self.cutn, 1, 1, 1]).uniform_(0, self.noise_fac)
86 |
batch = batch + facs * torch.randn_like(batch)
87 |
return batch
88 |
89 |
def resample(input, size, align_corners=True):
90 |
n, c, h, w = input.shape
91 |
dh, dw = size
92 |
93 |
input = input.view([n * c, 1, h, w])
94 |
95 |
if dh < h:
96 |
kernel_h = lanczos(ramp(dh / h, 2), 2).to(input.device, input.dtype)
97 |
pad_h = (kernel_h.shape[0] - 1) // 2
98 |
input = F.pad(input, (0, 0, pad_h, pad_h), 'reflect')
99 |
input = F.conv2d(input, kernel_h[None, None, :, None])
100 |
101 |
if dw < w:
102 |
kernel_w = lanczos(ramp(dw / w, 2), 2).to(input.device, input.dtype)
103 |
pad_w = (kernel_w.shape[0] - 1) // 2
104 |
input = F.pad(input, (pad_w, pad_w, 0, 0), 'reflect')
105 |
input = F.conv2d(input, kernel_w[None, None, None, :])
106 |
107 |
input = input.view([n, c, h, w])
108 |
return F.interpolate(input, size, mode='bicubic', align_corners=align_corners)
109 |
110 |
class ReplaceGrad(torch.autograd.Function):
111 |
112 |
def forward(ctx, x_forward, x_backward):
113 |
ctx.shape = x_backward.shape
114 |
return x_forward
115 |
116 |
117 |
def backward(ctx, grad_in):
118 |
return None, grad_in.sum_to_size(ctx.shape)
119 |
120 |
121 |
replace_grad = ReplaceGrad.apply
122 |
123 |
# Set up CLIP
124 |
perceptor = clip.load('ViT-B/32', jit=False)[0].eval().requires_grad_(False).to(device)
125 |
normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
126 |
std=[0.26862954, 0.26130258, 0.27577711])
127 |
cut_size = perceptor.visual.input_resolution
128 |
129 |
130 |
make_cutouts = MakeCutouts(cut_size, cutn, cut_pow=cut_pow)
131 |
132 |
# ImStack
133 |
class ImStack(nn.Module):
134 |
""" This class represents an image as a series of stacked arrays, where each is 1/2
135 |
the resolution of the next. This is useful eg when trying to create an image to minimise
136 |
some loss - parameters in the early (small) layers can have an affect on the overall
137 |
structure and shapes while those in later layers act as residuals and fill in fine detail.
138 |
139 |
140 |
def __init__(self, n_layers=3, base_size=32, scale=2,
141 |
init_image=None, out_size=256, decay=0.7):
142 |
"""Constructs the Image Stack
143 |
144 |
145 |
146 |
147 |
148 |
self.n_layers = n_layers
149 |
self.base_size = base_size
150 |
self.sig = nn.Sigmoid()
151 |
self.layers = []
152 |
153 |
for i in range(n_layers):
154 |
side = base_size * (scale**i)
155 |
tim = torch.randn((3, side, side)).to(device)*(decay**i)
156 |
157 |
158 |
self.scalers = [nn.Upsample(scale_factor=out_size/(l.shape[1]), mode='bilinear', align_corners=False) for l in self.layers]
159 |
160 |
self.preview_scalers = [nn.Upsample(scale_factor=224/(l.shape[1]), mode='bilinear', align_corners=False) for l in self.layers]
161 |
162 |
if init_image != None: # Given a PIL image, decompose it into a stack
163 |
downscalers = [nn.Upsample(scale_factor=(l.shape[1]/out_size), mode='bilinear', align_corners=False) for l in self.layers]
164 |
final_side = base_size * (scale ** n_layers)
165 |
im = torch.tensor(np.array(init_image.resize((out_size, out_size)))/255).clip(1e-03, 1-1e-3) # Between 0 and 1 (non-inclusive)
166 |
im = im.permute(2, 0, 1).unsqueeze(0).to(device) # torch.log(im/(1-im))
167 |
for i in range(n_layers):self.layers[i] *= 0 # Sero out the layers
168 |
for i in range(n_layers):
169 |
side = base_size * (scale**i)
170 |
out = self.forward()
171 |
residual = (torch.logit(im) - torch.logit(out))
172 |
Image.fromarray((torch.logit(residual).detach().cpu().squeeze().permute([1, 2, 0]) * 255).numpy().astype(np.uint8)).save(f'residual{i}.png')
173 |
self.layers[i] = downscalers[i](residual).squeeze()
174 |
175 |
for l in self.layers: l.requires_grad = True
176 |
177 |
def forward(self):
178 |
im = self.scalers[0](self.layers[0].unsqueeze(0))
179 |
for i in range(1, self.n_layers):
180 |
im += self.scalers[i](self.layers[i].unsqueeze(0))
181 |
return self.sig(im)
182 |
183 |
def preview(self, n_preview=2):
184 |
im = self.preview_scalers[0](self.layers[0].unsqueeze(0))
185 |
for i in range(1, n_preview):
186 |
im += self.preview_scalers[i](self.layers[i].unsqueeze(0))
187 |
return self.sig(im)
188 |
189 |
def to_pil(self):
190 |
return Image.fromarray((self.forward().detach().cpu().squeeze().permute([1, 2, 0]) * 255).numpy().astype(np.uint8))
191 |
192 |
def preview_pil(self):
193 |
return Image.fromarray((self.preview().detach().cpu().squeeze().permute([1, 2, 0]) * 255).numpy().astype(np.uint8))
194 |
195 |
def save(self, fn):
196 |
197 |
198 |
def plot_layers(self):
199 |
fig, axs = plt.subplots(1, self.n_layers, figsize=(15, 5))
200 |
for i in range(self.n_layers):
201 |
im = (self.sig(self.layers[i].unsqueeze(0)).detach().cpu().squeeze().permute([1, 2, 0]) * 255).numpy().astype(np.uint8)
202 |
203 |
204 |
205 |
206 |
207 |
def generate(text, n_steps):
208 |
# Encode prompt
209 |
embed = perceptor.encode_text(clip.tokenize(text).to(device)).float()
210 |
211 |
return np.random.random((128, 128, 3)).astype(np.uint8)
212 |