movie-diffusion / unet.py
Anton Forsman
fix device issue
06ebfb2
from typing import Any
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from collections import defaultdict
import torch as th
import numpy as np
import math
str_to_act = defaultdict(lambda: nn.SiLU())
str_to_act.update({
"relu": nn.ReLU(),
"silu": nn.SiLU(),
"gelu": nn.GELU(),
})
class SinusoidalPositionalEncoding(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, t):
device = t.device
t = t.unsqueeze(-1)
inv_freq = 1.0 / (10000 ** (torch.arange(0, self.dim, 2, device=device).float() / self.dim))
sin_enc = torch.sin(t.repeat(1, self.dim // 2) * inv_freq)
cos_enc = torch.cos(t.repeat(1, self.dim // 2) * inv_freq)
pos_enc = torch.cat([sin_enc, cos_enc], dim=-1)
return pos_enc
class TimeEmbedding(nn.Module):
def __init__(self, model_dim: int, emb_dim: int, act="silu"):
super().__init__()
self.lin = nn.Linear(model_dim, emb_dim)
self.act = str_to_act[act]
self.lin2 = nn.Linear(emb_dim, emb_dim)
def forward(self, x):
x = self.lin(x)
x = self.act(x)
x = self.lin2(x)
return x
class ConvBlock(nn.Module):
def __init__(self, in_channels, out_channels, act="silu", dropout=None, zero=False):
super().__init__()
self.norm = nn.GroupNorm(
num_groups=32,
num_channels=in_channels,
)
self.act = str_to_act[act]
if dropout is not None:
self.dropout = nn.Dropout(dropout)
self.conv = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
padding=1,
)
if zero:
self.conv.weight.data.zero_()
def forward(self, x):
x = self.norm(x)
x = self.act(x)
if hasattr(self, "dropout"):
x = self.dropout(x)
x = self.conv(x)
return x
class EmbeddingBlock(nn.Module):
def __init__(self, channels: int, emb_dim: int, act="silu"):
super().__init__()
self.act = str_to_act[act]
self.lin = nn.Linear(emb_dim, channels)
def forward(self, x):
x = self.act(x)
x = self.lin(x)
return x
class ResBlock(nn.Module):
def __init__(self, channels: int, emb_dim: int, dropout: float = 0, out_channels=None):
"""A resblock with a time embedding and an optional change in channel count
"""
if out_channels is None:
out_channels = channels
super().__init__()
self.conv1 = ConvBlock(channels, out_channels)
self.emb = EmbeddingBlock(out_channels, emb_dim)
self.conv2 = ConvBlock(out_channels, out_channels, dropout=dropout, zero=True)
if channels != out_channels:
self.skip_connection = nn.Conv2d(channels, out_channels, kernel_size=1)
else:
self.skip_connection = nn.Identity()
def forward(self, x, t):
original = x
x = self.conv1(x)
t = self.emb(t)
# t: (batch_size, time_embedding_dim) = (batch_size, out_channels)
# x: (batch_size, out_channels, height, width)
# we repeat the time embedding to match the shape of x
t = t.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, x.shape[2], x.shape[3])
x = x + t
x = self.conv2(x)
x = x + self.skip_connection(original)
return x
class SelfAttentionBlock(nn.Module):
def __init__(self, channels, num_heads=1):
super().__init__()
self.channels = channels
self.num_heads = num_heads
self.norm = nn.GroupNorm(32, channels)
self.attention = nn.MultiheadAttention(
embed_dim=channels,
num_heads=num_heads,
dropout=0,
batch_first=True,
bias=True,
)
def forward(self, x):
h, w = x.shape[-2:]
original = x
x = self.norm(x)
x = rearrange(x, "b c h w -> b (h w) c")
x = self.attention(x, x, x)[0]
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
return x + original
class Downsample(nn.Module):
def __init__(self, channels):
super().__init__()
# ddpm uses maxpool
# self.down = nn.MaxPool2d
# iddpm uses strided conv
self.down = nn.Conv2d(
in_channels=channels,
out_channels=channels,
kernel_size=3,
stride=2,
padding=1,
)
def forward(self, x):
return self.down(x)
class DownBlock(nn.Module):
"""According to U-Net paper
'The contracting path follows the typical architecture of a convolutional network.
It consists of the repeated application of two 3x3 convolutions (unpadded convolutions),
each followed by a rectified linear unit (ReLU) and a 2x2 max pooling operation with stride 2
for downsampling. At each downsampling step we double the number of feature channels.'
"""
def __init__(self, in_channels, out_channels, time_embedding_dim, use_attn=False, dropout=0, downsample=True, width=1):
"""in_channels will typically be half of out_channels"""
super().__init__()
self.width = width
self.use_attn = use_attn
self.do_downsample = downsample
self.blocks = nn.ModuleList()
for _ in range(width):
self.blocks.append(ResBlock(
channels=in_channels,
out_channels=out_channels,
emb_dim=time_embedding_dim,
dropout=dropout,
))
if self.use_attn:
self.blocks.append(SelfAttentionBlock(
channels=out_channels,
))
in_channels = out_channels
if self.do_downsample:
self.downsample = Downsample(out_channels)
def forward(self, x, t):
for block in self.blocks:
if isinstance(block, ResBlock):
x = block(x, t)
elif isinstance(block, SelfAttentionBlock):
x = block(x)
residual = x
if self.do_downsample:
x = self.downsample(x)
return x, residual
class Upsample(nn.Module):
def __init__(self, channels):
super().__init__()
self.upsample = nn.Upsample(scale_factor=2)
self.conv = nn.Conv2d(
in_channels=channels,
out_channels=channels,
kernel_size=3,
padding=1,
)
def forward(self, x):
x = self.upsample(x)
x = self.conv(x)
return x
class UpBlock(nn.Module):
"""According to U-Net paper
Every step in the expansive path consists of an upsampling of the feature map followed by a 2x2
convolution (“up-convolution”) that halves the number of feature channels, a concatenation with
the correspondingly cropped feature map from the contracting path, and two 3x3 convolutions,
each followed by a ReLU.
"""
def __init__(self, in_channels, out_channels, time_embedding_dim, use_attn=False, dropout=0, upsample=True, width=1):
"""in_channels will typically be double of out_channels
"""
super().__init__()
self.use_attn = use_attn
self.do_upsample = upsample
self.blocks = nn.ModuleList()
for _ in range(width):
self.blocks.append(ResBlock(
channels=in_channels,
out_channels=out_channels,
emb_dim=time_embedding_dim,
dropout=dropout,
))
if self.use_attn:
self.blocks.append(SelfAttentionBlock(
channels=out_channels,
))
in_channels = out_channels
if self.do_upsample:
self.upsample = Upsample(out_channels)
def forward(self, x, t):
for block in self.blocks:
if isinstance(block, ResBlock):
x = block(x, t)
elif isinstance(block, SelfAttentionBlock):
x = block(x)
if self.do_upsample:
x = self.upsample(x)
return x
class Bottleneck(nn.Module):
def __init__(self, channels, dropout, time_embedding_dim):
super().__init__()
in_channels = channels
out_channels = channels
self.resblock_1 = ResBlock(
channels=in_channels,
out_channels=out_channels,
dropout=dropout,
emb_dim=time_embedding_dim
)
self.attention_block = SelfAttentionBlock(
channels=out_channels,
)
self.resblock_2 = ResBlock(
channels=out_channels,
out_channels=out_channels,
dropout=dropout,
emb_dim=time_embedding_dim
)
def forward(self, x, t):
x = self.resblock_1(x, t)
x = self.attention_block(x)
x = self.resblock_2(x, t)
return x
class Unet(nn.Module):
def __init__(
self,
image_channels=3,
res_block_width=2,
starting_channels=128,
dropout=0,
channel_mults=(1, 2, 2, 4, 4),
attention_layers=(False, False, False, True, False)
):
super().__init__()
self.is_conditional = False
#channel_mults = (1, 2, 2, 2)
#attention_layers = (False, False, True, False)
#res_block_width=3
self.image_channels = image_channels
self.starting_channels = starting_channels
time_embedding_dim = 4 * starting_channels
self.time_encoding = SinusoidalPositionalEncoding(dim=starting_channels)
self.time_embedding = TimeEmbedding(model_dim=starting_channels, emb_dim=time_embedding_dim)
self.input = nn.Conv2d(3, starting_channels, kernel_size=3, padding=1)
current_channel_count = starting_channels
input_channel_counts = []
self.contracting_path = nn.ModuleList([])
for i, channel_multiplier in enumerate(channel_mults):
is_last_layer = i == len(channel_mults) - 1
next_channel_count = channel_multiplier * starting_channels
self.contracting_path.append(DownBlock(
in_channels=current_channel_count,
out_channels=next_channel_count,
time_embedding_dim=time_embedding_dim,
use_attn=attention_layers[i],
dropout=dropout,
downsample=not is_last_layer,
width=res_block_width,
))
current_channel_count = next_channel_count
input_channel_counts.append(current_channel_count)
self.bottleneck = Bottleneck(channels=current_channel_count, time_embedding_dim=time_embedding_dim, dropout=dropout)
self.expansive_path = nn.ModuleList([])
for i, channel_multiplier in enumerate(reversed(channel_mults)):
next_channel_count = channel_multiplier * starting_channels
self.expansive_path.append(UpBlock(
in_channels=current_channel_count + input_channel_counts.pop(),
out_channels=next_channel_count,
time_embedding_dim=time_embedding_dim,
use_attn=list(reversed(attention_layers))[i],
dropout=dropout,
upsample=i != len(channel_mults) - 1,
width=res_block_width,
))
current_channel_count = next_channel_count
last_conv = nn.Conv2d(
in_channels=starting_channels,
out_channels=image_channels,
kernel_size=3,
padding=1,
)
last_conv.weight.data.zero_()
self.head = nn.Sequential(
nn.GroupNorm(32, starting_channels),
nn.SiLU(),
last_conv,
)
def forward(self, x, t):
t = self.time_encoding(t)
return self._forward(x, t)
def _forward(self, x, t):
t = self.time_embedding(t)
x = self.input(x)
residuals = []
for contracting_block in self.contracting_path:
x, residual = contracting_block(x, t)
residuals.append(residual)
x = self.bottleneck(x, t)
for expansive_block in self.expansive_path:
# Add the residual
residual = residuals.pop()
x = torch.cat([x, residual], dim=1)
x = expansive_block(x, t)
x = self.head(x)
return x
class ConditionalUnet(nn.Module):
def __init__(self, unet, num_classes):
super().__init__()
self.is_conditional = True
self.unet = unet
self.num_classes = num_classes
self.class_embedding = nn.Embedding(num_classes + 1, unet.starting_channels, padding_idx=0)
def to(self, device):
self.device = device
return super().to(device)
def forward(self, x, t, cond=None):
cond = torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
cond = cond.unsqueeze(0)
cond = cond.to(self.device)
# cond: (batch_size, n), where n is the number of classes that we are conditioning on
t = self.unet.time_encoding(t)
if cond is not None:
cond = self.class_embedding(cond)
# sum across the classes so we get a single vector representing the set of classes
cond = cond.sum(dim=1)
t += cond
return self.unet._forward(x, t)