fluxdev's picture
Upload folder using huggingface_hub
b1bd80d verified
raw
history blame
No virus
4 kB
# 1st edit by https://github.com/CompVis/latent-diffusion
# 2nd edit by https://github.com/Stability-AI/stablediffusion
# 3rd edit by https://github.com/Stability-AI/generative-models
# 4th edit by https://github.com/comfyanonymous/ComfyUI
# This file is only for reference, and not used in the backend or runtime.
import torch
import torch.nn as nn
import numpy as np
from functools import partial
from .util import extract_into_tensor, make_beta_schedule
from ldm_patched.ldm.util import default
class AbstractLowScaleModel(nn.Module):
# for concatenating a downsampled image to the latent representation
def __init__(self, noise_schedule_config=None):
super(AbstractLowScaleModel, self).__init__()
if noise_schedule_config is not None:
self.register_schedule(**noise_schedule_config)
def register_schedule(self, beta_schedule="linear", timesteps=1000,
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
cosine_s=cosine_s)
alphas = 1. - betas
alphas_cumprod = np.cumprod(alphas, axis=0)
alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
timesteps, = betas.shape
self.num_timesteps = int(timesteps)
self.linear_start = linear_start
self.linear_end = linear_end
assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
to_torch = partial(torch.tensor, dtype=torch.float32)
self.register_buffer('betas', to_torch(betas))
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
def q_sample(self, x_start, t, noise=None, seed=None):
if noise is None:
if seed is None:
noise = torch.randn_like(x_start)
else:
noise = torch.randn(x_start.size(), dtype=x_start.dtype, layout=x_start.layout, generator=torch.manual_seed(seed)).to(x_start.device)
return (extract_into_tensor(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start +
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise)
def forward(self, x):
return x, None
def decode(self, x):
return x
class SimpleImageConcat(AbstractLowScaleModel):
# no noise level conditioning
def __init__(self):
super(SimpleImageConcat, self).__init__(noise_schedule_config=None)
self.max_noise_level = 0
def forward(self, x):
# fix to constant noise level
return x, torch.zeros(x.shape[0], device=x.device).long()
class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel):
def __init__(self, noise_schedule_config, max_noise_level=1000, to_cuda=False):
super().__init__(noise_schedule_config=noise_schedule_config)
self.max_noise_level = max_noise_level
def forward(self, x, noise_level=None, seed=None):
if noise_level is None:
noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()
else:
assert isinstance(noise_level, torch.Tensor)
z = self.q_sample(x, noise_level, seed=seed)
return z, noise_level