Spaces:
Running
on
Zero
Running
on
Zero
File size: 20,817 Bytes
c968fc3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 |
# This module is from [WeNet](https://github.com/wenet-e2e/wenet).
# ## Citations
# ```bibtex
# @inproceedings{yao2021wenet,
# title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit},
# author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin},
# booktitle={Proc. Interspeech},
# year={2021},
# address={Brno, Czech Republic },
# organization={IEEE}
# }
# @article{zhang2022wenet,
# title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit},
# author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei},
# journal={arXiv preprint arXiv:2203.15455},
# year={2022}
# }
#
"""Encoder definition."""
from typing import Tuple
import torch
from modules.wenet_extractor.transformer.attention import MultiHeadedAttention
from modules.wenet_extractor.transformer.attention import (
RelPositionMultiHeadedAttention,
)
from modules.wenet_extractor.transformer.convolution import ConvolutionModule
from modules.wenet_extractor.transformer.embedding import PositionalEncoding
from modules.wenet_extractor.transformer.embedding import RelPositionalEncoding
from modules.wenet_extractor.transformer.embedding import NoPositionalEncoding
from modules.wenet_extractor.transformer.encoder_layer import TransformerEncoderLayer
from modules.wenet_extractor.transformer.encoder_layer import ConformerEncoderLayer
from modules.wenet_extractor.transformer.positionwise_feed_forward import (
PositionwiseFeedForward,
)
from modules.wenet_extractor.transformer.subsampling import Conv2dSubsampling4
from modules.wenet_extractor.transformer.subsampling import Conv2dSubsampling6
from modules.wenet_extractor.transformer.subsampling import Conv2dSubsampling8
from modules.wenet_extractor.transformer.subsampling import LinearNoSubsampling
from modules.wenet_extractor.utils.common import get_activation
from modules.wenet_extractor.utils.mask import make_pad_mask
from modules.wenet_extractor.utils.mask import add_optional_chunk_mask
class BaseEncoder(torch.nn.Module):
def __init__(
self,
input_size: int,
output_size: int = 256,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 6,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
attention_dropout_rate: float = 0.0,
input_layer: str = "conv2d",
pos_enc_layer_type: str = "abs_pos",
normalize_before: bool = True,
static_chunk_size: int = 0,
use_dynamic_chunk: bool = False,
global_cmvn: torch.nn.Module = None,
use_dynamic_left_chunk: bool = False,
):
"""
Args:
input_size (int): input dim
output_size (int): dimension of attention
attention_heads (int): the number of heads of multi head attention
linear_units (int): the hidden units number of position-wise feed
forward
num_blocks (int): the number of decoder blocks
dropout_rate (float): dropout rate
attention_dropout_rate (float): dropout rate in attention
positional_dropout_rate (float): dropout rate after adding
positional encoding
input_layer (str): input layer type.
optional [linear, conv2d, conv2d6, conv2d8]
pos_enc_layer_type (str): Encoder positional encoding layer type.
opitonal [abs_pos, scaled_abs_pos, rel_pos, no_pos]
normalize_before (bool):
True: use layer_norm before each sub-block of a layer.
False: use layer_norm after each sub-block of a layer.
static_chunk_size (int): chunk size for static chunk training and
decoding
use_dynamic_chunk (bool): whether use dynamic chunk size for
training or not, You can only use fixed chunk(chunk_size > 0)
or dyanmic chunk size(use_dynamic_chunk = True)
global_cmvn (Optional[torch.nn.Module]): Optional GlobalCMVN module
use_dynamic_left_chunk (bool): whether use dynamic left chunk in
dynamic chunk training
"""
super().__init__()
self._output_size = output_size
if pos_enc_layer_type == "abs_pos":
pos_enc_class = PositionalEncoding
elif pos_enc_layer_type == "rel_pos":
pos_enc_class = RelPositionalEncoding
elif pos_enc_layer_type == "no_pos":
pos_enc_class = NoPositionalEncoding
else:
raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type)
if input_layer == "linear":
subsampling_class = LinearNoSubsampling
elif input_layer == "conv2d":
subsampling_class = Conv2dSubsampling4
elif input_layer == "conv2d6":
subsampling_class = Conv2dSubsampling6
elif input_layer == "conv2d8":
subsampling_class = Conv2dSubsampling8
else:
raise ValueError("unknown input_layer: " + input_layer)
self.global_cmvn = global_cmvn
self.embed = subsampling_class(
input_size,
output_size,
dropout_rate,
pos_enc_class(output_size, positional_dropout_rate),
)
self.normalize_before = normalize_before
self.after_norm = torch.nn.LayerNorm(output_size, eps=1e-5)
self.static_chunk_size = static_chunk_size
self.use_dynamic_chunk = use_dynamic_chunk
self.use_dynamic_left_chunk = use_dynamic_left_chunk
def output_size(self) -> int:
return self._output_size
def forward(
self,
xs: torch.Tensor,
xs_lens: torch.Tensor,
decoding_chunk_size: int = 0,
num_decoding_left_chunks: int = -1,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Embed positions in tensor.
Args:
xs: padded input tensor (B, T, D)
xs_lens: input length (B)
decoding_chunk_size: decoding chunk size for dynamic chunk
0: default for training, use random dynamic chunk.
<0: for decoding, use full chunk.
>0: for decoding, use fixed chunk size as set.
num_decoding_left_chunks: number of left chunks, this is for decoding,
the chunk size is decoding_chunk_size.
>=0: use num_decoding_left_chunks
<0: use all left chunks
Returns:
encoder output tensor xs, and subsampled masks
xs: padded output tensor (B, T' ~= T/subsample_rate, D)
masks: torch.Tensor batch padding mask after subsample
(B, 1, T' ~= T/subsample_rate)
"""
T = xs.size(1)
masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
if self.global_cmvn is not None:
xs = self.global_cmvn(xs)
xs, pos_emb, masks = self.embed(xs, masks)
mask_pad = masks # (B, 1, T/subsample_rate)
chunk_masks = add_optional_chunk_mask(
xs,
masks,
self.use_dynamic_chunk,
self.use_dynamic_left_chunk,
decoding_chunk_size,
self.static_chunk_size,
num_decoding_left_chunks,
)
for layer in self.encoders:
xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
if self.normalize_before:
xs = self.after_norm(xs)
# Here we assume the mask is not changed in encoder layers, so just
# return the masks before encoder layers, and the masks will be used
# for cross attention with decoder later
return xs, masks
def forward_chunk(
self,
xs: torch.Tensor,
offset: int,
required_cache_size: int,
att_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
cnn_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
att_mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
""" Forward just one chunk
Args:
xs (torch.Tensor): chunk input, with shape (b=1, time, mel-dim),
where `time == (chunk_size - 1) * subsample_rate + \
subsample.right_context + 1`
offset (int): current offset in encoder output time stamp
required_cache_size (int): cache size required for next chunk
compuation
>=0: actual cache size
<0: means all history cache is required
att_cache (torch.Tensor): cache tensor for KEY & VALUE in
transformer/conformer attention, with shape
(elayers, head, cache_t1, d_k * 2), where
`head * d_k == hidden-dim` and
`cache_t1 == chunk_size * num_decoding_left_chunks`.
cnn_cache (torch.Tensor): cache tensor for cnn_module in conformer,
(elayers, b=1, hidden-dim, cache_t2), where
`cache_t2 == cnn.lorder - 1`
Returns:
torch.Tensor: output of current input xs,
with shape (b=1, chunk_size, hidden-dim).
torch.Tensor: new attention cache required for next chunk, with
dynamic shape (elayers, head, ?, d_k * 2)
depending on required_cache_size.
torch.Tensor: new conformer cnn cache required for next chunk, with
same shape as the original cnn_cache.
"""
assert xs.size(0) == 1
# tmp_masks is just for interface compatibility
tmp_masks = torch.ones(1, xs.size(1), device=xs.device, dtype=torch.bool)
tmp_masks = tmp_masks.unsqueeze(1)
if self.global_cmvn is not None:
xs = self.global_cmvn(xs)
# NOTE(xcsong): Before embed, shape(xs) is (b=1, time, mel-dim)
xs, pos_emb, _ = self.embed(xs, tmp_masks, offset)
# NOTE(xcsong): After embed, shape(xs) is (b=1, chunk_size, hidden-dim)
elayers, cache_t1 = att_cache.size(0), att_cache.size(2)
chunk_size = xs.size(1)
attention_key_size = cache_t1 + chunk_size
pos_emb = self.embed.position_encoding(
offset=offset - cache_t1, size=attention_key_size
)
if required_cache_size < 0:
next_cache_start = 0
elif required_cache_size == 0:
next_cache_start = attention_key_size
else:
next_cache_start = max(attention_key_size - required_cache_size, 0)
r_att_cache = []
r_cnn_cache = []
for i, layer in enumerate(self.encoders):
# NOTE(xcsong): Before layer.forward
# shape(att_cache[i:i + 1]) is (1, head, cache_t1, d_k * 2),
# shape(cnn_cache[i]) is (b=1, hidden-dim, cache_t2)
xs, _, new_att_cache, new_cnn_cache = layer(
xs,
att_mask,
pos_emb,
att_cache=att_cache[i : i + 1] if elayers > 0 else att_cache,
cnn_cache=cnn_cache[i] if cnn_cache.size(0) > 0 else cnn_cache,
)
# NOTE(xcsong): After layer.forward
# shape(new_att_cache) is (1, head, attention_key_size, d_k * 2),
# shape(new_cnn_cache) is (b=1, hidden-dim, cache_t2)
r_att_cache.append(new_att_cache[:, :, next_cache_start:, :])
r_cnn_cache.append(new_cnn_cache.unsqueeze(0))
if self.normalize_before:
xs = self.after_norm(xs)
# NOTE(xcsong): shape(r_att_cache) is (elayers, head, ?, d_k * 2),
# ? may be larger than cache_t1, it depends on required_cache_size
r_att_cache = torch.cat(r_att_cache, dim=0)
# NOTE(xcsong): shape(r_cnn_cache) is (e, b=1, hidden-dim, cache_t2)
r_cnn_cache = torch.cat(r_cnn_cache, dim=0)
return (xs, r_att_cache, r_cnn_cache)
def forward_chunk_by_chunk(
self,
xs: torch.Tensor,
decoding_chunk_size: int,
num_decoding_left_chunks: int = -1,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward input chunk by chunk with chunk_size like a streaming
fashion
Here we should pay special attention to computation cache in the
streaming style forward chunk by chunk. Three things should be taken
into account for computation in the current network:
1. transformer/conformer encoder layers output cache
2. convolution in conformer
3. convolution in subsampling
However, we don't implement subsampling cache for:
1. We can control subsampling module to output the right result by
overlapping input instead of cache left context, even though it
wastes some computation, but subsampling only takes a very
small fraction of computation in the whole model.
2. Typically, there are several covolution layers with subsampling
in subsampling module, it is tricky and complicated to do cache
with different convolution layers with different subsampling
rate.
3. Currently, nn.Sequential is used to stack all the convolution
layers in subsampling, we need to rewrite it to make it work
with cache, which is not prefered.
Args:
xs (torch.Tensor): (1, max_len, dim)
chunk_size (int): decoding chunk size
"""
assert decoding_chunk_size > 0
# The model is trained by static or dynamic chunk
assert self.static_chunk_size > 0 or self.use_dynamic_chunk
subsampling = self.embed.subsampling_rate
context = self.embed.right_context + 1 # Add current frame
stride = subsampling * decoding_chunk_size
decoding_window = (decoding_chunk_size - 1) * subsampling + context
num_frames = xs.size(1)
att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0), device=xs.device)
cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0), device=xs.device)
outputs = []
offset = 0
required_cache_size = decoding_chunk_size * num_decoding_left_chunks
# Feed forward overlap input step by step
for cur in range(0, num_frames - context + 1, stride):
end = min(cur + decoding_window, num_frames)
chunk_xs = xs[:, cur:end, :]
(y, att_cache, cnn_cache) = self.forward_chunk(
chunk_xs, offset, required_cache_size, att_cache, cnn_cache
)
outputs.append(y)
offset += y.size(1)
ys = torch.cat(outputs, 1)
masks = torch.ones((1, 1, ys.size(1)), device=ys.device, dtype=torch.bool)
return ys, masks
class TransformerEncoder(BaseEncoder):
"""Transformer encoder module."""
def __init__(
self,
input_size: int,
output_size: int = 256,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 6,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
attention_dropout_rate: float = 0.0,
input_layer: str = "conv2d",
pos_enc_layer_type: str = "abs_pos",
normalize_before: bool = True,
static_chunk_size: int = 0,
use_dynamic_chunk: bool = False,
global_cmvn: torch.nn.Module = None,
use_dynamic_left_chunk: bool = False,
):
"""Construct TransformerEncoder
See Encoder for the meaning of each parameter.
"""
super().__init__(
input_size,
output_size,
attention_heads,
linear_units,
num_blocks,
dropout_rate,
positional_dropout_rate,
attention_dropout_rate,
input_layer,
pos_enc_layer_type,
normalize_before,
static_chunk_size,
use_dynamic_chunk,
global_cmvn,
use_dynamic_left_chunk,
)
self.encoders = torch.nn.ModuleList(
[
TransformerEncoderLayer(
output_size,
MultiHeadedAttention(
attention_heads, output_size, attention_dropout_rate
),
PositionwiseFeedForward(output_size, linear_units, dropout_rate),
dropout_rate,
normalize_before,
)
for _ in range(num_blocks)
]
)
class ConformerEncoder(BaseEncoder):
"""Conformer encoder module."""
def __init__(
self,
input_size: int,
output_size: int = 256,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 6,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
attention_dropout_rate: float = 0.0,
input_layer: str = "conv2d",
pos_enc_layer_type: str = "rel_pos",
normalize_before: bool = True,
static_chunk_size: int = 0,
use_dynamic_chunk: bool = False,
global_cmvn: torch.nn.Module = None,
use_dynamic_left_chunk: bool = False,
positionwise_conv_kernel_size: int = 1,
macaron_style: bool = True,
selfattention_layer_type: str = "rel_selfattn",
activation_type: str = "swish",
use_cnn_module: bool = True,
cnn_module_kernel: int = 15,
causal: bool = False,
cnn_module_norm: str = "batch_norm",
):
"""Construct ConformerEncoder
Args:
input_size to use_dynamic_chunk, see in BaseEncoder
positionwise_conv_kernel_size (int): Kernel size of positionwise
conv1d layer.
macaron_style (bool): Whether to use macaron style for
positionwise layer.
selfattention_layer_type (str): Encoder attention layer type,
the parameter has no effect now, it's just for configure
compatibility.
activation_type (str): Encoder activation function type.
use_cnn_module (bool): Whether to use convolution module.
cnn_module_kernel (int): Kernel size of convolution module.
causal (bool): whether to use causal convolution or not.
"""
super().__init__(
input_size,
output_size,
attention_heads,
linear_units,
num_blocks,
dropout_rate,
positional_dropout_rate,
attention_dropout_rate,
input_layer,
pos_enc_layer_type,
normalize_before,
static_chunk_size,
use_dynamic_chunk,
global_cmvn,
use_dynamic_left_chunk,
)
activation = get_activation(activation_type)
# self-attention module definition
if pos_enc_layer_type != "rel_pos":
encoder_selfattn_layer = MultiHeadedAttention
else:
encoder_selfattn_layer = RelPositionMultiHeadedAttention
encoder_selfattn_layer_args = (
attention_heads,
output_size,
attention_dropout_rate,
)
# feed-forward module definition
positionwise_layer = PositionwiseFeedForward
positionwise_layer_args = (
output_size,
linear_units,
dropout_rate,
activation,
)
# convolution module definition
convolution_layer = ConvolutionModule
convolution_layer_args = (
output_size,
cnn_module_kernel,
activation,
cnn_module_norm,
causal,
)
self.encoders = torch.nn.ModuleList(
[
ConformerEncoderLayer(
output_size,
encoder_selfattn_layer(*encoder_selfattn_layer_args),
positionwise_layer(*positionwise_layer_args),
(
positionwise_layer(*positionwise_layer_args)
if macaron_style
else None
),
(
convolution_layer(*convolution_layer_args)
if use_cnn_module
else None
),
dropout_rate,
normalize_before,
)
for _ in range(num_blocks)
]
)
|