Spaces:
Runtime error
Runtime error
# Copyright (c) 2023 Amphion. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import argparse | |
from argparse import ArgumentParser | |
import os | |
from models.tta.ldm.audioldm_inference import AudioLDMInference | |
from utils.util import save_config, load_model_config, load_config | |
import numpy as np | |
import torch | |
def build_inference(args, cfg): | |
supported_inference = { | |
"AudioLDM": AudioLDMInference, | |
} | |
inference_class = supported_inference[cfg.model_type] | |
inference = inference_class(args, cfg) | |
return inference | |
def build_parser(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--config", | |
type=str, | |
required=True, | |
help="JSON/YAML file for configurations.", | |
) | |
parser.add_argument( | |
"--text", | |
help="Text to be synthesized", | |
type=str, | |
default="Text to be synthesized.", | |
) | |
parser.add_argument( | |
"--checkpoint_path", | |
type=str, | |
) | |
parser.add_argument( | |
"--vocoder_path", type=str, help="Checkpoint path of the vocoder" | |
) | |
parser.add_argument( | |
"--vocoder_config_path", type=str, help="Config path of the vocoder" | |
) | |
parser.add_argument( | |
"--output_dir", | |
type=str, | |
default=None, | |
help="Output dir for saving generated results", | |
) | |
parser.add_argument( | |
"--num_steps", | |
type=int, | |
default=200, | |
help="The total number of denosing steps", | |
) | |
parser.add_argument( | |
"--guidance_scale", | |
type=float, | |
default=4.0, | |
help="The scale of classifer free guidance", | |
) | |
parser.add_argument("--local_rank", default=-1, type=int) | |
return parser | |
def main(): | |
# Parse arguments | |
args = build_parser().parse_args() | |
# args, infer_type = formulate_parser(args) | |
# Parse config | |
cfg = load_config(args.config) | |
if torch.cuda.is_available(): | |
args.local_rank = torch.device("cuda") | |
else: | |
args.local_rank = torch.device("cpu") | |
print("args: ", args) | |
# Build inference | |
inferencer = build_inference(args, cfg) | |
# Run inference | |
inferencer.inference() | |
if __name__ == "__main__": | |
main() | |