Update custom_interface_app.py
Browse files- custom_interface_app.py +8 -5
custom_interface_app.py
CHANGED
@@ -41,7 +41,7 @@ class ASR(Pretrained):
|
|
41 |
# Forward encoder + decoder
|
42 |
tokens = torch.tensor([[1, 1]]) * self.mods.whisper.config.decoder_start_token_id
|
43 |
tokens = tokens.to(device)
|
44 |
-
enc_out, logits, _ = self.mods.whisper(wavs, tokens)
|
45 |
log_probs = self.hparams.log_softmax(logits)
|
46 |
|
47 |
hyps, _, _, _ = self.hparams.test_search(enc_out.detach(), wav_lens)
|
@@ -174,7 +174,6 @@ class ASR(Pretrained):
|
|
174 |
|
175 |
def classify_file_whisper_mkd(self, path, device):
|
176 |
# Load the audio file
|
177 |
-
# path = "long_sample.wav"
|
178 |
waveform, sr = librosa.load(path, sr=16000)
|
179 |
|
180 |
# Get audio length in seconds
|
@@ -202,7 +201,8 @@ class ASR(Pretrained):
|
|
202 |
|
203 |
# Fake a batch for the segment
|
204 |
batch = segment_tensor.unsqueeze(0).to(device)
|
205 |
-
|
|
|
206 |
|
207 |
# Pass the segment through the ASR model
|
208 |
segment_output = self.encode_batch_whisper(device, batch, rel_length)
|
@@ -210,13 +210,14 @@ class ASR(Pretrained):
|
|
210 |
else:
|
211 |
waveform = torch.tensor(waveform).to(device)
|
212 |
waveform = waveform.to(device)
|
213 |
-
# Fake a batch:
|
214 |
batch = waveform.unsqueeze(0)
|
215 |
-
|
|
|
216 |
outputs = self.encode_batch_whisper(device, batch, rel_length)
|
217 |
yield outputs
|
218 |
|
219 |
|
|
|
220 |
def classify_file_whisper(self, path, pipe, device):
|
221 |
waveform, sr = librosa.load(path, sr=16000)
|
222 |
transcription = pipe(waveform, generate_kwargs={"language": "macedonian"})["text"]
|
@@ -252,6 +253,7 @@ class ASR(Pretrained):
|
|
252 |
|
253 |
# Pass the segment through the ASR model
|
254 |
inputs = processor(segment_tensor, sampling_rate=16_000, return_tensors="pt").to(device)
|
|
|
255 |
outputs = model(**inputs).logits
|
256 |
ids = torch.argmax(outputs, dim=-1)[0]
|
257 |
segment_output = processor.decode(ids)
|
@@ -259,6 +261,7 @@ class ASR(Pretrained):
|
|
259 |
else:
|
260 |
waveform = torch.tensor(waveform).to(device)
|
261 |
inputs = processor(waveform, sampling_rate=16_000, return_tensors="pt").to(device)
|
|
|
262 |
outputs = model(**inputs).logits
|
263 |
ids = torch.argmax(outputs, dim=-1)[0]
|
264 |
transcription = processor.decode(ids)
|
|
|
41 |
# Forward encoder + decoder
|
42 |
tokens = torch.tensor([[1, 1]]) * self.mods.whisper.config.decoder_start_token_id
|
43 |
tokens = tokens.to(device)
|
44 |
+
enc_out, logits, _ = self.mods.whisper(wavs.detach(), tokens.detach())
|
45 |
log_probs = self.hparams.log_softmax(logits)
|
46 |
|
47 |
hyps, _, _, _ = self.hparams.test_search(enc_out.detach(), wav_lens)
|
|
|
174 |
|
175 |
def classify_file_whisper_mkd(self, path, device):
|
176 |
# Load the audio file
|
|
|
177 |
waveform, sr = librosa.load(path, sr=16000)
|
178 |
|
179 |
# Get audio length in seconds
|
|
|
201 |
|
202 |
# Fake a batch for the segment
|
203 |
batch = segment_tensor.unsqueeze(0).to(device)
|
204 |
+
batch = batch.to(torch.float16)
|
205 |
+
rel_length = torch.tensor([1.0], dtype=torch.float16).to(device)
|
206 |
|
207 |
# Pass the segment through the ASR model
|
208 |
segment_output = self.encode_batch_whisper(device, batch, rel_length)
|
|
|
210 |
else:
|
211 |
waveform = torch.tensor(waveform).to(device)
|
212 |
waveform = waveform.to(device)
|
|
|
213 |
batch = waveform.unsqueeze(0)
|
214 |
+
batch = batch.to(torch.float16)
|
215 |
+
rel_length = torch.tensor([1.0], dtype=torch.float16).to(device)
|
216 |
outputs = self.encode_batch_whisper(device, batch, rel_length)
|
217 |
yield outputs
|
218 |
|
219 |
|
220 |
+
|
221 |
def classify_file_whisper(self, path, pipe, device):
|
222 |
waveform, sr = librosa.load(path, sr=16000)
|
223 |
transcription = pipe(waveform, generate_kwargs={"language": "macedonian"})["text"]
|
|
|
253 |
|
254 |
# Pass the segment through the ASR model
|
255 |
inputs = processor(segment_tensor, sampling_rate=16_000, return_tensors="pt").to(device)
|
256 |
+
inputs['input_values'] = inputs['input_values'].to(torch.float16)
|
257 |
outputs = model(**inputs).logits
|
258 |
ids = torch.argmax(outputs, dim=-1)[0]
|
259 |
segment_output = processor.decode(ids)
|
|
|
261 |
else:
|
262 |
waveform = torch.tensor(waveform).to(device)
|
263 |
inputs = processor(waveform, sampling_rate=16_000, return_tensors="pt").to(device)
|
264 |
+
inputs['input_values'] = inputs['input_values'].to(torch.float16)
|
265 |
outputs = model(**inputs).logits
|
266 |
ids = torch.argmax(outputs, dim=-1)[0]
|
267 |
transcription = processor.decode(ids)
|