from torchaudio.models import Conformer from torchaudio.models.rnnt import _TimeReduction from transformers import PretrainedConfig, PreTrainedModel import torch import torchaudio import math import numpy as np from torch import nn from typing import List, Tuple, Optional HF_CTC_VOCAB = [ '', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ' ', '?', '_' ] DECIBEL = 2 * 20 * math.log10(torch.iinfo(torch.int16).max) GAIN = pow(10, 0.05 * DECIBEL) spectrogram_transform = torchaudio.transforms.MelSpectrogram( sample_rate=16000, n_fft=400, n_mels=80, hop_length=160) def piecewise_linear_log(x): x = x * GAIN x[x > math.e] = torch.log(x[x > math.e]) x[x <= math.e] = x[x <= math.e] / math.e return x def melspectrogram(x): if isinstance(x, np.ndarray): x = torch.Tensor(x) x = spectrogram_transform(x).transpose(1, 0) return piecewise_linear_log(x) class ConformerConfig(PretrainedConfig): model_type = 'conformer' class ConformerEncoder(PreTrainedModel): config_class = ConformerConfig def __init__( self, config, ) -> None: super().__init__(config) self.time_reduction = _TimeReduction(config.time_reduction_stride) self.input_linear = torch.nn.Linear( config.input_dim * config.time_reduction_stride, config.conformer_input_dim) self.conformer = Conformer( num_layers=config.conformer_num_layers, input_dim=config.conformer_input_dim, ffn_dim=config.conformer_ffn_dim, num_heads=config.conformer_num_heads, depthwise_conv_kernel_size=config.conformer_depthwise_conv_kernel_size, dropout=config.conformer_dropout, use_group_norm=True, convolution_first=True, ) self.output_linear = torch.nn.Linear(config.conformer_input_dim, config.output_dim) def forward(self, inputs, lengths, labels=None): time_reduction_out, time_reduction_lengths = self.time_reduction(inputs, lengths) input_linear_out = self.input_linear(time_reduction_out) x, input_lengths = self.conformer(input_linear_out, time_reduction_lengths) logits = self.output_linear(x) loss = None if labels is not None: labels_mask = labels >= 0 target_lengths = labels_mask.sum(-1) flattened_targets = labels.masked_select(labels_mask) log_probs = nn.functional.log_softmax( logits, dim=-1, dtype=torch.float32 ).transpose(0, 1) with torch.backends.cudnn.flags(enabled=False): loss = nn.functional.ctc_loss( log_probs, flattened_targets, input_lengths, target_lengths, blank=self.config.pad_token_id, reduction=self.config.ctc_loss_reduction, zero_infinity=self.config.ctc_zero_infinity, ) output = (logits, input_lengths) return ((loss,) + output) if loss is not None else output