Hecheng0625's picture
Upload 167 files
8c92a11 verified
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import torch
from models.tts.fastspeech2.fs2_trainer import FastSpeech2Trainer
from models.tts.vits.vits_trainer import VITSTrainer
from models.tts.valle.valle_trainer import VALLETrainer
from models.tts.naturalspeech2.ns2_trainer import NS2Trainer
from models.tts.valle_v2.valle_ar_trainer import ValleARTrainer as VALLE_V2_AR
from models.tts.valle_v2.valle_nar_trainer import ValleNARTrainer as VALLE_V2_NAR
from models.tts.jets.jets_trainer import JetsTrainer
from utils.util import load_config
def build_trainer(args, cfg):
supported_trainer = {
"FastSpeech2": FastSpeech2Trainer,
"VITS": VITSTrainer,
"VALLE": VALLETrainer,
"NaturalSpeech2": NS2Trainer,
"VALLE_V2_AR": VALLE_V2_AR,
"VALLE_V2_NAR": VALLE_V2_NAR,
"Jets": JetsTrainer,
}
trainer_class = supported_trainer[cfg.model_type]
trainer = trainer_class(args, cfg)
return trainer
def cuda_relevant(deterministic=False):
torch.cuda.empty_cache()
# TF32 on Ampere and above
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.allow_tf32 = True
# Deterministic
torch.backends.cudnn.deterministic = deterministic
torch.backends.cudnn.benchmark = not deterministic
torch.use_deterministic_algorithms(deterministic)
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--config",
default="config.json",
help="json files for configurations.",
required=True,
)
parser.add_argument(
"--seed",
type=int,
default=1234,
help="random seed",
required=False,
)
parser.add_argument(
"--exp_name",
type=str,
default="exp_name",
help="A specific name to note the experiment",
required=True,
)
parser.add_argument(
"--resume", action="store_true", help="The model name to restore"
)
parser.add_argument(
"--test", action="store_true", default=False, help="Test the model"
)
parser.add_argument(
"--log_level", default="warning", help="logging level (debug, info, warning)"
)
parser.add_argument(
"--resume_type",
type=str,
default="resume",
help="Resume training or finetuning.",
)
parser.add_argument(
"--checkpoint_path",
type=str,
default=None,
help="Checkpoint for resume training or finetuning.",
)
parser.add_argument(
"--resume_from_ckpt_path",
type=str,
default="",
help="Checkpoint for resume training or finetuning.",
)
# VALLETrainer.add_arguments(parser)
args = parser.parse_args()
cfg = load_config(args.config)
# Data Augmentation
if hasattr(cfg, "preprocess"):
if hasattr(cfg.preprocess, "data_augment"):
if (
type(cfg.preprocess.data_augment) == list
and len(cfg.preprocess.data_augment) > 0
):
new_datasets_list = []
for dataset in cfg.preprocess.data_augment:
new_datasets = [
(
f"{dataset}_pitch_shift"
if cfg.preprocess.use_pitch_shift
else None
),
(
f"{dataset}_formant_shift"
if cfg.preprocess.use_formant_shift
else None
),
(
f"{dataset}_equalizer"
if cfg.preprocess.use_equalizer
else None
),
(
f"{dataset}_time_stretch"
if cfg.preprocess.use_time_stretch
else None
),
]
new_datasets_list.extend(filter(None, new_datasets))
cfg.dataset.extend(new_datasets_list)
print("experiment name: ", args.exp_name)
# # CUDA settings
cuda_relevant()
# Build trainer
print(f"Building {cfg.model_type} trainer")
trainer = build_trainer(args, cfg)
print(f"Start training {cfg.model_type} model")
if args.test:
trainer.test_loop()
else:
trainer.train_loop()
if __name__ == "__main__":
main()