owiedotch commited on
Commit
086a0ea
1 Parent(s): d87908f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -4
app.py CHANGED
@@ -66,11 +66,12 @@ def decode_audio(encoded_file_path):
66
  compressed_data = temp_file.read()
67
  tokens_numpy_bytes = lz4.frame.decompress(compressed_data)
68
  tokens_numpy = np.frombuffer(tokens_numpy_bytes, dtype=np.int64) # Ensure dtype matches encoder's output
69
- tokens = torch.from_numpy(tokens_numpy).to(torch_device)
70
 
71
- # Ensure tokens has the right dimensions
72
- if tokens.ndimension() == 2: # If tokens have only 2 dimensions
73
- tokens = tokens.unsqueeze(0) # Add batch dimension
 
 
74
 
75
  # Debugging prints to check tensor shapes
76
  print(f"Tokens shape: {tokens.shape}, dtype: {tokens.dtype}")
@@ -94,6 +95,11 @@ async def stream_decode_audio(encoded_file_path) -> Generator[tuple, None, None]
94
  compressed_data = temp_file.read()
95
  tokens_numpy_bytes = lz4.frame.decompress(compressed_data)
96
  tokens_numpy = np.frombuffer(tokens_numpy_bytes, dtype=np.int64) # Ensure dtype matches encoder's output
 
 
 
 
 
97
  tokens = torch.from_numpy(tokens_numpy).to(torch_device)
98
 
99
  # Ensure tokens has the right dimensions
 
66
  compressed_data = temp_file.read()
67
  tokens_numpy_bytes = lz4.frame.decompress(compressed_data)
68
  tokens_numpy = np.frombuffer(tokens_numpy_bytes, dtype=np.int64) # Ensure dtype matches encoder's output
 
69
 
70
+ # Check if tokens are 1D and reshape to 3D
71
+ if tokens_numpy.ndim == 1:
72
+ tokens_numpy = tokens_numpy.reshape(1, -1, 1) # Reshape to [batch_size, token_length, 1]
73
+
74
+ tokens = torch.from_numpy(tokens_numpy).to(torch_device)
75
 
76
  # Debugging prints to check tensor shapes
77
  print(f"Tokens shape: {tokens.shape}, dtype: {tokens.dtype}")
 
95
  compressed_data = temp_file.read()
96
  tokens_numpy_bytes = lz4.frame.decompress(compressed_data)
97
  tokens_numpy = np.frombuffer(tokens_numpy_bytes, dtype=np.int64) # Ensure dtype matches encoder's output
98
+
99
+ # Check if tokens are 1D and reshape to 3D
100
+ if tokens_numpy.ndim == 1:
101
+ tokens_numpy = tokens_numpy.reshape(1, -1, 1) # Reshape to [batch_size, token_length, 1]
102
+
103
  tokens = torch.from_numpy(tokens_numpy).to(torch_device)
104
 
105
  # Ensure tokens has the right dimensions