divakaivan commited on
Commit
4f76169
1 Parent(s): 014aba2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -6
app.py CHANGED
@@ -137,13 +137,15 @@ def predict(text, speaker):
137
 
138
  ### ### ###
139
  example = dataset['test'][11]
140
- # speaker_embedding = torch.tensor(example["speaker_embeddings"]).unsqueeze(0)
141
- speaker_embedding = torch.tensor(example["speaker_embeddings"]).unsqueeze(0).unsqueeze(0).to(device)
142
- speaker_embedding = speaker_embedding.expand(-1, inputs["input_ids"].size(1), -1)
143
- spectrogram = model.generate_speech(inputs["input_ids"].to(device), speaker_embedding)
 
 
 
144
 
145
- # speaker_embedding = torch.tensor(speaker_embedding).unsqueeze(0)
146
- # spectrogram = model.generate_speech(inputs["input_ids"], speaker_embedding)
147
  with torch.no_grad():
148
  speech = vocoder(spectrogram)
149
  # speech = model.generate_speech(input_ids, speaker_embedding, vocoder=vocoder)
 
137
 
138
  ### ### ###
139
  example = dataset['test'][11]
140
+ speaker_embedding = torch.tensor(example["speaker_embeddings"]).unsqueeze(0).to(device)
141
+
142
+ # Ensure the speaker_embedding has the correct dimensions
143
+ if speaker_embedding.dim() == 2:
144
+ speaker_embedding = speaker_embedding.unsqueeze(1).expand(-1, inputs["input_ids"].size(1), -1)
145
+ elif speaker_embedding.dim() == 3:
146
+ speaker_embedding = speaker_embedding.expand(-1, inputs["input_ids"].size(1), -1)
147
 
148
+ spectrogram = model.generate_speech(inputs["input_ids"].to(device), speaker_embedding)
 
149
  with torch.no_grad():
150
  speech = vocoder(spectrogram)
151
  # speech = model.generate_speech(input_ids, speaker_embedding, vocoder=vocoder)