|
import torch |
|
from torch import nn |
|
import torch.nn.functional as F |
|
from .denoiser import ConditionalUNet |
|
import numpy as np |
|
|
|
|
|
def extract(v, i, shape): |
|
out = torch.gather(v, index=i, dim=0) |
|
out = out.to(device=i.device, dtype=torch.float32) |
|
|
|
out = out.view([i.shape[0]] + [1] * (len(shape) - 1)) |
|
return out |
|
|
|
|
|
class GaussianDiffusionTrainer(nn.Module): |
|
def __init__(self, model: nn.Module, beta: tuple[int, int], T: int): |
|
super().__init__() |
|
self.model = model |
|
self.T = T |
|
|
|
self.register_buffer("beta_t", torch.linspace(*beta, T, dtype=torch.float32)) |
|
|
|
alpha_t = 1.0 - self.beta_t |
|
alpha_t_bar = torch.cumprod(alpha_t, dim=0) |
|
|
|
self.register_buffer("signal_rate", torch.sqrt(alpha_t_bar)) |
|
self.register_buffer("noise_rate", torch.sqrt(1.0 - alpha_t_bar)) |
|
|
|
def forward(self, x_0, z, **kwargs): |
|
|
|
mask = torch.isnan(x_0) |
|
x_0 = torch.nan_to_num(x_0, 0.) |
|
|
|
t = torch.randint(self.T, size=(x_0.shape[0],), device=x_0.device) |
|
|
|
epsilon = torch.randn_like(x_0) |
|
|
|
x_t = (extract(self.signal_rate, t, x_0.shape) * x_0 + |
|
extract(self.noise_rate, t, x_0.shape) * epsilon) |
|
epsilon_theta = self.model(x_t, t, z) |
|
|
|
loss = F.mse_loss(epsilon_theta, epsilon, reduction="none") |
|
loss[mask] = torch.nan |
|
return loss.nanmean() |
|
|
|
|
|
class DDPMSampler(nn.Module): |
|
def __init__(self, model: nn.Module, beta: tuple[int, int], T: int): |
|
super().__init__() |
|
self.model = model |
|
self.T = T |
|
|
|
self.register_buffer("beta_t", torch.linspace(*beta, T, dtype=torch.float32)) |
|
|
|
alpha_t = 1.0 - self.beta_t |
|
alpha_t_bar = torch.cumprod(alpha_t, dim=0) |
|
alpha_t_bar_prev = F.pad(alpha_t_bar[:-1], (1, 0), value=1.0) |
|
self.register_buffer("coeff_1", torch.sqrt(1.0 / alpha_t)) |
|
self.register_buffer("coeff_2", self.coeff_1 * (1.0 - alpha_t) / torch.sqrt(1.0 - alpha_t_bar)) |
|
self.register_buffer("posterior_variance", self.beta_t * (1.0 - alpha_t_bar_prev) / (1.0 - alpha_t_bar)) |
|
|
|
@torch.no_grad() |
|
def cal_mean_variance(self, x_t, t, c): |
|
|
|
epsilon_theta = self.model(x_t, t, c) |
|
mean = extract(self.coeff_1, t, x_t.shape) * x_t - extract(self.coeff_2, t, x_t.shape) * epsilon_theta |
|
|
|
var = extract(self.posterior_variance, t, x_t.shape) |
|
return mean, var |
|
|
|
@torch.no_grad() |
|
def sample_one_step(self, x_t, time_step, c): |
|
|
|
t = torch.full((x_t.shape[0],), time_step, device=x_t.device, dtype=torch.long) |
|
mean, var = self.cal_mean_variance(x_t, t, c) |
|
z = torch.randn_like(x_t) if time_step > 0 else 0 |
|
x_t_minus_one = mean + torch.sqrt(var) * z |
|
if torch.isnan(x_t_minus_one).int().sum() != 0: |
|
raise ValueError("nan in tensor!") |
|
return x_t_minus_one |
|
|
|
@torch.no_grad() |
|
def forward(self, x_t, c, only_return_x_0=True, interval=1, **kwargs): |
|
x = [x_t] |
|
for time_step in reversed(range(self.T)): |
|
x_t = self.sample_one_step(x_t, time_step, c) |
|
if not only_return_x_0 and ((self.T - time_step) % interval == 0 or time_step == 0): |
|
x.append(x_t) |
|
if only_return_x_0: |
|
return x_t |
|
return torch.stack(x, dim=1) |
|
|
|
|
|
class DDIMSampler(nn.Module): |
|
def __init__(self, model: nn.Module, beta: tuple[int, int], T: int): |
|
super().__init__() |
|
self.model = model |
|
self.T = T |
|
|
|
beta_t = torch.linspace(*beta, T, dtype=torch.float32) |
|
|
|
alpha_t = 1.0 - beta_t |
|
self.register_buffer("alpha_t_bar", torch.cumprod(alpha_t, dim=0)) |
|
|
|
@torch.no_grad() |
|
def sample_one_step(self, x_t, time_step, c, prev_time_step, eta): |
|
t = torch.full((x_t.shape[0],), time_step, device=x_t.device, dtype=torch.long) |
|
prev_t = torch.full((x_t.shape[0],), prev_time_step, device=x_t.device, dtype=torch.long) |
|
|
|
alpha_t = extract(self.alpha_t_bar, t, x_t.shape) |
|
alpha_t_prev = extract(self.alpha_t_bar, prev_t, x_t.shape) |
|
|
|
epsilon_theta_t = self.model(x_t, t, c) |
|
|
|
sigma_t = eta * torch.sqrt((1 - alpha_t_prev) / (1 - alpha_t) * (1 - alpha_t / alpha_t_prev)) |
|
epsilon_t = torch.randn_like(x_t) |
|
x_t_minus_one = (torch.sqrt(alpha_t_prev / alpha_t) * x_t + |
|
(torch.sqrt(1 - alpha_t_prev - sigma_t ** 2) - torch.sqrt( |
|
(alpha_t_prev * (1 - alpha_t)) / alpha_t)) * epsilon_theta_t + |
|
sigma_t * epsilon_t) |
|
return x_t_minus_one |
|
|
|
@torch.no_grad() |
|
def forward(self, x_t, c, steps=60, method="linear", eta=0.05, only_return_x_0=True, interval=1, **kwargs): |
|
if steps == 0: |
|
return c |
|
if method == "linear": |
|
a = self.T // steps |
|
time_steps = np.asarray(list(range(0, self.T, a))) |
|
elif method == "quadratic": |
|
time_steps = (np.linspace(0, np.sqrt(self.T * 0.8), steps) ** 2).astype(np.int) |
|
else: |
|
raise NotImplementedError(f"sampling method {method} is not implemented!") |
|
|
|
time_steps = time_steps + 1 |
|
|
|
time_steps_prev = np.concatenate([[0], time_steps[:-1]]) |
|
x = [x_t] |
|
for i in reversed(range(0, steps)): |
|
x_t = self.sample_one_step(x_t, time_steps[i], c, time_steps_prev[i], eta) |
|
if not only_return_x_0 and ((steps - i) % interval == 0 or i == 0): |
|
x.append(x_t) |
|
if only_return_x_0: |
|
return x_t |
|
return torch.stack(x, dim=1) |
|
|
|
|
|
|
|
|
|
class DiffusionLoss(nn.Module): |
|
config = {} |
|
|
|
def __init__(self): |
|
super().__init__() |
|
self.net = ConditionalUNet( |
|
layer_channels=self.config["layer_channels"], |
|
model_dim=self.config["model_dim"], |
|
condition_dim=self.config["condition_dim"], |
|
kernel_size=self.config["kernel_size"], |
|
) |
|
self.diffusion_trainer = GaussianDiffusionTrainer( |
|
model=self.net, |
|
beta=self.config["beta"], |
|
T=self.config["T"] |
|
) |
|
self.diffusion_sampler = self.config["sample_mode"]( |
|
model=self.net, |
|
beta=self.config["beta"], |
|
T=self.config["T"] |
|
) |
|
|
|
def forward(self, x, c, **kwargs): |
|
if kwargs.get("parameter_weight_decay"): |
|
x = x * (1.0 - kwargs["parameter_weight_decay"]) |
|
|
|
x = x.view(-1, x.size(-1)) |
|
c = c.view(-1, c.size(-1)) |
|
real_batch = x.size(0) |
|
batch = self.config.get("diffusion_batch") |
|
if self.config.get("forward_once"): |
|
random_indices = torch.randperm(x.size(0))[:batch] |
|
x, c = x[random_indices], c[random_indices] |
|
real_batch = x.size(0) |
|
if batch is not None and real_batch > batch: |
|
loss = 0. |
|
num_loops = x.size(0) // batch if x.size(0) % batch != 0 else x.size(0) // batch - 1 |
|
for _ in range(num_loops): |
|
loss += self.diffusion_trainer(x[:batch], c[:batch], **kwargs) * batch |
|
x, c = x[batch:], c[batch:] |
|
loss += self.diffusion_trainer(x, c, **kwargs) * x.size(0) |
|
loss = loss / real_batch |
|
else: |
|
loss = self.diffusion_trainer(x, c, **kwargs) |
|
return loss |
|
|
|
@torch.no_grad() |
|
def sample(self, x, c, **kwargs): |
|
|
|
|
|
batch = self.config.get("diffusion_batch") |
|
|
|
|
|
x_shape = x.shape |
|
x = x.view(-1, x.size(-1)) |
|
c = c.view(-1, c.size(-1)) |
|
if kwargs.get("only_return_x_0") is False: |
|
diffusion_steps = self.diffusion_sampler(x, c, **kwargs) |
|
return torch.permute(diffusion_steps, (1, 0, 2)) |
|
if batch is not None and x.size(0) > batch: |
|
result = [] |
|
num_loops = x.size(0) // batch if x.size(0) % batch != 0 else x.size(0) // batch - 1 |
|
for _ in range(num_loops): |
|
result.append(self.diffusion_sampler(x[:batch], c[:batch], **kwargs)) |
|
x, c = x[batch:], c[batch:] |
|
result.append(self.diffusion_sampler(x, c, **kwargs)) |
|
return torch.cat(result, dim=0).view(x_shape) |
|
else: |
|
return self.diffusion_sampler(x, c, **kwargs).view(x_shape) |