johnslegers's picture
test
121f6d3
raw
history blame
48.9 kB
"""
wild mixture of
https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py
https://github.com/CompVis/taming-transformers
-- merci
"""
import time, math
from tqdm.auto import trange, tqdm
import torch
from einops import rearrange
from tqdm import tqdm
from ldmlib.modules.distributions.distributions import DiagonalGaussianDistribution
from ldmlib.models.autoencoder import VQModelInterface
import torch.nn as nn
import numpy as np
import pytorch_lightning as pl
from functools import partial
from pytorch_lightning.utilities.distributed import rank_zero_only
from ldmlib.util import exists, default, instantiate_from_config
from ldmlib.modules.diffusionmodules.util import make_beta_schedule
from ldmlib.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
from ldmlib.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
from .samplers import CompVisDenoiser, get_ancestral_step, to_d, append_dims,linear_multistep_coeff
def disabled_train(self):
"""Overwrite model.train with this function to make sure train/eval mode
does not change anymore."""
return self
class DDPM(pl.LightningModule):
# classic DDPM with Gaussian diffusion, in image space
def __init__(self,
timesteps=1000,
beta_schedule="linear",
ckpt_path=None,
ignore_keys=[],
load_only_unet=False,
monitor="val/loss",
use_ema=True,
first_stage_key="image",
image_size=256,
channels=3,
log_every_t=100,
clip_denoised=True,
linear_start=1e-4,
linear_end=2e-2,
cosine_s=8e-3,
given_betas=None,
original_elbo_weight=0.,
v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
l_simple_weight=1.,
conditioning_key=None,
parameterization="eps", # all assuming fixed variance schedules
scheduler_config=None,
use_positional_encodings=False,
):
super().__init__()
assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"'
self.parameterization = parameterization
print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode")
self.cond_stage_model = None
self.clip_denoised = clip_denoised
self.log_every_t = log_every_t
self.first_stage_key = first_stage_key
self.image_size = image_size # try conv?
self.channels = channels
self.use_positional_encodings = use_positional_encodings
self.use_scheduler = scheduler_config is not None
if self.use_scheduler:
self.scheduler_config = scheduler_config
self.v_posterior = v_posterior
self.original_elbo_weight = original_elbo_weight
self.l_simple_weight = l_simple_weight
if monitor is not None:
self.monitor = monitor
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
if exists(given_betas):
betas = given_betas
else:
betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
cosine_s=cosine_s)
alphas = 1. - betas
alphas_cumprod = np.cumprod(alphas, axis=0)
timesteps, = betas.shape
self.num_timesteps = int(timesteps)
self.linear_start = linear_start
self.linear_end = linear_end
assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
to_torch = partial(torch.tensor, dtype=torch.float32)
self.register_buffer('betas', to_torch(betas))
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
class FirstStage(DDPM):
"""main class"""
def __init__(self,
first_stage_config,
num_timesteps_cond=None,
cond_stage_key="image",
cond_stage_trainable=False,
concat_mode=True,
cond_stage_forward=None,
conditioning_key=None,
scale_factor=1.0,
scale_by_std=False,
*args, **kwargs):
self.num_timesteps_cond = default(num_timesteps_cond, 1)
self.scale_by_std = scale_by_std
assert self.num_timesteps_cond <= kwargs['timesteps']
# for backwards compatibility after implementation of DiffusionWrapper
if conditioning_key is None:
conditioning_key = 'concat' if concat_mode else 'crossattn'
ckpt_path = kwargs.pop("ckpt_path", None)
ignore_keys = kwargs.pop("ignore_keys", [])
super().__init__()
self.concat_mode = concat_mode
self.cond_stage_trainable = cond_stage_trainable
self.cond_stage_key = cond_stage_key
try:
self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
except:
self.num_downs = 0
if not scale_by_std:
self.scale_factor = scale_factor
self.instantiate_first_stage(first_stage_config)
self.cond_stage_forward = cond_stage_forward
self.clip_denoised = False
self.bbox_tokenizer = None
self.restarted_from_ckpt = False
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys)
self.restarted_from_ckpt = True
def instantiate_first_stage(self, config):
model = instantiate_from_config(config)
self.first_stage_model = model.eval()
self.first_stage_model.train = disabled_train
for param in self.first_stage_model.parameters():
param.requires_grad = False
def get_first_stage_encoding(self, encoder_posterior):
if isinstance(encoder_posterior, DiagonalGaussianDistribution):
z = encoder_posterior.sample()
elif isinstance(encoder_posterior, torch.Tensor):
z = encoder_posterior
else:
raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented")
return self.scale_factor * z
@torch.no_grad()
def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
if predict_cids:
if z.dim() == 4:
z = torch.argmax(z.exp(), dim=1).long()
z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
z = rearrange(z, 'b h w c -> b c h w').contiguous()
z = 1. / self.scale_factor * z
if hasattr(self, "split_input_params"):
if isinstance(self.first_stage_model, VQModelInterface):
return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
else:
return self.first_stage_model.decode(z)
else:
if isinstance(self.first_stage_model, VQModelInterface):
return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
else:
return self.first_stage_model.decode(z)
@torch.no_grad()
def encode_first_stage(self, x):
if hasattr(self, "split_input_params"):
if self.split_input_params["patch_distributed_vq"]:
ks = self.split_input_params["ks"] # eg. (128, 128)
stride = self.split_input_params["stride"] # eg. (64, 64)
df = self.split_input_params["vqf"]
self.split_input_params['original_image_size'] = x.shape[-2:]
bs, nc, h, w = x.shape
if ks[0] > h or ks[1] > w:
ks = (min(ks[0], h), min(ks[1], w))
print("reducing Kernel")
if stride[0] > h or stride[1] > w:
stride = (min(stride[0], h), min(stride[1], w))
print("reducing stride")
fold, unfold, normalization, weighting = self.get_fold_unfold(x, ks, stride, df=df)
z = unfold(x) # (bn, nc * prod(**ks), L)
# Reshape to img shape
z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
output_list = [self.first_stage_model.encode(z[:, :, :, :, i])
for i in range(z.shape[-1])]
o = torch.stack(output_list, axis=-1)
o = o * weighting
# Reverse reshape to img shape
o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
# stitch crops together
decoded = fold(o)
decoded = decoded / normalization
return decoded
else:
return self.first_stage_model.encode(x)
else:
return self.first_stage_model.encode(x)
class CondStage(DDPM):
"""main class"""
def __init__(self,
cond_stage_config,
num_timesteps_cond=None,
cond_stage_key="image",
cond_stage_trainable=False,
concat_mode=True,
cond_stage_forward=None,
conditioning_key=None,
scale_factor=1.0,
scale_by_std=False,
*args, **kwargs):
self.num_timesteps_cond = default(num_timesteps_cond, 1)
self.scale_by_std = scale_by_std
assert self.num_timesteps_cond <= kwargs['timesteps']
# for backwards compatibility after implementation of DiffusionWrapper
if conditioning_key is None:
conditioning_key = 'concat' if concat_mode else 'crossattn'
if cond_stage_config == '__is_unconditional__':
conditioning_key = None
ckpt_path = kwargs.pop("ckpt_path", None)
ignore_keys = kwargs.pop("ignore_keys", [])
super().__init__()
self.concat_mode = concat_mode
self.cond_stage_trainable = cond_stage_trainable
self.cond_stage_key = cond_stage_key
self.num_downs = 0
if not scale_by_std:
self.scale_factor = scale_factor
self.instantiate_cond_stage(cond_stage_config)
self.cond_stage_forward = cond_stage_forward
self.clip_denoised = False
self.bbox_tokenizer = None
self.restarted_from_ckpt = False
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys)
self.restarted_from_ckpt = True
def instantiate_cond_stage(self, config):
if not self.cond_stage_trainable:
if config == "__is_first_stage__":
print("Using first stage also as cond stage.")
self.cond_stage_model = self.first_stage_model
elif config == "__is_unconditional__":
print(f"Training {self.__class__.__name__} as an unconditional model.")
self.cond_stage_model = None
# self.be_unconditional = True
else:
model = instantiate_from_config(config)
self.cond_stage_model = model.eval()
self.cond_stage_model.train = disabled_train
for param in self.cond_stage_model.parameters():
param.requires_grad = False
else:
assert config != '__is_first_stage__'
assert config != '__is_unconditional__'
model = instantiate_from_config(config)
self.cond_stage_model = model
def get_learned_conditioning(self, c):
if self.cond_stage_forward is None:
if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode):
c = self.cond_stage_model.encode(c)
if isinstance(c, DiagonalGaussianDistribution):
c = c.mode()
else:
c = self.cond_stage_model(c)
else:
assert hasattr(self.cond_stage_model, self.cond_stage_forward)
c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
return c
class DiffusionWrapper(pl.LightningModule):
def __init__(self, diff_model_config):
super().__init__()
self.diffusion_model = instantiate_from_config(diff_model_config)
def forward(self, x, t, cc):
out = self.diffusion_model(x, t, context=cc)
return out
class DiffusionWrapperOut(pl.LightningModule):
def __init__(self, diff_model_config):
super().__init__()
self.diffusion_model = instantiate_from_config(diff_model_config)
def forward(self, h,emb,tp,hs, cc):
return self.diffusion_model(h,emb,tp,hs, context=cc)
class UNet(DDPM):
"""main class"""
def __init__(self,
unetConfigEncode,
unetConfigDecode,
num_timesteps_cond=None,
cond_stage_key="image",
cond_stage_trainable=False,
concat_mode=True,
cond_stage_forward=None,
conditioning_key=None,
scale_factor=1.0,
unet_bs = 1,
scale_by_std=False,
*args, **kwargs):
self.num_timesteps_cond = default(num_timesteps_cond, 1)
self.scale_by_std = scale_by_std
assert self.num_timesteps_cond <= kwargs['timesteps']
# for backwards compatibility after implementation of DiffusionWrapper
if conditioning_key is None:
conditioning_key = 'concat' if concat_mode else 'crossattn'
ckpt_path = kwargs.pop("ckpt_path", None)
ignore_keys = kwargs.pop("ignore_keys", [])
super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
self.concat_mode = concat_mode
self.cond_stage_trainable = cond_stage_trainable
self.cond_stage_key = cond_stage_key
self.num_downs = 0
self.cdevice = "cuda"
self.unetConfigEncode = unetConfigEncode
self.unetConfigDecode = unetConfigDecode
if not scale_by_std:
self.scale_factor = scale_factor
else:
self.register_buffer('scale_factor', torch.tensor(scale_factor))
self.cond_stage_forward = cond_stage_forward
self.clip_denoised = False
self.bbox_tokenizer = None
self.model1 = DiffusionWrapper(self.unetConfigEncode)
self.model2 = DiffusionWrapperOut(self.unetConfigDecode)
self.model1.eval()
self.model2.eval()
self.turbo = False
self.unet_bs = unet_bs
self.restarted_from_ckpt = False
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys)
self.restarted_from_ckpt = True
def make_cond_schedule(self, ):
self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long)
ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()
self.cond_ids[:self.num_timesteps_cond] = ids
@rank_zero_only
@torch.no_grad()
def on_train_batch_start(self, batch, batch_idx):
# only for very first batch
if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt:
assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously'
# set rescale weight to 1./std of encodings
print("### USING STD-RESCALING ###")
x = super().get_input(batch, self.first_stage_key)
x = x.to(self.cdevice)
encoder_posterior = self.encode_first_stage(x)
z = self.get_first_stage_encoding(encoder_posterior).detach()
del self.scale_factor
self.register_buffer('scale_factor', 1. / z.flatten().std())
print(f"setting self.scale_factor to {self.scale_factor}")
print("### USING STD-RESCALING ###")
def apply_model(self, x_noisy, t, cond, return_ids=False):
if(not self.turbo):
self.model1.to(self.cdevice)
step = self.unet_bs
h,emb,hs = self.model1(x_noisy[0:step], t[:step], cond[:step])
bs = cond.shape[0]
# assert bs%2 == 0
lenhs = len(hs)
for i in range(step,bs,step):
h_temp,emb_temp,hs_temp = self.model1(x_noisy[i:i+step], t[i:i+step], cond[i:i+step])
h = torch.cat((h,h_temp))
emb = torch.cat((emb,emb_temp))
for j in range(lenhs):
hs[j] = torch.cat((hs[j], hs_temp[j]))
if(not self.turbo):
self.model1.to("cpu")
self.model2.to(self.cdevice)
hs_temp = [hs[j][:step] for j in range(lenhs)]
x_recon = self.model2(h[:step],emb[:step],x_noisy.dtype,hs_temp,cond[:step])
for i in range(step,bs,step):
hs_temp = [hs[j][i:i+step] for j in range(lenhs)]
x_recon1 = self.model2(h[i:i+step],emb[i:i+step],x_noisy.dtype,hs_temp,cond[i:i+step])
x_recon = torch.cat((x_recon, x_recon1))
if(not self.turbo):
self.model2.to("cpu")
if isinstance(x_recon, tuple) and not return_ids:
return x_recon[0]
else:
return x_recon
def register_buffer1(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != torch.device(self.cdevice):
attr = attr.to(torch.device(self.cdevice))
setattr(self, name, attr)
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
num_ddpm_timesteps=self.num_timesteps,verbose=verbose)
assert self.alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
to_torch = lambda x: x.to(self.cdevice)
self.register_buffer1('betas', to_torch(self.betas))
self.register_buffer1('alphas_cumprod', to_torch(self.alphas_cumprod))
# ddim sampling parameters
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=self.alphas_cumprod.cpu(),
ddim_timesteps=self.ddim_timesteps,
eta=ddim_eta,verbose=verbose)
self.register_buffer1('ddim_sigmas', ddim_sigmas)
self.register_buffer1('ddim_alphas', ddim_alphas)
self.register_buffer1('ddim_alphas_prev', ddim_alphas_prev)
self.register_buffer1('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
@torch.no_grad()
def sample(self,
S,
conditioning,
x0=None,
shape = None,
seed=1234,
callback=None,
img_callback=None,
quantize_x0=False,
eta=0.,
mask=None,
sampler = "plms",
temperature=1.,
noise_dropout=0.,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.,
unconditional_conditioning=None,
):
if(self.turbo):
self.model1.to(self.cdevice)
self.model2.to(self.cdevice)
if x0 is None:
batch_size, b1, b2, b3 = shape
img_shape = (1, b1, b2, b3)
tens = []
print("seeds used = ", [seed+s for s in range(batch_size)])
for _ in range(batch_size):
torch.manual_seed(seed)
tens.append(torch.randn(img_shape, device=self.cdevice))
seed+=1
noise = torch.cat(tens)
del tens
x_latent = noise if x0 is None else x0
# sampling
if sampler in ('ddim', 'dpm2', 'heun', 'dpm2_a', 'lms') and not hasattr(self, 'ddim_timesteps'):
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=False)
if sampler == "plms":
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=False)
print(f'Data shape for PLMS sampling is {shape}')
samples = self.plms_sampling(conditioning, batch_size, x_latent,
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
mask=mask, x0=x0,
ddim_use_original_steps=False,
noise_dropout=noise_dropout,
temperature=temperature,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
log_every_t=log_every_t,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
)
elif sampler == "ddim":
samples = self.ddim_sampling(x_latent, conditioning, S, unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
mask = mask,init_latent=x_T,use_original_steps=False,
callback=callback, img_callback=img_callback)
elif sampler == "euler":
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=False)
samples = self.euler_sampling(self.alphas_cumprod,x_latent, S, conditioning, unconditional_conditioning=unconditional_conditioning,
unconditional_guidance_scale=unconditional_guidance_scale,
img_callback=img_callback)
elif sampler == "euler_a":
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=False)
samples = self.euler_ancestral_sampling(self.alphas_cumprod,x_latent, S, conditioning, unconditional_conditioning=unconditional_conditioning,
unconditional_guidance_scale=unconditional_guidance_scale,
img_callback=img_callback)
elif sampler == "dpm2":
samples = self.dpm_2_sampling(self.alphas_cumprod,x_latent, S, conditioning, unconditional_conditioning=unconditional_conditioning,
unconditional_guidance_scale=unconditional_guidance_scale,
img_callback=img_callback)
elif sampler == "heun":
samples = self.heun_sampling(self.alphas_cumprod,x_latent, S, conditioning, unconditional_conditioning=unconditional_conditioning,
unconditional_guidance_scale=unconditional_guidance_scale,
img_callback=img_callback)
elif sampler == "dpm2_a":
samples = self.dpm_2_ancestral_sampling(self.alphas_cumprod,x_latent, S, conditioning, unconditional_conditioning=unconditional_conditioning,
unconditional_guidance_scale=unconditional_guidance_scale,
img_callback=img_callback)
elif sampler == "lms":
samples = self.lms_sampling(self.alphas_cumprod,x_latent, S, conditioning, unconditional_conditioning=unconditional_conditioning,
unconditional_guidance_scale=unconditional_guidance_scale,
img_callback=img_callback)
yield from samples
if(self.turbo):
self.model1.to("cpu")
self.model2.to("cpu")
@torch.no_grad()
def plms_sampling(self, cond,b, img,
ddim_use_original_steps=False,
callback=None, quantize_denoised=False,
mask=None, x0=None, img_callback=None, log_every_t=100,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None,):
device = self.betas.device
timesteps = self.ddim_timesteps
time_range = np.flip(timesteps)
total_steps = timesteps.shape[0]
print(f"Running PLMS Sampling with {total_steps} timesteps")
iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
old_eps = []
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((b,), step, device=device, dtype=torch.long)
ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
if mask is not None:
assert x0 is not None
img_orig = self.q_sample(x0, ts) # TODO: deterministic forward pass?
img = img_orig * mask + (1. - mask) * img
outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised, temperature=temperature,
noise_dropout=noise_dropout, score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
old_eps=old_eps, t_next=ts_next)
img, pred_x0, e_t = outs
old_eps.append(e_t)
if len(old_eps) >= 4:
old_eps.pop(0)
if callback: yield from callback(i)
if img_callback: yield from img_callback(pred_x0, i)
yield from img_callback(img, len(iterator)-1)
@torch.no_grad()
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None):
b, *_, device = *x.shape, x.device
def get_model_output(x, t):
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
e_t = self.apply_model(x, t, c)
else:
x_in = torch.cat([x] * 2)
t_in = torch.cat([t] * 2)
c_in = torch.cat([unconditional_conditioning, c])
e_t_uncond, e_t = self.apply_model(x_in, t_in, c_in).chunk(2)
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
if score_corrector is not None:
assert self.parameterization == "eps"
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
return e_t
alphas = self.ddim_alphas
alphas_prev = self.ddim_alphas_prev
sqrt_one_minus_alphas = self.ddim_sqrt_one_minus_alphas
sigmas = self.ddim_sigmas
def get_x_prev_and_pred_x0(e_t, index):
# select parameters corresponding to the currently considered timestep
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
# current prediction for x_0
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
if quantize_denoised:
pred_x0, _, *_ = self.first_stage_model.quantize(pred_x0)
# direction pointing to x_t
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
if noise_dropout > 0.:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
return x_prev, pred_x0
e_t = get_model_output(x, t)
if len(old_eps) == 0:
# Pseudo Improved Euler (2nd order)
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
e_t_next = get_model_output(x_prev, t_next)
e_t_prime = (e_t + e_t_next) / 2
elif len(old_eps) == 1:
# 2nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = (3 * e_t - old_eps[-1]) / 2
elif len(old_eps) == 2:
# 3nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
elif len(old_eps) >= 3:
# 4nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
return x_prev, pred_x0, e_t
@torch.no_grad()
def stochastic_encode(self, x0, t, seed, ddim_eta,ddim_steps,use_original_steps=False, noise=None):
# fast, but does not allow for exact reconstruction
# t serves as an index to gather the correct alphas
self.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=ddim_eta, verbose=False)
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
if noise is None:
b0, b1, b2, b3 = x0.shape
img_shape = (1, b1, b2, b3)
tens = []
print("seeds used = ", [seed+s for s in range(b0)])
for _ in range(b0):
torch.manual_seed(seed)
tens.append(torch.randn(img_shape, device=x0.device))
seed+=1
noise = torch.cat(tens)
del tens
return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
extract_into_tensor(self.ddim_sqrt_one_minus_alphas, t, x0.shape) * noise)
@torch.no_grad()
def add_noise(self, x0, t):
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
noise = torch.randn(x0.shape, device=x0.device)
# print(extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape),
# extract_into_tensor(self.ddim_sqrt_one_minus_alphas, t, x0.shape))
return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
extract_into_tensor(self.ddim_sqrt_one_minus_alphas, t, x0.shape) * noise)
@torch.no_grad()
def ddim_sampling(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
mask = None,init_latent=None,use_original_steps=False,
callback=None, img_callback=None):
timesteps = self.ddim_timesteps
timesteps = timesteps[:t_start]
time_range = np.flip(timesteps)
total_steps = timesteps.shape[0]
print(f"Running DDIM Sampling with {total_steps} timesteps")
iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
x_dec = x_latent
x0 = init_latent
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
if mask is not None:
# x0_noisy = self.add_noise(mask, torch.tensor([index] * x0.shape[0]).to(self.cdevice))
x0_noisy = x0
x_dec = x0_noisy* mask + (1. - mask) * x_dec
x_dec = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning)
if callback: yield from callback(i)
if img_callback: yield from img_callback(x_dec, i)
if mask is not None:
x_dec = x0 * mask + (1. - mask) * x_dec
yield from img_callback(x_dec, len(iterator)-1)
@torch.no_grad()
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None):
b, *_, device = *x.shape, x.device
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
e_t = self.apply_model(x, t, c)
else:
x_in = torch.cat([x] * 2)
t_in = torch.cat([t] * 2)
c_in = torch.cat([unconditional_conditioning, c])
e_t_uncond, e_t = self.apply_model(x_in, t_in, c_in).chunk(2)
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
if score_corrector is not None:
assert self.model.parameterization == "eps"
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
alphas = self.ddim_alphas
alphas_prev = self.ddim_alphas_prev
sqrt_one_minus_alphas = self.ddim_sqrt_one_minus_alphas
sigmas = self.ddim_sigmas
# select parameters corresponding to the currently considered timestep
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
# current prediction for x_0
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
if quantize_denoised:
pred_x0, _, *_ = self.first_stage_model.quantize(pred_x0)
# direction pointing to x_t
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
if noise_dropout > 0.:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
return x_prev
@torch.no_grad()
def euler_sampling(self, ac, x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1,extra_args=None,callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.,
img_callback=None):
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
extra_args = {} if extra_args is None else extra_args
cvd = CompVisDenoiser(ac)
sigmas = cvd.get_sigmas(S)
x = x*sigmas[0]
print(f"Running Euler Sampling with {len(sigmas) - 1} timesteps")
s_in = x.new_ones([x.shape[0]]).half()
for i in trange(len(sigmas) - 1, disable=disable):
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
eps = torch.randn_like(x) * s_noise
sigma_hat = (sigmas[i] * (gamma + 1)).half()
if gamma > 0:
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
s_i = sigma_hat * s_in
x_in = torch.cat([x] * 2)
t_in = torch.cat([s_i] * 2)
cond_in = torch.cat([unconditional_conditioning, cond])
c_out, c_in = [append_dims(tmp, x_in.ndim) for tmp in cvd.get_scalings(t_in)]
eps = self.apply_model(x_in * c_in, cvd.sigma_to_t(t_in), cond_in)
e_t_uncond, e_t = (x_in + eps * c_out).chunk(2)
denoised = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
d = to_d(x, sigma_hat, denoised)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
if img_callback: yield from img_callback(x, i)
dt = sigmas[i + 1] - sigma_hat
# Euler method
x = x + d * dt
yield from img_callback(x, len(sigmas)-1)
@torch.no_grad()
def euler_ancestral_sampling(self,ac,x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1,extra_args=None, callback=None, disable=None,
img_callback=None):
"""Ancestral sampling with Euler method steps."""
extra_args = {} if extra_args is None else extra_args
cvd = CompVisDenoiser(ac)
sigmas = cvd.get_sigmas(S)
x = x*sigmas[0]
print(f"Running Euler Ancestral Sampling with {len(sigmas) - 1} timesteps")
s_in = x.new_ones([x.shape[0]]).half()
for i in trange(len(sigmas) - 1, disable=disable):
s_i = sigmas[i] * s_in
x_in = torch.cat([x] * 2)
t_in = torch.cat([s_i] * 2)
cond_in = torch.cat([unconditional_conditioning, cond])
c_out, c_in = [append_dims(tmp, x_in.ndim) for tmp in cvd.get_scalings(t_in)]
eps = self.apply_model(x_in * c_in, cvd.sigma_to_t(t_in), cond_in)
e_t_uncond, e_t = (x_in + eps * c_out).chunk(2)
denoised = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1])
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
if img_callback: yield from img_callback(x, i)
d = to_d(x, sigmas[i], denoised)
# Euler method
dt = sigma_down - sigmas[i]
x = x + d * dt
x = x + torch.randn_like(x) * sigma_up
yield from img_callback(x, len(sigmas)-1)
@torch.no_grad()
def heun_sampling(self, ac, x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.,
img_callback=None):
"""Implements Algorithm 2 (Heun steps) from Karras et al. (2022)."""
extra_args = {} if extra_args is None else extra_args
cvd = CompVisDenoiser(alphas_cumprod=ac)
sigmas = cvd.get_sigmas(S)
x = x*sigmas[0]
print(f"Running Heun Sampling with {len(sigmas) - 1} timesteps")
s_in = x.new_ones([x.shape[0]]).half()
for i in trange(len(sigmas) - 1, disable=disable):
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
eps = torch.randn_like(x) * s_noise
sigma_hat = (sigmas[i] * (gamma + 1)).half()
if gamma > 0:
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
s_i = sigma_hat * s_in
x_in = torch.cat([x] * 2)
t_in = torch.cat([s_i] * 2)
cond_in = torch.cat([unconditional_conditioning, cond])
c_out, c_in = [append_dims(tmp, x_in.ndim) for tmp in cvd.get_scalings(t_in)]
eps = self.apply_model(x_in * c_in, cvd.sigma_to_t(t_in), cond_in)
e_t_uncond, e_t = (x_in + eps * c_out).chunk(2)
denoised = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
d = to_d(x, sigma_hat, denoised)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
if img_callback: yield from img_callback(x, i)
dt = sigmas[i + 1] - sigma_hat
if sigmas[i + 1] == 0:
# Euler method
x = x + d * dt
else:
# Heun's method
x_2 = x + d * dt
s_i = sigmas[i + 1] * s_in
x_in = torch.cat([x_2] * 2)
t_in = torch.cat([s_i] * 2)
cond_in = torch.cat([unconditional_conditioning, cond])
c_out, c_in = [append_dims(tmp, x_in.ndim) for tmp in cvd.get_scalings(t_in)]
eps = self.apply_model(x_in * c_in, cvd.sigma_to_t(t_in), cond_in)
e_t_uncond, e_t = (x_in + eps * c_out).chunk(2)
denoised_2 = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
d_prime = (d + d_2) / 2
x = x + d_prime * dt
yield from img_callback(x, len(sigmas)-1)
@torch.no_grad()
def dpm_2_sampling(self,ac,x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1,extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.,
img_callback=None):
"""A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022)."""
extra_args = {} if extra_args is None else extra_args
cvd = CompVisDenoiser(ac)
sigmas = cvd.get_sigmas(S)
x = x*sigmas[0]
print(f"Running DPM2 Sampling with {len(sigmas) - 1} timesteps")
s_in = x.new_ones([x.shape[0]]).half()
for i in trange(len(sigmas) - 1, disable=disable):
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
eps = torch.randn_like(x) * s_noise
sigma_hat = sigmas[i] * (gamma + 1)
if gamma > 0:
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
s_i = sigma_hat * s_in
x_in = torch.cat([x] * 2)
t_in = torch.cat([s_i] * 2)
cond_in = torch.cat([unconditional_conditioning, cond])
c_out, c_in = [append_dims(tmp, x_in.ndim) for tmp in cvd.get_scalings(t_in)]
eps = self.apply_model(x_in * c_in, cvd.sigma_to_t(t_in), cond_in)
e_t_uncond, e_t = (x_in + eps * c_out).chunk(2)
denoised = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
if img_callback: yield from img_callback(x, i)
d = to_d(x, sigma_hat, denoised)
# Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule
sigma_mid = ((sigma_hat ** (1 / 3) + sigmas[i + 1] ** (1 / 3)) / 2) ** 3
dt_1 = sigma_mid - sigma_hat
dt_2 = sigmas[i + 1] - sigma_hat
x_2 = x + d * dt_1
s_i = sigma_mid * s_in
x_in = torch.cat([x_2] * 2)
t_in = torch.cat([s_i] * 2)
cond_in = torch.cat([unconditional_conditioning, cond])
c_out, c_in = [append_dims(tmp, x_in.ndim) for tmp in cvd.get_scalings(t_in)]
eps = self.apply_model(x_in * c_in, cvd.sigma_to_t(t_in), cond_in)
e_t_uncond, e_t = (x_in + eps * c_out).chunk(2)
denoised_2 = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
d_2 = to_d(x_2, sigma_mid, denoised_2)
x = x + d_2 * dt_2
yield from img_callback(x, len(sigmas)-1)
@torch.no_grad()
def dpm_2_ancestral_sampling(self,ac,x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1, extra_args=None, callback=None, disable=None,
img_callback=None):
"""Ancestral sampling with DPM-Solver inspired second-order steps."""
extra_args = {} if extra_args is None else extra_args
cvd = CompVisDenoiser(ac)
sigmas = cvd.get_sigmas(S)
x = x*sigmas[0]
print(f"Running DPM2 Ancestral Sampling with {len(sigmas) - 1} timesteps")
s_in = x.new_ones([x.shape[0]]).half()
for i in trange(len(sigmas) - 1, disable=disable):
s_i = sigmas[i] * s_in
x_in = torch.cat([x] * 2)
t_in = torch.cat([s_i] * 2)
cond_in = torch.cat([unconditional_conditioning, cond])
c_out, c_in = [append_dims(tmp, x_in.ndim) for tmp in cvd.get_scalings(t_in)]
eps = self.apply_model(x_in * c_in, cvd.sigma_to_t(t_in), cond_in)
e_t_uncond, e_t = (x_in + eps * c_out).chunk(2)
denoised = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1])
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
if img_callback: yield from img_callback(x, i)
d = to_d(x, sigmas[i], denoised)
# Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule
sigma_mid = ((sigmas[i] ** (1 / 3) + sigma_down ** (1 / 3)) / 2) ** 3
dt_1 = sigma_mid - sigmas[i]
dt_2 = sigma_down - sigmas[i]
x_2 = x + d * dt_1
s_i = sigma_mid * s_in
x_in = torch.cat([x_2] * 2)
t_in = torch.cat([s_i] * 2)
cond_in = torch.cat([unconditional_conditioning, cond])
c_out, c_in = [append_dims(tmp, x_in.ndim) for tmp in cvd.get_scalings(t_in)]
eps = self.apply_model(x_in * c_in, cvd.sigma_to_t(t_in), cond_in)
e_t_uncond, e_t = (x_in + eps * c_out).chunk(2)
denoised_2 = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
d_2 = to_d(x_2, sigma_mid, denoised_2)
x = x + d_2 * dt_2
x = x + torch.randn_like(x) * sigma_up
yield from img_callback(x, len(sigmas)-1)
@torch.no_grad()
def lms_sampling(self,ac,x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1, extra_args=None, callback=None, disable=None, order=4,
img_callback=None):
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
cvd = CompVisDenoiser(ac)
sigmas = cvd.get_sigmas(S)
x = x*sigmas[0]
print(f"Running LMS Sampling with {len(sigmas) - 1} timesteps")
ds = []
for i in trange(len(sigmas) - 1, disable=disable):
s_i = sigmas[i] * s_in
x_in = torch.cat([x] * 2)
t_in = torch.cat([s_i] * 2)
cond_in = torch.cat([unconditional_conditioning, cond])
c_out, c_in = [append_dims(tmp, x_in.ndim) for tmp in cvd.get_scalings(t_in)]
eps = self.apply_model(x_in * c_in, cvd.sigma_to_t(t_in), cond_in)
e_t_uncond, e_t = (x_in + eps * c_out).chunk(2)
denoised = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
if img_callback: yield from img_callback(x, i)
d = to_d(x, sigmas[i], denoised)
ds.append(d)
if len(ds) > order:
ds.pop(0)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
cur_order = min(i + 1, order)
coeffs = [linear_multistep_coeff(cur_order, sigmas.cpu(), i, j) for j in range(cur_order)]
x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds)))
yield from img_callback(x, len(sigmas)-1)