# 1st edit by https://github.com/CompVis/latent-diffusion # 2nd edit by https://github.com/Stability-AI/stablediffusion # 3rd edit by https://github.com/Stability-AI/generative-models # 4th edit by https://github.com/comfyanonymous/ComfyUI # This file is only for reference, and not used in the backend or runtime. import torch import torch.nn as nn import numpy as np from functools import partial from .util import extract_into_tensor, make_beta_schedule from ldm_patched.ldm.util import default class AbstractLowScaleModel(nn.Module): # for concatenating a downsampled image to the latent representation def __init__(self, noise_schedule_config=None): super(AbstractLowScaleModel, self).__init__() if noise_schedule_config is not None: self.register_schedule(**noise_schedule_config) def register_schedule(self, beta_schedule="linear", timesteps=1000, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s) alphas = 1. - betas alphas_cumprod = np.cumprod(alphas, axis=0) alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) timesteps, = betas.shape self.num_timesteps = int(timesteps) self.linear_start = linear_start self.linear_end = linear_end assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' to_torch = partial(torch.tensor, dtype=torch.float32) self.register_buffer('betas', to_torch(betas)) self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) # calculations for diffusion q(x_t | x_{t-1}) and others self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) def q_sample(self, x_start, t, noise=None, seed=None): if noise is None: if seed is None: noise = torch.randn_like(x_start) else: noise = torch.randn(x_start.size(), dtype=x_start.dtype, layout=x_start.layout, generator=torch.manual_seed(seed)).to(x_start.device) return (extract_into_tensor(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise) def forward(self, x): return x, None def decode(self, x): return x class SimpleImageConcat(AbstractLowScaleModel): # no noise level conditioning def __init__(self): super(SimpleImageConcat, self).__init__(noise_schedule_config=None) self.max_noise_level = 0 def forward(self, x): # fix to constant noise level return x, torch.zeros(x.shape[0], device=x.device).long() class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel): def __init__(self, noise_schedule_config, max_noise_level=1000, to_cuda=False): super().__init__(noise_schedule_config=noise_schedule_config) self.max_noise_level = max_noise_level def forward(self, x, noise_level=None, seed=None): if noise_level is None: noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long() else: assert isinstance(noise_level, torch.Tensor) z = self.q_sample(x, noise_level, seed=seed) return z, noise_level