owiedotch's picture
Update app.py
99f4e90 verified
raw
history blame
No virus
4.92 kB
import gradio as gr
import torch
import demucs.api
import os
import spaces
import subprocess
from pydub import AudioSegment
from typing import Tuple, Dict, List
# check if cuda is available
device: str = "cuda" if torch.cuda.is_available() else "cpu"
# check if sox is installed and install it if necessary
try:
subprocess.run(["sox", "--version"], check=True, capture_output=True)
except FileNotFoundError:
print("sox is not installed. trying to install it now...")
try:
subprocess.run(["apt-get", "update"], check=True)
subprocess.run(["apt-get", "install", "-y", "sox"], check=True)
print("sox has been installed.")
except subprocess.CalledProcessError as e:
print(f"error installing sox: {e}")
print("please install sox manually or try adding the following repository to your sources list:")
print("deb http://deb.debian.org/debian stretch main contrib non-free")
exit(1)
# define the inference function
@spaces.GPU
def inference(audio_file: str, model_name: str, vocals: bool, drums: bool, bass: bool, other: bool, mp3: bool, mp3_bitrate: int) -> Tuple[str, str]:
"""
performs inference using demucs and mixes the selected stems.
args:
audio_file: the audio file to separate.
model_name: the name of the demucs model to use.
vocals: whether to include vocals in the mix.
drums: whether to include drums in the mix.
bass: whether to include bass in the mix.
other: whether to include other instruments in the mix.
mp3: whether to save the output as mp3.
mp3_bitrate: the bitrate of the output mp3 file.
returns:
a tuple containing the path to the mixed audio file and the separation log.
"""
# initialize demucs separator
separator: demucs.api.Separator = demucs.api.Separator(model=model_name)
# separate audio file and capture log
import io
log_stream = io.StringIO()
origin, separated = separator.separate_audio_file(audio_file, progress=True, log_stream=log_stream)
separation_log = log_stream.getvalue()
# get the output file paths
output_dir: str = os.path.join("separated", model_name, os.path.splitext(os.path.basename(audio_file))[0])
os.makedirs(output_dir, exist_ok=True) # create output directory if it doesn't exist
stems: Dict[str, str] = {}
for stem, source in separated.items():
stem_path: str = os.path.join(output_dir, f"{stem}.wav")
demucs.api.save_audio(source, stem_path, samplerate=separator.samplerate)
stems[stem] = stem_path
# mix the selected stems
selected_stems: List[str] = [stems[stem] for stem, include in zip(["vocals", "drums", "bass", "other"], [vocals, drums, bass, other]) if include]
if not selected_stems:
raise gr.Error("please select at least one stem to mix.")
output_file: str = os.path.join(output_dir, "mixed.wav")
if len(selected_stems) == 1:
# if only one stem is selected, just copy it
os.rename(selected_stems[0], output_file)
else:
# otherwise, use pydub to mix the stems
mixed_audio: AudioSegment = AudioSegment.empty()
for stem_path in selected_stems:
mixed_audio += AudioSegment.from_wav(stem_path)
mixed_audio.export(output_file, format="wav")
# automatically convert to mp3 if requested
if mp3:
mp3_output_file: str = os.path.splitext(output_file)[0] + ".mp3"
mixed_audio.export(mp3_output_file, format="mp3", bitrate=str(mp3_bitrate) + "k")
output_file = mp3_output_file # update output_file to the mp3 file
return output_file, separation_log
# define the gradio interface
iface: gr.Interface = gr.Interface(
fn=inference,
inputs=[
gr.Audio(type="filepath"),
gr.Dropdown(["htdemucs", "htdemucs_ft", "htdemucs_6s", "hdemucs_mmi", "mdx", "mdx_extra", "mdx_q", "mdx_extra_q"], label="model name", value="htdemucs_ft"), # set default value
gr.Checkbox(label="vocals", value=True),
gr.Checkbox(label="drums", value=True),
gr.Checkbox(label="bass", value=True),
gr.Checkbox(label="other", value=True),
gr.Checkbox(label="save as mp3", value=False), # set default value to false
gr.Slider(128, 320, step=32, label="mp3 bitrate", visible=False), # set visible to false initially
],
outputs=[
gr.Audio(type="filepath"),
gr.Textbox(label="separation log", lines=10),
],
title="demucs music source separation and mixing",
description="separate vocals, drums, bass, and other instruments from your music using demucs and mix the selected stems.",
)
# make mp3 bitrate slider visible only when "save as mp3" is checked
iface.inputs[-2].change(fn=lambda mp3: gr.update(visible=mp3), inputs=iface.inputs[-2], outputs=iface.inputs[-1])
# launch the gradio interface
iface.launch()