File size: 6,207 Bytes
1ddb890
 
8153f3b
1ddb890
 
0479c04
99f4e90
 
8153f3b
d4b0a5f
8153f3b
d4b0a5f
1ddb890
99f4e90
 
1ddb890
99f4e90
37de889
 
 
99f4e90
eb814f3
 
 
 
 
99f4e90
 
eb814f3
 
37de889
99f4e90
1ddb890
99f4e90
1ddb890
99f4e90
1ddb890
99f4e90
 
 
 
 
 
 
 
 
1ddb890
99f4e90
 
1ddb890
 
99f4e90
8153f3b
 
1ddb890
99f4e90
 
 
8153f3b
 
 
 
 
 
99f4e90
1ddb890
99f4e90
 
 
 
8153f3b
99f4e90
8153f3b
d4b0a5f
99f4e90
1ddb890
99f4e90
 
0479c04
99f4e90
0479c04
99f4e90
0479c04
99f4e90
0479c04
 
99f4e90
 
 
 
 
0479c04
99f4e90
eb814f3
99f4e90
 
 
eb814f3
99f4e90
1ddb890
7c43d6c
 
ff064e2
 
7c43d6c
ff064e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7c43d6c
 
 
 
 
 
 
 
 
 
 
 
1ddb890
7c43d6c
1ddb890
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import gradio as gr
import torch
# import demucs.api
import os
import spaces
import subprocess
from pydub import AudioSegment
from typing import Tuple, Dict, List
from demucs.apply import apply_model
from demucs.separate import load_track
from demucs.pretrained import get_model
from demucs.audio import save_audio  # Import save_audio from demucs.audio

# check if cuda is available
device: str = "cuda" if torch.cuda.is_available() else "cpu"

# check if sox is installed and install it if necessary
try:
    subprocess.run(["sox", "--version"], check=True, capture_output=True)
except FileNotFoundError:
    print("sox is not installed. trying to install it now...")
    try:
        subprocess.run(["apt-get", "update"], check=True)
        subprocess.run(["apt-get", "install", "-y", "sox"], check=True)
        print("sox has been installed.")
    except subprocess.CalledProcessError as e:
        print(f"error installing sox: {e}")
        print("please install sox manually or try adding the following repository to your sources list:")
        print("deb http://deb.debian.org/debian stretch main contrib non-free")
        exit(1)

# define the inference function
@spaces.GPU
def inference(audio_file: str, model_name: str, vocals: bool, drums: bool, bass: bool, other: bool, mp3: bool, mp3_bitrate: int) -> Tuple[str, str]:
    """
    performs inference using demucs and mixes the selected stems.

    args:
        audio_file: the audio file to separate.
        model_name: the name of the demucs model to use.
        vocals: whether to include vocals in the mix.
        drums: whether to include drums in the mix.
        bass: whether to include bass in the mix.
        other: whether to include other instruments in the mix.
        mp3: whether to save the output as mp3.
        mp3_bitrate: the bitrate of the output mp3 file.

    returns:
        a tuple containing the path to the mixed audio file and the separation log.
    """

    # initialize demucs separator
    # separator: demucs.api.Separator = demucs.api.Separator(model=model_name)
    separator = get_model(name=model_name)

    # separate audio file and capture log
    import io
    log_stream = io.StringIO()
    # origin, separated = separator.separate_audio_file(audio_file, progress=True, log_stream=log_stream)
    wav = load_track(audio_file, separator.samplerate, channels=separator.audio_channels)
    ref = wav.mean(0)
    wav = (wav - ref.view(1, -1)).to(device)
    sources = apply_model(separator, wav, device=device, progress=True, log_stream=log_stream)
    sources = sources * ref.view(1, -1) + ref.view(1, -1)
    separation_log = log_stream.getvalue()

    # get the output file paths
    output_dir: str = os.path.join("separated", model_name, os.path.splitext(os.path.basename(audio_file))[0])
    os.makedirs(output_dir, exist_ok=True)  # create output directory if it doesn't exist
    stems: Dict[str, str] = {}
    for stem, source in zip(separator.sources, sources):
        stem_path: str = os.path.join(output_dir, f"{stem}.wav")
        # demucs.api.save_audio(source, stem_path, samplerate=separator.samplerate)
        save_audio(source, stem_path, separator.samplerate)  # Use save_audio
        stems[stem] = stem_path

    # mix the selected stems
    selected_stems: List[str] = [stems[stem] for stem, include in zip(["vocals", "drums", "bass", "other"], [vocals, drums, bass, other]) if include]
    if not selected_stems:
        raise gr.Error("please select at least one stem to mix.")

    output_file: str = os.path.join(output_dir, "mixed.wav")
    if len(selected_stems) == 1:
        # if only one stem is selected, just copy it
        os.rename(selected_stems[0], output_file)
    else:
        # otherwise, use pydub to mix the stems
        mixed_audio: AudioSegment = AudioSegment.empty()
        for stem_path in selected_stems:
            mixed_audio += AudioSegment.from_wav(stem_path)
        mixed_audio.export(output_file, format="wav")

    # automatically convert to mp3 if requested
    if mp3:
        mp3_output_file: str = os.path.splitext(output_file)[0] + ".mp3"
        mixed_audio.export(mp3_output_file, format="mp3", bitrate=str(mp3_bitrate) + "k")
        output_file = mp3_output_file  # update output_file to the mp3 file

    return output_file, separation_log

# Define the Gradio interface
with gr.Blocks() as iface:
    gr.Markdown("# Demucs Music Source Separation and Mixing")
    gr.Markdown("Separate vocals, drums, bass, and other instruments from your music using Demucs and mix the selected stems.")
    
    with gr.Row():
        with gr.Column(scale=1):
            audio_input = gr.Audio(type="filepath", label="Upload Audio File")
            model_dropdown = gr.Dropdown(
                ["htdemucs", "htdemucs_ft", "htdemucs_6s", "hdemucs_mmi", "mdx", "mdx_extra", "mdx_q", "mdx_extra_q"],
                label="Model Name",
                value="htdemucs_ft"
            )
            with gr.Row():
                vocals_checkbox = gr.Checkbox(label="Vocals", value=True)
                drums_checkbox = gr.Checkbox(label="Drums", value=True)
            with gr.Row():
                bass_checkbox = gr.Checkbox(label="Bass", value=True)
                other_checkbox = gr.Checkbox(label="Other", value=True)
            mp3_checkbox = gr.Checkbox(label="Save as MP3", value=False)
            mp3_bitrate = gr.Slider(128, 320, step=32, label="MP3 Bitrate", visible=False)
            submit_btn = gr.Button("Process", variant="primary")

        with gr.Column(scale=1):
            output_audio = gr.Audio(type="filepath", label="Processed Audio")
            separation_log = gr.Textbox(label="Separation Log", lines=10)

    submit_btn.click(
        fn=inference,
        inputs=[audio_input, model_dropdown, vocals_checkbox, drums_checkbox, bass_checkbox, other_checkbox, mp3_checkbox, mp3_bitrate],
        outputs=[output_audio, separation_log]
    )
    
    # Make MP3 bitrate slider visible only when "Save as MP3" is checked
    mp3_checkbox.change(
        fn=lambda mp3: gr.update(visible=mp3),
        inputs=mp3_checkbox,
        outputs=mp3_bitrate
    )

# Launch the Gradio interface
iface.launch()