owiedotch commited on
Commit
d87908f
1 Parent(s): 4ca3581

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -2
app.py CHANGED
@@ -69,7 +69,11 @@ def decode_audio(encoded_file_path):
69
  tokens = torch.from_numpy(tokens_numpy).to(torch_device)
70
 
71
  # Ensure tokens has the right dimensions
72
- tokens = tokens.unsqueeze(0) if tokens.ndimension() == 1 else tokens
 
 
 
 
73
 
74
  # Decode the audio
75
  with torch.no_grad():
@@ -93,7 +97,8 @@ async def stream_decode_audio(encoded_file_path) -> Generator[tuple, None, None]
93
  tokens = torch.from_numpy(tokens_numpy).to(torch_device)
94
 
95
  # Ensure tokens has the right dimensions
96
- tokens = tokens.unsqueeze(0) if tokens.ndimension() == 1 else tokens
 
97
 
98
  # Decode the audio in chunks
99
  chunk_size = sample_rate # Use the stored sample rate as chunk size
 
69
  tokens = torch.from_numpy(tokens_numpy).to(torch_device)
70
 
71
  # Ensure tokens has the right dimensions
72
+ if tokens.ndimension() == 2: # If tokens have only 2 dimensions
73
+ tokens = tokens.unsqueeze(0) # Add batch dimension
74
+
75
+ # Debugging prints to check tensor shapes
76
+ print(f"Tokens shape: {tokens.shape}, dtype: {tokens.dtype}")
77
 
78
  # Decode the audio
79
  with torch.no_grad():
 
97
  tokens = torch.from_numpy(tokens_numpy).to(torch_device)
98
 
99
  # Ensure tokens has the right dimensions
100
+ if tokens.ndimension() == 2: # If tokens have only 2 dimensions
101
+ tokens = tokens.unsqueeze(0) # Add batch dimension
102
 
103
  # Decode the audio in chunks
104
  chunk_size = sample_rate # Use the stored sample rate as chunk size