Doven
update code.
f7009b3
raw
history blame
9.87 kB
import torch
from torch import nn
import torch.nn.functional as F
from .denoiser import ConditionalUNet
import numpy as np
def extract(v, i, shape):
out = torch.gather(v, index=i, dim=0)
out = out.to(device=i.device, dtype=torch.float32)
# reshape to (batch_size, 1, 1, 1, 1, ...) for broadcasting purposes.
out = out.view([i.shape[0]] + [1] * (len(shape) - 1))
return out
class GaussianDiffusionTrainer(nn.Module):
def __init__(self, model: nn.Module, beta: tuple[int, int], T: int):
super().__init__()
self.model = model
self.T = T
# generate T steps of beta
self.register_buffer("beta_t", torch.linspace(*beta, T, dtype=torch.float32))
# calculate the cumulative product of $\alpha$ , named $\bar{\alpha_t}$ in paper
alpha_t = 1.0 - self.beta_t
alpha_t_bar = torch.cumprod(alpha_t, dim=0)
# calculate and store two coefficient of $q(x_t | x_0)$
self.register_buffer("signal_rate", torch.sqrt(alpha_t_bar))
self.register_buffer("noise_rate", torch.sqrt(1.0 - alpha_t_bar))
def forward(self, x_0, z, **kwargs):
# preprocess nan to zero
mask = torch.isnan(x_0)
x_0 = torch.nan_to_num(x_0, 0.)
# get a random training step $t \sim Uniform({1, ..., T})$
t = torch.randint(self.T, size=(x_0.shape[0],), device=x_0.device)
# generate $\epsilon \sim N(0, 1)$
epsilon = torch.randn_like(x_0)
# predict the noise added from $x_{t-1}$ to $x_t$
x_t = (extract(self.signal_rate, t, x_0.shape) * x_0 +
extract(self.noise_rate, t, x_0.shape) * epsilon)
epsilon_theta = self.model(x_t, t, z)
# get the gradient
loss = F.mse_loss(epsilon_theta, epsilon, reduction="none")
loss[mask] = torch.nan
return loss.nanmean()
class DDPMSampler(nn.Module):
def __init__(self, model: nn.Module, beta: tuple[int, int], T: int):
super().__init__()
self.model = model
self.T = T
# generate T steps of beta
self.register_buffer("beta_t", torch.linspace(*beta, T, dtype=torch.float32))
# calculate the cumulative product of $\alpha$ , named $\bar{\alpha_t}$ in paper
alpha_t = 1.0 - self.beta_t
alpha_t_bar = torch.cumprod(alpha_t, dim=0)
alpha_t_bar_prev = F.pad(alpha_t_bar[:-1], (1, 0), value=1.0)
self.register_buffer("coeff_1", torch.sqrt(1.0 / alpha_t))
self.register_buffer("coeff_2", self.coeff_1 * (1.0 - alpha_t) / torch.sqrt(1.0 - alpha_t_bar))
self.register_buffer("posterior_variance", self.beta_t * (1.0 - alpha_t_bar_prev) / (1.0 - alpha_t_bar))
@torch.no_grad()
def cal_mean_variance(self, x_t, t, c):
# """ Calculate the mean and variance for $q(x_{t-1} | x_t, x_0)$ """
epsilon_theta = self.model(x_t, t, c)
mean = extract(self.coeff_1, t, x_t.shape) * x_t - extract(self.coeff_2, t, x_t.shape) * epsilon_theta
# var is a constant
var = extract(self.posterior_variance, t, x_t.shape)
return mean, var
@torch.no_grad()
def sample_one_step(self, x_t, time_step, c):
# """ Calculate $x_{t-1}$ according to $x_t$ """
t = torch.full((x_t.shape[0],), time_step, device=x_t.device, dtype=torch.long)
mean, var = self.cal_mean_variance(x_t, t, c)
z = torch.randn_like(x_t) if time_step > 0 else 0
x_t_minus_one = mean + torch.sqrt(var) * z
if torch.isnan(x_t_minus_one).int().sum() != 0:
raise ValueError("nan in tensor!")
return x_t_minus_one
@torch.no_grad()
def forward(self, x_t, c, only_return_x_0=True, interval=1, **kwargs):
x = [x_t]
for time_step in reversed(range(self.T)):
x_t = self.sample_one_step(x_t, time_step, c)
if not only_return_x_0 and ((self.T - time_step) % interval == 0 or time_step == 0):
x.append(x_t)
if only_return_x_0:
return x_t # [batch_size, channels, height, width]
return torch.stack(x, dim=1) # [batch_size, sample, channels, height, width]
class DDIMSampler(nn.Module):
def __init__(self, model: nn.Module, beta: tuple[int, int], T: int):
super().__init__()
self.model = model
self.T = T
# generate T steps of beta
beta_t = torch.linspace(*beta, T, dtype=torch.float32)
# calculate the cumulative product of $\alpha$ , named $\bar{\alpha_t}$ in paper
alpha_t = 1.0 - beta_t
self.register_buffer("alpha_t_bar", torch.cumprod(alpha_t, dim=0))
@torch.no_grad()
def sample_one_step(self, x_t, time_step, c, prev_time_step, eta):
t = torch.full((x_t.shape[0],), time_step, device=x_t.device, dtype=torch.long)
prev_t = torch.full((x_t.shape[0],), prev_time_step, device=x_t.device, dtype=torch.long)
# get current and previous alpha_cumprod
alpha_t = extract(self.alpha_t_bar, t, x_t.shape)
alpha_t_prev = extract(self.alpha_t_bar, prev_t, x_t.shape)
# predict noise using model
epsilon_theta_t = self.model(x_t, t, c)
# calculate x_{t-1}
sigma_t = eta * torch.sqrt((1 - alpha_t_prev) / (1 - alpha_t) * (1 - alpha_t / alpha_t_prev))
epsilon_t = torch.randn_like(x_t)
x_t_minus_one = (torch.sqrt(alpha_t_prev / alpha_t) * x_t +
(torch.sqrt(1 - alpha_t_prev - sigma_t ** 2) - torch.sqrt(
(alpha_t_prev * (1 - alpha_t)) / alpha_t)) * epsilon_theta_t +
sigma_t * epsilon_t)
return x_t_minus_one
@torch.no_grad()
def forward(self, x_t, c, steps=60, method="linear", eta=0.05, only_return_x_0=True, interval=1, **kwargs):
if steps == 0:
return c
if method == "linear":
a = self.T // steps
time_steps = np.asarray(list(range(0, self.T, a)))
elif method == "quadratic":
time_steps = (np.linspace(0, np.sqrt(self.T * 0.8), steps) ** 2).astype(np.int)
else: # NotImplementedError
raise NotImplementedError(f"sampling method {method} is not implemented!")
# add one to get the final alpha values right (the ones from first scale to data during sampling)
time_steps = time_steps + 1
# previous sequence
time_steps_prev = np.concatenate([[0], time_steps[:-1]])
x = [x_t]
for i in reversed(range(0, steps)):
x_t = self.sample_one_step(x_t, time_steps[i], c, time_steps_prev[i], eta)
if not only_return_x_0 and ((steps - i) % interval == 0 or i == 0):
x.append(x_t)
if only_return_x_0:
return x_t # [batch_size x channels, dim]
return torch.stack(x, dim=1) # [batch_size x channels, sample, dim]
class DiffusionLoss(nn.Module):
config = {}
def __init__(self):
super().__init__()
self.net = ConditionalUNet(
layer_channels=self.config["layer_channels"],
model_dim=self.config["model_dim"],
condition_dim=self.config["condition_dim"],
kernel_size=self.config["kernel_size"],
)
self.diffusion_trainer = GaussianDiffusionTrainer(
model=self.net,
beta=self.config["beta"],
T=self.config["T"]
)
self.diffusion_sampler = self.config["sample_mode"](
model=self.net,
beta=self.config["beta"],
T=self.config["T"]
)
def forward(self, x, c, **kwargs):
if kwargs.get("parameter_weight_decay"):
x = x * (1.0 - kwargs["parameter_weight_decay"])
# Given condition z and ground truth token x, compute loss
x = x.view(-1, x.size(-1))
c = c.view(-1, c.size(-1))
real_batch = x.size(0)
batch = self.config.get("diffusion_batch")
if self.config.get("forward_once"):
random_indices = torch.randperm(x.size(0))[:batch]
x, c = x[random_indices], c[random_indices]
real_batch = x.size(0)
if batch is not None and real_batch > batch:
loss = 0.
num_loops = x.size(0) // batch if x.size(0) % batch != 0 else x.size(0) // batch - 1
for _ in range(num_loops):
loss += self.diffusion_trainer(x[:batch], c[:batch], **kwargs) * batch
x, c = x[batch:], c[batch:]
loss += self.diffusion_trainer(x, c, **kwargs) * x.size(0)
loss = loss / real_batch
else: # all as a batch
loss = self.diffusion_trainer(x, c, **kwargs)
return loss
@torch.no_grad()
def sample(self, x, c, **kwargs):
# Given condition and noise, sample x using reverse diffusion process
# Given condition z and ground truth token x, compute loss
batch = self.config.get("diffusion_batch")
# if batch is not None:
# batch = max(batch, 256)
x_shape = x.shape
x = x.view(-1, x.size(-1))
c = c.view(-1, c.size(-1))
if kwargs.get("only_return_x_0") is False:
diffusion_steps = self.diffusion_sampler(x, c, **kwargs)
return torch.permute(diffusion_steps, (1, 0, 2)) # [sample, 1 x channels, dim]
if batch is not None and x.size(0) > batch:
result = []
num_loops = x.size(0) // batch if x.size(0) % batch != 0 else x.size(0) // batch - 1
for _ in range(num_loops):
result.append(self.diffusion_sampler(x[:batch], c[:batch], **kwargs))
x, c = x[batch:], c[batch:]
result.append(self.diffusion_sampler(x, c, **kwargs))
return torch.cat(result, dim=0).view(x_shape)
else: # all as a batch
return self.diffusion_sampler(x, c, **kwargs).view(x_shape)