Doven
update code.
f7009b3
raw
history blame
10 kB
import torch
from abc import ABC
from torch import nn
from torch.nn import functional as F
from .diffusion import DiffusionLoss, DDIMSampler, DDPMSampler
from .transformer import TransformerModel
from .mamba import MambaModel
from .lstm import LstmModel
from .gatemlp import GMLPModel
class ModelDiffusion(nn.Module, ABC):
config = {}
def __init__(self, sequence_length):
super().__init__()
DiffusionLoss.config = self.config
self.criteria = DiffusionLoss()
if self.config.get("post_d_model") is None:
assert self.config["d_model"] == self.config["condition_dim"]
self.sequence_length = sequence_length
# to define model after this function
self.to_condition = nn.Linear(self.config["d_condition"], self.config["d_model"])
self.to_permutation_state = nn.Embedding(self.config["num_permutation"], self.config["d_model"])
self.to_permutation_state.weight = \
nn.Parameter(torch.ones_like(self.to_permutation_state.weight) / self.config["d_model"])
def forward(self, output_shape=None, x_0=None, condition=None, permutation_state=None, **kwargs):
# condition
if condition is not None:
assert len(condition.shape) == 2
assert condition.shape[-1] == self.config["d_condition"]
condition = self.to_condition(condition.to(self.device)[:, None, :])
else: # not use condition
condition = self.to_condition(torch.zeros(size=(1, 1, 1), device=self.device))
# process
if kwargs.get("sample"):
if permutation_state is not False:
permutation_state = torch.randint(0, self.to_permutation_state.num_embeddings, (1,), device=self.device)
permutation_state = self.to_permutation_state(permutation_state)[:, None, :]
else: # permutation state == False
permutation_state = 0.
return self.sample(x=None, condition=condition+permutation_state)
else: # train
if permutation_state is not None:
permutation_state = self.to_permutation_state(permutation_state)[:, None, :]
else: # not use permutation state
permutation_state = 0.
# Given condition c and ground truth token x, compute loss
c = self.model(output_shape, condition+permutation_state)
loss = self.criteria(x=x_0, c=c, **kwargs)
return loss
@torch.no_grad()
def sample(self, x=None, condition=None):
z = self.model([1, self.sequence_length, self.config["d_model"]], condition)
if x is None:
x = torch.randn((1, self.sequence_length, self.config["model_dim"]), device=z.device)
x = self.criteria.sample(x, z)
return x
@property
def device(self):
return next(self.parameters()).device
class ModelMSELoss(nn.Module, ABC):
config = {}
def __init__(self, sequence_length):
super().__init__()
if self.config.get("post_d_model") is None:
assert self.config["d_model"] == self.config["condition_dim"]
self.sequence_length = sequence_length
# to define model after this function
self.to_condition = nn.Linear(self.config["d_condition"], self.config["d_model"])
self.to_permutation_state = nn.Embedding(self.config["num_permutation"], self.config["d_model"])
self.to_permutation_state.weight = \
nn.Parameter(torch.ones_like(self.to_permutation_state.weight) / self.config["d_model"])
def forward(self, output_shape=None, x_0=None, condition=None, permutation_state=None, **kwargs):
# condition
if condition is not None:
assert len(condition.shape) == 2
assert condition.shape[-1] == self.config["d_condition"]
condition = self.to_condition(condition.to(self.device)[:, None, :])
else: # not use condition
condition = self.to_condition(torch.zeros(size=(1, 1, 1), device=self.device))
# process
if kwargs.get("sample"):
if permutation_state is not False:
permutation_state = torch.randint(0, self.to_permutation_state.num_embeddings, (1,), device=self.device)
permutation_state = self.to_permutation_state(permutation_state)[:, None, :]
else: # permutation state == False
permutation_state = 0.
return self.sample(x=None, condition=condition+permutation_state)
else: # train
if permutation_state is not None:
permutation_state = self.to_permutation_state(permutation_state)[:, None, :]
else: # not use permutation state
permutation_state = 0.
# Given condition c and ground truth token x, compute loss
c = self.model(output_shape, condition+permutation_state)
assert c.shape[-1] == x_0.shape[-1], "d_model should be equal to dim_per_token"
# preprocess nan to zero
mask = torch.isnan(x_0)
x_0 = torch.nan_to_num(x_0, 0.)
# get the gradient
loss = F.mse_loss(c, x_0, reduction="none")
loss[mask] = torch.nan
return loss.nanmean()
@torch.no_grad()
def sample(self, x=None, condition=None):
z = self.model([1, self.sequence_length, self.config["d_model"]], condition)
return z
@property
def device(self):
return next(self.parameters()).device
class MambaDiffusion(ModelDiffusion):
def __init__(self, sequence_length, positional_embedding):
super().__init__(sequence_length=sequence_length)
MambaModel.config = self.config
self.model = MambaModel(positional_embedding=positional_embedding)
class TransformerDiffusion(ModelDiffusion):
def __init__(self, sequence_length, positional_embedding):
super().__init__(sequence_length=sequence_length)
TransformerModel.config = self.config
self.model = TransformerModel(positional_embedding=positional_embedding)
class LstmDiffusion(ModelDiffusion):
def __init__(self, sequence_length, positional_embedding):
super().__init__(sequence_length=sequence_length)
LstmModel.config = self.config
self.model = LstmModel(positional_embedding=positional_embedding)
class GMLPDiffusion(ModelDiffusion):
def __init__(self, sequence_length, positional_embedding):
super().__init__(sequence_length=sequence_length)
GMLPModel.config = self.config
self.model = GMLPModel(positional_embedding=positional_embedding)
class MambaMSELoss(ModelMSELoss):
def __init__(self, sequence_length, positional_embedding):
super().__init__(sequence_length=sequence_length)
MambaModel.config = self.config
self.model = MambaModel(positional_embedding=positional_embedding)
class ClassConditionMambaDiffusion(MambaDiffusion):
def __init__(self, sequence_length, positional_embedding, input_class=10):
super().__init__(sequence_length, positional_embedding)
self.get_condition = nn.Sequential(
nn.Linear(input_class, self.config["d_condition"]),
nn.SiLU(),
) # to condition
self.to_permutation_state = nn.Embedding(self.config["num_permutation"], self.config["d_model"])
# condition module
self.to_condition_linear = nn.Linear(self.config["d_condition"], self.config["d_model"])
to_condition_gate = torch.zeros(size=(1, sequence_length, 1))
to_condition_gate[:, -8:, :] = 1.
self.register_buffer("to_condition_gate", to_condition_gate)
# reset to_condition
del self.to_condition
self.to_condition = self._to_condition
def forward(self, output_shape=None, x_0=None, condition=None, **kwargs):
condition = self.get_condition(condition.to(self.device))
return super().forward(output_shape=output_shape, x_0=x_0, condition=condition, **kwargs)
def _to_condition(self, x):
assert len(x.shape) == 3
x = self.to_condition_linear(x)
x = x * self.to_condition_gate
return x
class ClassConditionMambaDiffusionFull(MambaDiffusion):
def __init__(self, sequence_length, positional_embedding, input_class=10, init_noise_intensity=1e-4):
super().__init__(sequence_length, positional_embedding)
self.get_condition = nn.Sequential(
nn.Linear(input_class, self.config["d_condition"]),
nn.LayerNorm(self.config["d_condition"]),
) # to condition
self.to_permutation_state = nn.Embedding(self.config["num_permutation"], self.config["d_model"])
# condition module
self.to_condition_linear = nn.Linear(self.config["d_condition"], self.config["d_model"])
self.to_condition_conv = nn.Sequential(
nn.Conv1d(1, sequence_length, 9, 1, 4),
nn.GroupNorm(num_groups=1, num_channels=sequence_length),
nn.Conv1d(sequence_length, sequence_length, 9, 1, 4),
) # [batch_size, sequence_length, d_model]
# reset to_condition
del self.to_condition
def forward(self, output_shape=None, x_0=None, condition=None, **kwargs):
if kwargs.get("pre_training"):
self.to_condition = self._zero_condition
condition = None
else: # train with condition
self.to_condition = self._to_condition
condition = self.get_condition(condition.to(self.device))
return super().forward(output_shape=output_shape, x_0=x_0, condition=condition, **kwargs)
def _to_condition(self, x):
assert len(x.shape) == 3
x = self.to_condition_linear(x) # [batch_size, 1, d_model]
x = self.to_condition_conv(x) # [batch_size, sequence_length, d_model]
return x
def _zero_condition(self, x):
return torch.zeros(size=(x.shape[0], self.sequence_length, self.config["d_model"]), device=x.device)