owiedotch commited on
Commit
cefd33c
1 Parent(s): c07b48c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -51
app.py CHANGED
@@ -27,7 +27,7 @@ cancel_encode = False
27
  cancel_decode = False
28
  cancel_stream = False
29
 
30
- @spaces.GPU(duration=30) # Changed from 250 to 30
31
  def encode_audio(audio_file_path):
32
  global cancel_encode
33
 
@@ -51,22 +51,10 @@ def encode_audio(audio_file_path):
51
  # Encode the audio
52
  tokens = semanticodec.encode(temp_wav_file_path)
53
 
54
- # Convert tokens to NumPy and save to .owie file
55
  tokens_numpy = tokens.detach().cpu().numpy()
56
 
57
- print(f"Original tokens shape: {tokens_numpy.shape}")
58
-
59
- # Ensure tokens_numpy is 2D
60
- if tokens_numpy.ndim == 1:
61
- tokens_numpy = tokens_numpy.reshape(1, -1)
62
- elif tokens_numpy.ndim == 2:
63
- pass # Already 2D
64
- elif tokens_numpy.ndim == 3 and tokens_numpy.shape[0] == 1:
65
- tokens_numpy = tokens_numpy.squeeze(0)
66
- else:
67
- raise ValueError(f"Unexpected tokens array shape: {tokens_numpy.shape}")
68
-
69
- print(f"Reshaped tokens shape: {tokens_numpy.shape}")
70
 
71
  # Create temporary .owie file
72
  temp_fd, temp_file_path = tempfile.mkstemp(suffix=".owie")
@@ -82,12 +70,12 @@ def encode_audio(audio_file_path):
82
 
83
  except Exception as e:
84
  print(f"Encoding error: {e}")
85
- return None # Return None instead of the error message
86
 
87
  finally:
88
- cancel_encode = False # Reset cancel flag after encoding
89
  if 'temp_wav_file_path' in locals():
90
- os.remove(temp_wav_file_path) # Clean up temporary WAV file
91
 
92
  # Add this function to handle the output
93
  def handle_encode_output(file_path):
@@ -95,7 +83,7 @@ def handle_encode_output(file_path):
95
  return None, gr.Markdown("Encoding failed. Please ensure you've uploaded an audio file and try again.", visible=True)
96
  return file_path, gr.Markdown(visible=False)
97
 
98
- @spaces.GPU(duration=30) # Changed from 250 to 30
99
  def decode_audio(encoded_file_path):
100
  global cancel_decode
101
 
@@ -105,18 +93,11 @@ def decode_audio(encoded_file_path):
105
  sample_rate = int.from_bytes(temp_file.read(4), byteorder='little')
106
  compressed_data = temp_file.read()
107
  tokens_numpy_bytes = lz4.frame.decompress(compressed_data)
108
- tokens_numpy = np.frombuffer(tokens_numpy_bytes, dtype=np.int64)
109
-
110
- # Reshape tokens to match the original shape
111
- tokens_numpy = tokens_numpy.reshape(1, -1, 2)
112
-
113
- # Create a writable copy of the numpy array
114
- tokens_numpy = np.array(tokens_numpy, copy=True)
115
 
116
  # Move the tensor to the same device as the model
117
  tokens = torch.from_numpy(tokens_numpy).to(device=semanticodec.device)
118
 
119
- # Debugging prints to check tensor shapes and device
120
  print(f"Tokens shape: {tokens.shape}, dtype: {tokens.dtype}, device: {tokens.device}")
121
  print(f"Model device: {semanticodec.device}")
122
 
@@ -124,23 +105,20 @@ def decode_audio(encoded_file_path):
124
  with torch.no_grad():
125
  waveform = semanticodec.decode(tokens)
126
 
127
- # Move waveform to CPU for saving
128
- waveform_cpu = waveform.cpu()
129
-
130
  # Save to a temporary WAV file
131
  temp_wav_path = tempfile.mktemp(suffix=".wav")
132
- torchaudio.save(temp_wav_path, waveform_cpu.squeeze(0), sample_rate)
133
  return temp_wav_path
134
 
135
  except Exception as e:
136
  print(f"Decoding error: {e}")
137
  print(f"Traceback: {traceback.format_exc()}")
138
- return str(e) # Return error message as string
139
 
140
  finally:
141
- cancel_decode = False # Reset cancel flag after decoding
142
 
143
- @spaces.GPU(duration=30) # Changed from 250 to 30
144
  async def stream_decode_audio(encoded_file_path) -> Generator[tuple, None, None]:
145
  global cancel_stream
146
 
@@ -150,11 +128,7 @@ async def stream_decode_audio(encoded_file_path) -> Generator[tuple, None, None]
150
  sample_rate = int.from_bytes(temp_file.read(4), byteorder='little')
151
  compressed_data = temp_file.read()
152
  tokens_numpy_bytes = lz4.frame.decompress(compressed_data)
153
- tokens_numpy = np.frombuffer(tokens_numpy_bytes, dtype=np.int64)
154
- tokens_numpy = tokens_numpy.reshape(1, -1, 2)
155
-
156
- # Create a writable copy of the numpy array
157
- tokens_numpy = np.array(tokens_numpy, copy=True)
158
 
159
  # Move the tensor to the same device as the model
160
  tokens = torch.from_numpy(tokens_numpy).to(device=semanticodec.device)
@@ -163,13 +137,13 @@ async def stream_decode_audio(encoded_file_path) -> Generator[tuple, None, None]
163
  print(f"Model device: {semanticodec.device}")
164
 
165
  # Decode the audio in chunks
166
- chunk_size = sample_rate // 2 # Adjust chunk size to account for the new shape
167
  with torch.no_grad():
168
- for i in range(0, tokens.shape[1], chunk_size):
169
  if cancel_stream:
170
  break # Exit the loop if cancellation is requested
171
 
172
- tokens_chunk = tokens[:, i:i+chunk_size, :]
173
  audio_chunk = semanticodec.decode(tokens_chunk)
174
  # Convert to numpy array and transpose
175
  audio_data = audio_chunk.squeeze(0).cpu().numpy().T
@@ -182,17 +156,17 @@ async def stream_decode_audio(encoded_file_path) -> Generator[tuple, None, None]
182
  yield (sample_rate, np.zeros((chunk_size, 1), dtype=np.float32)) # Return silence
183
 
184
  finally:
185
- cancel_stream = False # Reset cancel flag after streaming
186
 
187
  # Gradio Interface
188
  with gr.Blocks() as demo:
189
  gr.Markdown("## Audio Compression with SemantiCodec (GPU/CPU)")
190
 
191
  with gr.Tab("Encode"):
192
- input_audio = gr.Audio(label="Input Audio", type="filepath") # Using "filepath" mode
193
  encode_button = gr.Button("Encode")
194
  cancel_encode_button = gr.Button("Cancel")
195
- encoded_output = gr.File(label="Encoded File (.owie)", type="filepath") # Using "filepath" mode
196
  encode_error_message = gr.Markdown(visible=False)
197
 
198
  def encode_wrapper(audio):
@@ -205,24 +179,24 @@ with gr.Blocks() as demo:
205
  inputs=input_audio,
206
  outputs=[encoded_output, encode_error_message]
207
  )
208
- cancel_encode_button.click(lambda: globals().update(cancel_encode=True), outputs=None) # Set cancel_encode flag
209
 
210
  with gr.Tab("Decode"):
211
- input_encoded = gr.File(label="Encoded File (.owie)", type="filepath") # Using "filepath" mode
212
  decode_button = gr.Button("Decode")
213
  cancel_decode_button = gr.Button("Cancel")
214
- decoded_output = gr.Audio(label="Decoded Audio", type="filepath") # Using "filepath" mode
215
 
216
  decode_button.click(decode_audio, inputs=input_encoded, outputs=decoded_output)
217
- cancel_decode_button.click(lambda: globals().update(cancel_decode=True), outputs=None) # Set cancel_decode flag
218
 
219
  with gr.Tab("Streaming"):
220
- input_encoded_stream = gr.File(label="Encoded File (.owie)", type="filepath") # Using "filepath" mode
221
  stream_button = gr.Button("Start Streaming")
222
  cancel_stream_button = gr.Button("Cancel")
223
  audio_output = gr.Audio(label="Streaming Audio Output", streaming=True)
224
 
225
  stream_button.click(stream_decode_audio, inputs=input_encoded_stream, outputs=audio_output)
226
- cancel_stream_button.click(lambda: globals().update(cancel_stream=True), outputs=None) # Set cancel_stream flag
227
 
228
  demo.queue().launch()
 
27
  cancel_decode = False
28
  cancel_stream = False
29
 
30
+ @spaces.GPU(duration=30)
31
  def encode_audio(audio_file_path):
32
  global cancel_encode
33
 
 
51
  # Encode the audio
52
  tokens = semanticodec.encode(temp_wav_file_path)
53
 
54
+ # Convert tokens to NumPy
55
  tokens_numpy = tokens.detach().cpu().numpy()
56
 
57
+ print(f"Tokens shape: {tokens_numpy.shape}")
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
  # Create temporary .owie file
60
  temp_fd, temp_file_path = tempfile.mkstemp(suffix=".owie")
 
70
 
71
  except Exception as e:
72
  print(f"Encoding error: {e}")
73
+ return None
74
 
75
  finally:
76
+ cancel_encode = False
77
  if 'temp_wav_file_path' in locals():
78
+ os.remove(temp_wav_file_path)
79
 
80
  # Add this function to handle the output
81
  def handle_encode_output(file_path):
 
83
  return None, gr.Markdown("Encoding failed. Please ensure you've uploaded an audio file and try again.", visible=True)
84
  return file_path, gr.Markdown(visible=False)
85
 
86
+ @spaces.GPU(duration=30)
87
  def decode_audio(encoded_file_path):
88
  global cancel_decode
89
 
 
93
  sample_rate = int.from_bytes(temp_file.read(4), byteorder='little')
94
  compressed_data = temp_file.read()
95
  tokens_numpy_bytes = lz4.frame.decompress(compressed_data)
96
+ tokens_numpy = np.frombuffer(tokens_numpy_bytes, dtype=np.int64).reshape(-1)
 
 
 
 
 
 
97
 
98
  # Move the tensor to the same device as the model
99
  tokens = torch.from_numpy(tokens_numpy).to(device=semanticodec.device)
100
 
 
101
  print(f"Tokens shape: {tokens.shape}, dtype: {tokens.dtype}, device: {tokens.device}")
102
  print(f"Model device: {semanticodec.device}")
103
 
 
105
  with torch.no_grad():
106
  waveform = semanticodec.decode(tokens)
107
 
 
 
 
108
  # Save to a temporary WAV file
109
  temp_wav_path = tempfile.mktemp(suffix=".wav")
110
+ torchaudio.save(temp_wav_path, waveform.squeeze(0).cpu(), sample_rate)
111
  return temp_wav_path
112
 
113
  except Exception as e:
114
  print(f"Decoding error: {e}")
115
  print(f"Traceback: {traceback.format_exc()}")
116
+ return str(e)
117
 
118
  finally:
119
+ cancel_decode = False
120
 
121
+ @spaces.GPU(duration=30)
122
  async def stream_decode_audio(encoded_file_path) -> Generator[tuple, None, None]:
123
  global cancel_stream
124
 
 
128
  sample_rate = int.from_bytes(temp_file.read(4), byteorder='little')
129
  compressed_data = temp_file.read()
130
  tokens_numpy_bytes = lz4.frame.decompress(compressed_data)
131
+ tokens_numpy = np.frombuffer(tokens_numpy_bytes, dtype=np.int64).reshape(-1)
 
 
 
 
132
 
133
  # Move the tensor to the same device as the model
134
  tokens = torch.from_numpy(tokens_numpy).to(device=semanticodec.device)
 
137
  print(f"Model device: {semanticodec.device}")
138
 
139
  # Decode the audio in chunks
140
+ chunk_size = sample_rate * 2 # Adjust chunk size as needed
141
  with torch.no_grad():
142
+ for i in range(0, tokens.shape[0], chunk_size):
143
  if cancel_stream:
144
  break # Exit the loop if cancellation is requested
145
 
146
+ tokens_chunk = tokens[i:i+chunk_size]
147
  audio_chunk = semanticodec.decode(tokens_chunk)
148
  # Convert to numpy array and transpose
149
  audio_data = audio_chunk.squeeze(0).cpu().numpy().T
 
156
  yield (sample_rate, np.zeros((chunk_size, 1), dtype=np.float32)) # Return silence
157
 
158
  finally:
159
+ cancel_stream = False
160
 
161
  # Gradio Interface
162
  with gr.Blocks() as demo:
163
  gr.Markdown("## Audio Compression with SemantiCodec (GPU/CPU)")
164
 
165
  with gr.Tab("Encode"):
166
+ input_audio = gr.Audio(label="Input Audio", type="filepath")
167
  encode_button = gr.Button("Encode")
168
  cancel_encode_button = gr.Button("Cancel")
169
+ encoded_output = gr.File(label="Encoded File (.owie)", type="filepath")
170
  encode_error_message = gr.Markdown(visible=False)
171
 
172
  def encode_wrapper(audio):
 
179
  inputs=input_audio,
180
  outputs=[encoded_output, encode_error_message]
181
  )
182
+ cancel_encode_button.click(lambda: globals().update(cancel_encode=True), outputs=None)
183
 
184
  with gr.Tab("Decode"):
185
+ input_encoded = gr.File(label="Encoded File (.owie)", type="filepath")
186
  decode_button = gr.Button("Decode")
187
  cancel_decode_button = gr.Button("Cancel")
188
+ decoded_output = gr.Audio(label="Decoded Audio", type="filepath")
189
 
190
  decode_button.click(decode_audio, inputs=input_encoded, outputs=decoded_output)
191
+ cancel_decode_button.click(lambda: globals().update(cancel_decode=True), outputs=None)
192
 
193
  with gr.Tab("Streaming"):
194
+ input_encoded_stream = gr.File(label="Encoded File (.owie)", type="filepath")
195
  stream_button = gr.Button("Start Streaming")
196
  cancel_stream_button = gr.Button("Cancel")
197
  audio_output = gr.Audio(label="Streaming Audio Output", streaming=True)
198
 
199
  stream_button.click(stream_decode_audio, inputs=input_encoded_stream, outputs=audio_output)
200
+ cancel_stream_button.click(lambda: globals().update(cancel_stream=True), outputs=None)
201
 
202
  demo.queue().launch()