|
import argparse |
|
import logging |
|
from typing import Callable |
|
from typing import Collection |
|
from typing import Dict |
|
from typing import List |
|
from typing import Optional |
|
from typing import Tuple |
|
|
|
import numpy as np |
|
import torch |
|
from typeguard import check_argument_types |
|
from typeguard import check_return_type |
|
|
|
from espnet2.layers.abs_normalize import AbsNormalize |
|
from espnet2.layers.global_mvn import GlobalMVN |
|
from espnet2.tasks.abs_task import AbsTask |
|
from espnet2.train.class_choices import ClassChoices |
|
from espnet2.train.collate_fn import CommonCollateFn |
|
from espnet2.train.preprocessor import CommonPreprocessor |
|
from espnet2.train.trainer import Trainer |
|
from espnet2.tts.abs_tts import AbsTTS |
|
from espnet2.tts.espnet_model import ESPnetTTSModel |
|
from espnet2.tts.fastspeech import FastSpeech |
|
from espnet2.tts.fastspeech2 import FastSpeech2 |
|
from espnet2.tts.fastespeech import FastESpeech |
|
from espnet2.tts.feats_extract.abs_feats_extract import AbsFeatsExtract |
|
from espnet2.tts.feats_extract.dio import Dio |
|
from espnet2.tts.feats_extract.energy import Energy |
|
from espnet2.tts.feats_extract.log_mel_fbank import LogMelFbank |
|
from espnet2.tts.feats_extract.log_spectrogram import LogSpectrogram |
|
from espnet2.tts.tacotron2 import Tacotron2 |
|
from espnet2.tts.transformer import Transformer |
|
from espnet2.utils.get_default_kwargs import get_default_kwargs |
|
from espnet2.utils.nested_dict_action import NestedDictAction |
|
from espnet2.utils.types import int_or_none |
|
from espnet2.utils.types import str2bool |
|
from espnet2.utils.types import str_or_none |
|
|
|
feats_extractor_choices = ClassChoices( |
|
"feats_extract", |
|
classes=dict(fbank=LogMelFbank, spectrogram=LogSpectrogram), |
|
type_check=AbsFeatsExtract, |
|
default="fbank", |
|
) |
|
pitch_extractor_choices = ClassChoices( |
|
"pitch_extract", |
|
classes=dict(dio=Dio), |
|
type_check=AbsFeatsExtract, |
|
default=None, |
|
optional=True, |
|
) |
|
energy_extractor_choices = ClassChoices( |
|
"energy_extract", |
|
classes=dict(energy=Energy), |
|
type_check=AbsFeatsExtract, |
|
default=None, |
|
optional=True, |
|
) |
|
normalize_choices = ClassChoices( |
|
"normalize", |
|
classes=dict(global_mvn=GlobalMVN), |
|
type_check=AbsNormalize, |
|
default="global_mvn", |
|
optional=True, |
|
) |
|
pitch_normalize_choices = ClassChoices( |
|
"pitch_normalize", |
|
classes=dict(global_mvn=GlobalMVN), |
|
type_check=AbsNormalize, |
|
default=None, |
|
optional=True, |
|
) |
|
energy_normalize_choices = ClassChoices( |
|
"energy_normalize", |
|
classes=dict(global_mvn=GlobalMVN), |
|
type_check=AbsNormalize, |
|
default=None, |
|
optional=True, |
|
) |
|
tts_choices = ClassChoices( |
|
"tts", |
|
classes=dict( |
|
tacotron2=Tacotron2, |
|
transformer=Transformer, |
|
fastspeech=FastSpeech, |
|
fastspeech2=FastSpeech2, |
|
fastespeech=FastESpeech, |
|
), |
|
type_check=AbsTTS, |
|
default="tacotron2", |
|
) |
|
|
|
|
|
class TTSTask(AbsTask): |
|
|
|
num_optimizers: int = 1 |
|
|
|
|
|
class_choices_list = [ |
|
|
|
feats_extractor_choices, |
|
|
|
normalize_choices, |
|
|
|
tts_choices, |
|
|
|
pitch_extractor_choices, |
|
|
|
pitch_normalize_choices, |
|
|
|
energy_extractor_choices, |
|
|
|
energy_normalize_choices, |
|
] |
|
|
|
|
|
trainer = Trainer |
|
|
|
@classmethod |
|
def add_task_arguments(cls, parser: argparse.ArgumentParser): |
|
|
|
assert check_argument_types() |
|
group = parser.add_argument_group(description="Task related") |
|
|
|
|
|
|
|
required = parser.get_default("required") |
|
required += ["token_list"] |
|
|
|
group.add_argument( |
|
"--token_list", |
|
type=str_or_none, |
|
default=None, |
|
help="A text mapping int-id to token", |
|
) |
|
group.add_argument( |
|
"--odim", |
|
type=int_or_none, |
|
default=None, |
|
help="The number of dimension of output feature", |
|
) |
|
group.add_argument( |
|
"--model_conf", |
|
action=NestedDictAction, |
|
default=get_default_kwargs(ESPnetTTSModel), |
|
help="The keyword arguments for model class.", |
|
) |
|
|
|
group = parser.add_argument_group(description="Preprocess related") |
|
group.add_argument( |
|
"--use_preprocessor", |
|
type=str2bool, |
|
default=True, |
|
help="Apply preprocessing to data or not", |
|
) |
|
group.add_argument( |
|
"--token_type", |
|
type=str, |
|
default="phn", |
|
choices=["bpe", "char", "word", "phn"], |
|
help="The text will be tokenized in the specified level token", |
|
) |
|
group.add_argument( |
|
"--bpemodel", |
|
type=str_or_none, |
|
default=None, |
|
help="The model file of sentencepiece", |
|
) |
|
parser.add_argument( |
|
"--non_linguistic_symbols", |
|
type=str_or_none, |
|
help="non_linguistic_symbols file path", |
|
) |
|
parser.add_argument( |
|
"--cleaner", |
|
type=str_or_none, |
|
choices=[None, "tacotron", "jaconv", "vietnamese"], |
|
default=None, |
|
help="Apply text cleaning", |
|
) |
|
parser.add_argument( |
|
"--g2p", |
|
type=str_or_none, |
|
choices=[ |
|
None, |
|
"g2p_en", |
|
"g2p_en_no_space", |
|
"pyopenjtalk", |
|
"pyopenjtalk_kana", |
|
"pyopenjtalk_accent", |
|
"pyopenjtalk_accent_with_pause", |
|
"pypinyin_g2p", |
|
"pypinyin_g2p_phone", |
|
"espeak_ng_arabic", |
|
], |
|
default=None, |
|
help="Specify g2p method if --token_type=phn", |
|
) |
|
|
|
for class_choices in cls.class_choices_list: |
|
|
|
|
|
class_choices.add_arguments(group) |
|
|
|
@classmethod |
|
def build_collate_fn( |
|
cls, args: argparse.Namespace, train: bool |
|
) -> Callable[ |
|
[Collection[Tuple[str, Dict[str, np.ndarray]]]], |
|
Tuple[List[str], Dict[str, torch.Tensor]], |
|
]: |
|
assert check_argument_types() |
|
return CommonCollateFn( |
|
float_pad_value=0.0, int_pad_value=0, not_sequence=["spembs"] |
|
) |
|
|
|
@classmethod |
|
def build_preprocess_fn( |
|
cls, args: argparse.Namespace, train: bool |
|
) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]: |
|
assert check_argument_types() |
|
if args.use_preprocessor: |
|
retval = CommonPreprocessor( |
|
train=train, |
|
token_type=args.token_type, |
|
token_list=args.token_list, |
|
bpemodel=args.bpemodel, |
|
non_linguistic_symbols=args.non_linguistic_symbols, |
|
text_cleaner=args.cleaner, |
|
g2p_type=args.g2p, |
|
) |
|
else: |
|
retval = None |
|
assert check_return_type(retval) |
|
return retval |
|
|
|
@classmethod |
|
def required_data_names( |
|
cls, train: bool = True, inference: bool = False |
|
) -> Tuple[str, ...]: |
|
if not inference: |
|
retval = ("text", "speech") |
|
else: |
|
|
|
retval = ("text",) |
|
return retval |
|
|
|
@classmethod |
|
def optional_data_names( |
|
cls, train: bool = True, inference: bool = False |
|
) -> Tuple[str, ...]: |
|
if not inference: |
|
retval = ("spembs", "durations", "pitch", "energy") |
|
else: |
|
|
|
retval = ("spembs", "speech", "durations") |
|
return retval |
|
|
|
@classmethod |
|
def build_model(cls, args: argparse.Namespace) -> ESPnetTTSModel: |
|
assert check_argument_types() |
|
if isinstance(args.token_list, str): |
|
with open(args.token_list, encoding="utf-8") as f: |
|
token_list = [line.rstrip() for line in f] |
|
|
|
|
|
|
|
args.token_list = token_list.copy() |
|
elif isinstance(args.token_list, (tuple, list)): |
|
token_list = args.token_list.copy() |
|
else: |
|
raise RuntimeError("token_list must be str or dict") |
|
|
|
vocab_size = len(token_list) |
|
logging.info(f"Vocabulary size: {vocab_size }") |
|
|
|
|
|
if args.odim is None: |
|
|
|
feats_extract_class = feats_extractor_choices.get_class(args.feats_extract) |
|
feats_extract = feats_extract_class(**args.feats_extract_conf) |
|
odim = feats_extract.output_size() |
|
else: |
|
|
|
args.feats_extract = None |
|
args.feats_extract_conf = None |
|
feats_extract = None |
|
odim = args.odim |
|
|
|
|
|
if args.normalize is not None: |
|
normalize_class = normalize_choices.get_class(args.normalize) |
|
normalize = normalize_class(**args.normalize_conf) |
|
else: |
|
normalize = None |
|
|
|
|
|
tts_class = tts_choices.get_class(args.tts) |
|
tts = tts_class(idim=vocab_size, odim=odim, **args.tts_conf) |
|
|
|
|
|
pitch_extract = None |
|
energy_extract = None |
|
pitch_normalize = None |
|
energy_normalize = None |
|
if getattr(args, "pitch_extract", None) is not None: |
|
pitch_extract_class = pitch_extractor_choices.get_class(args.pitch_extract) |
|
if args.pitch_extract_conf.get("reduction_factor", None) is not None: |
|
assert args.pitch_extract_conf.get( |
|
"reduction_factor", None |
|
) == args.tts_conf.get("reduction_factor", 1) |
|
else: |
|
args.pitch_extract_conf["reduction_factor"] = args.tts_conf.get( |
|
"reduction_factor", 1 |
|
) |
|
pitch_extract = pitch_extract_class(**args.pitch_extract_conf) |
|
if getattr(args, "energy_extract", None) is not None: |
|
if args.energy_extract_conf.get("reduction_factor", None) is not None: |
|
assert args.energy_extract_conf.get( |
|
"reduction_factor", None |
|
) == args.tts_conf.get("reduction_factor", 1) |
|
else: |
|
args.energy_extract_conf["reduction_factor"] = args.tts_conf.get( |
|
"reduction_factor", 1 |
|
) |
|
energy_extract_class = energy_extractor_choices.get_class( |
|
args.energy_extract |
|
) |
|
energy_extract = energy_extract_class(**args.energy_extract_conf) |
|
if getattr(args, "pitch_normalize", None) is not None: |
|
pitch_normalize_class = pitch_normalize_choices.get_class( |
|
args.pitch_normalize |
|
) |
|
pitch_normalize = pitch_normalize_class(**args.pitch_normalize_conf) |
|
if getattr(args, "energy_normalize", None) is not None: |
|
energy_normalize_class = energy_normalize_choices.get_class( |
|
args.energy_normalize |
|
) |
|
energy_normalize = energy_normalize_class(**args.energy_normalize_conf) |
|
|
|
|
|
model = ESPnetTTSModel( |
|
feats_extract=feats_extract, |
|
pitch_extract=pitch_extract, |
|
energy_extract=energy_extract, |
|
normalize=normalize, |
|
pitch_normalize=pitch_normalize, |
|
energy_normalize=energy_normalize, |
|
tts=tts, |
|
**args.model_conf, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assert check_return_type(model) |
|
return model |
|
|