Spaces:
Runtime error
Runtime error
import math | |
from collections import OrderedDict | |
from typing import Optional | |
from torch import Tensor | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.optim as optim | |
from torchmetrics.functional import( | |
scale_invariant_signal_noise_ratio as si_snr, | |
signal_noise_ratio as snr, | |
signal_distortion_ratio as sdr, | |
scale_invariant_signal_distortion_ratio as si_sdr) | |
from speechbrain.lobes.models.transformer.Transformer import PositionalEncoding | |
def mod_pad(x, chunk_size, pad): | |
# Mod pad the input to perform integer number of | |
# inferences | |
mod = 0 | |
if (x.shape[-1] % chunk_size) != 0: | |
mod = chunk_size - (x.shape[-1] % chunk_size) | |
x = F.pad(x, (0, mod)) | |
x = F.pad(x, pad) | |
return x, mod | |
class LayerNormPermuted(nn.LayerNorm): | |
def __init__(self, *args, **kwargs): | |
super(LayerNormPermuted, self).__init__(*args, **kwargs) | |
def forward(self, x): | |
""" | |
Args: | |
x: [B, C, T] | |
""" | |
x = x.permute(0, 2, 1) # [B, T, C] | |
x = super().forward(x) | |
x = x.permute(0, 2, 1) # [B, C, T] | |
return x | |
class DepthwiseSeparableConv(nn.Module): | |
""" | |
Depthwise separable convolutions | |
""" | |
def __init__(self, in_channels, out_channels, kernel_size, stride, | |
padding, dilation): | |
super(DepthwiseSeparableConv, self).__init__() | |
self.layers = nn.Sequential( | |
nn.Conv1d(in_channels, in_channels, kernel_size, stride, | |
padding, groups=in_channels, dilation=dilation), | |
LayerNormPermuted(in_channels), | |
nn.ReLU(), | |
nn.Conv1d(in_channels, out_channels, kernel_size=1, stride=1, | |
padding=0), | |
LayerNormPermuted(out_channels), | |
nn.ReLU(), | |
) | |
def forward(self, x): | |
return self.layers(x) | |
class DilatedCausalConvEncoder(nn.Module): | |
""" | |
A dilated causal convolution based encoder for encoding | |
time domain audio input into latent space. | |
""" | |
def __init__(self, channels, num_layers, kernel_size=3): | |
super(DilatedCausalConvEncoder, self).__init__() | |
self.channels = channels | |
self.num_layers = num_layers | |
self.kernel_size = kernel_size | |
# Compute buffer lengths for each layer | |
# buf_length[i] = (kernel_size - 1) * dilation[i] | |
self.buf_lengths = [(kernel_size - 1) * 2**i | |
for i in range(num_layers)] | |
# Compute buffer start indices for each layer | |
self.buf_indices = [0] | |
for i in range(num_layers - 1): | |
self.buf_indices.append( | |
self.buf_indices[-1] + self.buf_lengths[i]) | |
# Dilated causal conv layers aggregate previous context to obtain | |
# contexful encoded input. | |
_dcc_layers = OrderedDict() | |
for i in range(num_layers): | |
dcc_layer = DepthwiseSeparableConv( | |
channels, channels, kernel_size=3, stride=1, | |
padding=0, dilation=2**i) | |
_dcc_layers.update({'dcc_%d' % i: dcc_layer}) | |
self.dcc_layers = nn.Sequential(_dcc_layers) | |
def init_ctx_buf(self, batch_size, device): | |
""" | |
Returns an initialized context buffer for a given batch size. | |
""" | |
return torch.zeros( | |
(batch_size, self.channels, | |
(self.kernel_size - 1) * (2**self.num_layers - 1)), | |
device=device) | |
def forward(self, x, ctx_buf): | |
""" | |
Encodes input audio `x` into latent space, and aggregates | |
contextual information in `ctx_buf`. Also generates new context | |
buffer with updated context. | |
Args: | |
x: [B, in_channels, T] | |
Input multi-channel audio. | |
ctx_buf: {[B, channels, self.buf_length[0]], ...} | |
A list of tensors holding context for each dilation | |
causal conv layer. (len(ctx_buf) == self.num_layers) | |
Returns: | |
ctx_buf: {[B, channels, self.buf_length[0]], ...} | |
Updated context buffer with output as the | |
last element. | |
""" | |
T = x.shape[-1] # Sequence length | |
for i in range(self.num_layers): | |
buf_start_idx = self.buf_indices[i] | |
buf_end_idx = self.buf_indices[i] + self.buf_lengths[i] | |
# DCC input: concatenation of current output and context | |
dcc_in = torch.cat( | |
(ctx_buf[..., buf_start_idx:buf_end_idx], x), dim=-1) | |
# Push current output to the context buffer | |
ctx_buf[..., buf_start_idx:buf_end_idx] = \ | |
dcc_in[..., -self.buf_lengths[i]:] | |
# Residual connection | |
x = x + self.dcc_layers[i](dcc_in) | |
return x, ctx_buf | |
class CausalTransformerDecoderLayer(torch.nn.TransformerDecoderLayer): | |
""" | |
Adapted from: | |
"https://github.com/alexmt-scale/causal-transformer-decoder/blob/" | |
"0caf6ad71c46488f76d89845b0123d2550ef792f/" | |
"causal_transformer_decoder/model.py#L77" | |
""" | |
def forward( | |
self, | |
tgt: Tensor, | |
memory: Optional[Tensor] = None, | |
chunk_size: int = 1 | |
) -> Tensor: | |
tgt_last_tok = tgt[:, -chunk_size:, :] | |
# self attention part | |
tmp_tgt, sa_map = self.self_attn( | |
tgt_last_tok, | |
tgt, | |
tgt, | |
attn_mask=None, # not needed because we only care about the last token | |
key_padding_mask=None, | |
) | |
tgt_last_tok = tgt_last_tok + self.dropout1(tmp_tgt) | |
tgt_last_tok = self.norm1(tgt_last_tok) | |
# encoder-decoder attention | |
if memory is not None: | |
tmp_tgt, ca_map = self.multihead_attn( | |
tgt_last_tok, | |
memory, | |
memory, | |
attn_mask=None, # Attend to the entire chunk | |
key_padding_mask=None, | |
) | |
tgt_last_tok = tgt_last_tok + self.dropout2(tmp_tgt) | |
tgt_last_tok = self.norm2(tgt_last_tok) | |
# final feed-forward network | |
tmp_tgt = self.linear2( | |
self.dropout(self.activation(self.linear1(tgt_last_tok))) | |
) | |
tgt_last_tok = tgt_last_tok + self.dropout3(tmp_tgt) | |
tgt_last_tok = self.norm3(tgt_last_tok) | |
return tgt_last_tok, sa_map, ca_map | |
class CausalTransformerDecoder(nn.Module): | |
""" | |
A casual transformer decoder which decodes input vectors using | |
precisely `ctx_len` past vectors in the sequence, and using no future | |
vectors at all. | |
""" | |
def __init__(self, model_dim, ctx_len, chunk_size, num_layers, | |
nhead, use_pos_enc, ff_dim): | |
super(CausalTransformerDecoder, self).__init__() | |
self.num_layers = num_layers | |
self.model_dim = model_dim | |
self.ctx_len = ctx_len | |
self.chunk_size = chunk_size | |
self.nhead = nhead | |
self.use_pos_enc = use_pos_enc | |
self.unfold = nn.Unfold(kernel_size=(ctx_len + chunk_size, 1), stride=chunk_size) | |
self.pos_enc = PositionalEncoding(model_dim, max_len=200) | |
self.tf_dec_layers = nn.ModuleList([CausalTransformerDecoderLayer( | |
d_model=model_dim, nhead=nhead, dim_feedforward=ff_dim, | |
batch_first=True) for _ in range(num_layers)]) | |
def init_ctx_buf(self, batch_size, device): | |
return torch.zeros( | |
(batch_size, self.num_layers + 1, self.ctx_len, self.model_dim), | |
device=device) | |
def _causal_unfold(self, x): | |
""" | |
Unfolds the sequence into a batch of sequences | |
prepended with `ctx_len` previous values. | |
Args: | |
x: [B, ctx_len + L, C] | |
ctx_len: int | |
Returns: | |
[B * L, ctx_len + 1, C] | |
""" | |
B, T, C = x.shape | |
x = x.permute(0, 2, 1) # [B, C, ctx_len + L] | |
x = self.unfold(x.unsqueeze(-1)) # [B, C * (ctx_len + chunk_size), -1] | |
x = x.permute(0, 2, 1) | |
x = x.reshape(B, -1, C, self.ctx_len + self.chunk_size) | |
x = x.reshape(-1, C, self.ctx_len + self.chunk_size) | |
x = x.permute(0, 2, 1) | |
return x | |
def forward(self, tgt, mem, ctx_buf, probe=False): | |
""" | |
Args: | |
x: [B, model_dim, T] | |
ctx_buf: [B, num_layers, model_dim, ctx_len] | |
""" | |
mem, _ = mod_pad(mem, self.chunk_size, (0, 0)) | |
tgt, mod = mod_pad(tgt, self.chunk_size, (0, 0)) | |
# Input sequence length | |
B, C, T = tgt.shape | |
tgt = tgt.permute(0, 2, 1) | |
mem = mem.permute(0, 2, 1) | |
# Prepend mem with the context | |
mem = torch.cat((ctx_buf[:, 0, :, :], mem), dim=1) | |
ctx_buf[:, 0, :, :] = mem[:, -self.ctx_len:, :] | |
mem_ctx = self._causal_unfold(mem) | |
if self.use_pos_enc: | |
mem_ctx = mem_ctx + self.pos_enc(mem_ctx) | |
# Attention chunk size: required to ensure the model | |
# wouldn't trigger an out-of-memory error when working | |
# on long sequences. | |
K = 1000 | |
for i, tf_dec_layer in enumerate(self.tf_dec_layers): | |
# Update the tgt with context | |
tgt = torch.cat((ctx_buf[:, i + 1, :, :], tgt), dim=1) | |
ctx_buf[:, i + 1, :, :] = tgt[:, -self.ctx_len:, :] | |
# Compute encoded output | |
tgt_ctx = self._causal_unfold(tgt) | |
if self.use_pos_enc and i == 0: | |
tgt_ctx = tgt_ctx + self.pos_enc(tgt_ctx) | |
tgt = torch.zeros_like(tgt_ctx)[:, -self.chunk_size:, :] | |
for i in range(int(math.ceil(tgt.shape[0] / K))): | |
tgt[i*K:(i+1)*K], _sa_map, _ca_map = tf_dec_layer( | |
tgt_ctx[i*K:(i+1)*K], mem_ctx[i*K:(i+1)*K], | |
self.chunk_size) | |
tgt = tgt.reshape(B, T, C) | |
tgt = tgt.permute(0, 2, 1) | |
if mod != 0: | |
tgt = tgt[..., :-mod] | |
return tgt, ctx_buf | |
class MaskNet(nn.Module): | |
def __init__(self, enc_dim, num_enc_layers, dec_dim, dec_buf_len, | |
dec_chunk_size, num_dec_layers, use_pos_enc, skip_connection, proj): | |
super(MaskNet, self).__init__() | |
self.skip_connection = skip_connection | |
self.proj = proj | |
# Encoder based on dilated causal convolutions. | |
self.encoder = DilatedCausalConvEncoder(channels=enc_dim, | |
num_layers=num_enc_layers) | |
# Project between encoder and decoder dimensions | |
self.proj_e2d_e = nn.Sequential( | |
nn.Conv1d(enc_dim, dec_dim, kernel_size=1, stride=1, padding=0, | |
groups=dec_dim), | |
nn.ReLU()) | |
self.proj_e2d_l = nn.Sequential( | |
nn.Conv1d(enc_dim, dec_dim, kernel_size=1, stride=1, padding=0, | |
groups=dec_dim), | |
nn.ReLU()) | |
self.proj_d2e = nn.Sequential( | |
nn.Conv1d(dec_dim, enc_dim, kernel_size=1, stride=1, padding=0, | |
groups=dec_dim), | |
nn.ReLU()) | |
# Transformer decoder that operates on chunks of size | |
# buffer size. | |
self.decoder = CausalTransformerDecoder( | |
model_dim=dec_dim, ctx_len=dec_buf_len, chunk_size=dec_chunk_size, | |
num_layers=num_dec_layers, nhead=8, use_pos_enc=use_pos_enc, | |
ff_dim=2 * dec_dim) | |
def forward(self, x, l, enc_buf, dec_buf): | |
""" | |
Generates a mask based on encoded input `e` and the one-hot | |
label `label`. | |
Args: | |
x: [B, C, T] | |
Input audio sequence | |
l: [B, C] | |
Label embedding | |
ctx_buf: {[B, C, <receptive field of the layer>], ...} | |
List of context buffers maintained by DCC encoder | |
""" | |
# Enocder the label integrated input | |
e, enc_buf = self.encoder(x, enc_buf) | |
# Label integration | |
l = l.unsqueeze(2) * e | |
# Project to `dec_dim` dimensions | |
if self.proj: | |
e = self.proj_e2d_e(e) | |
m = self.proj_e2d_l(l) | |
# Cross-attention to predict the mask | |
m, dec_buf = self.decoder(m, e, dec_buf) | |
else: | |
# Cross-attention to predict the mask | |
m, dec_buf = self.decoder(l, e, dec_buf) | |
# Project mask to encoder dimensions | |
if self.proj: | |
m = self.proj_d2e(m) | |
# Final mask after residual connection | |
if self.skip_connection: | |
m = l + m | |
return m, enc_buf, dec_buf | |
class Net(nn.Module): | |
def __init__(self, label_len, L=8, | |
enc_dim=512, num_enc_layers=10, | |
dec_dim=256, dec_buf_len=100, num_dec_layers=2, | |
dec_chunk_size=72, out_buf_len=2, | |
use_pos_enc=True, skip_connection=True, proj=True, lookahead=True): | |
super(Net, self).__init__() | |
self.L = L | |
self.out_buf_len = out_buf_len | |
self.enc_dim = enc_dim | |
self.lookahead = lookahead | |
# Input conv to convert input audio to a latent representation | |
kernel_size = 3 * L if lookahead else L | |
self.in_conv = nn.Sequential( | |
nn.Conv1d(in_channels=1, | |
out_channels=enc_dim, kernel_size=kernel_size, stride=L, | |
padding=0, bias=False), | |
nn.ReLU()) | |
# Label embedding layer | |
self.label_embedding = nn.Sequential( | |
nn.Linear(label_len, 512), | |
nn.LayerNorm(512), | |
nn.ReLU(), | |
nn.Linear(512, enc_dim), | |
nn.LayerNorm(enc_dim), | |
nn.ReLU()) | |
# Mask generator | |
self.mask_gen = MaskNet( | |
enc_dim=enc_dim, num_enc_layers=num_enc_layers, | |
dec_dim=dec_dim, dec_buf_len=dec_buf_len, | |
dec_chunk_size=dec_chunk_size, num_dec_layers=num_dec_layers, | |
use_pos_enc=use_pos_enc, skip_connection=skip_connection, proj=proj) | |
# Output conv layer | |
self.out_conv = nn.Sequential( | |
nn.ConvTranspose1d( | |
in_channels=enc_dim, out_channels=1, | |
kernel_size=(out_buf_len + 1) * L, | |
stride=L, | |
padding=out_buf_len * L, bias=False), | |
nn.Tanh()) | |
def init_buffers(self, batch_size, device): | |
enc_buf = self.mask_gen.encoder.init_ctx_buf(batch_size, device) | |
dec_buf = self.mask_gen.decoder.init_ctx_buf(batch_size, device) | |
out_buf = torch.zeros(batch_size, self.enc_dim, self.out_buf_len, | |
device=device) | |
return enc_buf, dec_buf, out_buf | |
def forward(self, x, label, init_enc_buf=None, init_dec_buf=None, | |
init_out_buf=None, pad=True): | |
""" | |
Extracts the audio corresponding to the `label` in the given | |
`mixture`. Generates `chunk_size` samples per iteration. | |
Args: | |
mixed: [B, n_mics, T] | |
input audio mixture | |
label: [B, num_labels] | |
one hot label | |
Returns: | |
out: [B, n_spk, T] | |
extracted audio with sounds corresponding to the `label` | |
""" | |
mod = 0 | |
if pad: | |
pad_size = (self.L, self.L) if self.lookahead else (0, 0) | |
x, mod = mod_pad(x, chunk_size=self.L, pad=pad_size) | |
if init_enc_buf is None or init_dec_buf is None or init_out_buf is None: | |
assert init_enc_buf is None and \ | |
init_dec_buf is None and \ | |
init_out_buf is None, \ | |
"Both buffers have to initialized, or " \ | |
"both of them have to be None." | |
enc_buf, dec_buf, out_buf = self.init_buffers( | |
x.shape[0], x.device) | |
else: | |
enc_buf, dec_buf, out_buf = \ | |
init_enc_buf, init_dec_buf, init_out_buf | |
# Generate latent space representation of the input | |
x = self.in_conv(x) | |
# Generate label embedding | |
l = self.label_embedding(label) # [B, label_len] --> [B, channels] | |
# Generate mask corresponding to the label | |
m, enc_buf, dec_buf = self.mask_gen(x, l, enc_buf, dec_buf) | |
# Apply mask and decode | |
x = x * m | |
x = torch.cat((out_buf, x), dim=-1) | |
out_buf = x[..., -self.out_buf_len:] | |
x = self.out_conv(x) | |
# Remove mod padding, if present. | |
if mod != 0: | |
x = x[:, :, :-mod] | |
if init_enc_buf is None: | |
return x | |
else: | |
return x, enc_buf, dec_buf, out_buf | |
# Define optimizer, loss and metrics | |
def optimizer(model, data_parallel=False, **kwargs): | |
return optim.Adam(model.parameters(), **kwargs) | |
def loss(pred, tgt): | |
return -0.9 * snr(pred, tgt).mean() - 0.1 * si_snr(pred, tgt).mean() | |
def metrics(mixed, output, gt): | |
""" Function to compute metrics """ | |
metrics = {} | |
def metric_i(metric, src, pred, tgt): | |
_vals = [] | |
for s, t, p in zip(src, tgt, pred): | |
_vals.append((metric(p, t) - metric(s, t)).cpu().item()) | |
return _vals | |
for m_fn in [snr, si_snr]: | |
metrics[m_fn.__name__] = metric_i(m_fn, | |
mixed[:, :gt.shape[1], :], | |
output, | |
gt) | |
return metrics | |