import torch from typing import Optional from rvc.lib.algorithm.nsf import GeneratorNSF from rvc.lib.algorithm.generators import Generator from rvc.lib.algorithm.commons import slice_segments, rand_slice_segments from rvc.lib.algorithm.residuals import ResidualCouplingBlock from rvc.lib.algorithm.encoders import TextEncoder, PosteriorEncoder class Synthesizer(torch.nn.Module): """ Base Synthesizer model. Args: spec_channels (int): Number of channels in the spectrogram. segment_size (int): Size of the audio segment. inter_channels (int): Number of channels in the intermediate layers. hidden_channels (int): Number of channels in the hidden layers. filter_channels (int): Number of channels in the filter layers. n_heads (int): Number of attention heads. n_layers (int): Number of layers in the encoder. kernel_size (int): Size of the convolution kernel. p_dropout (float): Dropout probability. resblock (str): Type of residual block. resblock_kernel_sizes (list): Kernel sizes for the residual blocks. resblock_dilation_sizes (list): Dilation sizes for the residual blocks. upsample_rates (list): Upsampling rates for the decoder. upsample_initial_channel (int): Number of channels in the initial upsampling layer. upsample_kernel_sizes (list): Kernel sizes for the upsampling layers. spk_embed_dim (int): Dimension of the speaker embedding. gin_channels (int): Number of channels in the global conditioning vector. sr (int): Sampling rate of the audio. use_f0 (bool): Whether to use F0 information. text_enc_hidden_dim (int): Hidden dimension for the text encoder. kwargs: Additional keyword arguments. """ def __init__( self, spec_channels, segment_size, inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, spk_embed_dim, gin_channels, sr, use_f0, text_enc_hidden_dim=768, **kwargs ): super(Synthesizer, self).__init__() self.spec_channels = spec_channels self.inter_channels = inter_channels self.hidden_channels = hidden_channels self.filter_channels = filter_channels self.n_heads = n_heads self.n_layers = n_layers self.kernel_size = kernel_size self.p_dropout = float(p_dropout) self.resblock = resblock self.resblock_kernel_sizes = resblock_kernel_sizes self.resblock_dilation_sizes = resblock_dilation_sizes self.upsample_rates = upsample_rates self.upsample_initial_channel = upsample_initial_channel self.upsample_kernel_sizes = upsample_kernel_sizes self.segment_size = segment_size self.gin_channels = gin_channels self.spk_embed_dim = spk_embed_dim self.use_f0 = use_f0 self.enc_p = TextEncoder( inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, float(p_dropout), text_enc_hidden_dim, f0=use_f0, ) if use_f0: self.dec = GeneratorNSF( inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels, sr=sr, is_half=kwargs["is_half"], ) else: self.dec = Generator( inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels, ) self.enc_q = PosteriorEncoder( spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels, ) self.flow = ResidualCouplingBlock( inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels ) self.emb_g = torch.nn.Embedding(self.spk_embed_dim, gin_channels) def remove_weight_norm(self): """Removes weight normalization from the model.""" self.dec.remove_weight_norm() self.flow.remove_weight_norm() self.enc_q.remove_weight_norm() def __prepare_scriptable__(self): for hook in self.dec._forward_pre_hooks.values(): if ( hook.__module__ == "torch.nn.utils.parametrizations.weight_norm" and hook.__class__.__name__ == "WeightNorm" ): torch.nn.utils.remove_weight_norm(self.dec) for hook in self.flow._forward_pre_hooks.values(): if ( hook.__module__ == "torch.nn.utils.parametrizations.weight_norm" and hook.__class__.__name__ == "WeightNorm" ): torch.nn.utils.remove_weight_norm(self.flow) if hasattr(self, "enc_q"): for hook in self.enc_q._forward_pre_hooks.values(): if ( hook.__module__ == "torch.nn.utils.parametrizations.weight_norm" and hook.__class__.__name__ == "WeightNorm" ): torch.nn.utils.remove_weight_norm(self.enc_q) return self @torch.jit.ignore def forward( self, phone: torch.Tensor, phone_lengths: torch.Tensor, pitch: Optional[torch.Tensor] = None, pitchf: Optional[torch.Tensor] = None, y: torch.Tensor = None, y_lengths: torch.Tensor = None, ds: Optional[torch.Tensor] = None, ): """ Forward pass of the model. Args: phone (torch.Tensor): Phoneme sequence. phone_lengths (torch.Tensor): Lengths of the phoneme sequences. pitch (torch.Tensor, optional): Pitch sequence. pitchf (torch.Tensor, optional): Fine-grained pitch sequence. y (torch.Tensor, optional): Target spectrogram. y_lengths (torch.Tensor, optional): Lengths of the target spectrograms. ds (torch.Tensor, optional): Speaker embedding. Defaults to None. """ g = self.emb_g(ds).unsqueeze(-1) m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths) if y is not None: z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g) z_p = self.flow(z, y_mask, g=g) z_slice, ids_slice = rand_slice_segments(z, y_lengths, self.segment_size) if self.use_f0: pitchf = slice_segments(pitchf, ids_slice, self.segment_size, 2) o = self.dec(z_slice, pitchf, g=g) else: o = self.dec(z_slice, g=g) return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q) else: return None, None, x_mask, None, (None, None, m_p, logs_p, None, None) @torch.jit.export def infer( self, phone: torch.Tensor, phone_lengths: torch.Tensor, pitch: Optional[torch.Tensor] = None, nsff0: Optional[torch.Tensor] = None, sid: torch.Tensor = None, rate: Optional[torch.Tensor] = None, ): """ Inference of the model. Args: phone (torch.Tensor): Phoneme sequence. phone_lengths (torch.Tensor): Lengths of the phoneme sequences. pitch (torch.Tensor, optional): Pitch sequence. nsff0 (torch.Tensor, optional): Fine-grained pitch sequence. sid (torch.Tensor): Speaker embedding. rate (torch.Tensor, optional): Rate for time-stretching. Defaults to None. """ g = self.emb_g(sid).unsqueeze(-1) m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths) z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask if rate is not None: assert isinstance(rate, torch.Tensor) head = int(z_p.shape[2] * (1.0 - rate.item())) z_p = z_p[:, :, head:] x_mask = x_mask[:, :, head:] if self.use_f0: nsff0 = nsff0[:, head:] if self.use_f0: z = self.flow(z_p, x_mask, g=g, reverse=True) o = self.dec(z * x_mask, nsff0, g=g) else: z = self.flow(z_p, x_mask, g=g, reverse=True) o = self.dec(z * x_mask, g=g) return o, x_mask, (z, z_p, m_p, logs_p)