owiedotch commited on
Commit
8153f3b
1 Parent(s): 5136e79

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -5
app.py CHANGED
@@ -1,11 +1,14 @@
1
  import gradio as gr
2
  import torch
3
- import demucs.api
4
  import os
5
  import spaces
6
  import subprocess
7
  from pydub import AudioSegment
8
  from typing import Tuple, Dict, List
 
 
 
9
 
10
  # check if cuda is available
11
  device: str = "cuda" if torch.cuda.is_available() else "cpu"
@@ -46,21 +49,28 @@ def inference(audio_file: str, model_name: str, vocals: bool, drums: bool, bass:
46
  """
47
 
48
  # initialize demucs separator
49
- separator: demucs.api.Separator = demucs.api.Separator(model=model_name)
 
50
 
51
  # separate audio file and capture log
52
  import io
53
  log_stream = io.StringIO()
54
- origin, separated = separator.separate_audio_file(audio_file, progress=True, log_stream=log_stream)
 
 
 
 
 
55
  separation_log = log_stream.getvalue()
56
 
57
  # get the output file paths
58
  output_dir: str = os.path.join("separated", model_name, os.path.splitext(os.path.basename(audio_file))[0])
59
  os.makedirs(output_dir, exist_ok=True) # create output directory if it doesn't exist
60
  stems: Dict[str, str] = {}
61
- for stem, source in separated.items():
62
  stem_path: str = os.path.join(output_dir, f"{stem}.wav")
63
- demucs.api.save_audio(source, stem_path, samplerate=separator.samplerate)
 
64
  stems[stem] = stem_path
65
 
66
  # mix the selected stems
 
1
  import gradio as gr
2
  import torch
3
+ # import demucs.api
4
  import os
5
  import spaces
6
  import subprocess
7
  from pydub import AudioSegment
8
  from typing import Tuple, Dict, List
9
+ from demucs.apply import apply_model
10
+ from demucs.separate import load_track, save_tracks
11
+ from demucs.pretrained import get_model
12
 
13
  # check if cuda is available
14
  device: str = "cuda" if torch.cuda.is_available() else "cpu"
 
49
  """
50
 
51
  # initialize demucs separator
52
+ # separator: demucs.api.Separator = demucs.api.Separator(model=model_name)
53
+ separator = get_model(name=model_name)
54
 
55
  # separate audio file and capture log
56
  import io
57
  log_stream = io.StringIO()
58
+ # origin, separated = separator.separate_audio_file(audio_file, progress=True, log_stream=log_stream)
59
+ wav = load_track(audio_file, separator.samplerate, channels=separator.audio_channels)
60
+ ref = wav.mean(0)
61
+ wav = (wav - ref.view(1, -1)).to(device)
62
+ sources = apply_model(separator, wav, device=device, progress=True, log_stream=log_stream)
63
+ sources = sources * ref.view(1, -1) + ref.view(1, -1)
64
  separation_log = log_stream.getvalue()
65
 
66
  # get the output file paths
67
  output_dir: str = os.path.join("separated", model_name, os.path.splitext(os.path.basename(audio_file))[0])
68
  os.makedirs(output_dir, exist_ok=True) # create output directory if it doesn't exist
69
  stems: Dict[str, str] = {}
70
+ for stem, source in zip(separator.sources, sources):
71
  stem_path: str = os.path.join(output_dir, f"{stem}.wav")
72
+ # demucs.api.save_audio(source, stem_path, samplerate=separator.samplerate)
73
+ save_tracks(source, stem_path, separator.samplerate)
74
  stems[stem] = stem_path
75
 
76
  # mix the selected stems