owiedotch's picture
Update app.py
d4b0a5f verified
raw
history blame
No virus
5.56 kB
import gradio as gr
import torch
# import demucs.api
import os
import spaces
import subprocess
from pydub import AudioSegment
from typing import Tuple, Dict, List
from demucs.apply import apply_model
from demucs.separate import load_track
from demucs.pretrained import get_model
from demucs.audio import save_audio # Import save_audio from demucs.audio
# check if cuda is available
device: str = "cuda" if torch.cuda.is_available() else "cpu"
# check if sox is installed and install it if necessary
try:
subprocess.run(["sox", "--version"], check=True, capture_output=True)
except FileNotFoundError:
print("sox is not installed. trying to install it now...")
try:
subprocess.run(["apt-get", "update"], check=True)
subprocess.run(["apt-get", "install", "-y", "sox"], check=True)
print("sox has been installed.")
except subprocess.CalledProcessError as e:
print(f"error installing sox: {e}")
print("please install sox manually or try adding the following repository to your sources list:")
print("deb http://deb.debian.org/debian stretch main contrib non-free")
exit(1)
# define the inference function
@spaces.GPU
def inference(audio_file: str, model_name: str, vocals: bool, drums: bool, bass: bool, other: bool, mp3: bool, mp3_bitrate: int) -> Tuple[str, str]:
"""
performs inference using demucs and mixes the selected stems.
args:
audio_file: the audio file to separate.
model_name: the name of the demucs model to use.
vocals: whether to include vocals in the mix.
drums: whether to include drums in the mix.
bass: whether to include bass in the mix.
other: whether to include other instruments in the mix.
mp3: whether to save the output as mp3.
mp3_bitrate: the bitrate of the output mp3 file.
returns:
a tuple containing the path to the mixed audio file and the separation log.
"""
# initialize demucs separator
# separator: demucs.api.Separator = demucs.api.Separator(model=model_name)
separator = get_model(name=model_name)
# separate audio file and capture log
import io
log_stream = io.StringIO()
# origin, separated = separator.separate_audio_file(audio_file, progress=True, log_stream=log_stream)
wav = load_track(audio_file, separator.samplerate, channels=separator.audio_channels)
ref = wav.mean(0)
wav = (wav - ref.view(1, -1)).to(device)
sources = apply_model(separator, wav, device=device, progress=True, log_stream=log_stream)
sources = sources * ref.view(1, -1) + ref.view(1, -1)
separation_log = log_stream.getvalue()
# get the output file paths
output_dir: str = os.path.join("separated", model_name, os.path.splitext(os.path.basename(audio_file))[0])
os.makedirs(output_dir, exist_ok=True) # create output directory if it doesn't exist
stems: Dict[str, str] = {}
for stem, source in zip(separator.sources, sources):
stem_path: str = os.path.join(output_dir, f"{stem}.wav")
# demucs.api.save_audio(source, stem_path, samplerate=separator.samplerate)
save_audio(source, stem_path, separator.samplerate) # Use save_audio
stems[stem] = stem_path
# mix the selected stems
selected_stems: List[str] = [stems[stem] for stem, include in zip(["vocals", "drums", "bass", "other"], [vocals, drums, bass, other]) if include]
if not selected_stems:
raise gr.Error("please select at least one stem to mix.")
output_file: str = os.path.join(output_dir, "mixed.wav")
if len(selected_stems) == 1:
# if only one stem is selected, just copy it
os.rename(selected_stems[0], output_file)
else:
# otherwise, use pydub to mix the stems
mixed_audio: AudioSegment = AudioSegment.empty()
for stem_path in selected_stems:
mixed_audio += AudioSegment.from_wav(stem_path)
mixed_audio.export(output_file, format="wav")
# automatically convert to mp3 if requested
if mp3:
mp3_output_file: str = os.path.splitext(output_file)[0] + ".mp3"
mixed_audio.export(mp3_output_file, format="mp3", bitrate=str(mp3_bitrate) + "k")
output_file = mp3_output_file # update output_file to the mp3 file
return output_file, separation_log
# define the gradio interface
iface: gr.Interface = gr.Interface(
fn=inference,
inputs=[
gr.Audio(type="filepath"),
gr.Dropdown(["htdemucs", "htdemucs_ft", "htdemucs_6s", "hdemucs_mmi", "mdx", "mdx_extra", "mdx_q", "mdx_extra_q"], label="model name", value="htdemucs_ft"), # set default value
gr.Checkbox(label="vocals", value=True),
gr.Checkbox(label="drums", value=True),
gr.Checkbox(label="bass", value=True),
gr.Checkbox(label="other", value=True),
gr.Checkbox(label="save as mp3", value=False), # set default value to false
gr.Slider(128, 320, step=32, label="mp3 bitrate", visible=False), # set visible to false initially
],
outputs=[
gr.Audio(type="filepath"),
gr.Textbox(label="separation log", lines=10),
],
title="demucs music source separation and mixing",
description="separate vocals, drums, bass, and other instruments from your music using demucs and mix the selected stems.",
)
# make mp3 bitrate slider visible only when "save as mp3" is checked
iface.inputs[-2].change(fn=lambda mp3: gr.update(visible=mp3), inputs=iface.inputs[-2], outputs=iface.inputs[-1])
# launch the gradio interface
iface.launch()