File size: 5,153 Bytes
1ddb890
 
 
 
99f4e90
 
8153f3b
d4b0a5f
8153f3b
b072ff1
1ddb890
99f4e90
1ddb890
b072ff1
1ddb890
b072ff1
8153f3b
1ddb890
b072ff1
 
 
 
 
ff4e567
1cfa594
 
ff4e567
58ff621
 
 
 
 
 
 
ff4e567
58ff621
 
8153f3b
58ff621
b072ff1
7cbf8ae
97a9102
 
7cbf8ae
97a9102
7d031a6
7cbf8ae
b072ff1
1ddb890
99f4e90
b072ff1
99f4e90
8153f3b
99f4e90
b072ff1
99f4e90
b072ff1
1ddb890
99f4e90
0479c04
b072ff1
0479c04
99f4e90
b072ff1
0479c04
 
 
99f4e90
 
 
 
0479c04
eb814f3
b072ff1
99f4e90
 
b072ff1
eb814f3
b072ff1
 
1ddb890
7c43d6c
 
ff064e2
 
7c43d6c
ff064e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b072ff1
ff064e2
7c43d6c
 
 
 
 
 
 
 
 
 
 
1ddb890
7c43d6c
1ddb890
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import gradio as gr
import torch
import os
import spaces
from pydub import AudioSegment
from typing import Tuple, Dict, List
from demucs.apply import apply_model
from demucs.separate import load_track
from demucs.pretrained import get_model
from demucs.audio import save_audio

device: str = "cuda" if torch.cuda.is_available() else "cpu"

# Define the inference function
@spaces.GPU
def inference(audio_file: str, model_name: str, vocals: bool, drums: bool, bass: bool, other: bool, mp3: bool, mp3_bitrate: int) -> Tuple[str, gr.HTML]:
    separator = get_model(name=model_name)

    def stream_log(message):
        return f"<pre style='margin-bottom: 0;'>[{model_name}] {message}</pre>"

    yield None, stream_log("Starting separation process...")
    yield None, stream_log(f"Loading audio file: {audio_file}")
    
    # Load the audio file with the correct samplerate and audio channels
    wav, sr = load_track(audio_file, samplerate=separator.samplerate, audio_channels=2)
    
    # Check the number of channels and adjust if necessary
    if wav.dim() == 1:
        wav = wav.unsqueeze(0)  # Add channel dimension if mono
    if wav.shape[0] == 1:
        wav = wav.repeat(2, 1)  # If mono, duplicate to stereo
    elif wav.shape[0] > 2:
        wav = wav[:2]  # If more than 2 channels, keep only the first two

    wav = wav.to(device)
    
    ref = wav.mean(0)
    wav = (wav - ref.view(1, -1))
    yield None, stream_log("Audio loaded successfully. Applying model...")
    
    # Use apply_model as a standalone function
    sources = apply_model(separator, wav.to(device), device=device)
    
    # Process the sources
    sources = [source * ref.view(1, -1) + ref.view(1, -1) for source in sources]
    
    yield None, stream_log("Model applied. Processing stems...")

    output_dir: str = os.path.join("separated", model_name, os.path.splitext(os.path.basename(audio_file))[0])
    os.makedirs(output_dir, exist_ok=True)
    stems: Dict[str, str] = {}
    for stem, source in zip(separator.sources, sources):
        stem_path: str = os.path.join(output_dir, f"{stem}.wav")
        save_audio(source, stem_path, separator.samplerate)
        stems[stem] = stem_path
        yield None, stream_log(f"Saved {stem} stem")

    selected_stems: List[str] = [stems[stem] for stem, include in zip(["vocals", "drums", "bass", "other"], [vocals, drums, bass, other]) if include]
    if not selected_stems:
        raise gr.Error("Please select at least one stem to mix.")

    output_file: str = os.path.join(output_dir, "mixed.wav")
    yield None, stream_log("Mixing selected stems...")
    if len(selected_stems) == 1:
        os.rename(selected_stems[0], output_file)
    else:
        mixed_audio: AudioSegment = AudioSegment.empty()
        for stem_path in selected_stems:
            mixed_audio += AudioSegment.from_wav(stem_path)
        mixed_audio.export(output_file, format="wav")

    if mp3:
        yield None, stream_log(f"Converting to MP3 (bitrate: {mp3_bitrate}k)...")
        mp3_output_file: str = os.path.splitext(output_file)[0] + ".mp3"
        mixed_audio.export(mp3_output_file, format="mp3", bitrate=str(mp3_bitrate) + "k")
        output_file = mp3_output_file

    yield None, stream_log("Process completed successfully!")
    yield output_file, gr.HTML("<pre style='color: green;'>Separation and mixing completed successfully!</pre>")

# Define the Gradio interface
with gr.Blocks() as iface:
    gr.Markdown("# Demucs Music Source Separation and Mixing")
    gr.Markdown("Separate vocals, drums, bass, and other instruments from your music using Demucs and mix the selected stems.")
    
    with gr.Row():
        with gr.Column(scale=1):
            audio_input = gr.Audio(type="filepath", label="Upload Audio File")
            model_dropdown = gr.Dropdown(
                ["htdemucs", "htdemucs_ft", "htdemucs_6s", "hdemucs_mmi", "mdx", "mdx_extra", "mdx_q", "mdx_extra_q"],
                label="Model Name",
                value="htdemucs_ft"
            )
            with gr.Row():
                vocals_checkbox = gr.Checkbox(label="Vocals", value=True)
                drums_checkbox = gr.Checkbox(label="Drums", value=True)
            with gr.Row():
                bass_checkbox = gr.Checkbox(label="Bass", value=True)
                other_checkbox = gr.Checkbox(label="Other", value=True)
            mp3_checkbox = gr.Checkbox(label="Save as MP3", value=False)
            mp3_bitrate = gr.Slider(128, 320, step=32, label="MP3 Bitrate", visible=False)
            submit_btn = gr.Button("Process", variant="primary")

        with gr.Column(scale=1):
            output_audio = gr.Audio(type="filepath", label="Processed Audio")
            separation_log = gr.HTML()

    submit_btn.click(
        fn=inference,
        inputs=[audio_input, model_dropdown, vocals_checkbox, drums_checkbox, bass_checkbox, other_checkbox, mp3_checkbox, mp3_bitrate],
        outputs=[output_audio, separation_log]
    )
    
    mp3_checkbox.change(
        fn=lambda mp3: gr.update(visible=mp3),
        inputs=mp3_checkbox,
        outputs=mp3_bitrate
    )

# Launch the Gradio interface
iface.launch()