owiedotch's picture
Update app.py
5e53171 verified
raw
history blame
No virus
6.87 kB
import gradio as gr
import torch
import os
import spaces
from pydub import AudioSegment
from typing import Tuple, Dict, List
from demucs.apply import apply_model
from demucs.separate import load_track
from demucs.pretrained import get_model
from demucs.audio import save_audio
device: str = "cuda" if torch.cuda.is_available() else "cpu"
# Define the inference function
@spaces.GPU
def inference(audio_file: str, model_name: str, vocals: bool, drums: bool, bass: bool, other: bool, mp3: bool, mp3_bitrate: int) -> Tuple[str, gr.HTML]:
separator = get_model(name=model_name)
log_messages = []
def stream_log(message):
log_messages.append(f"[{model_name}] {message}")
return gr.HTML("<pre style='margin-bottom: 0;'>" + "<br>".join(log_messages) + "</pre>")
yield None, stream_log("Starting separation process...")
yield None, stream_log(f"Loading audio file: {audio_file}")
# Check if audio_file is None
if audio_file is None:
yield None, stream_log("Error: No audio file provided")
raise gr.Error("Please upload an audio file")
# Load the audio file with the correct samplerate and audio channels
try:
wav, sr = load_track(audio_file, samplerate=separator.samplerate, audio_channels=2)
except Exception as e:
yield None, stream_log(f"Error loading audio file: {str(e)}")
raise gr.Error(f"Failed to load audio file: {str(e)}")
# Check the number of channels and adjust if necessary
if wav.dim() == 1:
wav = wav.unsqueeze(0) # Add channel dimension if mono
if wav.shape[0] == 1:
wav = wav.repeat(2, 1) # If mono, duplicate to stereo
elif wav.shape[0] > 2:
wav = wav[:2] # If more than 2 channels, keep only the first two
wav = wav.to(device)
ref = wav.mean(0)
wav = (wav - ref.view(1, -1))
yield None, stream_log("Audio loaded successfully. Applying model...")
# Use apply_model as a standalone function
try:
result = apply_model(separator, wav.to(device), device=device)
yield None, stream_log(f"Model application result type: {type(result)}")
yield None, stream_log(f"Model application result shape: {result.shape if hasattr(result, 'shape') else 'N/A'}")
if isinstance(result, tuple) and len(result) == 2:
sources, _ = result
elif isinstance(result, torch.Tensor):
sources = result
else:
raise ValueError(f"Unexpected result type from apply_model: {type(result)}")
yield None, stream_log(f"Sources shape: {sources.shape}")
except ValueError as e:
yield None, stream_log(f"Error applying model: {str(e)}")
yield None, stream_log(f"Separator sources: {separator.sources}")
yield None, stream_log(f"WAV shape: {wav.shape}")
yield None, stream_log(f"Separator model: {separator.__class__.__name__}")
yield None, stream_log(f"Separator config: {separator.config}")
raise gr.Error(f"Failed to apply model: {str(e)}. This might be due to incompatible audio format or model configuration.")
except Exception as e:
yield None, stream_log(f"Unexpected error applying model: {str(e)}")
raise gr.Error(f"An unexpected error occurred while applying the model: {str(e)}")
# Process the sources
sources = [source * ref.view(1, -1) + ref.view(1, -1) for source in sources]
yield None, stream_log("Model applied. Processing stems...")
output_dir: str = os.path.join("separated", model_name, os.path.splitext(os.path.basename(audio_file))[0])
os.makedirs(output_dir, exist_ok=True)
stems: Dict[str, str] = {}
for stem, source in zip(separator.sources, sources):
stem_path: str = os.path.join(output_dir, f"{stem}.wav")
save_audio(source, stem_path, separator.samplerate)
stems[stem] = stem_path
yield None, stream_log(f"Saved {stem} stem")
selected_stems: List[str] = [stems[stem] for stem, include in zip(["vocals", "drums", "bass", "other"], [vocals, drums, bass, other]) if include]
if not selected_stems:
raise gr.Error("Please select at least one stem to mix.")
output_file: str = os.path.join(output_dir, "mixed.wav")
yield None, stream_log("Mixing selected stems...")
if len(selected_stems) == 1:
os.rename(selected_stems[0], output_file)
else:
mixed_audio: AudioSegment = AudioSegment.empty()
for stem_path in selected_stems:
mixed_audio += AudioSegment.from_wav(stem_path)
mixed_audio.export(output_file, format="wav")
if mp3:
yield None, stream_log(f"Converting to MP3 (bitrate: {mp3_bitrate}k)...")
mp3_output_file: str = os.path.splitext(output_file)[0] + ".mp3"
mixed_audio.export(mp3_output_file, format="mp3", bitrate=str(mp3_bitrate) + "k")
output_file = mp3_output_file
yield None, stream_log("Process completed successfully!")
yield output_file, gr.HTML("<pre style='color: green;'>Separation and mixing completed successfully!</pre>")
# Define the Gradio interface
with gr.Blocks() as iface:
gr.Markdown("# Demucs Music Source Separation and Mixing")
gr.Markdown("Separate vocals, drums, bass, and other instruments from your music using Demucs and mix the selected stems.")
with gr.Row():
with gr.Column(scale=1):
audio_input = gr.Audio(type="filepath", label="Upload Audio File")
model_dropdown = gr.Dropdown(
["htdemucs", "htdemucs_ft", "htdemucs_6s", "hdemucs_mmi", "mdx", "mdx_extra", "mdx_q", "mdx_extra_q"],
label="Model Name",
value="htdemucs_ft"
)
with gr.Row():
vocals_checkbox = gr.Checkbox(label="Vocals", value=True)
drums_checkbox = gr.Checkbox(label="Drums", value=True)
with gr.Row():
bass_checkbox = gr.Checkbox(label="Bass", value=True)
other_checkbox = gr.Checkbox(label="Other", value=True)
mp3_checkbox = gr.Checkbox(label="Save as MP3", value=False)
mp3_bitrate = gr.Slider(128, 320, step=32, label="MP3 Bitrate", visible=False)
submit_btn = gr.Button("Process", variant="primary")
with gr.Column(scale=1):
output_audio = gr.Audio(type="filepath", label="Processed Audio")
separation_log = gr.HTML()
submit_btn.click(
fn=inference,
inputs=[audio_input, model_dropdown, vocals_checkbox, drums_checkbox, bass_checkbox, other_checkbox, mp3_checkbox, mp3_bitrate],
outputs=[output_audio, separation_log]
)
mp3_checkbox.change(
fn=lambda mp3: gr.update(visible=mp3),
inputs=mp3_checkbox,
outputs=mp3_bitrate
)
# Launch the Gradio interface
iface.launch()