owiedotch commited on
Commit
c99df4b
1 Parent(s): c6e9cf1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -8
app.py CHANGED
@@ -1,14 +1,10 @@
1
  import gradio as gr
2
- import spaces
3
  import torch
4
  import dac
5
  import io
6
  from audiotools import AudioSignal
7
  from pydub import AudioSegment
8
 
9
- # Ensure we're using CPU even if GPU is available
10
- torch.set_default_tensor_type(torch.FloatTensor)
11
-
12
  class DACApi:
13
  def __init__(self, model_type="44khz", model_bitrate="16kbps"):
14
  self.model_type = model_type
@@ -16,8 +12,10 @@ class DACApi:
16
  self.model_path = dac.utils.download(model_type, model_bitrate)
17
  print("Loading DAC model...")
18
  self.model = dac.DAC.load(self.model_path)
19
- self.model.to('cpu')
 
20
 
 
21
  def encode_audio(self, audio):
22
  # Convert audio to WAV
23
  audio = AudioSegment.from_file(audio.name)
@@ -27,6 +25,7 @@ class DACApi:
27
 
28
  # Load audio signal
29
  signal = AudioSignal(wav_io)
 
30
 
31
  # Compress audio
32
  print("Compressing audio...")
@@ -39,9 +38,11 @@ class DACApi:
39
 
40
  return output
41
 
 
42
  def decode_audio(self, compressed_file):
43
  # Load compressed audio
44
  compressed = dac.DACFile.load(compressed_file)
 
45
 
46
  # Decompress audio
47
  print("Decompressing audio...")
@@ -56,12 +57,10 @@ class DACApi:
56
 
57
  dac_api = DACApi()
58
 
59
- @spaces.CPU
60
  def encode(audio):
61
  compressed = dac_api.encode_audio(audio)
62
  return compressed
63
 
64
- @spaces.CPU
65
  def decode(compressed_file):
66
  decompressed = dac_api.decode_audio(compressed_file)
67
  return decompressed
@@ -84,4 +83,4 @@ with gr.Blocks() as demo:
84
 
85
  decode_button.click(decode, inputs=compressed_input, outputs=decoded_output)
86
 
87
- demo.launch()
 
1
  import gradio as gr
 
2
  import torch
3
  import dac
4
  import io
5
  from audiotools import AudioSignal
6
  from pydub import AudioSegment
7
 
 
 
 
8
  class DACApi:
9
  def __init__(self, model_type="44khz", model_bitrate="16kbps"):
10
  self.model_type = model_type
 
12
  self.model_path = dac.utils.download(model_type, model_bitrate)
13
  print("Loading DAC model...")
14
  self.model = dac.DAC.load(self.model_path)
15
+ self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
16
+ self.model.to(self.device)
17
 
18
+ @spaces.GPU
19
  def encode_audio(self, audio):
20
  # Convert audio to WAV
21
  audio = AudioSegment.from_file(audio.name)
 
25
 
26
  # Load audio signal
27
  signal = AudioSignal(wav_io)
28
+ signal = signal.to(self.device)
29
 
30
  # Compress audio
31
  print("Compressing audio...")
 
38
 
39
  return output
40
 
41
+ @spaces.GPU
42
  def decode_audio(self, compressed_file):
43
  # Load compressed audio
44
  compressed = dac.DACFile.load(compressed_file)
45
+ compressed = compressed.to(self.device)
46
 
47
  # Decompress audio
48
  print("Decompressing audio...")
 
57
 
58
  dac_api = DACApi()
59
 
 
60
  def encode(audio):
61
  compressed = dac_api.encode_audio(audio)
62
  return compressed
63
 
 
64
  def decode(compressed_file):
65
  decompressed = dac_api.decode_audio(compressed_file)
66
  return decompressed
 
83
 
84
  decode_button.click(decode, inputs=compressed_input, outputs=decoded_output)
85
 
86
+ demo.launch()