Spaces:
Sleeping
Sleeping
# Copyright (c) 2023 Amphion. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from modules.general.utils import Conv1d, normalization, zero_module | |
from .basic import UNetBlock | |
class AttentionBlock(UNetBlock): | |
r"""A spatial transformer encoder block that allows spatial positions to attend | |
to each other. Reference from `latent diffusion repo | |
<https://github.com/Stability-AI/generative-models/blob/main/sgm/modules/attention.py#L531>`_. | |
Args: | |
channels: Number of channels in the input. | |
num_head_channels: Number of channels per attention head. | |
num_heads: Number of attention heads. Overrides ``num_head_channels`` if set. | |
encoder_channels: Number of channels in the encoder output for cross-attention. | |
If ``None``, then self-attention is performed. | |
use_self_attention: Whether to use self-attention before cross-attention, only applicable if encoder_channels is set. | |
dims: Number of spatial dimensions, i.e. 1 for temporal signals, 2 for images. | |
h_dim: The dimension of the height, would be applied if ``dims`` is 2. | |
encoder_hdim: The dimension of the height of the encoder output, would be applied if ``dims`` is 2. | |
p_dropout: Dropout probability. | |
""" | |
def __init__( | |
self, | |
channels: int, | |
num_head_channels: int = 32, | |
num_heads: int = -1, | |
encoder_channels: int = None, | |
use_self_attention: bool = False, | |
dims: int = 1, | |
h_dim: int = 100, | |
encoder_hdim: int = 384, | |
p_dropout: float = 0.0, | |
): | |
super().__init__() | |
self.channels = channels | |
self.p_dropout = p_dropout | |
self.dims = dims | |
if dims == 1: | |
self.channels = channels | |
elif dims == 2: | |
# We consider the channel as product of channel and height, i.e. C x H | |
# This is because we want to apply attention on the audio signal, which is 1D | |
self.channels = channels * h_dim | |
else: | |
raise ValueError(f"invalid number of dimensions: {dims}") | |
if num_head_channels == -1: | |
assert ( | |
self.channels % num_heads == 0 | |
), f"q,k,v channels {self.channels} is not divisible by num_heads {num_heads}" | |
self.num_heads = num_heads | |
self.num_head_channels = self.channels // num_heads | |
else: | |
assert ( | |
self.channels % num_head_channels == 0 | |
), f"q,k,v channels {self.channels} is not divisible by num_head_channels {num_head_channels}" | |
self.num_heads = self.channels // num_head_channels | |
self.num_head_channels = num_head_channels | |
if encoder_channels is not None: | |
self.use_self_attention = use_self_attention | |
if dims == 1: | |
self.encoder_channels = encoder_channels | |
elif dims == 2: | |
self.encoder_channels = encoder_channels * encoder_hdim | |
else: | |
raise ValueError(f"invalid number of dimensions: {dims}") | |
if use_self_attention: | |
self.self_attention = BasicAttentionBlock( | |
self.channels, | |
self.num_head_channels, | |
self.num_heads, | |
p_dropout=self.p_dropout, | |
) | |
self.cross_attention = BasicAttentionBlock( | |
self.channels, | |
self.num_head_channels, | |
self.num_heads, | |
self.encoder_channels, | |
p_dropout=self.p_dropout, | |
) | |
else: | |
self.encoder_channels = None | |
self.self_attention = BasicAttentionBlock( | |
self.channels, | |
self.num_head_channels, | |
self.num_heads, | |
p_dropout=self.p_dropout, | |
) | |
def forward(self, x: torch.Tensor, encoder_output: torch.Tensor = None): | |
r""" | |
Args: | |
x: input tensor with shape [B x ``channels`` x ...] | |
encoder_output: feature tensor with shape [B x ``encoder_channels`` x ...], if ``None``, then self-attention is performed. | |
Returns: | |
output tensor with shape [B x ``channels`` x ...] | |
""" | |
shape = x.size() | |
x = x.reshape(shape[0], self.channels, -1).contiguous() | |
if self.encoder_channels is None: | |
assert ( | |
encoder_output is None | |
), "encoder_output must be None for self-attention." | |
h = self.self_attention(x) | |
else: | |
assert ( | |
encoder_output is not None | |
), "encoder_output must be given for cross-attention." | |
encoder_output = encoder_output.reshape( | |
shape[0], self.encoder_channels, -1 | |
).contiguous() | |
if self.use_self_attention: | |
x = self.self_attention(x) | |
h = self.cross_attention(x, encoder_output) | |
return h.reshape(*shape).contiguous() | |
class BasicAttentionBlock(nn.Module): | |
def __init__( | |
self, | |
channels: int, | |
num_head_channels: int = 32, | |
num_heads: int = -1, | |
context_channels: int = None, | |
p_dropout: float = 0.0, | |
): | |
super().__init__() | |
self.channels = channels | |
self.p_dropout = p_dropout | |
self.context_channels = context_channels | |
if num_head_channels == -1: | |
assert ( | |
self.channels % num_heads == 0 | |
), f"q,k,v channels {self.channels} is not divisible by num_heads {num_heads}" | |
self.num_heads = num_heads | |
self.num_head_channels = self.channels // num_heads | |
else: | |
assert ( | |
self.channels % num_head_channels == 0 | |
), f"q,k,v channels {self.channels} is not divisible by num_head_channels {num_head_channels}" | |
self.num_heads = self.channels // num_head_channels | |
self.num_head_channels = num_head_channels | |
if context_channels is not None: | |
self.to_q = nn.Sequential( | |
normalization(self.channels), | |
Conv1d(self.channels, self.channels, 1), | |
) | |
self.to_kv = Conv1d(context_channels, 2 * self.channels, 1) | |
else: | |
self.to_qkv = nn.Sequential( | |
normalization(self.channels), | |
Conv1d(self.channels, 3 * self.channels, 1), | |
) | |
self.linear = Conv1d(self.channels, self.channels) | |
self.proj_out = nn.Sequential( | |
normalization(self.channels), | |
Conv1d(self.channels, self.channels, 1), | |
nn.GELU(), | |
nn.Dropout(p=self.p_dropout), | |
zero_module(Conv1d(self.channels, self.channels, 1)), | |
) | |
def forward(self, q: torch.Tensor, kv: torch.Tensor = None): | |
r""" | |
Args: | |
q: input tensor with shape [B, ``channels``, L] | |
kv: feature tensor with shape [B, ``context_channels``, T], if ``None``, then self-attention is performed. | |
Returns: | |
output tensor with shape [B, ``channels``, L] | |
""" | |
N, C, L = q.size() | |
if self.context_channels is not None: | |
assert kv is not None, "kv must be given for cross-attention." | |
q = ( | |
self.to_q(q) | |
.reshape(self.num_heads, self.num_head_channels, -1) | |
.transpose(-1, -2) | |
.contiguous() | |
) | |
kv = ( | |
self.to_kv(kv) | |
.reshape(2, self.num_heads, self.num_head_channels, -1) | |
.transpose(-1, -2) | |
.chunk(2) | |
) | |
k, v = ( | |
kv[0].squeeze(0).contiguous(), | |
kv[1].squeeze(0).contiguous(), | |
) | |
else: | |
qkv = ( | |
self.to_qkv(q) | |
.reshape(3, self.num_heads, self.num_head_channels, -1) | |
.transpose(-1, -2) | |
.chunk(3) | |
) | |
q, k, v = ( | |
qkv[0].squeeze(0).contiguous(), | |
qkv[1].squeeze(0).contiguous(), | |
qkv[2].squeeze(0).contiguous(), | |
) | |
h = F.scaled_dot_product_attention(q, k, v, dropout_p=self.p_dropout).transpose( | |
-1, -2 | |
) | |
h = h.reshape(N, -1, L).contiguous() | |
h = self.linear(h) | |
x = q + h | |
h = self.proj_out(x) | |
return x + h | |