from dataclasses import dataclass from einops import rearrange import jax import jax.numpy as jnp from jax import Array as Tensor from flax import nnx from flux.wrapper import TorchWrapper from flux.math import dot_product_attention @dataclass class AutoEncoderParams: resolution: int in_channels: int ch: int out_ch: int ch_mult: list[int] num_res_blocks: int z_channels: int scale_factor: float shift_factor: float swish = nnx.swish class AttnBlock(nnx.Module): def __init__(self, in_channels: int, dtype=jnp.float32, rngs: nnx.Rngs = None): nn = TorchWrapper(rngs, dtype=dtype) self.in_channels = in_channels self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1) self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1) self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1) self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1) def attention(self, h_: Tensor) -> Tensor: h_ = self.norm(h_) q = self.q(h_) k = self.k(h_) v = self.v(h_) # b, c, h, w = q.shape b, h, w, c = q.shape # q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous() # k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous() # v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous() q = rearrange(q, "b h w c -> b 1 (h w) c") k = rearrange(k, "b h w c -> b 1 (h w) c") v = rearrange(v, "b h w c -> b 1 (h w) c") # h_ = nn.functional.scaled_dot_product_attention(q, k, v) h_ = dot_product_attention(q, k, v) # return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b) return rearrange(h_, "b 1 (h w) c -> b h w c", h=h, w=w, c=c, b=b) def __call__(self, x: Tensor) -> Tensor: return x + self.proj_out(self.attention(x)) class ResnetBlock(nnx.Module): def __init__(self, in_channels: int, out_channels: int, dtype=jnp.float32, rngs: nnx.Rngs = None): nn = TorchWrapper(rngs, dtype=dtype) self.in_channels = in_channels out_channels = in_channels if out_channels is None else out_channels self.out_channels = out_channels self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True) self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) if self.in_channels != self.out_channels: self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) def __call__(self, x): h = x h = self.norm1(h) h = swish(h) h = self.conv1(h) h = self.norm2(h) h = swish(h) h = self.conv2(h) if self.in_channels != self.out_channels: x = self.nin_shortcut(x) return x + h class Downsample(nnx.Module): def __init__(self, in_channels: int, dtype=jnp.float32, rngs: nnx.Rngs = None): nn = TorchWrapper(rngs, dtype=dtype) # no asymmetric padding in torch conv, must do it ourselves self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) def __call__(self, x: Tensor): # pad = (0, 1, 0, 1) # x = nn.functional.pad(x, pad, mode="constant", value=0) x = jnp.pad(x, ((0, 0), (0, 1), (0, 1), (0, 0)), mode="constant") x = self.conv(x) return x class Upsample(nnx.Module): def __init__(self, in_channels: int, dtype=jnp.float32, rngs: nnx.Rngs = None): nn = TorchWrapper(rngs, dtype=dtype) self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) def __call__(self, x: Tensor): # x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") B, H, W, C = x.shape x = jax.image.resize(x, (B, H * 2, W * 2, C), method="nearest") x = self.conv(x) return x ResnetBlock_class, Downsample_class, Upsample_class, AttnBlock_class = ResnetBlock, Downsample, Upsample, AttnBlock class Encoder(nnx.Module): def __init__( self, resolution: int, in_channels: int, ch: int, ch_mult: list[int], num_res_blocks: int, z_channels: int, dtype=jnp.float32, rngs: nnx.Rngs = None ): nn = TorchWrapper(rngs, dtype=dtype) ResnetBlock, Downsample, Upsample, AttnBlock = nn.declare_with_rng(ResnetBlock_class, Downsample_class, Upsample_class, AttnBlock_class) self.ch = ch self.num_resolutions = len(ch_mult) self.num_res_blocks = num_res_blocks self.resolution = resolution self.in_channels = in_channels # downsampling self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1) curr_res = resolution in_ch_mult = (1,) + tuple(ch_mult) self.in_ch_mult = in_ch_mult self.down = nn.ModuleList() block_in = self.ch for i_level in range(self.num_resolutions): block = nn.ModuleList() attn = nn.ModuleList() block_in = ch * in_ch_mult[i_level] block_out = ch * ch_mult[i_level] for _ in range(self.num_res_blocks): block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) block_in = block_out down = nn.Module() down.block = block down.attn = attn if i_level != self.num_resolutions - 1: down.downsample = Downsample(block_in) curr_res = curr_res // 2 self.down.append(down) # middle self.mid = nn.Module() self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) self.mid.attn_1 = AttnBlock(block_in) self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) # end self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1) def __call__(self, x: Tensor) -> Tensor: # downsampling hs = [self.conv_in(x)] for i_level in range(self.num_resolutions): for i_block in range(self.num_res_blocks): h = self.down[i_level].block[i_block](hs[-1]) if len(self.down[i_level].attn) > 0: h = self.down[i_level].attn[i_block](h) hs.append(h) if i_level != self.num_resolutions - 1: hs.append(self.down[i_level].downsample(hs[-1])) # middle h = hs[-1] h = self.mid.block_1(h) h = self.mid.attn_1(h) h = self.mid.block_2(h) # end h = self.norm_out(h) h = swish(h) h = self.conv_out(h) return h class Decoder(nnx.Module): def __init__( self, ch: int, out_ch: int, ch_mult: list[int], num_res_blocks: int, in_channels: int, resolution: int, z_channels: int, dtype=jnp.float32, rngs: nnx.Rngs = None ): nn = TorchWrapper(rngs, dtype=dtype) ResnetBlock, Downsample, Upsample, AttnBlock = nn.declare_with_rng(ResnetBlock_class, Downsample_class, Upsample_class, AttnBlock_class) self.ch = ch self.num_resolutions = len(ch_mult) self.num_res_blocks = num_res_blocks self.resolution = resolution self.in_channels = in_channels self.ffactor = 2 ** (self.num_resolutions - 1) # compute in_ch_mult, block_in and curr_res at lowest res block_in = ch * ch_mult[self.num_resolutions - 1] curr_res = resolution // 2 ** (self.num_resolutions - 1) self.z_shape = (1, z_channels, curr_res, curr_res) # z to block_in self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1) # middle self.mid = nn.Module() self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) self.mid.attn_1 = AttnBlock(block_in) self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) # upsampling self.up = nn.ModuleList() for i_level in reversed(range(self.num_resolutions)): block = nn.ModuleList() attn = nn.ModuleList() block_out = ch * ch_mult[i_level] for _ in range(self.num_res_blocks + 1): block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) block_in = block_out up = nn.Module() up.block = block up.attn = attn if i_level != 0: up.upsample = Upsample(block_in) curr_res = curr_res * 2 self.up.insert(0, up) # prepend to get consistent order # end self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) def __call__(self, z: Tensor) -> Tensor: # z to block_in h = self.conv_in(z) # middle h = self.mid.block_1(h) h = self.mid.attn_1(h) h = self.mid.block_2(h) # upsampling for i_level in reversed(range(self.num_resolutions)): for i_block in range(self.num_res_blocks + 1): h = self.up[i_level].block[i_block](h) if len(self.up[i_level].attn) > 0: h = self.up[i_level].attn[i_block](h) if i_level != 0: h = self.up[i_level].upsample(h) # end h = self.norm_out(h) h = swish(h) h = self.conv_out(h) return h class DiagonalGaussian(nnx.Module): def __init__(self, sample: bool = True, chunk_dim: int = -1, dtype=jnp.float32, rngs: nnx.Rngs = None): self.sample = sample self.chunk_dim = chunk_dim self.rngs = rngs self.dtype = dtype def __call__(self, z: Tensor) -> Tensor: # mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim) mean, logvar = jnp.split(z, 2, axis=self.chunk_dim) if self.sample: # std = torch.exp(0.5 * logvar) # return mean + std * torch.randn_like(mean) std = jnp.exp(0.5 * logvar) return mean + std * jax.random.normal(self.rngs(), mean.shape) else: return mean Encoder_class, Decoder_class, DiagonalGaussian_class = Encoder, Decoder, DiagonalGaussian class AutoEncoder(nnx.Module): def __init__(self, params: AutoEncoderParams, dtype=jnp.float32, rngs: nnx.Rngs = None): nn = TorchWrapper(rngs, dtype=dtype) Encoder, Decoder, DiagonalGaussian = nn.declare_with_rng(Encoder_class, Decoder_class, DiagonalGaussian_class) self.encoder = Encoder( resolution=params.resolution, in_channels=params.in_channels, ch=params.ch, ch_mult=params.ch_mult, num_res_blocks=params.num_res_blocks, z_channels=params.z_channels, ) self.decoder = Decoder( resolution=params.resolution, in_channels=params.in_channels, ch=params.ch, out_ch=params.out_ch, ch_mult=params.ch_mult, num_res_blocks=params.num_res_blocks, z_channels=params.z_channels, ) self.reg = DiagonalGaussian() self.scale_factor = params.scale_factor self.shift_factor = params.shift_factor def encode(self, x: Tensor) -> Tensor: z = self.reg(self.encoder(x)) z = self.scale_factor * (z - self.shift_factor) return z def decode(self, z: Tensor) -> Tensor: z = z / self.scale_factor + self.shift_factor return self.decoder(z) def __call__(self, x: Tensor) -> Tensor: return self.decode(self.encode(x))