|
|
|
|
|
|
|
from functools import partial
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from .modulate_layers import modulate
|
|
from ..utils.helpers import to_2tuple
|
|
|
|
|
|
class MLP(nn.Module):
|
|
"""MLP as used in Vision Transformer, MLP-Mixer and related networks"""
|
|
|
|
def __init__(
|
|
self,
|
|
in_channels,
|
|
hidden_channels=None,
|
|
out_features=None,
|
|
act_layer=nn.GELU,
|
|
norm_layer=None,
|
|
bias=True,
|
|
drop=0.0,
|
|
use_conv=False,
|
|
device=None,
|
|
dtype=None,
|
|
):
|
|
factory_kwargs = {"device": device, "dtype": dtype}
|
|
super().__init__()
|
|
out_features = out_features or in_channels
|
|
hidden_channels = hidden_channels or in_channels
|
|
bias = to_2tuple(bias)
|
|
drop_probs = to_2tuple(drop)
|
|
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
|
|
|
|
self.fc1 = linear_layer(
|
|
in_channels, hidden_channels, bias=bias[0], **factory_kwargs
|
|
)
|
|
self.act = act_layer()
|
|
self.drop1 = nn.Dropout(drop_probs[0])
|
|
self.norm = (
|
|
norm_layer(hidden_channels, **factory_kwargs)
|
|
if norm_layer is not None
|
|
else nn.Identity()
|
|
)
|
|
self.fc2 = linear_layer(
|
|
hidden_channels, out_features, bias=bias[1], **factory_kwargs
|
|
)
|
|
self.drop2 = nn.Dropout(drop_probs[1])
|
|
|
|
def forward(self, x):
|
|
x = self.fc1(x)
|
|
x = self.act(x)
|
|
x = self.drop1(x)
|
|
x = self.norm(x)
|
|
x = self.fc2(x)
|
|
x = self.drop2(x)
|
|
return x
|
|
|
|
|
|
|
|
class MLPEmbedder(nn.Module):
|
|
"""copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py"""
|
|
def __init__(self, in_dim: int, hidden_dim: int, device=None, dtype=None):
|
|
factory_kwargs = {"device": device, "dtype": dtype}
|
|
super().__init__()
|
|
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True, **factory_kwargs)
|
|
self.silu = nn.SiLU()
|
|
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True, **factory_kwargs)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
return self.out_layer(self.silu(self.in_layer(x)))
|
|
|
|
|
|
class FinalLayer(nn.Module):
|
|
"""The final layer of DiT."""
|
|
|
|
def __init__(
|
|
self, hidden_size, patch_size, out_channels, act_layer, device=None, dtype=None
|
|
):
|
|
factory_kwargs = {"device": device, "dtype": dtype}
|
|
super().__init__()
|
|
|
|
|
|
self.norm_final = nn.LayerNorm(
|
|
hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
|
|
)
|
|
if isinstance(patch_size, int):
|
|
self.linear = nn.Linear(
|
|
hidden_size,
|
|
patch_size * patch_size * out_channels,
|
|
bias=True,
|
|
**factory_kwargs
|
|
)
|
|
else:
|
|
self.linear = nn.Linear(
|
|
hidden_size,
|
|
patch_size[0] * patch_size[1] * patch_size[2] * out_channels,
|
|
bias=True,
|
|
)
|
|
nn.init.zeros_(self.linear.weight)
|
|
nn.init.zeros_(self.linear.bias)
|
|
|
|
|
|
self.adaLN_modulation = nn.Sequential(
|
|
act_layer(),
|
|
nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs),
|
|
)
|
|
|
|
nn.init.zeros_(self.adaLN_modulation[1].weight)
|
|
nn.init.zeros_(self.adaLN_modulation[1].bias)
|
|
|
|
def forward(self, x, c):
|
|
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
|
|
x = modulate(self.norm_final(x), shift=shift, scale=scale)
|
|
x = self.linear(x)
|
|
return x
|
|
|