#!/usr/bin/python3 # -*- coding: utf-8 -*- import argparse from functools import lru_cache import logging from pathlib import Path import platform import shutil import zipfile import gradio as gr from huggingface_hub import snapshot_download import numpy as np import log from project_settings import environment, project_path, log_directory from toolbox.os.command import Command from toolbox.torchaudio.models.mpnet.inference_mpnet import InferenceMPNet log.setup_size_rotating(log_directory=log_directory) logger = logging.getLogger("main") def get_args(): parser = argparse.ArgumentParser() parser.add_argument( "--examples_dir", # default=(project_path / "data").as_posix(), default=(project_path / "data/examples").as_posix(), type=str ) parser.add_argument( "--models_repo_id", default="qgyd2021/nx_denoise", type=str ) parser.add_argument( "--trained_model_dir", default=(project_path / "trained_models").as_posix(), type=str ) parser.add_argument( "--hf_token", default=environment.get("hf_token"), type=str, ) parser.add_argument( "--server_port", default=environment.get("server_port", 7860), type=int ) args = parser.parse_args() return args def shell(cmd: str): return Command.popen(cmd) denoise_engines = { "mpnet-nx-speech-1-epoch": { "infer_cls": InferenceMPNet, "kwargs": { "pretrained_model_path_or_zip_file": ( project_path / "trained_models/mpnet-nx-speech-1-epoch.zip").as_posix() } }, "mpnet-aishell-1-epoch": { "infer_cls": InferenceMPNet, "kwargs": { "pretrained_model_path_or_zip_file": (project_path / "trained_models/mpnet-aishell-1-epoch.zip").as_posix() } }, "mpnet-aishell-11-epoch": { "infer_cls": InferenceMPNet, "kwargs": { "pretrained_model_path_or_zip_file": (project_path / "trained_models/mpnet-aishell-11-epoch.zip").as_posix() } }, } @lru_cache(maxsize=3) def load_denoise_model(infer_cls, **kwargs): infer_engine = infer_cls(**kwargs) return infer_engine def when_click_denoise_button(noisy_audio_t, engine: str): sample_rate, signal = noisy_audio_t logger.info(f"run denoise; engine: {engine}, sample_rate: {sample_rate}, signal dtype: {signal.dtype}, signal shape: {signal.shape}") noisy_audio = np.array(signal / (1 << 15), dtype=np.float32) infer_engine_param = denoise_engines.get(engine) if infer_engine_param is None: raise gr.Error(f"invalid denoise engine: {engine}.") try: infer_cls = infer_engine_param["infer_cls"] kwargs = infer_engine_param["kwargs"] infer_engine = load_denoise_model(infer_cls=infer_cls, **kwargs) enhanced_audio = infer_engine.enhancement_by_ndarray(noisy_audio) enhanced_audio = np.array(enhanced_audio * (1 << 15), dtype=np.int16) except Exception as e: raise gr.Error(f"enhancement failed, error type: {type(e)}, error text: {str(e)}.") enhanced_audio_t = (sample_rate, enhanced_audio) return enhanced_audio_t def main(): args = get_args() examples_dir = Path(args.examples_dir) trained_model_dir = Path(args.trained_model_dir) # download models if not trained_model_dir.exists(): trained_model_dir.mkdir(parents=True, exist_ok=True) _ = snapshot_download( repo_id=args.models_repo_id, local_dir=trained_model_dir.as_posix(), token=args.hf_token, ) # choices denoise_engine_choices = list(denoise_engines.keys()) # examples if not examples_dir.exists(): example_zip_file = trained_model_dir / "examples.zip" with zipfile.ZipFile(example_zip_file.as_posix(), "r") as f_zip: out_root = examples_dir if out_root.exists(): shutil.rmtree(out_root.as_posix()) out_root.mkdir(parents=True, exist_ok=True) f_zip.extractall(path=out_root) # examples examples = list() for filename in examples_dir.glob("**/*.wav"): examples.append([ filename.as_posix(), denoise_engine_choices[0] ]) # ui with gr.Blocks() as blocks: gr.Markdown(value="nx denoise.") with gr.Tabs(): with gr.TabItem("denoise"): with gr.Row(): with gr.Column(variant="panel", scale=5): dn_noisy_audio = gr.Audio(label="noisy_audio") dn_engine = gr.Dropdown(choices=denoise_engine_choices, value=denoise_engine_choices[0], label="engine") dn_button = gr.Button(variant="primary") with gr.Column(variant="panel", scale=5): dn_enhanced_audio = gr.Audio(label="enhanced_audio") dn_button.click( when_click_denoise_button, inputs=[dn_noisy_audio, dn_engine], outputs=[dn_enhanced_audio] ) gr.Examples( examples=examples, inputs=[dn_noisy_audio, dn_engine], outputs=[dn_enhanced_audio], fn=when_click_denoise_button, # cache_examples=True, # cache_mode="lazy", ) with gr.TabItem("shell"): shell_text = gr.Textbox(label="cmd") shell_button = gr.Button("run") shell_output = gr.Textbox(label="output") shell_button.click( shell, inputs=[shell_text,], outputs=[shell_output], ) # http://127.0.0.1:7864/ blocks.queue().launch( share=False if platform.system() == "Windows" else False, server_name="127.0.0.1" if platform.system() == "Windows" else "0.0.0.0", server_port=args.server_port ) return if __name__ == "__main__": main()