owiedotch commited on
Commit
b072ff1
1 Parent(s): ff064e2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -60
app.py CHANGED
@@ -1,102 +1,65 @@
1
  import gradio as gr
2
  import torch
3
- # import demucs.api
4
  import os
5
  import spaces
6
- import subprocess
7
  from pydub import AudioSegment
8
  from typing import Tuple, Dict, List
9
  from demucs.apply import apply_model
10
  from demucs.separate import load_track
11
  from demucs.pretrained import get_model
12
- from demucs.audio import save_audio # Import save_audio from demucs.audio
13
 
14
- # check if cuda is available
15
  device: str = "cuda" if torch.cuda.is_available() else "cpu"
16
 
17
- # check if sox is installed and install it if necessary
18
- try:
19
- subprocess.run(["sox", "--version"], check=True, capture_output=True)
20
- except FileNotFoundError:
21
- print("sox is not installed. trying to install it now...")
22
- try:
23
- subprocess.run(["apt-get", "update"], check=True)
24
- subprocess.run(["apt-get", "install", "-y", "sox"], check=True)
25
- print("sox has been installed.")
26
- except subprocess.CalledProcessError as e:
27
- print(f"error installing sox: {e}")
28
- print("please install sox manually or try adding the following repository to your sources list:")
29
- print("deb http://deb.debian.org/debian stretch main contrib non-free")
30
- exit(1)
31
-
32
- # define the inference function
33
  @spaces.GPU
34
- def inference(audio_file: str, model_name: str, vocals: bool, drums: bool, bass: bool, other: bool, mp3: bool, mp3_bitrate: int) -> Tuple[str, str]:
35
- """
36
- performs inference using demucs and mixes the selected stems.
37
-
38
- args:
39
- audio_file: the audio file to separate.
40
- model_name: the name of the demucs model to use.
41
- vocals: whether to include vocals in the mix.
42
- drums: whether to include drums in the mix.
43
- bass: whether to include bass in the mix.
44
- other: whether to include other instruments in the mix.
45
- mp3: whether to save the output as mp3.
46
- mp3_bitrate: the bitrate of the output mp3 file.
47
-
48
- returns:
49
- a tuple containing the path to the mixed audio file and the separation log.
50
- """
51
-
52
- # initialize demucs separator
53
- # separator: demucs.api.Separator = demucs.api.Separator(model=model_name)
54
  separator = get_model(name=model_name)
55
 
56
- # separate audio file and capture log
57
- import io
58
- log_stream = io.StringIO()
59
- # origin, separated = separator.separate_audio_file(audio_file, progress=True, log_stream=log_stream)
60
- wav = load_track(audio_file, separator.samplerate, channels=separator.audio_channels)
 
61
  ref = wav.mean(0)
62
  wav = (wav - ref.view(1, -1)).to(device)
63
- sources = apply_model(separator, wav, device=device, progress=True, log_stream=log_stream)
 
64
  sources = sources * ref.view(1, -1) + ref.view(1, -1)
65
- separation_log = log_stream.getvalue()
66
 
67
- # get the output file paths
68
  output_dir: str = os.path.join("separated", model_name, os.path.splitext(os.path.basename(audio_file))[0])
69
- os.makedirs(output_dir, exist_ok=True) # create output directory if it doesn't exist
70
  stems: Dict[str, str] = {}
71
  for stem, source in zip(separator.sources, sources):
72
  stem_path: str = os.path.join(output_dir, f"{stem}.wav")
73
- # demucs.api.save_audio(source, stem_path, samplerate=separator.samplerate)
74
- save_audio(source, stem_path, separator.samplerate) # Use save_audio
75
  stems[stem] = stem_path
 
76
 
77
- # mix the selected stems
78
  selected_stems: List[str] = [stems[stem] for stem, include in zip(["vocals", "drums", "bass", "other"], [vocals, drums, bass, other]) if include]
79
  if not selected_stems:
80
- raise gr.Error("please select at least one stem to mix.")
81
 
82
  output_file: str = os.path.join(output_dir, "mixed.wav")
 
83
  if len(selected_stems) == 1:
84
- # if only one stem is selected, just copy it
85
  os.rename(selected_stems[0], output_file)
86
  else:
87
- # otherwise, use pydub to mix the stems
88
  mixed_audio: AudioSegment = AudioSegment.empty()
89
  for stem_path in selected_stems:
90
  mixed_audio += AudioSegment.from_wav(stem_path)
91
  mixed_audio.export(output_file, format="wav")
92
 
93
- # automatically convert to mp3 if requested
94
  if mp3:
 
95
  mp3_output_file: str = os.path.splitext(output_file)[0] + ".mp3"
96
  mixed_audio.export(mp3_output_file, format="mp3", bitrate=str(mp3_bitrate) + "k")
97
- output_file = mp3_output_file # update output_file to the mp3 file
98
 
99
- return output_file, separation_log
 
100
 
101
  # Define the Gradio interface
102
  with gr.Blocks() as iface:
@@ -123,7 +86,7 @@ with gr.Blocks() as iface:
123
 
124
  with gr.Column(scale=1):
125
  output_audio = gr.Audio(type="filepath", label="Processed Audio")
126
- separation_log = gr.Textbox(label="Separation Log", lines=10)
127
 
128
  submit_btn.click(
129
  fn=inference,
@@ -131,7 +94,6 @@ with gr.Blocks() as iface:
131
  outputs=[output_audio, separation_log]
132
  )
133
 
134
- # Make MP3 bitrate slider visible only when "Save as MP3" is checked
135
  mp3_checkbox.change(
136
  fn=lambda mp3: gr.update(visible=mp3),
137
  inputs=mp3_checkbox,
 
1
  import gradio as gr
2
  import torch
 
3
  import os
4
  import spaces
 
5
  from pydub import AudioSegment
6
  from typing import Tuple, Dict, List
7
  from demucs.apply import apply_model
8
  from demucs.separate import load_track
9
  from demucs.pretrained import get_model
10
+ from demucs.audio import save_audio
11
 
 
12
  device: str = "cuda" if torch.cuda.is_available() else "cpu"
13
 
14
+ # Define the inference function
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  @spaces.GPU
16
+ def inference(audio_file: str, model_name: str, vocals: bool, drums: bool, bass: bool, other: bool, mp3: bool, mp3_bitrate: int) -> Tuple[str, gr.HTML]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  separator = get_model(name=model_name)
18
 
19
+ def stream_log(message):
20
+ return f"<pre style='margin-bottom: 0;'>[{model_name}] {message}</pre>"
21
+
22
+ yield None, stream_log("Starting separation process...")
23
+ yield None, stream_log(f"Loading audio file: {audio_file}")
24
+ wav = load_track(audio_file, separator.samplerate, separator.audio_channels)
25
  ref = wav.mean(0)
26
  wav = (wav - ref.view(1, -1)).to(device)
27
+ yield None, stream_log("Audio loaded successfully. Applying model...")
28
+ sources = apply_model(separator, wav, device=device, progress=True)
29
  sources = sources * ref.view(1, -1) + ref.view(1, -1)
30
+ yield None, stream_log("Model applied. Processing stems...")
31
 
 
32
  output_dir: str = os.path.join("separated", model_name, os.path.splitext(os.path.basename(audio_file))[0])
33
+ os.makedirs(output_dir, exist_ok=True)
34
  stems: Dict[str, str] = {}
35
  for stem, source in zip(separator.sources, sources):
36
  stem_path: str = os.path.join(output_dir, f"{stem}.wav")
37
+ save_audio(source, stem_path, separator.samplerate)
 
38
  stems[stem] = stem_path
39
+ yield None, stream_log(f"Saved {stem} stem")
40
 
 
41
  selected_stems: List[str] = [stems[stem] for stem, include in zip(["vocals", "drums", "bass", "other"], [vocals, drums, bass, other]) if include]
42
  if not selected_stems:
43
+ raise gr.Error("Please select at least one stem to mix.")
44
 
45
  output_file: str = os.path.join(output_dir, "mixed.wav")
46
+ yield None, stream_log("Mixing selected stems...")
47
  if len(selected_stems) == 1:
 
48
  os.rename(selected_stems[0], output_file)
49
  else:
 
50
  mixed_audio: AudioSegment = AudioSegment.empty()
51
  for stem_path in selected_stems:
52
  mixed_audio += AudioSegment.from_wav(stem_path)
53
  mixed_audio.export(output_file, format="wav")
54
 
 
55
  if mp3:
56
+ yield None, stream_log(f"Converting to MP3 (bitrate: {mp3_bitrate}k)...")
57
  mp3_output_file: str = os.path.splitext(output_file)[0] + ".mp3"
58
  mixed_audio.export(mp3_output_file, format="mp3", bitrate=str(mp3_bitrate) + "k")
59
+ output_file = mp3_output_file
60
 
61
+ yield None, stream_log("Process completed successfully!")
62
+ yield output_file, gr.HTML("<pre style='color: green;'>Separation and mixing completed successfully!</pre>")
63
 
64
  # Define the Gradio interface
65
  with gr.Blocks() as iface:
 
86
 
87
  with gr.Column(scale=1):
88
  output_audio = gr.Audio(type="filepath", label="Processed Audio")
89
+ separation_log = gr.HTML()
90
 
91
  submit_btn.click(
92
  fn=inference,
 
94
  outputs=[output_audio, separation_log]
95
  )
96
 
 
97
  mp3_checkbox.change(
98
  fn=lambda mp3: gr.update(visible=mp3),
99
  inputs=mp3_checkbox,