IMS-ToucanTTS / InferenceInterfaces /ToucanTTSInterface.py
NorHsangPha's picture
Initial commit
de6e35f verified
import itertools
import os
import warnings
from typing import cast
import librosa
import matplotlib.pyplot as plt
from matplotlib import font_manager as fm, rcParams
import pyloudnorm
import sounddevice
import soundfile
import torch
with warnings.catch_warnings():
warnings.simplefilter("ignore")
from audioseal.builder import create_generator
from omegaconf import DictConfig
from omegaconf import OmegaConf
from speechbrain.pretrained import EncoderClassifier
from torchaudio.transforms import Resample
from Architectures.ToucanTTS.InferenceToucanTTS import ToucanTTS
from Architectures.Vocoder.HiFiGAN_Generator import HiFiGAN
from Preprocessing.AudioPreprocessor import AudioPreprocessor
from Preprocessing.TextFrontend import ArticulatoryCombinedTextFrontend
from Preprocessing.TextFrontend import get_language_id
from Utility.storage_config import MODELS_DIR
from Utility.utils import cumsum_durations
from Utility.utils import float2pcm
class ToucanTTSInterface(torch.nn.Module):
def __init__(
self,
device="cpu", # device that everything computes on. If a cuda device is available, this can speed things up by an order of magnitude.
tts_model_path=os.path.join(
MODELS_DIR, f"ToucanTTS_Shan", "best.pt"
), # path to the ToucanTTS checkpoint or just a shorthand if run standalone
vocoder_model_path=os.path.join(
MODELS_DIR, f"Vocoder", "best.pt"
), # path to the Vocoder checkpoint
language="eng", # initial language of the model, can be changed later with the setter methods
enhance=None, # legacy argument
):
super().__init__()
self.device = device
if not tts_model_path.endswith(".pt"):
# default to shorthand system
tts_model_path = os.path.join(
MODELS_DIR, f"ToucanTTS_{tts_model_path}", "best.pt"
)
if "USER" not in os.environ:
os.environ["USER"] = (
"" # that's the case under Windows, but omegaconf needs this
)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
watermark_conf = cast(
DictConfig,
OmegaConf.load("InferenceInterfaces/audioseal_wm_16bits.yaml"),
)
self.watermark = create_generator(watermark_conf)
self.watermark.load_state_dict(
torch.load("Models/audioseal/generator.pth", map_location="cpu")[
"model"
]
) # downloaded from https://dl.fbaipublicfiles.com/audioseal/6edcf62f/generator.pth originally
################################
# build text to phone #
################################
self.text2phone = ArticulatoryCombinedTextFrontend(
language=language, add_silence_to_end=True
)
#####################################
# load phone to features model #
#####################################
checkpoint = torch.load(tts_model_path, map_location="cpu")
self.phone2mel = ToucanTTS(
weights=checkpoint["model"], config=checkpoint["config"]
)
with torch.no_grad():
self.phone2mel.store_inverse_all() # this also removes weight norm
self.phone2mel = self.phone2mel.to(torch.device(device))
######################################
# load features to style models #
######################################
self.speaker_embedding_func_ecapa = EncoderClassifier.from_hparams(
source="speechbrain/spkrec-ecapa-voxceleb",
run_opts={"device": str(device)},
savedir=os.path.join(
MODELS_DIR, "Embedding", "speechbrain_speaker_embedding_ecapa"
),
)
################################
# load mel to wave model #
################################
vocoder_checkpoint = torch.load(vocoder_model_path, map_location="cpu")
self.vocoder = HiFiGAN()
self.vocoder.load_state_dict(vocoder_checkpoint)
self.vocoder = self.vocoder.to(device).eval()
self.vocoder.remove_weight_norm()
self.meter = pyloudnorm.Meter(24000)
################################
# set defaults #
################################
self.default_utterance_embedding = checkpoint["default_emb"].to(self.device)
self.ap = AudioPreprocessor(input_sr=100, output_sr=16000, device=device)
self.phone2mel.eval()
self.vocoder.eval()
self.lang_id = get_language_id(language)
self.to(torch.device(device))
self.eval()
def set_utterance_embedding(self, path_to_reference_audio="", embedding=None):
if embedding is not None:
self.default_utterance_embedding = embedding.squeeze().to(self.device)
return
if type(path_to_reference_audio) != list:
path_to_reference_audio = [path_to_reference_audio]
if len(path_to_reference_audio) > 0:
for path in path_to_reference_audio:
assert os.path.exists(path)
speaker_embs = list()
for path in path_to_reference_audio:
wave, sr = soundfile.read(path)
if len(wave.shape) > 1: # oh no, we found a stereo audio!
if (
len(wave[0]) == 2
): # let's figure out whether we need to switch the axes
wave = wave.transpose() # if yes, we switch the axes.
wave = librosa.to_mono(wave)
wave = Resample(orig_freq=sr, new_freq=16000).to(self.device)(
torch.tensor(wave, device=self.device, dtype=torch.float32)
)
speaker_embedding = self.speaker_embedding_func_ecapa.encode_batch(
wavs=wave.to(self.device).squeeze().unsqueeze(0)
).squeeze()
speaker_embs.append(speaker_embedding)
self.default_utterance_embedding = sum(speaker_embs) / len(speaker_embs)
def set_language(self, lang_id):
"""
The id parameter actually refers to the shorthand. This has become ambiguous with the introduction of the actual language IDs
"""
self.set_phonemizer_language(lang_id=lang_id)
self.set_accent_language(lang_id=lang_id)
def set_phonemizer_language(self, lang_id):
self.text2phone = ArticulatoryCombinedTextFrontend(
language=lang_id, add_silence_to_end=True
)
def set_accent_language(self, lang_id):
if lang_id in [
"ajp",
"ajt",
"lak",
"lno",
"nul",
"pii",
"plj",
"slq",
"smd",
"snb",
"tpw",
"wya",
"zua",
"en-us",
"en-sc",
"fr-be",
"fr-sw",
"pt-br",
"spa-lat",
"vi-ctr",
"vi-so",
]:
if lang_id == "vi-so" or lang_id == "vi-ctr":
lang_id = "vie"
elif lang_id == "spa-lat":
lang_id = "spa"
elif lang_id == "pt-br":
lang_id = "por"
elif lang_id == "fr-sw" or lang_id == "fr-be":
lang_id = "fra"
elif lang_id == "en-sc" or lang_id == "en-us":
lang_id = "eng"
else:
# no clue where these others are even coming from, they are not in ISO 639-2
lang_id = "eng"
self.lang_id = get_language_id(lang_id).to(self.device)
def forward(
self,
text,
view=False,
duration_scaling_factor=1.0,
pitch_variance_scale=1.0,
energy_variance_scale=1.0,
pause_duration_scaling_factor=1.0,
durations=None,
pitch=None,
energy=None,
input_is_phones=False,
return_plot_as_filepath=False,
loudness_in_db=-24.0,
glow_sampling_temperature=0.2,
):
"""
duration_scaling_factor: reasonable values are 0.8 < scale < 1.2.
1.0 means no scaling happens, higher values increase durations for the whole
utterance, lower values decrease durations for the whole utterance.
pitch_variance_scale: reasonable values are 0.6 < scale < 1.4.
1.0 means no scaling happens, higher values increase variance of the pitch curve,
lower values decrease variance of the pitch curve.
energy_variance_scale: reasonable values are 0.6 < scale < 1.4.
1.0 means no scaling happens, higher values increase variance of the energy curve,
lower values decrease variance of the energy curve.
"""
with torch.inference_mode():
phones = self.text2phone.string_to_tensor(
text, input_phonemes=input_is_phones
).to(torch.device(self.device))
mel, durations, pitch, energy = self.phone2mel(
phones,
return_duration_pitch_energy=True,
utterance_embedding=self.default_utterance_embedding,
durations=durations,
pitch=pitch,
energy=energy,
lang_id=self.lang_id,
duration_scaling_factor=duration_scaling_factor,
pitch_variance_scale=pitch_variance_scale,
energy_variance_scale=energy_variance_scale,
pause_duration_scaling_factor=pause_duration_scaling_factor,
glow_sampling_temperature=glow_sampling_temperature,
)
wave, _, _ = self.vocoder(mel.unsqueeze(0))
wave = wave.squeeze().cpu()
wave = wave.numpy()
sr = 24000
try:
loudness = self.meter.integrated_loudness(wave)
wave = pyloudnorm.normalize.loudness(wave, loudness, loudness_in_db)
except ValueError:
# if the audio is too short, a value error will arise
pass
with torch.inference_mode():
wave = (
(
torch.tensor(wave)
+ 0.1
* self.watermark.get_watermark(
torch.tensor(wave).to(self.device).unsqueeze(0).unsqueeze(0)
)
.squeeze()
.detach()
.cpu()
)
.detach()
.numpy()
)
if view or return_plot_as_filepath:
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(9, 5))
# fpath = "./src/fonts/Shan.ttf"
fpath = os.path.join(os.path.dirname(__file__), "src/fonts/Shan.ttf")
prop = fm.FontProperties(fname=fpath)
ax.imshow(mel.cpu().numpy(), origin="lower", cmap="GnBu")
ax.yaxis.set_visible(False)
duration_splits, label_positions = cumsum_durations(durations.cpu().numpy())
ax.xaxis.grid(True, which="minor")
ax.set_xticks(label_positions, minor=False)
if input_is_phones:
phones = text.replace(" ", "|")
else:
phones = self.text2phone.get_phone_string(text, for_plot_labels=True)
ax.set_xticklabels(phones)
word_boundaries = list()
for label_index, phone in enumerate(phones):
if phone == "|":
word_boundaries.append(label_positions[label_index])
try:
prev_word_boundary = 0
word_label_positions = list()
for word_boundary in word_boundaries:
word_label_positions.append(
(word_boundary + prev_word_boundary) / 2
)
prev_word_boundary = word_boundary
word_label_positions.append(
(duration_splits[-1] + prev_word_boundary) / 2
)
secondary_ax = ax.secondary_xaxis("bottom")
secondary_ax.tick_params(axis="x", direction="out", pad=24)
secondary_ax.set_xticks(word_label_positions, minor=False)
secondary_ax.set_xticklabels(text.split(), fontproperties=prop)
secondary_ax.tick_params(axis="x", colors="orange")
secondary_ax.xaxis.label.set_color("orange")
except ValueError:
ax.set_title(text)
except IndexError:
ax.set_title(text)
ax.vlines(
x=duration_splits,
colors="green",
linestyles="solid",
ymin=0,
ymax=120,
linewidth=0.5,
)
ax.vlines(
x=word_boundaries,
colors="orange",
linestyles="solid",
ymin=0,
ymax=120,
linewidth=1.0,
)
plt.subplots_adjust(
left=0.02, bottom=0.2, right=0.98, top=0.9, wspace=0.0, hspace=0.0
)
ax.set_aspect("auto")
if return_plot_as_filepath:
plt.savefig("tmp.png")
return wave, sr, "tmp.png"
return wave, sr
def read_to_file(
self,
text_list,
file_location,
duration_scaling_factor=1.0,
pitch_variance_scale=1.0,
energy_variance_scale=1.0,
pause_duration_scaling_factor=1.0,
silent=False,
dur_list=None,
pitch_list=None,
energy_list=None,
glow_sampling_temperature=0.2,
):
"""
Args:
silent: Whether to be verbose about the process
text_list: A list of strings to be read
file_location: The path and name of the file it should be saved to
energy_list: list of energy tensors to be used for the texts
pitch_list: list of pitch tensors to be used for the texts
dur_list: list of duration tensors to be used for the texts
duration_scaling_factor: reasonable values are 0.8 < scale < 1.2.
1.0 means no scaling happens, higher values increase durations for the whole
utterance, lower values decrease durations for the whole utterance.
pitch_variance_scale: reasonable values are 0.6 < scale < 1.4.
1.0 means no scaling happens, higher values increase variance of the pitch curve,
lower values decrease variance of the pitch curve.
energy_variance_scale: reasonable values are 0.6 < scale < 1.4.
1.0 means no scaling happens, higher values increase variance of the energy curve,
lower values decrease variance of the energy curve.
"""
if not dur_list:
dur_list = []
if not pitch_list:
pitch_list = []
if not energy_list:
energy_list = []
silence = torch.zeros([14300])
wav = silence.clone()
for text, durations, pitch, energy in itertools.zip_longest(
text_list, dur_list, pitch_list, energy_list
):
if text.strip() != "":
if not silent:
print("Now synthesizing: {}".format(text))
spoken_sentence, sr = self(
text,
durations=(
durations.to(self.device) if durations is not None else None
),
pitch=pitch.to(self.device) if pitch is not None else None,
energy=energy.to(self.device) if energy is not None else None,
duration_scaling_factor=duration_scaling_factor,
pitch_variance_scale=pitch_variance_scale,
energy_variance_scale=energy_variance_scale,
pause_duration_scaling_factor=pause_duration_scaling_factor,
glow_sampling_temperature=glow_sampling_temperature,
)
spoken_sentence = torch.tensor(spoken_sentence).cpu()
wav = torch.cat((wav, spoken_sentence, silence), 0)
soundfile.write(
file=file_location, data=float2pcm(wav), samplerate=sr, subtype="PCM_16"
)
def read_aloud(
self,
text,
view=False,
duration_scaling_factor=1.0,
pitch_variance_scale=1.0,
energy_variance_scale=1.0,
blocking=False,
glow_sampling_temperature=0.2,
):
if text.strip() == "":
return
wav, sr = self(
text,
view,
duration_scaling_factor=duration_scaling_factor,
pitch_variance_scale=pitch_variance_scale,
energy_variance_scale=energy_variance_scale,
glow_sampling_temperature=glow_sampling_temperature,
)
silence = torch.zeros([sr // 2])
wav = torch.cat((silence, torch.tensor(wav), silence), 0).numpy()
sounddevice.play(float2pcm(wav), samplerate=sr)
if view:
plt.show()
if blocking:
sounddevice.wait()