yuancwang
init
b725c5a
raw
history blame
13.8 kB
# This module is from [WeNet](https://github.com/wenet-e2e/wenet).
# ## Citations
# ```bibtex
# @inproceedings{yao2021wenet,
# title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit},
# author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin},
# booktitle={Proc. Interspeech},
# year={2021},
# address={Brno, Czech Republic },
# organization={IEEE}
# }
# @article{zhang2022wenet,
# title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit},
# author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei},
# journal={arXiv preprint arXiv:2203.15455},
# year={2022}
# }
#
from typing import Dict, Optional, Tuple
import torch
from modules.wenet_extractor.cif.predictor import MAELoss
from modules.wenet_extractor.paraformer.search.beam_search import Hypothesis
from modules.wenet_extractor.transformer.asr_model import ASRModel
from modules.wenet_extractor.transformer.ctc import CTC
from modules.wenet_extractor.transformer.decoder import TransformerDecoder
from modules.wenet_extractor.transformer.encoder import TransformerEncoder
from modules.wenet_extractor.utils.common import IGNORE_ID, add_sos_eos, th_accuracy
from modules.wenet_extractor.utils.mask import make_pad_mask
class Paraformer(ASRModel):
"""Paraformer: Fast and Accurate Parallel Transformer for
Non-autoregressive End-to-End Speech Recognition
see https://arxiv.org/pdf/2206.08317.pdf
"""
def __init__(
self,
vocab_size: int,
encoder: TransformerEncoder,
decoder: TransformerDecoder,
ctc: CTC,
predictor,
ctc_weight: float = 0.5,
predictor_weight: float = 1.0,
predictor_bias: int = 0,
ignore_id: int = IGNORE_ID,
reverse_weight: float = 0.0,
lsm_weight: float = 0.0,
length_normalized_loss: bool = False,
):
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
assert 0.0 <= predictor_weight <= 1.0, predictor_weight
super().__init__(
vocab_size,
encoder,
decoder,
ctc,
ctc_weight,
ignore_id,
reverse_weight,
lsm_weight,
length_normalized_loss,
)
self.predictor = predictor
self.predictor_weight = predictor_weight
self.predictor_bias = predictor_bias
self.criterion_pre = MAELoss(normalize_length=length_normalized_loss)
def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
) -> Dict[str, Optional[torch.Tensor]]:
"""Frontend + Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
text: (Batch, Length)
text_lengths: (Batch,)
"""
assert text_lengths.dim() == 1, text_lengths.shape
# Check that batch_size is unified
assert (
speech.shape[0]
== speech_lengths.shape[0]
== text.shape[0]
== text_lengths.shape[0]
), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
# 1. Encoder
encoder_out, encoder_mask = self.encoder(speech, speech_lengths)
encoder_out_lens = encoder_mask.squeeze(1).sum(1)
# 2a. Attention-decoder branch
if self.ctc_weight != 1.0:
loss_att, acc_att, loss_pre = self._calc_att_loss(
encoder_out, encoder_mask, text, text_lengths
)
else:
# loss_att = None
# loss_pre = None
loss_att: torch.Tensor = torch.tensor(0)
loss_pre: torch.Tensor = torch.tensor(0)
# 2b. CTC branch
if self.ctc_weight != 0.0:
loss_ctc = self.ctc(encoder_out, encoder_out_lens, text, text_lengths)
else:
loss_ctc = None
if loss_ctc is None:
loss = loss_att + self.predictor_weight * loss_pre
# elif loss_att is None:
elif loss_att == torch.tensor(0):
loss = loss_ctc
else:
loss = (
self.ctc_weight * loss_ctc
+ (1 - self.ctc_weight) * loss_att
+ self.predictor_weight * loss_pre
)
return {
"loss": loss,
"loss_att": loss_att,
"loss_ctc": loss_ctc,
"loss_pre": loss_pre,
}
def _calc_att_loss(
self,
encoder_out: torch.Tensor,
encoder_mask: torch.Tensor,
ys_pad: torch.Tensor,
ys_pad_lens: torch.Tensor,
) -> Tuple[torch.Tensor, float, torch.Tensor]:
if self.predictor_bias == 1:
_, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
ys_pad_lens = ys_pad_lens + self.predictor_bias
pre_acoustic_embeds, pre_token_length, _, pre_peak_index = self.predictor(
encoder_out, ys_pad, encoder_mask, ignore_id=self.ignore_id
)
# 1. Forward decoder
decoder_out, _, _ = self.decoder(
encoder_out, encoder_mask, pre_acoustic_embeds, ys_pad_lens
)
# 2. Compute attention loss
loss_att = self.criterion_att(decoder_out, ys_pad)
acc_att = th_accuracy(
decoder_out.view(-1, self.vocab_size),
ys_pad,
ignore_label=self.ignore_id,
)
loss_pre: torch.Tensor = self.criterion_pre(
ys_pad_lens.type_as(pre_token_length), pre_token_length
)
return loss_att, acc_att, loss_pre
def calc_predictor(self, encoder_out, encoder_mask):
encoder_mask = (
~make_pad_mask(encoder_mask, max_len=encoder_out.size(1))[:, None, :]
).to(encoder_out.device)
pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(
encoder_out, None, encoder_mask, ignore_id=self.ignore_id
)
return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
def cal_decoder_with_predictor(
self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens
):
decoder_out, _, _ = self.decoder(
encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens
)
decoder_out = torch.log_softmax(decoder_out, dim=-1)
return decoder_out, ys_pad_lens
def recognize(self):
raise NotImplementedError
def paraformer_greedy_search(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
decoding_chunk_size: int = -1,
num_decoding_left_chunks: int = -1,
simulate_streaming: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Apply beam search on attention decoder
Args:
speech (torch.Tensor): (batch, max_len, feat_dim)
speech_length (torch.Tensor): (batch, )
decoding_chunk_size (int): decoding chunk for dynamic chunk
trained model.
<0: for decoding, use full chunk.
>0: for decoding, use fixed chunk size as set.
0: used for training, it's prohibited here
simulate_streaming (bool): whether do encoder forward in a
streaming fashion
Returns:
torch.Tensor: decoding result, (batch, max_result_len)
"""
assert speech.shape[0] == speech_lengths.shape[0]
assert decoding_chunk_size != 0
device = speech.device
batch_size = speech.shape[0]
# Let's assume B = batch_size and N = beam_size
# 1. Encoder
encoder_out, encoder_mask = self._forward_encoder(
speech,
speech_lengths,
decoding_chunk_size,
num_decoding_left_chunks,
simulate_streaming,
) # (B, maxlen, encoder_dim)
encoder_out_lens = encoder_mask.squeeze(1).sum(1)
# 2. Predictor
predictor_outs = self.calc_predictor(encoder_out, encoder_mask)
pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = (
predictor_outs[0],
predictor_outs[1],
predictor_outs[2],
predictor_outs[3],
)
pre_token_length = pre_token_length.round().long()
if torch.max(pre_token_length) < 1:
return torch.tensor([]), torch.tensor([])
# 2. Decoder forward
decoder_outs = self.cal_decoder_with_predictor(
encoder_out, encoder_out_lens, pre_acoustic_embeds, pre_token_length
)
decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
hyps = []
b, n, d = decoder_out.size()
for i in range(b):
x = encoder_out[i, : encoder_out_lens[i], :]
am_scores = decoder_out[i, : pre_token_length[i], :]
yseq = am_scores.argmax(dim=-1)
score = am_scores.max(dim=-1)[0]
score = torch.sum(score, dim=-1)
# pad with mask tokens to ensure compatibility with sos/eos tokens
yseq = torch.tensor(
[self.sos] + yseq.tolist() + [self.eos], device=yseq.device
)
nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
for hyp in nbest_hyps:
assert isinstance(hyp, (Hypothesis)), type(hyp)
# remove sos/eos and get hyps
last_pos = -1
if isinstance(hyp.yseq, list):
token_int = hyp.yseq[1:last_pos]
else:
token_int = hyp.yseq[1:last_pos].tolist()
# remove blank symbol id and unk id, which is assumed to be 0
# and 1
token_int = list(filter(lambda x: x != 0 and x != 1, token_int))
hyps.append(token_int)
return hyps
def paraformer_beam_search(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
beam_search: torch.nn.Module = None,
decoding_chunk_size: int = -1,
num_decoding_left_chunks: int = -1,
simulate_streaming: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Apply beam search on attention decoder
Args:
speech (torch.Tensor): (batch, max_len, feat_dim)
speech_lengths (torch.Tensor): (batch, )
beam_search (torch.nn.Moudle): beam search module
decoding_chunk_size (int): decoding chunk for dynamic chunk
trained model.
<0: for decoding, use full chunk.
>0: for decoding, use fixed chunk size as set.
0: used for training, it's prohibited here
simulate_streaming (bool): whether do encoder forward in a
streaming fashion
Returns:
torch.Tensor: decoding result, (batch, max_result_len)
"""
assert speech.shape[0] == speech_lengths.shape[0]
assert decoding_chunk_size != 0
device = speech.device
batch_size = speech.shape[0]
# Let's assume B = batch_size and N = beam_size
# 1. Encoder
encoder_out, encoder_mask = self._forward_encoder(
speech,
speech_lengths,
decoding_chunk_size,
num_decoding_left_chunks,
simulate_streaming,
) # (B, maxlen, encoder_dim)
encoder_out_lens = encoder_mask.squeeze(1).sum(1)
# 2. Predictor
predictor_outs = self.calc_predictor(encoder_out, encoder_mask)
pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = (
predictor_outs[0],
predictor_outs[1],
predictor_outs[2],
predictor_outs[3],
)
pre_token_length = pre_token_length.round().long()
if torch.max(pre_token_length) < 1:
return torch.tensor([]), torch.tensor([])
# 2. Decoder forward
decoder_outs = self.cal_decoder_with_predictor(
encoder_out, encoder_out_lens, pre_acoustic_embeds, pre_token_length
)
decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
hyps = []
b, n, d = decoder_out.size()
for i in range(b):
x = encoder_out[i, : encoder_out_lens[i], :]
am_scores = decoder_out[i, : pre_token_length[i], :]
if beam_search is not None:
nbest_hyps = beam_search(x=x, am_scores=am_scores)
nbest_hyps = nbest_hyps[:1]
else:
yseq = am_scores.argmax(dim=-1)
score = am_scores.max(dim=-1)[0]
score = torch.sum(score, dim=-1)
# pad with mask tokens to ensure compatibility with sos/eos
# tokens
yseq = torch.tensor(
[self.sos] + yseq.tolist() + [self.eos], device=yseq.device
)
nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
for hyp in nbest_hyps:
assert isinstance(hyp, (Hypothesis)), type(hyp)
# remove sos/eos and get hyps
last_pos = -1
if isinstance(hyp.yseq, list):
token_int = hyp.yseq[1:last_pos]
else:
token_int = hyp.yseq[1:last_pos].tolist()
# remove blank symbol id and unk id, which is assumed to be 0
# and 1
token_int = list(filter(lambda x: x != 0 and x != 1, token_int))
hyps.append(token_int)
return hyps