import torch import torch.nn as nn import torch.nn.functional as F # This architecture was my attempt at the following Simple Diffusion paper with some modifications: # https://arxiv.org/pdf/2410.19324v1 # Very similar to GeGLU or SwiGLU, there's a learned gate FN, uses arctan as the activation fn. class xATGLU(nn.Module): def __init__(self, input_dim, output_dim, bias=True): super().__init__() # GATE path | VALUE path self.proj = nn.Linear(input_dim, output_dim * 2, bias=bias) nn.init.kaiming_normal_(self.proj.weight, nonlinearity='linear') self.alpha = nn.Parameter(torch.zeros(1)) self.half_pi = torch.pi / 2 self.inv_pi = 1 / torch.pi def forward(self, x): projected = self.proj(x) gate_path, value_path = projected.chunk(2, dim=-1) # Apply arctan gating with expanded range via learned alpha -- https://arxiv.org/pdf/2405.20768 gate = (torch.arctan(gate_path) + self.half_pi) * self.inv_pi expanded_gate = gate * (1 + 2 * self.alpha) - self.alpha return expanded_gate * value_path # g(x) × y class ResBlock(nn.Module): def __init__(self, channels): super().__init__() self.conv1 = nn.Conv2d(channels, channels, 3, padding=1) self.norm1 = nn.GroupNorm(32, channels) self.conv2 = nn.Conv2d(channels, channels, 3, padding=1) self.norm2 = nn.GroupNorm(32, channels) def forward(self, x): h = self.conv1(F.silu(self.norm1(x))) h = self.conv2(F.silu(self.norm2(h))) return x + h class TransformerBlock(nn.Module): def __init__(self, channels, num_heads=8): super().__init__() self.norm1 = nn.LayerNorm(channels) self.attn = nn.MultiheadAttention(channels, num_heads) self.norm2 = nn.LayerNorm(channels) self.mlp = nn.Sequential( xATGLU(channels, 4 * channels), nn.Linear(4 * channels, channels) ) def forward(self, x): # Reshape for attention [B, C, H, W] -> [H*W, B, C] b, c, h, w = x.shape spatial_size = h * w x = x.flatten(2).permute(2, 0, 1) # Self attention h_attn = self.norm1(x) h_attn, _ = self.attn(h_attn, h_attn, h_attn) x = x + h_attn # MLP h_mlp = self.norm2(x) h_mlp = self.mlp(h_mlp) x = x + h_mlp # Reshape back [H*W, B, C] -> [B, C, H, W] return x.permute(1, 2, 0).reshape(b, c, h, w) class LevelBlock(nn.Module): def __init__(self, channels, num_blocks, block_type='res'): super().__init__() self.blocks = nn.ModuleList() for _ in range(num_blocks): if block_type == 'transformer': self.blocks.append(TransformerBlock(channels)) else: self.blocks.append(ResBlock(channels)) def forward(self, x): for block in self.blocks: x = block(x) return x class AsymmetricResidualUDiT(nn.Module): def __init__(self, in_channels=3, # Input color channels base_channels=128, # Initial feature size, dramatically increases parameter size of network. patch_size=2, # Smaller patches dramatically increases flops and compute expenses. Recommend >=4 unless you have real compute. num_levels=3, # Feature downsample, essentially the unet depth -- so we down/upsample three times. Dramatically increases parameters as you increase. encoder_blocks=3, # Can be different number of blocks VS decoder_blocks decoder_blocks=7, # Can be different number of blocks VS encoder_blocks encoder_transformer_thresh=2, #When to start using transformer blocks instead of res blocks in the encoder. (>=) decoder_transformer_thresh=4, #When to stop using transformer blocks instead of res blocks in the decoder. (<=) mid_blocks=16 # Number of middle transformer blocks. Relatively cheap as this is at the bottom of the unet feature bottleneck. ): super().__init__() # Initial projection from image space self.patch_embed = nn.Conv2d(in_channels, base_channels, kernel_size=patch_size, stride=patch_size) # Create encoder levels self.encoders = nn.ModuleList() curr_channels = base_channels for level in range(num_levels): # Create the main processing blocks for this level use_transformer = level >= encoder_transformer_thresh # Use transformers for latter levels # Encoder blocks -- encoder_blocks self.encoders.append( LevelBlock(curr_channels, encoder_blocks, use_transformer) ) # Add channel scaling for next level # Doubles the size of the feature space for each step, except for the last level. if level < num_levels - 1: self.encoders.append( nn.Conv2d(curr_channels, curr_channels * 2, 1) ) curr_channels *= 2 # Middle transformer blocks -- mid_blocks self.middle = nn.ModuleList([ TransformerBlock(curr_channels) for _ in range(mid_blocks) ]) # Create decoder levels self.decoders = nn.ModuleList() for level in range(num_levels): # Create the main processing blocks for this level use_transformer = level <= decoder_transformer_thresh # Use transformers for early levels (inverse of encoder) # Decoder blocks -- decoder_blocks self.decoders.append( LevelBlock(curr_channels, decoder_blocks, use_transformer) ) # Add channel scaling for next level # Halves the size of the feature space for each step, except for the last level. if level < num_levels - 1: self.decoders.append( nn.Conv2d(curr_channels, curr_channels // 2, 1) ) curr_channels //= 2 # Final projection back to image space self.final_proj = nn.ConvTranspose2d(base_channels, in_channels, kernel_size=patch_size, stride=patch_size) def downsample(self, x): return F.avg_pool2d(x, kernel_size=2) def upsample(self, x): return F.interpolate(x, scale_factor=2, mode='nearest') def forward(self, x, t=None): # Start by patch embedding the inputs. x = self.patch_embed(x) # Track residual path and features at each spatial level # The paper was very specific about the residual flow path, I tried my best to copy how they described it. # *Per resolution e.g. per num_level resolution block more or less # f(x) = fu( U(fm(D(h)) - D(h)) + h ) where h = fd(x) # # Where # 1. h = fd(x) : Encoder path processes input # 2. D(h) : Downsample the encoded features # 3. fm(D(h)) : Middle transformer blocks process downsampled features # 4. fm(D(h))-D(h): Subtract original downsampled features (residual connection) # 5. U(...) : Upsample the processed features # 6. ... + h : Add back original encoder features (skip connection) # 7. fu(...) : Decoder path processes the combined features residuals = [] curr_res = x # Encoder path (computing h = fd(x)) h = x for i, blocks in enumerate(self.encoders): if isinstance(blocks, LevelBlock): h = blocks(h) else: # Save residual before downsampling residuals.append(curr_res) # Downsample and update current residual h = self.downsample(blocks(h)) curr_res = h # Middle blocks (fm) x = h for block in self.middle: x = block(x) # Subtract the residual at this level (D(h)) x = x - curr_res # Decoder path (fu) for i, blocks in enumerate(self.decoders): if isinstance(blocks, LevelBlock): x = blocks(x) else: # Channel reduction x = blocks(x) # Upsample x = self.upsample(x) # Add residual from encoder at this level, LIFO, last residual added is the first we want, since it's this u-shape. curr_res = residuals.pop() x = x + curr_res # Final projection x = self.final_proj(x) return x