File size: 5,663 Bytes
1ddb890
 
 
 
99f4e90
 
8153f3b
d4b0a5f
8153f3b
b072ff1
1ddb890
99f4e90
1ddb890
b072ff1
1ddb890
b072ff1
8153f3b
1ddb890
b072ff1
 
 
 
 
ff4e567
2a231ba
 
 
 
 
1cfa594
2a231ba
 
 
 
 
 
58ff621
 
 
 
 
 
 
ff4e567
58ff621
 
8153f3b
58ff621
b072ff1
7cbf8ae
97a9102
2a231ba
 
 
 
 
 
97a9102
7d031a6
7cbf8ae
b072ff1
1ddb890
99f4e90
b072ff1
99f4e90
8153f3b
99f4e90
b072ff1
99f4e90
b072ff1
1ddb890
99f4e90
0479c04
b072ff1
0479c04
99f4e90
b072ff1
0479c04
 
 
99f4e90
 
 
 
0479c04
eb814f3
b072ff1
99f4e90
 
b072ff1
eb814f3
b072ff1
 
1ddb890
7c43d6c
 
ff064e2
 
7c43d6c
ff064e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b072ff1
ff064e2
7c43d6c
 
 
 
 
 
 
 
 
 
 
1ddb890
7c43d6c
1ddb890
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import gradio as gr
import torch
import os
import spaces
from pydub import AudioSegment
from typing import Tuple, Dict, List
from demucs.apply import apply_model
from demucs.separate import load_track
from demucs.pretrained import get_model
from demucs.audio import save_audio

device: str = "cuda" if torch.cuda.is_available() else "cpu"

# Define the inference function
@spaces.GPU
def inference(audio_file: str, model_name: str, vocals: bool, drums: bool, bass: bool, other: bool, mp3: bool, mp3_bitrate: int) -> Tuple[str, gr.HTML]:
    separator = get_model(name=model_name)

    def stream_log(message):
        return f"<pre style='margin-bottom: 0;'>[{model_name}] {message}</pre>"

    yield None, stream_log("Starting separation process...")
    yield None, stream_log(f"Loading audio file: {audio_file}")
    
    # Check if audio_file is None
    if audio_file is None:
        yield None, stream_log("Error: No audio file provided")
        raise gr.Error("Please upload an audio file")

    # Load the audio file with the correct samplerate and audio channels
    try:
        wav, sr = load_track(audio_file, samplerate=separator.samplerate, audio_channels=2)
    except Exception as e:
        yield None, stream_log(f"Error loading audio file: {str(e)}")
        raise gr.Error(f"Failed to load audio file: {str(e)}")

    # Check the number of channels and adjust if necessary
    if wav.dim() == 1:
        wav = wav.unsqueeze(0)  # Add channel dimension if mono
    if wav.shape[0] == 1:
        wav = wav.repeat(2, 1)  # If mono, duplicate to stereo
    elif wav.shape[0] > 2:
        wav = wav[:2]  # If more than 2 channels, keep only the first two

    wav = wav.to(device)
    
    ref = wav.mean(0)
    wav = (wav - ref.view(1, -1))
    yield None, stream_log("Audio loaded successfully. Applying model...")
    
    # Use apply_model as a standalone function
    try:
        sources = apply_model(separator, wav.to(device), device=device)
    except Exception as e:
        yield None, stream_log(f"Error applying model: {str(e)}")
        raise gr.Error(f"Failed to apply model: {str(e)}")

    # Process the sources
    sources = [source * ref.view(1, -1) + ref.view(1, -1) for source in sources]
    
    yield None, stream_log("Model applied. Processing stems...")

    output_dir: str = os.path.join("separated", model_name, os.path.splitext(os.path.basename(audio_file))[0])
    os.makedirs(output_dir, exist_ok=True)
    stems: Dict[str, str] = {}
    for stem, source in zip(separator.sources, sources):
        stem_path: str = os.path.join(output_dir, f"{stem}.wav")
        save_audio(source, stem_path, separator.samplerate)
        stems[stem] = stem_path
        yield None, stream_log(f"Saved {stem} stem")

    selected_stems: List[str] = [stems[stem] for stem, include in zip(["vocals", "drums", "bass", "other"], [vocals, drums, bass, other]) if include]
    if not selected_stems:
        raise gr.Error("Please select at least one stem to mix.")

    output_file: str = os.path.join(output_dir, "mixed.wav")
    yield None, stream_log("Mixing selected stems...")
    if len(selected_stems) == 1:
        os.rename(selected_stems[0], output_file)
    else:
        mixed_audio: AudioSegment = AudioSegment.empty()
        for stem_path in selected_stems:
            mixed_audio += AudioSegment.from_wav(stem_path)
        mixed_audio.export(output_file, format="wav")

    if mp3:
        yield None, stream_log(f"Converting to MP3 (bitrate: {mp3_bitrate}k)...")
        mp3_output_file: str = os.path.splitext(output_file)[0] + ".mp3"
        mixed_audio.export(mp3_output_file, format="mp3", bitrate=str(mp3_bitrate) + "k")
        output_file = mp3_output_file

    yield None, stream_log("Process completed successfully!")
    yield output_file, gr.HTML("<pre style='color: green;'>Separation and mixing completed successfully!</pre>")

# Define the Gradio interface
with gr.Blocks() as iface:
    gr.Markdown("# Demucs Music Source Separation and Mixing")
    gr.Markdown("Separate vocals, drums, bass, and other instruments from your music using Demucs and mix the selected stems.")
    
    with gr.Row():
        with gr.Column(scale=1):
            audio_input = gr.Audio(type="filepath", label="Upload Audio File")
            model_dropdown = gr.Dropdown(
                ["htdemucs", "htdemucs_ft", "htdemucs_6s", "hdemucs_mmi", "mdx", "mdx_extra", "mdx_q", "mdx_extra_q"],
                label="Model Name",
                value="htdemucs_ft"
            )
            with gr.Row():
                vocals_checkbox = gr.Checkbox(label="Vocals", value=True)
                drums_checkbox = gr.Checkbox(label="Drums", value=True)
            with gr.Row():
                bass_checkbox = gr.Checkbox(label="Bass", value=True)
                other_checkbox = gr.Checkbox(label="Other", value=True)
            mp3_checkbox = gr.Checkbox(label="Save as MP3", value=False)
            mp3_bitrate = gr.Slider(128, 320, step=32, label="MP3 Bitrate", visible=False)
            submit_btn = gr.Button("Process", variant="primary")

        with gr.Column(scale=1):
            output_audio = gr.Audio(type="filepath", label="Processed Audio")
            separation_log = gr.HTML()

    submit_btn.click(
        fn=inference,
        inputs=[audio_input, model_dropdown, vocals_checkbox, drums_checkbox, bass_checkbox, other_checkbox, mp3_checkbox, mp3_bitrate],
        outputs=[output_audio, separation_log]
    )
    
    mp3_checkbox.change(
        fn=lambda mp3: gr.update(visible=mp3),
        inputs=mp3_checkbox,
        outputs=mp3_bitrate
    )

# Launch the Gradio interface
iface.launch()