Porjaz commited on
Commit
8214844
1 Parent(s): 433d102

Update custom_interface_app.py

Browse files
Files changed (1) hide show
  1. custom_interface_app.py +8 -5
custom_interface_app.py CHANGED
@@ -41,7 +41,7 @@ class ASR(Pretrained):
41
  # Forward encoder + decoder
42
  tokens = torch.tensor([[1, 1]]) * self.mods.whisper.config.decoder_start_token_id
43
  tokens = tokens.to(device)
44
- enc_out, logits, _ = self.mods.whisper(wavs, tokens)
45
  log_probs = self.hparams.log_softmax(logits)
46
 
47
  hyps, _, _, _ = self.hparams.test_search(enc_out.detach(), wav_lens)
@@ -174,7 +174,6 @@ class ASR(Pretrained):
174
 
175
  def classify_file_whisper_mkd(self, path, device):
176
  # Load the audio file
177
- # path = "long_sample.wav"
178
  waveform, sr = librosa.load(path, sr=16000)
179
 
180
  # Get audio length in seconds
@@ -202,7 +201,8 @@ class ASR(Pretrained):
202
 
203
  # Fake a batch for the segment
204
  batch = segment_tensor.unsqueeze(0).to(device)
205
- rel_length = torch.tensor([1.0]).to(device)
 
206
 
207
  # Pass the segment through the ASR model
208
  segment_output = self.encode_batch_whisper(device, batch, rel_length)
@@ -210,13 +210,14 @@ class ASR(Pretrained):
210
  else:
211
  waveform = torch.tensor(waveform).to(device)
212
  waveform = waveform.to(device)
213
- # Fake a batch:
214
  batch = waveform.unsqueeze(0)
215
- rel_length = torch.tensor([1.0]).to(device)
 
216
  outputs = self.encode_batch_whisper(device, batch, rel_length)
217
  yield outputs
218
 
219
 
 
220
  def classify_file_whisper(self, path, pipe, device):
221
  waveform, sr = librosa.load(path, sr=16000)
222
  transcription = pipe(waveform, generate_kwargs={"language": "macedonian"})["text"]
@@ -252,6 +253,7 @@ class ASR(Pretrained):
252
 
253
  # Pass the segment through the ASR model
254
  inputs = processor(segment_tensor, sampling_rate=16_000, return_tensors="pt").to(device)
 
255
  outputs = model(**inputs).logits
256
  ids = torch.argmax(outputs, dim=-1)[0]
257
  segment_output = processor.decode(ids)
@@ -259,6 +261,7 @@ class ASR(Pretrained):
259
  else:
260
  waveform = torch.tensor(waveform).to(device)
261
  inputs = processor(waveform, sampling_rate=16_000, return_tensors="pt").to(device)
 
262
  outputs = model(**inputs).logits
263
  ids = torch.argmax(outputs, dim=-1)[0]
264
  transcription = processor.decode(ids)
 
41
  # Forward encoder + decoder
42
  tokens = torch.tensor([[1, 1]]) * self.mods.whisper.config.decoder_start_token_id
43
  tokens = tokens.to(device)
44
+ enc_out, logits, _ = self.mods.whisper(wavs.detach(), tokens.detach())
45
  log_probs = self.hparams.log_softmax(logits)
46
 
47
  hyps, _, _, _ = self.hparams.test_search(enc_out.detach(), wav_lens)
 
174
 
175
  def classify_file_whisper_mkd(self, path, device):
176
  # Load the audio file
 
177
  waveform, sr = librosa.load(path, sr=16000)
178
 
179
  # Get audio length in seconds
 
201
 
202
  # Fake a batch for the segment
203
  batch = segment_tensor.unsqueeze(0).to(device)
204
+ batch = batch.to(torch.float16)
205
+ rel_length = torch.tensor([1.0], dtype=torch.float16).to(device)
206
 
207
  # Pass the segment through the ASR model
208
  segment_output = self.encode_batch_whisper(device, batch, rel_length)
 
210
  else:
211
  waveform = torch.tensor(waveform).to(device)
212
  waveform = waveform.to(device)
 
213
  batch = waveform.unsqueeze(0)
214
+ batch = batch.to(torch.float16)
215
+ rel_length = torch.tensor([1.0], dtype=torch.float16).to(device)
216
  outputs = self.encode_batch_whisper(device, batch, rel_length)
217
  yield outputs
218
 
219
 
220
+
221
  def classify_file_whisper(self, path, pipe, device):
222
  waveform, sr = librosa.load(path, sr=16000)
223
  transcription = pipe(waveform, generate_kwargs={"language": "macedonian"})["text"]
 
253
 
254
  # Pass the segment through the ASR model
255
  inputs = processor(segment_tensor, sampling_rate=16_000, return_tensors="pt").to(device)
256
+ inputs['input_values'] = inputs['input_values'].to(torch.float16)
257
  outputs = model(**inputs).logits
258
  ids = torch.argmax(outputs, dim=-1)[0]
259
  segment_output = processor.decode(ids)
 
261
  else:
262
  waveform = torch.tensor(waveform).to(device)
263
  inputs = processor(waveform, sampling_rate=16_000, return_tensors="pt").to(device)
264
+ inputs['input_values'] = inputs['input_values'].to(torch.float16)
265
  outputs = model(**inputs).logits
266
  ids = torch.argmax(outputs, dim=-1)[0]
267
  transcription = processor.decode(ids)