owiedotch commited on
Commit
763a29b
1 Parent(s): 411ac1d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -26
app.py CHANGED
@@ -1,13 +1,14 @@
1
  import gradio as gr
2
  import torch
3
  import torchaudio
4
- from agc import AGC
5
  import tempfile
6
  import numpy as np
7
  import lz4.frame
8
  import os
9
  from typing import Generator
10
  import spaces
 
11
 
12
  # Attempt to use GPU, fallback to CPU
13
  try:
@@ -17,14 +18,17 @@ except Exception as e:
17
  print(f"Error detecting GPU. Using CPU. Error: {e}")
18
  torch_device = torch.device("cpu")
19
 
20
- # Load the AGC model
21
- def load_agc_model():
22
- return AGC.from_pretrained("Audiogen/agc-continuous").to(torch_device)
23
 
24
- agc = load_agc_model()
 
 
 
25
 
26
- @spaces.GPU(duration=180)
27
  def encode_audio(audio_file_path):
 
28
  try:
29
  # Load the audio file
30
  waveform, sample_rate = torchaudio.load(audio_file_path)
@@ -32,17 +36,17 @@ def encode_audio(audio_file_path):
32
  # Encode the audio
33
  audio = waveform.unsqueeze(0).to(torch_device)
34
  with torch.no_grad():
35
- z = agc.encode(audio)
36
 
37
  # Convert to NumPy and save to a temporary .owie file
38
- z_numpy = z.detach().cpu().numpy()
39
  temp_fd, temp_file_path = tempfile.mkstemp(suffix=".owie")
40
- os.close(temp_fd) # Close the file descriptor to avoid issues with os.fdopen
41
  with open(temp_file_path, 'wb') as temp_file:
42
  # Store the sample rate as the first 4 bytes
43
  temp_file.write(sample_rate.to_bytes(4, byteorder='little'))
44
  # Compress and write the encoded data
45
- compressed_data = lz4.frame.compress(z_numpy.tobytes())
46
  temp_file.write(compressed_data)
47
 
48
  return temp_file_path
@@ -50,78 +54,101 @@ def encode_audio(audio_file_path):
50
  except Exception as e:
51
  return f"Encoding error: {e}"
52
 
53
- @spaces.GPU(duration=180)
 
 
 
54
  def decode_audio(encoded_file_path):
 
55
  try:
56
  # Load encoded data and sample rate from the .owie file
57
  with open(encoded_file_path, 'rb') as temp_file:
58
  sample_rate = int.from_bytes(temp_file.read(4), byteorder='little')
59
  compressed_data = temp_file.read()
60
- z_numpy_bytes = lz4.frame.decompress(compressed_data)
61
- z_numpy = np.frombuffer(z_numpy_bytes, dtype=np.float32).reshape(1, 32, -1)
62
- z = torch.from_numpy(z_numpy).to(torch_device)
63
 
64
  # Decode the audio
65
  with torch.no_grad():
66
- reconstructed_audio = agc.decode(z)
67
 
68
  # Save to a temporary WAV file
69
  temp_wav_path = tempfile.mktemp(suffix=".wav")
70
- torchaudio.save(temp_wav_path, reconstructed_audio.squeeze(0).cpu(), sample_rate)
71
  return temp_wav_path
72
 
73
  except Exception as e:
74
  return f"Decoding error: {e}"
75
 
76
- @spaces.GPU(duration=180)
77
- def stream_decode_audio(encoded_file_path) -> Generator[tuple, None, None]:
 
 
 
 
78
  try:
79
  # Load encoded data and sample rate from the .owie file
80
  with open(encoded_file_path, 'rb') as temp_file:
81
  sample_rate = int.from_bytes(temp_file.read(4), byteorder='little')
82
  compressed_data = temp_file.read()
83
- z_numpy_bytes = lz4.frame.decompress(compressed_data)
84
- z_numpy = np.frombuffer(z_numpy_bytes, dtype=np.float32).reshape(1, 32, -1)
85
- z = torch.from_numpy(z_numpy).to(torch_device)
86
 
87
  # Decode the audio in chunks
88
  chunk_size = sample_rate # Use the stored sample rate as chunk size
89
  with torch.no_grad():
90
- for i in range(0, z.shape[2], chunk_size):
91
- z_chunk = z[:, :, i:i+chunk_size]
92
- audio_chunk = agc.decode(z_chunk)
 
 
 
93
  # Convert to numpy array and transpose
94
  audio_data = audio_chunk.squeeze(0).cpu().numpy().T
95
  yield (sample_rate, audio_data)
 
96
 
97
  except Exception as e:
98
  print(f"Streaming decoding error: {e}")
99
- yield (sample_rate, np.zeros((chunk_size, 32), dtype=np.float32)) # Return silence
100
 
 
 
101
 
102
  # Gradio Interface
103
  with gr.Blocks() as demo:
104
- gr.Markdown("## Audio Compression with AGC (GPU/CPU)")
105
 
106
  with gr.Tab("Encode"):
107
  input_audio = gr.Audio(label="Input Audio", type="filepath")
108
  encode_button = gr.Button("Encode")
 
109
  encoded_output = gr.File(label="Encoded File (.owie)", type="filepath")
110
 
111
  encode_button.click(encode_audio, inputs=input_audio, outputs=encoded_output)
 
 
112
 
113
  with gr.Tab("Decode"):
114
  input_encoded = gr.File(label="Encoded File (.owie)", type="filepath")
115
  decode_button = gr.Button("Decode")
 
116
  decoded_output = gr.Audio(label="Decoded Audio", type="filepath")
117
 
118
  decode_button.click(decode_audio, inputs=input_encoded, outputs=decoded_output)
 
 
119
 
120
  with gr.Tab("Streaming"):
121
  input_encoded_stream = gr.File(label="Encoded File (.owie)", type="filepath")
122
  stream_button = gr.Button("Start Streaming")
 
123
  audio_output = gr.Audio(label="Streaming Audio Output", streaming=True)
124
 
125
  stream_button.click(stream_decode_audio, inputs=input_encoded_stream, outputs=audio_output)
 
 
126
 
127
  demo.queue().launch()
 
1
  import gradio as gr
2
  import torch
3
  import torchaudio
4
+ from semanticodec import SemantiCodec
5
  import tempfile
6
  import numpy as np
7
  import lz4.frame
8
  import os
9
  from typing import Generator
10
  import spaces
11
+ import asyncio # Import asyncio for cancellation
12
 
13
  # Attempt to use GPU, fallback to CPU
14
  try:
 
18
  print(f"Error detecting GPU. Using CPU. Error: {e}")
19
  torch_device = torch.device("cpu")
20
 
21
+ # Load the SemantiCodec model
22
+ semanticodec = SemantiCodec(token_rate=100, semantic_vocab_size=32768).to(torch_device)
 
23
 
24
+ # Global variable for cancellation
25
+ cancel_encode = False
26
+ cancel_decode = False
27
+ cancel_stream = False
28
 
29
+ @spaces.GPU(duration=500) # Increased GPU duration to 500 seconds
30
  def encode_audio(audio_file_path):
31
+ global cancel_encode
32
  try:
33
  # Load the audio file
34
  waveform, sample_rate = torchaudio.load(audio_file_path)
 
36
  # Encode the audio
37
  audio = waveform.unsqueeze(0).to(torch_device)
38
  with torch.no_grad():
39
+ tokens = semanticodec.encode(audio)
40
 
41
  # Convert to NumPy and save to a temporary .owie file
42
+ tokens_numpy = tokens.detach().cpu().numpy()
43
  temp_fd, temp_file_path = tempfile.mkstemp(suffix=".owie")
44
+ os.close(temp_fd)
45
  with open(temp_file_path, 'wb') as temp_file:
46
  # Store the sample rate as the first 4 bytes
47
  temp_file.write(sample_rate.to_bytes(4, byteorder='little'))
48
  # Compress and write the encoded data
49
+ compressed_data = lz4.frame.compress(tokens_numpy.tobytes())
50
  temp_file.write(compressed_data)
51
 
52
  return temp_file_path
 
54
  except Exception as e:
55
  return f"Encoding error: {e}"
56
 
57
+ finally:
58
+ cancel_encode = False # Reset cancel flag after encoding
59
+
60
+ @spaces.GPU(duration=500) # Increased GPU duration to 500 seconds
61
  def decode_audio(encoded_file_path):
62
+ global cancel_decode
63
  try:
64
  # Load encoded data and sample rate from the .owie file
65
  with open(encoded_file_path, 'rb') as temp_file:
66
  sample_rate = int.from_bytes(temp_file.read(4), byteorder='little')
67
  compressed_data = temp_file.read()
68
+ tokens_numpy_bytes = lz4.frame.decompress(compressed_data)
69
+ tokens_numpy = np.frombuffer(tokens_numpy_bytes, dtype=np.int64)
70
+ tokens = torch.from_numpy(tokens_numpy).to(torch_device)
71
 
72
  # Decode the audio
73
  with torch.no_grad():
74
+ waveform = semanticodec.decode(tokens)
75
 
76
  # Save to a temporary WAV file
77
  temp_wav_path = tempfile.mktemp(suffix=".wav")
78
+ torchaudio.save(temp_wav_path, waveform.squeeze(0).cpu(), sample_rate)
79
  return temp_wav_path
80
 
81
  except Exception as e:
82
  return f"Decoding error: {e}"
83
 
84
+ finally:
85
+ cancel_decode = False # Reset cancel flag after decoding
86
+
87
+ @spaces.GPU(duration=500) # Increased GPU duration to 500 seconds
88
+ async def stream_decode_audio(encoded_file_path) -> Generator[tuple, None, None]:
89
+ global cancel_stream
90
  try:
91
  # Load encoded data and sample rate from the .owie file
92
  with open(encoded_file_path, 'rb') as temp_file:
93
  sample_rate = int.from_bytes(temp_file.read(4), byteorder='little')
94
  compressed_data = temp_file.read()
95
+ tokens_numpy_bytes = lz4.frame.decompress(compressed_data)
96
+ tokens_numpy = np.frombuffer(tokens_numpy_bytes, dtype=np.int64)
97
+ tokens = torch.from_numpy(tokens_numpy).to(torch_device)
98
 
99
  # Decode the audio in chunks
100
  chunk_size = sample_rate # Use the stored sample rate as chunk size
101
  with torch.no_grad():
102
+ for i in range(0, tokens.shape[1], chunk_size):
103
+ if cancel_stream:
104
+ break # Exit the loop if cancellation is requested
105
+
106
+ tokens_chunk = tokens[:, i:i+chunk_size]
107
+ audio_chunk = semanticodec.decode(tokens_chunk)
108
  # Convert to numpy array and transpose
109
  audio_data = audio_chunk.squeeze(0).cpu().numpy().T
110
  yield (sample_rate, audio_data)
111
+ await asyncio.sleep(0) # Allow for cancellation check
112
 
113
  except Exception as e:
114
  print(f"Streaming decoding error: {e}")
115
+ yield (sample_rate, np.zeros((chunk_size, 1), dtype=np.float32)) # Return silence
116
 
117
+ finally:
118
+ cancel_stream = False # Reset cancel flag after streaming
119
 
120
  # Gradio Interface
121
  with gr.Blocks() as demo:
122
+ gr.Markdown("## Audio Compression with SemantiCodec (GPU/CPU)")
123
 
124
  with gr.Tab("Encode"):
125
  input_audio = gr.Audio(label="Input Audio", type="filepath")
126
  encode_button = gr.Button("Encode")
127
+ cancel_encode_button = gr.Button("Cancel")
128
  encoded_output = gr.File(label="Encoded File (.owie)", type="filepath")
129
 
130
  encode_button.click(encode_audio, inputs=input_audio, outputs=encoded_output)
131
+ cancel_encode_button.click(lambda: globals().update(cancel_encode=True),
132
+ outputs=None) # Set cancel_encode flag
133
 
134
  with gr.Tab("Decode"):
135
  input_encoded = gr.File(label="Encoded File (.owie)", type="filepath")
136
  decode_button = gr.Button("Decode")
137
+ cancel_decode_button = gr.Button("Cancel")
138
  decoded_output = gr.Audio(label="Decoded Audio", type="filepath")
139
 
140
  decode_button.click(decode_audio, inputs=input_encoded, outputs=decoded_output)
141
+ cancel_decode_button.click(lambda: globals().update(cancel_decode=True),
142
+ outputs=None) # Set cancel_decode flag
143
 
144
  with gr.Tab("Streaming"):
145
  input_encoded_stream = gr.File(label="Encoded File (.owie)", type="filepath")
146
  stream_button = gr.Button("Start Streaming")
147
+ cancel_stream_button = gr.Button("Cancel")
148
  audio_output = gr.Audio(label="Streaming Audio Output", streaming=True)
149
 
150
  stream_button.click(stream_decode_audio, inputs=input_encoded_stream, outputs=audio_output)
151
+ cancel_stream_button.click(lambda: globals().update(cancel_stream=True),
152
+ outputs=None) # Set cancel_stream flag
153
 
154
  demo.queue().launch()