PolyAI-pheme / modules /conformer.py
taras-sereda's picture
minimal set of files to run inference; pheme-small checkpoint
96ee597
raw
history blame
19.1 kB
"""Conformer definition adjusted given the Lucidrain's repo.
https://github.com/lucidrains/soundstorm-pytorch/blob/main/soundstorm_pytorch/soundstorm.py # noqa
Copyright PolyAI Limited.
"""
from collections import namedtuple
from functools import wraps
from typing import Dict, Union
import torch
import torch.nn.functional as F
from einops import rearrange, reduce
from einops.layers.torch import EinMix, Rearrange
from torch import einsum, nn
# rotary embedding
class RotaryEmbedding(nn.Module):
def __init__(self, dim, theta = 10000):
super().__init__()
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq, persistent = False)
@property
def device(self):
return next(self.buffers()).device
def forward(self, seq_len):
t = torch.arange(seq_len, device = self.device).type_as(self.inv_freq)
freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
freqs = torch.cat((freqs, freqs), dim = -1)
return freqs
def rotate_half(x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(pos, t):
return (t * pos.cos()) + (rotate_half(t) * pos.sin())
# constants
EfficientAttentionConfig = namedtuple(
'EfficientAttentionConfig',
['enable_flash', 'enable_math', 'enable_mem_efficient']
)
# helpers
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def divisible_by(numer, denom):
return (numer % denom) == 0
def calc_same_padding(kernel_size):
pad = kernel_size // 2
return (pad, pad - (kernel_size + 1) % 2)
def eval_decorator(fn):
@wraps(fn)
def inner(model, *args, **kwargs):
was_training = model.training
model.eval()
out = fn(model, *args, **kwargs)
model.train(was_training)
return out
return inner
def once(fn):
called = False
@wraps(fn)
def inner(x):
nonlocal called
if called:
return
called = True
return fn(x)
return inner
print_once = once(print)
# t5 relative positional bias
class T5RelativePositionBias(nn.Module):
def __init__(
self,
scale = 1.,
num_buckets = 32,
max_distance = 128,
heads = 8
):
super().__init__()
self.scale = scale
self.num_buckets = num_buckets
self.max_distance = max_distance
self.relative_attention_bias = nn.Embedding(num_buckets, heads)
@staticmethod
def _relative_position_bucket(
relative_position,
num_buckets = 32,
max_distance = 128
):
ret = 0
n = -relative_position
num_buckets //= 2
ret += (n < 0).long() * num_buckets
n = torch.abs(n)
max_exact = num_buckets // 2
is_small = n < max_exact
val_if_large = max_exact + (
torch.log(n.float() / max_exact) / math.log(
max_distance / max_exact) * (num_buckets - max_exact)
).long()
val_if_large = torch.min(
val_if_large,
torch.full_like(val_if_large, num_buckets - 1)
)
ret += torch.where(is_small, n, val_if_large)
return ret
@property
def device(self):
return next(self.parameters()).device
def forward(self, n):
pos = torch.arange(n, device = self.device).long()
rel_pos = rearrange(pos, 'j -> 1 j') - rearrange(pos, 'i -> i 1')
rp_bucket = self._relative_position_bucket(
rel_pos, num_buckets = self.num_buckets,
max_distance = self.max_distance)
values = self.relative_attention_bias(rp_bucket)
bias = rearrange(values, 'i j h -> h i j')
return bias * self.scale
# main class
class Attend(nn.Module):
def __init__(
self,
causal = False,
dropout = 0.,
flash = False
):
super().__init__()
self.dropout = dropout
self.attn_dropout = nn.Dropout(dropout)
self.causal = causal
self.flash = flash
# determine efficient attention configs for cuda and cpu
self.cpu_config = EfficientAttentionConfig(True, True, True)
self.cuda_config = None
if not torch.cuda.is_available() or not flash:
return
device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
if device_properties.major == 8 and device_properties.minor == 0:
print_once('A100 GPU detected, using flash attention if input tensor is on cuda') # noqa
self.cuda_config = EfficientAttentionConfig(True, True, True)
else:
print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda') # noqa
self.cuda_config = EfficientAttentionConfig(False, True, True)
def get_mask(self, i, j, device):
return torch.ones((i, j), device=device, dtype=torch.bool).triu(j - i + 1) # noqa
def flash_attn(self, q, k, v, mask = None, attn_bias = None):
_, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device # noqa
# single headed key / values
if k.ndim == 3:
k = rearrange(k, 'b n d -> b 1 n d')
if v.ndim == 3:
v = rearrange(v, 'b n d -> b 1 n d')
# Check if mask exists and expand to compatible shape
# The mask is B L, so it would have to be expanded to B H N L
if exists(mask) and mask.ndim != 4:
mask = rearrange(mask, 'b j -> b 1 1 j')
mask = mask.expand(-1, heads, q_len, -1)
# Check if there is a compatible device for flash attention
config = self.cuda_config if is_cuda else self.cpu_config
causal = self.causal
# handle attention bias
if exists(attn_bias):
mask_value = -torch.finfo(q.dtype).max // 2
causal_mask = self.get_mask(q_len, k_len, device)
attn_bias = attn_bias.masked_fill(causal_mask, mask_value)
if exists(mask):
attn_bias = attn_bias.masked_fill(~mask, mask_value)
mask = attn_bias
causal = False
# pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale
with torch.backends.cuda.sdp_kernel(**config._asdict()):
out = F.scaled_dot_product_attention(
q, k, v,
attn_mask = mask,
dropout_p = self.dropout if self.training else 0.,
is_causal = causal
)
return out
def forward(self, q, k, v, mask = None, attn_bias = None):
"""
einstein notation
b - batch
h - heads
n, i, j - sequence length (base sequence length, source, target)
d - feature dimension
"""
q_len, k_len, device = q.shape[-2], k.shape[-2], q.device
scale = q.shape[-1] ** -0.5
kv_einsum_eq = 'b j d' if k.ndim == 3 else 'b h j d'
if self.flash:
assert not exists(attn_bias)
return self.flash_attn(q, k, v, mask = mask)
# similarity
sim = einsum(f"b h i d, {kv_einsum_eq} -> b h i j", q, k) * scale
# attention bias
if exists(attn_bias):
sim = sim + attn_bias
# causal mask
if self.causal:
causal_mask = self.get_mask(q_len, k_len, device)
sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
# key padding mask
if exists(mask):
if mask.ndim != 4:
mask = rearrange(mask, 'b j -> b 1 1 j')
sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)
# attention
attn = sim.softmax(dim=-1)
attn = self.attn_dropout(attn)
# aggregate values
out = einsum(f"b h i j, {kv_einsum_eq} -> b h i d", attn, v)
return out
class Swish(nn.Module):
def forward(self, x):
return x * x.sigmoid()
class GLU(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
out, gate = x.chunk(2, dim=self.dim)
return out * gate.sigmoid()
class DepthWiseConv1d(nn.Module):
def __init__(self, chan_in, chan_out, kernel_size, padding):
super().__init__()
self.padding = padding
self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, groups = chan_in)
def forward(self, x):
x = F.pad(x, self.padding)
return self.conv(x)
class Scale(nn.Module):
def __init__(self, scale, fn):
super().__init__()
self.fn = fn
self.scale = scale
def forward(self, x, **kwargs):
return self.fn(x, **kwargs) * self.scale
class ChanLayerNorm(nn.Module):
def __init__(self, dim):
super().__init__()
self.gamma = nn.Parameter(torch.ones(1, dim, 1))
def forward(self, x):
eps = 1e-6 if x.dtype == torch.float32 else 1e-4
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) * var.clamp(min = eps).rsqrt() * self.gamma
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.fn = fn
self.norm = nn.LayerNorm(dim)
def forward(self, x, **kwargs):
x = self.norm(x)
return self.fn(x, **kwargs)
class Attention(nn.Module):
def __init__(
self,
dim,
heads = 8,
dim_head = 64,
dropout = 0.,
flash = True
):
super().__init__()
inner_dim = dim_head * heads
self.heads= heads
self.scale = dim_head ** -0.5
self.attend = Attend(
flash = flash,
dropout = dropout
)
self.dropout = nn.Dropout(dropout)
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
self.to_out = nn.Linear(inner_dim, dim)
def forward(
self,
x,
context = None,
mask = None,
rotary_emb = None,
attn_bias = None
):
n, device, h, has_context = x.shape[-2], x.device, self.heads, exists(context)
context = default(context, x)
q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
q, k, v = map(
lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
if exists(rotary_emb):
q = apply_rotary_pos_emb(rotary_emb, q)
k = apply_rotary_pos_emb(rotary_emb, k)
out = self.attend(q, k, v, mask = mask, attn_bias = attn_bias)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class FeedForward(nn.Module):
def __init__(
self,
dim,
mult = 4,
dropout = 0.
):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, dim * mult),
Swish(),
nn.Dropout(dropout),
nn.Linear(dim * mult, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class ConformerConvModule(nn.Module):
def __init__(
self,
dim,
causal = False,
expansion_factor = 2,
kernel_size = 31,
dropout = 0.
):
super().__init__()
inner_dim = dim * expansion_factor
padding = calc_same_padding(kernel_size) if not causal else (kernel_size - 1, 0)
self.net = nn.Sequential(
nn.LayerNorm(dim),
Rearrange('b n c -> b c n'),
nn.Conv1d(dim, inner_dim * 2, 1),
GLU(dim=1),
DepthWiseConv1d(
inner_dim, inner_dim, kernel_size = kernel_size,
padding = padding
),
Swish(),
ChanLayerNorm(inner_dim),
nn.Conv1d(inner_dim, dim, 1),
Rearrange('b c n -> b n c'),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
# Conformer Block
class ConformerBlock(nn.Module):
def __init__(
self,
*,
dim,
dim_head = 64,
heads = 8,
ff_mult = 4,
conv_expansion_factor = 2,
conv_kernel_size = 31,
attn_dropout = 0.,
attn_flash = True,
ff_dropout = 0.,
conv_dropout = 0.,
conv_causal = False
):
super().__init__()
self.ff1 = FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)
self.attn = Attention(
dim = dim, dim_head = dim_head, heads = heads,
dropout = attn_dropout, flash = attn_flash
)
self.conv = ConformerConvModule(
dim = dim, causal = conv_causal,
expansion_factor = conv_expansion_factor,
kernel_size = conv_kernel_size, dropout = conv_dropout
)
self.ff2 = FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)
self.attn = PreNorm(dim, self.attn)
self.ff1 = Scale(0.5, PreNorm(dim, self.ff1))
self.ff2 = Scale(0.5, PreNorm(dim, self.ff2))
self.post_norm = nn.LayerNorm(dim)
def forward(
self,
x,
mask = None,
rotary_emb = None,
attn_bias = None
):
x = self.ff1(x) + x
x = self.attn(x, mask = mask, rotary_emb = rotary_emb, attn_bias = attn_bias) + x # noqa
x = self.conv(x) + x
x = self.ff2(x) + x
x = self.post_norm(x)
return x
# Conformer
class Conformer(nn.Module):
def __init__(
self,
dim,
*,
num_layers,
dim_head = 64,
heads = 8,
ff_mult = 4,
conv_expansion_factor = 2,
conv_kernel_size = 31,
attn_dropout = 0.,
ff_dropout = 0.,
conv_dropout = 0.,
conv_causal = False,
attn_flash = True,
t5_rel_pos_bias = False
):
super().__init__()
assert not (t5_rel_pos_bias and attn_flash), 'flash attention is not compatible with learned bias' # noqa
self.dim = dim
self.layers = nn.ModuleList([])
self.rotary_emb = RotaryEmbedding(
dim_head) if not t5_rel_pos_bias else None
self.rel_pos_bias = T5RelativePositionBias(
dim_head ** 0.5, heads = heads) if t5_rel_pos_bias else None
for _ in range(num_layers):
self.layers.append(ConformerBlock(
dim = dim,
dim_head = dim_head,
heads = heads,
ff_mult = ff_mult,
conv_expansion_factor = conv_expansion_factor,
conv_kernel_size = conv_kernel_size,
attn_dropout = attn_dropout,
ff_dropout = ff_dropout,
conv_dropout = conv_dropout,
conv_causal = conv_causal,
attn_flash = attn_flash
))
def forward(self, x, mask = None):
seq_len = x.shape[-2]
rotary_emb = self.rotary_emb(seq_len) if exists(self.rotary_emb) else None # noqa
attn_bias = self.rel_pos_bias(seq_len) if exists(self.rel_pos_bias) else None #noqa
for block in self.layers:
x = block(
x,
mask = mask,
rotary_emb = rotary_emb,
attn_bias = attn_bias
)
return x
# conformer with sum reduction across quantized tokens at the beginning,
# along with heads
class ConformerWrapper(nn.Module):
def __init__(
self,
*,
codebook_size,
num_quantizers,
conformer: Union[Conformer, Dict[str, any]],
grouped_quantizers = 1
):
super().__init__()
self.conformer = conformer
if isinstance(conformer, dict):
self.conformer = Conformer(**self.conformer)
dim = self.conformer.dim
self.embedding_proj = nn.Sequential(
nn.Linear(dim * grouped_quantizers, dim),
nn.LayerNorm(dim)
) if grouped_quantizers > 1 else nn.Identity()
num_codes_with_mask = codebook_size + 1
num_effective_quantizers = num_quantizers * grouped_quantizers
self.code_embeds = nn.Embedding(
num_codes_with_mask * num_effective_quantizers, dim)
self.register_buffer(
'quantizer_offsets',
torch.arange(num_effective_quantizers) * num_codes_with_mask,
persistent = False
)
self.register_buffer(
'mask_tokens', self.quantizer_offsets + num_codes_with_mask,
persistent = False
)
self.dim = dim
self.codebook_size = codebook_size
self.num_codes_with_mask = num_codes_with_mask
self.num_quantizers = num_quantizers
self.grouped_quantizers = grouped_quantizers
self.heads = nn.Sequential(
nn.Linear(dim, dim * num_effective_quantizers),
Rearrange('b n (h d) -> b (n h) d', h = num_effective_quantizers)
)
# each quantizer codebook would require its own logits weight
# and bias matrices
# the amazing einops makes this easy with 'EinMix'
self.to_logits = nn.Sequential(
nn.LayerNorm(dim),
Rearrange('b (n gq) d -> b n gq d', gq = num_effective_quantizers),
EinMix(
'b n gq d -> b n gq l',
weight_shape = 'gq d l',
bias_shape = 'gq l',
gq = num_effective_quantizers,
l = codebook_size,
d = dim
),
Rearrange('b ... d -> b (...) d')
)
def forward(
self,
x,
*,
mask = None,
cond = None,
sum_embeds = None,
return_embeddings = False,
return_logits_and_embeddings = False
):
"""
einops notation:
b - batch
n - sequence
g - groups
q - quantizers
d - feature dimension
"""
n, q, g = x.shape[-1], self.num_quantizers, self.grouped_quantizers
assert divisible_by(n, g * q), 'sequence must be divisible by number of quantizers' # noqa
x = rearrange(x, 'b (n gq) -> b n gq', gq = g * q)
x = x + self.quantizer_offsets
x = self.code_embeds(x)
x = reduce(x, 'b n (g q) d -> b n (g d)', 'sum', g = g)
x = self.embedding_proj(x)
if exists(sum_embeds):
x = x + sum_embeds
if exists(cond):
if cond.ndim == 2:
cond = rearrange(cond, 'b d -> b 1 d')
x = x + cond
x = self.conformer(x, mask = mask)
embeds = self.heads(x)
if return_embeddings or not exists(self.to_logits):
return embeds
logits = self.to_logits(embeds)
if return_logits_and_embeddings:
return logits, embeds
return logits