File size: 10,296 Bytes
ad16788 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 |
#!/usr/bin/env python3
# 2020, Technische Universität München; Ludwig Kürzinger
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Sinc convolutions for raw audio input."""
from collections import OrderedDict
from espnet2.asr.preencoder.abs_preencoder import AbsPreEncoder
from espnet2.layers.sinc_conv import LogCompression
from espnet2.layers.sinc_conv import SincConv
import humanfriendly
import torch
from typeguard import check_argument_types
from typing import Optional
from typing import Tuple
from typing import Union
class LightweightSincConvs(AbsPreEncoder):
"""Lightweight Sinc Convolutions.
Instead of using precomputed features, end-to-end speech recognition
can also be done directly from raw audio using sinc convolutions, as
described in "Lightweight End-to-End Speech Recognition from Raw Audio
Data Using Sinc-Convolutions" by Kürzinger et al.
https://arxiv.org/abs/2010.07597
To use Sinc convolutions in your model instead of the default f-bank
frontend, set this module as your pre-encoder with `preencoder: sinc`
and use the input of the sliding window frontend with
`frontend: sliding_window` in your yaml configuration file.
So that the process flow is:
Frontend (SlidingWindow) -> SpecAug -> Normalization ->
Pre-encoder (LightweightSincConvs) -> Encoder -> Decoder
Note that this method also performs data augmentation in time domain
(vs. in spectral domain in the default frontend).
Use `plot_sinc_filters.py` to visualize the learned Sinc filters.
"""
def __init__(
self,
fs: Union[int, str, float] = 16000,
in_channels: int = 1,
out_channels: int = 256,
activation_type: str = "leakyrelu",
dropout_type: str = "dropout",
windowing_type: str = "hamming",
scale_type: str = "mel",
):
"""Initialize the module.
Args:
fs: Sample rate.
in_channels: Number of input channels.
out_channels: Number of output channels (for each input channel).
activation_type: Choice of activation function.
dropout_type: Choice of dropout function.
windowing_type: Choice of windowing function.
scale_type: Choice of filter-bank initialization scale.
"""
assert check_argument_types()
super().__init__()
if isinstance(fs, str):
fs = humanfriendly.parse_size(fs)
self.fs = fs
self.in_channels = in_channels
self.out_channels = out_channels
self.activation_type = activation_type
self.dropout_type = dropout_type
self.windowing_type = windowing_type
self.scale_type = scale_type
self.choices_dropout = {
"dropout": torch.nn.Dropout,
"spatial": SpatialDropout,
"dropout2d": torch.nn.Dropout2d,
}
if dropout_type not in self.choices_dropout:
raise NotImplementedError(
f"Dropout type has to be one of "
f"{list(self.choices_dropout.keys())}",
)
self.choices_activation = {
"leakyrelu": torch.nn.LeakyReLU,
"relu": torch.nn.ReLU,
}
if activation_type not in self.choices_activation:
raise NotImplementedError(
f"Activation type has to be one of "
f"{list(self.choices_activation.keys())}",
)
# initialization
self._create_sinc_convs()
# Sinc filters require custom initialization
self.espnet_initialization_fn()
def _create_sinc_convs(self):
blocks = OrderedDict()
# SincConvBlock
out_channels = 128
self.filters = SincConv(
self.in_channels,
out_channels,
kernel_size=101,
stride=1,
fs=self.fs,
window_func=self.windowing_type,
scale_type=self.scale_type,
)
block = OrderedDict(
[
("Filters", self.filters),
("LogCompression", LogCompression()),
("BatchNorm", torch.nn.BatchNorm1d(out_channels, affine=True)),
("AvgPool", torch.nn.AvgPool1d(2)),
]
)
blocks["SincConvBlock"] = torch.nn.Sequential(block)
in_channels = out_channels
# First convolutional block, connects the sinc output to the front-end "body"
out_channels = 128
blocks["DConvBlock1"] = self.gen_lsc_block(
in_channels,
out_channels,
depthwise_kernel_size=25,
depthwise_stride=2,
pointwise_groups=0,
avgpool=True,
dropout_probability=0.1,
)
in_channels = out_channels
# Second convolutional block, multiple convolutional layers
out_channels = self.out_channels
for layer in [2, 3, 4]:
blocks[f"DConvBlock{layer}"] = self.gen_lsc_block(
in_channels, out_channels, depthwise_kernel_size=9, depthwise_stride=1
)
in_channels = out_channels
# Third Convolutional block, acts as coupling to encoder
out_channels = self.out_channels
blocks["DConvBlock5"] = self.gen_lsc_block(
in_channels,
out_channels,
depthwise_kernel_size=7,
depthwise_stride=1,
pointwise_groups=0,
)
self.blocks = torch.nn.Sequential(blocks)
def gen_lsc_block(
self,
in_channels: int,
out_channels: int,
depthwise_kernel_size: int = 9,
depthwise_stride: int = 1,
depthwise_groups=None,
pointwise_groups=0,
dropout_probability: float = 0.15,
avgpool=False,
):
"""Generate a convolutional block for Lightweight Sinc convolutions.
Each block consists of either a depthwise or a depthwise-separable
convolutions together with dropout, (batch-)normalization layer, and
an optional average-pooling layer.
Args:
in_channels: Number of input channels.
out_channels: Number of output channels.
depthwise_kernel_size: Kernel size of the depthwise convolution.
depthwise_stride: Stride of the depthwise convolution.
depthwise_groups: Number of groups of the depthwise convolution.
pointwise_groups: Number of groups of the pointwise convolution.
dropout_probability: Dropout probability in the block.
avgpool: If True, an AvgPool layer is inserted.
Returns:
torch.nn.Sequential: Neural network building block.
"""
block = OrderedDict()
if not depthwise_groups:
# GCD(in_channels, out_channels) to prevent size mismatches
depthwise_groups, r = in_channels, out_channels
while r != 0:
depthwise_groups, r = depthwise_groups, depthwise_groups % r
block["depthwise"] = torch.nn.Conv1d(
in_channels,
out_channels,
depthwise_kernel_size,
depthwise_stride,
groups=depthwise_groups,
)
if pointwise_groups:
block["pointwise"] = torch.nn.Conv1d(
out_channels, out_channels, 1, 1, groups=pointwise_groups
)
block["activation"] = self.choices_activation[self.activation_type]()
block["batchnorm"] = torch.nn.BatchNorm1d(out_channels, affine=True)
if avgpool:
block["avgpool"] = torch.nn.AvgPool1d(2)
block["dropout"] = self.choices_dropout[self.dropout_type](dropout_probability)
return torch.nn.Sequential(block)
def espnet_initialization_fn(self):
"""Initialize sinc filters with filterbank values."""
self.filters.init_filters()
for block in self.blocks:
for layer in block:
if type(layer) == torch.nn.BatchNorm1d and layer.affine:
layer.weight.data[:] = 1.0
layer.bias.data[:] = 0.0
def forward(
self, input: torch.Tensor, input_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Apply Lightweight Sinc Convolutions.
The input shall be formatted as (B, T, C_in, D_in)
with B as batch size, T as time dimension, C_in as channels,
and D_in as feature dimension.
The output will then be (B, T, C_out*D_out)
with C_out and D_out as output dimensions.
The current module structure only handles D_in=400, so that D_out=1.
Remark for the multichannel case: C_out is the number of out_channels
given at initialization multiplied with C_in.
"""
# Transform input data:
# (B, T, C_in, D_in) -> (B*T, C_in, D_in)
B, T, C_in, D_in = input.size()
input_frames = input.view(B * T, C_in, D_in)
output_frames = self.blocks.forward(input_frames)
# ---TRANSFORM: (B*T, C_out, D_out) -> (B, T, C_out*D_out)
_, C_out, D_out = output_frames.size()
output_frames = output_frames.view(B, T, C_out * D_out)
return output_frames, input_lengths # no state in this layer
def output_size(self) -> int:
"""Get the output size."""
return self.out_channels * self.in_channels
class SpatialDropout(torch.nn.Module):
"""Spatial dropout module.
Apply dropout to full channels on tensors of input (B, C, D)
"""
def __init__(
self,
dropout_probability: float = 0.15,
shape: Optional[Union[tuple, list]] = None,
):
"""Initialize.
Args:
dropout_probability: Dropout probability.
shape (tuple, list): Shape of input tensors.
"""
assert check_argument_types()
super().__init__()
if shape is None:
shape = (0, 2, 1)
self.dropout = torch.nn.Dropout2d(dropout_probability)
self.shape = (shape,)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward of spatial dropout module."""
y = x.permute(*self.shape)
y = self.dropout(y)
return y.permute(*self.shape)
|