Spaces:
Runtime error
Runtime error
from contextlib import contextmanager | |
from distutils.version import LooseVersion | |
from typing import Dict | |
from typing import List | |
from typing import Optional | |
from typing import Tuple | |
from typing import Union | |
import logging | |
import torch | |
from funasr_detach.metrics import ErrorCalculator | |
from funasr_detach.metrics.compute_acc import th_accuracy | |
from funasr_detach.models.transformer.utils.add_sos_eos import add_sos_eos | |
from funasr_detach.losses.label_smoothing_loss import ( | |
LabelSmoothingLoss, # noqa: H301 | |
) | |
from funasr_detach.models.ctc import CTC | |
from funasr_detach.models.decoder.abs_decoder import AbsDecoder | |
from funasr_detach.models.encoder.abs_encoder import AbsEncoder | |
from funasr_detach.frontends.abs_frontend import AbsFrontend | |
from funasr_detach.models.preencoder.abs_preencoder import AbsPreEncoder | |
from funasr_detach.models.specaug.abs_specaug import AbsSpecAug | |
from funasr_detach.layers.abs_normalize import AbsNormalize | |
from funasr_detach.train_utils.device_funcs import force_gatherable | |
from funasr_detach.models.base_model import FunASRModel | |
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): | |
from torch.cuda.amp import autocast | |
else: | |
# Nothing to do if torch<1.6.0 | |
def autocast(enabled=True): | |
yield | |
import pdb | |
import random | |
import math | |
class MFCCA(FunASRModel): | |
""" | |
Author: Audio, Speech and Language Processing Group (ASLP@NPU), Northwestern Polytechnical University | |
MFCCA:Multi-Frame Cross-Channel attention for multi-speaker ASR in Multi-party meeting scenario | |
https://arxiv.org/abs/2210.05265 | |
""" | |
def __init__( | |
self, | |
vocab_size: int, | |
token_list: Union[Tuple[str, ...], List[str]], | |
frontend: Optional[AbsFrontend], | |
specaug: Optional[AbsSpecAug], | |
normalize: Optional[AbsNormalize], | |
encoder: AbsEncoder, | |
decoder: AbsDecoder, | |
ctc: CTC, | |
rnnt_decoder: None = None, | |
ctc_weight: float = 0.5, | |
ignore_id: int = -1, | |
lsm_weight: float = 0.0, | |
mask_ratio: float = 0.0, | |
length_normalized_loss: bool = False, | |
report_cer: bool = True, | |
report_wer: bool = True, | |
sym_space: str = "<space>", | |
sym_blank: str = "<blank>", | |
preencoder: Optional[AbsPreEncoder] = None, | |
): | |
assert 0.0 <= ctc_weight <= 1.0, ctc_weight | |
assert rnnt_decoder is None, "Not implemented" | |
super().__init__() | |
# note that eos is the same as sos (equivalent ID) | |
self.sos = vocab_size - 1 | |
self.eos = vocab_size - 1 | |
self.vocab_size = vocab_size | |
self.ignore_id = ignore_id | |
self.ctc_weight = ctc_weight | |
self.token_list = token_list.copy() | |
self.mask_ratio = mask_ratio | |
self.frontend = frontend | |
self.specaug = specaug | |
self.normalize = normalize | |
self.preencoder = preencoder | |
self.encoder = encoder | |
# we set self.decoder = None in the CTC mode since | |
# self.decoder parameters were never used and PyTorch complained | |
# and threw an Exception in the multi-GPU experiment. | |
# thanks Jeff Farris for pointing out the issue. | |
if ctc_weight == 1.0: | |
self.decoder = None | |
else: | |
self.decoder = decoder | |
if ctc_weight == 0.0: | |
self.ctc = None | |
else: | |
self.ctc = ctc | |
self.rnnt_decoder = rnnt_decoder | |
self.criterion_att = LabelSmoothingLoss( | |
size=vocab_size, | |
padding_idx=ignore_id, | |
smoothing=lsm_weight, | |
normalize_length=length_normalized_loss, | |
) | |
if report_cer or report_wer: | |
self.error_calculator = ErrorCalculator( | |
token_list, sym_space, sym_blank, report_cer, report_wer | |
) | |
else: | |
self.error_calculator = None | |
def forward( | |
self, | |
speech: torch.Tensor, | |
speech_lengths: torch.Tensor, | |
text: torch.Tensor, | |
text_lengths: torch.Tensor, | |
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: | |
"""Frontend + Encoder + Decoder + Calc loss | |
Args: | |
speech: (Batch, Length, ...) | |
speech_lengths: (Batch, ) | |
text: (Batch, Length) | |
text_lengths: (Batch,) | |
""" | |
assert text_lengths.dim() == 1, text_lengths.shape | |
# Check that batch_size is unified | |
assert ( | |
speech.shape[0] | |
== speech_lengths.shape[0] | |
== text.shape[0] | |
== text_lengths.shape[0] | |
), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape) | |
# pdb.set_trace() | |
if speech.dim() == 3 and speech.size(2) == 8 and self.mask_ratio != 0: | |
rate_num = random.random() | |
# rate_num = 0.1 | |
if rate_num <= self.mask_ratio: | |
retain_channel = math.ceil(random.random() * 8) | |
if retain_channel > 1: | |
speech = speech[ | |
:, :, torch.randperm(8)[0:retain_channel].sort().values | |
] | |
else: | |
speech = speech[:, :, torch.randperm(8)[0]] | |
# pdb.set_trace() | |
batch_size = speech.shape[0] | |
# for data-parallel | |
text = text[:, : text_lengths.max()] | |
# 1. Encoder | |
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) | |
# 2a. Attention-decoder branch | |
if self.ctc_weight == 1.0: | |
loss_att, acc_att, cer_att, wer_att = None, None, None, None | |
else: | |
loss_att, acc_att, cer_att, wer_att = self._calc_att_loss( | |
encoder_out, encoder_out_lens, text, text_lengths | |
) | |
# 2b. CTC branch | |
if self.ctc_weight == 0.0: | |
loss_ctc, cer_ctc = None, None | |
else: | |
loss_ctc, cer_ctc = self._calc_ctc_loss( | |
encoder_out, encoder_out_lens, text, text_lengths | |
) | |
# 2c. RNN-T branch | |
if self.rnnt_decoder is not None: | |
_ = self._calc_rnnt_loss(encoder_out, encoder_out_lens, text, text_lengths) | |
if self.ctc_weight == 0.0: | |
loss = loss_att | |
elif self.ctc_weight == 1.0: | |
loss = loss_ctc | |
else: | |
loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att | |
stats = dict( | |
loss=loss.detach(), | |
loss_att=loss_att.detach() if loss_att is not None else None, | |
loss_ctc=loss_ctc.detach() if loss_ctc is not None else None, | |
acc=acc_att, | |
cer=cer_att, | |
wer=wer_att, | |
cer_ctc=cer_ctc, | |
) | |
# force_gatherable: to-device and to-tensor if scalar for DataParallel | |
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) | |
return loss, stats, weight | |
def collect_feats( | |
self, | |
speech: torch.Tensor, | |
speech_lengths: torch.Tensor, | |
text: torch.Tensor, | |
text_lengths: torch.Tensor, | |
) -> Dict[str, torch.Tensor]: | |
feats, feats_lengths, channel_size = self._extract_feats(speech, speech_lengths) | |
return {"feats": feats, "feats_lengths": feats_lengths} | |
def encode( | |
self, speech: torch.Tensor, speech_lengths: torch.Tensor | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
"""Frontend + Encoder. Note that this method is used by asr_inference.py | |
Args: | |
speech: (Batch, Length, ...) | |
speech_lengths: (Batch, ) | |
""" | |
with autocast(False): | |
# 1. Extract feats | |
feats, feats_lengths, channel_size = self._extract_feats( | |
speech, speech_lengths | |
) | |
# 2. Data augmentation | |
if self.specaug is not None and self.training: | |
feats, feats_lengths = self.specaug(feats, feats_lengths) | |
# 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN | |
if self.normalize is not None: | |
feats, feats_lengths = self.normalize(feats, feats_lengths) | |
# Pre-encoder, e.g. used for raw input data | |
if self.preencoder is not None: | |
feats, feats_lengths = self.preencoder(feats, feats_lengths) | |
# pdb.set_trace() | |
encoder_out, encoder_out_lens, _ = self.encoder( | |
feats, feats_lengths, channel_size | |
) | |
assert encoder_out.size(0) == speech.size(0), ( | |
encoder_out.size(), | |
speech.size(0), | |
) | |
if encoder_out.dim() == 4: | |
assert encoder_out.size(2) <= encoder_out_lens.max(), ( | |
encoder_out.size(), | |
encoder_out_lens.max(), | |
) | |
else: | |
assert encoder_out.size(1) <= encoder_out_lens.max(), ( | |
encoder_out.size(), | |
encoder_out_lens.max(), | |
) | |
return encoder_out, encoder_out_lens | |
def _extract_feats( | |
self, speech: torch.Tensor, speech_lengths: torch.Tensor | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
assert speech_lengths.dim() == 1, speech_lengths.shape | |
# for data-parallel | |
speech = speech[:, : speech_lengths.max()] | |
if self.frontend is not None: | |
# Frontend | |
# e.g. STFT and Feature extract | |
# data_loader may send time-domain signal in this case | |
# speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim) | |
feats, feats_lengths, channel_size = self.frontend(speech, speech_lengths) | |
else: | |
# No frontend and no feature extract | |
feats, feats_lengths = speech, speech_lengths | |
channel_size = 1 | |
return feats, feats_lengths, channel_size | |
def _calc_att_loss( | |
self, | |
encoder_out: torch.Tensor, | |
encoder_out_lens: torch.Tensor, | |
ys_pad: torch.Tensor, | |
ys_pad_lens: torch.Tensor, | |
): | |
ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) | |
ys_in_lens = ys_pad_lens + 1 | |
# 1. Forward decoder | |
decoder_out, _ = self.decoder( | |
encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens | |
) | |
# 2. Compute attention loss | |
loss_att = self.criterion_att(decoder_out, ys_out_pad) | |
acc_att = th_accuracy( | |
decoder_out.view(-1, self.vocab_size), | |
ys_out_pad, | |
ignore_label=self.ignore_id, | |
) | |
# Compute cer/wer using attention-decoder | |
if self.training or self.error_calculator is None: | |
cer_att, wer_att = None, None | |
else: | |
ys_hat = decoder_out.argmax(dim=-1) | |
cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) | |
return loss_att, acc_att, cer_att, wer_att | |
def _calc_ctc_loss( | |
self, | |
encoder_out: torch.Tensor, | |
encoder_out_lens: torch.Tensor, | |
ys_pad: torch.Tensor, | |
ys_pad_lens: torch.Tensor, | |
): | |
# Calc CTC loss | |
if encoder_out.dim() == 4: | |
encoder_out = encoder_out.mean(1) | |
loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens) | |
# Calc CER using CTC | |
cer_ctc = None | |
if not self.training and self.error_calculator is not None: | |
ys_hat = self.ctc.argmax(encoder_out).data | |
cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True) | |
return loss_ctc, cer_ctc | |
def _calc_rnnt_loss( | |
self, | |
encoder_out: torch.Tensor, | |
encoder_out_lens: torch.Tensor, | |
ys_pad: torch.Tensor, | |
ys_pad_lens: torch.Tensor, | |
): | |
raise NotImplementedError | |