Porjaz commited on
Commit
c31791e
1 Parent(s): abf03fe

Update custom_interface_app.py

Browse files
Files changed (1) hide show
  1. custom_interface_app.py +87 -4
custom_interface_app.py CHANGED
@@ -8,7 +8,7 @@ class ASR(Pretrained):
8
  def __init__(self, *args, **kwargs):
9
  super().__init__(*args, **kwargs)
10
 
11
- def encode_batch(self, device, wavs, wav_lens=None, normalize=False):
12
  wavs = wavs.to(device)
13
  wav_lens = wav_lens.to(device)
14
 
@@ -33,6 +33,22 @@ class ASR(Pretrained):
33
  predicted_words = prediction
34
  return predicted_words
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  def filter_repetitions(self, seq, max_repetition_length):
37
  seq = list(seq)
38
  output = []
@@ -110,7 +126,74 @@ class ASR(Pretrained):
110
  return waveform
111
 
112
 
113
- def classify_file(self, path, device):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  # Load the audio file
115
  # path = "long_sample.wav"
116
  waveform, sr = librosa.load(path, sr=16000)
@@ -165,7 +248,7 @@ class ASR(Pretrained):
165
  rel_length = torch.tensor([1.0]).to(device) # Adjust if necessary
166
 
167
  # Pass the segment through the ASR model
168
- segment_output = self.encode_batch(device, batch, rel_length)
169
  yield segment_output
170
  else:
171
  waveform = torch.tensor(waveform).to(device)
@@ -173,7 +256,7 @@ class ASR(Pretrained):
173
  # Fake a batch:
174
  batch = waveform.unsqueeze(0)
175
  rel_length = torch.tensor([1.0]).to(device)
176
- outputs = self.encode_batch(device, batch, rel_length)
177
  yield outputs
178
 
179
 
 
8
  def __init__(self, *args, **kwargs):
9
  super().__init__(*args, **kwargs)
10
 
11
+ def encode_batch_w2v2(self, device, wavs, wav_lens=None, normalize=False):
12
  wavs = wavs.to(device)
13
  wav_lens = wav_lens.to(device)
14
 
 
33
  predicted_words = prediction
34
  return predicted_words
35
 
36
+
37
+ def encode_batch_whisper(self, device, wavs, wav_lens=None, normalize=False):
38
+ wavs = wavs.to(device)
39
+ wav_lens = wav_lens.to(device)
40
+
41
+ # Forward encoder + decoder
42
+ tokens = torch.tensor([[1, 1]]) * self.mods.whisper.config.decoder_start_token_id
43
+ tokens = tokens.to(device)
44
+ enc_out, logits, _ = self.mods.whisper(wavs, tokens)
45
+ log_probs = self.hparams.log_softmax(logits)
46
+
47
+ hyps, _, _, _ = self.hparams.test_search(enc_out.detach(), wav_lens)
48
+ predicted_words = [self.mods.whisper.tokenizer.decode(token, skip_special_tokens=True).strip() for token in hyps]
49
+ return predicted_words
50
+
51
+
52
  def filter_repetitions(self, seq, max_repetition_length):
53
  seq = list(seq)
54
  output = []
 
126
  return waveform
127
 
128
 
129
+ def classify_file_w2v2(self, path, device):
130
+ # Load the audio file
131
+ # path = "long_sample.wav"
132
+ waveform, sr = librosa.load(path, sr=16000)
133
+
134
+ # increase the volume if needed
135
+ # waveform = self.increase_volume(waveform)
136
+
137
+ # Get audio length in seconds
138
+ audio_length = len(waveform) / sr
139
+
140
+ if audio_length >= 20:
141
+ print(f"Audio is too long ({audio_length:.2f} seconds), splitting into segments")
142
+ # Detect non-silent segments
143
+
144
+ non_silent_intervals = librosa.effects.split(waveform, top_db=20) # Adjust top_db for sensitivity
145
+
146
+ segments = []
147
+ current_segment = []
148
+ current_length = 0
149
+ max_duration = 20 * sr # Maximum segment duration in samples (20 seconds)
150
+
151
+
152
+ for interval in non_silent_intervals:
153
+ start, end = interval
154
+ segment_part = waveform[start:end]
155
+
156
+ # If adding the next part exceeds max duration, store the segment and start a new one
157
+ if current_length + len(segment_part) > max_duration:
158
+ segments.append(np.concatenate(current_segment))
159
+ current_segment = []
160
+ current_length = 0
161
+
162
+ current_segment.append(segment_part)
163
+ current_length += len(segment_part)
164
+
165
+ # Append the last segment if it's not empty
166
+ if current_segment:
167
+ segments.append(np.concatenate(current_segment))
168
+
169
+ # Process each segment
170
+ outputs = []
171
+ for i, segment in enumerate(segments):
172
+ print(f"Processing segment {i + 1}/{len(segments)}, length: {len(segment) / sr:.2f} seconds")
173
+
174
+ # import soundfile as sf
175
+ # sf.write(f"outputs/segment_{i}.wav", segment, sr)
176
+
177
+ segment_tensor = torch.tensor(segment).to(device)
178
+
179
+ # Fake a batch for the segment
180
+ batch = segment_tensor.unsqueeze(0).to(device)
181
+ rel_length = torch.tensor([1.0]).to(device) # Adjust if necessary
182
+
183
+ # Pass the segment through the ASR model
184
+ segment_output = self.encode_batch_w2v2(device, batch, rel_length)
185
+ yield segment_output
186
+ else:
187
+ waveform = torch.tensor(waveform).to(device)
188
+ waveform = waveform.to(device)
189
+ # Fake a batch:
190
+ batch = waveform.unsqueeze(0)
191
+ rel_length = torch.tensor([1.0]).to(device)
192
+ outputs = self.encode_batch_w2v2(device, batch, rel_length)
193
+ yield outputs
194
+
195
+
196
+ def classify_file_whisper_mkd(self, path, device):
197
  # Load the audio file
198
  # path = "long_sample.wav"
199
  waveform, sr = librosa.load(path, sr=16000)
 
248
  rel_length = torch.tensor([1.0]).to(device) # Adjust if necessary
249
 
250
  # Pass the segment through the ASR model
251
+ segment_output = self.encode_batch_whisper(device, batch, rel_length)
252
  yield segment_output
253
  else:
254
  waveform = torch.tensor(waveform).to(device)
 
256
  # Fake a batch:
257
  batch = waveform.unsqueeze(0)
258
  rel_length = torch.tensor([1.0]).to(device)
259
+ outputs = self.encode_batch_whisper(device, batch, rel_length)
260
  yield outputs
261
 
262