Update app.py
Browse files
app.py
CHANGED
@@ -66,11 +66,12 @@ def decode_audio(encoded_file_path):
|
|
66 |
compressed_data = temp_file.read()
|
67 |
tokens_numpy_bytes = lz4.frame.decompress(compressed_data)
|
68 |
tokens_numpy = np.frombuffer(tokens_numpy_bytes, dtype=np.int64) # Ensure dtype matches encoder's output
|
69 |
-
tokens = torch.from_numpy(tokens_numpy).to(torch_device)
|
70 |
|
71 |
-
|
72 |
-
|
73 |
-
|
|
|
|
|
74 |
|
75 |
# Debugging prints to check tensor shapes
|
76 |
print(f"Tokens shape: {tokens.shape}, dtype: {tokens.dtype}")
|
@@ -94,6 +95,11 @@ async def stream_decode_audio(encoded_file_path) -> Generator[tuple, None, None]
|
|
94 |
compressed_data = temp_file.read()
|
95 |
tokens_numpy_bytes = lz4.frame.decompress(compressed_data)
|
96 |
tokens_numpy = np.frombuffer(tokens_numpy_bytes, dtype=np.int64) # Ensure dtype matches encoder's output
|
|
|
|
|
|
|
|
|
|
|
97 |
tokens = torch.from_numpy(tokens_numpy).to(torch_device)
|
98 |
|
99 |
# Ensure tokens has the right dimensions
|
|
|
66 |
compressed_data = temp_file.read()
|
67 |
tokens_numpy_bytes = lz4.frame.decompress(compressed_data)
|
68 |
tokens_numpy = np.frombuffer(tokens_numpy_bytes, dtype=np.int64) # Ensure dtype matches encoder's output
|
|
|
69 |
|
70 |
+
# Check if tokens are 1D and reshape to 3D
|
71 |
+
if tokens_numpy.ndim == 1:
|
72 |
+
tokens_numpy = tokens_numpy.reshape(1, -1, 1) # Reshape to [batch_size, token_length, 1]
|
73 |
+
|
74 |
+
tokens = torch.from_numpy(tokens_numpy).to(torch_device)
|
75 |
|
76 |
# Debugging prints to check tensor shapes
|
77 |
print(f"Tokens shape: {tokens.shape}, dtype: {tokens.dtype}")
|
|
|
95 |
compressed_data = temp_file.read()
|
96 |
tokens_numpy_bytes = lz4.frame.decompress(compressed_data)
|
97 |
tokens_numpy = np.frombuffer(tokens_numpy_bytes, dtype=np.int64) # Ensure dtype matches encoder's output
|
98 |
+
|
99 |
+
# Check if tokens are 1D and reshape to 3D
|
100 |
+
if tokens_numpy.ndim == 1:
|
101 |
+
tokens_numpy = tokens_numpy.reshape(1, -1, 1) # Reshape to [batch_size, token_length, 1]
|
102 |
+
|
103 |
tokens = torch.from_numpy(tokens_numpy).to(torch_device)
|
104 |
|
105 |
# Ensure tokens has the right dimensions
|