|
import numpy as np |
|
import torch as th |
|
|
|
from .gaussian_diffusion import GaussianDiffusion, GaussianDiffusionDDPM |
|
|
|
def space_timesteps(num_timesteps, sample_timesteps): |
|
""" |
|
Create a list of timesteps to use from an original diffusion process, |
|
given the number of timesteps we want to take from equally-sized portions |
|
of the original process. |
|
|
|
:param num_timesteps: the number of diffusion steps in the original |
|
process to divide up. |
|
:param section_counts: timesteps for sampling |
|
:return: a set of diffusion steps from the original process to use. |
|
""" |
|
all_steps = [int((num_timesteps/sample_timesteps) * x) for x in range(sample_timesteps)] |
|
return set(all_steps) |
|
|
|
class SpacedDiffusion(GaussianDiffusion): |
|
""" |
|
A diffusion process which can skip steps in a base diffusion process. |
|
|
|
:param use_timesteps: a collection (sequence or set) of timesteps from the |
|
original diffusion process to retain. |
|
:param kwargs: the kwargs to create the base diffusion process. |
|
""" |
|
|
|
def __init__(self, use_timesteps, **kwargs): |
|
self.use_timesteps = set(use_timesteps) |
|
self.timestep_map = [] |
|
self.original_num_steps = len(kwargs["sqrt_etas"]) |
|
|
|
base_diffusion = GaussianDiffusion(**kwargs) |
|
new_sqrt_etas = [] |
|
for ii, etas_current in enumerate(base_diffusion.sqrt_etas): |
|
if ii in self.use_timesteps: |
|
new_sqrt_etas.append(etas_current) |
|
self.timestep_map.append(ii) |
|
kwargs["sqrt_etas"] = np.array(new_sqrt_etas) |
|
super().__init__(**kwargs) |
|
|
|
def p_mean_variance(self, model, *args, **kwargs): |
|
return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) |
|
|
|
def training_losses(self, model, *args, **kwargs): |
|
return super().training_losses(self._wrap_model(model), *args, **kwargs) |
|
|
|
def _wrap_model(self, model): |
|
if isinstance(model, _WrappedModel): |
|
return model |
|
return _WrappedModel(model, self.timestep_map, self.original_num_steps) |
|
|
|
class _WrappedModel: |
|
def __init__(self, model, timestep_map, original_num_steps): |
|
self.model = model |
|
self.timestep_map = timestep_map |
|
self.original_num_steps = original_num_steps |
|
|
|
def __call__(self, x, ts, **kwargs): |
|
map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) |
|
new_ts = map_tensor[ts] |
|
return self.model(x, new_ts, **kwargs) |
|
|
|
class SpacedDiffusionDDPM(GaussianDiffusionDDPM): |
|
""" |
|
A diffusion process which can skip steps in a base diffusion process. |
|
|
|
:param use_timesteps: a collection (sequence or set) of timesteps from the |
|
original diffusion process to retain. |
|
:param kwargs: the kwargs to create the base diffusion process. |
|
""" |
|
|
|
def __init__(self, use_timesteps, **kwargs): |
|
self.use_timesteps = set(use_timesteps) |
|
self.timestep_map = [] |
|
self.original_num_steps = len(kwargs["betas"]) |
|
|
|
base_diffusion = GaussianDiffusionDDPM(**kwargs) |
|
last_alpha_cumprod = 1.0 |
|
new_betas = [] |
|
for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): |
|
if i in self.use_timesteps: |
|
new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) |
|
last_alpha_cumprod = alpha_cumprod |
|
self.timestep_map.append(i) |
|
kwargs["betas"] = np.array(new_betas) |
|
super().__init__(**kwargs) |
|
|
|
def p_mean_variance(self, model, *args, **kwargs): |
|
return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) |
|
|
|
def training_losses(self, model, *args, **kwargs): |
|
return super().training_losses(self._wrap_model(model), *args, **kwargs) |
|
|
|
def _wrap_model(self, model): |
|
if isinstance(model, _WrappedModel): |
|
return model |
|
return _WrappedModel(model, self.timestep_map, self.original_num_steps) |
|
|
|
|