Spaces:
Running
on
Zero
Running
on
Zero
File size: 8,648 Bytes
c968fc3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 |
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
import torch.nn.functional as F
from modules.general.utils import Conv1d, normalization, zero_module
from .basic import UNetBlock
class AttentionBlock(UNetBlock):
r"""A spatial transformer encoder block that allows spatial positions to attend
to each other. Reference from `latent diffusion repo
<https://github.com/Stability-AI/generative-models/blob/main/sgm/modules/attention.py#L531>`_.
Args:
channels: Number of channels in the input.
num_head_channels: Number of channels per attention head.
num_heads: Number of attention heads. Overrides ``num_head_channels`` if set.
encoder_channels: Number of channels in the encoder output for cross-attention.
If ``None``, then self-attention is performed.
use_self_attention: Whether to use self-attention before cross-attention, only applicable if encoder_channels is set.
dims: Number of spatial dimensions, i.e. 1 for temporal signals, 2 for images.
h_dim: The dimension of the height, would be applied if ``dims`` is 2.
encoder_hdim: The dimension of the height of the encoder output, would be applied if ``dims`` is 2.
p_dropout: Dropout probability.
"""
def __init__(
self,
channels: int,
num_head_channels: int = 32,
num_heads: int = -1,
encoder_channels: int = None,
use_self_attention: bool = False,
dims: int = 1,
h_dim: int = 100,
encoder_hdim: int = 384,
p_dropout: float = 0.0,
):
super().__init__()
self.channels = channels
self.p_dropout = p_dropout
self.dims = dims
if dims == 1:
self.channels = channels
elif dims == 2:
# We consider the channel as product of channel and height, i.e. C x H
# This is because we want to apply attention on the audio signal, which is 1D
self.channels = channels * h_dim
else:
raise ValueError(f"invalid number of dimensions: {dims}")
if num_head_channels == -1:
assert (
self.channels % num_heads == 0
), f"q,k,v channels {self.channels} is not divisible by num_heads {num_heads}"
self.num_heads = num_heads
self.num_head_channels = self.channels // num_heads
else:
assert (
self.channels % num_head_channels == 0
), f"q,k,v channels {self.channels} is not divisible by num_head_channels {num_head_channels}"
self.num_heads = self.channels // num_head_channels
self.num_head_channels = num_head_channels
if encoder_channels is not None:
self.use_self_attention = use_self_attention
if dims == 1:
self.encoder_channels = encoder_channels
elif dims == 2:
self.encoder_channels = encoder_channels * encoder_hdim
else:
raise ValueError(f"invalid number of dimensions: {dims}")
if use_self_attention:
self.self_attention = BasicAttentionBlock(
self.channels,
self.num_head_channels,
self.num_heads,
p_dropout=self.p_dropout,
)
self.cross_attention = BasicAttentionBlock(
self.channels,
self.num_head_channels,
self.num_heads,
self.encoder_channels,
p_dropout=self.p_dropout,
)
else:
self.encoder_channels = None
self.self_attention = BasicAttentionBlock(
self.channels,
self.num_head_channels,
self.num_heads,
p_dropout=self.p_dropout,
)
def forward(self, x: torch.Tensor, encoder_output: torch.Tensor = None):
r"""
Args:
x: input tensor with shape [B x ``channels`` x ...]
encoder_output: feature tensor with shape [B x ``encoder_channels`` x ...], if ``None``, then self-attention is performed.
Returns:
output tensor with shape [B x ``channels`` x ...]
"""
shape = x.size()
x = x.reshape(shape[0], self.channels, -1).contiguous()
if self.encoder_channels is None:
assert (
encoder_output is None
), "encoder_output must be None for self-attention."
h = self.self_attention(x)
else:
assert (
encoder_output is not None
), "encoder_output must be given for cross-attention."
encoder_output = encoder_output.reshape(
shape[0], self.encoder_channels, -1
).contiguous()
if self.use_self_attention:
x = self.self_attention(x)
h = self.cross_attention(x, encoder_output)
return h.reshape(*shape).contiguous()
class BasicAttentionBlock(nn.Module):
def __init__(
self,
channels: int,
num_head_channels: int = 32,
num_heads: int = -1,
context_channels: int = None,
p_dropout: float = 0.0,
):
super().__init__()
self.channels = channels
self.p_dropout = p_dropout
self.context_channels = context_channels
if num_head_channels == -1:
assert (
self.channels % num_heads == 0
), f"q,k,v channels {self.channels} is not divisible by num_heads {num_heads}"
self.num_heads = num_heads
self.num_head_channels = self.channels // num_heads
else:
assert (
self.channels % num_head_channels == 0
), f"q,k,v channels {self.channels} is not divisible by num_head_channels {num_head_channels}"
self.num_heads = self.channels // num_head_channels
self.num_head_channels = num_head_channels
if context_channels is not None:
self.to_q = nn.Sequential(
normalization(self.channels),
Conv1d(self.channels, self.channels, 1),
)
self.to_kv = Conv1d(context_channels, 2 * self.channels, 1)
else:
self.to_qkv = nn.Sequential(
normalization(self.channels),
Conv1d(self.channels, 3 * self.channels, 1),
)
self.linear = Conv1d(self.channels, self.channels)
self.proj_out = nn.Sequential(
normalization(self.channels),
Conv1d(self.channels, self.channels, 1),
nn.GELU(),
nn.Dropout(p=self.p_dropout),
zero_module(Conv1d(self.channels, self.channels, 1)),
)
def forward(self, q: torch.Tensor, kv: torch.Tensor = None):
r"""
Args:
q: input tensor with shape [B, ``channels``, L]
kv: feature tensor with shape [B, ``context_channels``, T], if ``None``, then self-attention is performed.
Returns:
output tensor with shape [B, ``channels``, L]
"""
N, C, L = q.size()
if self.context_channels is not None:
assert kv is not None, "kv must be given for cross-attention."
q = (
self.to_q(q)
.reshape(self.num_heads, self.num_head_channels, -1)
.transpose(-1, -2)
.contiguous()
)
kv = (
self.to_kv(kv)
.reshape(2, self.num_heads, self.num_head_channels, -1)
.transpose(-1, -2)
.chunk(2)
)
k, v = (
kv[0].squeeze(0).contiguous(),
kv[1].squeeze(0).contiguous(),
)
else:
qkv = (
self.to_qkv(q)
.reshape(3, self.num_heads, self.num_head_channels, -1)
.transpose(-1, -2)
.chunk(3)
)
q, k, v = (
qkv[0].squeeze(0).contiguous(),
qkv[1].squeeze(0).contiguous(),
qkv[2].squeeze(0).contiguous(),
)
h = F.scaled_dot_product_attention(q, k, v, dropout_p=self.p_dropout).transpose(
-1, -2
)
h = h.reshape(N, -1, L).contiguous()
h = self.linear(h)
x = q + h
h = self.proj_out(x)
return x + h
|