Spaces:
Sleeping
Sleeping
# Copyright (c) 2023 Amphion. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import torch | |
import torch.nn as nn | |
from modules.encoder.position_encoder import PositionEncoder | |
from modules.general.utils import append_dims, ConvNd, normalization, zero_module | |
from .attention import AttentionBlock | |
from .resblock import Downsample, ResBlock, Upsample | |
class UNet(nn.Module): | |
r"""The full UNet model with attention and timestep embedding. | |
Args: | |
dims: determines if the signal is 1D (temporal), 2D(spatial). | |
in_channels: channels in the input Tensor. | |
model_channels: base channel count for the model. | |
out_channels: channels in the output Tensor. | |
num_res_blocks: number of residual blocks per downsample. | |
channel_mult: channel multiplier for each level of the UNet. | |
num_attn_blocks: number of attention blocks at place. | |
attention_resolutions: a collection of downsample rates at which attention will | |
take place. May be a set, list, or tuple. For example, if this contains 4, | |
then at 4x downsampling, attention will be used. | |
num_heads: the number of attention heads in each attention layer. | |
num_head_channels: if specified, ignore num_heads and instead use a fixed | |
channel width per attention head. | |
d_context: if specified, use for cross-attention channel project. | |
p_dropout: the dropout probability. | |
use_self_attention: Apply self attention before cross attention. | |
num_classes: if specified (as an int), then this model will be class-conditional | |
with ``num_classes`` classes. | |
use_extra_film: if specified, use an extra FiLM-like conditioning mechanism. | |
d_emb: if specified, use for FiLM-like conditioning. | |
use_scale_shift_norm: use a FiLM-like conditioning mechanism. | |
resblock_updown: use residual blocks for up/downsampling. | |
""" | |
def __init__( | |
self, | |
dims: int = 1, | |
in_channels: int = 100, | |
model_channels: int = 128, | |
out_channels: int = 100, | |
h_dim: int = 128, | |
num_res_blocks: int = 1, | |
channel_mult: tuple = (1, 2, 4), | |
num_attn_blocks: int = 1, | |
attention_resolutions: tuple = (1, 2, 4), | |
num_heads: int = 1, | |
num_head_channels: int = -1, | |
d_context: int = None, | |
context_hdim: int = 128, | |
p_dropout: float = 0.0, | |
num_classes: int = -1, | |
use_extra_film: str = None, | |
d_emb: int = None, | |
use_scale_shift_norm: bool = True, | |
resblock_updown: bool = False, | |
): | |
super().__init__() | |
self.dims = dims | |
self.in_channels = in_channels | |
self.model_channels = model_channels | |
self.out_channels = out_channels | |
self.num_res_blocks = num_res_blocks | |
self.channel_mult = channel_mult | |
self.num_attn_blocks = num_attn_blocks | |
self.attention_resolutions = attention_resolutions | |
self.num_heads = num_heads | |
self.num_head_channels = num_head_channels | |
self.d_context = d_context | |
self.p_dropout = p_dropout | |
self.num_classes = num_classes | |
self.use_extra_film = use_extra_film | |
self.d_emb = d_emb | |
self.use_scale_shift_norm = use_scale_shift_norm | |
self.resblock_updown = resblock_updown | |
time_embed_dim = model_channels * 4 | |
self.pos_enc = PositionEncoder(model_channels, time_embed_dim) | |
assert ( | |
num_classes == -1 or use_extra_film is None | |
), "You cannot set both num_classes and use_extra_film." | |
if self.num_classes > 0: | |
# TODO: if used for singer, norm should be 1, correct? | |
self.label_emb = nn.Embedding(num_classes, time_embed_dim, max_norm=1.0) | |
elif use_extra_film is not None: | |
assert ( | |
d_emb is not None | |
), "d_emb must be specified if use_extra_film is not None" | |
assert use_extra_film in [ | |
"add", | |
"concat", | |
], f"use_extra_film only supported by add or concat. Your input is {use_extra_film}" | |
self.use_extra_film = use_extra_film | |
self.film_emb = ConvNd(dims, d_emb, time_embed_dim, 1) | |
if use_extra_film == "concat": | |
time_embed_dim *= 2 | |
# Input blocks | |
ch = input_ch = int(channel_mult[0] * model_channels) | |
self.input_blocks = nn.ModuleList( | |
[UNetSequential(ConvNd(dims, in_channels, ch, 3, padding=1))] | |
) | |
self._feature_size = ch | |
input_block_chans = [ch] | |
ds = 1 | |
for level, mult in enumerate(channel_mult): | |
for _ in range(num_res_blocks): | |
layers = [ | |
ResBlock( | |
ch, | |
time_embed_dim, | |
p_dropout, | |
out_channels=int(mult * model_channels), | |
dims=dims, | |
use_scale_shift_norm=use_scale_shift_norm, | |
) | |
] | |
ch = int(mult * model_channels) | |
if ds in attention_resolutions: | |
for _ in range(num_attn_blocks): | |
layers.append( | |
AttentionBlock( | |
ch, | |
num_heads=num_heads, | |
num_head_channels=num_head_channels, | |
encoder_channels=d_context, | |
dims=dims, | |
h_dim=h_dim // (level + 1), | |
encoder_hdim=context_hdim, | |
p_dropout=p_dropout, | |
) | |
) | |
self.input_blocks.append(UNetSequential(*layers)) | |
self._feature_size += ch | |
input_block_chans.append(ch) | |
if level != len(channel_mult) - 1: | |
out_ch = ch | |
self.input_blocks.append( | |
UNetSequential( | |
ResBlock( | |
ch, | |
time_embed_dim, | |
p_dropout, | |
out_channels=out_ch, | |
dims=dims, | |
use_scale_shift_norm=use_scale_shift_norm, | |
down=True, | |
) | |
if resblock_updown | |
else Downsample(ch, dims=dims, out_channels=out_ch) | |
) | |
) | |
ch = out_ch | |
input_block_chans.append(ch) | |
ds *= 2 | |
self._feature_size += ch | |
# Middle blocks | |
self.middle_block = UNetSequential( | |
ResBlock( | |
ch, | |
time_embed_dim, | |
p_dropout, | |
dims=dims, | |
use_scale_shift_norm=use_scale_shift_norm, | |
), | |
AttentionBlock( | |
ch, | |
num_heads=num_heads, | |
num_head_channels=num_head_channels, | |
encoder_channels=d_context, | |
dims=dims, | |
h_dim=h_dim // (level + 1), | |
encoder_hdim=context_hdim, | |
p_dropout=p_dropout, | |
), | |
ResBlock( | |
ch, | |
time_embed_dim, | |
p_dropout, | |
dims=dims, | |
use_scale_shift_norm=use_scale_shift_norm, | |
), | |
) | |
self._feature_size += ch | |
# Output blocks | |
self.output_blocks = nn.ModuleList([]) | |
for level, mult in tuple(enumerate(channel_mult))[::-1]: | |
for i in range(num_res_blocks + 1): | |
ich = input_block_chans.pop() | |
layers = [ | |
ResBlock( | |
ch + ich, | |
time_embed_dim, | |
p_dropout, | |
out_channels=int(model_channels * mult), | |
dims=dims, | |
use_scale_shift_norm=use_scale_shift_norm, | |
) | |
] | |
ch = int(model_channels * mult) | |
if ds in attention_resolutions: | |
for _ in range(num_attn_blocks): | |
layers.append( | |
AttentionBlock( | |
ch, | |
num_heads=num_heads, | |
num_head_channels=num_head_channels, | |
encoder_channels=d_context, | |
dims=dims, | |
h_dim=h_dim // (level + 1), | |
encoder_hdim=context_hdim, | |
p_dropout=p_dropout, | |
) | |
) | |
if level and i == num_res_blocks: | |
out_ch = ch | |
layers.append( | |
ResBlock( | |
ch, | |
time_embed_dim, | |
p_dropout, | |
out_channels=out_ch, | |
dims=dims, | |
use_scale_shift_norm=use_scale_shift_norm, | |
up=True, | |
) | |
if resblock_updown | |
else Upsample(ch, dims=dims, out_channels=out_ch) | |
) | |
ds //= 2 | |
self.output_blocks.append(UNetSequential(*layers)) | |
self._feature_size += ch | |
# Final proj out | |
self.out = nn.Sequential( | |
normalization(ch), | |
nn.SiLU(), | |
zero_module(ConvNd(dims, input_ch, out_channels, 3, padding=1)), | |
) | |
def forward(self, x, timesteps=None, context=None, y=None, **kwargs): | |
r"""Apply the model to an input batch. | |
Args: | |
x: an [N x C x ...] Tensor of inputs. | |
timesteps: a 1-D batch of timesteps, i.e. [N]. | |
context: conditioning Tensor with shape of [N x ``d_context`` x ...] plugged | |
in via cross attention. | |
y: an [N] Tensor of labels, if **class-conditional**. | |
an [N x ``d_emb`` x ...] Tensor if **film-embed conditional**. | |
Returns: | |
an [N x C x ...] Tensor of outputs. | |
""" | |
assert (y is None) or ( | |
(y is not None) | |
and ((self.num_classes > 0) or (self.use_extra_film is not None)) | |
), f"y must be specified if num_classes or use_extra_film is not None. \nGot num_classes: {self.num_classes}\t\nuse_extra_film: {self.use_extra_film}\t\n" | |
hs = [] | |
emb = self.pos_enc(timesteps) | |
emb = append_dims(emb, x.dim()) | |
if self.num_classes > 0: | |
assert y.size() == (x.size(0),) | |
emb = emb + self.label_emb(y) | |
elif self.use_extra_film is not None: | |
assert y.size() == (x.size(0), self.d_emb, *x.size()[2:]) | |
y = self.film_emb(y) | |
if self.use_extra_film == "add": | |
emb = emb + y | |
elif self.use_extra_film == "concat": | |
emb = torch.cat([emb, y], dim=1) | |
h = x | |
for module in self.input_blocks: | |
h = module(h, emb, context) | |
hs.append(h) | |
h = self.middle_block(h, emb, context) | |
for module in self.output_blocks: | |
h = torch.cat([h, hs.pop()], dim=1) | |
h = module(h, emb, context) | |
return self.out(h) | |
class UNetSequential(nn.Sequential): | |
r"""A sequential module that passes embeddings to the children that support it.""" | |
def forward(self, x, emb=None, context=None): | |
for layer in self: | |
if isinstance(layer, ResBlock): | |
x = layer(x, emb) | |
elif isinstance(layer, AttentionBlock): | |
x = layer(x, context) | |
else: | |
x = layer(x) | |
return x | |