lnyan's picture
Update
d4607d7
import math
from dataclasses import dataclass
import jax
import jax.numpy as jnp
from jax import Array as Tensor
from flax import nnx
from einops import rearrange
from flux.wrapper import TorchWrapper
from flux.math import attention, rope
class EmbedND(nnx.Module):
def __init__(self, dim: int, theta: int, axes_dim: list[int], dtype=jnp.float32, rngs: nnx.Rngs = None):
self.dim = dim
self.theta = theta
self.axes_dim = axes_dim
def __call__(self, ids: Tensor) -> Tensor:
n_axes = ids.shape[-1]
# emb = torch.cat(
# [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
# dim=-3,
# )
emb = jnp.concatenate(
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
axis=-3,
)
# return emb.unsqueeze(1)
return jnp.expand_dims(emb, 1)
def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
"""
Create sinusoidal timestep embeddings.
:param t: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an (N, D) Tensor of positional embeddings.
"""
t = time_factor * t
half = dim // 2
# freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
# t.device
# )
freqs = jnp.exp(-math.log(max_period) * jnp.arange(half, dtype=jnp.float32) / half)
# args = t[:, None].float() * freqs[None]
# embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
args = t[:, None] * freqs[None]
embedding = jnp.concatenate([jnp.cos(args), jnp.sin(args)], axis=-1)
if dim % 2:
# embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
embedding = jnp.concatenate([embedding, jnp.zeros_like(embedding[:, :1])], axis=-1)
# if torch.is_floating_point(t):
# embedding = embedding.to(t)
# return embedding
if jnp.issubdtype(t.dtype, jnp.floating):
embedding = embedding.astype(t.dtype)
return embedding
class MLPEmbedder(nnx.Module):
def __init__(self, in_dim: int, hidden_dim: int, dtype=jnp.float32, rngs: nnx.Rngs = None):
nn = TorchWrapper(rngs=rngs, dtype=dtype)
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
self.silu = nn.SiLU()
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
def __call__(self, x: Tensor) -> Tensor:
return self.out_layer(self.silu(self.in_layer(x)))
class RMSNorm(nnx.Module):
def __init__(self, dim: int, dtype=jnp.float32, rngs: nnx.Rngs = None):
nn = TorchWrapper(rngs=rngs, dtype=dtype)
# self.scale = nn.Parameter(torch.ones(dim))
self.scale = nn.Parameter(jnp.ones((dim,)))
def __call__(self, x: Tensor):
x_dtype = x.dtype
# x = x.float()
x = x.astype(jnp.float32)
# rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
rrms = jax.lax.rsqrt(jnp.mean(x**2, axis=-1, keepdims=True) + 1e-6)
# return (x * rrms).to(dtype=x_dtype) * self.scale
return (x * rrms).astype(x.dtype) * self.scale
RMSNorm_class = RMSNorm
class QKNorm(nnx.Module):
def __init__(self, dim: int, dtype=jnp.float32, rngs: nnx.Rngs = None):
nn = TorchWrapper(rngs=rngs, dtype=dtype)
RMSNorm = nn.declare_with_rng(RMSNorm_class)
self.query_norm = RMSNorm(dim)
self.key_norm = RMSNorm(dim)
def __call__(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
q = self.query_norm(q)
k = self.key_norm(k)
# return q.to(v), k.to(v)
return q.astype(v.dtype), k.astype(v.dtype)
QKNorm_class = QKNorm
class SelfAttention(nnx.Module):
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False, dtype=jnp.float32, rngs: nnx.Rngs = None):
nn = TorchWrapper(rngs=rngs, dtype=dtype)
QKNorm = nn.declare_with_rng(QKNorm_class)
self.num_heads = num_heads
head_dim = dim // num_heads
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.norm = QKNorm(head_dim)
self.proj = nn.Linear(dim, dim)
def __call__(self, x: Tensor, pe: Tensor) -> Tensor:
qkv = self.qkv(x)
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
q, k = self.norm(q, k, v)
x = attention(q, k, v, pe=pe)
x = self.proj(x)
return x
@dataclass
class ModulationOut:
shift: Tensor
scale: Tensor
gate: Tensor
class Modulation(nnx.Module):
def __init__(self, dim: int, double: bool, dtype=jnp.float32, rngs: nnx.Rngs = None):
nn = TorchWrapper(rngs=rngs, dtype=dtype)
self.is_double = double
self.multiplier = 6 if double else 3
self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
def __call__(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
# out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
out = self.lin(nnx.silu(vec))[:, None, :]
out = jnp.split(out, self.multiplier, axis=-1)
return (
ModulationOut(*out[:3]),
ModulationOut(*out[3:]) if self.is_double else None,
)
Modulation_class, SelfAttention_class = Modulation, SelfAttention
class DoubleStreamBlock(nnx.Module):
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, dtype=jnp.float32, rngs: nnx.Rngs = None):
nn = TorchWrapper(rngs=rngs, dtype=dtype)
Modulation, SelfAttention = nn.declare_with_rng(Modulation_class, SelfAttention_class)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.num_heads = num_heads
self.hidden_size = hidden_size
self.img_mod = Modulation(hidden_size, double=True)
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.img_mlp = nn.Sequential(
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
nn.GELU(approximate="tanh"),
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
)
self.txt_mod = Modulation(hidden_size, double=True)
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.txt_mlp = nn.Sequential(
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
nn.GELU(approximate="tanh"),
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
)
def __call__(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]:
img_mod1, img_mod2 = self.img_mod(vec)
txt_mod1, txt_mod2 = self.txt_mod(vec)
# prepare image for attention
img_modulated = self.img_norm1(img)
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
img_qkv = self.img_attn.qkv(img_modulated)
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
# prepare txt for attention
txt_modulated = self.txt_norm1(txt)
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
txt_qkv = self.txt_attn.qkv(txt_modulated)
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
# run actual attention
# q = torch.cat((txt_q, img_q), dim=2)
# k = torch.cat((txt_k, img_k), dim=2)
# v = torch.cat((txt_v, img_v), dim=2)
q = jnp.concatenate((txt_q, img_q), axis=2)
k = jnp.concatenate((txt_k, img_k), axis=2)
v = jnp.concatenate((txt_v, img_v), axis=2)
attn = attention(q, k, v, pe=pe)
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
# calculate the img bloks
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
# calculate the txt bloks
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
return img, txt
class SingleStreamBlock(nnx.Module):
"""
A DiT block with parallel linear layers as described in
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
"""
def __init__(
self,
hidden_size: int,
num_heads: int,
mlp_ratio: float = 4.0,
qk_scale: float | None = None,
dtype=jnp.float32, rngs: nnx.Rngs = None
):
nn = TorchWrapper(rngs=rngs, dtype=dtype)
QKNorm, Modulation = nn.declare_with_rng(QKNorm_class, Modulation_class)
self.hidden_dim = hidden_size
self.num_heads = num_heads
head_dim = hidden_size // num_heads
self.scale = qk_scale or head_dim**-0.5
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
# qkv and mlp_in
self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
# proj and mlp_out
self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
self.norm = QKNorm(head_dim)
self.hidden_size = hidden_size
self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.mlp_act = nn.GELU(approximate="tanh")
self.modulation = Modulation(hidden_size, double=False)
def __call__(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
mod, _ = self.modulation(vec)
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
# qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
qkv, mlp = jnp.split(self.linear1(x_mod), [3 * self.hidden_size,], axis=-1)
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
q, k = self.norm(q, k, v)
# compute attention
attn = attention(q, k, v, pe=pe)
# compute activation in mlp stream, cat again and run second linear layer
# output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
output = self.linear2(jnp.concatenate((attn, self.mlp_act(mlp)), axis=2))
return x + mod.gate * output
class LastLayer(nnx.Module):
def __init__(self, hidden_size: int, patch_size: int, out_channels: int, dtype=jnp.float32, rngs: nnx.Rngs = None):
nn = TorchWrapper(rngs=rngs, dtype=dtype)
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
def __call__(self, x: Tensor, vec: Tensor) -> Tensor:
# shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
shift, scale = jnp.split(self.adaLN_modulation(vec), 2, axis=1)
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
x = self.linear(x)
return x