|
import torch
|
|
|
|
from TTS.tts.configs.delightful_tts_config import DelightfulTTSConfig
|
|
from TTS.tts.layers.delightful_tts.acoustic_model import AcousticModel
|
|
from TTS.tts.models.delightful_tts import DelightfulTtsArgs, VocoderConfig
|
|
from TTS.tts.utils.helpers import rand_segments
|
|
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
|
from TTS.vocoder.models.hifigan_generator import HifiganGenerator
|
|
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
|
|
args = DelightfulTtsArgs()
|
|
v_args = VocoderConfig()
|
|
|
|
|
|
config = DelightfulTTSConfig(
|
|
model_args=args,
|
|
|
|
|
|
text_cleaner="english_cleaners",
|
|
use_phonemes=True,
|
|
phoneme_language="en-us",
|
|
|
|
)
|
|
|
|
tokenizer, config = TTSTokenizer.init_from_config(config)
|
|
|
|
|
|
def test_acoustic_model():
|
|
dummy_tokens = torch.rand((1, 41)).long().to(device)
|
|
dummy_text_lens = torch.tensor([41]).long().to(device)
|
|
dummy_spec = torch.rand((1, 100, 207)).to(device)
|
|
dummy_spec_lens = torch.tensor([207]).to(device)
|
|
dummy_pitch = torch.rand((1, 1, 207)).long().to(device)
|
|
dummy_energy = torch.rand((1, 1, 207)).long().to(device)
|
|
|
|
args.out_channels = 100
|
|
args.num_mels = 100
|
|
|
|
acoustic_model = AcousticModel(args=args, tokenizer=tokenizer, speaker_manager=None).to(device)
|
|
acoustic_model = acoustic_model.train()
|
|
|
|
output = acoustic_model(
|
|
tokens=dummy_tokens,
|
|
src_lens=dummy_text_lens,
|
|
mel_lens=dummy_spec_lens,
|
|
mels=dummy_spec,
|
|
pitches=dummy_pitch,
|
|
energies=dummy_energy,
|
|
attn_priors=None,
|
|
d_vectors=None,
|
|
speaker_idx=None,
|
|
)
|
|
assert list(output["model_outputs"].shape) == [1, 207, 100]
|
|
|
|
|
|
|
|
def test_hifi_decoder():
|
|
dummy_input = torch.rand((1, 207, 100)).to(device)
|
|
dummy_spec_lens = torch.tensor([207]).to(device)
|
|
|
|
waveform_decoder = HifiganGenerator(
|
|
100,
|
|
1,
|
|
v_args.resblock_type_decoder,
|
|
v_args.resblock_dilation_sizes_decoder,
|
|
v_args.resblock_kernel_sizes_decoder,
|
|
v_args.upsample_kernel_sizes_decoder,
|
|
v_args.upsample_initial_channel_decoder,
|
|
v_args.upsample_rates_decoder,
|
|
inference_padding=0,
|
|
cond_channels=0,
|
|
conv_pre_weight_norm=False,
|
|
conv_post_weight_norm=False,
|
|
conv_post_bias=False,
|
|
).to(device)
|
|
waveform_decoder = waveform_decoder.train()
|
|
|
|
vocoder_input_slices, slice_ids = rand_segments(
|
|
x=dummy_input.transpose(1, 2),
|
|
x_lengths=dummy_spec_lens,
|
|
segment_size=32,
|
|
let_short_samples=True,
|
|
pad_short=True,
|
|
)
|
|
|
|
outputs = waveform_decoder(x=vocoder_input_slices.detach())
|
|
assert list(outputs.shape) == [1, 1, 8192]
|
|
|
|
|