nx_denoise / main.py
HoneyTian's picture
update
637d40c
raw
history blame
6.18 kB
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import argparse
from functools import lru_cache
import logging
from pathlib import Path
import platform
import shutil
import zipfile
import gradio as gr
from huggingface_hub import snapshot_download
import numpy as np
import log
from project_settings import environment, project_path, log_directory
from toolbox.os.command import Command
from toolbox.torchaudio.models.mpnet.inference_mpnet import InferenceMPNet
log.setup_size_rotating(log_directory=log_directory)
logger = logging.getLogger("main")
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--examples_dir",
# default=(project_path / "data").as_posix(),
default=(project_path / "data/examples").as_posix(),
type=str
)
parser.add_argument(
"--models_repo_id",
default="qgyd2021/nx_denoise",
type=str
)
parser.add_argument(
"--trained_model_dir",
default=(project_path / "trained_models").as_posix(),
type=str
)
parser.add_argument(
"--hf_token",
default=environment.get("hf_token"),
type=str,
)
parser.add_argument(
"--server_port",
default=environment.get("server_port", 7860),
type=int
)
args = parser.parse_args()
return args
def shell(cmd: str):
return Command.popen(cmd)
denoise_engines = {
"mpnet-nx-speech-1-epoch": {
"infer_cls": InferenceMPNet,
"kwargs": {
"pretrained_model_path_or_zip_file": (
project_path / "trained_models/mpnet-nx-speech-1-epoch.zip").as_posix()
}
},
"mpnet-aishell-1-epoch": {
"infer_cls": InferenceMPNet,
"kwargs": {
"pretrained_model_path_or_zip_file": (project_path / "trained_models/mpnet-aishell-1-epoch.zip").as_posix()
}
},
"mpnet-aishell-11-epoch": {
"infer_cls": InferenceMPNet,
"kwargs": {
"pretrained_model_path_or_zip_file": (project_path / "trained_models/mpnet-aishell-11-epoch.zip").as_posix()
}
},
}
@lru_cache(maxsize=3)
def load_denoise_model(infer_cls, **kwargs):
infer_engine = infer_cls(**kwargs)
return infer_engine
def when_click_denoise_button(noisy_audio_t, engine: str):
sample_rate, signal = noisy_audio_t
logger.info(f"run denoise; engine: {engine}, sample_rate: {sample_rate}, signal dtype: {signal.dtype}, signal shape: {signal.shape}")
noisy_audio = np.array(signal / (1 << 15), dtype=np.float32)
infer_engine_param = denoise_engines.get(engine)
if infer_engine_param is None:
raise gr.Error(f"invalid denoise engine: {engine}.")
try:
infer_cls = infer_engine_param["infer_cls"]
kwargs = infer_engine_param["kwargs"]
infer_engine = load_denoise_model(infer_cls=infer_cls, **kwargs)
enhanced_audio = infer_engine.enhancement_by_ndarray(noisy_audio)
enhanced_audio = np.array(enhanced_audio * (1 << 15), dtype=np.int16)
except Exception as e:
raise gr.Error(f"enhancement failed, error type: {type(e)}, error text: {str(e)}.")
enhanced_audio_t = (sample_rate, enhanced_audio)
return enhanced_audio_t
def main():
args = get_args()
examples_dir = Path(args.examples_dir)
trained_model_dir = Path(args.trained_model_dir)
# download models
if not trained_model_dir.exists():
trained_model_dir.mkdir(parents=True, exist_ok=True)
_ = snapshot_download(
repo_id=args.models_repo_id,
local_dir=trained_model_dir.as_posix(),
token=args.hf_token,
)
# choices
denoise_engine_choices = list(denoise_engines.keys())
# examples
if not examples_dir.exists():
example_zip_file = trained_model_dir / "examples.zip"
with zipfile.ZipFile(example_zip_file.as_posix(), "r") as f_zip:
out_root = examples_dir
if out_root.exists():
shutil.rmtree(out_root.as_posix())
out_root.mkdir(parents=True, exist_ok=True)
f_zip.extractall(path=out_root)
# examples
examples = list()
for filename in examples_dir.glob("**/*.wav"):
examples.append([
filename.as_posix(),
denoise_engine_choices[0]
])
# ui
with gr.Blocks() as blocks:
gr.Markdown(value="nx denoise.")
with gr.Tabs():
with gr.TabItem("denoise"):
with gr.Row():
with gr.Column(variant="panel", scale=5):
dn_noisy_audio = gr.Audio(label="noisy_audio")
dn_engine = gr.Dropdown(choices=denoise_engine_choices, value=denoise_engine_choices[0], label="engine")
dn_button = gr.Button(variant="primary")
with gr.Column(variant="panel", scale=5):
dn_enhanced_audio = gr.Audio(label="enhanced_audio")
dn_button.click(
when_click_denoise_button,
inputs=[dn_noisy_audio, dn_engine],
outputs=[dn_enhanced_audio]
)
gr.Examples(
examples=examples,
inputs=[dn_noisy_audio, dn_engine],
outputs=[dn_enhanced_audio],
fn=when_click_denoise_button,
# cache_examples=True,
# cache_mode="lazy",
)
with gr.TabItem("shell"):
shell_text = gr.Textbox(label="cmd")
shell_button = gr.Button("run")
shell_output = gr.Textbox(label="output")
shell_button.click(
shell,
inputs=[shell_text,],
outputs=[shell_output],
)
# http://127.0.0.1:7864/
blocks.queue().launch(
share=False if platform.system() == "Windows" else False,
server_name="127.0.0.1" if platform.system() == "Windows" else "0.0.0.0",
server_port=args.server_port
)
return
if __name__ == "__main__":
main()