owiedotch commited on
Commit
ab12e78
1 Parent(s): d07be48

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -23
app.py CHANGED
@@ -3,48 +3,46 @@ import torch
3
  from datasets import load_dataset, Audio
4
  from transformers import EncodecModel, AutoProcessor
5
  import spaces
 
 
6
 
7
  # Load the Encodec model and processor
8
  model = EncodecModel.from_pretrained("facebook/encodec_48khz")
9
  processor = AutoProcessor.from_pretrained("facebook/encodec_48khz")
10
 
11
  @spaces.GPU
12
- def encode(audio_file_path): # Change argument name to reflect it's a path
13
  try:
14
  # Open the audio file
15
  with open(audio_file_path, "rb") as audio_file:
16
  # Load and preprocess the audio
17
  audio_sample, sampling_rate = load_dataset("audiofolder", data_dir=audio_file_path, split="train")[0]["audio"]
18
- inputs = processor(raw_audio=audio_sample, sampling_rate=sampling_rate, return_tensors="pt")
19
 
20
- # Encode the audio
21
- with torch.no_grad():
22
- encoder_outputs = model.encode(inputs["input_values"], inputs["padding_mask"])
23
 
24
- # Extract the encoded codes and scales
25
- audio_codes = encoder_outputs.audio_codes
26
- audio_scales = encoder_outputs.audio_scales
27
-
28
- # Return the encoded data
29
- return {"codes": audio_codes.tolist(), "scales": audio_scales.tolist()}
30
 
31
  except Exception as e:
32
  gr.Warning(f"An error occurred during encoding: {e}")
33
  return None
34
 
35
  @spaces.GPU
36
- def decode(encoded_data):
37
  try:
38
- # Convert the encoded data back to tensors
39
- audio_codes = torch.tensor(encoded_data["codes"])
40
- audio_scales = torch.tensor(encoded_data["scales"])
41
-
42
- # Decode the audio
43
- with torch.no_grad():
44
- audio_values = model.decode(audio_codes, audio_scales)[0]
45
 
 
 
 
46
  # Convert the decoded audio to a numpy array for Gradio output
47
- decoded_audio = audio_values.cpu().numpy()
48
 
49
  return decoded_audio
50
 
@@ -61,17 +59,17 @@ with gr.Blocks() as demo:
61
  audio_input = gr.Audio(type="filepath", label="Input Audio")
62
  encode_button = gr.Button("Encode", variant="primary")
63
  with gr.Row():
64
- encoded_output = gr.JSON(label="Encoded Audio")
65
 
66
  encode_button.click(encode, inputs=audio_input, outputs=encoded_output)
67
 
68
  with gr.Tab("Decode"):
69
  with gr.Row():
70
- encoded_input = gr.JSON(label="Encoded Audio")
71
  decode_button = gr.Button("Decode", variant="primary")
72
  with gr.Row():
73
  decoded_output = gr.Audio(label="Decompressed Audio")
74
 
75
- decode_button.click(decode, inputs=encoded_input, outputs=decoded_output)
76
 
77
  demo.queue().launch()
 
3
  from datasets import load_dataset, Audio
4
  from transformers import EncodecModel, AutoProcessor
5
  import spaces
6
+ from encodec import compress, decompress
7
+ import io
8
 
9
  # Load the Encodec model and processor
10
  model = EncodecModel.from_pretrained("facebook/encodec_48khz")
11
  processor = AutoProcessor.from_pretrained("facebook/encodec_48khz")
12
 
13
  @spaces.GPU
14
+ def encode(audio_file_path):
15
  try:
16
  # Open the audio file
17
  with open(audio_file_path, "rb") as audio_file:
18
  # Load and preprocess the audio
19
  audio_sample, sampling_rate = load_dataset("audiofolder", data_dir=audio_file_path, split="train")[0]["audio"]
20
+ wav = torch.tensor(audio_sample).unsqueeze(0)
21
 
22
+ # Compress to ecdc
23
+ compressed_audio = compress(model, wav)
 
24
 
25
+ # Save compressed audio to BytesIO
26
+ output = io.BytesIO(compressed_audio)
27
+ output.seek(0)
28
+
29
+ return output
 
30
 
31
  except Exception as e:
32
  gr.Warning(f"An error occurred during encoding: {e}")
33
  return None
34
 
35
  @spaces.GPU
36
+ def decode(compressed_audio_file):
37
  try:
38
+ # Load compressed audio
39
+ compressed_audio = compressed_audio_file.read()
 
 
 
 
 
40
 
41
+ # Decompress audio
42
+ wav, sr = decompress(compressed_audio)
43
+
44
  # Convert the decoded audio to a numpy array for Gradio output
45
+ decoded_audio = wav.cpu().numpy()
46
 
47
  return decoded_audio
48
 
 
59
  audio_input = gr.Audio(type="filepath", label="Input Audio")
60
  encode_button = gr.Button("Encode", variant="primary")
61
  with gr.Row():
62
+ encoded_output = gr.File(label="Compressed Audio (.ecdc)")
63
 
64
  encode_button.click(encode, inputs=audio_input, outputs=encoded_output)
65
 
66
  with gr.Tab("Decode"):
67
  with gr.Row():
68
+ compressed_input = gr.File(label="Compressed Audio (.ecdc)")
69
  decode_button = gr.Button("Decode", variant="primary")
70
  with gr.Row():
71
  decoded_output = gr.Audio(label="Decompressed Audio")
72
 
73
+ decode_button.click(decode, inputs=compressed_input, outputs=decoded_output)
74
 
75
  demo.queue().launch()