Update custom_interface_app.py
Browse files- custom_interface_app.py +87 -4
custom_interface_app.py
CHANGED
@@ -8,7 +8,7 @@ class ASR(Pretrained):
|
|
8 |
def __init__(self, *args, **kwargs):
|
9 |
super().__init__(*args, **kwargs)
|
10 |
|
11 |
-
def
|
12 |
wavs = wavs.to(device)
|
13 |
wav_lens = wav_lens.to(device)
|
14 |
|
@@ -33,6 +33,22 @@ class ASR(Pretrained):
|
|
33 |
predicted_words = prediction
|
34 |
return predicted_words
|
35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
def filter_repetitions(self, seq, max_repetition_length):
|
37 |
seq = list(seq)
|
38 |
output = []
|
@@ -110,7 +126,74 @@ class ASR(Pretrained):
|
|
110 |
return waveform
|
111 |
|
112 |
|
113 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
114 |
# Load the audio file
|
115 |
# path = "long_sample.wav"
|
116 |
waveform, sr = librosa.load(path, sr=16000)
|
@@ -165,7 +248,7 @@ class ASR(Pretrained):
|
|
165 |
rel_length = torch.tensor([1.0]).to(device) # Adjust if necessary
|
166 |
|
167 |
# Pass the segment through the ASR model
|
168 |
-
segment_output = self.
|
169 |
yield segment_output
|
170 |
else:
|
171 |
waveform = torch.tensor(waveform).to(device)
|
@@ -173,7 +256,7 @@ class ASR(Pretrained):
|
|
173 |
# Fake a batch:
|
174 |
batch = waveform.unsqueeze(0)
|
175 |
rel_length = torch.tensor([1.0]).to(device)
|
176 |
-
outputs = self.
|
177 |
yield outputs
|
178 |
|
179 |
|
|
|
8 |
def __init__(self, *args, **kwargs):
|
9 |
super().__init__(*args, **kwargs)
|
10 |
|
11 |
+
def encode_batch_w2v2(self, device, wavs, wav_lens=None, normalize=False):
|
12 |
wavs = wavs.to(device)
|
13 |
wav_lens = wav_lens.to(device)
|
14 |
|
|
|
33 |
predicted_words = prediction
|
34 |
return predicted_words
|
35 |
|
36 |
+
|
37 |
+
def encode_batch_whisper(self, device, wavs, wav_lens=None, normalize=False):
|
38 |
+
wavs = wavs.to(device)
|
39 |
+
wav_lens = wav_lens.to(device)
|
40 |
+
|
41 |
+
# Forward encoder + decoder
|
42 |
+
tokens = torch.tensor([[1, 1]]) * self.mods.whisper.config.decoder_start_token_id
|
43 |
+
tokens = tokens.to(device)
|
44 |
+
enc_out, logits, _ = self.mods.whisper(wavs, tokens)
|
45 |
+
log_probs = self.hparams.log_softmax(logits)
|
46 |
+
|
47 |
+
hyps, _, _, _ = self.hparams.test_search(enc_out.detach(), wav_lens)
|
48 |
+
predicted_words = [self.mods.whisper.tokenizer.decode(token, skip_special_tokens=True).strip() for token in hyps]
|
49 |
+
return predicted_words
|
50 |
+
|
51 |
+
|
52 |
def filter_repetitions(self, seq, max_repetition_length):
|
53 |
seq = list(seq)
|
54 |
output = []
|
|
|
126 |
return waveform
|
127 |
|
128 |
|
129 |
+
def classify_file_w2v2(self, path, device):
|
130 |
+
# Load the audio file
|
131 |
+
# path = "long_sample.wav"
|
132 |
+
waveform, sr = librosa.load(path, sr=16000)
|
133 |
+
|
134 |
+
# increase the volume if needed
|
135 |
+
# waveform = self.increase_volume(waveform)
|
136 |
+
|
137 |
+
# Get audio length in seconds
|
138 |
+
audio_length = len(waveform) / sr
|
139 |
+
|
140 |
+
if audio_length >= 20:
|
141 |
+
print(f"Audio is too long ({audio_length:.2f} seconds), splitting into segments")
|
142 |
+
# Detect non-silent segments
|
143 |
+
|
144 |
+
non_silent_intervals = librosa.effects.split(waveform, top_db=20) # Adjust top_db for sensitivity
|
145 |
+
|
146 |
+
segments = []
|
147 |
+
current_segment = []
|
148 |
+
current_length = 0
|
149 |
+
max_duration = 20 * sr # Maximum segment duration in samples (20 seconds)
|
150 |
+
|
151 |
+
|
152 |
+
for interval in non_silent_intervals:
|
153 |
+
start, end = interval
|
154 |
+
segment_part = waveform[start:end]
|
155 |
+
|
156 |
+
# If adding the next part exceeds max duration, store the segment and start a new one
|
157 |
+
if current_length + len(segment_part) > max_duration:
|
158 |
+
segments.append(np.concatenate(current_segment))
|
159 |
+
current_segment = []
|
160 |
+
current_length = 0
|
161 |
+
|
162 |
+
current_segment.append(segment_part)
|
163 |
+
current_length += len(segment_part)
|
164 |
+
|
165 |
+
# Append the last segment if it's not empty
|
166 |
+
if current_segment:
|
167 |
+
segments.append(np.concatenate(current_segment))
|
168 |
+
|
169 |
+
# Process each segment
|
170 |
+
outputs = []
|
171 |
+
for i, segment in enumerate(segments):
|
172 |
+
print(f"Processing segment {i + 1}/{len(segments)}, length: {len(segment) / sr:.2f} seconds")
|
173 |
+
|
174 |
+
# import soundfile as sf
|
175 |
+
# sf.write(f"outputs/segment_{i}.wav", segment, sr)
|
176 |
+
|
177 |
+
segment_tensor = torch.tensor(segment).to(device)
|
178 |
+
|
179 |
+
# Fake a batch for the segment
|
180 |
+
batch = segment_tensor.unsqueeze(0).to(device)
|
181 |
+
rel_length = torch.tensor([1.0]).to(device) # Adjust if necessary
|
182 |
+
|
183 |
+
# Pass the segment through the ASR model
|
184 |
+
segment_output = self.encode_batch_w2v2(device, batch, rel_length)
|
185 |
+
yield segment_output
|
186 |
+
else:
|
187 |
+
waveform = torch.tensor(waveform).to(device)
|
188 |
+
waveform = waveform.to(device)
|
189 |
+
# Fake a batch:
|
190 |
+
batch = waveform.unsqueeze(0)
|
191 |
+
rel_length = torch.tensor([1.0]).to(device)
|
192 |
+
outputs = self.encode_batch_w2v2(device, batch, rel_length)
|
193 |
+
yield outputs
|
194 |
+
|
195 |
+
|
196 |
+
def classify_file_whisper_mkd(self, path, device):
|
197 |
# Load the audio file
|
198 |
# path = "long_sample.wav"
|
199 |
waveform, sr = librosa.load(path, sr=16000)
|
|
|
248 |
rel_length = torch.tensor([1.0]).to(device) # Adjust if necessary
|
249 |
|
250 |
# Pass the segment through the ASR model
|
251 |
+
segment_output = self.encode_batch_whisper(device, batch, rel_length)
|
252 |
yield segment_output
|
253 |
else:
|
254 |
waveform = torch.tensor(waveform).to(device)
|
|
|
256 |
# Fake a batch:
|
257 |
batch = waveform.unsqueeze(0)
|
258 |
rel_length = torch.tensor([1.0]).to(device)
|
259 |
+
outputs = self.encode_batch_whisper(device, batch, rel_length)
|
260 |
yield outputs
|
261 |
|
262 |
|