Blackroot's picture
Upload 9 files
4a9ad28 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
# This architecture was my attempt at the following Simple Diffusion paper with some modifications:
# https://arxiv.org/pdf/2410.19324v1
# Very similar to GeGLU or SwiGLU, there's a learned gate FN, uses arctan as the activation fn.
class xATGLU(nn.Module):
def __init__(self, input_dim, output_dim, bias=True):
super().__init__()
# GATE path | VALUE path
self.proj = nn.Linear(input_dim, output_dim * 2, bias=bias)
nn.init.kaiming_normal_(self.proj.weight, nonlinearity='linear')
self.alpha = nn.Parameter(torch.zeros(1))
self.half_pi = torch.pi / 2
self.inv_pi = 1 / torch.pi
def forward(self, x):
projected = self.proj(x)
gate_path, value_path = projected.chunk(2, dim=-1)
# Apply arctan gating with expanded range via learned alpha -- https://arxiv.org/pdf/2405.20768
gate = (torch.arctan(gate_path) + self.half_pi) * self.inv_pi
expanded_gate = gate * (1 + 2 * self.alpha) - self.alpha
return expanded_gate * value_path # g(x) × y
class ResBlock(nn.Module):
def __init__(self, channels):
super().__init__()
self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
self.norm1 = nn.GroupNorm(32, channels)
self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
self.norm2 = nn.GroupNorm(32, channels)
def forward(self, x):
h = self.conv1(F.silu(self.norm1(x)))
h = self.conv2(F.silu(self.norm2(h)))
return x + h
class TransformerBlock(nn.Module):
def __init__(self, channels, num_heads=8):
super().__init__()
self.norm1 = nn.LayerNorm(channels)
self.attn = nn.MultiheadAttention(channels, num_heads)
self.norm2 = nn.LayerNorm(channels)
self.mlp = nn.Sequential(
xATGLU(channels, 4 * channels),
nn.Linear(4 * channels, channels)
)
def forward(self, x):
# Reshape for attention [B, C, H, W] -> [H*W, B, C]
b, c, h, w = x.shape
spatial_size = h * w
x = x.flatten(2).permute(2, 0, 1)
# Self attention
h_attn = self.norm1(x)
h_attn, _ = self.attn(h_attn, h_attn, h_attn)
x = x + h_attn
# MLP
h_mlp = self.norm2(x)
h_mlp = self.mlp(h_mlp)
x = x + h_mlp
# Reshape back [H*W, B, C] -> [B, C, H, W]
return x.permute(1, 2, 0).reshape(b, c, h, w)
class LevelBlock(nn.Module):
def __init__(self, channels, num_blocks, block_type='res'):
super().__init__()
self.blocks = nn.ModuleList()
for _ in range(num_blocks):
if block_type == 'transformer':
self.blocks.append(TransformerBlock(channels))
else:
self.blocks.append(ResBlock(channels))
def forward(self, x):
for block in self.blocks:
x = block(x)
return x
class AsymmetricResidualUDiT(nn.Module):
def __init__(self,
in_channels=3, # Input color channels
base_channels=128, # Initial feature size, dramatically increases parameter size of network.
patch_size=2, # Smaller patches dramatically increases flops and compute expenses. Recommend >=4 unless you have real compute.
num_levels=3, # Feature downsample, essentially the unet depth -- so we down/upsample three times. Dramatically increases parameters as you increase.
encoder_blocks=3, # Can be different number of blocks VS decoder_blocks
decoder_blocks=7, # Can be different number of blocks VS encoder_blocks
encoder_transformer_thresh=2, #When to start using transformer blocks instead of res blocks in the encoder. (>=)
decoder_transformer_thresh=4, #When to stop using transformer blocks instead of res blocks in the decoder. (<=)
mid_blocks=16 # Number of middle transformer blocks. Relatively cheap as this is at the bottom of the unet feature bottleneck.
):
super().__init__()
# Initial projection from image space
self.patch_embed = nn.Conv2d(in_channels, base_channels,
kernel_size=patch_size, stride=patch_size)
# Create encoder levels
self.encoders = nn.ModuleList()
curr_channels = base_channels
for level in range(num_levels):
# Create the main processing blocks for this level
use_transformer = level >= encoder_transformer_thresh # Use transformers for latter levels
# Encoder blocks -- encoder_blocks
self.encoders.append(
LevelBlock(curr_channels, encoder_blocks, use_transformer)
)
# Add channel scaling for next level
# Doubles the size of the feature space for each step, except for the last level.
if level < num_levels - 1:
self.encoders.append(
nn.Conv2d(curr_channels, curr_channels * 2, 1)
)
curr_channels *= 2
# Middle transformer blocks -- mid_blocks
self.middle = nn.ModuleList([
TransformerBlock(curr_channels) for _ in range(mid_blocks)
])
# Create decoder levels
self.decoders = nn.ModuleList()
for level in range(num_levels):
# Create the main processing blocks for this level
use_transformer = level <= decoder_transformer_thresh # Use transformers for early levels (inverse of encoder)
# Decoder blocks -- decoder_blocks
self.decoders.append(
LevelBlock(curr_channels, decoder_blocks, use_transformer)
)
# Add channel scaling for next level
# Halves the size of the feature space for each step, except for the last level.
if level < num_levels - 1:
self.decoders.append(
nn.Conv2d(curr_channels, curr_channels // 2, 1)
)
curr_channels //= 2
# Final projection back to image space
self.final_proj = nn.ConvTranspose2d(base_channels, in_channels,
kernel_size=patch_size, stride=patch_size)
def downsample(self, x):
return F.avg_pool2d(x, kernel_size=2)
def upsample(self, x):
return F.interpolate(x, scale_factor=2, mode='nearest')
def forward(self, x, t=None):
# Start by patch embedding the inputs.
x = self.patch_embed(x)
# Track residual path and features at each spatial level
# The paper was very specific about the residual flow path, I tried my best to copy how they described it.
# *Per resolution e.g. per num_level resolution block more or less
# f(x) = fu( U(fm(D(h)) - D(h)) + h ) where h = fd(x)
#
# Where
# 1. h = fd(x) : Encoder path processes input
# 2. D(h) : Downsample the encoded features
# 3. fm(D(h)) : Middle transformer blocks process downsampled features
# 4. fm(D(h))-D(h): Subtract original downsampled features (residual connection)
# 5. U(...) : Upsample the processed features
# 6. ... + h : Add back original encoder features (skip connection)
# 7. fu(...) : Decoder path processes the combined features
residuals = []
curr_res = x
# Encoder path (computing h = fd(x))
h = x
for i, blocks in enumerate(self.encoders):
if isinstance(blocks, LevelBlock):
h = blocks(h)
else:
# Save residual before downsampling
residuals.append(curr_res)
# Downsample and update current residual
h = self.downsample(blocks(h))
curr_res = h
# Middle blocks (fm)
x = h
for block in self.middle:
x = block(x)
# Subtract the residual at this level (D(h))
x = x - curr_res
# Decoder path (fu)
for i, blocks in enumerate(self.decoders):
if isinstance(blocks, LevelBlock):
x = blocks(x)
else:
# Channel reduction
x = blocks(x)
# Upsample
x = self.upsample(x)
# Add residual from encoder at this level, LIFO, last residual added is the first we want, since it's this u-shape.
curr_res = residuals.pop()
x = x + curr_res
# Final projection
x = self.final_proj(x)
return x