# coding=utf-8 # Copyright 2023 The Kakao Enterprise Authors and the HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """PyTorch VITS model.""" import math from dataclasses import dataclass from typing import Any, Optional, Tuple, Union import numpy as np import torch import torch.utils.checkpoint from scipy.signal import get_window, kaiser from torch import nn from transformers.activations import ACT2FN from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled from transformers.integrations.fsdp import is_fsdp_managed_module from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask from transformers.modeling_outputs import ( BaseModelOutput, ModelOutput, ) from transformers.modeling_utils import PreTrainedModel from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from .configuration_vits import VitsConfig logger = logging.get_logger(__name__) # General docstring _CONFIG_FOR_DOC = "VitsConfig" @dataclass class VitsModelOutput(ModelOutput): """ Describes the outputs for the VITS model, with potential hidden states and attentions. Args: waveform (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): The final audio waveform predicted by the model. sequence_lengths (`torch.FloatTensor` of shape `(batch_size,)`): The length in samples of each element in the `waveform` batch. spectrogram (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_bins)`): The log-mel spectrogram predicted at the output of the flow model. This spectrogram is passed to the Hi-Fi GAN decoder model to obtain the final audio waveform. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. Attention weights after the attention softmax, used to compute the weighted average in the self-attention heads. """ waveform: torch.FloatTensor = None sequence_lengths: torch.FloatTensor = None spectrogram: Optional[Tuple[torch.FloatTensor]] = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None @dataclass class VitsTextEncoderOutput(ModelOutput): """ Describes the outputs for the VITS text encoder model, with potential hidden states and attentions. Args: last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): Sequence of hidden-states at the output of the last layer of the model. prior_means (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): The predicted mean values of the prior distribution for the latent text variables. prior_log_variances (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): The predicted log-variance values of the prior distribution for the latent text variables. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. Attention weights after the attention softmax, used to compute the weighted average in the self-attention heads. """ last_hidden_state: torch.FloatTensor = None prior_means: torch.FloatTensor = None prior_log_variances: torch.FloatTensor = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None @torch.jit.script def fused_add_tanh_sigmoid_multiply(input_a, input_b, num_channels): in_act = input_a + input_b t_act = torch.tanh(in_act[:, :num_channels, :]) s_act = torch.sigmoid(in_act[:, num_channels:, :]) acts = t_act * s_act return acts def _unconstrained_rational_quadratic_spline( inputs, unnormalized_widths, unnormalized_heights, unnormalized_derivatives, reverse=False, tail_bound=5.0, min_bin_width=1e-3, min_bin_height=1e-3, min_derivative=1e-3, ): """ This transformation represents a monotonically increasing piecewise rational quadratic function. Outside of the `tail_bound`, the transform behaves as an identity function. Args: inputs (`torch.FloatTensor` of shape `(batch_size, channels, seq_len)`: Second half of the hidden-states input to the Vits convolutional flow module. unnormalized_widths (`torch.FloatTensor` of shape `(batch_size, channels, seq_len, duration_predictor_flow_bins)`): First `duration_predictor_flow_bins` of the hidden-states from the output of the convolution projection layer in the convolutional flow module unnormalized_heights (`torch.FloatTensor` of shape `(batch_size, channels, seq_len, duration_predictor_flow_bins)`): Second `duration_predictor_flow_bins` of the hidden-states from the output of the convolution projection layer in the convolutional flow module unnormalized_derivatives (`torch.FloatTensor` of shape `(batch_size, channels, seq_len, duration_predictor_flow_bins)`): Third `duration_predictor_flow_bins` of the hidden-states from the output of the convolution projection layer in the convolutional flow module reverse (`bool`, *optional*, defaults to `False`): Whether the model is being run in reverse mode. tail_bound (`float`, *optional* defaults to 5): Upper and lower limit bound for the rational quadratic function. Outside of this `tail_bound`, the transform behaves as an identity function. min_bin_width (`float`, *optional*, defaults to 1e-3): Minimum bin value across the width dimension for the piecewise rational quadratic function. min_bin_height (`float`, *optional*, defaults to 1e-3): Minimum bin value across the height dimension for the piecewise rational quadratic function. min_derivative (`float`, *optional*, defaults to 1e-3): Minimum bin value across the derivatives for the piecewise rational quadratic function. Returns: outputs (`torch.FloatTensor` of shape `(batch_size, channels, seq_len)`: Hidden-states as transformed by the piecewise rational quadratic function with the `tail_bound` limits applied. log_abs_det (`torch.FloatTensor` of shape `(batch_size, channels, seq_len)`: Logarithm of the absolute value of the determinants corresponding to the `outputs` with the `tail_bound` limits applied. """ inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound) outside_interval_mask = ~inside_interval_mask outputs = torch.zeros_like(inputs) log_abs_det = torch.zeros_like(inputs) constant = np.log(np.exp(1 - min_derivative) - 1) unnormalized_derivatives = nn.functional.pad(unnormalized_derivatives, pad=(1, 1)) unnormalized_derivatives[..., 0] = constant unnormalized_derivatives[..., -1] = constant outputs[outside_interval_mask] = inputs[outside_interval_mask] log_abs_det[outside_interval_mask] = 0.0 outputs[inside_interval_mask], log_abs_det[inside_interval_mask] = _rational_quadratic_spline( inputs=inputs[inside_interval_mask], unnormalized_widths=unnormalized_widths[inside_interval_mask, :], unnormalized_heights=unnormalized_heights[inside_interval_mask, :], unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :], reverse=reverse, tail_bound=tail_bound, min_bin_width=min_bin_width, min_bin_height=min_bin_height, min_derivative=min_derivative, ) return outputs, log_abs_det def _rational_quadratic_spline( inputs, unnormalized_widths, unnormalized_heights, unnormalized_derivatives, reverse, tail_bound, min_bin_width, min_bin_height, min_derivative, ): """ This transformation represents a monotonically increasing piecewise rational quadratic function. Unlike the function `_unconstrained_rational_quadratic_spline`, the function behaves the same across the `tail_bound`. Args: inputs (`torch.FloatTensor` of shape `(batch_size, channels, seq_len)`: Second half of the hidden-states input to the Vits convolutional flow module. unnormalized_widths (`torch.FloatTensor` of shape `(batch_size, channels, seq_len, duration_predictor_flow_bins)`): First `duration_predictor_flow_bins` of the hidden-states from the output of the convolution projection layer in the convolutional flow module unnormalized_heights (`torch.FloatTensor` of shape `(batch_size, channels, seq_len, duration_predictor_flow_bins)`): Second `duration_predictor_flow_bins` of the hidden-states from the output of the convolution projection layer in the convolutional flow module unnormalized_derivatives (`torch.FloatTensor` of shape `(batch_size, channels, seq_len, duration_predictor_flow_bins)`): Third `duration_predictor_flow_bins` of the hidden-states from the output of the convolution projection layer in the convolutional flow module reverse (`bool`): Whether the model is being run in reverse mode. tail_bound (`float`): Upper and lower limit bound for the rational quadratic function. Outside of this `tail_bound`, the transform behaves as an identity function. min_bin_width (`float`): Minimum bin value across the width dimension for the piecewise rational quadratic function. min_bin_height (`float`): Minimum bin value across the height dimension for the piecewise rational quadratic function. min_derivative (`float`): Minimum bin value across the derivatives for the piecewise rational quadratic function. Returns: outputs (`torch.FloatTensor` of shape `(batch_size, channels, seq_len)`: Hidden-states as transformed by the piecewise rational quadratic function. log_abs_det (`torch.FloatTensor` of shape `(batch_size, channels, seq_len)`: Logarithm of the absolute value of the determinants corresponding to the `outputs`. """ upper_bound = tail_bound lower_bound = -tail_bound if torch.min(inputs) < lower_bound or torch.max(inputs) > upper_bound: raise ValueError("Input to a transform is not within its domain") num_bins = unnormalized_widths.shape[-1] if min_bin_width * num_bins > 1.0: raise ValueError(f"Minimal bin width {min_bin_width} too large for the number of bins {num_bins}") if min_bin_height * num_bins > 1.0: raise ValueError(f"Minimal bin height {min_bin_height} too large for the number of bins {num_bins}") widths = nn.functional.softmax(unnormalized_widths, dim=-1) widths = min_bin_width + (1 - min_bin_width * num_bins) * widths cumwidths = torch.cumsum(widths, dim=-1) cumwidths = nn.functional.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0) cumwidths = (upper_bound - lower_bound) * cumwidths + lower_bound cumwidths[..., 0] = lower_bound cumwidths[..., -1] = upper_bound widths = cumwidths[..., 1:] - cumwidths[..., :-1] derivatives = min_derivative + nn.functional.softplus(unnormalized_derivatives) heights = nn.functional.softmax(unnormalized_heights, dim=-1) heights = min_bin_height + (1 - min_bin_height * num_bins) * heights cumheights = torch.cumsum(heights, dim=-1) cumheights = nn.functional.pad(cumheights, pad=(1, 0), mode="constant", value=0.0) cumheights = (upper_bound - lower_bound) * cumheights + lower_bound cumheights[..., 0] = lower_bound cumheights[..., -1] = upper_bound heights = cumheights[..., 1:] - cumheights[..., :-1] bin_locations = cumheights if reverse else cumwidths bin_locations[..., -1] += 1e-6 bin_idx = torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1 bin_idx = bin_idx[..., None] input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0] input_bin_widths = widths.gather(-1, bin_idx)[..., 0] input_cumheights = cumheights.gather(-1, bin_idx)[..., 0] delta = heights / widths input_delta = delta.gather(-1, bin_idx)[..., 0] input_derivatives = derivatives.gather(-1, bin_idx)[..., 0] input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0] input_heights = heights.gather(-1, bin_idx)[..., 0] intermediate1 = input_derivatives + input_derivatives_plus_one - 2 * input_delta if not reverse: theta = (inputs - input_cumwidths) / input_bin_widths theta_one_minus_theta = theta * (1 - theta) numerator = input_heights * (input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta) denominator = input_delta + intermediate1 * theta_one_minus_theta outputs = input_cumheights + numerator / denominator derivative_numerator = input_delta.pow(2) * ( input_derivatives_plus_one * theta.pow(2) + 2 * input_delta * theta_one_minus_theta + input_derivatives * (1 - theta).pow(2) ) log_abs_det = torch.log(derivative_numerator) - 2 * torch.log(denominator) return outputs, log_abs_det else: # find the roots of a quadratic equation intermediate2 = inputs - input_cumheights intermediate3 = intermediate2 * intermediate1 a = input_heights * (input_delta - input_derivatives) + intermediate3 b = input_heights * input_derivatives - intermediate3 c = -input_delta * intermediate2 discriminant = b.pow(2) - 4 * a * c if not (discriminant >= 0).all(): raise RuntimeError(f"invalid discriminant {discriminant}") root = (2 * c) / (-b - torch.sqrt(discriminant)) outputs = root * input_bin_widths + input_cumwidths theta_one_minus_theta = root * (1 - root) denominator = input_delta + intermediate1 * theta_one_minus_theta derivative_numerator = input_delta.pow(2) * ( input_derivatives_plus_one * root.pow(2) + 2 * input_delta * theta_one_minus_theta + input_derivatives * (1 - root).pow(2) ) log_abs_det = torch.log(derivative_numerator) - 2 * torch.log(denominator) return outputs, -log_abs_det class VitsWaveNet(torch.nn.Module): def __init__(self, config: VitsConfig, num_layers: int): super().__init__() self.hidden_size = config.hidden_size self.num_layers = num_layers self.in_layers = torch.nn.ModuleList() self.res_skip_layers = torch.nn.ModuleList() self.dropout = nn.Dropout(config.wavenet_dropout) if hasattr(nn.utils.parametrizations, "weight_norm"): weight_norm = nn.utils.parametrizations.weight_norm else: weight_norm = nn.utils.weight_norm if config.speaker_embedding_size != 0: cond_layer = torch.nn.Conv1d(config.speaker_embedding_size, 2 * config.hidden_size * num_layers, 1) self.cond_layer = weight_norm(cond_layer, name="weight") for i in range(num_layers): dilation = config.wavenet_dilation_rate**i padding = (config.wavenet_kernel_size * dilation - dilation) // 2 in_layer = torch.nn.Conv1d( in_channels=config.hidden_size, out_channels=2 * config.hidden_size, kernel_size=config.wavenet_kernel_size, dilation=dilation, padding=padding, ) in_layer = weight_norm(in_layer, name="weight") self.in_layers.append(in_layer) # last one is not necessary if i < num_layers - 1: res_skip_channels = 2 * config.hidden_size else: res_skip_channels = config.hidden_size res_skip_layer = torch.nn.Conv1d(config.hidden_size, res_skip_channels, 1) res_skip_layer = weight_norm(res_skip_layer, name="weight") self.res_skip_layers.append(res_skip_layer) def forward(self, inputs, padding_mask, global_conditioning=None): outputs = torch.zeros_like(inputs) num_channels_tensor = torch.IntTensor([self.hidden_size]) if global_conditioning is not None: global_conditioning = self.cond_layer(global_conditioning) for i in range(self.num_layers): hidden_states = self.in_layers[i](inputs) if global_conditioning is not None: cond_offset = i * 2 * self.hidden_size global_states = global_conditioning[:, cond_offset : cond_offset + 2 * self.hidden_size, :] else: global_states = torch.zeros_like(hidden_states) acts = fused_add_tanh_sigmoid_multiply(hidden_states, global_states, num_channels_tensor[0]) acts = self.dropout(acts) res_skip_acts = self.res_skip_layers[i](acts) if i < self.num_layers - 1: res_acts = res_skip_acts[:, : self.hidden_size, :] inputs = (inputs + res_acts) * padding_mask outputs = outputs + res_skip_acts[:, self.hidden_size :, :] else: outputs = outputs + res_skip_acts return outputs * padding_mask def remove_weight_norm(self): if self.speaker_embedding_size != 0: torch.nn.utils.remove_weight_norm(self.cond_layer) for layer in self.in_layers: torch.nn.utils.remove_weight_norm(layer) for layer in self.res_skip_layers: torch.nn.utils.remove_weight_norm(layer) class VitsPosteriorEncoder(nn.Module): def __init__(self, config: VitsConfig): super().__init__() self.out_channels = config.flow_size self.conv_pre = nn.Conv1d(config.spectrogram_bins, config.hidden_size, 1) self.wavenet = VitsWaveNet(config, num_layers=config.posterior_encoder_num_wavenet_layers) self.conv_proj = nn.Conv1d(config.hidden_size, self.out_channels * 2, 1) def forward(self, inputs, padding_mask, global_conditioning=None): inputs = self.conv_pre(inputs) * padding_mask inputs = self.wavenet(inputs, padding_mask, global_conditioning) stats = self.conv_proj(inputs) * padding_mask mean, log_stddev = torch.split(stats, self.out_channels, dim=1) sampled = (mean + torch.randn_like(mean) * torch.exp(log_stddev)) * padding_mask return sampled, mean, log_stddev # Copied from transformers.models.speecht5.modeling_speecht5.HifiGanResidualBlock class HifiGanResidualBlock(nn.Module): def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), leaky_relu_slope=0.1): super().__init__() self.leaky_relu_slope = leaky_relu_slope self.convs1 = nn.ModuleList( [ nn.Conv1d( channels, channels, kernel_size, stride=1, dilation=dilation[i], padding=self.get_padding(kernel_size, dilation[i]), ) for i in range(len(dilation)) ] ) self.convs2 = nn.ModuleList( [ nn.Conv1d( channels, channels, kernel_size, stride=1, dilation=1, padding=self.get_padding(kernel_size, 1), ) for _ in range(len(dilation)) ] ) def get_padding(self, kernel_size, dilation=1): return (kernel_size * dilation - dilation) // 2 def apply_weight_norm(self): weight_norm = nn.utils.weight_norm if hasattr(nn.utils.parametrizations, "weight_norm"): weight_norm = nn.utils.parametrizations.weight_norm for layer in self.convs1: weight_norm(layer) for layer in self.convs2: weight_norm(layer) def remove_weight_norm(self): for layer in self.convs1: nn.utils.remove_weight_norm(layer) for layer in self.convs2: nn.utils.remove_weight_norm(layer) def forward(self, hidden_states): for conv1, conv2 in zip(self.convs1, self.convs2): residual = hidden_states hidden_states = nn.functional.leaky_relu(hidden_states, self.leaky_relu_slope) hidden_states = conv1(hidden_states) hidden_states = nn.functional.leaky_relu(hidden_states, self.leaky_relu_slope) hidden_states = conv2(hidden_states) hidden_states = hidden_states + residual return hidden_states class VitsHifiGan(nn.Module): def __init__(self, config: VitsConfig): super().__init__() self.config = config self.num_kernels = len(config.resblock_kernel_sizes) self.num_upsamples = len(config.upsample_rates) self.conv_pre = nn.Conv1d( config.flow_size, config.upsample_initial_channel, kernel_size=7, stride=1, padding=3, ) self.upsampler = nn.ModuleList() for i, (upsample_rate, kernel_size) in enumerate(zip(config.upsample_rates, config.upsample_kernel_sizes)): self.upsampler.append( nn.ConvTranspose1d( config.upsample_initial_channel // (2**i), config.upsample_initial_channel // (2 ** (i + 1)), kernel_size=kernel_size, stride=upsample_rate, padding=(kernel_size - upsample_rate) // 2, ) ) self.resblocks = nn.ModuleList() for i in range(len(self.upsampler)): channels = config.upsample_initial_channel // (2 ** (i + 1)) for kernel_size, dilation in zip(config.resblock_kernel_sizes, config.resblock_dilation_sizes): self.resblocks.append(HifiGanResidualBlock(channels, kernel_size, dilation, config.leaky_relu_slope)) self.conv_post = nn.Conv1d(channels, 1, kernel_size=7, stride=1, padding=3, bias=False) if config.speaker_embedding_size != 0: self.cond = nn.Conv1d(config.speaker_embedding_size, config.upsample_initial_channel, 1) def apply_weight_norm(self): weight_norm = nn.utils.weight_norm if hasattr(nn.utils.parametrizations, "weight_norm"): weight_norm = nn.utils.parametrizations.weight_norm for layer in self.upsampler: weight_norm(layer) for layer in self.resblocks: layer.apply_weight_norm() def remove_weight_norm(self): for layer in self.upsampler: nn.utils.remove_weight_norm(layer) for layer in self.resblocks: layer.remove_weight_norm() def forward( self, spectrogram: torch.FloatTensor, global_conditioning: Optional[torch.FloatTensor] = None ) -> torch.FloatTensor: r""" Converts a spectrogram into a speech waveform. Args: spectrogram (`torch.FloatTensor` of shape `(batch_size, config.spectrogram_bins, sequence_length)`): Tensor containing the spectrograms. global_conditioning (`torch.FloatTensor` of shape `(batch_size, config.speaker_embedding_size, 1)`, *optional*): Tensor containing speaker embeddings, for multispeaker models. Returns: `torch.FloatTensor`: Tensor of shape shape `(batch_size, 1, num_frames)` containing the speech waveform. """ hidden_states = self.conv_pre(spectrogram) if global_conditioning is not None: hidden_states = hidden_states + self.cond(global_conditioning) for i in range(self.num_upsamples): hidden_states = nn.functional.leaky_relu(hidden_states, self.config.leaky_relu_slope) hidden_states = self.upsampler[i](hidden_states) res_state = self.resblocks[i * self.num_kernels](hidden_states) for j in range(1, self.num_kernels): res_state += self.resblocks[i * self.num_kernels + j](hidden_states) hidden_states = res_state / self.num_kernels hidden_states = nn.functional.leaky_relu(hidden_states) hidden_states = self.conv_post(hidden_states) waveform = torch.tanh(hidden_states) return waveform class VitsISTFT(nn.Module): def __init__(self, config: VitsConfig): super().__init__() self.config = config self.gen_istft_n_fft = config.gen_istft_n_fft self.gen_istft_hop_size = config.gen_istft_hop_size self.post_n_fft = config.gen_istft_n_fft if config.istft_decoder in ["ms_istft", "mb_istft"]: self.subbands = config.subbands if config.istft_decoder == "mb_istft": self.pqmf = PQMF(subbands=self.subbands) else: updown_filter = torch.zeros((self.subbands, self.subbands, self.subbands)).float() for k in range(self.subbands): updown_filter[k, k, 0] = 1.0 self.register_buffer("updown_filter", updown_filter) self.multistream_conv_post = nn.Conv1d( 4, 1, kernel_size=63, bias=False, padding=self.get_padding(63, 1) ) self.num_kernels = len(config.resblock_kernel_sizes) self.num_upsamples = len(config.upsample_rates) self.conv_pre = nn.Conv1d( config.flow_size, config.upsample_initial_channel, kernel_size=7, stride=1, padding=3, ) self.upsampler = nn.ModuleList() for i, (upsample_rate, kernel_size) in enumerate(zip(config.upsample_rates, config.upsample_kernel_sizes)): self.upsampler.append( nn.ConvTranspose1d( config.upsample_initial_channel // (2**i), config.upsample_initial_channel // (2 ** (i + 1)), kernel_size=kernel_size, stride=upsample_rate, padding=(kernel_size - upsample_rate) // 2, ) ) self.resblocks = nn.ModuleList() for i in range(len(self.upsampler)): channels = config.upsample_initial_channel // (2 ** (i + 1)) for kernel_size, dilation in zip(config.resblock_kernel_sizes, config.resblock_dilation_sizes): self.resblocks.append(HifiGanResidualBlock(channels, kernel_size, dilation, config.leaky_relu_slope)) if config.istft_decoder == "istft": self.conv_post = nn.Conv1d(channels, self.post_n_fft + 2, kernel_size=7, stride=1, padding=3, bias=True) elif config.istft_decoder in ["ms_istft", "mb_istft"]: self.conv_post = nn.Conv1d( channels, self.subbands * (self.post_n_fft + 2), kernel_size=7, stride=1, padding=3, bias=True ) self.reflection_pad = nn.ReflectionPad1d((1, 0)) self.stft = TorchSTFT( filter_length=self.gen_istft_n_fft, hop_length=self.gen_istft_hop_size, win_length=self.gen_istft_n_fft ) if config.speaker_embedding_size != 0: self.cond = nn.Conv1d(config.speaker_embedding_size, config.upsample_initial_channel, 1) def get_padding(self, kernel_size, dilation=1): return int((kernel_size * dilation - dilation) / 2) def apply_weight_norm(self): weight_norm = nn.utils.weight_norm if hasattr(nn.utils.parametrizations, "weight_norm"): weight_norm = nn.utils.parametrizations.weight_norm for layer in self.upsampler: weight_norm(layer) for layer in self.resblocks: layer.apply_weight_norm() weight_norm(self.conv_pre) weight_norm(self.conv_post) if self.config.istft_decoder == "ms_istft": weight_norm(self.multistream_conv_post) def remove_weight_norm(self): for layer in self.upsampler: nn.utils.remove_weight_norm(layer) for layer in self.resblocks: layer.remove_weight_norm() nn.utils.remove_weight_norm(self.conv_pre) nn.utils.remove_weight_norm(self.conv_post) if self.config.istft_decoder == "ms_istft": nn.utils.remove_weight_norm(self.multistream_conv_post) def forward( self, spectrogram: torch.FloatTensor, global_conditioning: Optional[torch.FloatTensor] = None ) -> torch.FloatTensor: r""" Converts a spectrogram into a speech waveform. Args: spectrogram (`torch.FloatTensor` of shape `(batch_size, config.spectrogram_bins, sequence_length)`): Tensor containing the spectrograms. global_conditioning (`torch.FloatTensor` of shape `(batch_size, config.speaker_embedding_size, 1)`, *optional*): Tensor containing speaker embeddings, for multispeaker models. Returns: `torch.FloatTensor`: Tensor of shape shape `(batch_size, 1, num_frames)` containing the speech waveform. """ hidden_states = self.conv_pre(spectrogram) if global_conditioning is not None: hidden_states = hidden_states + self.cond(global_conditioning) for i in range(self.num_upsamples): hidden_states = nn.functional.leaky_relu(hidden_states, self.config.leaky_relu_slope) hidden_states = self.upsampler[i](hidden_states) res_state = self.resblocks[i * self.num_kernels](hidden_states) for j in range(1, self.num_kernels): res_state += self.resblocks[i * self.num_kernels + j](hidden_states) hidden_states = res_state / self.num_kernels hidden_states = nn.functional.leaky_relu(hidden_states) hidden_states = self.reflection_pad(hidden_states) hidden_states = self.conv_post(hidden_states) if self.config.istft_decoder == "istft": spec = torch.exp(hidden_states[:, : self.post_n_fft // 2 + 1, :]) phase = math.pi * torch.sin(hidden_states[:, self.post_n_fft // 2 + 1 :, :]) waveform = self.stft.inverse(spec, phase) elif self.config.istft_decoder in ["mb_istft", "ms_istft"]: hidden_states = torch.reshape( hidden_states, ( hidden_states.shape[0], self.subbands, hidden_states.shape[1] // self.subbands, hidden_states.shape[-1], ), ) spec = torch.exp(hidden_states[:, :, : self.post_n_fft // 2 + 1, :]) phase = math.pi * torch.sin(hidden_states[:, :, self.post_n_fft // 2 + 1 :, :]) waveform_mb = self.stft.inverse( torch.reshape(spec, (spec.shape[0] * self.subbands, self.gen_istft_n_fft // 2 + 1, spec.shape[-1])), torch.reshape(phase, (phase.shape[0] * self.subbands, self.gen_istft_n_fft // 2 + 1, phase.shape[-1])), ) waveform_mb = torch.reshape(waveform_mb, (hidden_states.shape[0], self.subbands, 1, waveform_mb.shape[-1])) waveform_mb = waveform_mb.squeeze(-2) if self.config.istft_decoder == "mb_istft": waveform = self.pqmf.synthesis(waveform_mb) else: waveform_mb = torch.nn.functional.conv_transpose1d( waveform_mb, self.updown_filter * self.subbands, stride=self.subbands ) waveform = self.multistream_conv_post(waveform_mb) return waveform class PQMF(torch.nn.Module): """PQMF module. This module is based on `Near-perfect-reconstruction pseudo-QMF banks`_. .. _`Near-perfect-reconstruction pseudo-QMF banks`: https://ieeexplore.ieee.org/document/258122 """ def __init__(self, subbands=4, taps=62, cutoff_ratio=0.15, beta=9.0): """Initilize PQMF module. Args: subbands (int): The number of subbands. taps (int): The number of filter taps. cutoff_ratio (float): Cut-off frequency ratio. beta (float): Beta coefficient for kaiser window. """ super(PQMF, self).__init__() # define filter coefficient h_proto = self.design_prototype_filter(taps, cutoff_ratio, beta) h_analysis = np.zeros((subbands, len(h_proto))) h_synthesis = np.zeros((subbands, len(h_proto))) for k in range(subbands): h_analysis[k] = ( 2 * h_proto * np.cos( (2 * k + 1) * (np.pi / (2 * subbands)) * (np.arange(taps + 1) - ((taps - 1) / 2)) + (-1) ** k * np.pi / 4 ) ) h_synthesis[k] = ( 2 * h_proto * np.cos( (2 * k + 1) * (np.pi / (2 * subbands)) * (np.arange(taps + 1) - ((taps - 1) / 2)) - (-1) ** k * np.pi / 4 ) ) # convert to tensor analysis_filter = torch.from_numpy(h_analysis).float().unsqueeze(1) synthesis_filter = torch.from_numpy(h_synthesis).float().unsqueeze(0) # register coefficients as beffer self.register_buffer("analysis_filter", analysis_filter) self.register_buffer("synthesis_filter", synthesis_filter) # filter for downsampling & upsampling updown_filter = torch.zeros((subbands, subbands, subbands)).float() for k in range(subbands): updown_filter[k, k, 0] = 1.0 self.register_buffer("updown_filter", updown_filter) self.subbands = subbands # keep padding info self.pad_fn = torch.nn.ConstantPad1d(taps // 2, 0.0) def design_prototype_filter(self, taps=62, cutoff_ratio=0.15, beta=9.0): """Design prototype filter for PQMF. This method is based on `A Kaiser window approach for the design of prototype filters of cosine modulated filterbanks`_. Args: taps (int): The number of filter taps. cutoff_ratio (float): Cut-off frequency ratio. beta (float): Beta coefficient for kaiser window. Returns: ndarray: Impluse response of prototype filter (taps + 1,). .. _`A Kaiser window approach for the design of prototype filters of cosine modulated filterbanks`: https://ieeexplore.ieee.org/abstract/document/681427 """ # check the arguments are valid assert taps % 2 == 0, "The number of taps mush be even number." assert 0.0 < cutoff_ratio < 1.0, "Cutoff ratio must be > 0.0 and < 1.0." # make initial filter omega_c = np.pi * cutoff_ratio with np.errstate(invalid="ignore"): h_i = np.sin(omega_c * (np.arange(taps + 1) - 0.5 * taps)) / (np.pi * (np.arange(taps + 1) - 0.5 * taps)) h_i[taps // 2] = np.cos(0) * cutoff_ratio # fix nan due to indeterminate form # apply kaiser window w = kaiser(taps + 1, beta) h = h_i * w return h def analysis(self, x): """Analysis with PQMF. Args: x (Tensor): Input tensor (B, 1, T). Returns: Tensor: Output tensor (B, subbands, T // subbands). """ x = torch.nn.functional.conv1d(self.pad_fn(x), self.analysis_filter) return torch.nn.functional.conv1d(x, self.updown_filter, stride=self.subbands) def synthesis(self, x): """Synthesis with PQMF. Args: x (Tensor): Input tensor (B, subbands, T // subbands). Returns: Tensor: Output tensor (B, 1, T). """ # NOTE(kan-bayashi): Power will be dreased so here multipy by # subbands. # Not sure this is the correct way, it is better to check again. # TODO(kan-bayashi): Understand the reconstruction procedure x = torch.nn.functional.conv_transpose1d(x, self.updown_filter * self.subbands, stride=self.subbands) return torch.nn.functional.conv1d(self.pad_fn(x), self.synthesis_filter) class TorchSTFT(torch.nn.Module): def __init__(self, filter_length=800, hop_length=200, win_length=800, window="hann"): super().__init__() self.filter_length = filter_length self.hop_length = hop_length self.win_length = win_length self.window = torch.from_numpy(get_window(window, win_length, fftbins=True).astype(np.float32)) def transform(self, input_data): forward_transform = torch.stft( input_data, self.filter_length, self.hop_length, self.win_length, window=self.window, return_complex=True ) return torch.abs(forward_transform), torch.angle(forward_transform) def inverse(self, magnitude, phase): inverse_transform = torch.istft( magnitude * torch.exp(phase * 1j), self.filter_length, self.hop_length, self.win_length, window=self.window.to(magnitude.device), ) return inverse_transform.unsqueeze(-2) # unsqueeze to stay consistent with conv_transpose1d implementation def forward(self, input_data): self.magnitude, self.phase = self.transform(input_data) reconstruction = self.inverse(self.magnitude, self.phase) return reconstruction class VitsResidualCouplingLayer(nn.Module): def __init__(self, config: VitsConfig): super().__init__() self.half_channels = config.flow_size // 2 self.conv_pre = nn.Conv1d(self.half_channels, config.hidden_size, 1) self.wavenet = VitsWaveNet(config, num_layers=config.prior_encoder_num_wavenet_layers) self.conv_post = nn.Conv1d(config.hidden_size, self.half_channels, 1) def forward(self, inputs, padding_mask, global_conditioning=None, reverse=False): first_half, second_half = torch.split(inputs, [self.half_channels] * 2, dim=1) hidden_states = self.conv_pre(first_half) * padding_mask hidden_states = self.wavenet(hidden_states, padding_mask, global_conditioning) mean = self.conv_post(hidden_states) * padding_mask log_stddev = torch.zeros_like(mean) if not reverse: second_half = mean + second_half * torch.exp(log_stddev) * padding_mask outputs = torch.cat([first_half, second_half], dim=1) log_determinant = torch.sum(log_stddev, [1, 2]) return outputs, log_determinant else: second_half = (second_half - mean) * torch.exp(-log_stddev) * padding_mask outputs = torch.cat([first_half, second_half], dim=1) return outputs, None class VitsResidualCouplingBlock(nn.Module): def __init__(self, config: VitsConfig): super().__init__() self.flows = nn.ModuleList() for _ in range(config.prior_encoder_num_flows): self.flows.append(VitsResidualCouplingLayer(config)) def forward(self, inputs, padding_mask, global_conditioning=None, reverse=False): if not reverse: for flow in self.flows: inputs, _ = flow(inputs, padding_mask, global_conditioning) inputs = torch.flip(inputs, [1]) else: for flow in reversed(self.flows): inputs = torch.flip(inputs, [1]) inputs, _ = flow(inputs, padding_mask, global_conditioning, reverse=True) return inputs class VitsDilatedDepthSeparableConv(nn.Module): def __init__(self, config: VitsConfig, dropout_rate=0.0): super().__init__() kernel_size = config.duration_predictor_kernel_size channels = config.hidden_size self.num_layers = config.depth_separable_num_layers self.dropout = nn.Dropout(dropout_rate) self.convs_dilated = nn.ModuleList() self.convs_pointwise = nn.ModuleList() self.norms_1 = nn.ModuleList() self.norms_2 = nn.ModuleList() for i in range(self.num_layers): dilation = kernel_size**i padding = (kernel_size * dilation - dilation) // 2 self.convs_dilated.append( nn.Conv1d( in_channels=channels, out_channels=channels, kernel_size=kernel_size, groups=channels, dilation=dilation, padding=padding, ) ) self.convs_pointwise.append(nn.Conv1d(channels, channels, 1)) self.norms_1.append(nn.LayerNorm(channels)) self.norms_2.append(nn.LayerNorm(channels)) def forward(self, inputs, padding_mask, global_conditioning=None): if global_conditioning is not None: inputs = inputs + global_conditioning for i in range(self.num_layers): hidden_states = self.convs_dilated[i](inputs * padding_mask) hidden_states = self.norms_1[i](hidden_states.transpose(1, -1)).transpose(1, -1) hidden_states = nn.functional.gelu(hidden_states) hidden_states = self.convs_pointwise[i](hidden_states) hidden_states = self.norms_2[i](hidden_states.transpose(1, -1)).transpose(1, -1) hidden_states = nn.functional.gelu(hidden_states) hidden_states = self.dropout(hidden_states) inputs = inputs + hidden_states return inputs * padding_mask class VitsConvFlow(nn.Module): def __init__(self, config: VitsConfig): super().__init__() self.filter_channels = config.hidden_size self.half_channels = config.depth_separable_channels // 2 self.num_bins = config.duration_predictor_flow_bins self.tail_bound = config.duration_predictor_tail_bound self.conv_pre = nn.Conv1d(self.half_channels, self.filter_channels, 1) self.conv_dds = VitsDilatedDepthSeparableConv(config) self.conv_proj = nn.Conv1d(self.filter_channels, self.half_channels * (self.num_bins * 3 - 1), 1) def forward(self, inputs, padding_mask, global_conditioning=None, reverse=False): first_half, second_half = torch.split(inputs, [self.half_channels] * 2, dim=1) hidden_states = self.conv_pre(first_half) hidden_states = self.conv_dds(hidden_states, padding_mask, global_conditioning) hidden_states = self.conv_proj(hidden_states) * padding_mask batch_size, channels, length = first_half.shape hidden_states = hidden_states.reshape(batch_size, channels, -1, length).permute(0, 1, 3, 2) unnormalized_widths = hidden_states[..., : self.num_bins] / math.sqrt(self.filter_channels) unnormalized_heights = hidden_states[..., self.num_bins : 2 * self.num_bins] / math.sqrt(self.filter_channels) unnormalized_derivatives = hidden_states[..., 2 * self.num_bins :] second_half, log_abs_det = _unconstrained_rational_quadratic_spline( second_half, unnormalized_widths, unnormalized_heights, unnormalized_derivatives, reverse=reverse, tail_bound=self.tail_bound, ) outputs = torch.cat([first_half, second_half], dim=1) * padding_mask if not reverse: log_determinant = torch.sum(log_abs_det * padding_mask, [1, 2]) return outputs, log_determinant else: return outputs, None class VitsElementwiseAffine(nn.Module): def __init__(self, config: VitsConfig): super().__init__() self.channels = config.depth_separable_channels self.translate = nn.Parameter(torch.zeros(self.channels, 1)) self.log_scale = nn.Parameter(torch.zeros(self.channels, 1)) def forward(self, inputs, padding_mask, global_conditioning=None, reverse=False): if not reverse: outputs = self.translate + torch.exp(self.log_scale) * inputs outputs = outputs * padding_mask log_determinant = torch.sum(self.log_scale * padding_mask, [1, 2]) return outputs, log_determinant else: outputs = (inputs - self.translate) * torch.exp(-self.log_scale) * padding_mask return outputs, None class VitsStochasticDurationPredictor(nn.Module): def __init__(self, config): super().__init__() embed_dim = config.speaker_embedding_size filter_channels = config.hidden_size self.conv_pre = nn.Conv1d(filter_channels, filter_channels, 1) self.conv_proj = nn.Conv1d(filter_channels, filter_channels, 1) self.conv_dds = VitsDilatedDepthSeparableConv( config, dropout_rate=config.duration_predictor_dropout, ) if embed_dim != 0: self.cond = nn.Conv1d(embed_dim, filter_channels, 1) self.flows = nn.ModuleList() self.flows.append(VitsElementwiseAffine(config)) for _ in range(config.duration_predictor_num_flows): self.flows.append(VitsConvFlow(config)) self.post_conv_pre = nn.Conv1d(1, filter_channels, 1) self.post_conv_proj = nn.Conv1d(filter_channels, filter_channels, 1) self.post_conv_dds = VitsDilatedDepthSeparableConv( config, dropout_rate=config.duration_predictor_dropout, ) self.post_flows = nn.ModuleList() self.post_flows.append(VitsElementwiseAffine(config)) for _ in range(config.duration_predictor_num_flows): self.post_flows.append(VitsConvFlow(config)) def forward(self, inputs, padding_mask, global_conditioning=None, durations=None, reverse=False, noise_scale=1.0): inputs = torch.detach(inputs) inputs = self.conv_pre(inputs) if global_conditioning is not None: global_conditioning = torch.detach(global_conditioning) inputs = inputs + self.cond(global_conditioning) inputs = self.conv_dds(inputs, padding_mask) inputs = self.conv_proj(inputs) * padding_mask if not reverse: hidden_states = self.post_conv_pre(durations) hidden_states = self.post_conv_dds(hidden_states, padding_mask) hidden_states = self.post_conv_proj(hidden_states) * padding_mask random_posterior = ( torch.randn(durations.size(0), 2, durations.size(2)).to(device=inputs.device, dtype=inputs.dtype) * padding_mask ) log_determinant_posterior_sum = 0 latents_posterior = random_posterior for flow in self.post_flows: latents_posterior, log_determinant = flow( latents_posterior, padding_mask, global_conditioning=inputs + hidden_states ) latents_posterior = torch.flip(latents_posterior, [1]) log_determinant_posterior_sum += log_determinant first_half, second_half = torch.split(latents_posterior, [1, 1], dim=1) log_determinant_posterior_sum += torch.sum( (nn.functional.logsigmoid(first_half) + nn.functional.logsigmoid(-first_half)) * padding_mask, [1, 2] ) logq = ( torch.sum(-0.5 * (math.log(2 * math.pi) + (random_posterior**2)) * padding_mask, [1, 2]) - log_determinant_posterior_sum ) first_half = (durations - torch.sigmoid(first_half)) * padding_mask first_half = torch.log(torch.clamp_min(first_half, 1e-5)) * padding_mask log_determinant_sum = torch.sum(-first_half, [1, 2]) latents = torch.cat([first_half, second_half], dim=1) for flow in self.flows: latents, log_determinant = flow(latents, padding_mask, global_conditioning=inputs) latents = torch.flip(latents, [1]) log_determinant_sum += log_determinant nll = torch.sum(0.5 * (math.log(2 * math.pi) + (latents**2)) * padding_mask, [1, 2]) - log_determinant_sum return nll + logq else: flows = list(reversed(self.flows)) flows = flows[:-2] + [flows[-1]] # remove a useless vflow latents = ( torch.randn(inputs.size(0), 2, inputs.size(2)).to(device=inputs.device, dtype=inputs.dtype) * noise_scale ) for flow in flows: latents = torch.flip(latents, [1]) latents, _ = flow(latents, padding_mask, global_conditioning=inputs, reverse=True) log_duration, _ = torch.split(latents, [1, 1], dim=1) return log_duration class VitsDurationPredictor(nn.Module): def __init__(self, config): super().__init__() kernel_size = config.duration_predictor_kernel_size filter_channels = config.duration_predictor_filter_channels self.dropout = nn.Dropout(config.duration_predictor_dropout) self.conv_1 = nn.Conv1d(config.hidden_size, filter_channels, kernel_size, padding=kernel_size // 2) self.norm_1 = nn.LayerNorm(filter_channels, eps=config.layer_norm_eps) self.conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2) self.norm_2 = nn.LayerNorm(filter_channels, eps=config.layer_norm_eps) self.proj = nn.Conv1d(filter_channels, 1, 1) if config.speaker_embedding_size != 0: self.cond = nn.Conv1d(config.speaker_embedding_size, config.hidden_size, 1) def forward(self, inputs, padding_mask, global_conditioning=None): inputs = torch.detach(inputs) if global_conditioning is not None: global_conditioning = torch.detach(global_conditioning) inputs = inputs + self.cond(global_conditioning) inputs = self.conv_1(inputs * padding_mask) inputs = torch.relu(inputs) inputs = self.norm_1(inputs.transpose(1, -1)).transpose(1, -1) inputs = self.dropout(inputs) inputs = self.conv_2(inputs * padding_mask) inputs = torch.relu(inputs) inputs = self.norm_2(inputs.transpose(1, -1)).transpose(1, -1) inputs = self.dropout(inputs) inputs = self.proj(inputs * padding_mask) return inputs * padding_mask class VitsAttention(nn.Module): """Multi-headed attention with relative positional representation.""" def __init__(self, config: VitsConfig): super().__init__() self.embed_dim = config.hidden_size self.num_heads = config.num_attention_heads self.dropout = config.attention_dropout self.window_size = config.window_size self.head_dim = self.embed_dim // self.num_heads self.scaling = self.head_dim**-0.5 if (self.head_dim * self.num_heads) != self.embed_dim: raise ValueError( f"hidden_size must be divisible by num_attention_heads (got `hidden_size`: {self.embed_dim}" f" and `num_attention_heads`: {self.num_heads})." ) self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.use_bias) self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.use_bias) self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.use_bias) self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.use_bias) if self.window_size: self.emb_rel_k = nn.Parameter(torch.randn(1, self.window_size * 2 + 1, self.head_dim) * self.scaling) self.emb_rel_v = nn.Parameter(torch.randn(1, self.window_size * 2 + 1, self.head_dim) * self.scaling) def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """Input shape: Batch x Time x Channel""" # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder bsz, tgt_len, _ = hidden_states.size() # get query proj query_states = self.q_proj(hidden_states) * self.scaling # self_attention key_states = self._shape(self.k_proj(hidden_states), -1, bsz) value_states = self._shape(self.v_proj(hidden_states), -1, bsz) proj_shape = (bsz * self.num_heads, -1, self.head_dim) query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) key_states = key_states.view(*proj_shape) value_states = value_states.view(*proj_shape) src_len = key_states.size(1) attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): raise ValueError( f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" f" {attn_weights.size()}" ) if self.window_size is not None: key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, src_len) relative_logits = torch.matmul(query_states, key_relative_embeddings.transpose(-2, -1)) rel_pos_bias = self._relative_position_to_absolute_position(relative_logits) attn_weights += rel_pos_bias if attention_mask is not None: if attention_mask.size() != (bsz, 1, tgt_len, src_len): raise ValueError( f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" ) attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) attn_weights = nn.functional.softmax(attn_weights, dim=-1) if layer_head_mask is not None: if layer_head_mask.size() != (self.num_heads,): raise ValueError( f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" f" {layer_head_mask.size()}" ) attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) if output_attentions: # this operation is a bit awkward, but it's required to # make sure that attn_weights keeps its gradient. # In order to do so, attn_weights have to be reshaped # twice and have to be reused in the following attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) else: attn_weights_reshaped = None attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) attn_output = torch.bmm(attn_probs, value_states) if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): raise ValueError( f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" f" {attn_output.size()}" ) if self.window_size is not None: value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, src_len) relative_weights = self._absolute_position_to_relative_position(attn_probs) rel_pos_bias = torch.matmul(relative_weights, value_relative_embeddings) attn_output += rel_pos_bias attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) attn_output = attn_output.transpose(1, 2) # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be # partitioned aross GPUs when using tensor-parallelism. attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) attn_output = self.out_proj(attn_output) return attn_output, attn_weights_reshaped def _get_relative_embeddings(self, relative_embeddings, length): pad_length = max(length - (self.window_size + 1), 0) if pad_length > 0: relative_embeddings = nn.functional.pad(relative_embeddings, [0, 0, pad_length, pad_length, 0, 0]) slice_start_position = max((self.window_size + 1) - length, 0) slice_end_position = slice_start_position + 2 * length - 1 return relative_embeddings[:, slice_start_position:slice_end_position] def _relative_position_to_absolute_position(self, x): batch_heads, length, _ = x.size() # Concat columns of pad to shift from relative to absolute indexing. x = nn.functional.pad(x, [0, 1, 0, 0, 0, 0]) # Concat extra elements so to add up to shape (len+1, 2*len-1). x_flat = x.view([batch_heads, length * 2 * length]) x_flat = nn.functional.pad(x_flat, [0, length - 1, 0, 0]) # Reshape and slice out the padded elements. x_final = x_flat.view([batch_heads, length + 1, 2 * length - 1]) x_final = x_final[:, :length, length - 1 :] return x_final def _absolute_position_to_relative_position(self, x): batch_heads, length, _ = x.size() # Pad along column x = nn.functional.pad(x, [0, length - 1, 0, 0, 0, 0]) x_flat = x.view([batch_heads, length * (2 * length - 1)]) # Add 0's in the beginning that will skew the elements after reshape x_flat = nn.functional.pad(x_flat, [length, 0, 0, 0]) x_final = x_flat.view([batch_heads, length, 2 * length])[:, :, 1:] return x_final class VitsFeedForward(nn.Module): def __init__(self, config): super().__init__() self.conv_1 = nn.Conv1d(config.hidden_size, config.ffn_dim, config.ffn_kernel_size) self.conv_2 = nn.Conv1d(config.ffn_dim, config.hidden_size, config.ffn_kernel_size) self.dropout = nn.Dropout(config.activation_dropout) if isinstance(config.hidden_act, str): self.act_fn = ACT2FN[config.hidden_act] else: self.act_fn = config.hidden_act if config.ffn_kernel_size > 1: pad_left = (config.ffn_kernel_size - 1) // 2 pad_right = config.ffn_kernel_size // 2 self.padding = [pad_left, pad_right, 0, 0, 0, 0] else: self.padding = None def forward(self, hidden_states, padding_mask): hidden_states = hidden_states.permute(0, 2, 1) padding_mask = padding_mask.permute(0, 2, 1) hidden_states = hidden_states * padding_mask if self.padding is not None: hidden_states = nn.functional.pad(hidden_states, self.padding) hidden_states = self.conv_1(hidden_states) hidden_states = self.act_fn(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = hidden_states * padding_mask if self.padding is not None: hidden_states = nn.functional.pad(hidden_states, self.padding) hidden_states = self.conv_2(hidden_states) hidden_states = hidden_states * padding_mask hidden_states = hidden_states.permute(0, 2, 1) return hidden_states class VitsEncoderLayer(nn.Module): def __init__(self, config: VitsConfig): super().__init__() self.attention = VitsAttention(config) self.dropout = nn.Dropout(config.hidden_dropout) self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.feed_forward = VitsFeedForward(config) self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) def forward( self, hidden_states: torch.Tensor, padding_mask: torch.FloatTensor, attention_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, ): residual = hidden_states hidden_states, attn_weights = self.attention( hidden_states=hidden_states, attention_mask=attention_mask, output_attentions=output_attentions, ) hidden_states = self.dropout(hidden_states) hidden_states = self.layer_norm(residual + hidden_states) residual = hidden_states hidden_states = self.feed_forward(hidden_states, padding_mask) hidden_states = self.dropout(hidden_states) hidden_states = self.final_layer_norm(residual + hidden_states) outputs = (hidden_states,) if output_attentions: outputs += (attn_weights,) return outputs class VitsEncoder(nn.Module): def __init__(self, config: VitsConfig): super().__init__() self.config = config self.layers = nn.ModuleList([VitsEncoderLayer(config) for _ in range(config.num_hidden_layers)]) self.gradient_checkpointing = False self.layerdrop = config.layerdrop def forward( self, hidden_states: torch.FloatTensor, padding_mask: torch.FloatTensor, attention_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutput]: all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None # expand attention_mask if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) hidden_states = hidden_states * padding_mask synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self) for encoder_layer in self.layers: if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) dropout_probability = np.random.uniform(0, 1) skip_the_layer = self.training and (dropout_probability < self.layerdrop) if not skip_the_layer or synced_gpus: # under fsdp or deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( encoder_layer.__call__, hidden_states, padding_mask, attention_mask, output_attentions, ) else: layer_outputs = encoder_layer( hidden_states, attention_mask=attention_mask, padding_mask=padding_mask, output_attentions=output_attentions, ) hidden_states = layer_outputs[0] if skip_the_layer: layer_outputs = (None, None) if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) hidden_states = hidden_states * padding_mask if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if not return_dict: return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) return BaseModelOutput( last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_self_attentions, ) class VitsTextEncoder(nn.Module): """ Transformer encoder that uses relative positional representation instead of absolute positional encoding. """ def __init__(self, config: VitsConfig): super().__init__() self.config = config self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id) self.encoder = VitsEncoder(config) self.project = nn.Conv1d(config.hidden_size, config.flow_size * 2, kernel_size=1) def get_input_embeddings(self): return self.embed_tokens def set_input_embeddings(self, value): self.embed_tokens = value def forward( self, input_ids: torch.Tensor, padding_mask: torch.FloatTensor, attention_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = True, ) -> Union[Tuple[torch.Tensor], VitsTextEncoderOutput]: hidden_states = self.embed_tokens(input_ids) * math.sqrt(self.config.hidden_size) encoder_outputs = self.encoder( hidden_states=hidden_states, padding_mask=padding_mask, attention_mask=attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) last_hidden_state = encoder_outputs[0] if not return_dict else encoder_outputs.last_hidden_state stats = self.project(last_hidden_state.transpose(1, 2)).transpose(1, 2) * padding_mask prior_means, prior_log_variances = torch.split(stats, self.config.flow_size, dim=2) if not return_dict: outputs = (last_hidden_state, prior_means, prior_log_variances) + encoder_outputs[1:] return outputs return VitsTextEncoderOutput( last_hidden_state=last_hidden_state, prior_means=prior_means, prior_log_variances=prior_log_variances, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, ) class VitsPreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. """ config_class = VitsConfig base_model_prefix = "vits" main_input_name = "input_ids" supports_gradient_checkpointing = True def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) elif isinstance(module, nn.Conv1d): nn.init.kaiming_normal_(module.weight) if module.bias is not None: k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0])) nn.init.uniform_(module.bias, a=-k, b=k) elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() VITS_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.) This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior. Parameters: config ([`VitsConfig`]): Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. """ VITS_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide it. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing convolution and attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) speaker_id (`int`, *optional*): Which speaker embedding to use. Only used for multispeaker models. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ @add_start_docstrings( "The complete VITS model, for text-to-speech synthesis.", VITS_START_DOCSTRING, ) class VitsModel(VitsPreTrainedModel): def __init__(self, config: VitsConfig): super().__init__(config) self.config = config self.text_encoder = VitsTextEncoder(config) self.flow = VitsResidualCouplingBlock(config) if config.istft_decoder in ["istft", "mb_istft", "ms_istft"]: self.decoder = VitsISTFT(config) else: self.decoder = VitsHifiGan(config) if config.use_stochastic_duration_prediction: self.duration_predictor = VitsStochasticDurationPredictor(config) else: self.duration_predictor = VitsDurationPredictor(config) if config.num_speakers > 1: self.embed_speaker = nn.Embedding(config.num_speakers, config.speaker_embedding_size) # This is used only for training. self.posterior_encoder = VitsPosteriorEncoder(config) # These parameters control the synthesised speech properties self.speaking_rate = config.speaking_rate self.noise_scale = config.noise_scale self.noise_scale_duration = config.noise_scale_duration # Initialize weights and apply final processing self.post_init() def get_encoder(self): return self.text_encoder @add_start_docstrings_to_model_forward(VITS_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=VitsModelOutput, config_class=_CONFIG_FOR_DOC) def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, speaker_id: Optional[int] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, labels: Optional[torch.FloatTensor] = None, ) -> Union[Tuple[Any], VitsModelOutput]: r""" labels (`torch.FloatTensor` of shape `(batch_size, config.spectrogram_bins, sequence_length)`, *optional*): Float values of target spectrogram. Timesteps set to `-100.0` are ignored (masked) for the loss computation. Returns: Example: ```python >>> from transformers import VitsTokenizer, VitsModel, set_seed >>> import torch >>> tokenizer = VitsTokenizer.from_pretrained("facebook/mms-tts-eng") >>> model = VitsModel.from_pretrained("facebook/mms-tts-eng") >>> inputs = tokenizer(text="Hello - my dog is cute", return_tensors="pt") >>> set_seed(555) # make deterministic >>> with torch.no_grad(): ... outputs = model(inputs["input_ids"]) >>> outputs.waveform.shape torch.Size([1, 45824]) ``` """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if labels is not None: raise NotImplementedError("Training of VITS is not supported yet.") if attention_mask is not None: input_padding_mask = attention_mask.unsqueeze(-1).float() else: input_padding_mask = torch.ones_like(input_ids).unsqueeze(-1).float() if self.config.num_speakers > 1 and speaker_id is not None: if not 0 <= speaker_id < self.config.num_speakers: raise ValueError(f"Set `speaker_id` in the range 0-{self.config.num_speakers - 1}.") if isinstance(speaker_id, int): speaker_id = torch.full(size=(1,), fill_value=speaker_id, device=self.device) speaker_embeddings = self.embed_speaker(speaker_id).unsqueeze(-1) else: speaker_embeddings = None text_encoder_output = self.text_encoder( input_ids=input_ids, padding_mask=input_padding_mask, attention_mask=attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = text_encoder_output[0] if not return_dict else text_encoder_output.last_hidden_state hidden_states = hidden_states.transpose(1, 2) input_padding_mask = input_padding_mask.transpose(1, 2) prior_means = text_encoder_output[1] if not return_dict else text_encoder_output.prior_means prior_log_variances = text_encoder_output[2] if not return_dict else text_encoder_output.prior_log_variances if self.config.use_stochastic_duration_prediction: log_duration = self.duration_predictor( hidden_states, input_padding_mask, speaker_embeddings, reverse=True, noise_scale=self.noise_scale_duration, ) else: log_duration = self.duration_predictor(hidden_states, input_padding_mask, speaker_embeddings) length_scale = 1.0 / self.speaking_rate duration = torch.ceil(torch.exp(log_duration) * input_padding_mask * length_scale) predicted_lengths = torch.clamp_min(torch.sum(duration, [1, 2]), 1).long() # Create a padding mask for the output lengths of shape (batch, 1, max_output_length) indices = torch.arange(predicted_lengths.max(), dtype=predicted_lengths.dtype, device=predicted_lengths.device) output_padding_mask = indices.unsqueeze(0) < predicted_lengths.unsqueeze(1) output_padding_mask = output_padding_mask.unsqueeze(1).to(input_padding_mask.dtype) # Reconstruct an attention tensor of shape (batch, 1, out_length, in_length) attn_mask = torch.unsqueeze(input_padding_mask, 2) * torch.unsqueeze(output_padding_mask, -1) batch_size, _, output_length, input_length = attn_mask.shape cum_duration = torch.cumsum(duration, -1).view(batch_size * input_length, 1) indices = torch.arange(output_length, dtype=duration.dtype, device=duration.device) valid_indices = indices.unsqueeze(0) < cum_duration valid_indices = valid_indices.to(attn_mask.dtype).view(batch_size, input_length, output_length) padded_indices = valid_indices - nn.functional.pad(valid_indices, [0, 0, 1, 0, 0, 0])[:, :-1] attn = padded_indices.unsqueeze(1).transpose(2, 3) * attn_mask # Expand prior distribution prior_means = torch.matmul(attn.squeeze(1), prior_means).transpose(1, 2) prior_log_variances = torch.matmul(attn.squeeze(1), prior_log_variances).transpose(1, 2) prior_latents = prior_means + torch.randn_like(prior_means) * torch.exp(prior_log_variances) * self.noise_scale latents = self.flow(prior_latents, output_padding_mask, speaker_embeddings, reverse=True) spectrogram = latents * output_padding_mask waveform = self.decoder(spectrogram, speaker_embeddings) waveform = waveform.squeeze(1) sequence_lengths = predicted_lengths * np.prod(self.config.upsample_rates) if not return_dict: outputs = (waveform, sequence_lengths, spectrogram) + text_encoder_output[3:] return outputs return VitsModelOutput( waveform=waveform, sequence_lengths=sequence_lengths, spectrogram=spectrogram, hidden_states=text_encoder_output.hidden_states, attentions=text_encoder_output.attentions, )