Robert001 commited on
Commit
6d05c2f
1 Parent(s): 6d16592

first commit

Browse files
Files changed (1) hide show
  1. lib/ddim.py +348 -0
lib/ddim.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ * Copyright (c) 2023 Salesforce, Inc.
3
+ * All rights reserved.
4
+ * SPDX-License-Identifier: Apache License 2.0
5
+ * For full license text, see LICENSE.txt file in the repo root or http://www.apache.org/licenses/
6
+ * By Can Qin
7
+ * Modified from ControlNet repo: https://github.com/lllyasviel/ControlNet
8
+ * Copyright (c) 2023 Lvmin Zhang and Maneesh Agrawala
9
+ '''
10
+
11
+ """SAMPLING ONLY."""
12
+
13
+ import torch
14
+ import numpy as np
15
+ from tqdm import tqdm
16
+
17
+ from lib.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, \
18
+ extract_into_tensor
19
+
20
+
21
+ class DDIMSampler(object):
22
+ def __init__(self, model, schedule="linear", **kwargs):
23
+ super().__init__()
24
+ self.model = model
25
+ self.ddpm_num_timesteps = model.num_timesteps
26
+ self.schedule = schedule
27
+
28
+ def register_buffer(self, name, attr):
29
+ if type(attr) == torch.Tensor:
30
+ if attr.device != torch.device("cuda"):
31
+ attr = attr.to(torch.device("cuda"))
32
+ setattr(self, name, attr)
33
+
34
+ def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
35
+ self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
36
+ num_ddpm_timesteps=self.ddpm_num_timesteps, verbose=verbose)
37
+ alphas_cumprod = self.model.alphas_cumprod
38
+ assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
39
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
40
+
41
+ self.register_buffer('betas', to_torch(self.model.betas))
42
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
43
+ self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
44
+
45
+ # calculations for diffusion q(x_t | x_{t-1}) and others
46
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
47
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
48
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
49
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
50
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
51
+
52
+ # ddim sampling parameters
53
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
54
+ ddim_timesteps=self.ddim_timesteps,
55
+ eta=ddim_eta, verbose=verbose)
56
+ self.register_buffer('ddim_sigmas', ddim_sigmas)
57
+ self.register_buffer('ddim_alphas', ddim_alphas)
58
+ self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
59
+ self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
60
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
61
+ (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
62
+ 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
63
+ self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
64
+
65
+ @torch.no_grad()
66
+ def sample(self,
67
+ S,
68
+ batch_size,
69
+ shape,
70
+ conditioning=None,
71
+ callback=None,
72
+ normals_sequence=None,
73
+ img_callback=None,
74
+ quantize_x0=False,
75
+ eta=0.,
76
+ mask=None,
77
+ x0=None,
78
+ temperature=1.,
79
+ noise_dropout=0.,
80
+ score_corrector=None,
81
+ corrector_kwargs=None,
82
+ verbose=True,
83
+ x_T=None,
84
+ log_every_t=100,
85
+ unconditional_guidance_scale=1.,
86
+ unconditional_conditioning=None,
87
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
88
+ dynamic_threshold=None,
89
+ ucg_schedule=None,
90
+ **kwargs
91
+ ):
92
+ if conditioning is not None:
93
+ if isinstance(conditioning, dict):
94
+ ctmp = conditioning[list(conditioning.keys())[0]]
95
+ while isinstance(ctmp, list): ctmp = ctmp[0]
96
+ cbs = ctmp.shape[0]
97
+ if cbs != batch_size:
98
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
99
+
100
+ elif isinstance(conditioning, list):
101
+ for ctmp in conditioning:
102
+ if ctmp.shape[0] != batch_size:
103
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
104
+
105
+ else:
106
+ if conditioning.shape[0] != batch_size:
107
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
108
+
109
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
110
+ # sampling
111
+ C, H, W = shape
112
+ size = (batch_size, C, H, W)
113
+ print(f'Data shape for DDIM sampling is {size}, eta {eta}')
114
+
115
+ samples, intermediates = self.ddim_sampling(conditioning, size,
116
+ callback=callback,
117
+ img_callback=img_callback,
118
+ quantize_denoised=quantize_x0,
119
+ mask=mask, x0=x0,
120
+ ddim_use_original_steps=False,
121
+ noise_dropout=noise_dropout,
122
+ temperature=temperature,
123
+ score_corrector=score_corrector,
124
+ corrector_kwargs=corrector_kwargs,
125
+ x_T=x_T,
126
+ log_every_t=log_every_t,
127
+ unconditional_guidance_scale=unconditional_guidance_scale,
128
+ unconditional_conditioning=unconditional_conditioning,
129
+ dynamic_threshold=dynamic_threshold,
130
+ ucg_schedule=ucg_schedule
131
+ )
132
+ return samples, intermediates
133
+
134
+ @torch.no_grad()
135
+ def ddim_sampling(self, cond, shape,
136
+ x_T=None, ddim_use_original_steps=False,
137
+ callback=None, timesteps=None, quantize_denoised=False,
138
+ mask=None, x0=None, img_callback=None, log_every_t=100,
139
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
140
+ unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None,
141
+ ucg_schedule=None):
142
+ device = self.model.betas.device
143
+ b = shape[0]
144
+ if x_T is None:
145
+ img = torch.randn(shape, device=device)
146
+ else:
147
+ img = x_T
148
+
149
+ if timesteps is None:
150
+ timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
151
+ elif timesteps is not None and not ddim_use_original_steps:
152
+ subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
153
+ timesteps = self.ddim_timesteps[:subset_end]
154
+
155
+ intermediates = {'x_inter': [img], 'pred_x0': [img]}
156
+ time_range = reversed(range(0, timesteps)) if ddim_use_original_steps else np.flip(timesteps)
157
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
158
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
159
+
160
+ iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
161
+
162
+ for i, step in enumerate(iterator):
163
+ index = total_steps - i - 1
164
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
165
+
166
+ if mask is not None:
167
+ assert x0 is not None
168
+ img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
169
+ img = img_orig * mask + (1. - mask) * img
170
+
171
+ if ucg_schedule is not None:
172
+ assert len(ucg_schedule) == len(time_range)
173
+ unconditional_guidance_scale = ucg_schedule[i]
174
+
175
+ outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
176
+ quantize_denoised=quantize_denoised, temperature=temperature,
177
+ noise_dropout=noise_dropout, score_corrector=score_corrector,
178
+ corrector_kwargs=corrector_kwargs,
179
+ unconditional_guidance_scale=unconditional_guidance_scale,
180
+ unconditional_conditioning=unconditional_conditioning,
181
+ dynamic_threshold=dynamic_threshold)
182
+ img, pred_x0 = outs
183
+ if callback: callback(i)
184
+ if img_callback: img_callback(pred_x0, i)
185
+
186
+ if index % log_every_t == 0 or index == total_steps - 1:
187
+ intermediates['x_inter'].append(img)
188
+ intermediates['pred_x0'].append(pred_x0)
189
+
190
+ return img, intermediates
191
+
192
+ @torch.no_grad()
193
+ def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
194
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
195
+ unconditional_guidance_scale=1., unconditional_conditioning=None,
196
+ dynamic_threshold=None):
197
+ b, *_, device = *x.shape, x.device
198
+
199
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
200
+ model_output = self.model.apply_model(x, t, c)
201
+ else:
202
+ x_in = torch.cat([x] * 2)
203
+ t_in = torch.cat([t] * 2)
204
+ if isinstance(c, dict):
205
+ assert isinstance(unconditional_conditioning, dict)
206
+ c_in = dict()
207
+ for k in c:
208
+ if isinstance(c[k], list):
209
+ c_in[k] = [torch.cat([
210
+ unconditional_conditioning[k][i],
211
+ c[k][i]]) for i in range(len(c[k]))]
212
+ else:
213
+ c_in[k] = torch.cat([
214
+ unconditional_conditioning[k],
215
+ c[k]])
216
+ elif isinstance(c, list):
217
+ c_in = list()
218
+ assert isinstance(unconditional_conditioning, list)
219
+ for i in range(len(c)):
220
+ c_in.append(torch.cat([unconditional_conditioning[i], c[i]]))
221
+ else:
222
+ c_in = torch.cat([unconditional_conditioning, c])
223
+ model_uncond, model_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
224
+ model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)
225
+
226
+ if self.model.parameterization == "v":
227
+ e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
228
+ else:
229
+ e_t = model_output
230
+
231
+ if score_corrector is not None:
232
+ assert self.model.parameterization == "eps", 'not implemented'
233
+ e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
234
+
235
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
236
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
237
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
238
+ sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
239
+ # select parameters corresponding to the currently considered timestep
240
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
241
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
242
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
243
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device)
244
+
245
+ # current prediction for x_0
246
+ if self.model.parameterization != "v":
247
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
248
+ else:
249
+ pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
250
+
251
+ if quantize_denoised:
252
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
253
+
254
+ if dynamic_threshold is not None:
255
+ raise NotImplementedError()
256
+
257
+ # direction pointing to x_t
258
+ dir_xt = (1. - a_prev - sigma_t ** 2).sqrt() * e_t
259
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
260
+ if noise_dropout > 0.:
261
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
262
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
263
+ return x_prev, pred_x0
264
+
265
+ @torch.no_grad()
266
+ def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None,
267
+ unconditional_guidance_scale=1.0, unconditional_conditioning=None, callback=None):
268
+ num_reference_steps = self.ddpm_num_timesteps if use_original_steps else self.ddim_timesteps.shape[0]
269
+
270
+ assert t_enc <= num_reference_steps
271
+ num_steps = t_enc
272
+
273
+ if use_original_steps:
274
+ alphas_next = self.alphas_cumprod[:num_steps]
275
+ alphas = self.alphas_cumprod_prev[:num_steps]
276
+ else:
277
+ alphas_next = self.ddim_alphas[:num_steps]
278
+ alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
279
+
280
+ x_next = x0
281
+ intermediates = []
282
+ inter_steps = []
283
+ for i in tqdm(range(num_steps), desc='Encoding Image'):
284
+ t = torch.full((x0.shape[0],), i, device=self.model.device, dtype=torch.long)
285
+ if unconditional_guidance_scale == 1.:
286
+ noise_pred = self.model.apply_model(x_next, t, c)
287
+ else:
288
+ assert unconditional_conditioning is not None
289
+ e_t_uncond, noise_pred = torch.chunk(
290
+ self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)),
291
+ torch.cat((unconditional_conditioning, c))), 2)
292
+ noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond)
293
+
294
+ xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
295
+ weighted_noise_pred = alphas_next[i].sqrt() * (
296
+ (1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred
297
+ x_next = xt_weighted + weighted_noise_pred
298
+ if return_intermediates and i % (
299
+ num_steps // return_intermediates) == 0 and i < num_steps - 1:
300
+ intermediates.append(x_next)
301
+ inter_steps.append(i)
302
+ elif return_intermediates and i >= num_steps - 2:
303
+ intermediates.append(x_next)
304
+ inter_steps.append(i)
305
+ if callback: callback(i)
306
+
307
+ out = {'x_encoded': x_next, 'intermediate_steps': inter_steps}
308
+ if return_intermediates:
309
+ out.update({'intermediates': intermediates})
310
+ return x_next, out
311
+
312
+ @torch.no_grad()
313
+ def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
314
+ # fast, but does not allow for exact reconstruction
315
+ # t serves as an index to gather the correct alphas
316
+ if use_original_steps:
317
+ sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
318
+ sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
319
+ else:
320
+ sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
321
+ sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
322
+
323
+ if noise is None:
324
+ noise = torch.randn_like(x0)
325
+ return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
326
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
327
+
328
+ @torch.no_grad()
329
+ def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
330
+ use_original_steps=False, callback=None):
331
+
332
+ timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
333
+ timesteps = timesteps[:t_start]
334
+
335
+ time_range = np.flip(timesteps)
336
+ total_steps = timesteps.shape[0]
337
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
338
+
339
+ iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
340
+ x_dec = x_latent
341
+ for i, step in enumerate(iterator):
342
+ index = total_steps - i - 1
343
+ ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
344
+ x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
345
+ unconditional_guidance_scale=unconditional_guidance_scale,
346
+ unconditional_conditioning=unconditional_conditioning)
347
+ if callback: callback(i)
348
+ return x_dec