|
import torch |
|
from abc import ABC |
|
from torch import nn |
|
from torch.nn import functional as F |
|
from .diffusion import DiffusionLoss, DDIMSampler, DDPMSampler |
|
from .transformer import TransformerModel |
|
from .mamba import MambaModel |
|
from .lstm import LstmModel |
|
from .gatemlp import GMLPModel |
|
|
|
|
|
|
|
|
|
class ModelDiffusion(nn.Module, ABC): |
|
config = {} |
|
|
|
def __init__(self, sequence_length): |
|
super().__init__() |
|
DiffusionLoss.config = self.config |
|
self.criteria = DiffusionLoss() |
|
if self.config.get("post_d_model") is None: |
|
assert self.config["d_model"] == self.config["condition_dim"] |
|
self.sequence_length = sequence_length |
|
|
|
self.to_condition = nn.Linear(self.config["d_condition"], self.config["d_model"]) |
|
self.to_permutation_state = nn.Embedding(self.config["num_permutation"], self.config["d_model"]) |
|
self.to_permutation_state.weight = \ |
|
nn.Parameter(torch.ones_like(self.to_permutation_state.weight) / self.config["d_model"]) |
|
|
|
def forward(self, output_shape=None, x_0=None, condition=None, permutation_state=None, **kwargs): |
|
|
|
if condition is not None: |
|
assert len(condition.shape) == 2 |
|
assert condition.shape[-1] == self.config["d_condition"] |
|
condition = self.to_condition(condition.to(self.device)[:, None, :]) |
|
else: |
|
condition = self.to_condition(torch.zeros(size=(1, 1, 1), device=self.device)) |
|
|
|
if kwargs.get("sample"): |
|
if permutation_state is not False: |
|
permutation_state = torch.randint(0, self.to_permutation_state.num_embeddings, (1,), device=self.device) |
|
permutation_state = self.to_permutation_state(permutation_state)[:, None, :] |
|
else: |
|
permutation_state = 0. |
|
return self.sample(x=None, condition=condition+permutation_state) |
|
else: |
|
if permutation_state is not None: |
|
permutation_state = self.to_permutation_state(permutation_state)[:, None, :] |
|
else: |
|
permutation_state = 0. |
|
|
|
c = self.model(output_shape, condition+permutation_state) |
|
loss = self.criteria(x=x_0, c=c, **kwargs) |
|
return loss |
|
|
|
@torch.no_grad() |
|
def sample(self, x=None, condition=None): |
|
z = self.model([1, self.sequence_length, self.config["d_model"]], condition) |
|
if x is None: |
|
x = torch.randn((1, self.sequence_length, self.config["model_dim"]), device=z.device) |
|
x = self.criteria.sample(x, z) |
|
return x |
|
|
|
@property |
|
def device(self): |
|
return next(self.parameters()).device |
|
|
|
|
|
class ModelMSELoss(nn.Module, ABC): |
|
config = {} |
|
|
|
def __init__(self, sequence_length): |
|
super().__init__() |
|
if self.config.get("post_d_model") is None: |
|
assert self.config["d_model"] == self.config["condition_dim"] |
|
self.sequence_length = sequence_length |
|
|
|
self.to_condition = nn.Linear(self.config["d_condition"], self.config["d_model"]) |
|
self.to_permutation_state = nn.Embedding(self.config["num_permutation"], self.config["d_model"]) |
|
self.to_permutation_state.weight = \ |
|
nn.Parameter(torch.ones_like(self.to_permutation_state.weight) / self.config["d_model"]) |
|
|
|
def forward(self, output_shape=None, x_0=None, condition=None, permutation_state=None, **kwargs): |
|
|
|
if condition is not None: |
|
assert len(condition.shape) == 2 |
|
assert condition.shape[-1] == self.config["d_condition"] |
|
condition = self.to_condition(condition.to(self.device)[:, None, :]) |
|
else: |
|
condition = self.to_condition(torch.zeros(size=(1, 1, 1), device=self.device)) |
|
|
|
if kwargs.get("sample"): |
|
if permutation_state is not False: |
|
permutation_state = torch.randint(0, self.to_permutation_state.num_embeddings, (1,), device=self.device) |
|
permutation_state = self.to_permutation_state(permutation_state)[:, None, :] |
|
else: |
|
permutation_state = 0. |
|
return self.sample(x=None, condition=condition+permutation_state) |
|
else: |
|
if permutation_state is not None: |
|
permutation_state = self.to_permutation_state(permutation_state)[:, None, :] |
|
else: |
|
permutation_state = 0. |
|
|
|
c = self.model(output_shape, condition+permutation_state) |
|
assert c.shape[-1] == x_0.shape[-1], "d_model should be equal to dim_per_token" |
|
|
|
mask = torch.isnan(x_0) |
|
x_0 = torch.nan_to_num(x_0, 0.) |
|
|
|
loss = F.mse_loss(c, x_0, reduction="none") |
|
loss[mask] = torch.nan |
|
return loss.nanmean() |
|
|
|
@torch.no_grad() |
|
def sample(self, x=None, condition=None): |
|
z = self.model([1, self.sequence_length, self.config["d_model"]], condition) |
|
return z |
|
|
|
@property |
|
def device(self): |
|
return next(self.parameters()).device |
|
|
|
|
|
|
|
|
|
class MambaDiffusion(ModelDiffusion): |
|
def __init__(self, sequence_length, positional_embedding): |
|
super().__init__(sequence_length=sequence_length) |
|
MambaModel.config = self.config |
|
self.model = MambaModel(positional_embedding=positional_embedding) |
|
|
|
|
|
class TransformerDiffusion(ModelDiffusion): |
|
def __init__(self, sequence_length, positional_embedding): |
|
super().__init__(sequence_length=sequence_length) |
|
TransformerModel.config = self.config |
|
self.model = TransformerModel(positional_embedding=positional_embedding) |
|
|
|
|
|
class LstmDiffusion(ModelDiffusion): |
|
def __init__(self, sequence_length, positional_embedding): |
|
super().__init__(sequence_length=sequence_length) |
|
LstmModel.config = self.config |
|
self.model = LstmModel(positional_embedding=positional_embedding) |
|
|
|
|
|
class GMLPDiffusion(ModelDiffusion): |
|
def __init__(self, sequence_length, positional_embedding): |
|
super().__init__(sequence_length=sequence_length) |
|
GMLPModel.config = self.config |
|
self.model = GMLPModel(positional_embedding=positional_embedding) |
|
|
|
|
|
|
|
|
|
class MambaMSELoss(ModelMSELoss): |
|
def __init__(self, sequence_length, positional_embedding): |
|
super().__init__(sequence_length=sequence_length) |
|
MambaModel.config = self.config |
|
self.model = MambaModel(positional_embedding=positional_embedding) |
|
|
|
|
|
|
|
|
|
class ClassConditionMambaDiffusion(MambaDiffusion): |
|
def __init__(self, sequence_length, positional_embedding, input_class=10): |
|
super().__init__(sequence_length, positional_embedding) |
|
self.get_condition = nn.Sequential( |
|
nn.Linear(input_class, self.config["d_condition"]), |
|
nn.SiLU(), |
|
) |
|
self.to_permutation_state = nn.Embedding(self.config["num_permutation"], self.config["d_model"]) |
|
|
|
self.to_condition_linear = nn.Linear(self.config["d_condition"], self.config["d_model"]) |
|
to_condition_gate = torch.zeros(size=(1, sequence_length, 1)) |
|
to_condition_gate[:, -8:, :] = 1. |
|
self.register_buffer("to_condition_gate", to_condition_gate) |
|
|
|
del self.to_condition |
|
self.to_condition = self._to_condition |
|
|
|
def forward(self, output_shape=None, x_0=None, condition=None, **kwargs): |
|
condition = self.get_condition(condition.to(self.device)) |
|
return super().forward(output_shape=output_shape, x_0=x_0, condition=condition, **kwargs) |
|
|
|
def _to_condition(self, x): |
|
assert len(x.shape) == 3 |
|
x = self.to_condition_linear(x) |
|
x = x * self.to_condition_gate |
|
return x |
|
|
|
|
|
class ClassConditionMambaDiffusionFull(MambaDiffusion): |
|
def __init__(self, sequence_length, positional_embedding, input_class=10, init_noise_intensity=1e-4): |
|
super().__init__(sequence_length, positional_embedding) |
|
self.get_condition = nn.Sequential( |
|
nn.Linear(input_class, self.config["d_condition"]), |
|
nn.LayerNorm(self.config["d_condition"]), |
|
) |
|
self.to_permutation_state = nn.Embedding(self.config["num_permutation"], self.config["d_model"]) |
|
|
|
self.to_condition_linear = nn.Linear(self.config["d_condition"], self.config["d_model"]) |
|
self.to_condition_conv = nn.Sequential( |
|
nn.Conv1d(1, sequence_length, 9, 1, 4), |
|
nn.GroupNorm(num_groups=1, num_channels=sequence_length), |
|
nn.Conv1d(sequence_length, sequence_length, 9, 1, 4), |
|
) |
|
|
|
del self.to_condition |
|
|
|
def forward(self, output_shape=None, x_0=None, condition=None, **kwargs): |
|
if kwargs.get("pre_training"): |
|
self.to_condition = self._zero_condition |
|
condition = None |
|
else: |
|
self.to_condition = self._to_condition |
|
condition = self.get_condition(condition.to(self.device)) |
|
return super().forward(output_shape=output_shape, x_0=x_0, condition=condition, **kwargs) |
|
|
|
def _to_condition(self, x): |
|
assert len(x.shape) == 3 |
|
x = self.to_condition_linear(x) |
|
x = self.to_condition_conv(x) |
|
return x |
|
|
|
def _zero_condition(self, x): |
|
return torch.zeros(size=(x.shape[0], self.sequence_length, self.config["d_model"]), device=x.device) |