Spaces:
Build error
Build error
"""Conformer definition adjusted given the Lucidrain's repo. | |
https://github.com/lucidrains/soundstorm-pytorch/blob/main/soundstorm_pytorch/soundstorm.py # noqa | |
Copyright PolyAI Limited. | |
""" | |
from collections import namedtuple | |
from functools import wraps | |
from typing import Dict, Union | |
import torch | |
import torch.nn.functional as F | |
from einops import rearrange, reduce | |
from einops.layers.torch import EinMix, Rearrange | |
from torch import einsum, nn | |
# rotary embedding | |
class RotaryEmbedding(nn.Module): | |
def __init__(self, dim, theta = 10000): | |
super().__init__() | |
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) | |
self.register_buffer("inv_freq", inv_freq, persistent = False) | |
def device(self): | |
return next(self.buffers()).device | |
def forward(self, seq_len): | |
t = torch.arange(seq_len, device = self.device).type_as(self.inv_freq) | |
freqs = torch.einsum('i , j -> i j', t, self.inv_freq) | |
freqs = torch.cat((freqs, freqs), dim = -1) | |
return freqs | |
def rotate_half(x): | |
x1, x2 = x.chunk(2, dim=-1) | |
return torch.cat((-x2, x1), dim=-1) | |
def apply_rotary_pos_emb(pos, t): | |
return (t * pos.cos()) + (rotate_half(t) * pos.sin()) | |
# constants | |
EfficientAttentionConfig = namedtuple( | |
'EfficientAttentionConfig', | |
['enable_flash', 'enable_math', 'enable_mem_efficient'] | |
) | |
# helpers | |
def exists(val): | |
return val is not None | |
def default(val, d): | |
return val if exists(val) else d | |
def divisible_by(numer, denom): | |
return (numer % denom) == 0 | |
def calc_same_padding(kernel_size): | |
pad = kernel_size // 2 | |
return (pad, pad - (kernel_size + 1) % 2) | |
def eval_decorator(fn): | |
def inner(model, *args, **kwargs): | |
was_training = model.training | |
model.eval() | |
out = fn(model, *args, **kwargs) | |
model.train(was_training) | |
return out | |
return inner | |
def once(fn): | |
called = False | |
def inner(x): | |
nonlocal called | |
if called: | |
return | |
called = True | |
return fn(x) | |
return inner | |
print_once = once(print) | |
# t5 relative positional bias | |
class T5RelativePositionBias(nn.Module): | |
def __init__( | |
self, | |
scale = 1., | |
num_buckets = 32, | |
max_distance = 128, | |
heads = 8 | |
): | |
super().__init__() | |
self.scale = scale | |
self.num_buckets = num_buckets | |
self.max_distance = max_distance | |
self.relative_attention_bias = nn.Embedding(num_buckets, heads) | |
def _relative_position_bucket( | |
relative_position, | |
num_buckets = 32, | |
max_distance = 128 | |
): | |
ret = 0 | |
n = -relative_position | |
num_buckets //= 2 | |
ret += (n < 0).long() * num_buckets | |
n = torch.abs(n) | |
max_exact = num_buckets // 2 | |
is_small = n < max_exact | |
val_if_large = max_exact + ( | |
torch.log(n.float() / max_exact) / math.log( | |
max_distance / max_exact) * (num_buckets - max_exact) | |
).long() | |
val_if_large = torch.min( | |
val_if_large, | |
torch.full_like(val_if_large, num_buckets - 1) | |
) | |
ret += torch.where(is_small, n, val_if_large) | |
return ret | |
def device(self): | |
return next(self.parameters()).device | |
def forward(self, n): | |
pos = torch.arange(n, device = self.device).long() | |
rel_pos = rearrange(pos, 'j -> 1 j') - rearrange(pos, 'i -> i 1') | |
rp_bucket = self._relative_position_bucket( | |
rel_pos, num_buckets = self.num_buckets, | |
max_distance = self.max_distance) | |
values = self.relative_attention_bias(rp_bucket) | |
bias = rearrange(values, 'i j h -> h i j') | |
return bias * self.scale | |
# main class | |
class Attend(nn.Module): | |
def __init__( | |
self, | |
causal = False, | |
dropout = 0., | |
flash = False | |
): | |
super().__init__() | |
self.dropout = dropout | |
self.attn_dropout = nn.Dropout(dropout) | |
self.causal = causal | |
self.flash = flash | |
# determine efficient attention configs for cuda and cpu | |
self.cpu_config = EfficientAttentionConfig(True, True, True) | |
self.cuda_config = None | |
if not torch.cuda.is_available() or not flash: | |
return | |
device_properties = torch.cuda.get_device_properties(torch.device('cuda')) | |
if device_properties.major == 8 and device_properties.minor == 0: | |
print_once('A100 GPU detected, using flash attention if input tensor is on cuda') # noqa | |
self.cuda_config = EfficientAttentionConfig(True, True, True) | |
else: | |
print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda') # noqa | |
self.cuda_config = EfficientAttentionConfig(False, True, True) | |
def get_mask(self, i, j, device): | |
return torch.ones((i, j), device=device, dtype=torch.bool).triu(j - i + 1) # noqa | |
def flash_attn(self, q, k, v, mask = None, attn_bias = None): | |
_, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device # noqa | |
# single headed key / values | |
if k.ndim == 3: | |
k = rearrange(k, 'b n d -> b 1 n d') | |
if v.ndim == 3: | |
v = rearrange(v, 'b n d -> b 1 n d') | |
# Check if mask exists and expand to compatible shape | |
# The mask is B L, so it would have to be expanded to B H N L | |
if exists(mask) and mask.ndim != 4: | |
mask = rearrange(mask, 'b j -> b 1 1 j') | |
mask = mask.expand(-1, heads, q_len, -1) | |
# Check if there is a compatible device for flash attention | |
config = self.cuda_config if is_cuda else self.cpu_config | |
causal = self.causal | |
# handle attention bias | |
if exists(attn_bias): | |
mask_value = -torch.finfo(q.dtype).max // 2 | |
causal_mask = self.get_mask(q_len, k_len, device) | |
attn_bias = attn_bias.masked_fill(causal_mask, mask_value) | |
if exists(mask): | |
attn_bias = attn_bias.masked_fill(~mask, mask_value) | |
mask = attn_bias | |
causal = False | |
# pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale | |
with torch.backends.cuda.sdp_kernel(**config._asdict()): | |
out = F.scaled_dot_product_attention( | |
q, k, v, | |
attn_mask = mask, | |
dropout_p = self.dropout if self.training else 0., | |
is_causal = causal | |
) | |
return out | |
def forward(self, q, k, v, mask = None, attn_bias = None): | |
""" | |
einstein notation | |
b - batch | |
h - heads | |
n, i, j - sequence length (base sequence length, source, target) | |
d - feature dimension | |
""" | |
q_len, k_len, device = q.shape[-2], k.shape[-2], q.device | |
scale = q.shape[-1] ** -0.5 | |
kv_einsum_eq = 'b j d' if k.ndim == 3 else 'b h j d' | |
if self.flash: | |
assert not exists(attn_bias) | |
return self.flash_attn(q, k, v, mask = mask) | |
# similarity | |
sim = einsum(f"b h i d, {kv_einsum_eq} -> b h i j", q, k) * scale | |
# attention bias | |
if exists(attn_bias): | |
sim = sim + attn_bias | |
# causal mask | |
if self.causal: | |
causal_mask = self.get_mask(q_len, k_len, device) | |
sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max) | |
# key padding mask | |
if exists(mask): | |
if mask.ndim != 4: | |
mask = rearrange(mask, 'b j -> b 1 1 j') | |
sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max) | |
# attention | |
attn = sim.softmax(dim=-1) | |
attn = self.attn_dropout(attn) | |
# aggregate values | |
out = einsum(f"b h i j, {kv_einsum_eq} -> b h i d", attn, v) | |
return out | |
class Swish(nn.Module): | |
def forward(self, x): | |
return x * x.sigmoid() | |
class GLU(nn.Module): | |
def __init__(self, dim): | |
super().__init__() | |
self.dim = dim | |
def forward(self, x): | |
out, gate = x.chunk(2, dim=self.dim) | |
return out * gate.sigmoid() | |
class DepthWiseConv1d(nn.Module): | |
def __init__(self, chan_in, chan_out, kernel_size, padding): | |
super().__init__() | |
self.padding = padding | |
self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, groups = chan_in) | |
def forward(self, x): | |
x = F.pad(x, self.padding) | |
return self.conv(x) | |
class Scale(nn.Module): | |
def __init__(self, scale, fn): | |
super().__init__() | |
self.fn = fn | |
self.scale = scale | |
def forward(self, x, **kwargs): | |
return self.fn(x, **kwargs) * self.scale | |
class ChanLayerNorm(nn.Module): | |
def __init__(self, dim): | |
super().__init__() | |
self.gamma = nn.Parameter(torch.ones(1, dim, 1)) | |
def forward(self, x): | |
eps = 1e-6 if x.dtype == torch.float32 else 1e-4 | |
var = torch.var(x, dim = 1, unbiased = False, keepdim = True) | |
mean = torch.mean(x, dim = 1, keepdim = True) | |
return (x - mean) * var.clamp(min = eps).rsqrt() * self.gamma | |
class PreNorm(nn.Module): | |
def __init__(self, dim, fn): | |
super().__init__() | |
self.fn = fn | |
self.norm = nn.LayerNorm(dim) | |
def forward(self, x, **kwargs): | |
x = self.norm(x) | |
return self.fn(x, **kwargs) | |
class Attention(nn.Module): | |
def __init__( | |
self, | |
dim, | |
heads = 8, | |
dim_head = 64, | |
dropout = 0., | |
flash = True | |
): | |
super().__init__() | |
inner_dim = dim_head * heads | |
self.heads= heads | |
self.scale = dim_head ** -0.5 | |
self.attend = Attend( | |
flash = flash, | |
dropout = dropout | |
) | |
self.dropout = nn.Dropout(dropout) | |
self.to_q = nn.Linear(dim, inner_dim, bias = False) | |
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False) | |
self.to_out = nn.Linear(inner_dim, dim) | |
def forward( | |
self, | |
x, | |
context = None, | |
mask = None, | |
rotary_emb = None, | |
attn_bias = None | |
): | |
n, device, h, has_context = x.shape[-2], x.device, self.heads, exists(context) | |
context = default(context, x) | |
q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1)) | |
q, k, v = map( | |
lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v)) | |
if exists(rotary_emb): | |
q = apply_rotary_pos_emb(rotary_emb, q) | |
k = apply_rotary_pos_emb(rotary_emb, k) | |
out = self.attend(q, k, v, mask = mask, attn_bias = attn_bias) | |
out = rearrange(out, 'b h n d -> b n (h d)') | |
return self.to_out(out) | |
class FeedForward(nn.Module): | |
def __init__( | |
self, | |
dim, | |
mult = 4, | |
dropout = 0. | |
): | |
super().__init__() | |
self.net = nn.Sequential( | |
nn.Linear(dim, dim * mult), | |
Swish(), | |
nn.Dropout(dropout), | |
nn.Linear(dim * mult, dim), | |
nn.Dropout(dropout) | |
) | |
def forward(self, x): | |
return self.net(x) | |
class ConformerConvModule(nn.Module): | |
def __init__( | |
self, | |
dim, | |
causal = False, | |
expansion_factor = 2, | |
kernel_size = 31, | |
dropout = 0. | |
): | |
super().__init__() | |
inner_dim = dim * expansion_factor | |
padding = calc_same_padding(kernel_size) if not causal else (kernel_size - 1, 0) | |
self.net = nn.Sequential( | |
nn.LayerNorm(dim), | |
Rearrange('b n c -> b c n'), | |
nn.Conv1d(dim, inner_dim * 2, 1), | |
GLU(dim=1), | |
DepthWiseConv1d( | |
inner_dim, inner_dim, kernel_size = kernel_size, | |
padding = padding | |
), | |
Swish(), | |
ChanLayerNorm(inner_dim), | |
nn.Conv1d(inner_dim, dim, 1), | |
Rearrange('b c n -> b n c'), | |
nn.Dropout(dropout) | |
) | |
def forward(self, x): | |
return self.net(x) | |
# Conformer Block | |
class ConformerBlock(nn.Module): | |
def __init__( | |
self, | |
*, | |
dim, | |
dim_head = 64, | |
heads = 8, | |
ff_mult = 4, | |
conv_expansion_factor = 2, | |
conv_kernel_size = 31, | |
attn_dropout = 0., | |
attn_flash = True, | |
ff_dropout = 0., | |
conv_dropout = 0., | |
conv_causal = False | |
): | |
super().__init__() | |
self.ff1 = FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout) | |
self.attn = Attention( | |
dim = dim, dim_head = dim_head, heads = heads, | |
dropout = attn_dropout, flash = attn_flash | |
) | |
self.conv = ConformerConvModule( | |
dim = dim, causal = conv_causal, | |
expansion_factor = conv_expansion_factor, | |
kernel_size = conv_kernel_size, dropout = conv_dropout | |
) | |
self.ff2 = FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout) | |
self.attn = PreNorm(dim, self.attn) | |
self.ff1 = Scale(0.5, PreNorm(dim, self.ff1)) | |
self.ff2 = Scale(0.5, PreNorm(dim, self.ff2)) | |
self.post_norm = nn.LayerNorm(dim) | |
def forward( | |
self, | |
x, | |
mask = None, | |
rotary_emb = None, | |
attn_bias = None | |
): | |
x = self.ff1(x) + x | |
x = self.attn(x, mask = mask, rotary_emb = rotary_emb, attn_bias = attn_bias) + x # noqa | |
x = self.conv(x) + x | |
x = self.ff2(x) + x | |
x = self.post_norm(x) | |
return x | |
# Conformer | |
class Conformer(nn.Module): | |
def __init__( | |
self, | |
dim, | |
*, | |
num_layers, | |
dim_head = 64, | |
heads = 8, | |
ff_mult = 4, | |
conv_expansion_factor = 2, | |
conv_kernel_size = 31, | |
attn_dropout = 0., | |
ff_dropout = 0., | |
conv_dropout = 0., | |
conv_causal = False, | |
attn_flash = True, | |
t5_rel_pos_bias = False | |
): | |
super().__init__() | |
assert not (t5_rel_pos_bias and attn_flash), 'flash attention is not compatible with learned bias' # noqa | |
self.dim = dim | |
self.layers = nn.ModuleList([]) | |
self.rotary_emb = RotaryEmbedding( | |
dim_head) if not t5_rel_pos_bias else None | |
self.rel_pos_bias = T5RelativePositionBias( | |
dim_head ** 0.5, heads = heads) if t5_rel_pos_bias else None | |
for _ in range(num_layers): | |
self.layers.append(ConformerBlock( | |
dim = dim, | |
dim_head = dim_head, | |
heads = heads, | |
ff_mult = ff_mult, | |
conv_expansion_factor = conv_expansion_factor, | |
conv_kernel_size = conv_kernel_size, | |
attn_dropout = attn_dropout, | |
ff_dropout = ff_dropout, | |
conv_dropout = conv_dropout, | |
conv_causal = conv_causal, | |
attn_flash = attn_flash | |
)) | |
def forward(self, x, mask = None): | |
seq_len = x.shape[-2] | |
rotary_emb = self.rotary_emb(seq_len) if exists(self.rotary_emb) else None # noqa | |
attn_bias = self.rel_pos_bias(seq_len) if exists(self.rel_pos_bias) else None #noqa | |
for block in self.layers: | |
x = block( | |
x, | |
mask = mask, | |
rotary_emb = rotary_emb, | |
attn_bias = attn_bias | |
) | |
return x | |
# conformer with sum reduction across quantized tokens at the beginning, | |
# along with heads | |
class ConformerWrapper(nn.Module): | |
def __init__( | |
self, | |
*, | |
codebook_size, | |
num_quantizers, | |
conformer: Union[Conformer, Dict[str, any]], | |
grouped_quantizers = 1 | |
): | |
super().__init__() | |
self.conformer = conformer | |
if isinstance(conformer, dict): | |
self.conformer = Conformer(**self.conformer) | |
dim = self.conformer.dim | |
self.embedding_proj = nn.Sequential( | |
nn.Linear(dim * grouped_quantizers, dim), | |
nn.LayerNorm(dim) | |
) if grouped_quantizers > 1 else nn.Identity() | |
num_codes_with_mask = codebook_size + 1 | |
num_effective_quantizers = num_quantizers * grouped_quantizers | |
self.code_embeds = nn.Embedding( | |
num_codes_with_mask * num_effective_quantizers, dim) | |
self.register_buffer( | |
'quantizer_offsets', | |
torch.arange(num_effective_quantizers) * num_codes_with_mask, | |
persistent = False | |
) | |
self.register_buffer( | |
'mask_tokens', self.quantizer_offsets + num_codes_with_mask, | |
persistent = False | |
) | |
self.dim = dim | |
self.codebook_size = codebook_size | |
self.num_codes_with_mask = num_codes_with_mask | |
self.num_quantizers = num_quantizers | |
self.grouped_quantizers = grouped_quantizers | |
self.heads = nn.Sequential( | |
nn.Linear(dim, dim * num_effective_quantizers), | |
Rearrange('b n (h d) -> b (n h) d', h = num_effective_quantizers) | |
) | |
# each quantizer codebook would require its own logits weight | |
# and bias matrices | |
# the amazing einops makes this easy with 'EinMix' | |
self.to_logits = nn.Sequential( | |
nn.LayerNorm(dim), | |
Rearrange('b (n gq) d -> b n gq d', gq = num_effective_quantizers), | |
EinMix( | |
'b n gq d -> b n gq l', | |
weight_shape = 'gq d l', | |
bias_shape = 'gq l', | |
gq = num_effective_quantizers, | |
l = codebook_size, | |
d = dim | |
), | |
Rearrange('b ... d -> b (...) d') | |
) | |
def forward( | |
self, | |
x, | |
*, | |
mask = None, | |
cond = None, | |
sum_embeds = None, | |
return_embeddings = False, | |
return_logits_and_embeddings = False | |
): | |
""" | |
einops notation: | |
b - batch | |
n - sequence | |
g - groups | |
q - quantizers | |
d - feature dimension | |
""" | |
n, q, g = x.shape[-1], self.num_quantizers, self.grouped_quantizers | |
assert divisible_by(n, g * q), 'sequence must be divisible by number of quantizers' # noqa | |
x = rearrange(x, 'b (n gq) -> b n gq', gq = g * q) | |
x = x + self.quantizer_offsets | |
x = self.code_embeds(x) | |
x = reduce(x, 'b n (g q) d -> b n (g d)', 'sum', g = g) | |
x = self.embedding_proj(x) | |
if exists(sum_embeds): | |
x = x + sum_embeds | |
if exists(cond): | |
if cond.ndim == 2: | |
cond = rearrange(cond, 'b d -> b 1 d') | |
x = x + cond | |
x = self.conformer(x, mask = mask) | |
embeds = self.heads(x) | |
if return_embeddings or not exists(self.to_logits): | |
return embeds | |
logits = self.to_logits(embeds) | |
if return_logits_and_embeddings: | |
return logits, embeds | |
return logits | |