owiedotch's picture
Update app.py
b072ff1 verified
raw
history blame
No virus
4.6 kB
import gradio as gr
import torch
import os
import spaces
from pydub import AudioSegment
from typing import Tuple, Dict, List
from demucs.apply import apply_model
from demucs.separate import load_track
from demucs.pretrained import get_model
from demucs.audio import save_audio
device: str = "cuda" if torch.cuda.is_available() else "cpu"
# Define the inference function
@spaces.GPU
def inference(audio_file: str, model_name: str, vocals: bool, drums: bool, bass: bool, other: bool, mp3: bool, mp3_bitrate: int) -> Tuple[str, gr.HTML]:
separator = get_model(name=model_name)
def stream_log(message):
return f"<pre style='margin-bottom: 0;'>[{model_name}] {message}</pre>"
yield None, stream_log("Starting separation process...")
yield None, stream_log(f"Loading audio file: {audio_file}")
wav = load_track(audio_file, separator.samplerate, separator.audio_channels)
ref = wav.mean(0)
wav = (wav - ref.view(1, -1)).to(device)
yield None, stream_log("Audio loaded successfully. Applying model...")
sources = apply_model(separator, wav, device=device, progress=True)
sources = sources * ref.view(1, -1) + ref.view(1, -1)
yield None, stream_log("Model applied. Processing stems...")
output_dir: str = os.path.join("separated", model_name, os.path.splitext(os.path.basename(audio_file))[0])
os.makedirs(output_dir, exist_ok=True)
stems: Dict[str, str] = {}
for stem, source in zip(separator.sources, sources):
stem_path: str = os.path.join(output_dir, f"{stem}.wav")
save_audio(source, stem_path, separator.samplerate)
stems[stem] = stem_path
yield None, stream_log(f"Saved {stem} stem")
selected_stems: List[str] = [stems[stem] for stem, include in zip(["vocals", "drums", "bass", "other"], [vocals, drums, bass, other]) if include]
if not selected_stems:
raise gr.Error("Please select at least one stem to mix.")
output_file: str = os.path.join(output_dir, "mixed.wav")
yield None, stream_log("Mixing selected stems...")
if len(selected_stems) == 1:
os.rename(selected_stems[0], output_file)
else:
mixed_audio: AudioSegment = AudioSegment.empty()
for stem_path in selected_stems:
mixed_audio += AudioSegment.from_wav(stem_path)
mixed_audio.export(output_file, format="wav")
if mp3:
yield None, stream_log(f"Converting to MP3 (bitrate: {mp3_bitrate}k)...")
mp3_output_file: str = os.path.splitext(output_file)[0] + ".mp3"
mixed_audio.export(mp3_output_file, format="mp3", bitrate=str(mp3_bitrate) + "k")
output_file = mp3_output_file
yield None, stream_log("Process completed successfully!")
yield output_file, gr.HTML("<pre style='color: green;'>Separation and mixing completed successfully!</pre>")
# Define the Gradio interface
with gr.Blocks() as iface:
gr.Markdown("# Demucs Music Source Separation and Mixing")
gr.Markdown("Separate vocals, drums, bass, and other instruments from your music using Demucs and mix the selected stems.")
with gr.Row():
with gr.Column(scale=1):
audio_input = gr.Audio(type="filepath", label="Upload Audio File")
model_dropdown = gr.Dropdown(
["htdemucs", "htdemucs_ft", "htdemucs_6s", "hdemucs_mmi", "mdx", "mdx_extra", "mdx_q", "mdx_extra_q"],
label="Model Name",
value="htdemucs_ft"
)
with gr.Row():
vocals_checkbox = gr.Checkbox(label="Vocals", value=True)
drums_checkbox = gr.Checkbox(label="Drums", value=True)
with gr.Row():
bass_checkbox = gr.Checkbox(label="Bass", value=True)
other_checkbox = gr.Checkbox(label="Other", value=True)
mp3_checkbox = gr.Checkbox(label="Save as MP3", value=False)
mp3_bitrate = gr.Slider(128, 320, step=32, label="MP3 Bitrate", visible=False)
submit_btn = gr.Button("Process", variant="primary")
with gr.Column(scale=1):
output_audio = gr.Audio(type="filepath", label="Processed Audio")
separation_log = gr.HTML()
submit_btn.click(
fn=inference,
inputs=[audio_input, model_dropdown, vocals_checkbox, drums_checkbox, bass_checkbox, other_checkbox, mp3_checkbox, mp3_bitrate],
outputs=[output_audio, separation_log]
)
mp3_checkbox.change(
fn=lambda mp3: gr.update(visible=mp3),
inputs=mp3_checkbox,
outputs=mp3_bitrate
)
# Launch the Gradio interface
iface.launch()