Spaces:
Running
Running
#!/usr/bin/python3 | |
# -*- coding: utf-8 -*- | |
import argparse | |
import logging | |
from pathlib import Path | |
import platform | |
import shutil | |
import zipfile | |
import gradio as gr | |
from huggingface_hub import snapshot_download | |
import numpy as np | |
import log | |
from project_settings import environment, project_path, log_directory | |
from toolbox.os.command import Command | |
from toolbox.torchaudio.models.mpnet.inference_mpnet import InferenceMPNet | |
log.setup_size_rotating(log_directory=log_directory) | |
logger = logging.getLogger("main") | |
def get_args(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--examples_dir", | |
# default=(project_path / "data").as_posix(), | |
default=(project_path / "data/examples").as_posix(), | |
type=str | |
) | |
parser.add_argument( | |
"--models_repo_id", | |
default="qgyd2021/nx_denoise", | |
type=str | |
) | |
parser.add_argument( | |
"--trained_model_dir", | |
default=(project_path / "trained_models").as_posix(), | |
type=str | |
) | |
parser.add_argument( | |
"--hf_token", | |
default=environment.get("hf_token"), | |
type=str, | |
) | |
parser.add_argument( | |
"--server_port", | |
default=environment.get("server_port", 7860), | |
type=int | |
) | |
args = parser.parse_args() | |
return args | |
def shell(cmd: str): | |
return Command.popen(cmd) | |
denoise_engines = dict() | |
def when_click_denoise_button(noisy_audio_t, engine: str): | |
sample_rate, signal = noisy_audio_t | |
logger.info(f"run denoise; engine: {engine}, sample_rate: {sample_rate}, signal dtype: {signal.dtype}, signal shape: {signal.shape}") | |
noisy_audio = np.array(signal / (1 << 15), dtype=np.float32) | |
infer_engine = denoise_engines.get(engine) | |
if infer_engine is None: | |
raise gr.Error(f"invalid denoise engine: {engine}.") | |
try: | |
enhanced_audio = infer_engine.enhancement_by_ndarray(noisy_audio) | |
enhanced_audio = np.array(enhanced_audio * (1 << 15), dtype=np.int16) | |
except Exception as e: | |
raise gr.Error(f"enhancement failed, error type: {type(e)}, error text: {str(e)}.") | |
enhanced_audio_t = (sample_rate, enhanced_audio) | |
return enhanced_audio_t | |
def main(): | |
args = get_args() | |
examples_dir = Path(args.examples_dir) | |
trained_model_dir = Path(args.trained_model_dir) | |
# download models | |
if not trained_model_dir.exists(): | |
trained_model_dir.mkdir(parents=True, exist_ok=True) | |
_ = snapshot_download( | |
repo_id=args.models_repo_id, | |
local_dir=trained_model_dir.as_posix(), | |
token=args.hf_token, | |
) | |
# engines | |
global denoise_engines | |
denoise_engines = { | |
"mpnet-aishell-1-epoch": InferenceMPNet( | |
pretrained_model_path_or_zip_file=(project_path / "trained_models/mpnet-aishell-1-epoch.zip").as_posix(), | |
), | |
"mpnet-aishell-11-epoch": InferenceMPNet( | |
pretrained_model_path_or_zip_file=(project_path / "trained_models/mpnet-aishell-11-epoch.zip").as_posix(), | |
), | |
} | |
# choices | |
denoise_engine_choices = list(denoise_engines.keys()) | |
# examples | |
example_zip_file = trained_model_dir / "examples.zip" | |
with zipfile.ZipFile(example_zip_file.as_posix(), "r") as f_zip: | |
out_root = examples_dir | |
if out_root.exists(): | |
shutil.rmtree(out_root.as_posix()) | |
out_root.mkdir(parents=True, exist_ok=True) | |
f_zip.extractall(path=out_root) | |
# examples | |
examples = list() | |
for filename in examples_dir.glob("**/*.wav"): | |
examples.append([ | |
filename.as_posix(), | |
denoise_engine_choices[0] | |
]) | |
# ui | |
with gr.Blocks() as blocks: | |
gr.Markdown(value="nx denoise.") | |
with gr.Tabs(): | |
with gr.TabItem("denoise"): | |
with gr.Row(): | |
with gr.Column(variant="panel", scale=5): | |
dn_noisy_audio = gr.Audio(label="noisy_audio") | |
dn_engine = gr.Dropdown(choices=denoise_engine_choices, value=denoise_engine_choices[0], label="engine") | |
dn_button = gr.Button(variant="primary") | |
with gr.Column(variant="panel", scale=5): | |
dn_enhanced_audio = gr.Audio(label="enhanced_audio") | |
dn_button.click( | |
when_click_denoise_button, | |
inputs=[dn_noisy_audio, dn_engine], | |
outputs=[dn_enhanced_audio] | |
) | |
gr.Examples( | |
examples=examples, | |
inputs=[dn_noisy_audio, dn_engine], | |
outputs=[dn_enhanced_audio], | |
fn=when_click_denoise_button, | |
cache_examples=True, | |
cache_mode="lazy", | |
) | |
with gr.TabItem("shell"): | |
shell_text = gr.Textbox(label="cmd") | |
shell_button = gr.Button("run") | |
shell_output = gr.Textbox(label="output") | |
shell_button.click( | |
shell, | |
inputs=[shell_text,], | |
outputs=[shell_output], | |
) | |
# http://127.0.0.1:7864/ | |
blocks.queue().launch( | |
share=False if platform.system() == "Windows" else False, | |
server_name="127.0.0.1" if platform.system() == "Windows" else "0.0.0.0", | |
server_port=args.server_port | |
) | |
return | |
if __name__ == "__main__": | |
main() | |