owiedotch commited on
Commit
58ff621
1 Parent(s): 15e2eb9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -5
app.py CHANGED
@@ -22,14 +22,21 @@ def inference(audio_file: str, model_name: str, vocals: bool, drums: bool, bass:
22
  yield None, stream_log("Starting separation process...")
23
  yield None, stream_log(f"Loading audio file: {audio_file}")
24
 
25
- # Load the audio file with the correct parameters
26
- wav = load_track(audio_file, device, audio_channels=2, samplerate=separator.samplerate)
27
 
28
- # The load_track function already handles channel conversion and resampling,
29
- # so we can remove the channel adjustment code here.
 
 
 
 
 
30
 
 
 
31
  ref = wav.mean(0)
32
- wav = (wav - ref.view(1, -1)).to(device)
33
  yield None, stream_log("Audio loaded successfully. Applying model...")
34
  sources = apply_model(separator, wav, device=device, progress=True)
35
  sources = sources * ref.view(1, -1) + ref.view(1, -1)
 
22
  yield None, stream_log("Starting separation process...")
23
  yield None, stream_log(f"Loading audio file: {audio_file}")
24
 
25
+ # Load the audio file with the correct samplerate
26
+ wav, sr = load_track(audio_file, separator.samplerate)
27
 
28
+ # Check the number of channels and adjust if necessary
29
+ if wav.dim() == 1:
30
+ wav = wav.unsqueeze(0) # Add channel dimension if mono
31
+ if wav.shape[0] == 1:
32
+ wav = wav.repeat(2, 1) # If mono, duplicate to stereo
33
+ elif wav.shape[0] > 2:
34
+ wav = wav[:2] # If more than 2 channels, keep only the first two
35
 
36
+ wav = wav.to(device)
37
+
38
  ref = wav.mean(0)
39
+ wav = (wav - ref.view(1, -1))
40
  yield None, stream_log("Audio loaded successfully. Applying model...")
41
  sources = apply_model(separator, wav, device=device, progress=True)
42
  sources = sources * ref.view(1, -1) + ref.view(1, -1)