|
from typing import Optional |
|
import torch |
|
import torch.nn as nn |
|
from torch import Tensor |
|
from torch.utils.data import DataLoader |
|
|
|
def atan2(y, x): |
|
r"""Element-wise arctangent function of y/x. |
|
Returns a new tensor with signed angles in radians. |
|
It is an alternative implementation of torch.atan2 |
|
|
|
Args: |
|
y (Tensor): First input tensor |
|
x (Tensor): Second input tensor [shape=y.shape] |
|
|
|
Returns: |
|
Tensor: [shape=y.shape]. |
|
""" |
|
pi = 2 * torch.asin(torch.tensor(1.0)) |
|
x += ((x == 0) & (y == 0)) * 1.0 |
|
out = torch.atan(y / x) |
|
out += ((y >= 0) & (x < 0)) * pi |
|
out -= ((y < 0) & (x < 0)) * pi |
|
out *= 1 - ((y > 0) & (x == 0)) * 1.0 |
|
out += ((y > 0) & (x == 0)) * (pi / 2) |
|
out *= 1 - ((y < 0) & (x == 0)) * 1.0 |
|
out += ((y < 0) & (x == 0)) * (-pi / 2) |
|
return out |
|
|
|
|
|
|
|
|
|
|
|
|
|
def _norm(x: torch.Tensor) -> torch.Tensor: |
|
r"""Computes the norm value of a torch Tensor, assuming that it |
|
comes as real and imaginary part in its last dimension. |
|
|
|
Args: |
|
x (Tensor): Input Tensor of shape [shape=(..., 2)] |
|
|
|
Returns: |
|
Tensor: shape as x excluding the last dimension. |
|
""" |
|
return torch.abs(x[..., 0]) ** 2 + torch.abs(x[..., 1]) ** 2 |
|
|
|
|
|
def _mul_add(a: torch.Tensor, b: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor: |
|
"""Element-wise multiplication of two complex Tensors described |
|
through their real and imaginary parts. |
|
The result is added to the `out` tensor""" |
|
|
|
|
|
target_shape = torch.Size([max(sa, sb) for (sa, sb) in zip(a.shape, b.shape)]) |
|
if out is None or out.shape != target_shape: |
|
out = torch.zeros(target_shape, dtype=a.dtype, device=a.device) |
|
if out is a: |
|
real_a = a[..., 0] |
|
out[..., 0] = out[..., 0] + (real_a * b[..., 0] - a[..., 1] * b[..., 1]) |
|
out[..., 1] = out[..., 1] + (real_a * b[..., 1] + a[..., 1] * b[..., 0]) |
|
else: |
|
out[..., 0] = out[..., 0] + (a[..., 0] * b[..., 0] - a[..., 1] * b[..., 1]) |
|
out[..., 1] = out[..., 1] + (a[..., 0] * b[..., 1] + a[..., 1] * b[..., 0]) |
|
return out |
|
|
|
|
|
def _mul(a: torch.Tensor, b: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor: |
|
"""Element-wise multiplication of two complex Tensors described |
|
through their real and imaginary parts |
|
can work in place in case out is a only""" |
|
target_shape = torch.Size([max(sa, sb) for (sa, sb) in zip(a.shape, b.shape)]) |
|
if out is None or out.shape != target_shape: |
|
out = torch.zeros(target_shape, dtype=a.dtype, device=a.device) |
|
if out is a: |
|
real_a = a[..., 0] |
|
out[..., 0] = real_a * b[..., 0] - a[..., 1] * b[..., 1] |
|
out[..., 1] = real_a * b[..., 1] + a[..., 1] * b[..., 0] |
|
else: |
|
out[..., 0] = a[..., 0] * b[..., 0] - a[..., 1] * b[..., 1] |
|
out[..., 1] = a[..., 0] * b[..., 1] + a[..., 1] * b[..., 0] |
|
return out |
|
|
|
|
|
def _inv(z: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor: |
|
"""Element-wise multiplicative inverse of a Tensor with complex |
|
entries described through their real and imaginary parts. |
|
can work in place in case out is z""" |
|
ez = _norm(z) |
|
if out is None or out.shape != z.shape: |
|
out = torch.zeros_like(z) |
|
out[..., 0] = z[..., 0] / ez |
|
out[..., 1] = -z[..., 1] / ez |
|
return out |
|
|
|
|
|
def _conj(z, out: Optional[torch.Tensor] = None) -> torch.Tensor: |
|
"""Element-wise complex conjugate of a Tensor with complex entries |
|
described through their real and imaginary parts. |
|
can work in place in case out is z""" |
|
if out is None or out.shape != z.shape: |
|
out = torch.zeros_like(z) |
|
out[..., 0] = z[..., 0] |
|
out[..., 1] = -z[..., 1] |
|
return out |
|
|
|
|
|
def _invert(M: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor: |
|
""" |
|
Invert 1x1 or 2x2 matrices |
|
|
|
Will generate errors if the matrices are singular: user must handle this |
|
through his own regularization schemes. |
|
|
|
Args: |
|
M (Tensor): [shape=(..., nb_channels, nb_channels, 2)] |
|
matrices to invert: must be square along dimensions -3 and -2 |
|
|
|
Returns: |
|
invM (Tensor): [shape=M.shape] |
|
inverses of M |
|
""" |
|
nb_channels = M.shape[-2] |
|
|
|
if out is None or out.shape != M.shape: |
|
out = torch.empty_like(M) |
|
|
|
if nb_channels == 1: |
|
|
|
out = _inv(M, out) |
|
elif nb_channels == 2: |
|
|
|
|
|
|
|
det = _mul(M[..., 0, 0, :], M[..., 1, 1, :]) |
|
det = det - _mul(M[..., 0, 1, :], M[..., 1, 0, :]) |
|
|
|
invDet = _inv(det) |
|
|
|
|
|
out[..., 0, 0, :] = _mul(invDet, M[..., 1, 1, :], out[..., 0, 0, :]) |
|
out[..., 1, 0, :] = _mul(-invDet, M[..., 1, 0, :], out[..., 1, 0, :]) |
|
out[..., 0, 1, :] = _mul(-invDet, M[..., 0, 1, :], out[..., 0, 1, :]) |
|
out[..., 1, 1, :] = _mul(invDet, M[..., 0, 0, :], out[..., 1, 1, :]) |
|
else: |
|
raise Exception("Only 2 channels are supported for the torch version.") |
|
return out |
|
|
|
|
|
|
|
|
|
|
|
def expectation_maximization( |
|
y: torch.Tensor, |
|
x: torch.Tensor, |
|
iterations: int = 2, |
|
eps: float = 1e-10, |
|
batch_size: int = 200, |
|
): |
|
r"""Expectation maximization algorithm, for refining source separation |
|
estimates. |
|
|
|
This algorithm allows to make source separation results better by |
|
enforcing multichannel consistency for the estimates. This usually means |
|
a better perceptual quality in terms of spatial artifacts. |
|
|
|
The implementation follows the details presented in [1]_, taking |
|
inspiration from the original EM algorithm proposed in [2]_ and its |
|
weighted refinement proposed in [3]_, [4]_. |
|
It works by iteratively: |
|
|
|
* Re-estimate source parameters (power spectral densities and spatial |
|
covariance matrices) through :func:`get_local_gaussian_model`. |
|
|
|
* Separate again the mixture with the new parameters by first computing |
|
the new modelled mixture covariance matrices with :func:`get_mix_model`, |
|
prepare the Wiener filters through :func:`wiener_gain` and apply them |
|
with :func:`apply_filter``. |
|
|
|
References |
|
---------- |
|
.. [1] S. Uhlich and M. Porcu and F. Giron and M. Enenkl and T. Kemp and |
|
N. Takahashi and Y. Mitsufuji, "Improving music source separation based |
|
on deep neural networks through data augmentation and network |
|
blending." 2017 IEEE International Conference on Acoustics, Speech |
|
and Signal Processing (ICASSP). IEEE, 2017. |
|
|
|
.. [2] N.Q. Duong and E. Vincent and R.Gribonval. "Under-determined |
|
reverberant audio source separation using a full-rank spatial |
|
covariance model." IEEE Transactions on Audio, Speech, and Language |
|
Processing 18.7 (2010): 1830-1840. |
|
|
|
.. [3] A. Nugraha and A. Liutkus and E. Vincent. "Multichannel audio source |
|
separation with deep neural networks." IEEE/ACM Transactions on Audio, |
|
Speech, and Language Processing 24.9 (2016): 1652-1664. |
|
|
|
.. [4] A. Nugraha and A. Liutkus and E. Vincent. "Multichannel music |
|
separation with deep neural networks." 2016 24th European Signal |
|
Processing Conference (EUSIPCO). IEEE, 2016. |
|
|
|
.. [5] A. Liutkus and R. Badeau and G. Richard "Kernel additive models for |
|
source separation." IEEE Transactions on Signal Processing |
|
62.16 (2014): 4298-4310. |
|
|
|
Args: |
|
y (Tensor): [shape=(nb_frames, nb_bins, nb_channels, 2, nb_sources)] |
|
initial estimates for the sources |
|
x (Tensor): [shape=(nb_frames, nb_bins, nb_channels, 2)] |
|
complex STFT of the mixture signal |
|
iterations (int): [scalar] |
|
number of iterations for the EM algorithm. |
|
eps (float or None): [scalar] |
|
The epsilon value to use for regularization and filters. |
|
|
|
Returns: |
|
y (Tensor): [shape=(nb_frames, nb_bins, nb_channels, 2, nb_sources)] |
|
estimated sources after iterations |
|
v (Tensor): [shape=(nb_frames, nb_bins, nb_sources)] |
|
estimated power spectral densities |
|
R (Tensor): [shape=(nb_bins, nb_channels, nb_channels, 2, nb_sources)] |
|
estimated spatial covariance matrices |
|
|
|
Notes: |
|
* You need an initial estimate for the sources to apply this |
|
algorithm. This is precisely what the :func:`wiener` function does. |
|
* This algorithm *is not* an implementation of the "exact" EM |
|
proposed in [1]_. In particular, it does compute the posterior |
|
covariance matrices the same (exact) way. Instead, it uses the |
|
simplified approximate scheme initially proposed in [5]_ and further |
|
refined in [3]_, [4]_, that boils down to just take the empirical |
|
covariance of the recent source estimates, followed by a weighted |
|
average for the update of the spatial covariance matrix. It has been |
|
empirically demonstrated that this simplified algorithm is more |
|
robust for music separation. |
|
|
|
Warning: |
|
It is *very* important to make sure `x.dtype` is `torch.float64` |
|
if you want double precision, because this function will **not** |
|
do such conversion for you from `torch.complex32`, in case you want the |
|
smaller RAM usage on purpose. |
|
|
|
It is usually always better in terms of quality to have double |
|
precision, by e.g. calling :func:`expectation_maximization` |
|
with ``x.to(torch.float64)``. |
|
""" |
|
|
|
(nb_frames, nb_bins, nb_channels) = x.shape[:-1] |
|
nb_sources = y.shape[-1] |
|
|
|
regularization = torch.cat( |
|
( |
|
torch.eye(nb_channels, dtype=x.dtype, device=x.device)[..., None], |
|
torch.zeros((nb_channels, nb_channels, 1), dtype=x.dtype, device=x.device), |
|
), |
|
dim=2, |
|
) |
|
regularization = torch.sqrt(torch.as_tensor(eps)) * ( |
|
regularization[None, None, ...].expand((-1, nb_bins, -1, -1, -1)) |
|
) |
|
|
|
|
|
R = [ |
|
torch.zeros((nb_bins, nb_channels, nb_channels, 2), dtype=x.dtype, device=x.device) |
|
for j in range(nb_sources) |
|
] |
|
weight: torch.Tensor = torch.zeros((nb_bins,), dtype=x.dtype, device=x.device) |
|
|
|
v: torch.Tensor = torch.zeros((nb_frames, nb_bins, nb_sources), dtype=x.dtype, device=x.device) |
|
for it in range(iterations): |
|
|
|
|
|
|
|
|
|
v = torch.mean(torch.abs(y[..., 0, :]) ** 2 + torch.abs(y[..., 1, :]) ** 2, dim=-2) |
|
|
|
|
|
for j in range(nb_sources): |
|
R[j] = torch.tensor(0.0, device=x.device) |
|
weight = torch.tensor(eps, device=x.device) |
|
pos: int = 0 |
|
batch_size = batch_size if batch_size else nb_frames |
|
while pos < nb_frames: |
|
t = torch.arange(pos, min(nb_frames, pos + batch_size)) |
|
pos = int(t[-1]) + 1 |
|
|
|
R[j] = R[j] + torch.sum(_covariance(y[t, ..., j]), dim=0) |
|
weight = weight + torch.sum(v[t, ..., j], dim=0) |
|
R[j] = R[j] / weight[..., None, None, None] |
|
weight = torch.zeros_like(weight) |
|
|
|
|
|
if y.requires_grad: |
|
y = y.clone() |
|
|
|
pos = 0 |
|
while pos < nb_frames: |
|
t = torch.arange(pos, min(nb_frames, pos + batch_size)) |
|
pos = int(t[-1]) + 1 |
|
|
|
y[t, ...] = torch.tensor(0.0, device=x.device, dtype=x.dtype) |
|
|
|
|
|
Cxx = regularization |
|
for j in range(nb_sources): |
|
Cxx = Cxx + (v[t, ..., j, None, None, None] * R[j][None, ...].clone()) |
|
|
|
|
|
inv_Cxx = _invert(Cxx) |
|
|
|
|
|
for j in range(nb_sources): |
|
|
|
|
|
gain = torch.zeros_like(inv_Cxx) |
|
|
|
|
|
indices = torch.cartesian_prod( |
|
torch.arange(nb_channels), |
|
torch.arange(nb_channels), |
|
torch.arange(nb_channels), |
|
) |
|
for index in indices: |
|
gain[:, :, index[0], index[1], :] = _mul_add( |
|
R[j][None, :, index[0], index[2], :].clone(), |
|
inv_Cxx[:, :, index[2], index[1], :], |
|
gain[:, :, index[0], index[1], :], |
|
) |
|
gain = gain * v[t, ..., None, None, None, j] |
|
|
|
|
|
for i in range(nb_channels): |
|
y[t, ..., j] = _mul_add(gain[..., i, :], x[t, ..., i, None, :], y[t, ..., j]) |
|
|
|
return y, v, R |
|
|
|
|
|
def wiener( |
|
targets_spectrograms: torch.Tensor, |
|
mix_stft: torch.Tensor, |
|
iterations: int = 1, |
|
softmask: bool = False, |
|
residual: bool = False, |
|
scale_factor: float = 10.0, |
|
eps: float = 1e-10, |
|
): |
|
"""Wiener-based separation for multichannel audio. |
|
|
|
The method uses the (possibly multichannel) spectrograms of the |
|
sources to separate the (complex) Short Term Fourier Transform of the |
|
mix. Separation is done in a sequential way by: |
|
|
|
* Getting an initial estimate. This can be done in two ways: either by |
|
directly using the spectrograms with the mixture phase, or |
|
by using a softmasking strategy. This initial phase is controlled |
|
by the `softmask` flag. |
|
|
|
* If required, adding an additional residual target as the mix minus |
|
all targets. |
|
|
|
* Refinining these initial estimates through a call to |
|
:func:`expectation_maximization` if the number of iterations is nonzero. |
|
|
|
This implementation also allows to specify the epsilon value used for |
|
regularization. It is based on [1]_, [2]_, [3]_, [4]_. |
|
|
|
References |
|
---------- |
|
.. [1] S. Uhlich and M. Porcu and F. Giron and M. Enenkl and T. Kemp and |
|
N. Takahashi and Y. Mitsufuji, "Improving music source separation based |
|
on deep neural networks through data augmentation and network |
|
blending." 2017 IEEE International Conference on Acoustics, Speech |
|
and Signal Processing (ICASSP). IEEE, 2017. |
|
|
|
.. [2] A. Nugraha and A. Liutkus and E. Vincent. "Multichannel audio source |
|
separation with deep neural networks." IEEE/ACM Transactions on Audio, |
|
Speech, and Language Processing 24.9 (2016): 1652-1664. |
|
|
|
.. [3] A. Nugraha and A. Liutkus and E. Vincent. "Multichannel music |
|
separation with deep neural networks." 2016 24th European Signal |
|
Processing Conference (EUSIPCO). IEEE, 2016. |
|
|
|
.. [4] A. Liutkus and R. Badeau and G. Richard "Kernel additive models for |
|
source separation." IEEE Transactions on Signal Processing |
|
62.16 (2014): 4298-4310. |
|
|
|
Args: |
|
targets_spectrograms (Tensor): spectrograms of the sources |
|
[shape=(nb_frames, nb_bins, nb_channels, nb_sources)]. |
|
This is a nonnegative tensor that is |
|
usually the output of the actual separation method of the user. The |
|
spectrograms may be mono, but they need to be 4-dimensional in all |
|
cases. |
|
mix_stft (Tensor): [shape=(nb_frames, nb_bins, nb_channels, complex=2)] |
|
STFT of the mixture signal. |
|
iterations (int): [scalar] |
|
number of iterations for the EM algorithm |
|
softmask (bool): Describes how the initial estimates are obtained. |
|
* if `False`, then the mixture phase will directly be used with the |
|
spectrogram as initial estimates. |
|
* if `True`, initial estimates are obtained by multiplying the |
|
complex mix element-wise with the ratio of each target spectrogram |
|
with the sum of them all. This strategy is better if the model are |
|
not really good, and worse otherwise. |
|
residual (bool): if `True`, an additional target is created, which is |
|
equal to the mixture minus the other targets, before application of |
|
expectation maximization |
|
eps (float): Epsilon value to use for computing the separations. |
|
This is used whenever division with a model energy is |
|
performed, i.e. when softmasking and when iterating the EM. |
|
It can be understood as the energy of the additional white noise |
|
that is taken out when separating. |
|
|
|
Returns: |
|
Tensor: shape=(nb_frames, nb_bins, nb_channels, complex=2, nb_sources) |
|
STFT of estimated sources |
|
|
|
Notes: |
|
* Be careful that you need *magnitude spectrogram estimates* for the |
|
case `softmask==False`. |
|
* `softmask=False` is recommended |
|
* The epsilon value will have a huge impact on performance. If it's |
|
large, only the parts of the signal with a significant energy will |
|
be kept in the sources. This epsilon then directly controls the |
|
energy of the reconstruction error. |
|
|
|
Warning: |
|
As in :func:`expectation_maximization`, we recommend converting the |
|
mixture `x` to double precision `torch.float64` *before* calling |
|
:func:`wiener`. |
|
""" |
|
if softmask: |
|
|
|
|
|
y = ( |
|
mix_stft[..., None] |
|
* ( |
|
targets_spectrograms |
|
/ (eps + torch.sum(targets_spectrograms, dim=-1, keepdim=True).to(mix_stft.dtype)) |
|
)[..., None, :] |
|
) |
|
else: |
|
|
|
|
|
angle = atan2(mix_stft[..., 1], mix_stft[..., 0])[..., None] |
|
nb_sources = targets_spectrograms.shape[-1] |
|
y = torch.zeros( |
|
mix_stft.shape + (nb_sources,), dtype=mix_stft.dtype, device=mix_stft.device |
|
) |
|
y[..., 0, :] = targets_spectrograms * torch.cos(angle) |
|
y[..., 1, :] = targets_spectrograms * torch.sin(angle) |
|
|
|
if residual: |
|
|
|
|
|
y = torch.cat([y, mix_stft[..., None] - y.sum(dim=-1, keepdim=True)], dim=-1) |
|
|
|
if iterations == 0: |
|
return y |
|
|
|
|
|
|
|
max_abs = torch.max( |
|
torch.as_tensor(1.0, dtype=mix_stft.dtype, device=mix_stft.device), |
|
torch.sqrt(_norm(mix_stft)).max() / scale_factor, |
|
) |
|
|
|
mix_stft = mix_stft / max_abs |
|
y = y / max_abs |
|
|
|
|
|
y = expectation_maximization(y, mix_stft, iterations, eps=eps)[0] |
|
|
|
|
|
y = y * max_abs |
|
return y |
|
|
|
|
|
def _covariance(y_j): |
|
""" |
|
Compute the empirical covariance for a source. |
|
|
|
Args: |
|
y_j (Tensor): complex stft of the source. |
|
[shape=(nb_frames, nb_bins, nb_channels, 2)]. |
|
|
|
Returns: |
|
Cj (Tensor): [shape=(nb_frames, nb_bins, nb_channels, nb_channels, 2)] |
|
just y_j * conj(y_j.T): empirical covariance for each TF bin. |
|
""" |
|
(nb_frames, nb_bins, nb_channels) = y_j.shape[:-1] |
|
Cj = torch.zeros( |
|
(nb_frames, nb_bins, nb_channels, nb_channels, 2), |
|
dtype=y_j.dtype, |
|
device=y_j.device, |
|
) |
|
indices = torch.cartesian_prod(torch.arange(nb_channels), torch.arange(nb_channels)) |
|
for index in indices: |
|
Cj[:, :, index[0], index[1], :] = _mul_add( |
|
y_j[:, :, index[0], :], |
|
_conj(y_j[:, :, index[1], :]), |
|
Cj[:, :, index[0], index[1], :], |
|
) |
|
return Cj |
|
|