tyler commited on
Commit
1cd0a28
·
1 Parent(s): a009b73
.gitignore ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ pip-wheel-metadata/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ # PyInstaller
31
+ # Usually these files are written by a python script from a template
32
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
33
+ *.manifest
34
+ *.spec
35
+
36
+ # Installer logs
37
+ pip-log.txt
38
+ pip-delete-this-directory.txt
39
+
40
+ # Unit test / coverage reports
41
+ htmlcov/
42
+ .tox/
43
+ .nox/
44
+ .coverage
45
+ .coverage.*
46
+ .cache
47
+ nosetests.xml
48
+ coverage.xml
49
+ *.cover
50
+ *.py,cover
51
+ .hypothesis/
52
+ .pytest_cache/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ target/
76
+
77
+ # Jupyter Notebook
78
+ .ipynb_checkpoints
79
+
80
+ # IPython
81
+ profile_default/
82
+ ipython_config.py
83
+
84
+ # pyenv
85
+ .python-version
86
+
87
+ # pipenv
88
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
90
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
91
+ # install all needed dependencies.
92
+ #Pipfile.lock
93
+
94
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95
+ __pypackages__/
96
+
97
+ # Celery stuff
98
+ celerybeat-schedule
99
+ celerybeat.pid
100
+
101
+ # SageMath parsed files
102
+ *.sage.py
103
+
104
+ # Environments
105
+ .env
106
+ .venv
107
+ env/
108
+ venv/
109
+ ENV/
110
+ env.bak/
111
+ venv.bak/
112
+
113
+ # Spyder project settings
114
+ .spyderproject
115
+ .spyproject
116
+
117
+ # Rope project settings
118
+ .ropeproject
119
+
120
+ # mkdocs documentation
121
+ /site
122
+
123
+ # mypy
124
+ .mypy_cache/
125
+ .dmypy.json
126
+ dmypy.json
127
+
128
+ # Pyre type checker
129
+ .pyre/
130
+
131
+ # PyCharm
132
+ .idea/
133
+
134
+ # my ignore
135
+ pretrain_models/
136
+ deprecated/
137
+ preprocess/
138
+ **/.DS_Store
139
+ **/._.DS_Store
speak_detect/snakers4_silero-vad_master/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from importlib.metadata import version
2
+ try:
3
+ __version__ = version(__name__)
4
+ except:
5
+ pass
6
+
7
+ from silero_vad.model import load_silero_vad
8
+ from silero_vad.utils_vad import (get_speech_timestamps,
9
+ save_audio,
10
+ read_audio,
11
+ VADIterator,
12
+ collect_chunks)
speak_detect/snakers4_silero-vad_master/data/__init__.py ADDED
File without changes
speak_detect/snakers4_silero-vad_master/data/silero_vad.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2623a2953f6ff3d2c1e61740c6cdb7168133479b267dfef114a4a3cc5bdd788f
3
+ size 2327524
speak_detect/snakers4_silero-vad_master/hubconf.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dependencies = ['torch', 'torchaudio']
2
+ import torch
3
+ import os
4
+ import sys
5
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src'))
6
+ from utils_vad import (init_jit_model,
7
+ get_speech_timestamps,
8
+ save_audio,
9
+ read_audio,
10
+ VADIterator,
11
+ collect_chunks,
12
+ drop_chunks,
13
+ OnnxWrapper)
14
+
15
+
16
+ def versiontuple(v):
17
+ splitted = v.split('+')[0].split(".")
18
+ version_list = []
19
+ for i in splitted:
20
+ try:
21
+ version_list.append(int(i))
22
+ except:
23
+ version_list.append(0)
24
+ return tuple(version_list)
25
+
26
+
27
+ def silero_vad(onnx=False, force_onnx_cpu=False):
28
+ """Silero Voice Activity Detector
29
+ Returns a model with a set of utils
30
+ Please see https://github.com/snakers4/silero-vad for usage examples
31
+ """
32
+
33
+ if not onnx:
34
+ installed_version = torch.__version__
35
+ supported_version = '1.12.0'
36
+ if versiontuple(installed_version) < versiontuple(supported_version):
37
+ raise Exception(f'Please install torch {supported_version} or greater ({installed_version} installed)')
38
+
39
+ model_dir = os.path.join(os.path.dirname(__file__), 'data')
40
+ if onnx:
41
+ model = OnnxWrapper(os.path.join(model_dir, 'silero_vad.onnx'), force_onnx_cpu)
42
+ else:
43
+ model = init_jit_model(os.path.join(model_dir, 'silero_vad.jit'))
44
+ utils = (get_speech_timestamps,
45
+ save_audio,
46
+ read_audio,
47
+ VADIterator,
48
+ collect_chunks,
49
+ drop_chunks)
50
+
51
+ return model, utils
speak_detect/snakers4_silero-vad_master/model.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .utils_vad import init_jit_model, OnnxWrapper
2
+ import torch
3
+ torch.set_num_threads(1)
4
+
5
+ def load_silero_vad(onnx=False):
6
+ model_name = 'silero_vad.onnx' if onnx else 'silero_vad.jit'
7
+ package_path = "silero_vad.data"
8
+
9
+ try:
10
+ import importlib_resources as impresources
11
+ model_file_path = str(impresources.files(package_path).joinpath(model_name))
12
+ except:
13
+ from importlib import resources as impresources
14
+ try:
15
+ with impresources.path(package_path, model_name) as f:
16
+ model_file_path = f
17
+ except:
18
+ model_file_path = str(impresources.files(package_path).joinpath(model_name))
19
+
20
+ if onnx:
21
+ model = OnnxWrapper(model_file_path, force_onnx_cpu=True)
22
+ else:
23
+ model = init_jit_model(model_file_path)
24
+
25
+ return model
speak_detect/snakers4_silero-vad_master/utils_vad.py ADDED
@@ -0,0 +1,489 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchaudio
3
+ from typing import Callable, List
4
+ import warnings
5
+
6
+ languages = ['ru', 'en', 'de', 'es']
7
+
8
+
9
+ class OnnxWrapper():
10
+
11
+ def __init__(self, path, force_onnx_cpu=False):
12
+ import numpy as np
13
+ global np
14
+ import onnxruntime
15
+
16
+ opts = onnxruntime.SessionOptions()
17
+ opts.inter_op_num_threads = 1
18
+ opts.intra_op_num_threads = 1
19
+
20
+ if force_onnx_cpu and 'CPUExecutionProvider' in onnxruntime.get_available_providers():
21
+ self.session = onnxruntime.InferenceSession(path, providers=['CPUExecutionProvider'], sess_options=opts)
22
+ else:
23
+ self.session = onnxruntime.InferenceSession(path, sess_options=opts)
24
+
25
+ self.reset_states()
26
+ self.sample_rates = [8000, 16000]
27
+
28
+ def _validate_input(self, x, sr: int):
29
+ if x.dim() == 1:
30
+ x = x.unsqueeze(0)
31
+ if x.dim() > 2:
32
+ raise ValueError(f"Too many dimensions for input audio chunk {x.dim()}")
33
+
34
+ if sr != 16000 and (sr % 16000 == 0):
35
+ step = sr // 16000
36
+ x = x[:,::step]
37
+ sr = 16000
38
+
39
+ if sr not in self.sample_rates:
40
+ raise ValueError(f"Supported sampling rates: {self.sample_rates} (or multiply of 16000)")
41
+ if sr / x.shape[1] > 31.25:
42
+ raise ValueError("Input audio chunk is too short")
43
+
44
+ return x, sr
45
+
46
+ def reset_states(self, batch_size=1):
47
+ self._state = torch.zeros((2, batch_size, 128)).float()
48
+ self._context = torch.zeros(0)
49
+ self._last_sr = 0
50
+ self._last_batch_size = 0
51
+
52
+ def __call__(self, x, sr: int):
53
+
54
+ x, sr = self._validate_input(x, sr)
55
+ num_samples = 512 if sr == 16000 else 256
56
+
57
+ if x.shape[-1] != num_samples:
58
+ raise ValueError(f"Provided number of samples is {x.shape[-1]} (Supported values: 256 for 8000 sample rate, 512 for 16000)")
59
+
60
+ batch_size = x.shape[0]
61
+ context_size = 64 if sr == 16000 else 32
62
+
63
+ if not self._last_batch_size:
64
+ self.reset_states(batch_size)
65
+ if (self._last_sr) and (self._last_sr != sr):
66
+ self.reset_states(batch_size)
67
+ if (self._last_batch_size) and (self._last_batch_size != batch_size):
68
+ self.reset_states(batch_size)
69
+
70
+ if not len(self._context):
71
+ self._context = torch.zeros(batch_size, context_size)
72
+
73
+ x = torch.cat([self._context, x], dim=1)
74
+ if sr in [8000, 16000]:
75
+ ort_inputs = {'input': x.numpy(), 'state': self._state.numpy(), 'sr': np.array(sr, dtype='int64')}
76
+ ort_outs = self.session.run(None, ort_inputs)
77
+ out, state = ort_outs
78
+ self._state = torch.from_numpy(state)
79
+ else:
80
+ raise ValueError()
81
+
82
+ self._context = x[..., -context_size:]
83
+ self._last_sr = sr
84
+ self._last_batch_size = batch_size
85
+
86
+ out = torch.from_numpy(out)
87
+ return out
88
+
89
+ def audio_forward(self, x, sr: int):
90
+ outs = []
91
+ x, sr = self._validate_input(x, sr)
92
+ self.reset_states()
93
+ num_samples = 512 if sr == 16000 else 256
94
+
95
+ if x.shape[1] % num_samples:
96
+ pad_num = num_samples - (x.shape[1] % num_samples)
97
+ x = torch.nn.functional.pad(x, (0, pad_num), 'constant', value=0.0)
98
+
99
+ for i in range(0, x.shape[1], num_samples):
100
+ wavs_batch = x[:, i:i+num_samples]
101
+ out_chunk = self.__call__(wavs_batch, sr)
102
+ outs.append(out_chunk)
103
+
104
+ stacked = torch.cat(outs, dim=1)
105
+ return stacked.cpu()
106
+
107
+
108
+ class Validator():
109
+ def __init__(self, url, force_onnx_cpu):
110
+ self.onnx = True if url.endswith('.onnx') else False
111
+ torch.hub.download_url_to_file(url, 'inf.model')
112
+ if self.onnx:
113
+ import onnxruntime
114
+ if force_onnx_cpu and 'CPUExecutionProvider' in onnxruntime.get_available_providers():
115
+ self.model = onnxruntime.InferenceSession('inf.model', providers=['CPUExecutionProvider'])
116
+ else:
117
+ self.model = onnxruntime.InferenceSession('inf.model')
118
+ else:
119
+ self.model = init_jit_model(model_path='inf.model')
120
+
121
+ def __call__(self, inputs: torch.Tensor):
122
+ with torch.no_grad():
123
+ if self.onnx:
124
+ ort_inputs = {'input': inputs.cpu().numpy()}
125
+ outs = self.model.run(None, ort_inputs)
126
+ outs = [torch.Tensor(x) for x in outs]
127
+ else:
128
+ outs = self.model(inputs)
129
+
130
+ return outs
131
+
132
+
133
+ def read_audio(path: str,
134
+ sampling_rate: int = 16000):
135
+ list_backends = torchaudio.list_audio_backends()
136
+
137
+ assert len(list_backends) > 0, 'The list of available backends is empty, please install backend manually. \
138
+ \n Recommendations: \n \tSox (UNIX OS) \n \tSoundfile (Windows OS, UNIX OS) \n \tffmpeg (Windows OS, UNIX OS)'
139
+
140
+ try:
141
+ effects = [
142
+ ['channels', '1'],
143
+ ['rate', str(sampling_rate)]
144
+ ]
145
+
146
+ wav, sr = torchaudio.sox_effects.apply_effects_file(path, effects=effects)
147
+ except:
148
+ wav, sr = torchaudio.load(path)
149
+
150
+ if wav.size(0) > 1:
151
+ wav = wav.mean(dim=0, keepdim=True)
152
+
153
+ if sr != sampling_rate:
154
+ transform = torchaudio.transforms.Resample(orig_freq=sr,
155
+ new_freq=sampling_rate)
156
+ wav = transform(wav)
157
+ sr = sampling_rate
158
+
159
+ assert sr == sampling_rate
160
+ return wav.squeeze(0)
161
+
162
+
163
+ def save_audio(path: str,
164
+ tensor: torch.Tensor,
165
+ sampling_rate: int = 16000):
166
+ torchaudio.save(path, tensor.unsqueeze(0), sampling_rate, bits_per_sample=16)
167
+
168
+
169
+ def init_jit_model(model_path: str,
170
+ device=torch.device('cpu')):
171
+ model = torch.jit.load(model_path, map_location=device)
172
+ model.eval()
173
+ return model
174
+
175
+
176
+ def make_visualization(probs, step):
177
+ import pandas as pd
178
+ pd.DataFrame({'probs': probs},
179
+ index=[x * step for x in range(len(probs))]).plot(figsize=(16, 8),
180
+ kind='area', ylim=[0, 1.05], xlim=[0, len(probs) * step],
181
+ xlabel='seconds',
182
+ ylabel='speech probability',
183
+ colormap='tab20')
184
+
185
+
186
+ @torch.no_grad()
187
+ def get_speech_timestamps(audio: torch.Tensor,
188
+ model,
189
+ threshold: float = 0.5,
190
+ sampling_rate: int = 16000,
191
+ min_speech_duration_ms: int = 250,
192
+ max_speech_duration_s: float = float('inf'),
193
+ min_silence_duration_ms: int = 100,
194
+ speech_pad_ms: int = 30,
195
+ return_seconds: bool = False,
196
+ visualize_probs: bool = False,
197
+ progress_tracking_callback: Callable[[float], None] = None,
198
+ window_size_samples: int = 512,):
199
+
200
+ """
201
+ This method is used for splitting long audios into speech chunks using silero VAD
202
+
203
+ Parameters
204
+ ----------
205
+ audio: torch.Tensor, one dimensional
206
+ One dimensional float torch.Tensor, other types are casted to torch if possible
207
+
208
+ model: preloaded .jit/.onnx silero VAD model
209
+
210
+ threshold: float (default - 0.5)
211
+ Speech threshold. Silero VAD outputs speech probabilities for each audio chunk, probabilities ABOVE this value are considered as SPEECH.
212
+ It is better to tune this parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets.
213
+
214
+ sampling_rate: int (default - 16000)
215
+ Currently silero VAD models support 8000 and 16000 (or multiply of 16000) sample rates
216
+
217
+ min_speech_duration_ms: int (default - 250 milliseconds)
218
+ Final speech chunks shorter min_speech_duration_ms are thrown out
219
+
220
+ max_speech_duration_s: int (default - inf)
221
+ Maximum duration of speech chunks in seconds
222
+ Chunks longer than max_speech_duration_s will be split at the timestamp of the last silence that lasts more than 100ms (if any), to prevent agressive cutting.
223
+ Otherwise, they will be split aggressively just before max_speech_duration_s.
224
+
225
+ min_silence_duration_ms: int (default - 100 milliseconds)
226
+ In the end of each speech chunk wait for min_silence_duration_ms before separating it
227
+
228
+ speech_pad_ms: int (default - 30 milliseconds)
229
+ Final speech chunks are padded by speech_pad_ms each side
230
+
231
+ return_seconds: bool (default - False)
232
+ whether return timestamps in seconds (default - samples)
233
+
234
+ visualize_probs: bool (default - False)
235
+ whether draw prob hist or not
236
+
237
+ progress_tracking_callback: Callable[[float], None] (default - None)
238
+ callback function taking progress in percents as an argument
239
+
240
+ window_size_samples: int (default - 512 samples)
241
+ !!! DEPRECATED, DOES NOTHING !!!
242
+
243
+ Returns
244
+ ----------
245
+ speeches: list of dicts
246
+ list containing ends and beginnings of speech chunks (samples or seconds based on return_seconds)
247
+ """
248
+
249
+ if not torch.is_tensor(audio):
250
+ try:
251
+ audio = torch.Tensor(audio)
252
+ except:
253
+ raise TypeError("Audio cannot be casted to tensor. Cast it manually")
254
+
255
+ if len(audio.shape) > 1:
256
+ for i in range(len(audio.shape)): # trying to squeeze empty dimensions
257
+ audio = audio.squeeze(0)
258
+ if len(audio.shape) > 1:
259
+ raise ValueError("More than one dimension in audio. Are you trying to process audio with 2 channels?")
260
+
261
+ if sampling_rate > 16000 and (sampling_rate % 16000 == 0):
262
+ step = sampling_rate // 16000
263
+ sampling_rate = 16000
264
+ audio = audio[::step]
265
+ warnings.warn('Sampling rate is a multiply of 16000, casting to 16000 manually!')
266
+ else:
267
+ step = 1
268
+
269
+ if sampling_rate not in [8000, 16000]:
270
+ raise ValueError("Currently silero VAD models support 8000 and 16000 (or multiply of 16000) sample rates")
271
+
272
+ window_size_samples = 512 if sampling_rate == 16000 else 256
273
+
274
+ model.reset_states()
275
+ min_speech_samples = sampling_rate * min_speech_duration_ms / 1000
276
+ speech_pad_samples = sampling_rate * speech_pad_ms / 1000
277
+ max_speech_samples = sampling_rate * max_speech_duration_s - window_size_samples - 2 * speech_pad_samples
278
+ min_silence_samples = sampling_rate * min_silence_duration_ms / 1000
279
+ min_silence_samples_at_max_speech = sampling_rate * 98 / 1000
280
+
281
+ audio_length_samples = len(audio)
282
+
283
+ speech_probs = []
284
+ for current_start_sample in range(0, audio_length_samples, window_size_samples):
285
+ chunk = audio[current_start_sample: current_start_sample + window_size_samples]
286
+ if len(chunk) < window_size_samples:
287
+ chunk = torch.nn.functional.pad(chunk, (0, int(window_size_samples - len(chunk))))
288
+ speech_prob = model(chunk, sampling_rate).item()
289
+ speech_probs.append(speech_prob)
290
+ # caculate progress and seng it to callback function
291
+ progress = current_start_sample + window_size_samples
292
+ if progress > audio_length_samples:
293
+ progress = audio_length_samples
294
+ progress_percent = (progress / audio_length_samples) * 100
295
+ if progress_tracking_callback:
296
+ progress_tracking_callback(progress_percent)
297
+
298
+ triggered = False
299
+ speeches = []
300
+ current_speech = {}
301
+ neg_threshold = threshold - 0.15
302
+ temp_end = 0 # to save potential segment end (and tolerate some silence)
303
+ prev_end = next_start = 0 # to save potential segment limits in case of maximum segment size reached
304
+
305
+ for i, speech_prob in enumerate(speech_probs):
306
+ if (speech_prob >= threshold) and temp_end:
307
+ temp_end = 0
308
+ if next_start < prev_end:
309
+ next_start = window_size_samples * i
310
+
311
+ if (speech_prob >= threshold) and not triggered:
312
+ triggered = True
313
+ current_speech['start'] = window_size_samples * i
314
+ continue
315
+
316
+ if triggered and (window_size_samples * i) - current_speech['start'] > max_speech_samples:
317
+ if prev_end:
318
+ current_speech['end'] = prev_end
319
+ speeches.append(current_speech)
320
+ current_speech = {}
321
+ if next_start < prev_end: # previously reached silence (< neg_thres) and is still not speech (< thres)
322
+ triggered = False
323
+ else:
324
+ current_speech['start'] = next_start
325
+ prev_end = next_start = temp_end = 0
326
+ else:
327
+ current_speech['end'] = window_size_samples * i
328
+ speeches.append(current_speech)
329
+ current_speech = {}
330
+ prev_end = next_start = temp_end = 0
331
+ triggered = False
332
+ continue
333
+
334
+ if (speech_prob < neg_threshold) and triggered:
335
+ if not temp_end:
336
+ temp_end = window_size_samples * i
337
+ if ((window_size_samples * i) - temp_end) > min_silence_samples_at_max_speech : # condition to avoid cutting in very short silence
338
+ prev_end = temp_end
339
+ if (window_size_samples * i) - temp_end < min_silence_samples:
340
+ continue
341
+ else:
342
+ current_speech['end'] = temp_end
343
+ if (current_speech['end'] - current_speech['start']) > min_speech_samples:
344
+ speeches.append(current_speech)
345
+ current_speech = {}
346
+ prev_end = next_start = temp_end = 0
347
+ triggered = False
348
+ continue
349
+
350
+ if current_speech and (audio_length_samples - current_speech['start']) > min_speech_samples:
351
+ current_speech['end'] = audio_length_samples
352
+ speeches.append(current_speech)
353
+
354
+ for i, speech in enumerate(speeches):
355
+ if i == 0:
356
+ speech['start'] = int(max(0, speech['start'] - speech_pad_samples))
357
+ if i != len(speeches) - 1:
358
+ silence_duration = speeches[i+1]['start'] - speech['end']
359
+ if silence_duration < 2 * speech_pad_samples:
360
+ speech['end'] += int(silence_duration // 2)
361
+ speeches[i+1]['start'] = int(max(0, speeches[i+1]['start'] - silence_duration // 2))
362
+ else:
363
+ speech['end'] = int(min(audio_length_samples, speech['end'] + speech_pad_samples))
364
+ speeches[i+1]['start'] = int(max(0, speeches[i+1]['start'] - speech_pad_samples))
365
+ else:
366
+ speech['end'] = int(min(audio_length_samples, speech['end'] + speech_pad_samples))
367
+
368
+ if return_seconds:
369
+ for speech_dict in speeches:
370
+ speech_dict['start'] = round(speech_dict['start'] / sampling_rate, 1)
371
+ speech_dict['end'] = round(speech_dict['end'] / sampling_rate, 1)
372
+ elif step > 1:
373
+ for speech_dict in speeches:
374
+ speech_dict['start'] *= step
375
+ speech_dict['end'] *= step
376
+
377
+ if visualize_probs:
378
+ make_visualization(speech_probs, window_size_samples / sampling_rate)
379
+
380
+ return speeches
381
+
382
+
383
+ class VADIterator:
384
+ def __init__(self,
385
+ model,
386
+ threshold: float = 0.5,
387
+ sampling_rate: int = 16000,
388
+ min_silence_duration_ms: int = 100,
389
+ speech_pad_ms: int = 30
390
+ ):
391
+
392
+ """
393
+ Class for stream imitation
394
+
395
+ Parameters
396
+ ----------
397
+ model: preloaded .jit/.onnx silero VAD model
398
+
399
+ threshold: float (default - 0.5)
400
+ Speech threshold. Silero VAD outputs speech probabilities for each audio chunk, probabilities ABOVE this value are considered as SPEECH.
401
+ It is better to tune this parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets.
402
+
403
+ sampling_rate: int (default - 16000)
404
+ Currently silero VAD models support 8000 and 16000 sample rates
405
+
406
+ min_silence_duration_ms: int (default - 100 milliseconds)
407
+ In the end of each speech chunk wait for min_silence_duration_ms before separating it
408
+
409
+ speech_pad_ms: int (default - 30 milliseconds)
410
+ Final speech chunks are padded by speech_pad_ms each side
411
+ """
412
+
413
+ self.model = model
414
+ self.threshold = threshold
415
+ self.sampling_rate = sampling_rate
416
+
417
+ if sampling_rate not in [8000, 16000]:
418
+ raise ValueError('VADIterator does not support sampling rates other than [8000, 16000]')
419
+
420
+ self.min_silence_samples = sampling_rate * min_silence_duration_ms / 1000
421
+ self.speech_pad_samples = sampling_rate * speech_pad_ms / 1000
422
+ self.reset_states()
423
+
424
+ def reset_states(self):
425
+
426
+ self.model.reset_states()
427
+ self.triggered = False
428
+ self.temp_end = 0
429
+ self.current_sample = 0
430
+
431
+ @torch.no_grad()
432
+ def __call__(self, x, return_seconds=False):
433
+ """
434
+ x: torch.Tensor
435
+ audio chunk (see examples in repo)
436
+
437
+ return_seconds: bool (default - False)
438
+ whether return timestamps in seconds (default - samples)
439
+ """
440
+
441
+ if not torch.is_tensor(x):
442
+ try:
443
+ x = torch.Tensor(x)
444
+ except:
445
+ raise TypeError("Audio cannot be casted to tensor. Cast it manually")
446
+
447
+ window_size_samples = len(x[0]) if x.dim() == 2 else len(x)
448
+ self.current_sample += window_size_samples
449
+
450
+ speech_prob = self.model(x, self.sampling_rate).item()
451
+
452
+ if (speech_prob >= self.threshold) and self.temp_end:
453
+ self.temp_end = 0
454
+
455
+ if (speech_prob >= self.threshold) and not self.triggered:
456
+ self.triggered = True
457
+ speech_start = max(0, self.current_sample - self.speech_pad_samples - window_size_samples)
458
+ return {'start': int(speech_start) if not return_seconds else round(speech_start / self.sampling_rate, 1)}
459
+
460
+ if (speech_prob < self.threshold - 0.15) and self.triggered:
461
+ if not self.temp_end:
462
+ self.temp_end = self.current_sample
463
+ if self.current_sample - self.temp_end < self.min_silence_samples:
464
+ return None
465
+ else:
466
+ speech_end = self.temp_end + self.speech_pad_samples - window_size_samples
467
+ self.temp_end = 0
468
+ self.triggered = False
469
+ return {'end': int(speech_end) if not return_seconds else round(speech_end / self.sampling_rate, 1)}
470
+
471
+ return None
472
+
473
+
474
+ def collect_chunks(tss: List[dict],
475
+ wav: torch.Tensor):
476
+ chunks = []
477
+ for i in tss:
478
+ chunks.append(wav[i['start']: i['end']])
479
+ return torch.cat(chunks)
480
+
481
+
482
+ def drop_chunks(tss: List[dict],
483
+ wav: torch.Tensor):
484
+ chunks = []
485
+ cur_start = 0
486
+ for i in tss:
487
+ chunks.append((wav[cur_start: i['start']]))
488
+ cur_start = i['end']
489
+ return torch.cat(chunks)