owiedotch commited on
Commit
ac0c0ce
1 Parent(s): 5e53171

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -65
app.py CHANGED
@@ -1,20 +1,10 @@
1
  import gradio as gr
2
- import torch
3
  import os
4
- import spaces
 
5
  from pydub import AudioSegment
6
- from typing import Tuple, Dict, List
7
- from demucs.apply import apply_model
8
- from demucs.separate import load_track
9
- from demucs.pretrained import get_model
10
- from demucs.audio import save_audio
11
 
12
- device: str = "cuda" if torch.cuda.is_available() else "cpu"
13
-
14
- # Define the inference function
15
- @spaces.GPU
16
  def inference(audio_file: str, model_name: str, vocals: bool, drums: bool, bass: bool, other: bool, mp3: bool, mp3_bitrate: int) -> Tuple[str, gr.HTML]:
17
- separator = get_model(name=model_name)
18
  log_messages = []
19
 
20
  def stream_log(message):
@@ -24,77 +14,60 @@ def inference(audio_file: str, model_name: str, vocals: bool, drums: bool, bass:
24
  yield None, stream_log("Starting separation process...")
25
  yield None, stream_log(f"Loading audio file: {audio_file}")
26
 
27
- # Check if audio_file is None
28
  if audio_file is None:
29
  yield None, stream_log("Error: No audio file provided")
30
  raise gr.Error("Please upload an audio file")
31
 
32
- # Load the audio file with the correct samplerate and audio channels
33
- try:
34
- wav, sr = load_track(audio_file, samplerate=separator.samplerate, audio_channels=2)
35
- except Exception as e:
36
- yield None, stream_log(f"Error loading audio file: {str(e)}")
37
- raise gr.Error(f"Failed to load audio file: {str(e)}")
38
-
39
- # Check the number of channels and adjust if necessary
40
- if wav.dim() == 1:
41
- wav = wav.unsqueeze(0) # Add channel dimension if mono
42
- if wav.shape[0] == 1:
43
- wav = wav.repeat(2, 1) # If mono, duplicate to stereo
44
- elif wav.shape[0] > 2:
45
- wav = wav[:2] # If more than 2 channels, keep only the first two
46
-
47
- wav = wav.to(device)
48
-
49
- ref = wav.mean(0)
50
- wav = (wav - ref.view(1, -1))
51
- yield None, stream_log("Audio loaded successfully. Applying model...")
52
 
53
- # Use apply_model as a standalone function
54
  try:
55
- result = apply_model(separator, wav.to(device), device=device)
56
- yield None, stream_log(f"Model application result type: {type(result)}")
57
- yield None, stream_log(f"Model application result shape: {result.shape if hasattr(result, 'shape') else 'N/A'}")
58
 
59
- if isinstance(result, tuple) and len(result) == 2:
60
- sources, _ = result
61
- elif isinstance(result, torch.Tensor):
62
- sources = result
63
- else:
64
- raise ValueError(f"Unexpected result type from apply_model: {type(result)}")
65
 
66
- yield None, stream_log(f"Sources shape: {sources.shape}")
67
- except ValueError as e:
68
- yield None, stream_log(f"Error applying model: {str(e)}")
69
- yield None, stream_log(f"Separator sources: {separator.sources}")
70
- yield None, stream_log(f"WAV shape: {wav.shape}")
71
- yield None, stream_log(f"Separator model: {separator.__class__.__name__}")
72
- yield None, stream_log(f"Separator config: {separator.config}")
73
- raise gr.Error(f"Failed to apply model: {str(e)}. This might be due to incompatible audio format or model configuration.")
 
74
  except Exception as e:
75
- yield None, stream_log(f"Unexpected error applying model: {str(e)}")
76
- raise gr.Error(f"An unexpected error occurred while applying the model: {str(e)}")
77
 
78
- # Process the sources
79
- sources = [source * ref.view(1, -1) + ref.view(1, -1) for source in sources]
80
-
81
- yield None, stream_log("Model applied. Processing stems...")
82
 
83
- output_dir: str = os.path.join("separated", model_name, os.path.splitext(os.path.basename(audio_file))[0])
84
- os.makedirs(output_dir, exist_ok=True)
85
  stems: Dict[str, str] = {}
86
- for stem, source in zip(separator.sources, sources):
87
- stem_path: str = os.path.join(output_dir, f"{stem}.wav")
88
- save_audio(source, stem_path, separator.samplerate)
89
- stems[stem] = stem_path
90
- yield None, stream_log(f"Saved {stem} stem")
91
 
92
- selected_stems: List[str] = [stems[stem] for stem, include in zip(["vocals", "drums", "bass", "other"], [vocals, drums, bass, other]) if include]
93
  if not selected_stems:
94
  raise gr.Error("Please select at least one stem to mix.")
95
 
96
  output_file: str = os.path.join(output_dir, "mixed.wav")
97
  yield None, stream_log("Mixing selected stems...")
 
98
  if len(selected_stems) == 1:
99
  os.rename(selected_stems[0], output_file)
100
  else:
 
1
  import gradio as gr
 
2
  import os
3
+ import subprocess
4
+ from typing import Tuple, List, Dict
5
  from pydub import AudioSegment
 
 
 
 
 
6
 
 
 
 
 
7
  def inference(audio_file: str, model_name: str, vocals: bool, drums: bool, bass: bool, other: bool, mp3: bool, mp3_bitrate: int) -> Tuple[str, gr.HTML]:
 
8
  log_messages = []
9
 
10
  def stream_log(message):
 
14
  yield None, stream_log("Starting separation process...")
15
  yield None, stream_log(f"Loading audio file: {audio_file}")
16
 
 
17
  if audio_file is None:
18
  yield None, stream_log("Error: No audio file provided")
19
  raise gr.Error("Please upload an audio file")
20
 
21
+ output_dir = os.path.join("separated", model_name, os.path.splitext(os.path.basename(audio_file))[0])
22
+ os.makedirs(output_dir, exist_ok=True)
23
+
24
+ # Construct the Demucs command
25
+ cmd = [
26
+ "python", "-m", "demucs",
27
+ "--out", output_dir,
28
+ "-n", model_name,
29
+ audio_file
30
+ ]
31
+
32
+ yield None, stream_log(f"Running Demucs command: {' '.join(cmd)}")
 
 
 
 
 
 
 
 
33
 
 
34
  try:
35
+ # Run the Demucs command
36
+ process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True)
 
37
 
38
+ # Stream the output
39
+ for line in process.stdout:
40
+ yield None, stream_log(line.strip())
 
 
 
41
 
42
+ # Wait for the process to complete
43
+ process.wait()
44
+
45
+ if process.returncode != 0:
46
+ error_output = process.stderr.read()
47
+ yield None, stream_log(f"Error: Demucs command failed with return code {process.returncode}")
48
+ yield None, stream_log(f"Error output: {error_output}")
49
+ raise gr.Error(f"Demucs separation failed. Check the logs for details.")
50
+
51
  except Exception as e:
52
+ yield None, stream_log(f"Unexpected error: {str(e)}")
53
+ raise gr.Error(f"An unexpected error occurred: {str(e)}")
54
 
55
+ yield None, stream_log("Separation completed. Processing stems...")
 
 
 
56
 
 
 
57
  stems: Dict[str, str] = {}
58
+ for stem in ["vocals", "drums", "bass", "other"]:
59
+ stem_path = os.path.join(output_dir, model_name, f"{stem}.wav")
60
+ if os.path.exists(stem_path):
61
+ stems[stem] = stem_path
62
+ yield None, stream_log(f"Found {stem} stem")
63
 
64
+ selected_stems: List[str] = [stems[stem] for stem in stems if locals()[stem]]
65
  if not selected_stems:
66
  raise gr.Error("Please select at least one stem to mix.")
67
 
68
  output_file: str = os.path.join(output_dir, "mixed.wav")
69
  yield None, stream_log("Mixing selected stems...")
70
+
71
  if len(selected_stems) == 1:
72
  os.rename(selected_stems[0], output_file)
73
  else: