|
from typing import Tuple |
|
|
|
from pytorch_wpe import wpe_one_iteration |
|
import torch |
|
from torch_complex.tensor import ComplexTensor |
|
|
|
from espnet.nets.pytorch_backend.nets_utils import make_pad_mask |
|
from espnet2.enh.layers.mask_estimator import MaskEstimator |
|
|
|
|
|
class DNN_WPE(torch.nn.Module): |
|
def __init__( |
|
self, |
|
wtype: str = "blstmp", |
|
widim: int = 257, |
|
wlayers: int = 3, |
|
wunits: int = 300, |
|
wprojs: int = 320, |
|
dropout_rate: float = 0.0, |
|
taps: int = 5, |
|
delay: int = 3, |
|
use_dnn_mask: bool = True, |
|
nmask: int = 1, |
|
nonlinear: str = "sigmoid", |
|
iterations: int = 1, |
|
normalization: bool = False, |
|
eps: float = 1e-6, |
|
diagonal_loading: bool = True, |
|
diag_eps: float = 1e-7, |
|
mask_flooring: bool = False, |
|
flooring_thres: float = 1e-6, |
|
use_torch_solver: bool = True, |
|
): |
|
super().__init__() |
|
self.iterations = iterations |
|
self.taps = taps |
|
self.delay = delay |
|
self.eps = eps |
|
|
|
self.normalization = normalization |
|
self.use_dnn_mask = use_dnn_mask |
|
|
|
self.inverse_power = True |
|
self.diagonal_loading = diagonal_loading |
|
self.diag_eps = diag_eps |
|
self.mask_flooring = mask_flooring |
|
self.flooring_thres = flooring_thres |
|
self.use_torch_solver = use_torch_solver |
|
|
|
if self.use_dnn_mask: |
|
self.nmask = nmask |
|
self.mask_est = MaskEstimator( |
|
wtype, |
|
widim, |
|
wlayers, |
|
wunits, |
|
wprojs, |
|
dropout_rate, |
|
nmask=nmask, |
|
nonlinear=nonlinear, |
|
) |
|
else: |
|
self.nmask = 1 |
|
|
|
def forward( |
|
self, data: ComplexTensor, ilens: torch.LongTensor |
|
) -> Tuple[ComplexTensor, torch.LongTensor, ComplexTensor]: |
|
"""DNN_WPE forward function. |
|
|
|
Notation: |
|
B: Batch |
|
C: Channel |
|
T: Time or Sequence length |
|
F: Freq or Some dimension of the feature vector |
|
|
|
Args: |
|
data: (B, T, C, F) |
|
ilens: (B,) |
|
Returns: |
|
enhanced (torch.Tensor or List[torch.Tensor]): (B, T, C, F) |
|
ilens: (B,) |
|
masks (torch.Tensor or List[torch.Tensor]): (B, T, C, F) |
|
power (List[torch.Tensor]): (B, F, T) |
|
""" |
|
|
|
data = data.permute(0, 3, 2, 1) |
|
enhanced = [data for i in range(self.nmask)] |
|
masks = None |
|
power = None |
|
|
|
for i in range(self.iterations): |
|
|
|
power = [enh.real ** 2 + enh.imag ** 2 for enh in enhanced] |
|
if i == 0 and self.use_dnn_mask: |
|
|
|
masks, _ = self.mask_est(data, ilens) |
|
|
|
if self.mask_flooring: |
|
masks = [m.clamp(min=self.flooring_thres) for m in masks] |
|
if self.normalization: |
|
|
|
masks = [m / m.sum(dim=-1, keepdim=True) for m in masks] |
|
|
|
power = [p * masks[i] for i, p in enumerate(power)] |
|
|
|
|
|
power = [p.mean(dim=-2).clamp(min=self.eps) for p in power] |
|
|
|
|
|
|
|
enhanced = [ |
|
wpe_one_iteration( |
|
data.contiguous().double(), |
|
p.double(), |
|
taps=self.taps, |
|
delay=self.delay, |
|
inverse_power=self.inverse_power, |
|
) |
|
for p in power |
|
] |
|
enhanced = [ |
|
enh.to(dtype=data.dtype).masked_fill(make_pad_mask(ilens, enh.real), 0) |
|
for enh in enhanced |
|
] |
|
|
|
|
|
enhanced = [enh.permute(0, 3, 2, 1) for enh in enhanced] |
|
if masks is not None: |
|
masks = ( |
|
[m.transpose(-1, -3) for m in masks] |
|
if self.nmask > 1 |
|
else masks[0].transpose(-1, -3) |
|
) |
|
if self.nmask == 1: |
|
enhanced = enhanced[0] |
|
|
|
return enhanced, ilens, masks, power |
|
|
|
def predict_mask( |
|
self, data: ComplexTensor, ilens: torch.LongTensor |
|
) -> Tuple[torch.Tensor, torch.LongTensor]: |
|
"""Predict mask for WPE dereverberation. |
|
|
|
Args: |
|
data (ComplexTensor): (B, T, C, F), double precision |
|
ilens (torch.Tensor): (B,) |
|
Returns: |
|
masks (torch.Tensor or List[torch.Tensor]): (B, T, C, F) |
|
ilens (torch.Tensor): (B,) |
|
""" |
|
if self.use_dnn_mask: |
|
masks, ilens = self.mask_est(data.permute(0, 3, 2, 1).float(), ilens) |
|
|
|
masks = [m.transpose(-1, -3) for m in masks] |
|
if self.nmask == 1: |
|
masks = masks[0] |
|
else: |
|
masks = None |
|
return masks, ilens |
|
|