Spaces:
Running
Running
""" | |
wild mixture of | |
https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py | |
https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py | |
https://github.com/CompVis/taming-transformers | |
-- merci | |
""" | |
import time, math | |
from tqdm.auto import trange, tqdm | |
import torch | |
from einops import rearrange | |
from tqdm import tqdm | |
from ldmlib.modules.distributions.distributions import DiagonalGaussianDistribution | |
from ldmlib.models.autoencoder import VQModelInterface | |
import torch.nn as nn | |
import numpy as np | |
import pytorch_lightning as pl | |
from functools import partial | |
from pytorch_lightning.utilities.distributed import rank_zero_only | |
from ldmlib.util import exists, default, instantiate_from_config | |
from ldmlib.modules.diffusionmodules.util import make_beta_schedule | |
from ldmlib.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like | |
from ldmlib.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like | |
from .samplers import CompVisDenoiser, get_ancestral_step, to_d, append_dims,linear_multistep_coeff | |
def disabled_train(self): | |
"""Overwrite model.train with this function to make sure train/eval mode | |
does not change anymore.""" | |
return self | |
class DDPM(pl.LightningModule): | |
# classic DDPM with Gaussian diffusion, in image space | |
def __init__(self, | |
timesteps=1000, | |
beta_schedule="linear", | |
ckpt_path=None, | |
ignore_keys=[], | |
load_only_unet=False, | |
monitor="val/loss", | |
use_ema=True, | |
first_stage_key="image", | |
image_size=256, | |
channels=3, | |
log_every_t=100, | |
clip_denoised=True, | |
linear_start=1e-4, | |
linear_end=2e-2, | |
cosine_s=8e-3, | |
given_betas=None, | |
original_elbo_weight=0., | |
v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta | |
l_simple_weight=1., | |
conditioning_key=None, | |
parameterization="eps", # all assuming fixed variance schedules | |
scheduler_config=None, | |
use_positional_encodings=False, | |
): | |
super().__init__() | |
assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"' | |
self.parameterization = parameterization | |
print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode") | |
self.cond_stage_model = None | |
self.clip_denoised = clip_denoised | |
self.log_every_t = log_every_t | |
self.first_stage_key = first_stage_key | |
self.image_size = image_size # try conv? | |
self.channels = channels | |
self.use_positional_encodings = use_positional_encodings | |
self.use_scheduler = scheduler_config is not None | |
if self.use_scheduler: | |
self.scheduler_config = scheduler_config | |
self.v_posterior = v_posterior | |
self.original_elbo_weight = original_elbo_weight | |
self.l_simple_weight = l_simple_weight | |
if monitor is not None: | |
self.monitor = monitor | |
if ckpt_path is not None: | |
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet) | |
self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps, | |
linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s) | |
def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000, | |
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): | |
if exists(given_betas): | |
betas = given_betas | |
else: | |
betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, | |
cosine_s=cosine_s) | |
alphas = 1. - betas | |
alphas_cumprod = np.cumprod(alphas, axis=0) | |
timesteps, = betas.shape | |
self.num_timesteps = int(timesteps) | |
self.linear_start = linear_start | |
self.linear_end = linear_end | |
assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' | |
to_torch = partial(torch.tensor, dtype=torch.float32) | |
self.register_buffer('betas', to_torch(betas)) | |
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) | |
class FirstStage(DDPM): | |
"""main class""" | |
def __init__(self, | |
first_stage_config, | |
num_timesteps_cond=None, | |
cond_stage_key="image", | |
cond_stage_trainable=False, | |
concat_mode=True, | |
cond_stage_forward=None, | |
conditioning_key=None, | |
scale_factor=1.0, | |
scale_by_std=False, | |
*args, **kwargs): | |
self.num_timesteps_cond = default(num_timesteps_cond, 1) | |
self.scale_by_std = scale_by_std | |
assert self.num_timesteps_cond <= kwargs['timesteps'] | |
# for backwards compatibility after implementation of DiffusionWrapper | |
if conditioning_key is None: | |
conditioning_key = 'concat' if concat_mode else 'crossattn' | |
ckpt_path = kwargs.pop("ckpt_path", None) | |
ignore_keys = kwargs.pop("ignore_keys", []) | |
super().__init__() | |
self.concat_mode = concat_mode | |
self.cond_stage_trainable = cond_stage_trainable | |
self.cond_stage_key = cond_stage_key | |
try: | |
self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1 | |
except: | |
self.num_downs = 0 | |
if not scale_by_std: | |
self.scale_factor = scale_factor | |
self.instantiate_first_stage(first_stage_config) | |
self.cond_stage_forward = cond_stage_forward | |
self.clip_denoised = False | |
self.bbox_tokenizer = None | |
self.restarted_from_ckpt = False | |
if ckpt_path is not None: | |
self.init_from_ckpt(ckpt_path, ignore_keys) | |
self.restarted_from_ckpt = True | |
def instantiate_first_stage(self, config): | |
model = instantiate_from_config(config) | |
self.first_stage_model = model.eval() | |
self.first_stage_model.train = disabled_train | |
for param in self.first_stage_model.parameters(): | |
param.requires_grad = False | |
def get_first_stage_encoding(self, encoder_posterior): | |
if isinstance(encoder_posterior, DiagonalGaussianDistribution): | |
z = encoder_posterior.sample() | |
elif isinstance(encoder_posterior, torch.Tensor): | |
z = encoder_posterior | |
else: | |
raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented") | |
return self.scale_factor * z | |
def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False): | |
if predict_cids: | |
if z.dim() == 4: | |
z = torch.argmax(z.exp(), dim=1).long() | |
z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None) | |
z = rearrange(z, 'b h w c -> b c h w').contiguous() | |
z = 1. / self.scale_factor * z | |
if hasattr(self, "split_input_params"): | |
if isinstance(self.first_stage_model, VQModelInterface): | |
return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) | |
else: | |
return self.first_stage_model.decode(z) | |
else: | |
if isinstance(self.first_stage_model, VQModelInterface): | |
return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) | |
else: | |
return self.first_stage_model.decode(z) | |
def encode_first_stage(self, x): | |
if hasattr(self, "split_input_params"): | |
if self.split_input_params["patch_distributed_vq"]: | |
ks = self.split_input_params["ks"] # eg. (128, 128) | |
stride = self.split_input_params["stride"] # eg. (64, 64) | |
df = self.split_input_params["vqf"] | |
self.split_input_params['original_image_size'] = x.shape[-2:] | |
bs, nc, h, w = x.shape | |
if ks[0] > h or ks[1] > w: | |
ks = (min(ks[0], h), min(ks[1], w)) | |
print("reducing Kernel") | |
if stride[0] > h or stride[1] > w: | |
stride = (min(stride[0], h), min(stride[1], w)) | |
print("reducing stride") | |
fold, unfold, normalization, weighting = self.get_fold_unfold(x, ks, stride, df=df) | |
z = unfold(x) # (bn, nc * prod(**ks), L) | |
# Reshape to img shape | |
z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) | |
output_list = [self.first_stage_model.encode(z[:, :, :, :, i]) | |
for i in range(z.shape[-1])] | |
o = torch.stack(output_list, axis=-1) | |
o = o * weighting | |
# Reverse reshape to img shape | |
o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) | |
# stitch crops together | |
decoded = fold(o) | |
decoded = decoded / normalization | |
return decoded | |
else: | |
return self.first_stage_model.encode(x) | |
else: | |
return self.first_stage_model.encode(x) | |
class CondStage(DDPM): | |
"""main class""" | |
def __init__(self, | |
cond_stage_config, | |
num_timesteps_cond=None, | |
cond_stage_key="image", | |
cond_stage_trainable=False, | |
concat_mode=True, | |
cond_stage_forward=None, | |
conditioning_key=None, | |
scale_factor=1.0, | |
scale_by_std=False, | |
*args, **kwargs): | |
self.num_timesteps_cond = default(num_timesteps_cond, 1) | |
self.scale_by_std = scale_by_std | |
assert self.num_timesteps_cond <= kwargs['timesteps'] | |
# for backwards compatibility after implementation of DiffusionWrapper | |
if conditioning_key is None: | |
conditioning_key = 'concat' if concat_mode else 'crossattn' | |
if cond_stage_config == '__is_unconditional__': | |
conditioning_key = None | |
ckpt_path = kwargs.pop("ckpt_path", None) | |
ignore_keys = kwargs.pop("ignore_keys", []) | |
super().__init__() | |
self.concat_mode = concat_mode | |
self.cond_stage_trainable = cond_stage_trainable | |
self.cond_stage_key = cond_stage_key | |
self.num_downs = 0 | |
if not scale_by_std: | |
self.scale_factor = scale_factor | |
self.instantiate_cond_stage(cond_stage_config) | |
self.cond_stage_forward = cond_stage_forward | |
self.clip_denoised = False | |
self.bbox_tokenizer = None | |
self.restarted_from_ckpt = False | |
if ckpt_path is not None: | |
self.init_from_ckpt(ckpt_path, ignore_keys) | |
self.restarted_from_ckpt = True | |
def instantiate_cond_stage(self, config): | |
if not self.cond_stage_trainable: | |
if config == "__is_first_stage__": | |
print("Using first stage also as cond stage.") | |
self.cond_stage_model = self.first_stage_model | |
elif config == "__is_unconditional__": | |
print(f"Training {self.__class__.__name__} as an unconditional model.") | |
self.cond_stage_model = None | |
# self.be_unconditional = True | |
else: | |
model = instantiate_from_config(config) | |
self.cond_stage_model = model.eval() | |
self.cond_stage_model.train = disabled_train | |
for param in self.cond_stage_model.parameters(): | |
param.requires_grad = False | |
else: | |
assert config != '__is_first_stage__' | |
assert config != '__is_unconditional__' | |
model = instantiate_from_config(config) | |
self.cond_stage_model = model | |
def get_learned_conditioning(self, c): | |
if self.cond_stage_forward is None: | |
if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode): | |
c = self.cond_stage_model.encode(c) | |
if isinstance(c, DiagonalGaussianDistribution): | |
c = c.mode() | |
else: | |
c = self.cond_stage_model(c) | |
else: | |
assert hasattr(self.cond_stage_model, self.cond_stage_forward) | |
c = getattr(self.cond_stage_model, self.cond_stage_forward)(c) | |
return c | |
class DiffusionWrapper(pl.LightningModule): | |
def __init__(self, diff_model_config): | |
super().__init__() | |
self.diffusion_model = instantiate_from_config(diff_model_config) | |
def forward(self, x, t, cc): | |
out = self.diffusion_model(x, t, context=cc) | |
return out | |
class DiffusionWrapperOut(pl.LightningModule): | |
def __init__(self, diff_model_config): | |
super().__init__() | |
self.diffusion_model = instantiate_from_config(diff_model_config) | |
def forward(self, h,emb,tp,hs, cc): | |
return self.diffusion_model(h,emb,tp,hs, context=cc) | |
class UNet(DDPM): | |
"""main class""" | |
def __init__(self, | |
unetConfigEncode, | |
unetConfigDecode, | |
num_timesteps_cond=None, | |
cond_stage_key="image", | |
cond_stage_trainable=False, | |
concat_mode=True, | |
cond_stage_forward=None, | |
conditioning_key=None, | |
scale_factor=1.0, | |
unet_bs = 1, | |
scale_by_std=False, | |
*args, **kwargs): | |
self.num_timesteps_cond = default(num_timesteps_cond, 1) | |
self.scale_by_std = scale_by_std | |
assert self.num_timesteps_cond <= kwargs['timesteps'] | |
# for backwards compatibility after implementation of DiffusionWrapper | |
if conditioning_key is None: | |
conditioning_key = 'concat' if concat_mode else 'crossattn' | |
ckpt_path = kwargs.pop("ckpt_path", None) | |
ignore_keys = kwargs.pop("ignore_keys", []) | |
super().__init__(conditioning_key=conditioning_key, *args, **kwargs) | |
self.concat_mode = concat_mode | |
self.cond_stage_trainable = cond_stage_trainable | |
self.cond_stage_key = cond_stage_key | |
self.num_downs = 0 | |
self.cdevice = "cuda" | |
self.unetConfigEncode = unetConfigEncode | |
self.unetConfigDecode = unetConfigDecode | |
if not scale_by_std: | |
self.scale_factor = scale_factor | |
else: | |
self.register_buffer('scale_factor', torch.tensor(scale_factor)) | |
self.cond_stage_forward = cond_stage_forward | |
self.clip_denoised = False | |
self.bbox_tokenizer = None | |
self.model1 = DiffusionWrapper(self.unetConfigEncode) | |
self.model2 = DiffusionWrapperOut(self.unetConfigDecode) | |
self.model1.eval() | |
self.model2.eval() | |
self.turbo = False | |
self.unet_bs = unet_bs | |
self.restarted_from_ckpt = False | |
if ckpt_path is not None: | |
self.init_from_ckpt(ckpt_path, ignore_keys) | |
self.restarted_from_ckpt = True | |
def make_cond_schedule(self, ): | |
self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long) | |
ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long() | |
self.cond_ids[:self.num_timesteps_cond] = ids | |
def on_train_batch_start(self, batch, batch_idx): | |
# only for very first batch | |
if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt: | |
assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously' | |
# set rescale weight to 1./std of encodings | |
print("### USING STD-RESCALING ###") | |
x = super().get_input(batch, self.first_stage_key) | |
x = x.to(self.cdevice) | |
encoder_posterior = self.encode_first_stage(x) | |
z = self.get_first_stage_encoding(encoder_posterior).detach() | |
del self.scale_factor | |
self.register_buffer('scale_factor', 1. / z.flatten().std()) | |
print(f"setting self.scale_factor to {self.scale_factor}") | |
print("### USING STD-RESCALING ###") | |
def apply_model(self, x_noisy, t, cond, return_ids=False): | |
if(not self.turbo): | |
self.model1.to(self.cdevice) | |
step = self.unet_bs | |
h,emb,hs = self.model1(x_noisy[0:step], t[:step], cond[:step]) | |
bs = cond.shape[0] | |
# assert bs%2 == 0 | |
lenhs = len(hs) | |
for i in range(step,bs,step): | |
h_temp,emb_temp,hs_temp = self.model1(x_noisy[i:i+step], t[i:i+step], cond[i:i+step]) | |
h = torch.cat((h,h_temp)) | |
emb = torch.cat((emb,emb_temp)) | |
for j in range(lenhs): | |
hs[j] = torch.cat((hs[j], hs_temp[j])) | |
if(not self.turbo): | |
self.model1.to("cpu") | |
self.model2.to(self.cdevice) | |
hs_temp = [hs[j][:step] for j in range(lenhs)] | |
x_recon = self.model2(h[:step],emb[:step],x_noisy.dtype,hs_temp,cond[:step]) | |
for i in range(step,bs,step): | |
hs_temp = [hs[j][i:i+step] for j in range(lenhs)] | |
x_recon1 = self.model2(h[i:i+step],emb[i:i+step],x_noisy.dtype,hs_temp,cond[i:i+step]) | |
x_recon = torch.cat((x_recon, x_recon1)) | |
if(not self.turbo): | |
self.model2.to("cpu") | |
if isinstance(x_recon, tuple) and not return_ids: | |
return x_recon[0] | |
else: | |
return x_recon | |
def register_buffer1(self, name, attr): | |
if type(attr) == torch.Tensor: | |
if attr.device != torch.device(self.cdevice): | |
attr = attr.to(torch.device(self.cdevice)) | |
setattr(self, name, attr) | |
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): | |
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, | |
num_ddpm_timesteps=self.num_timesteps,verbose=verbose) | |
assert self.alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' | |
to_torch = lambda x: x.to(self.cdevice) | |
self.register_buffer1('betas', to_torch(self.betas)) | |
self.register_buffer1('alphas_cumprod', to_torch(self.alphas_cumprod)) | |
# ddim sampling parameters | |
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=self.alphas_cumprod.cpu(), | |
ddim_timesteps=self.ddim_timesteps, | |
eta=ddim_eta,verbose=verbose) | |
self.register_buffer1('ddim_sigmas', ddim_sigmas) | |
self.register_buffer1('ddim_alphas', ddim_alphas) | |
self.register_buffer1('ddim_alphas_prev', ddim_alphas_prev) | |
self.register_buffer1('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) | |
def sample(self, | |
S, | |
conditioning, | |
x0=None, | |
shape = None, | |
seed=1234, | |
callback=None, | |
img_callback=None, | |
quantize_x0=False, | |
eta=0., | |
mask=None, | |
sampler = "plms", | |
temperature=1., | |
noise_dropout=0., | |
score_corrector=None, | |
corrector_kwargs=None, | |
verbose=True, | |
x_T=None, | |
log_every_t=100, | |
unconditional_guidance_scale=1., | |
unconditional_conditioning=None, | |
): | |
if(self.turbo): | |
self.model1.to(self.cdevice) | |
self.model2.to(self.cdevice) | |
if x0 is None: | |
batch_size, b1, b2, b3 = shape | |
img_shape = (1, b1, b2, b3) | |
tens = [] | |
print("seeds used = ", [seed+s for s in range(batch_size)]) | |
for _ in range(batch_size): | |
torch.manual_seed(seed) | |
tens.append(torch.randn(img_shape, device=self.cdevice)) | |
seed+=1 | |
noise = torch.cat(tens) | |
del tens | |
x_latent = noise if x0 is None else x0 | |
# sampling | |
if sampler in ('ddim', 'dpm2', 'heun', 'dpm2_a', 'lms') and not hasattr(self, 'ddim_timesteps'): | |
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=False) | |
if sampler == "plms": | |
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=False) | |
print(f'Data shape for PLMS sampling is {shape}') | |
samples = self.plms_sampling(conditioning, batch_size, x_latent, | |
callback=callback, | |
img_callback=img_callback, | |
quantize_denoised=quantize_x0, | |
mask=mask, x0=x0, | |
ddim_use_original_steps=False, | |
noise_dropout=noise_dropout, | |
temperature=temperature, | |
score_corrector=score_corrector, | |
corrector_kwargs=corrector_kwargs, | |
log_every_t=log_every_t, | |
unconditional_guidance_scale=unconditional_guidance_scale, | |
unconditional_conditioning=unconditional_conditioning, | |
) | |
elif sampler == "ddim": | |
samples = self.ddim_sampling(x_latent, conditioning, S, unconditional_guidance_scale=unconditional_guidance_scale, | |
unconditional_conditioning=unconditional_conditioning, | |
mask = mask,init_latent=x_T,use_original_steps=False, | |
callback=callback, img_callback=img_callback) | |
elif sampler == "euler": | |
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=False) | |
samples = self.euler_sampling(self.alphas_cumprod,x_latent, S, conditioning, unconditional_conditioning=unconditional_conditioning, | |
unconditional_guidance_scale=unconditional_guidance_scale, | |
img_callback=img_callback) | |
elif sampler == "euler_a": | |
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=False) | |
samples = self.euler_ancestral_sampling(self.alphas_cumprod,x_latent, S, conditioning, unconditional_conditioning=unconditional_conditioning, | |
unconditional_guidance_scale=unconditional_guidance_scale, | |
img_callback=img_callback) | |
elif sampler == "dpm2": | |
samples = self.dpm_2_sampling(self.alphas_cumprod,x_latent, S, conditioning, unconditional_conditioning=unconditional_conditioning, | |
unconditional_guidance_scale=unconditional_guidance_scale, | |
img_callback=img_callback) | |
elif sampler == "heun": | |
samples = self.heun_sampling(self.alphas_cumprod,x_latent, S, conditioning, unconditional_conditioning=unconditional_conditioning, | |
unconditional_guidance_scale=unconditional_guidance_scale, | |
img_callback=img_callback) | |
elif sampler == "dpm2_a": | |
samples = self.dpm_2_ancestral_sampling(self.alphas_cumprod,x_latent, S, conditioning, unconditional_conditioning=unconditional_conditioning, | |
unconditional_guidance_scale=unconditional_guidance_scale, | |
img_callback=img_callback) | |
elif sampler == "lms": | |
samples = self.lms_sampling(self.alphas_cumprod,x_latent, S, conditioning, unconditional_conditioning=unconditional_conditioning, | |
unconditional_guidance_scale=unconditional_guidance_scale, | |
img_callback=img_callback) | |
yield from samples | |
if(self.turbo): | |
self.model1.to("cpu") | |
self.model2.to("cpu") | |
def plms_sampling(self, cond,b, img, | |
ddim_use_original_steps=False, | |
callback=None, quantize_denoised=False, | |
mask=None, x0=None, img_callback=None, log_every_t=100, | |
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, | |
unconditional_guidance_scale=1., unconditional_conditioning=None,): | |
device = self.betas.device | |
timesteps = self.ddim_timesteps | |
time_range = np.flip(timesteps) | |
total_steps = timesteps.shape[0] | |
print(f"Running PLMS Sampling with {total_steps} timesteps") | |
iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps) | |
old_eps = [] | |
for i, step in enumerate(iterator): | |
index = total_steps - i - 1 | |
ts = torch.full((b,), step, device=device, dtype=torch.long) | |
ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long) | |
if mask is not None: | |
assert x0 is not None | |
img_orig = self.q_sample(x0, ts) # TODO: deterministic forward pass? | |
img = img_orig * mask + (1. - mask) * img | |
outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, | |
quantize_denoised=quantize_denoised, temperature=temperature, | |
noise_dropout=noise_dropout, score_corrector=score_corrector, | |
corrector_kwargs=corrector_kwargs, | |
unconditional_guidance_scale=unconditional_guidance_scale, | |
unconditional_conditioning=unconditional_conditioning, | |
old_eps=old_eps, t_next=ts_next) | |
img, pred_x0, e_t = outs | |
old_eps.append(e_t) | |
if len(old_eps) >= 4: | |
old_eps.pop(0) | |
if callback: yield from callback(i) | |
if img_callback: yield from img_callback(pred_x0, i) | |
yield from img_callback(img, len(iterator)-1) | |
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, | |
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, | |
unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None): | |
b, *_, device = *x.shape, x.device | |
def get_model_output(x, t): | |
if unconditional_conditioning is None or unconditional_guidance_scale == 1.: | |
e_t = self.apply_model(x, t, c) | |
else: | |
x_in = torch.cat([x] * 2) | |
t_in = torch.cat([t] * 2) | |
c_in = torch.cat([unconditional_conditioning, c]) | |
e_t_uncond, e_t = self.apply_model(x_in, t_in, c_in).chunk(2) | |
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) | |
if score_corrector is not None: | |
assert self.parameterization == "eps" | |
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) | |
return e_t | |
alphas = self.ddim_alphas | |
alphas_prev = self.ddim_alphas_prev | |
sqrt_one_minus_alphas = self.ddim_sqrt_one_minus_alphas | |
sigmas = self.ddim_sigmas | |
def get_x_prev_and_pred_x0(e_t, index): | |
# select parameters corresponding to the currently considered timestep | |
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) | |
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) | |
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) | |
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) | |
# current prediction for x_0 | |
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() | |
if quantize_denoised: | |
pred_x0, _, *_ = self.first_stage_model.quantize(pred_x0) | |
# direction pointing to x_t | |
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t | |
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature | |
if noise_dropout > 0.: | |
noise = torch.nn.functional.dropout(noise, p=noise_dropout) | |
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise | |
return x_prev, pred_x0 | |
e_t = get_model_output(x, t) | |
if len(old_eps) == 0: | |
# Pseudo Improved Euler (2nd order) | |
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index) | |
e_t_next = get_model_output(x_prev, t_next) | |
e_t_prime = (e_t + e_t_next) / 2 | |
elif len(old_eps) == 1: | |
# 2nd order Pseudo Linear Multistep (Adams-Bashforth) | |
e_t_prime = (3 * e_t - old_eps[-1]) / 2 | |
elif len(old_eps) == 2: | |
# 3nd order Pseudo Linear Multistep (Adams-Bashforth) | |
e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 | |
elif len(old_eps) >= 3: | |
# 4nd order Pseudo Linear Multistep (Adams-Bashforth) | |
e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24 | |
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) | |
return x_prev, pred_x0, e_t | |
def stochastic_encode(self, x0, t, seed, ddim_eta,ddim_steps,use_original_steps=False, noise=None): | |
# fast, but does not allow for exact reconstruction | |
# t serves as an index to gather the correct alphas | |
self.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=ddim_eta, verbose=False) | |
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas) | |
if noise is None: | |
b0, b1, b2, b3 = x0.shape | |
img_shape = (1, b1, b2, b3) | |
tens = [] | |
print("seeds used = ", [seed+s for s in range(b0)]) | |
for _ in range(b0): | |
torch.manual_seed(seed) | |
tens.append(torch.randn(img_shape, device=x0.device)) | |
seed+=1 | |
noise = torch.cat(tens) | |
del tens | |
return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 + | |
extract_into_tensor(self.ddim_sqrt_one_minus_alphas, t, x0.shape) * noise) | |
def add_noise(self, x0, t): | |
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas) | |
noise = torch.randn(x0.shape, device=x0.device) | |
# print(extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape), | |
# extract_into_tensor(self.ddim_sqrt_one_minus_alphas, t, x0.shape)) | |
return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 + | |
extract_into_tensor(self.ddim_sqrt_one_minus_alphas, t, x0.shape) * noise) | |
def ddim_sampling(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None, | |
mask = None,init_latent=None,use_original_steps=False, | |
callback=None, img_callback=None): | |
timesteps = self.ddim_timesteps | |
timesteps = timesteps[:t_start] | |
time_range = np.flip(timesteps) | |
total_steps = timesteps.shape[0] | |
print(f"Running DDIM Sampling with {total_steps} timesteps") | |
iterator = tqdm(time_range, desc='Decoding image', total=total_steps) | |
x_dec = x_latent | |
x0 = init_latent | |
for i, step in enumerate(iterator): | |
index = total_steps - i - 1 | |
ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long) | |
if mask is not None: | |
# x0_noisy = self.add_noise(mask, torch.tensor([index] * x0.shape[0]).to(self.cdevice)) | |
x0_noisy = x0 | |
x_dec = x0_noisy* mask + (1. - mask) * x_dec | |
x_dec = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps, | |
unconditional_guidance_scale=unconditional_guidance_scale, | |
unconditional_conditioning=unconditional_conditioning) | |
if callback: yield from callback(i) | |
if img_callback: yield from img_callback(x_dec, i) | |
if mask is not None: | |
x_dec = x0 * mask + (1. - mask) * x_dec | |
yield from img_callback(x_dec, len(iterator)-1) | |
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, | |
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, | |
unconditional_guidance_scale=1., unconditional_conditioning=None): | |
b, *_, device = *x.shape, x.device | |
if unconditional_conditioning is None or unconditional_guidance_scale == 1.: | |
e_t = self.apply_model(x, t, c) | |
else: | |
x_in = torch.cat([x] * 2) | |
t_in = torch.cat([t] * 2) | |
c_in = torch.cat([unconditional_conditioning, c]) | |
e_t_uncond, e_t = self.apply_model(x_in, t_in, c_in).chunk(2) | |
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) | |
if score_corrector is not None: | |
assert self.model.parameterization == "eps" | |
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) | |
alphas = self.ddim_alphas | |
alphas_prev = self.ddim_alphas_prev | |
sqrt_one_minus_alphas = self.ddim_sqrt_one_minus_alphas | |
sigmas = self.ddim_sigmas | |
# select parameters corresponding to the currently considered timestep | |
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) | |
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) | |
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) | |
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) | |
# current prediction for x_0 | |
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() | |
if quantize_denoised: | |
pred_x0, _, *_ = self.first_stage_model.quantize(pred_x0) | |
# direction pointing to x_t | |
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t | |
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature | |
if noise_dropout > 0.: | |
noise = torch.nn.functional.dropout(noise, p=noise_dropout) | |
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise | |
return x_prev | |
def euler_sampling(self, ac, x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1,extra_args=None,callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., | |
img_callback=None): | |
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022).""" | |
extra_args = {} if extra_args is None else extra_args | |
cvd = CompVisDenoiser(ac) | |
sigmas = cvd.get_sigmas(S) | |
x = x*sigmas[0] | |
print(f"Running Euler Sampling with {len(sigmas) - 1} timesteps") | |
s_in = x.new_ones([x.shape[0]]).half() | |
for i in trange(len(sigmas) - 1, disable=disable): | |
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. | |
eps = torch.randn_like(x) * s_noise | |
sigma_hat = (sigmas[i] * (gamma + 1)).half() | |
if gamma > 0: | |
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5 | |
s_i = sigma_hat * s_in | |
x_in = torch.cat([x] * 2) | |
t_in = torch.cat([s_i] * 2) | |
cond_in = torch.cat([unconditional_conditioning, cond]) | |
c_out, c_in = [append_dims(tmp, x_in.ndim) for tmp in cvd.get_scalings(t_in)] | |
eps = self.apply_model(x_in * c_in, cvd.sigma_to_t(t_in), cond_in) | |
e_t_uncond, e_t = (x_in + eps * c_out).chunk(2) | |
denoised = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) | |
d = to_d(x, sigma_hat, denoised) | |
if callback is not None: | |
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised}) | |
if img_callback: yield from img_callback(x, i) | |
dt = sigmas[i + 1] - sigma_hat | |
# Euler method | |
x = x + d * dt | |
yield from img_callback(x, len(sigmas)-1) | |
def euler_ancestral_sampling(self,ac,x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1,extra_args=None, callback=None, disable=None, | |
img_callback=None): | |
"""Ancestral sampling with Euler method steps.""" | |
extra_args = {} if extra_args is None else extra_args | |
cvd = CompVisDenoiser(ac) | |
sigmas = cvd.get_sigmas(S) | |
x = x*sigmas[0] | |
print(f"Running Euler Ancestral Sampling with {len(sigmas) - 1} timesteps") | |
s_in = x.new_ones([x.shape[0]]).half() | |
for i in trange(len(sigmas) - 1, disable=disable): | |
s_i = sigmas[i] * s_in | |
x_in = torch.cat([x] * 2) | |
t_in = torch.cat([s_i] * 2) | |
cond_in = torch.cat([unconditional_conditioning, cond]) | |
c_out, c_in = [append_dims(tmp, x_in.ndim) for tmp in cvd.get_scalings(t_in)] | |
eps = self.apply_model(x_in * c_in, cvd.sigma_to_t(t_in), cond_in) | |
e_t_uncond, e_t = (x_in + eps * c_out).chunk(2) | |
denoised = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) | |
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1]) | |
if callback is not None: | |
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) | |
if img_callback: yield from img_callback(x, i) | |
d = to_d(x, sigmas[i], denoised) | |
# Euler method | |
dt = sigma_down - sigmas[i] | |
x = x + d * dt | |
x = x + torch.randn_like(x) * sigma_up | |
yield from img_callback(x, len(sigmas)-1) | |
def heun_sampling(self, ac, x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., | |
img_callback=None): | |
"""Implements Algorithm 2 (Heun steps) from Karras et al. (2022).""" | |
extra_args = {} if extra_args is None else extra_args | |
cvd = CompVisDenoiser(alphas_cumprod=ac) | |
sigmas = cvd.get_sigmas(S) | |
x = x*sigmas[0] | |
print(f"Running Heun Sampling with {len(sigmas) - 1} timesteps") | |
s_in = x.new_ones([x.shape[0]]).half() | |
for i in trange(len(sigmas) - 1, disable=disable): | |
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. | |
eps = torch.randn_like(x) * s_noise | |
sigma_hat = (sigmas[i] * (gamma + 1)).half() | |
if gamma > 0: | |
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5 | |
s_i = sigma_hat * s_in | |
x_in = torch.cat([x] * 2) | |
t_in = torch.cat([s_i] * 2) | |
cond_in = torch.cat([unconditional_conditioning, cond]) | |
c_out, c_in = [append_dims(tmp, x_in.ndim) for tmp in cvd.get_scalings(t_in)] | |
eps = self.apply_model(x_in * c_in, cvd.sigma_to_t(t_in), cond_in) | |
e_t_uncond, e_t = (x_in + eps * c_out).chunk(2) | |
denoised = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) | |
d = to_d(x, sigma_hat, denoised) | |
if callback is not None: | |
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised}) | |
if img_callback: yield from img_callback(x, i) | |
dt = sigmas[i + 1] - sigma_hat | |
if sigmas[i + 1] == 0: | |
# Euler method | |
x = x + d * dt | |
else: | |
# Heun's method | |
x_2 = x + d * dt | |
s_i = sigmas[i + 1] * s_in | |
x_in = torch.cat([x_2] * 2) | |
t_in = torch.cat([s_i] * 2) | |
cond_in = torch.cat([unconditional_conditioning, cond]) | |
c_out, c_in = [append_dims(tmp, x_in.ndim) for tmp in cvd.get_scalings(t_in)] | |
eps = self.apply_model(x_in * c_in, cvd.sigma_to_t(t_in), cond_in) | |
e_t_uncond, e_t = (x_in + eps * c_out).chunk(2) | |
denoised_2 = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) | |
d_2 = to_d(x_2, sigmas[i + 1], denoised_2) | |
d_prime = (d + d_2) / 2 | |
x = x + d_prime * dt | |
yield from img_callback(x, len(sigmas)-1) | |
def dpm_2_sampling(self,ac,x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1,extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., | |
img_callback=None): | |
"""A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022).""" | |
extra_args = {} if extra_args is None else extra_args | |
cvd = CompVisDenoiser(ac) | |
sigmas = cvd.get_sigmas(S) | |
x = x*sigmas[0] | |
print(f"Running DPM2 Sampling with {len(sigmas) - 1} timesteps") | |
s_in = x.new_ones([x.shape[0]]).half() | |
for i in trange(len(sigmas) - 1, disable=disable): | |
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. | |
eps = torch.randn_like(x) * s_noise | |
sigma_hat = sigmas[i] * (gamma + 1) | |
if gamma > 0: | |
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5 | |
s_i = sigma_hat * s_in | |
x_in = torch.cat([x] * 2) | |
t_in = torch.cat([s_i] * 2) | |
cond_in = torch.cat([unconditional_conditioning, cond]) | |
c_out, c_in = [append_dims(tmp, x_in.ndim) for tmp in cvd.get_scalings(t_in)] | |
eps = self.apply_model(x_in * c_in, cvd.sigma_to_t(t_in), cond_in) | |
e_t_uncond, e_t = (x_in + eps * c_out).chunk(2) | |
denoised = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) | |
if img_callback: yield from img_callback(x, i) | |
d = to_d(x, sigma_hat, denoised) | |
# Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule | |
sigma_mid = ((sigma_hat ** (1 / 3) + sigmas[i + 1] ** (1 / 3)) / 2) ** 3 | |
dt_1 = sigma_mid - sigma_hat | |
dt_2 = sigmas[i + 1] - sigma_hat | |
x_2 = x + d * dt_1 | |
s_i = sigma_mid * s_in | |
x_in = torch.cat([x_2] * 2) | |
t_in = torch.cat([s_i] * 2) | |
cond_in = torch.cat([unconditional_conditioning, cond]) | |
c_out, c_in = [append_dims(tmp, x_in.ndim) for tmp in cvd.get_scalings(t_in)] | |
eps = self.apply_model(x_in * c_in, cvd.sigma_to_t(t_in), cond_in) | |
e_t_uncond, e_t = (x_in + eps * c_out).chunk(2) | |
denoised_2 = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) | |
d_2 = to_d(x_2, sigma_mid, denoised_2) | |
x = x + d_2 * dt_2 | |
yield from img_callback(x, len(sigmas)-1) | |
def dpm_2_ancestral_sampling(self,ac,x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1, extra_args=None, callback=None, disable=None, | |
img_callback=None): | |
"""Ancestral sampling with DPM-Solver inspired second-order steps.""" | |
extra_args = {} if extra_args is None else extra_args | |
cvd = CompVisDenoiser(ac) | |
sigmas = cvd.get_sigmas(S) | |
x = x*sigmas[0] | |
print(f"Running DPM2 Ancestral Sampling with {len(sigmas) - 1} timesteps") | |
s_in = x.new_ones([x.shape[0]]).half() | |
for i in trange(len(sigmas) - 1, disable=disable): | |
s_i = sigmas[i] * s_in | |
x_in = torch.cat([x] * 2) | |
t_in = torch.cat([s_i] * 2) | |
cond_in = torch.cat([unconditional_conditioning, cond]) | |
c_out, c_in = [append_dims(tmp, x_in.ndim) for tmp in cvd.get_scalings(t_in)] | |
eps = self.apply_model(x_in * c_in, cvd.sigma_to_t(t_in), cond_in) | |
e_t_uncond, e_t = (x_in + eps * c_out).chunk(2) | |
denoised = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) | |
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1]) | |
if callback is not None: | |
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) | |
if img_callback: yield from img_callback(x, i) | |
d = to_d(x, sigmas[i], denoised) | |
# Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule | |
sigma_mid = ((sigmas[i] ** (1 / 3) + sigma_down ** (1 / 3)) / 2) ** 3 | |
dt_1 = sigma_mid - sigmas[i] | |
dt_2 = sigma_down - sigmas[i] | |
x_2 = x + d * dt_1 | |
s_i = sigma_mid * s_in | |
x_in = torch.cat([x_2] * 2) | |
t_in = torch.cat([s_i] * 2) | |
cond_in = torch.cat([unconditional_conditioning, cond]) | |
c_out, c_in = [append_dims(tmp, x_in.ndim) for tmp in cvd.get_scalings(t_in)] | |
eps = self.apply_model(x_in * c_in, cvd.sigma_to_t(t_in), cond_in) | |
e_t_uncond, e_t = (x_in + eps * c_out).chunk(2) | |
denoised_2 = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) | |
d_2 = to_d(x_2, sigma_mid, denoised_2) | |
x = x + d_2 * dt_2 | |
x = x + torch.randn_like(x) * sigma_up | |
yield from img_callback(x, len(sigmas)-1) | |
def lms_sampling(self,ac,x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1, extra_args=None, callback=None, disable=None, order=4, | |
img_callback=None): | |
extra_args = {} if extra_args is None else extra_args | |
s_in = x.new_ones([x.shape[0]]) | |
cvd = CompVisDenoiser(ac) | |
sigmas = cvd.get_sigmas(S) | |
x = x*sigmas[0] | |
print(f"Running LMS Sampling with {len(sigmas) - 1} timesteps") | |
ds = [] | |
for i in trange(len(sigmas) - 1, disable=disable): | |
s_i = sigmas[i] * s_in | |
x_in = torch.cat([x] * 2) | |
t_in = torch.cat([s_i] * 2) | |
cond_in = torch.cat([unconditional_conditioning, cond]) | |
c_out, c_in = [append_dims(tmp, x_in.ndim) for tmp in cvd.get_scalings(t_in)] | |
eps = self.apply_model(x_in * c_in, cvd.sigma_to_t(t_in), cond_in) | |
e_t_uncond, e_t = (x_in + eps * c_out).chunk(2) | |
denoised = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) | |
if img_callback: yield from img_callback(x, i) | |
d = to_d(x, sigmas[i], denoised) | |
ds.append(d) | |
if len(ds) > order: | |
ds.pop(0) | |
if callback is not None: | |
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) | |
cur_order = min(i + 1, order) | |
coeffs = [linear_multistep_coeff(cur_order, sigmas.cpu(), i, j) for j in range(cur_order)] | |
x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds))) | |
yield from img_callback(x, len(sigmas)-1) | |