ayaanzaveri aadnk commited on
Commit
298e3b3
0 Parent(s):

Duplicate from aadnk/whisper-webui

Browse files

Co-authored-by: Kristian Stangeland <[email protected]>

.gitattributes ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ftz filter=lfs diff=lfs merge=lfs -text
6
+ *.gz filter=lfs diff=lfs merge=lfs -text
7
+ *.h5 filter=lfs diff=lfs merge=lfs -text
8
+ *.joblib filter=lfs diff=lfs merge=lfs -text
9
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
10
+ *.model filter=lfs diff=lfs merge=lfs -text
11
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
12
+ *.npy filter=lfs diff=lfs merge=lfs -text
13
+ *.npz filter=lfs diff=lfs merge=lfs -text
14
+ *.onnx filter=lfs diff=lfs merge=lfs -text
15
+ *.ot filter=lfs diff=lfs merge=lfs -text
16
+ *.parquet filter=lfs diff=lfs merge=lfs -text
17
+ *.pickle filter=lfs diff=lfs merge=lfs -text
18
+ *.pkl filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pt filter=lfs diff=lfs merge=lfs -text
21
+ *.pth filter=lfs diff=lfs merge=lfs -text
22
+ *.rar filter=lfs diff=lfs merge=lfs -text
23
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
24
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
25
+ *.tflite filter=lfs diff=lfs merge=lfs -text
26
+ *.tgz filter=lfs diff=lfs merge=lfs -text
27
+ *.wasm filter=lfs diff=lfs merge=lfs -text
28
+ *.xz filter=lfs diff=lfs merge=lfs -text
29
+ *.zip filter=lfs diff=lfs merge=lfs -text
30
+ *.zst filter=lfs diff=lfs merge=lfs -text
31
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
32
+ *.pdf filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ .vscode/
4
+ flagged/
5
+ *.py[cod]
6
+ *$py.class
LICENSE.md ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ ==============
3
+
4
+ _Version 2.0, January 2004_
5
+ _&lt;<http://www.apache.org/licenses/>&gt;_
6
+
7
+ ### Terms and Conditions for use, reproduction, and distribution
8
+
9
+ #### 1. Definitions
10
+
11
+ “License” shall mean the terms and conditions for use, reproduction, and
12
+ distribution as defined by Sections 1 through 9 of this document.
13
+
14
+ “Licensor” shall mean the copyright owner or entity authorized by the copyright
15
+ owner that is granting the License.
16
+
17
+ “Legal Entity” shall mean the union of the acting entity and all other entities
18
+ that control, are controlled by, or are under common control with that entity.
19
+ For the purposes of this definition, “control” means **(i)** the power, direct or
20
+ indirect, to cause the direction or management of such entity, whether by
21
+ contract or otherwise, or **(ii)** ownership of fifty percent (50%) or more of the
22
+ outstanding shares, or **(iii)** beneficial ownership of such entity.
23
+
24
+ “You” (or “Your”) shall mean an individual or Legal Entity exercising
25
+ permissions granted by this License.
26
+
27
+ “Source” form shall mean the preferred form for making modifications, including
28
+ but not limited to software source code, documentation source, and configuration
29
+ files.
30
+
31
+ “Object” form shall mean any form resulting from mechanical transformation or
32
+ translation of a Source form, including but not limited to compiled object code,
33
+ generated documentation, and conversions to other media types.
34
+
35
+ “Work” shall mean the work of authorship, whether in Source or Object form, made
36
+ available under the License, as indicated by a copyright notice that is included
37
+ in or attached to the work (an example is provided in the Appendix below).
38
+
39
+ “Derivative Works” shall mean any work, whether in Source or Object form, that
40
+ is based on (or derived from) the Work and for which the editorial revisions,
41
+ annotations, elaborations, or other modifications represent, as a whole, an
42
+ original work of authorship. For the purposes of this License, Derivative Works
43
+ shall not include works that remain separable from, or merely link (or bind by
44
+ name) to the interfaces of, the Work and Derivative Works thereof.
45
+
46
+ “Contribution” shall mean any work of authorship, including the original version
47
+ of the Work and any modifications or additions to that Work or Derivative Works
48
+ thereof, that is intentionally submitted to Licensor for inclusion in the Work
49
+ by the copyright owner or by an individual or Legal Entity authorized to submit
50
+ on behalf of the copyright owner. For the purposes of this definition,
51
+ “submitted” means any form of electronic, verbal, or written communication sent
52
+ to the Licensor or its representatives, including but not limited to
53
+ communication on electronic mailing lists, source code control systems, and
54
+ issue tracking systems that are managed by, or on behalf of, the Licensor for
55
+ the purpose of discussing and improving the Work, but excluding communication
56
+ that is conspicuously marked or otherwise designated in writing by the copyright
57
+ owner as “Not a Contribution.”
58
+
59
+ “Contributor” shall mean Licensor and any individual or Legal Entity on behalf
60
+ of whom a Contribution has been received by Licensor and subsequently
61
+ incorporated within the Work.
62
+
63
+ #### 2. Grant of Copyright License
64
+
65
+ Subject to the terms and conditions of this License, each Contributor hereby
66
+ grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free,
67
+ irrevocable copyright license to reproduce, prepare Derivative Works of,
68
+ publicly display, publicly perform, sublicense, and distribute the Work and such
69
+ Derivative Works in Source or Object form.
70
+
71
+ #### 3. Grant of Patent License
72
+
73
+ Subject to the terms and conditions of this License, each Contributor hereby
74
+ grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free,
75
+ irrevocable (except as stated in this section) patent license to make, have
76
+ made, use, offer to sell, sell, import, and otherwise transfer the Work, where
77
+ such license applies only to those patent claims licensable by such Contributor
78
+ that are necessarily infringed by their Contribution(s) alone or by combination
79
+ of their Contribution(s) with the Work to which such Contribution(s) was
80
+ submitted. If You institute patent litigation against any entity (including a
81
+ cross-claim or counterclaim in a lawsuit) alleging that the Work or a
82
+ Contribution incorporated within the Work constitutes direct or contributory
83
+ patent infringement, then any patent licenses granted to You under this License
84
+ for that Work shall terminate as of the date such litigation is filed.
85
+
86
+ #### 4. Redistribution
87
+
88
+ You may reproduce and distribute copies of the Work or Derivative Works thereof
89
+ in any medium, with or without modifications, and in Source or Object form,
90
+ provided that You meet the following conditions:
91
+
92
+ * **(a)** You must give any other recipients of the Work or Derivative Works a copy of
93
+ this License; and
94
+ * **(b)** You must cause any modified files to carry prominent notices stating that You
95
+ changed the files; and
96
+ * **(c)** You must retain, in the Source form of any Derivative Works that You distribute,
97
+ all copyright, patent, trademark, and attribution notices from the Source form
98
+ of the Work, excluding those notices that do not pertain to any part of the
99
+ Derivative Works; and
100
+ * **(d)** If the Work includes a “NOTICE” text file as part of its distribution, then any
101
+ Derivative Works that You distribute must include a readable copy of the
102
+ attribution notices contained within such NOTICE file, excluding those notices
103
+ that do not pertain to any part of the Derivative Works, in at least one of the
104
+ following places: within a NOTICE text file distributed as part of the
105
+ Derivative Works; within the Source form or documentation, if provided along
106
+ with the Derivative Works; or, within a display generated by the Derivative
107
+ Works, if and wherever such third-party notices normally appear. The contents of
108
+ the NOTICE file are for informational purposes only and do not modify the
109
+ License. You may add Your own attribution notices within Derivative Works that
110
+ You distribute, alongside or as an addendum to the NOTICE text from the Work,
111
+ provided that such additional attribution notices cannot be construed as
112
+ modifying the License.
113
+
114
+ You may add Your own copyright statement to Your modifications and may provide
115
+ additional or different license terms and conditions for use, reproduction, or
116
+ distribution of Your modifications, or for any such Derivative Works as a whole,
117
+ provided Your use, reproduction, and distribution of the Work otherwise complies
118
+ with the conditions stated in this License.
119
+
120
+ #### 5. Submission of Contributions
121
+
122
+ Unless You explicitly state otherwise, any Contribution intentionally submitted
123
+ for inclusion in the Work by You to the Licensor shall be under the terms and
124
+ conditions of this License, without any additional terms or conditions.
125
+ Notwithstanding the above, nothing herein shall supersede or modify the terms of
126
+ any separate license agreement you may have executed with Licensor regarding
127
+ such Contributions.
128
+
129
+ #### 6. Trademarks
130
+
131
+ This License does not grant permission to use the trade names, trademarks,
132
+ service marks, or product names of the Licensor, except as required for
133
+ reasonable and customary use in describing the origin of the Work and
134
+ reproducing the content of the NOTICE file.
135
+
136
+ #### 7. Disclaimer of Warranty
137
+
138
+ Unless required by applicable law or agreed to in writing, Licensor provides the
139
+ Work (and each Contributor provides its Contributions) on an “AS IS” BASIS,
140
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied,
141
+ including, without limitation, any warranties or conditions of TITLE,
142
+ NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are
143
+ solely responsible for determining the appropriateness of using or
144
+ redistributing the Work and assume any risks associated with Your exercise of
145
+ permissions under this License.
146
+
147
+ #### 8. Limitation of Liability
148
+
149
+ In no event and under no legal theory, whether in tort (including negligence),
150
+ contract, or otherwise, unless required by applicable law (such as deliberate
151
+ and grossly negligent acts) or agreed to in writing, shall any Contributor be
152
+ liable to You for damages, including any direct, indirect, special, incidental,
153
+ or consequential damages of any character arising as a result of this License or
154
+ out of the use or inability to use the Work (including but not limited to
155
+ damages for loss of goodwill, work stoppage, computer failure or malfunction, or
156
+ any and all other commercial damages or losses), even if such Contributor has
157
+ been advised of the possibility of such damages.
158
+
159
+ #### 9. Accepting Warranty or Additional Liability
160
+
161
+ While redistributing the Work or Derivative Works thereof, You may choose to
162
+ offer, and charge a fee for, acceptance of support, warranty, indemnity, or
163
+ other liability obligations and/or rights consistent with this License. However,
164
+ in accepting such obligations, You may act only on Your own behalf and on Your
165
+ sole responsibility, not on behalf of any other Contributor, and only if You
166
+ agree to indemnify, defend, and hold each Contributor harmless for any liability
167
+ incurred by, or claims asserted against, such Contributor by reason of your
168
+ accepting any such warranty or additional liability.
169
+
170
+ _END OF TERMS AND CONDITIONS_
171
+
172
+ ### APPENDIX: How to apply the Apache License to your work
173
+
174
+ To apply the Apache License to your work, attach the following boilerplate
175
+ notice, with the fields enclosed by brackets `[]` replaced with your own
176
+ identifying information. (Don't include the brackets!) The text should be
177
+ enclosed in the appropriate comment syntax for the file format. We also
178
+ recommend that a file or class name and description of purpose be included on
179
+ the same “printed page” as the copyright notice for easier identification within
180
+ third-party archives.
181
+
182
+ Copyright [yyyy] [name of copyright owner]
183
+
184
+ Licensed under the Apache License, Version 2.0 (the "License");
185
+ you may not use this file except in compliance with the License.
186
+ You may obtain a copy of the License at
187
+
188
+ http://www.apache.org/licenses/LICENSE-2.0
189
+
190
+ Unless required by applicable law or agreed to in writing, software
191
+ distributed under the License is distributed on an "AS IS" BASIS,
192
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
193
+ See the License for the specific language governing permissions and
194
+ limitations under the License.
195
+
README.md ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Whisper Webui
3
+ emoji: ⚡
4
+ colorFrom: pink
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 3.23.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ duplicated_from: aadnk/whisper-webui
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
15
+
16
+ # Running Locally
17
+
18
+ To run this program locally, first install Python 3.9+ and Git. Then install Pytorch 10.1+ and all the other dependencies:
19
+ ```
20
+ pip install -r requirements.txt
21
+ ```
22
+
23
+ You can find detailed instructions for how to install this on Windows 10/11 [here (PDF)](docs/windows/install_win10_win11.pdf).
24
+
25
+ Finally, run the full version (no audio length restrictions) of the app with parallel CPU/GPU enabled:
26
+ ```
27
+ python app.py --input_audio_max_duration -1 --server_name 127.0.0.1 --auto_parallel True
28
+ ```
29
+
30
+ You can also run the CLI interface, which is similar to Whisper's own CLI but also supports the following additional arguments:
31
+ ```
32
+ python cli.py \
33
+ [--vad {none,silero-vad,silero-vad-skip-gaps,silero-vad-expand-into-gaps,periodic-vad}] \
34
+ [--vad_merge_window VAD_MERGE_WINDOW] \
35
+ [--vad_max_merge_size VAD_MAX_MERGE_SIZE] \
36
+ [--vad_padding VAD_PADDING] \
37
+ [--vad_prompt_window VAD_PROMPT_WINDOW]
38
+ [--vad_cpu_cores NUMBER_OF_CORES]
39
+ [--vad_parallel_devices COMMA_DELIMITED_DEVICES]
40
+ [--auto_parallel BOOLEAN]
41
+ ```
42
+ In addition, you may also use URL's in addition to file paths as input.
43
+ ```
44
+ python cli.py --model large --vad silero-vad --language Japanese "https://www.youtube.com/watch?v=4cICErqqRSM"
45
+ ```
46
+
47
+ Rather than supplying arguments to `app.py` or `cli.py`, you can also use the configuration file [config.json5](config.json5). See that file for more information.
48
+ If you want to use a different configuration file, you can use the `WHISPER_WEBUI_CONFIG` environment variable to specify the path to another file.
49
+
50
+ ## Google Colab
51
+
52
+ You can also run this Web UI directly on [Google Colab](https://colab.research.google.com/drive/1qeTSvi7Bt_5RMm88ipW4fkcsMOKlDDss?usp=sharing), if you haven't got a GPU powerful enough to run the larger models.
53
+
54
+ See the [colab documentation](docs/colab.md) for more information.
55
+
56
+ ## Parallel Execution
57
+
58
+ You can also run both the Web-UI or the CLI on multiple GPUs in parallel, using the `vad_parallel_devices` option. This takes a comma-delimited list of
59
+ device IDs (0, 1, etc.) that Whisper should be distributed to and run on concurrently:
60
+ ```
61
+ python cli.py --model large --vad silero-vad --language Japanese \
62
+ --vad_parallel_devices 0,1 "https://www.youtube.com/watch?v=4cICErqqRSM"
63
+ ```
64
+
65
+ Note that this requires a VAD to function properly, otherwise only the first GPU will be used. Though you could use `period-vad` to avoid taking the hit
66
+ of running Silero-Vad, at a slight cost to accuracy.
67
+
68
+ This is achieved by creating N child processes (where N is the number of selected devices), where Whisper is run concurrently. In `app.py`, you can also
69
+ set the `vad_process_timeout` option. This configures the number of seconds until a process is killed due to inactivity, freeing RAM and video memory.
70
+ The default value is 30 minutes.
71
+
72
+ ```
73
+ python app.py --input_audio_max_duration -1 --vad_parallel_devices 0,1 --vad_process_timeout 3600
74
+ ```
75
+
76
+ To execute the Silero VAD itself in parallel, use the `vad_cpu_cores` option:
77
+ ```
78
+ python app.py --input_audio_max_duration -1 --vad_parallel_devices 0,1 --vad_process_timeout 3600 --vad_cpu_cores 4
79
+ ```
80
+
81
+ You may also use `vad_process_timeout` with a single device (`--vad_parallel_devices 0`), if you prefer to always free video memory after a period of time.
82
+
83
+ ### Auto Parallel
84
+
85
+ You can also set `auto_parallel` to `True`. This will set `vad_parallel_devices` to use all the GPU devices on the system, and `vad_cpu_cores` to be equal to the number of
86
+ cores (up to 8):
87
+ ```
88
+ python app.py --input_audio_max_duration -1 --auto_parallel True
89
+ ```
90
+
91
+ ### Multiple Files
92
+
93
+ You can upload multiple files either through the "Upload files" option, or as a playlist on YouTube.
94
+ Each audio file will then be processed in turn, and the resulting SRT/VTT/Transcript will be made available in the "Download" section.
95
+ When more than one file is processed, the UI will also generate a "All_Output" zip file containing all the text output files.
96
+
97
+ # Docker
98
+
99
+ To run it in Docker, first install Docker and optionally the NVIDIA Container Toolkit in order to use the GPU.
100
+ Then either use the GitLab hosted container below, or check out this repository and build an image:
101
+ ```
102
+ sudo docker build -t whisper-webui:1 .
103
+ ```
104
+
105
+ You can then start the WebUI with GPU support like so:
106
+ ```
107
+ sudo docker run -d --gpus=all -p 7860:7860 whisper-webui:1
108
+ ```
109
+
110
+ Leave out "--gpus=all" if you don't have access to a GPU with enough memory, and are fine with running it on the CPU only:
111
+ ```
112
+ sudo docker run -d -p 7860:7860 whisper-webui:1
113
+ ```
114
+
115
+ # GitLab Docker Registry
116
+
117
+ This Docker container is also hosted on GitLab:
118
+
119
+ ```
120
+ sudo docker run -d --gpus=all -p 7860:7860 registry.gitlab.com/aadnk/whisper-webui:latest
121
+ ```
122
+
123
+ ## Custom Arguments
124
+
125
+ You can also pass custom arguments to `app.py` in the Docker container, for instance to be able to use all the GPUs in parallel:
126
+ ```
127
+ sudo docker run -d --gpus all -p 7860:7860 \
128
+ --mount type=bind,source=/home/administrator/.cache/whisper,target=/root/.cache/whisper \
129
+ --restart=on-failure:15 registry.gitlab.com/aadnk/whisper-webui:latest \
130
+ app.py --input_audio_max_duration -1 --server_name 0.0.0.0 --auto_parallel True \
131
+ --default_vad silero-vad --default_model_name large
132
+ ```
133
+
134
+ You can also call `cli.py` the same way:
135
+ ```
136
+ sudo docker run --gpus all \
137
+ --mount type=bind,source=/home/administrator/.cache/whisper,target=/root/.cache/whisper \
138
+ --mount type=bind,source=${PWD},target=/app/data \
139
+ registry.gitlab.com/aadnk/whisper-webui:latest \
140
+ cli.py --model large --auto_parallel True --vad silero-vad \
141
+ --output_dir /app/data /app/data/YOUR-FILE-HERE.mp4
142
+ ```
143
+
144
+ ## Caching
145
+
146
+ Note that the models themselves are currently not included in the Docker images, and will be downloaded on the demand.
147
+ To avoid this, bind the directory /root/.cache/whisper to some directory on the host (for instance /home/administrator/.cache/whisper), where you can (optionally)
148
+ prepopulate the directory with the different Whisper models.
149
+ ```
150
+ sudo docker run -d --gpus=all -p 7860:7860 \
151
+ --mount type=bind,source=/home/administrator/.cache/whisper,target=/root/.cache/whisper \
152
+ registry.gitlab.com/aadnk/whisper-webui:latest
153
+ ```
app-local.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Run the app with no audio file restrictions
2
+ from app import create_ui
3
+ from src.config import ApplicationConfig
4
+
5
+ create_ui(ApplicationConfig.create_default(input_audio_max_duration=-1))
app-network.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Run the app with no audio file restrictions, and make it available on the network
2
+ from app import create_ui
3
+ from src.config import ApplicationConfig
4
+
5
+ create_ui(ApplicationConfig.create_default(input_audio_max_duration=-1, server_name="0.0.0.0"))
app-shared.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Run the app with no audio file restrictions
2
+ from app import create_ui
3
+ from src.config import ApplicationConfig
4
+
5
+ create_ui(ApplicationConfig.create_default(input_audio_max_duration=-1, share=True))
app.py ADDED
@@ -0,0 +1,515 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime
2
+ import math
3
+ from typing import Iterator, Union
4
+ import argparse
5
+
6
+ from io import StringIO
7
+ import os
8
+ import pathlib
9
+ import tempfile
10
+ import zipfile
11
+ import numpy as np
12
+
13
+ import torch
14
+ from src.config import ApplicationConfig
15
+ from src.hooks.whisperProgressHook import ProgressListener, SubTaskProgressListener, create_progress_listener_handle
16
+ from src.modelCache import ModelCache
17
+ from src.source import get_audio_source_collection
18
+ from src.vadParallel import ParallelContext, ParallelTranscription
19
+
20
+ # External programs
21
+ import ffmpeg
22
+
23
+ # UI
24
+ import gradio as gr
25
+
26
+ from src.download import ExceededMaximumDuration, download_url
27
+ from src.utils import slugify, write_srt, write_vtt
28
+ from src.vad import AbstractTranscription, NonSpeechStrategy, PeriodicTranscriptionConfig, TranscriptionConfig, VadPeriodicTranscription, VadSileroTranscription
29
+ from src.whisperContainer import WhisperContainer
30
+
31
+ # Configure more application defaults in config.json5
32
+
33
+ # Gradio seems to truncate files without keeping the extension, so we need to truncate the file prefix ourself
34
+ MAX_FILE_PREFIX_LENGTH = 17
35
+
36
+ # Limit auto_parallel to a certain number of CPUs (specify vad_cpu_cores to get a higher number)
37
+ MAX_AUTO_CPU_CORES = 8
38
+
39
+ LANGUAGES = [
40
+ "English", "Chinese", "German", "Spanish", "Russian", "Korean",
41
+ "French", "Japanese", "Portuguese", "Turkish", "Polish", "Catalan",
42
+ "Dutch", "Arabic", "Swedish", "Italian", "Indonesian", "Hindi",
43
+ "Finnish", "Vietnamese", "Hebrew", "Ukrainian", "Greek", "Malay",
44
+ "Czech", "Romanian", "Danish", "Hungarian", "Tamil", "Norwegian",
45
+ "Thai", "Urdu", "Croatian", "Bulgarian", "Lithuanian", "Latin",
46
+ "Maori", "Malayalam", "Welsh", "Slovak", "Telugu", "Persian",
47
+ "Latvian", "Bengali", "Serbian", "Azerbaijani", "Slovenian",
48
+ "Kannada", "Estonian", "Macedonian", "Breton", "Basque", "Icelandic",
49
+ "Armenian", "Nepali", "Mongolian", "Bosnian", "Kazakh", "Albanian",
50
+ "Swahili", "Galician", "Marathi", "Punjabi", "Sinhala", "Khmer",
51
+ "Shona", "Yoruba", "Somali", "Afrikaans", "Occitan", "Georgian",
52
+ "Belarusian", "Tajik", "Sindhi", "Gujarati", "Amharic", "Yiddish",
53
+ "Lao", "Uzbek", "Faroese", "Haitian Creole", "Pashto", "Turkmen",
54
+ "Nynorsk", "Maltese", "Sanskrit", "Luxembourgish", "Myanmar", "Tibetan",
55
+ "Tagalog", "Malagasy", "Assamese", "Tatar", "Hawaiian", "Lingala",
56
+ "Hausa", "Bashkir", "Javanese", "Sundanese"
57
+ ]
58
+
59
+ WHISPER_MODELS = ["tiny", "base", "small", "medium", "large", "large-v1", "large-v2"]
60
+
61
+ class WhisperTranscriber:
62
+ def __init__(self, input_audio_max_duration: float = None, vad_process_timeout: float = None,
63
+ vad_cpu_cores: int = 1, delete_uploaded_files: bool = False, output_dir: str = None,
64
+ app_config: ApplicationConfig = None):
65
+ self.model_cache = ModelCache()
66
+ self.parallel_device_list = None
67
+ self.gpu_parallel_context = None
68
+ self.cpu_parallel_context = None
69
+ self.vad_process_timeout = vad_process_timeout
70
+ self.vad_cpu_cores = vad_cpu_cores
71
+
72
+ self.vad_model = None
73
+ self.inputAudioMaxDuration = input_audio_max_duration
74
+ self.deleteUploadedFiles = delete_uploaded_files
75
+ self.output_dir = output_dir
76
+
77
+ self.app_config = app_config
78
+
79
+ def set_parallel_devices(self, vad_parallel_devices: str):
80
+ self.parallel_device_list = [ device.strip() for device in vad_parallel_devices.split(",") ] if vad_parallel_devices else None
81
+
82
+ def set_auto_parallel(self, auto_parallel: bool):
83
+ if auto_parallel:
84
+ if torch.cuda.is_available():
85
+ self.parallel_device_list = [ str(gpu_id) for gpu_id in range(torch.cuda.device_count())]
86
+
87
+ self.vad_cpu_cores = min(os.cpu_count(), MAX_AUTO_CPU_CORES)
88
+ print("[Auto parallel] Using GPU devices " + str(self.parallel_device_list) + " and " + str(self.vad_cpu_cores) + " CPU cores for VAD/transcription.")
89
+
90
+ # Entry function for the simple tab
91
+ def transcribe_webui_simple(self, modelName, languageName, urlData, multipleFiles, microphoneData, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow,
92
+ progress=gr.Progress()):
93
+ return self.transcribe_webui(modelName, languageName, urlData, multipleFiles, microphoneData, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow,
94
+ progress=progress)
95
+
96
+ # Entry function for the full tab
97
+ def transcribe_webui_full(self, modelName, languageName, urlData, multipleFiles, microphoneData, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow,
98
+ initial_prompt: str, temperature: float, best_of: int, beam_size: int, patience: float, length_penalty: float, suppress_tokens: str,
99
+ condition_on_previous_text: bool, fp16: bool, temperature_increment_on_fallback: float,
100
+ compression_ratio_threshold: float, logprob_threshold: float, no_speech_threshold: float,
101
+ progress=gr.Progress()):
102
+
103
+ # Handle temperature_increment_on_fallback
104
+ if temperature_increment_on_fallback is not None:
105
+ temperature = tuple(np.arange(temperature, 1.0 + 1e-6, temperature_increment_on_fallback))
106
+ else:
107
+ temperature = [temperature]
108
+
109
+ return self.transcribe_webui(modelName, languageName, urlData, multipleFiles, microphoneData, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow,
110
+ initial_prompt=initial_prompt, temperature=temperature, best_of=best_of, beam_size=beam_size, patience=patience, length_penalty=length_penalty, suppress_tokens=suppress_tokens,
111
+ condition_on_previous_text=condition_on_previous_text, fp16=fp16,
112
+ compression_ratio_threshold=compression_ratio_threshold, logprob_threshold=logprob_threshold, no_speech_threshold=no_speech_threshold,
113
+ progress=progress)
114
+
115
+ def transcribe_webui(self, modelName, languageName, urlData, multipleFiles, microphoneData, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow,
116
+ progress: gr.Progress = None, **decodeOptions: dict):
117
+ try:
118
+ sources = self.__get_source(urlData, multipleFiles, microphoneData)
119
+
120
+ try:
121
+ selectedLanguage = languageName.lower() if len(languageName) > 0 else None
122
+ selectedModel = modelName if modelName is not None else "base"
123
+
124
+ model = WhisperContainer(model_name=selectedModel, cache=self.model_cache, models=self.app_config.models)
125
+
126
+ # Result
127
+ download = []
128
+ zip_file_lookup = {}
129
+ text = ""
130
+ vtt = ""
131
+
132
+ # Write result
133
+ downloadDirectory = tempfile.mkdtemp()
134
+ source_index = 0
135
+
136
+ outputDirectory = self.output_dir if self.output_dir is not None else downloadDirectory
137
+
138
+ # Progress
139
+ total_duration = sum([source.get_audio_duration() for source in sources])
140
+ current_progress = 0
141
+
142
+ # A listener that will report progress to Gradio
143
+ root_progress_listener = self._create_progress_listener(progress)
144
+
145
+ # Execute whisper
146
+ for source in sources:
147
+ source_prefix = ""
148
+ source_audio_duration = source.get_audio_duration()
149
+
150
+ if (len(sources) > 1):
151
+ # Prefix (minimum 2 digits)
152
+ source_index += 1
153
+ source_prefix = str(source_index).zfill(2) + "_"
154
+ print("Transcribing ", source.source_path)
155
+
156
+ scaled_progress_listener = SubTaskProgressListener(root_progress_listener,
157
+ base_task_total=total_duration,
158
+ sub_task_start=current_progress,
159
+ sub_task_total=source_audio_duration)
160
+
161
+ # Transcribe
162
+ result = self.transcribe_file(model, source.source_path, selectedLanguage, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, scaled_progress_listener, **decodeOptions)
163
+ filePrefix = slugify(source_prefix + source.get_short_name(), allow_unicode=True)
164
+
165
+ # Update progress
166
+ current_progress += source_audio_duration
167
+
168
+ source_download, source_text, source_vtt = self.write_result(result, filePrefix, outputDirectory)
169
+
170
+ if len(sources) > 1:
171
+ # Add new line separators
172
+ if (len(source_text) > 0):
173
+ source_text += os.linesep + os.linesep
174
+ if (len(source_vtt) > 0):
175
+ source_vtt += os.linesep + os.linesep
176
+
177
+ # Append file name to source text too
178
+ source_text = source.get_full_name() + ":" + os.linesep + source_text
179
+ source_vtt = source.get_full_name() + ":" + os.linesep + source_vtt
180
+
181
+ # Add to result
182
+ download.extend(source_download)
183
+ text += source_text
184
+ vtt += source_vtt
185
+
186
+ if (len(sources) > 1):
187
+ # Zip files support at least 260 characters, but we'll play it safe and use 200
188
+ zipFilePrefix = slugify(source_prefix + source.get_short_name(max_length=200), allow_unicode=True)
189
+
190
+ # File names in ZIP file can be longer
191
+ for source_download_file in source_download:
192
+ # Get file postfix (after last -)
193
+ filePostfix = os.path.basename(source_download_file).split("-")[-1]
194
+ zip_file_name = zipFilePrefix + "-" + filePostfix
195
+ zip_file_lookup[source_download_file] = zip_file_name
196
+
197
+ # Create zip file from all sources
198
+ if len(sources) > 1:
199
+ downloadAllPath = os.path.join(downloadDirectory, "All_Output-" + datetime.now().strftime("%Y%m%d-%H%M%S") + ".zip")
200
+
201
+ with zipfile.ZipFile(downloadAllPath, 'w', zipfile.ZIP_DEFLATED) as zip:
202
+ for download_file in download:
203
+ # Get file name from lookup
204
+ zip_file_name = zip_file_lookup.get(download_file, os.path.basename(download_file))
205
+ zip.write(download_file, arcname=zip_file_name)
206
+
207
+ download.insert(0, downloadAllPath)
208
+
209
+ return download, text, vtt
210
+
211
+ finally:
212
+ # Cleanup source
213
+ if self.deleteUploadedFiles:
214
+ for source in sources:
215
+ print("Deleting source file " + source.source_path)
216
+
217
+ try:
218
+ os.remove(source.source_path)
219
+ except Exception as e:
220
+ # Ignore error - it's just a cleanup
221
+ print("Error deleting source file " + source.source_path + ": " + str(e))
222
+
223
+ except ExceededMaximumDuration as e:
224
+ return [], ("[ERROR]: Maximum remote video length is " + str(e.maxDuration) + "s, file was " + str(e.videoDuration) + "s"), "[ERROR]"
225
+
226
+ def transcribe_file(self, model: WhisperContainer, audio_path: str, language: str, task: str = None, vad: str = None,
227
+ vadMergeWindow: float = 5, vadMaxMergeSize: float = 150, vadPadding: float = 1, vadPromptWindow: float = 1,
228
+ progressListener: ProgressListener = None, **decodeOptions: dict):
229
+
230
+ initial_prompt = decodeOptions.pop('initial_prompt', None)
231
+
232
+ if progressListener is None:
233
+ # Default progress listener
234
+ progressListener = ProgressListener()
235
+
236
+ if ('task' in decodeOptions):
237
+ task = decodeOptions.pop('task')
238
+
239
+ # Callable for processing an audio file
240
+ whisperCallable = model.create_callback(language, task, initial_prompt, **decodeOptions)
241
+
242
+ # The results
243
+ if (vad == 'silero-vad'):
244
+ # Silero VAD where non-speech gaps are transcribed
245
+ process_gaps = self._create_silero_config(NonSpeechStrategy.CREATE_SEGMENT, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow)
246
+ result = self.process_vad(audio_path, whisperCallable, self.vad_model, process_gaps, progressListener=progressListener)
247
+ elif (vad == 'silero-vad-skip-gaps'):
248
+ # Silero VAD where non-speech gaps are simply ignored
249
+ skip_gaps = self._create_silero_config(NonSpeechStrategy.SKIP, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow)
250
+ result = self.process_vad(audio_path, whisperCallable, self.vad_model, skip_gaps, progressListener=progressListener)
251
+ elif (vad == 'silero-vad-expand-into-gaps'):
252
+ # Use Silero VAD where speech-segments are expanded into non-speech gaps
253
+ expand_gaps = self._create_silero_config(NonSpeechStrategy.EXPAND_SEGMENT, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow)
254
+ result = self.process_vad(audio_path, whisperCallable, self.vad_model, expand_gaps, progressListener=progressListener)
255
+ elif (vad == 'periodic-vad'):
256
+ # Very simple VAD - mark every 5 minutes as speech. This makes it less likely that Whisper enters an infinite loop, but
257
+ # it may create a break in the middle of a sentence, causing some artifacts.
258
+ periodic_vad = VadPeriodicTranscription()
259
+ period_config = PeriodicTranscriptionConfig(periodic_duration=vadMaxMergeSize, max_prompt_window=vadPromptWindow)
260
+ result = self.process_vad(audio_path, whisperCallable, periodic_vad, period_config, progressListener=progressListener)
261
+
262
+ else:
263
+ if (self._has_parallel_devices()):
264
+ # Use a simple period transcription instead, as we need to use the parallel context
265
+ periodic_vad = VadPeriodicTranscription()
266
+ period_config = PeriodicTranscriptionConfig(periodic_duration=math.inf, max_prompt_window=1)
267
+
268
+ result = self.process_vad(audio_path, whisperCallable, periodic_vad, period_config, progressListener=progressListener)
269
+ else:
270
+ # Default VAD
271
+ result = whisperCallable.invoke(audio_path, 0, None, None, progress_listener=progressListener)
272
+
273
+ return result
274
+
275
+ def _create_progress_listener(self, progress: gr.Progress):
276
+ if (progress is None):
277
+ # Dummy progress listener
278
+ return ProgressListener()
279
+
280
+ class ForwardingProgressListener(ProgressListener):
281
+ def __init__(self, progress: gr.Progress):
282
+ self.progress = progress
283
+
284
+ def on_progress(self, current: Union[int, float], total: Union[int, float]):
285
+ # From 0 to 1
286
+ self.progress(current / total)
287
+
288
+ def on_finished(self):
289
+ self.progress(1)
290
+
291
+ return ForwardingProgressListener(progress)
292
+
293
+ def process_vad(self, audio_path, whisperCallable, vadModel: AbstractTranscription, vadConfig: TranscriptionConfig,
294
+ progressListener: ProgressListener = None):
295
+ if (not self._has_parallel_devices()):
296
+ # No parallel devices, so just run the VAD and Whisper in sequence
297
+ return vadModel.transcribe(audio_path, whisperCallable, vadConfig, progressListener=progressListener)
298
+
299
+ gpu_devices = self.parallel_device_list
300
+
301
+ if (gpu_devices is None or len(gpu_devices) == 0):
302
+ # No GPU devices specified, pass the current environment variable to the first GPU process. This may be NULL.
303
+ gpu_devices = [os.environ.get("CUDA_VISIBLE_DEVICES", None)]
304
+
305
+ # Create parallel context if needed
306
+ if (self.gpu_parallel_context is None):
307
+ # Create a context wih processes and automatically clear the pool after 1 hour of inactivity
308
+ self.gpu_parallel_context = ParallelContext(num_processes=len(gpu_devices), auto_cleanup_timeout_seconds=self.vad_process_timeout)
309
+ # We also need a CPU context for the VAD
310
+ if (self.cpu_parallel_context is None):
311
+ self.cpu_parallel_context = ParallelContext(num_processes=self.vad_cpu_cores, auto_cleanup_timeout_seconds=self.vad_process_timeout)
312
+
313
+ parallel_vad = ParallelTranscription()
314
+ return parallel_vad.transcribe_parallel(transcription=vadModel, audio=audio_path, whisperCallable=whisperCallable,
315
+ config=vadConfig, cpu_device_count=self.vad_cpu_cores, gpu_devices=gpu_devices,
316
+ cpu_parallel_context=self.cpu_parallel_context, gpu_parallel_context=self.gpu_parallel_context,
317
+ progress_listener=progressListener)
318
+
319
+ def _has_parallel_devices(self):
320
+ return (self.parallel_device_list is not None and len(self.parallel_device_list) > 0) or self.vad_cpu_cores > 1
321
+
322
+ def _concat_prompt(self, prompt1, prompt2):
323
+ if (prompt1 is None):
324
+ return prompt2
325
+ elif (prompt2 is None):
326
+ return prompt1
327
+ else:
328
+ return prompt1 + " " + prompt2
329
+
330
+ def _create_silero_config(self, non_speech_strategy: NonSpeechStrategy, vadMergeWindow: float = 5, vadMaxMergeSize: float = 150, vadPadding: float = 1, vadPromptWindow: float = 1):
331
+ # Use Silero VAD
332
+ if (self.vad_model is None):
333
+ self.vad_model = VadSileroTranscription()
334
+
335
+ config = TranscriptionConfig(non_speech_strategy = non_speech_strategy,
336
+ max_silent_period=vadMergeWindow, max_merge_size=vadMaxMergeSize,
337
+ segment_padding_left=vadPadding, segment_padding_right=vadPadding,
338
+ max_prompt_window=vadPromptWindow)
339
+
340
+ return config
341
+
342
+ def write_result(self, result: dict, source_name: str, output_dir: str):
343
+ if not os.path.exists(output_dir):
344
+ os.makedirs(output_dir)
345
+
346
+ text = result["text"]
347
+ language = result["language"]
348
+ languageMaxLineWidth = self.__get_max_line_width(language)
349
+
350
+ print("Max line width " + str(languageMaxLineWidth))
351
+ vtt = self.__get_subs(result["segments"], "vtt", languageMaxLineWidth)
352
+ srt = self.__get_subs(result["segments"], "srt", languageMaxLineWidth)
353
+
354
+ output_files = []
355
+ output_files.append(self.__create_file(srt, output_dir, source_name + "-subs.srt"));
356
+ output_files.append(self.__create_file(vtt, output_dir, source_name + "-subs.vtt"));
357
+ output_files.append(self.__create_file(text, output_dir, source_name + "-transcript.txt"));
358
+
359
+ return output_files, text, vtt
360
+
361
+ def clear_cache(self):
362
+ self.model_cache.clear()
363
+ self.vad_model = None
364
+
365
+ def __get_source(self, urlData, multipleFiles, microphoneData):
366
+ return get_audio_source_collection(urlData, multipleFiles, microphoneData, self.inputAudioMaxDuration)
367
+
368
+ def __get_max_line_width(self, language: str) -> int:
369
+ if (language and language.lower() in ["japanese", "ja", "chinese", "zh"]):
370
+ # Chinese characters and kana are wider, so limit line length to 40 characters
371
+ return 40
372
+ else:
373
+ # TODO: Add more languages
374
+ # 80 latin characters should fit on a 1080p/720p screen
375
+ return 80
376
+
377
+ def __get_subs(self, segments: Iterator[dict], format: str, maxLineWidth: int) -> str:
378
+ segmentStream = StringIO()
379
+
380
+ if format == 'vtt':
381
+ write_vtt(segments, file=segmentStream, maxLineWidth=maxLineWidth)
382
+ elif format == 'srt':
383
+ write_srt(segments, file=segmentStream, maxLineWidth=maxLineWidth)
384
+ else:
385
+ raise Exception("Unknown format " + format)
386
+
387
+ segmentStream.seek(0)
388
+ return segmentStream.read()
389
+
390
+ def __create_file(self, text: str, directory: str, fileName: str) -> str:
391
+ # Write the text to a file
392
+ with open(os.path.join(directory, fileName), 'w+', encoding="utf-8") as file:
393
+ file.write(text)
394
+
395
+ return file.name
396
+
397
+ def close(self):
398
+ print("Closing parallel contexts")
399
+ self.clear_cache()
400
+
401
+ if (self.gpu_parallel_context is not None):
402
+ self.gpu_parallel_context.close()
403
+ if (self.cpu_parallel_context is not None):
404
+ self.cpu_parallel_context.close()
405
+
406
+
407
+ def create_ui(app_config: ApplicationConfig):
408
+ ui = WhisperTranscriber(app_config.input_audio_max_duration, app_config.vad_process_timeout, app_config.vad_cpu_cores,
409
+ app_config.delete_uploaded_files, app_config.output_dir, app_config)
410
+
411
+ # Specify a list of devices to use for parallel processing
412
+ ui.set_parallel_devices(app_config.vad_parallel_devices)
413
+ ui.set_auto_parallel(app_config.auto_parallel)
414
+
415
+ ui_description = "Whisper is a general-purpose speech recognition model. It is trained on a large dataset of diverse "
416
+ ui_description += " audio and is also a multi-task model that can perform multilingual speech recognition "
417
+ ui_description += " as well as speech translation and language identification. "
418
+
419
+ ui_description += "\n\n\n\nFor longer audio files (>10 minutes) not in English, it is recommended that you select Silero VAD (Voice Activity Detector) in the VAD option."
420
+
421
+ if app_config.input_audio_max_duration > 0:
422
+ ui_description += "\n\n" + "Max audio file length: " + str(app_config.input_audio_max_duration) + " s"
423
+
424
+ ui_article = "Read the [documentation here](https://gitlab.com/aadnk/whisper-webui/-/blob/main/docs/options.md)"
425
+
426
+ whisper_models = app_config.get_model_names()
427
+
428
+ simple_inputs = lambda : [
429
+ gr.Dropdown(choices=whisper_models, value=app_config.default_model_name, label="Model"),
430
+ gr.Dropdown(choices=sorted(LANGUAGES), label="Language", value=app_config.language),
431
+ gr.Text(label="URL (YouTube, etc.)"),
432
+ gr.File(label="Upload Files", file_count="multiple"),
433
+ gr.Audio(source="microphone", type="filepath", label="Microphone Input"),
434
+ gr.Dropdown(choices=["transcribe", "translate"], label="Task", value=app_config.task),
435
+ gr.Dropdown(choices=["none", "silero-vad", "silero-vad-skip-gaps", "silero-vad-expand-into-gaps", "periodic-vad"], value=app_config.default_vad, label="VAD"),
436
+ gr.Number(label="VAD - Merge Window (s)", precision=0, value=app_config.vad_merge_window),
437
+ gr.Number(label="VAD - Max Merge Size (s)", precision=0, value=app_config.vad_max_merge_size),
438
+ gr.Number(label="VAD - Padding (s)", precision=None, value=app_config.vad_padding),
439
+ gr.Number(label="VAD - Prompt Window (s)", precision=None, value=app_config.vad_prompt_window),
440
+ ]
441
+
442
+ simple_transcribe = gr.Interface(fn=ui.transcribe_webui_simple, description=ui_description, article=ui_article, inputs=simple_inputs(), outputs=[
443
+ gr.File(label="Download"),
444
+ gr.Text(label="Transcription"),
445
+ gr.Text(label="Segments")
446
+ ])
447
+
448
+ full_description = ui_description + "\n\n\n\n" + "Be careful when changing some of the options in the full interface - this can cause the model to crash."
449
+
450
+ full_transcribe = gr.Interface(fn=ui.transcribe_webui_full, description=full_description, article=ui_article, inputs=[
451
+ *simple_inputs(),
452
+ gr.TextArea(label="Initial Prompt"),
453
+ gr.Number(label="Temperature", value=app_config.temperature),
454
+ gr.Number(label="Best Of - Non-zero temperature", value=app_config.best_of, precision=0),
455
+ gr.Number(label="Beam Size - Zero temperature", value=app_config.beam_size, precision=0),
456
+ gr.Number(label="Patience - Zero temperature", value=app_config.patience),
457
+ gr.Number(label="Length Penalty - Any temperature", value=app_config.length_penalty),
458
+ gr.Text(label="Suppress Tokens - Comma-separated list of token IDs", value=app_config.suppress_tokens),
459
+ gr.Checkbox(label="Condition on previous text", value=app_config.condition_on_previous_text),
460
+ gr.Checkbox(label="FP16", value=app_config.fp16),
461
+ gr.Number(label="Temperature increment on fallback", value=app_config.temperature_increment_on_fallback),
462
+ gr.Number(label="Compression ratio threshold", value=app_config.compression_ratio_threshold),
463
+ gr.Number(label="Logprob threshold", value=app_config.logprob_threshold),
464
+ gr.Number(label="No speech threshold", value=app_config.no_speech_threshold)
465
+ ], outputs=[
466
+ gr.File(label="Download"),
467
+ gr.Text(label="Transcription"),
468
+ gr.Text(label="Segments")
469
+ ])
470
+
471
+ demo = gr.TabbedInterface([simple_transcribe, full_transcribe], tab_names=["Simple", "Full"])
472
+
473
+ # Queue up the demo
474
+ if app_config.queue_concurrency_count is not None and app_config.queue_concurrency_count > 0:
475
+ demo.queue(concurrency_count=app_config.queue_concurrency_count)
476
+
477
+ demo.launch(share=app_config.share, server_name=app_config.server_name, server_port=app_config.server_port)
478
+
479
+ # Clean up
480
+ ui.close()
481
+
482
+ if __name__ == '__main__':
483
+ app_config = ApplicationConfig.create_default()
484
+ whisper_models = app_config.get_model_names()
485
+
486
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
487
+ parser.add_argument("--input_audio_max_duration", type=int, default=app_config.input_audio_max_duration, \
488
+ help="Maximum audio file length in seconds, or -1 for no limit.") # 600
489
+ parser.add_argument("--share", type=bool, default=app_config.share, \
490
+ help="True to share the app on HuggingFace.") # False
491
+ parser.add_argument("--server_name", type=str, default=app_config.server_name, \
492
+ help="The host or IP to bind to. If None, bind to localhost.") # None
493
+ parser.add_argument("--server_port", type=int, default=app_config.server_port, \
494
+ help="The port to bind to.") # 7860
495
+ parser.add_argument("--queue_concurrency_count", type=int, default=app_config.queue_concurrency_count, \
496
+ help="The number of concurrent requests to process.") # 1
497
+ parser.add_argument("--default_model_name", type=str, choices=whisper_models, default=app_config.default_model_name, \
498
+ help="The default model name.") # medium
499
+ parser.add_argument("--default_vad", type=str, default=app_config.default_vad, \
500
+ help="The default VAD.") # silero-vad
501
+ parser.add_argument("--vad_parallel_devices", type=str, default=app_config.vad_parallel_devices, \
502
+ help="A commma delimited list of CUDA devices to use for parallel processing. If None, disable parallel processing.") # ""
503
+ parser.add_argument("--vad_cpu_cores", type=int, default=app_config.vad_cpu_cores, \
504
+ help="The number of CPU cores to use for VAD pre-processing.") # 1
505
+ parser.add_argument("--vad_process_timeout", type=float, default=app_config.vad_process_timeout, \
506
+ help="The number of seconds before inactivate processes are terminated. Use 0 to close processes immediately, or None for no timeout.") # 1800
507
+ parser.add_argument("--auto_parallel", type=bool, default=app_config.auto_parallel, \
508
+ help="True to use all available GPUs and CPU cores for processing. Use vad_cpu_cores/vad_parallel_devices to specify the number of CPU cores/GPUs to use.") # False
509
+ parser.add_argument("--output_dir", "-o", type=str, default=app_config.output_dir, \
510
+ help="directory to save the outputs") # None
511
+
512
+ args = parser.parse_args().__dict__
513
+
514
+ updated_config = app_config.update(**args)
515
+ create_ui(app_config=updated_config)
cli.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import pathlib
4
+ from urllib.parse import urlparse
5
+ import warnings
6
+ import numpy as np
7
+
8
+ import torch
9
+ from app import LANGUAGES, WhisperTranscriber
10
+ from src.config import ApplicationConfig
11
+ from src.download import download_url
12
+
13
+ from src.utils import optional_float, optional_int, str2bool
14
+ from src.whisperContainer import WhisperContainer
15
+
16
+ def cli():
17
+ app_config = ApplicationConfig.create_default()
18
+ whisper_models = app_config.get_model_names()
19
+
20
+ # For the CLI, we fallback to saving the output to the current directory
21
+ output_dir = app_config.output_dir if app_config.output_dir is not None else "."
22
+
23
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
24
+ parser.add_argument("audio", nargs="+", type=str, \
25
+ help="audio file(s) to transcribe")
26
+ parser.add_argument("--model", default=app_config.default_model_name, choices=whisper_models, \
27
+ help="name of the Whisper model to use") # medium
28
+ parser.add_argument("--model_dir", type=str, default=app_config.model_dir, \
29
+ help="the path to save model files; uses ~/.cache/whisper by default")
30
+ parser.add_argument("--device", default=app_config.device, \
31
+ help="device to use for PyTorch inference")
32
+ parser.add_argument("--output_dir", "-o", type=str, default=output_dir, \
33
+ help="directory to save the outputs")
34
+ parser.add_argument("--verbose", type=str2bool, default=app_config.verbose, \
35
+ help="whether to print out the progress and debug messages")
36
+
37
+ parser.add_argument("--task", type=str, default=app_config.task, choices=["transcribe", "translate"], \
38
+ help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
39
+ parser.add_argument("--language", type=str, default=app_config.language, choices=sorted(LANGUAGES), \
40
+ help="language spoken in the audio, specify None to perform language detection")
41
+
42
+ parser.add_argument("--vad", type=str, default=app_config.default_vad, choices=["none", "silero-vad", "silero-vad-skip-gaps", "silero-vad-expand-into-gaps", "periodic-vad"], \
43
+ help="The voice activity detection algorithm to use") # silero-vad
44
+ parser.add_argument("--vad_merge_window", type=optional_float, default=app_config.vad_merge_window, \
45
+ help="The window size (in seconds) to merge voice segments")
46
+ parser.add_argument("--vad_max_merge_size", type=optional_float, default=app_config.vad_max_merge_size,\
47
+ help="The maximum size (in seconds) of a voice segment")
48
+ parser.add_argument("--vad_padding", type=optional_float, default=app_config.vad_padding, \
49
+ help="The padding (in seconds) to add to each voice segment")
50
+ parser.add_argument("--vad_prompt_window", type=optional_float, default=app_config.vad_prompt_window, \
51
+ help="The window size of the prompt to pass to Whisper")
52
+ parser.add_argument("--vad_cpu_cores", type=int, default=app_config.vad_cpu_cores, \
53
+ help="The number of CPU cores to use for VAD pre-processing.") # 1
54
+ parser.add_argument("--vad_parallel_devices", type=str, default=app_config.vad_parallel_devices, \
55
+ help="A commma delimited list of CUDA devices to use for parallel processing. If None, disable parallel processing.") # ""
56
+ parser.add_argument("--auto_parallel", type=bool, default=app_config.auto_parallel, \
57
+ help="True to use all available GPUs and CPU cores for processing. Use vad_cpu_cores/vad_parallel_devices to specify the number of CPU cores/GPUs to use.") # False
58
+
59
+ parser.add_argument("--temperature", type=float, default=app_config.temperature, \
60
+ help="temperature to use for sampling")
61
+ parser.add_argument("--best_of", type=optional_int, default=app_config.best_of, \
62
+ help="number of candidates when sampling with non-zero temperature")
63
+ parser.add_argument("--beam_size", type=optional_int, default=app_config.beam_size, \
64
+ help="number of beams in beam search, only applicable when temperature is zero")
65
+ parser.add_argument("--patience", type=float, default=app_config.patience, \
66
+ help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search")
67
+ parser.add_argument("--length_penalty", type=float, default=app_config.length_penalty, \
68
+ help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple lengt normalization by default")
69
+
70
+ parser.add_argument("--suppress_tokens", type=str, default=app_config.suppress_tokens, \
71
+ help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
72
+ parser.add_argument("--initial_prompt", type=str, default=app_config.initial_prompt, \
73
+ help="optional text to provide as a prompt for the first window.")
74
+ parser.add_argument("--condition_on_previous_text", type=str2bool, default=app_config.condition_on_previous_text, \
75
+ help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
76
+ parser.add_argument("--fp16", type=str2bool, default=app_config.fp16, \
77
+ help="whether to perform inference in fp16; True by default")
78
+
79
+ parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=app_config.temperature_increment_on_fallback, \
80
+ help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below")
81
+ parser.add_argument("--compression_ratio_threshold", type=optional_float, default=app_config.compression_ratio_threshold, \
82
+ help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
83
+ parser.add_argument("--logprob_threshold", type=optional_float, default=app_config.logprob_threshold, \
84
+ help="if the average log probability is lower than this value, treat the decoding as failed")
85
+ parser.add_argument("--no_speech_threshold", type=optional_float, default=app_config.no_speech_threshold, \
86
+ help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
87
+
88
+ args = parser.parse_args().__dict__
89
+ model_name: str = args.pop("model")
90
+ model_dir: str = args.pop("model_dir")
91
+ output_dir: str = args.pop("output_dir")
92
+ device: str = args.pop("device")
93
+ os.makedirs(output_dir, exist_ok=True)
94
+
95
+ if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
96
+ warnings.warn(f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead.")
97
+ args["language"] = "en"
98
+
99
+ temperature = args.pop("temperature")
100
+ temperature_increment_on_fallback = args.pop("temperature_increment_on_fallback")
101
+ if temperature_increment_on_fallback is not None:
102
+ temperature = tuple(np.arange(temperature, 1.0 + 1e-6, temperature_increment_on_fallback))
103
+ else:
104
+ temperature = [temperature]
105
+
106
+ vad = args.pop("vad")
107
+ vad_merge_window = args.pop("vad_merge_window")
108
+ vad_max_merge_size = args.pop("vad_max_merge_size")
109
+ vad_padding = args.pop("vad_padding")
110
+ vad_prompt_window = args.pop("vad_prompt_window")
111
+ vad_cpu_cores = args.pop("vad_cpu_cores")
112
+ auto_parallel = args.pop("auto_parallel")
113
+
114
+ transcriber = WhisperTranscriber(delete_uploaded_files=False, vad_cpu_cores=vad_cpu_cores, app_config=app_config)
115
+ transcriber.set_parallel_devices(args.pop("vad_parallel_devices"))
116
+ transcriber.set_auto_parallel(auto_parallel)
117
+
118
+ model = WhisperContainer(model_name, device=device, download_root=model_dir, models=app_config.models)
119
+
120
+ if (transcriber._has_parallel_devices()):
121
+ print("Using parallel devices:", transcriber.parallel_device_list)
122
+
123
+ for audio_path in args.pop("audio"):
124
+ sources = []
125
+
126
+ # Detect URL and download the audio
127
+ if (uri_validator(audio_path)):
128
+ # Download from YouTube/URL directly
129
+ for source_path in download_url(audio_path, maxDuration=-1, destinationDirectory=output_dir, playlistItems=None):
130
+ source_name = os.path.basename(source_path)
131
+ sources.append({ "path": source_path, "name": source_name })
132
+ else:
133
+ sources.append({ "path": audio_path, "name": os.path.basename(audio_path) })
134
+
135
+ for source in sources:
136
+ source_path = source["path"]
137
+ source_name = source["name"]
138
+
139
+ result = transcriber.transcribe_file(model, source_path, temperature=temperature,
140
+ vad=vad, vadMergeWindow=vad_merge_window, vadMaxMergeSize=vad_max_merge_size,
141
+ vadPadding=vad_padding, vadPromptWindow=vad_prompt_window, **args)
142
+
143
+ transcriber.write_result(result, source_name, output_dir)
144
+
145
+ transcriber.close()
146
+
147
+ def uri_validator(x):
148
+ try:
149
+ result = urlparse(x)
150
+ return all([result.scheme, result.netloc])
151
+ except:
152
+ return False
153
+
154
+ if __name__ == '__main__':
155
+ cli()
config.json5 ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "models": [
3
+ // Configuration for the built-in models. You can remove any of these
4
+ // if you don't want to use the default models.
5
+ {
6
+ "name": "tiny",
7
+ "url": "tiny"
8
+ },
9
+ {
10
+ "name": "base",
11
+ "url": "base"
12
+ },
13
+ {
14
+ "name": "small",
15
+ "url": "small"
16
+ },
17
+ {
18
+ "name": "medium",
19
+ "url": "medium"
20
+ },
21
+ {
22
+ "name": "large",
23
+ "url": "large"
24
+ },
25
+ {
26
+ "name": "large-v2",
27
+ "url": "large-v2"
28
+ },
29
+ // Uncomment to add custom Japanese models
30
+ //{
31
+ // "name": "whisper-large-v2-mix-jp",
32
+ // "url": "vumichien/whisper-large-v2-mix-jp",
33
+ // // The type of the model. Can be "huggingface" or "whisper" - "whisper" is the default.
34
+ // // HuggingFace models are loaded using the HuggingFace transformers library and then converted to Whisper models.
35
+ // "type": "huggingface",
36
+ //},
37
+ //{
38
+ // "name": "local-model",
39
+ // "url": "path/to/local/model",
40
+ //},
41
+ //{
42
+ // "name": "remote-model",
43
+ // "url": "https://example.com/path/to/model",
44
+ //}
45
+ ],
46
+ // Configuration options that will be used if they are not specified in the command line arguments.
47
+
48
+ // * WEBUI options *
49
+
50
+ // Maximum audio file length in seconds, or -1 for no limit. Ignored by CLI.
51
+ "input_audio_max_duration": 600,
52
+ // True to share the app on HuggingFace.
53
+ "share": false,
54
+ // The host or IP to bind to. If None, bind to localhost.
55
+ "server_name": null,
56
+ // The port to bind to.
57
+ "server_port": 7860,
58
+ // The number of workers to use for the web server. Use -1 to disable queueing.
59
+ "queue_concurrency_count": 1,
60
+ // Whether or not to automatically delete all uploaded files, to save disk space
61
+ "delete_uploaded_files": true,
62
+
63
+ // * General options *
64
+
65
+ // The default model name.
66
+ "default_model_name": "medium",
67
+ // The default VAD.
68
+ "default_vad": "silero-vad",
69
+ // A commma delimited list of CUDA devices to use for parallel processing. If None, disable parallel processing.
70
+ "vad_parallel_devices": "",
71
+ // The number of CPU cores to use for VAD pre-processing.
72
+ "vad_cpu_cores": 1,
73
+ // The number of seconds before inactivate processes are terminated. Use 0 to close processes immediately, or None for no timeout.
74
+ "vad_process_timeout": 1800,
75
+ // True to use all available GPUs and CPU cores for processing. Use vad_cpu_cores/vad_parallel_devices to specify the number of CPU cores/GPUs to use.
76
+ "auto_parallel": false,
77
+ // Directory to save the outputs (CLI will use the current directory if not specified)
78
+ "output_dir": null,
79
+ // The path to save model files; uses ~/.cache/whisper by default
80
+ "model_dir": null,
81
+ // Device to use for PyTorch inference, or Null to use the default device
82
+ "device": null,
83
+ // Whether to print out the progress and debug messages
84
+ "verbose": true,
85
+ // Whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')
86
+ "task": "transcribe",
87
+ // Language spoken in the audio, specify None to perform language detection
88
+ "language": null,
89
+ // The window size (in seconds) to merge voice segments
90
+ "vad_merge_window": 5,
91
+ // The maximum size (in seconds) of a voice segment
92
+ "vad_max_merge_size": 30,
93
+ // The padding (in seconds) to add to each voice segment
94
+ "vad_padding": 1,
95
+ // The window size of the prompt to pass to Whisper
96
+ "vad_prompt_window": 3,
97
+ // Temperature to use for sampling
98
+ "temperature": 0,
99
+ // Number of candidates when sampling with non-zero temperature
100
+ "best_of": 5,
101
+ // Number of beams in beam search, only applicable when temperature is zero
102
+ "beam_size": 5,
103
+ // Optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search
104
+ "patience": null,
105
+ // Optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default
106
+ "length_penalty": null,
107
+ // Comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations
108
+ "suppress_tokens": "-1",
109
+ // Optional text to provide as a prompt for the first window
110
+ "initial_prompt": null,
111
+ // If True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop
112
+ "condition_on_previous_text": true,
113
+ // Whether to perform inference in fp16; True by default
114
+ "fp16": true,
115
+ // Temperature to increase when falling back when the decoding fails to meet either of the thresholds below
116
+ "temperature_increment_on_fallback": 0.2,
117
+ // If the gzip compression ratio is higher than this value, treat the decoding as failed
118
+ "compression_ratio_threshold": 2.4,
119
+ // If the average log probability is lower than this value, treat the decoding as failed
120
+ "logprob_threshold": -1.0,
121
+ // If the probability of the <no-speech> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence
122
+ "no_speech_threshold": 0.6
123
+ }
dockerfile ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM huggingface/transformers-pytorch-gpu
2
+ EXPOSE 7860
3
+
4
+ ADD . /opt/whisper-webui/
5
+
6
+ # Latest version of transformers-pytorch-gpu seems to lack tk.
7
+ # Further, pip install fails, so we must upgrade pip first.
8
+ RUN apt-get -y install python3-tk
9
+ RUN python3 -m pip install --upgrade pip &&\
10
+ python3 -m pip install -r /opt/whisper-webui/requirements.txt
11
+
12
+ # Note: Models will be downloaded on demand to the directory /root/.cache/whisper.
13
+ # You can also bind this directory in the container to somewhere on the host.
14
+
15
+ # To be able to see logs in real time
16
+ ENV PYTHONUNBUFFERED=1
17
+
18
+ WORKDIR /opt/whisper-webui/
19
+ ENTRYPOINT ["python3"]
20
+ CMD ["app.py", "--input_audio_max_duration", "-1", "--server_name", "0.0.0.0", "--auto_parallel", "True"]
docs/colab.md ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Running Whisper on Google Colab
2
+
3
+ If you don't have a decent GPU or any experience in running command-line applications, you might want to try this Google Colab instead:
4
+
5
+ * [Google Colab - Whisper WebUI GPU](https://colab.research.google.com/drive/1qeTSvi7Bt_5RMm88ipW4fkcsMOKlDDss?usp=sharing)
6
+ * [Screenshots](https://imgur.com/a/ZfY6uBO)
7
+
8
+ The runtime (Runtime -> Change runtime type -> Hardware accelerator) should already be set top GPU. But if not, change it to GPU.
9
+
10
+ Then, sign in to Google if you haven't already. Next, click on "Connect" at the top right.
11
+
12
+ Under "Checking out WebUI from Git", click on the [play icon](https://imgur.com/a/81gOLyD) that appears in "[ ]" at the left. If you get a warning, click "Run anyway".
13
+
14
+ After this step has completed, it should be get a green check mark. Then move on to the next section under "Installing dependencies", and click in "[ ]" again. This might take approximately 30 seconds.
15
+
16
+ Once this has completed, scroll down to the "Run WebUI" section, and click on "[ ]". This will launch the WebUI in a shared link (expires in 72 hours). To open the UI, click on the link next to "Running on public URL", which will be something like https://12xxx.gradio.app/
17
+
18
+ The audio length in this version is not restricted, and it will run much faster as it is backed by a GPU. You can also run it using the "Large" model. Also note that it might take some time to start the model the first time, as it may need to download a 2.8 GB file on Google's servers.
19
+
20
+ Once you're done, you can close the WebUI session by clicking the animated close button under "Run WebUI". You can also do this if you encounter any errors and need to restart the UI. You should also go to "Manage Sessions" and terminate the session, otherwise you may end up using all your free compute credits.
docs/options.md ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Standard Options
2
+ To transcribe or translate an audio file, you can either copy an URL from a website (all [websites](https://github.com/yt-dlp/yt-dlp/blob/master/supportedsites.md)
3
+ supported by YT-DLP will work, including YouTube). Otherwise, upload an audio file (choose "All Files (*.*)"
4
+ in the file selector to select any file type, including video files) or use the microphone.
5
+
6
+ For longer audio files (>10 minutes), it is recommended that you select Silero VAD (Voice Activity Detector) in the VAD option, especially if you are using the `large-v1` model. Note that `large-v2` is a lot more forgiving, but you may still want to use a VAD with a slightly higher "VAD - Max Merge Size (s)" (60 seconds or more).
7
+
8
+ ## Model
9
+ Select the model that Whisper will use to transcribe the audio:
10
+
11
+ | Size | Parameters | English-only model | Multilingual model | Required VRAM | Relative speed |
12
+ |-----------|------------|--------------------|--------------------|---------------|----------------|
13
+ | tiny | 39 M | tiny.en | tiny | ~1 GB | ~32x |
14
+ | base | 74 M | base.en | base | ~1 GB | ~16x |
15
+ | small | 244 M | small.en | small | ~2 GB | ~6x |
16
+ | medium | 769 M | medium.en | medium | ~5 GB | ~2x |
17
+ | large | 1550 M | N/A | large | ~10 GB | 1x |
18
+ | large-v2 | 1550 M | N/A | large | ~10 GB | 1x |
19
+
20
+ ## Language
21
+
22
+ Select the language, or leave it empty for Whisper to automatically detect it.
23
+
24
+ Note that if the selected language and the language in the audio differs, Whisper may start to translate the audio to the selected
25
+ language. For instance, if the audio is in English but you select Japaneese, the model may translate the audio to Japanese.
26
+
27
+ ## Inputs
28
+ The options "URL (YouTube, etc.)", "Upload Files" or "Micriphone Input" allows you to send an audio input to the model.
29
+
30
+ ### Multiple Files
31
+ Note that the UI will only process either the given URL or the upload files (including microphone) - not both.
32
+
33
+ But you can upload multiple files either through the "Upload files" option, or as a playlist on YouTube. Each audio file will then be processed in turn, and the resulting SRT/VTT/Transcript will be made available in the "Download" section. When more than one file is processed, the UI will also generate a "All_Output" zip file containing all the text output files.
34
+
35
+ ## Task
36
+ Select the task - either "transcribe" to transcribe the audio to text, or "translate" to translate it to English.
37
+
38
+ ## Vad
39
+ Using a VAD will improve the timing accuracy of each transcribed line, as well as prevent Whisper getting into an infinite
40
+ loop detecting the same sentence over and over again. The downside is that this may be at a cost to text accuracy, especially
41
+ with regards to unique words or names that appear in the audio. You can compensate for this by increasing the prompt window.
42
+
43
+ Note that English is very well handled by Whisper, and it's less susceptible to issues surrounding bad timings and infinite loops.
44
+ So you may only need to use a VAD for other languages, such as Japanese, or when the audio is very long.
45
+
46
+ * none
47
+ * Run whisper on the entire audio input
48
+ * silero-vad
49
+ * Use Silero VAD to detect sections that contain speech, and run Whisper on independently on each section. Whisper is also run
50
+ on the gaps between each speech section, by either expanding the section up to the max merge size, or running Whisper independently
51
+ on the non-speech section.
52
+ * silero-vad-expand-into-gaps
53
+ * Use Silero VAD to detect sections that contain speech, and run Whisper on independently on each section. Each spech section will be expanded
54
+ such that they cover any adjacent non-speech sections. For instance, if an audio file of one minute contains the speech sections
55
+ 00:00 - 00:10 (A) and 00:30 - 00:40 (B), the first section (A) will be expanded to 00:00 - 00:30, and (B) will be expanded to 00:30 - 00:60.
56
+ * silero-vad-skip-gaps
57
+ * As above, but sections that doesn't contain speech according to Silero will be skipped. This will be slightly faster, but
58
+ may cause dialogue to be skipped.
59
+ * periodic-vad
60
+ * Create sections of speech every 'VAD - Max Merge Size' seconds. This is very fast and simple, but will potentially break
61
+ a sentence or word in two.
62
+
63
+ ## VAD - Merge Window
64
+ If set, any adjacent speech sections that are at most this number of seconds apart will be automatically merged.
65
+
66
+ ## VAD - Max Merge Size (s)
67
+ Disables merging of adjacent speech sections if they are this number of seconds long.
68
+
69
+ ## VAD - Padding (s)
70
+ The number of seconds (floating point) to add to the beginning and end of each speech section. Setting this to a number
71
+ larger than zero ensures that Whisper is more likely to correctly transcribe a sentence in the beginning of
72
+ a speech section. However, this also increases the probability of Whisper assigning the wrong timestamp
73
+ to each transcribed line. The default value is 1 second.
74
+
75
+ ## VAD - Prompt Window (s)
76
+ The text of a detected line will be included as a prompt to the next speech section, if the speech section starts at most this
77
+ number of seconds after the line has finished. For instance, if a line ends at 10:00, and the next speech section starts at
78
+ 10:04, the line's text will be included if the prompt window is 4 seconds or more (10:04 - 10:00 = 4 seconds).
79
+
80
+ Note that detected lines in gaps between speech sections will not be included in the prompt
81
+ (if silero-vad or silero-vad-expand-into-gaps) is used.
82
+
83
+ # Command Line Options
84
+
85
+ Both `app.py` and `cli.py` also accept command line options, such as the ability to enable parallel execution on multiple
86
+ CPU/GPU cores, the default model name/VAD and so on. Consult the README in the root folder for more information.
87
+
88
+ # Additional Options
89
+
90
+ In addition to the above, there's also a "Full" options interface that allows you to set all the options available in the Whisper
91
+ model. The options are as follows:
92
+
93
+ ## Initial Prompt
94
+ Optional text to provide as a prompt for the first 30 seconds window. Whisper will attempt to use this as a starting point for the transcription, but you can
95
+ also get creative and specify a style or format for the output of the transcription.
96
+
97
+ For instance, if you use the prompt "hello how is it going always use lowercase no punctuation goodbye one two three start stop i you me they", Whisper will
98
+ be biased to output lower capital letters and no punctuation, and may also be biased to output the words in the prompt more often.
99
+
100
+ ## Temperature
101
+ The temperature to use when sampling. Default is 0 (zero). A higher temperature will result in more random output, while a lower temperature will be more deterministic.
102
+
103
+ ## Best Of - Non-zero temperature
104
+ The number of candidates to sample from when sampling with non-zero temperature. Default is 5.
105
+
106
+ ## Beam Size - Zero temperature
107
+ The number of beams to use in beam search when sampling with zero temperature. Default is 5.
108
+
109
+ ## Patience - Zero temperature
110
+ The patience value to use in beam search when sampling with zero temperature. As in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search.
111
+
112
+ ## Length Penalty - Any temperature
113
+ The token length penalty coefficient (alpha) to use when sampling with any temperature. As in https://arxiv.org/abs/1609.08144, uses simple length normalization by default.
114
+
115
+ ## Suppress Tokens - Comma-separated list of token IDs
116
+ A comma-separated list of token IDs to suppress during sampling. The default value of "-1" will suppress most special characters except common punctuations.
117
+
118
+ ## Condition on previous text
119
+ If True, provide the previous output of the model as a prompt for the next window. Disabling this may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop.
120
+
121
+ ## FP16
122
+ Whether to perform inference in fp16. True by default.
123
+
124
+ ## Temperature increment on fallback
125
+ The temperature to increase when falling back when the decoding fails to meet either of the thresholds below. Default is 0.2.
126
+
127
+ ## Compression ratio threshold
128
+ If the gzip compression ratio is higher than this value, treat the decoding as failed. Default is 2.4.
129
+
130
+ ## Logprob threshold
131
+ If the average log probability is lower than this value, treat the decoding as failed. Default is -1.0.
132
+
133
+ ## No speech threshold
134
+ If the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence. Default is 0.6.
docs/windows/install_win10_win11.pdf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9b9f4ed547d6534411c17da1ea56707d2ec6e812611b1cbd3098756d5cbb8084
3
+ size 3378789
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ git+https://github.com/huggingface/transformers
2
+ git+https://github.com/openai/whisper.git
3
+ transformers
4
+ ffmpeg-python==0.2.0
5
+ gradio==3.23.0
6
+ yt-dlp
7
+ torchaudio
8
+ altair
9
+ json5
src/__init__.py ADDED
File without changes
src/config.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import urllib
2
+
3
+ import os
4
+ from typing import List
5
+ from urllib.parse import urlparse
6
+ import json5
7
+ import torch
8
+
9
+ from tqdm import tqdm
10
+
11
+ from src.conversion.hf_converter import convert_hf_whisper
12
+
13
+ class ModelConfig:
14
+ def __init__(self, name: str, url: str, path: str = None, type: str = "whisper"):
15
+ """
16
+ Initialize a model configuration.
17
+
18
+ name: Name of the model
19
+ url: URL to download the model from
20
+ path: Path to the model file. If not set, the model will be downloaded from the URL.
21
+ type: Type of model. Can be whisper or huggingface.
22
+ """
23
+ self.name = name
24
+ self.url = url
25
+ self.path = path
26
+ self.type = type
27
+
28
+ def download_url(self, root_dir: str):
29
+ import whisper
30
+
31
+ # See if path is already set
32
+ if self.path is not None:
33
+ return self.path
34
+
35
+ if root_dir is None:
36
+ root_dir = os.path.join(os.path.expanduser("~"), ".cache", "whisper")
37
+
38
+ model_type = self.type.lower() if self.type is not None else "whisper"
39
+
40
+ if model_type in ["huggingface", "hf"]:
41
+ self.path = self.url
42
+ destination_target = os.path.join(root_dir, self.name + ".pt")
43
+
44
+ # Convert from HuggingFace format to Whisper format
45
+ if os.path.exists(destination_target):
46
+ print(f"File {destination_target} already exists, skipping conversion")
47
+ else:
48
+ print("Saving HuggingFace model in Whisper format to " + destination_target)
49
+ convert_hf_whisper(self.url, destination_target)
50
+
51
+ self.path = destination_target
52
+
53
+ elif model_type in ["whisper", "w"]:
54
+ self.path = self.url
55
+
56
+ # See if URL is just a file
57
+ if self.url in whisper._MODELS:
58
+ # No need to download anything - Whisper will handle it
59
+ self.path = self.url
60
+ elif self.url.startswith("file://"):
61
+ # Get file path
62
+ self.path = urlparse(self.url).path
63
+ # See if it is an URL
64
+ elif self.url.startswith("http://") or self.url.startswith("https://"):
65
+ # Extension (or file name)
66
+ extension = os.path.splitext(self.url)[-1]
67
+ download_target = os.path.join(root_dir, self.name + extension)
68
+
69
+ if os.path.exists(download_target) and not os.path.isfile(download_target):
70
+ raise RuntimeError(f"{download_target} exists and is not a regular file")
71
+
72
+ if not os.path.isfile(download_target):
73
+ self._download_file(self.url, download_target)
74
+ else:
75
+ print(f"File {download_target} already exists, skipping download")
76
+
77
+ self.path = download_target
78
+ # Must be a local file
79
+ else:
80
+ self.path = self.url
81
+
82
+ else:
83
+ raise ValueError(f"Unknown model type {model_type}")
84
+
85
+ return self.path
86
+
87
+ def _download_file(self, url: str, destination: str):
88
+ with urllib.request.urlopen(url) as source, open(destination, "wb") as output:
89
+ with tqdm(
90
+ total=int(source.info().get("Content-Length")),
91
+ ncols=80,
92
+ unit="iB",
93
+ unit_scale=True,
94
+ unit_divisor=1024,
95
+ ) as loop:
96
+ while True:
97
+ buffer = source.read(8192)
98
+ if not buffer:
99
+ break
100
+
101
+ output.write(buffer)
102
+ loop.update(len(buffer))
103
+
104
+ class ApplicationConfig:
105
+ def __init__(self, models: List[ModelConfig] = [], input_audio_max_duration: int = 600,
106
+ share: bool = False, server_name: str = None, server_port: int = 7860,
107
+ queue_concurrency_count: int = 1, delete_uploaded_files: bool = True,
108
+ default_model_name: str = "medium", default_vad: str = "silero-vad",
109
+ vad_parallel_devices: str = "", vad_cpu_cores: int = 1, vad_process_timeout: int = 1800,
110
+ auto_parallel: bool = False, output_dir: str = None,
111
+ model_dir: str = None, device: str = None,
112
+ verbose: bool = True, task: str = "transcribe", language: str = None,
113
+ vad_merge_window: float = 5, vad_max_merge_size: float = 30,
114
+ vad_padding: float = 1, vad_prompt_window: float = 3,
115
+ temperature: float = 0, best_of: int = 5, beam_size: int = 5,
116
+ patience: float = None, length_penalty: float = None,
117
+ suppress_tokens: str = "-1", initial_prompt: str = None,
118
+ condition_on_previous_text: bool = True, fp16: bool = True,
119
+ temperature_increment_on_fallback: float = 0.2, compression_ratio_threshold: float = 2.4,
120
+ logprob_threshold: float = -1.0, no_speech_threshold: float = 0.6):
121
+
122
+ if device is None:
123
+ device = "cuda" if torch.cuda.is_available() else "cpu"
124
+
125
+ self.models = models
126
+
127
+ # WebUI settings
128
+ self.input_audio_max_duration = input_audio_max_duration
129
+ self.share = share
130
+ self.server_name = server_name
131
+ self.server_port = server_port
132
+ self.queue_concurrency_count = queue_concurrency_count
133
+ self.delete_uploaded_files = delete_uploaded_files
134
+
135
+ self.default_model_name = default_model_name
136
+ self.default_vad = default_vad
137
+ self.vad_parallel_devices = vad_parallel_devices
138
+ self.vad_cpu_cores = vad_cpu_cores
139
+ self.vad_process_timeout = vad_process_timeout
140
+ self.auto_parallel = auto_parallel
141
+ self.output_dir = output_dir
142
+
143
+ self.model_dir = model_dir
144
+ self.device = device
145
+ self.verbose = verbose
146
+ self.task = task
147
+ self.language = language
148
+ self.vad_merge_window = vad_merge_window
149
+ self.vad_max_merge_size = vad_max_merge_size
150
+ self.vad_padding = vad_padding
151
+ self.vad_prompt_window = vad_prompt_window
152
+ self.temperature = temperature
153
+ self.best_of = best_of
154
+ self.beam_size = beam_size
155
+ self.patience = patience
156
+ self.length_penalty = length_penalty
157
+ self.suppress_tokens = suppress_tokens
158
+ self.initial_prompt = initial_prompt
159
+ self.condition_on_previous_text = condition_on_previous_text
160
+ self.fp16 = fp16
161
+ self.temperature_increment_on_fallback = temperature_increment_on_fallback
162
+ self.compression_ratio_threshold = compression_ratio_threshold
163
+ self.logprob_threshold = logprob_threshold
164
+ self.no_speech_threshold = no_speech_threshold
165
+
166
+ def get_model_names(self):
167
+ return [ x.name for x in self.models ]
168
+
169
+ def update(self, **new_values):
170
+ result = ApplicationConfig(**self.__dict__)
171
+
172
+ for key, value in new_values.items():
173
+ setattr(result, key, value)
174
+ return result
175
+
176
+ @staticmethod
177
+ def create_default(**kwargs):
178
+ app_config = ApplicationConfig.parse_file(os.environ.get("WHISPER_WEBUI_CONFIG", "config.json5"))
179
+
180
+ # Update with kwargs
181
+ if len(kwargs) > 0:
182
+ app_config = app_config.update(**kwargs)
183
+ return app_config
184
+
185
+ @staticmethod
186
+ def parse_file(config_path: str):
187
+ import json5
188
+
189
+ with open(config_path, "r") as f:
190
+ # Load using json5
191
+ data = json5.load(f)
192
+ data_models = data.pop("models", [])
193
+
194
+ models = [ ModelConfig(**x) for x in data_models ]
195
+
196
+ return ApplicationConfig(models, **data)
src/conversion/hf_converter.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/bayartsogt-ya/whisper-multiple-hf-datasets
2
+
3
+ from copy import deepcopy
4
+ import torch
5
+ from transformers import WhisperForConditionalGeneration
6
+
7
+ WHISPER_MAPPING = {
8
+ "layers": "blocks",
9
+ "fc1": "mlp.0",
10
+ "fc2": "mlp.2",
11
+ "final_layer_norm": "mlp_ln",
12
+ "layers": "blocks",
13
+ ".self_attn.q_proj": ".attn.query",
14
+ ".self_attn.k_proj": ".attn.key",
15
+ ".self_attn.v_proj": ".attn.value",
16
+ ".self_attn_layer_norm": ".attn_ln",
17
+ ".self_attn.out_proj": ".attn.out",
18
+ ".encoder_attn.q_proj": ".cross_attn.query",
19
+ ".encoder_attn.k_proj": ".cross_attn.key",
20
+ ".encoder_attn.v_proj": ".cross_attn.value",
21
+ ".encoder_attn_layer_norm": ".cross_attn_ln",
22
+ ".encoder_attn.out_proj": ".cross_attn.out",
23
+ "decoder.layer_norm.": "decoder.ln.",
24
+ "encoder.layer_norm.": "encoder.ln_post.",
25
+ "embed_tokens": "token_embedding",
26
+ "encoder.embed_positions.weight": "encoder.positional_embedding",
27
+ "decoder.embed_positions.weight": "decoder.positional_embedding",
28
+ "layer_norm": "ln_post",
29
+ }
30
+
31
+
32
+ def rename_keys(s_dict):
33
+ keys = list(s_dict.keys())
34
+ for key in keys:
35
+ new_key = key
36
+ for k, v in WHISPER_MAPPING.items():
37
+ if k in key:
38
+ new_key = new_key.replace(k, v)
39
+
40
+ print(f"{key} -> {new_key}")
41
+
42
+ s_dict[new_key] = s_dict.pop(key)
43
+ return s_dict
44
+
45
+
46
+ def convert_hf_whisper(hf_model_name_or_path: str, whisper_state_path: str):
47
+ transformer_model = WhisperForConditionalGeneration.from_pretrained(hf_model_name_or_path)
48
+ config = transformer_model.config
49
+
50
+ # first build dims
51
+ dims = {
52
+ 'n_mels': config.num_mel_bins,
53
+ 'n_vocab': config.vocab_size,
54
+ 'n_audio_ctx': config.max_source_positions,
55
+ 'n_audio_state': config.d_model,
56
+ 'n_audio_head': config.encoder_attention_heads,
57
+ 'n_audio_layer': config.encoder_layers,
58
+ 'n_text_ctx': config.max_target_positions,
59
+ 'n_text_state': config.d_model,
60
+ 'n_text_head': config.decoder_attention_heads,
61
+ 'n_text_layer': config.decoder_layers
62
+ }
63
+
64
+ state_dict = deepcopy(transformer_model.model.state_dict())
65
+ state_dict = rename_keys(state_dict)
66
+
67
+ torch.save({"dims": dims, "model_state_dict": state_dict}, whisper_state_path)
src/download.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tempfile import mkdtemp
2
+ from typing import List
3
+ from yt_dlp import YoutubeDL
4
+
5
+ import yt_dlp
6
+ from yt_dlp.postprocessor import PostProcessor
7
+
8
+ class FilenameCollectorPP(PostProcessor):
9
+ def __init__(self):
10
+ super(FilenameCollectorPP, self).__init__(None)
11
+ self.filenames = []
12
+
13
+ def run(self, information):
14
+ self.filenames.append(information["filepath"])
15
+ return [], information
16
+
17
+ def download_url(url: str, maxDuration: int = None, destinationDirectory: str = None, playlistItems: str = "1") -> List[str]:
18
+ try:
19
+ return _perform_download(url, maxDuration=maxDuration, outputTemplate=None, destinationDirectory=destinationDirectory, playlistItems=playlistItems)
20
+ except yt_dlp.utils.DownloadError as e:
21
+ # In case of an OS error, try again with a different output template
22
+ if e.msg and e.msg.find("[Errno 36] File name too long") >= 0:
23
+ return _perform_download(url, maxDuration=maxDuration, outputTemplate="%(title).10s %(id)s.%(ext)s")
24
+ pass
25
+
26
+ def _perform_download(url: str, maxDuration: int = None, outputTemplate: str = None, destinationDirectory: str = None, playlistItems: str = "1"):
27
+ # Create a temporary directory to store the downloaded files
28
+ if destinationDirectory is None:
29
+ destinationDirectory = mkdtemp()
30
+
31
+ ydl_opts = {
32
+ "format": "bestaudio/best",
33
+ 'paths': {
34
+ 'home': destinationDirectory
35
+ }
36
+ }
37
+ if (playlistItems):
38
+ ydl_opts['playlist_items'] = playlistItems
39
+
40
+ # Add output template if specified
41
+ if outputTemplate:
42
+ ydl_opts['outtmpl'] = outputTemplate
43
+
44
+ filename_collector = FilenameCollectorPP()
45
+
46
+ with YoutubeDL(ydl_opts) as ydl:
47
+ if maxDuration and maxDuration > 0:
48
+ info = ydl.extract_info(url, download=False)
49
+ entries = "entries" in info and info["entries"] or [info]
50
+
51
+ total_duration = 0
52
+
53
+ # Compute total duration
54
+ for entry in entries:
55
+ total_duration += float(entry["duration"])
56
+
57
+ if total_duration >= maxDuration:
58
+ raise ExceededMaximumDuration(videoDuration=total_duration, maxDuration=maxDuration, message="Video is too long")
59
+
60
+ ydl.add_post_processor(filename_collector)
61
+ ydl.download([url])
62
+
63
+ if len(filename_collector.filenames) <= 0:
64
+ raise Exception("Cannot download " + url)
65
+
66
+ result = []
67
+
68
+ for filename in filename_collector.filenames:
69
+ result.append(filename)
70
+ print("Downloaded " + filename)
71
+
72
+ return result
73
+
74
+ class ExceededMaximumDuration(Exception):
75
+ def __init__(self, videoDuration, maxDuration, message):
76
+ self.videoDuration = videoDuration
77
+ self.maxDuration = maxDuration
78
+ super().__init__(message)
src/hooks/whisperProgressHook.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import threading
3
+ from typing import List, Union
4
+ import tqdm
5
+
6
+ class ProgressListener:
7
+ def on_progress(self, current: Union[int, float], total: Union[int, float]):
8
+ self.total = total
9
+
10
+ def on_finished(self):
11
+ pass
12
+
13
+ class ProgressListenerHandle:
14
+ def __init__(self, listener: ProgressListener):
15
+ self.listener = listener
16
+
17
+ def __enter__(self):
18
+ register_thread_local_progress_listener(self.listener)
19
+
20
+ def __exit__(self, exc_type, exc_val, exc_tb):
21
+ unregister_thread_local_progress_listener(self.listener)
22
+
23
+ if exc_type is None:
24
+ self.listener.on_finished()
25
+
26
+ class SubTaskProgressListener(ProgressListener):
27
+ """
28
+ A sub task listener that reports the progress of a sub task to a base task listener
29
+
30
+ Parameters
31
+ ----------
32
+ base_task_listener : ProgressListener
33
+ The base progress listener to accumulate overall progress in.
34
+ base_task_total : float
35
+ The maximum total progress that will be reported to the base progress listener.
36
+ sub_task_start : float
37
+ The starting progress of a sub task, in respect to the base progress listener.
38
+ sub_task_total : float
39
+ The total amount of progress a sub task will report to the base progress listener.
40
+ """
41
+ def __init__(
42
+ self,
43
+ base_task_listener: ProgressListener,
44
+ base_task_total: float,
45
+ sub_task_start: float,
46
+ sub_task_total: float,
47
+ ):
48
+ self.base_task_listener = base_task_listener
49
+ self.base_task_total = base_task_total
50
+ self.sub_task_start = sub_task_start
51
+ self.sub_task_total = sub_task_total
52
+
53
+ def on_progress(self, current: Union[int, float], total: Union[int, float]):
54
+ sub_task_progress_frac = current / total
55
+ sub_task_progress = self.sub_task_start + self.sub_task_total * sub_task_progress_frac
56
+ self.base_task_listener.on_progress(sub_task_progress, self.base_task_total)
57
+
58
+ def on_finished(self):
59
+ self.base_task_listener.on_progress(self.sub_task_start + self.sub_task_total, self.base_task_total)
60
+
61
+ class _CustomProgressBar(tqdm.tqdm):
62
+ def __init__(self, *args, **kwargs):
63
+ super().__init__(*args, **kwargs)
64
+ self._current = self.n # Set the initial value
65
+
66
+ def update(self, n):
67
+ super().update(n)
68
+ # Because the progress bar might be disabled, we need to manually update the progress
69
+ self._current += n
70
+
71
+ # Inform listeners
72
+ listeners = _get_thread_local_listeners()
73
+
74
+ for listener in listeners:
75
+ listener.on_progress(self._current, self.total)
76
+
77
+ _thread_local = threading.local()
78
+
79
+ def _get_thread_local_listeners():
80
+ if not hasattr(_thread_local, 'listeners'):
81
+ _thread_local.listeners = []
82
+ return _thread_local.listeners
83
+
84
+ _hooked = False
85
+
86
+ def init_progress_hook():
87
+ global _hooked
88
+
89
+ if _hooked:
90
+ return
91
+
92
+ # Inject into tqdm.tqdm of Whisper, so we can see progress
93
+ import whisper.transcribe
94
+ transcribe_module = sys.modules['whisper.transcribe']
95
+ transcribe_module.tqdm.tqdm = _CustomProgressBar
96
+ _hooked = True
97
+
98
+ def register_thread_local_progress_listener(progress_listener: ProgressListener):
99
+ # This is a workaround for the fact that the progress bar is not exposed in the API
100
+ init_progress_hook()
101
+
102
+ listeners = _get_thread_local_listeners()
103
+ listeners.append(progress_listener)
104
+
105
+ def unregister_thread_local_progress_listener(progress_listener: ProgressListener):
106
+ listeners = _get_thread_local_listeners()
107
+
108
+ if progress_listener in listeners:
109
+ listeners.remove(progress_listener)
110
+
111
+ def create_progress_listener_handle(progress_listener: ProgressListener):
112
+ return ProgressListenerHandle(progress_listener)
113
+
114
+ # Example usage
115
+ if __name__ == '__main__':
116
+ class PrintingProgressListener:
117
+ def on_progress(self, current: Union[int, float], total: Union[int, float]):
118
+ print(f"Progress: {current}/{total}")
119
+
120
+ def on_finished(self):
121
+ print("Finished")
122
+
123
+ import whisper
124
+ model = whisper.load_model("medium")
125
+
126
+ with create_progress_listener_handle(PrintingProgressListener()) as listener:
127
+ # Set verbose to None to disable the progress bar, as we are using our own
128
+ result = model.transcribe("J:\\Dev\\OpenAI\\whisper\\tests\\Noriko\\out.mka", language="Japanese", fp16=False, verbose=None)
129
+ print(result)
130
+
131
+ print("Done")
src/modelCache.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class ModelCache:
2
+ def __init__(self):
3
+ self._cache = dict()
4
+
5
+ def get(self, model_key: str, model_factory):
6
+ result = self._cache.get(model_key)
7
+
8
+ if result is None:
9
+ result = model_factory()
10
+ self._cache[model_key] = result
11
+ return result
12
+
13
+ def clear(self):
14
+ self._cache.clear()
15
+
16
+ # A global cache of models. This is mainly used by the daemon processes to avoid loading the same model multiple times.
17
+ GLOBAL_MODEL_CACHE = ModelCache()
src/segments.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List
2
+
3
+ import copy
4
+
5
+ def merge_timestamps(timestamps: List[Dict[str, Any]], merge_window: float = 5, max_merge_size: float = 30, padding_left: float = 1, padding_right: float = 1):
6
+ result = []
7
+
8
+ if len(timestamps) == 0:
9
+ return result
10
+ if max_merge_size is None:
11
+ return timestamps
12
+
13
+ if padding_left is None:
14
+ padding_left = 0
15
+ if padding_right is None:
16
+ padding_right = 0
17
+
18
+ processed_time = 0
19
+ current_segment = None
20
+
21
+ for i in range(len(timestamps)):
22
+ next_segment = timestamps[i]
23
+
24
+ delta = next_segment['start'] - processed_time
25
+
26
+ # Note that segments can still be longer than the max merge size, they just won't be merged in that case
27
+ if current_segment is None or (merge_window is not None and delta > merge_window) \
28
+ or next_segment['end'] - current_segment['start'] > max_merge_size:
29
+ # Finish the current segment
30
+ if current_segment is not None:
31
+ # Add right padding
32
+ finish_padding = min(padding_right, delta / 2) if delta < padding_left + padding_right else padding_right
33
+ current_segment['end'] += finish_padding
34
+ delta -= finish_padding
35
+
36
+ result.append(current_segment)
37
+
38
+ # Start a new segment
39
+ current_segment = copy.deepcopy(next_segment)
40
+
41
+ # Pad the segment
42
+ current_segment['start'] = current_segment['start'] - min(padding_left, delta)
43
+ processed_time = current_segment['end']
44
+
45
+ else:
46
+ # Merge the segment
47
+ current_segment['end'] = next_segment['end']
48
+ processed_time = current_segment['end']
49
+
50
+ # Add the last segment
51
+ if current_segment is not None:
52
+ current_segment['end'] += padding_right
53
+ result.append(current_segment)
54
+
55
+ return result
src/source.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Gradio seems to truncate files without keeping the extension, so we need to truncate the file prefix ourself
2
+ import os
3
+ import pathlib
4
+ from typing import List
5
+ import zipfile
6
+
7
+ import ffmpeg
8
+ from more_itertools import unzip
9
+
10
+ from src.download import ExceededMaximumDuration, download_url
11
+
12
+ MAX_FILE_PREFIX_LENGTH = 17
13
+
14
+ class AudioSource:
15
+ def __init__(self, source_path, source_name = None, audio_duration = None):
16
+ self.source_path = source_path
17
+ self.source_name = source_name
18
+ self._audio_duration = audio_duration
19
+
20
+ # Load source name if not provided
21
+ if (self.source_name is None):
22
+ file_path = pathlib.Path(self.source_path)
23
+ self.source_name = file_path.name
24
+
25
+ def get_audio_duration(self):
26
+ if self._audio_duration is None:
27
+ self._audio_duration = float(ffmpeg.probe(self.source_path)["format"]["duration"])
28
+
29
+ return self._audio_duration
30
+
31
+ def get_full_name(self):
32
+ return self.source_name
33
+
34
+ def get_short_name(self, max_length: int = MAX_FILE_PREFIX_LENGTH):
35
+ file_path = pathlib.Path(self.source_name)
36
+ short_name = file_path.stem[:max_length] + file_path.suffix
37
+
38
+ return short_name
39
+
40
+ def __str__(self) -> str:
41
+ return self.source_path
42
+
43
+ class AudioSourceCollection:
44
+ def __init__(self, sources: List[AudioSource]):
45
+ self.sources = sources
46
+
47
+ def __iter__(self):
48
+ return iter(self.sources)
49
+
50
+ def get_audio_source_collection(urlData: str, multipleFiles: List, microphoneData: str, input_audio_max_duration: float = -1) -> List[AudioSource]:
51
+ output: List[AudioSource] = []
52
+
53
+ if urlData:
54
+ # Download from YouTube. This could also be a playlist or a channel.
55
+ output.extend([ AudioSource(x) for x in download_url(urlData, input_audio_max_duration, playlistItems=None) ])
56
+ else:
57
+ # Add input files
58
+ if (multipleFiles is not None):
59
+ output.extend([ AudioSource(x.name) for x in multipleFiles ])
60
+ if (microphoneData is not None):
61
+ output.append(AudioSource(microphoneData))
62
+
63
+ total_duration = 0
64
+
65
+ # Calculate total audio length. We do this even if input_audio_max_duration
66
+ # is disabled to ensure that all the audio files are valid.
67
+ for source in output:
68
+ audioDuration = ffmpeg.probe(source.source_path)["format"]["duration"]
69
+ total_duration += float(audioDuration)
70
+
71
+ # Save audio duration
72
+ source._audio_duration = float(audioDuration)
73
+
74
+ # Ensure the total duration of the audio is not too long
75
+ if input_audio_max_duration > 0:
76
+ if float(total_duration) > input_audio_max_duration:
77
+ raise ExceededMaximumDuration(videoDuration=total_duration, maxDuration=input_audio_max_duration, message="Video(s) is too long")
78
+
79
+ # Return a list of audio sources
80
+ return output
src/utils.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import textwrap
2
+ import unicodedata
3
+ import re
4
+
5
+ import zlib
6
+ from typing import Iterator, TextIO
7
+
8
+
9
+ def exact_div(x, y):
10
+ assert x % y == 0
11
+ return x // y
12
+
13
+
14
+ def str2bool(string):
15
+ str2val = {"True": True, "False": False}
16
+ if string in str2val:
17
+ return str2val[string]
18
+ else:
19
+ raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}")
20
+
21
+
22
+ def optional_int(string):
23
+ return None if string == "None" else int(string)
24
+
25
+
26
+ def optional_float(string):
27
+ return None if string == "None" else float(string)
28
+
29
+
30
+ def compression_ratio(text) -> float:
31
+ return len(text) / len(zlib.compress(text.encode("utf-8")))
32
+
33
+
34
+ def format_timestamp(seconds: float, always_include_hours: bool = False, fractionalSeperator: str = '.'):
35
+ assert seconds >= 0, "non-negative timestamp expected"
36
+ milliseconds = round(seconds * 1000.0)
37
+
38
+ hours = milliseconds // 3_600_000
39
+ milliseconds -= hours * 3_600_000
40
+
41
+ minutes = milliseconds // 60_000
42
+ milliseconds -= minutes * 60_000
43
+
44
+ seconds = milliseconds // 1_000
45
+ milliseconds -= seconds * 1_000
46
+
47
+ hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
48
+ return f"{hours_marker}{minutes:02d}:{seconds:02d}{fractionalSeperator}{milliseconds:03d}"
49
+
50
+
51
+ def write_txt(transcript: Iterator[dict], file: TextIO):
52
+ for segment in transcript:
53
+ print(segment['text'].strip(), file=file, flush=True)
54
+
55
+
56
+ def write_vtt(transcript: Iterator[dict], file: TextIO, maxLineWidth=None):
57
+ print("WEBVTT\n", file=file)
58
+ for segment in transcript:
59
+ text = process_text(segment['text'], maxLineWidth).replace('-->', '->')
60
+
61
+ print(
62
+ f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n"
63
+ f"{text}\n",
64
+ file=file,
65
+ flush=True,
66
+ )
67
+
68
+
69
+ def write_srt(transcript: Iterator[dict], file: TextIO, maxLineWidth=None):
70
+ """
71
+ Write a transcript to a file in SRT format.
72
+ Example usage:
73
+ from pathlib import Path
74
+ from whisper.utils import write_srt
75
+ result = transcribe(model, audio_path, temperature=temperature, **args)
76
+ # save SRT
77
+ audio_basename = Path(audio_path).stem
78
+ with open(Path(output_dir) / (audio_basename + ".srt"), "w", encoding="utf-8") as srt:
79
+ write_srt(result["segments"], file=srt)
80
+ """
81
+ for i, segment in enumerate(transcript, start=1):
82
+ text = process_text(segment['text'].strip(), maxLineWidth).replace('-->', '->')
83
+
84
+ # write srt lines
85
+ print(
86
+ f"{i}\n"
87
+ f"{format_timestamp(segment['start'], always_include_hours=True, fractionalSeperator=',')} --> "
88
+ f"{format_timestamp(segment['end'], always_include_hours=True, fractionalSeperator=',')}\n"
89
+ f"{text}\n",
90
+ file=file,
91
+ flush=True,
92
+ )
93
+
94
+ def process_text(text: str, maxLineWidth=None):
95
+ if (maxLineWidth is None or maxLineWidth < 0):
96
+ return text
97
+
98
+ lines = textwrap.wrap(text, width=maxLineWidth, tabsize=4)
99
+ return '\n'.join(lines)
100
+
101
+ def slugify(value, allow_unicode=False):
102
+ """
103
+ Taken from https://github.com/django/django/blob/master/django/utils/text.py
104
+ Convert to ASCII if 'allow_unicode' is False. Convert spaces or repeated
105
+ dashes to single dashes. Remove characters that aren't alphanumerics,
106
+ underscores, or hyphens. Convert to lowercase. Also strip leading and
107
+ trailing whitespace, dashes, and underscores.
108
+ """
109
+ value = str(value)
110
+ if allow_unicode:
111
+ value = unicodedata.normalize('NFKC', value)
112
+ else:
113
+ value = unicodedata.normalize('NFKD', value).encode('ascii', 'ignore').decode('ascii')
114
+ value = re.sub(r'[^\w\s-]', '', value.lower())
115
+ return re.sub(r'[-\s]+', '-', value).strip('-_')
src/vad.py ADDED
@@ -0,0 +1,553 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from collections import Counter, deque
3
+ import time
4
+
5
+ from typing import Any, Deque, Iterator, List, Dict
6
+
7
+ from pprint import pprint
8
+ from src.hooks.whisperProgressHook import ProgressListener, SubTaskProgressListener, create_progress_listener_handle
9
+ from src.modelCache import GLOBAL_MODEL_CACHE, ModelCache
10
+
11
+ from src.segments import merge_timestamps
12
+ from src.whisperContainer import WhisperCallback
13
+
14
+ # Workaround for https://github.com/tensorflow/tensorflow/issues/48797
15
+ try:
16
+ import tensorflow as tf
17
+ except ModuleNotFoundError:
18
+ # Error handling
19
+ pass
20
+
21
+ import torch
22
+
23
+ import ffmpeg
24
+ import numpy as np
25
+
26
+ from src.utils import format_timestamp
27
+ from enum import Enum
28
+
29
+ class NonSpeechStrategy(Enum):
30
+ """
31
+ Ignore non-speech frames segments.
32
+ """
33
+ SKIP = 1
34
+ """
35
+ Just treat non-speech segments as speech.
36
+ """
37
+ CREATE_SEGMENT = 2
38
+ """
39
+ Expand speech segments into subsequent non-speech segments.
40
+ """
41
+ EXPAND_SEGMENT = 3
42
+
43
+ # Defaults for Silero
44
+ SPEECH_TRESHOLD = 0.3
45
+
46
+ # Minimum size of segments to process
47
+ MIN_SEGMENT_DURATION = 1
48
+
49
+ # The maximum time for texts from old segments to be used in the next segment
50
+ MAX_PROMPT_WINDOW = 0 # seconds (0 = disabled)
51
+ PROMPT_NO_SPEECH_PROB = 0.1 # Do not pass the text from segments with a no speech probability higher than this
52
+
53
+ VAD_MAX_PROCESSING_CHUNK = 60 * 60 # 60 minutes of audio
54
+
55
+ class TranscriptionConfig(ABC):
56
+ def __init__(self, non_speech_strategy: NonSpeechStrategy = NonSpeechStrategy.SKIP,
57
+ segment_padding_left: float = None, segment_padding_right = None, max_silent_period: float = None,
58
+ max_merge_size: float = None, max_prompt_window: float = None, initial_segment_index = -1):
59
+ self.non_speech_strategy = non_speech_strategy
60
+ self.segment_padding_left = segment_padding_left
61
+ self.segment_padding_right = segment_padding_right
62
+ self.max_silent_period = max_silent_period
63
+ self.max_merge_size = max_merge_size
64
+ self.max_prompt_window = max_prompt_window
65
+ self.initial_segment_index = initial_segment_index
66
+
67
+ class PeriodicTranscriptionConfig(TranscriptionConfig):
68
+ def __init__(self, periodic_duration: float, non_speech_strategy: NonSpeechStrategy = NonSpeechStrategy.SKIP,
69
+ segment_padding_left: float = None, segment_padding_right = None, max_silent_period: float = None,
70
+ max_merge_size: float = None, max_prompt_window: float = None, initial_segment_index = -1):
71
+ super().__init__(non_speech_strategy, segment_padding_left, segment_padding_right, max_silent_period, max_merge_size, max_prompt_window, initial_segment_index)
72
+ self.periodic_duration = periodic_duration
73
+
74
+ class AbstractTranscription(ABC):
75
+ def __init__(self, sampling_rate: int = 16000):
76
+ self.sampling_rate = sampling_rate
77
+
78
+ def get_audio_segment(self, str, start_time: str = None, duration: str = None):
79
+ return load_audio(str, self.sampling_rate, start_time, duration)
80
+
81
+ def is_transcribe_timestamps_fast(self):
82
+ """
83
+ Determine if get_transcribe_timestamps is fast enough to not need parallelization.
84
+ """
85
+ return False
86
+
87
+ @abstractmethod
88
+ def get_transcribe_timestamps(self, audio: str, config: TranscriptionConfig, start_time: float, end_time: float):
89
+ """
90
+ Get the start and end timestamps of the sections that should be transcribed by this VAD method.
91
+
92
+ Parameters
93
+ ----------
94
+ audio: str
95
+ The audio file.
96
+ config: TranscriptionConfig
97
+ The transcription configuration.
98
+
99
+ Returns
100
+ -------
101
+ A list of start and end timestamps, in fractional seconds.
102
+ """
103
+ return
104
+
105
+ def get_merged_timestamps(self, timestamps: List[Dict[str, Any]], config: TranscriptionConfig, total_duration: float):
106
+ """
107
+ Get the start and end timestamps of the sections that should be transcribed by this VAD method,
108
+ after merging the given segments using the specified configuration.
109
+
110
+ Parameters
111
+ ----------
112
+ audio: str
113
+ The audio file.
114
+ config: TranscriptionConfig
115
+ The transcription configuration.
116
+
117
+ Returns
118
+ -------
119
+ A list of start and end timestamps, in fractional seconds.
120
+ """
121
+ merged = merge_timestamps(timestamps, config.max_silent_period, config.max_merge_size,
122
+ config.segment_padding_left, config.segment_padding_right)
123
+
124
+ if config.non_speech_strategy != NonSpeechStrategy.SKIP:
125
+ # Expand segments to include the gaps between them
126
+ if (config.non_speech_strategy == NonSpeechStrategy.CREATE_SEGMENT):
127
+ # When we have a prompt window, we create speech segments betwen each segment if we exceed the merge size
128
+ merged = self.fill_gaps(merged, total_duration=total_duration, max_expand_size=config.max_merge_size)
129
+ elif config.non_speech_strategy == NonSpeechStrategy.EXPAND_SEGMENT:
130
+ # With no prompt window, it is better to just expand the segments (this effectively passes the prompt to the next segment)
131
+ merged = self.expand_gaps(merged, total_duration=total_duration)
132
+ else:
133
+ raise Exception("Unknown non-speech strategy: " + str(config.non_speech_strategy))
134
+
135
+ print("Transcribing non-speech:")
136
+ pprint(merged)
137
+ return merged
138
+
139
+ def transcribe(self, audio: str, whisperCallable: WhisperCallback, config: TranscriptionConfig,
140
+ progressListener: ProgressListener = None):
141
+ """
142
+ Transcribe the given audo file.
143
+
144
+ Parameters
145
+ ----------
146
+ audio: str
147
+ The audio file.
148
+ whisperCallable: WhisperCallback
149
+ A callback object to call to transcribe each segment.
150
+
151
+ Returns
152
+ -------
153
+ A list of start and end timestamps, in fractional seconds.
154
+ """
155
+
156
+ try:
157
+ max_audio_duration = self.get_audio_duration(audio, config)
158
+ timestamp_segments = self.get_transcribe_timestamps(audio, config, 0, max_audio_duration)
159
+
160
+ # Get speech timestamps from full audio file
161
+ merged = self.get_merged_timestamps(timestamp_segments, config, max_audio_duration)
162
+
163
+ # A deque of transcribed segments that is passed to the next segment as a prompt
164
+ prompt_window = deque()
165
+
166
+ print("Processing timestamps:")
167
+ pprint(merged)
168
+
169
+ result = {
170
+ 'text': "",
171
+ 'segments': [],
172
+ 'language': ""
173
+ }
174
+ languageCounter = Counter()
175
+ detected_language = None
176
+
177
+ segment_index = config.initial_segment_index
178
+
179
+ # Calculate progress
180
+ progress_start_offset = merged[0]['start'] if len(merged) > 0 else 0
181
+ progress_total_duration = sum([segment['end'] - segment['start'] for segment in merged])
182
+
183
+ # For each time segment, run whisper
184
+ for segment in merged:
185
+ segment_index += 1
186
+ segment_start = segment['start']
187
+ segment_end = segment['end']
188
+ segment_expand_amount = segment.get('expand_amount', 0)
189
+ segment_gap = segment.get('gap', False)
190
+
191
+ segment_duration = segment_end - segment_start
192
+
193
+ if segment_duration < MIN_SEGMENT_DURATION:
194
+ continue
195
+
196
+ # Audio to run on Whisper
197
+ segment_audio = self.get_audio_segment(audio, start_time = str(segment_start), duration = str(segment_duration))
198
+ # Previous segments to use as a prompt
199
+ segment_prompt = ' '.join([segment['text'] for segment in prompt_window]) if len(prompt_window) > 0 else None
200
+
201
+ # Detected language
202
+ detected_language = languageCounter.most_common(1)[0][0] if len(languageCounter) > 0 else None
203
+
204
+ print("Running whisper from ", format_timestamp(segment_start), " to ", format_timestamp(segment_end), ", duration: ",
205
+ segment_duration, "expanded: ", segment_expand_amount, "prompt: ", segment_prompt, "language: ", detected_language)
206
+
207
+ scaled_progress_listener = SubTaskProgressListener(progressListener, base_task_total=progress_total_duration,
208
+ sub_task_start=segment_start - progress_start_offset, sub_task_total=segment_duration)
209
+ segment_result = whisperCallable.invoke(segment_audio, segment_index, segment_prompt, detected_language, progress_listener=scaled_progress_listener)
210
+
211
+ adjusted_segments = self.adjust_timestamp(segment_result["segments"], adjust_seconds=segment_start, max_source_time=segment_duration)
212
+
213
+ # Propagate expand amount to the segments
214
+ if (segment_expand_amount > 0):
215
+ segment_without_expansion = segment_duration - segment_expand_amount
216
+
217
+ for adjusted_segment in adjusted_segments:
218
+ adjusted_segment_end = adjusted_segment['end']
219
+
220
+ # Add expand amount if the segment got expanded
221
+ if (adjusted_segment_end > segment_without_expansion):
222
+ adjusted_segment["expand_amount"] = adjusted_segment_end - segment_without_expansion
223
+
224
+ # Append to output
225
+ result['text'] += segment_result['text']
226
+ result['segments'].extend(adjusted_segments)
227
+
228
+ # Increment detected language
229
+ if not segment_gap:
230
+ languageCounter[segment_result['language']] += 1
231
+
232
+ # Update prompt window
233
+ self.__update_prompt_window(prompt_window, adjusted_segments, segment_end, segment_gap, config)
234
+
235
+ if detected_language is not None:
236
+ result['language'] = detected_language
237
+ finally:
238
+ # Notify progress listener that we are done
239
+ if progressListener is not None:
240
+ progressListener.on_finished()
241
+ return result
242
+
243
+ def get_audio_duration(self, audio: str, config: TranscriptionConfig):
244
+ return get_audio_duration(audio)
245
+
246
+ def __update_prompt_window(self, prompt_window: Deque, adjusted_segments: List, segment_end: float, segment_gap: bool, config: TranscriptionConfig):
247
+ if (config.max_prompt_window is not None and config.max_prompt_window > 0):
248
+ # Add segments to the current prompt window (unless it is a speech gap)
249
+ if not segment_gap:
250
+ for segment in adjusted_segments:
251
+ if segment.get('no_speech_prob', 0) <= PROMPT_NO_SPEECH_PROB:
252
+ prompt_window.append(segment)
253
+
254
+ while (len(prompt_window) > 0):
255
+ first_end_time = prompt_window[0].get('end', 0)
256
+ # Time expanded in the segments should be discounted from the prompt window
257
+ first_expand_time = prompt_window[0].get('expand_amount', 0)
258
+
259
+ if (first_end_time - first_expand_time < segment_end - config.max_prompt_window):
260
+ prompt_window.popleft()
261
+ else:
262
+ break
263
+
264
+ def include_gaps(self, segments: Iterator[dict], min_gap_length: float, total_duration: float):
265
+ result = []
266
+ last_end_time = 0
267
+
268
+ for segment in segments:
269
+ segment_start = float(segment['start'])
270
+ segment_end = float(segment['end'])
271
+
272
+ if (last_end_time != segment_start):
273
+ delta = segment_start - last_end_time
274
+
275
+ if (min_gap_length is None or delta >= min_gap_length):
276
+ result.append( { 'start': last_end_time, 'end': segment_start, 'gap': True } )
277
+
278
+ last_end_time = segment_end
279
+ result.append(segment)
280
+
281
+ # Also include total duration if specified
282
+ if (total_duration is not None and last_end_time < total_duration):
283
+ delta = total_duration - segment_start
284
+
285
+ if (min_gap_length is None or delta >= min_gap_length):
286
+ result.append( { 'start': last_end_time, 'end': total_duration, 'gap': True } )
287
+
288
+ return result
289
+
290
+ # Expand the end time of each segment to the start of the next segment
291
+ def expand_gaps(self, segments: List[Dict[str, Any]], total_duration: float):
292
+ result = []
293
+
294
+ if len(segments) == 0:
295
+ return result
296
+
297
+ # Add gap at the beginning if needed
298
+ if (segments[0]['start'] > 0):
299
+ result.append({ 'start': 0, 'end': segments[0]['start'], 'gap': True } )
300
+
301
+ for i in range(len(segments) - 1):
302
+ current_segment = segments[i]
303
+ next_segment = segments[i + 1]
304
+
305
+ delta = next_segment['start'] - current_segment['end']
306
+
307
+ # Expand if the gap actually exists
308
+ if (delta >= 0):
309
+ current_segment = current_segment.copy()
310
+ current_segment['expand_amount'] = delta
311
+ current_segment['end'] = next_segment['start']
312
+
313
+ result.append(current_segment)
314
+
315
+ # Add last segment
316
+ last_segment = segments[-1]
317
+ result.append(last_segment)
318
+
319
+ # Also include total duration if specified
320
+ if (total_duration is not None):
321
+ last_segment = result[-1]
322
+
323
+ if (last_segment['end'] < total_duration):
324
+ last_segment = last_segment.copy()
325
+ last_segment['end'] = total_duration
326
+ result[-1] = last_segment
327
+
328
+ return result
329
+
330
+ def fill_gaps(self, segments: List[Dict[str, Any]], total_duration: float, max_expand_size: float = None):
331
+ result = []
332
+
333
+ if len(segments) == 0:
334
+ return result
335
+
336
+ # Add gap at the beginning if needed
337
+ if (segments[0]['start'] > 0):
338
+ result.append({ 'start': 0, 'end': segments[0]['start'], 'gap': True } )
339
+
340
+ for i in range(len(segments) - 1):
341
+ expanded = False
342
+ current_segment = segments[i]
343
+ next_segment = segments[i + 1]
344
+
345
+ delta = next_segment['start'] - current_segment['end']
346
+
347
+ if (max_expand_size is not None and delta <= max_expand_size):
348
+ # Just expand the current segment
349
+ current_segment = current_segment.copy()
350
+ current_segment['expand_amount'] = delta
351
+ current_segment['end'] = next_segment['start']
352
+ expanded = True
353
+
354
+ result.append(current_segment)
355
+
356
+ # Add a gap to the next segment if needed
357
+ if (delta >= 0 and not expanded):
358
+ result.append({ 'start': current_segment['end'], 'end': next_segment['start'], 'gap': True } )
359
+
360
+ # Add last segment
361
+ last_segment = segments[-1]
362
+ result.append(last_segment)
363
+
364
+ # Also include total duration if specified
365
+ if (total_duration is not None):
366
+ last_segment = result[-1]
367
+
368
+ delta = total_duration - last_segment['end']
369
+
370
+ if (delta > 0):
371
+ if (max_expand_size is not None and delta <= max_expand_size):
372
+ # Expand the last segment
373
+ last_segment = last_segment.copy()
374
+ last_segment['expand_amount'] = delta
375
+ last_segment['end'] = total_duration
376
+ result[-1] = last_segment
377
+ else:
378
+ result.append({ 'start': last_segment['end'], 'end': total_duration, 'gap': True } )
379
+
380
+ return result
381
+
382
+ def adjust_timestamp(self, segments: Iterator[dict], adjust_seconds: float, max_source_time: float = None):
383
+ result = []
384
+
385
+ for segment in segments:
386
+ segment_start = float(segment['start'])
387
+ segment_end = float(segment['end'])
388
+
389
+ # Filter segments?
390
+ if (max_source_time is not None):
391
+ if (segment_start > max_source_time):
392
+ continue
393
+ segment_end = min(max_source_time, segment_end)
394
+
395
+ new_segment = segment.copy()
396
+
397
+ # Add to start and end
398
+ new_segment['start'] = segment_start + adjust_seconds
399
+ new_segment['end'] = segment_end + adjust_seconds
400
+ result.append(new_segment)
401
+ return result
402
+
403
+ def multiply_timestamps(self, timestamps: List[Dict[str, Any]], factor: float):
404
+ result = []
405
+
406
+ for entry in timestamps:
407
+ start = entry['start']
408
+ end = entry['end']
409
+
410
+ result.append({
411
+ 'start': start * factor,
412
+ 'end': end * factor
413
+ })
414
+ return result
415
+
416
+
417
+ class VadSileroTranscription(AbstractTranscription):
418
+ def __init__(self, sampling_rate: int = 16000, cache: ModelCache = None):
419
+ super().__init__(sampling_rate=sampling_rate)
420
+ self.model = None
421
+ self.cache = cache
422
+ self._initialize_model()
423
+
424
+ def _initialize_model(self):
425
+ if (self.cache is not None):
426
+ model_key = "VadSileroTranscription"
427
+ self.model, self.get_speech_timestamps = self.cache.get(model_key, self._create_model)
428
+ print("Loaded Silerio model from cache.")
429
+ else:
430
+ self.model, self.get_speech_timestamps = self._create_model()
431
+ print("Created Silerio model")
432
+
433
+ def _create_model(self):
434
+ model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad', model='silero_vad')
435
+
436
+ # Silero does not benefit from multi-threading
437
+ torch.set_num_threads(1) # JIT
438
+ (get_speech_timestamps, _, _, _, _) = utils
439
+
440
+ return model, get_speech_timestamps
441
+
442
+ def get_transcribe_timestamps(self, audio: str, config: TranscriptionConfig, start_time: float, end_time: float):
443
+ result = []
444
+
445
+ print("Getting timestamps from audio file: {}, start: {}, duration: {}".format(audio, start_time, end_time))
446
+ perf_start_time = time.perf_counter()
447
+
448
+ # Divide procesisng of audio into chunks
449
+ chunk_start = start_time
450
+
451
+ while (chunk_start < end_time):
452
+ chunk_duration = min(end_time - chunk_start, VAD_MAX_PROCESSING_CHUNK)
453
+
454
+ print("Processing VAD in chunk from {} to {}".format(format_timestamp(chunk_start), format_timestamp(chunk_start + chunk_duration)))
455
+ wav = self.get_audio_segment(audio, str(chunk_start), str(chunk_duration))
456
+
457
+ sample_timestamps = self.get_speech_timestamps(wav, self.model, sampling_rate=self.sampling_rate, threshold=SPEECH_TRESHOLD)
458
+ seconds_timestamps = self.multiply_timestamps(sample_timestamps, factor=1 / self.sampling_rate)
459
+ adjusted = self.adjust_timestamp(seconds_timestamps, adjust_seconds=chunk_start, max_source_time=chunk_start + chunk_duration)
460
+
461
+ #pprint(adjusted)
462
+
463
+ result.extend(adjusted)
464
+ chunk_start += chunk_duration
465
+
466
+ perf_end_time = time.perf_counter()
467
+ print("VAD processing took {} seconds".format(perf_end_time - perf_start_time))
468
+
469
+ return result
470
+
471
+ def __getstate__(self):
472
+ # We only need the sampling rate
473
+ return { 'sampling_rate': self.sampling_rate }
474
+
475
+ def __setstate__(self, state):
476
+ self.sampling_rate = state['sampling_rate']
477
+ self.model = None
478
+ # Use the global cache
479
+ self.cache = GLOBAL_MODEL_CACHE
480
+ self._initialize_model()
481
+
482
+ # A very simple VAD that just marks every N seconds as speech
483
+ class VadPeriodicTranscription(AbstractTranscription):
484
+ def __init__(self, sampling_rate: int = 16000):
485
+ super().__init__(sampling_rate=sampling_rate)
486
+
487
+ def is_transcribe_timestamps_fast(self):
488
+ # This is a very fast VAD - no need to parallelize it
489
+ return True
490
+
491
+ def get_transcribe_timestamps(self, audio: str, config: PeriodicTranscriptionConfig, start_time: float, end_time: float):
492
+ result = []
493
+
494
+ # Generate a timestamp every N seconds
495
+ start_timestamp = start_time
496
+
497
+ while (start_timestamp < end_time):
498
+ end_timestamp = min(start_timestamp + config.periodic_duration, end_time)
499
+ segment_duration = end_timestamp - start_timestamp
500
+
501
+ # Minimum duration is 1 second
502
+ if (segment_duration >= 1):
503
+ result.append( { 'start': start_timestamp, 'end': end_timestamp } )
504
+
505
+ start_timestamp = end_timestamp
506
+
507
+ return result
508
+
509
+ def get_audio_duration(file: str):
510
+ return float(ffmpeg.probe(file)["format"]["duration"])
511
+
512
+ def load_audio(file: str, sample_rate: int = 16000,
513
+ start_time: str = None, duration: str = None):
514
+ """
515
+ Open an audio file and read as mono waveform, resampling as necessary
516
+
517
+ Parameters
518
+ ----------
519
+ file: str
520
+ The audio file to open
521
+
522
+ sr: int
523
+ The sample rate to resample the audio if necessary
524
+
525
+ start_time: str
526
+ The start time, using the standard FFMPEG time duration syntax, or None to disable.
527
+
528
+ duration: str
529
+ The duration, using the standard FFMPEG time duration syntax, or None to disable.
530
+
531
+ Returns
532
+ -------
533
+ A NumPy array containing the audio waveform, in float32 dtype.
534
+ """
535
+ try:
536
+ inputArgs = {'threads': 0}
537
+
538
+ if (start_time is not None):
539
+ inputArgs['ss'] = start_time
540
+ if (duration is not None):
541
+ inputArgs['t'] = duration
542
+
543
+ # This launches a subprocess to decode audio while down-mixing and resampling as necessary.
544
+ # Requires the ffmpeg CLI and `ffmpeg-python` package to be installed.
545
+ out, _ = (
546
+ ffmpeg.input(file, **inputArgs)
547
+ .output("-", format="s16le", acodec="pcm_s16le", ac=1, ar=sample_rate)
548
+ .run(cmd="ffmpeg", capture_stdout=True, capture_stderr=True)
549
+ )
550
+ except ffmpeg.Error as e:
551
+ raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}")
552
+
553
+ return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
src/vadParallel.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import multiprocessing
2
+ from queue import Empty
3
+ import threading
4
+ import time
5
+ from src.hooks.whisperProgressHook import ProgressListener
6
+ from src.vad import AbstractTranscription, TranscriptionConfig, get_audio_duration
7
+ from src.whisperContainer import WhisperCallback
8
+
9
+ from multiprocessing import Pool, Queue
10
+
11
+ from typing import Any, Dict, List, Union
12
+ import os
13
+
14
+ class _ProgressListenerToQueue(ProgressListener):
15
+ def __init__(self, progress_queue: Queue):
16
+ self.progress_queue = progress_queue
17
+ self.progress_total = 0
18
+ self.prev_progress = 0
19
+
20
+ def on_progress(self, current: Union[int, float], total: Union[int, float]):
21
+ delta = current - self.prev_progress
22
+ self.prev_progress = current
23
+ self.progress_total = total
24
+ self.progress_queue.put(delta)
25
+
26
+ def on_finished(self):
27
+ if self.progress_total > self.prev_progress:
28
+ delta = self.progress_total - self.prev_progress
29
+ self.progress_queue.put(delta)
30
+ self.prev_progress = self.progress_total
31
+
32
+ class ParallelContext:
33
+ def __init__(self, num_processes: int = None, auto_cleanup_timeout_seconds: float = None):
34
+ self.num_processes = num_processes
35
+ self.auto_cleanup_timeout_seconds = auto_cleanup_timeout_seconds
36
+ self.lock = threading.Lock()
37
+
38
+ self.ref_count = 0
39
+ self.pool = None
40
+ self.cleanup_timer = None
41
+
42
+ def get_pool(self):
43
+ # Initialize pool lazily
44
+ if (self.pool is None):
45
+ context = multiprocessing.get_context('spawn')
46
+ self.pool = context.Pool(self.num_processes)
47
+
48
+ self.ref_count = self.ref_count + 1
49
+
50
+ if (self.auto_cleanup_timeout_seconds is not None):
51
+ self._stop_auto_cleanup()
52
+
53
+ return self.pool
54
+
55
+ def return_pool(self, pool):
56
+ if (self.pool == pool and self.ref_count > 0):
57
+ self.ref_count = self.ref_count - 1
58
+
59
+ if (self.ref_count == 0):
60
+ if (self.auto_cleanup_timeout_seconds is not None):
61
+ self._start_auto_cleanup()
62
+
63
+ def _start_auto_cleanup(self):
64
+ if (self.cleanup_timer is not None):
65
+ self.cleanup_timer.cancel()
66
+ self.cleanup_timer = threading.Timer(self.auto_cleanup_timeout_seconds, self._execute_cleanup)
67
+ self.cleanup_timer.start()
68
+
69
+ print("Started auto cleanup of pool in " + str(self.auto_cleanup_timeout_seconds) + " seconds")
70
+
71
+ def _stop_auto_cleanup(self):
72
+ if (self.cleanup_timer is not None):
73
+ self.cleanup_timer.cancel()
74
+ self.cleanup_timer = None
75
+
76
+ print("Stopped auto cleanup of pool")
77
+
78
+ def _execute_cleanup(self):
79
+ print("Executing cleanup of pool")
80
+
81
+ if (self.ref_count == 0):
82
+ self.close()
83
+
84
+ def close(self):
85
+ self._stop_auto_cleanup()
86
+
87
+ if (self.pool is not None):
88
+ print("Closing pool of " + str(self.num_processes) + " processes")
89
+ self.pool.close()
90
+ self.pool.join()
91
+ self.pool = None
92
+
93
+ class ParallelTranscriptionConfig(TranscriptionConfig):
94
+ def __init__(self, device_id: str, override_timestamps, initial_segment_index, copy: TranscriptionConfig = None):
95
+ super().__init__(copy.non_speech_strategy, copy.segment_padding_left, copy.segment_padding_right, copy.max_silent_period, copy.max_merge_size, copy.max_prompt_window, initial_segment_index)
96
+ self.device_id = device_id
97
+ self.override_timestamps = override_timestamps
98
+
99
+ class ParallelTranscription(AbstractTranscription):
100
+ # Silero VAD typically takes about 3 seconds per minute, so there's no need to split the chunks
101
+ # into smaller segments than 2 minute (min 6 seconds per CPU core)
102
+ MIN_CPU_CHUNK_SIZE_SECONDS = 2 * 60
103
+
104
+ def __init__(self, sampling_rate: int = 16000):
105
+ super().__init__(sampling_rate=sampling_rate)
106
+
107
+ def transcribe_parallel(self, transcription: AbstractTranscription, audio: str, whisperCallable: WhisperCallback, config: TranscriptionConfig,
108
+ cpu_device_count: int, gpu_devices: List[str], cpu_parallel_context: ParallelContext = None, gpu_parallel_context: ParallelContext = None,
109
+ progress_listener: ProgressListener = None):
110
+ total_duration = get_audio_duration(audio)
111
+
112
+ # First, get the timestamps for the original audio
113
+ if (cpu_device_count > 1 and not transcription.is_transcribe_timestamps_fast()):
114
+ merged = self._get_merged_timestamps_parallel(transcription, audio, config, total_duration, cpu_device_count, cpu_parallel_context)
115
+ else:
116
+ timestamp_segments = transcription.get_transcribe_timestamps(audio, config, 0, total_duration)
117
+ merged = transcription.get_merged_timestamps(timestamp_segments, config, total_duration)
118
+
119
+ # We must make sure the whisper model is downloaded
120
+ if (len(gpu_devices) > 1):
121
+ whisperCallable.model_container.ensure_downloaded()
122
+
123
+ # Split into a list for each device
124
+ # TODO: Split by time instead of by number of chunks
125
+ merged_split = list(self._split(merged, len(gpu_devices)))
126
+
127
+ # Parameters that will be passed to the transcribe function
128
+ parameters = []
129
+ segment_index = config.initial_segment_index
130
+
131
+ processing_manager = multiprocessing.Manager()
132
+ progress_queue = processing_manager.Queue()
133
+
134
+ for i in range(len(gpu_devices)):
135
+ # Note that device_segment_list can be empty. But we will still create a process for it,
136
+ # as otherwise we run the risk of assigning the same device to multiple processes.
137
+ device_segment_list = list(merged_split[i]) if i < len(merged_split) else []
138
+ device_id = gpu_devices[i]
139
+
140
+ print("Device " + str(device_id) + " (index " + str(i) + ") has " + str(len(device_segment_list)) + " segments")
141
+
142
+ # Create a new config with the given device ID
143
+ device_config = ParallelTranscriptionConfig(device_id, device_segment_list, segment_index, config)
144
+ segment_index += len(device_segment_list)
145
+
146
+ progress_listener_to_queue = _ProgressListenerToQueue(progress_queue)
147
+ parameters.append([audio, whisperCallable, device_config, progress_listener_to_queue]);
148
+
149
+ merged = {
150
+ 'text': '',
151
+ 'segments': [],
152
+ 'language': None
153
+ }
154
+
155
+ created_context = False
156
+
157
+ perf_start_gpu = time.perf_counter()
158
+
159
+ # Spawn a separate process for each device
160
+ try:
161
+ if (gpu_parallel_context is None):
162
+ gpu_parallel_context = ParallelContext(len(gpu_devices))
163
+ created_context = True
164
+
165
+ # Get a pool of processes
166
+ pool = gpu_parallel_context.get_pool()
167
+
168
+ # Run the transcription in parallel
169
+ results_async = pool.starmap_async(self.transcribe, parameters)
170
+ total_progress = 0
171
+
172
+ while not results_async.ready():
173
+ try:
174
+ delta = progress_queue.get(timeout=5) # Set a timeout of 5 seconds
175
+ except Empty:
176
+ continue
177
+
178
+ total_progress += delta
179
+ if progress_listener is not None:
180
+ progress_listener.on_progress(total_progress, total_duration)
181
+
182
+ results = results_async.get()
183
+
184
+ # Call the finished callback
185
+ if progress_listener is not None:
186
+ progress_listener.on_finished()
187
+
188
+ for result in results:
189
+ # Merge the results
190
+ if (result['text'] is not None):
191
+ merged['text'] += result['text']
192
+ if (result['segments'] is not None):
193
+ merged['segments'].extend(result['segments'])
194
+ if (result['language'] is not None):
195
+ merged['language'] = result['language']
196
+
197
+ finally:
198
+ # Return the pool to the context
199
+ if (gpu_parallel_context is not None):
200
+ gpu_parallel_context.return_pool(pool)
201
+ # Always close the context if we created it
202
+ if (created_context):
203
+ gpu_parallel_context.close()
204
+
205
+ perf_end_gpu = time.perf_counter()
206
+ print("Parallel transcription took " + str(perf_end_gpu - perf_start_gpu) + " seconds")
207
+
208
+ return merged
209
+
210
+ def _get_merged_timestamps_parallel(self, transcription: AbstractTranscription, audio: str, config: TranscriptionConfig, total_duration: float,
211
+ cpu_device_count: int, cpu_parallel_context: ParallelContext = None):
212
+ parameters = []
213
+
214
+ chunk_size = max(total_duration / cpu_device_count, self.MIN_CPU_CHUNK_SIZE_SECONDS)
215
+ chunk_start = 0
216
+ cpu_device_id = 0
217
+
218
+ perf_start_time = time.perf_counter()
219
+
220
+ # Create chunks that will be processed on the CPU
221
+ while (chunk_start < total_duration):
222
+ chunk_end = min(chunk_start + chunk_size, total_duration)
223
+
224
+ if (chunk_end - chunk_start < 1):
225
+ # No need to process chunks that are less than 1 second
226
+ break
227
+
228
+ print("Parallel VAD: Executing chunk from " + str(chunk_start) + " to " +
229
+ str(chunk_end) + " on CPU device " + str(cpu_device_id))
230
+ parameters.append([audio, config, chunk_start, chunk_end]);
231
+
232
+ cpu_device_id += 1
233
+ chunk_start = chunk_end
234
+
235
+ created_context = False
236
+
237
+ # Spawn a separate process for each device
238
+ try:
239
+ if (cpu_parallel_context is None):
240
+ cpu_parallel_context = ParallelContext(cpu_device_count)
241
+ created_context = True
242
+
243
+ # Get a pool of processes
244
+ pool = cpu_parallel_context.get_pool()
245
+
246
+ # Run the transcription in parallel. Note that transcription must be picklable.
247
+ results = pool.starmap(transcription.get_transcribe_timestamps, parameters)
248
+
249
+ timestamps = []
250
+
251
+ # Flatten the results
252
+ for result in results:
253
+ timestamps.extend(result)
254
+
255
+ merged = transcription.get_merged_timestamps(timestamps, config, total_duration)
256
+
257
+ perf_end_time = time.perf_counter()
258
+ print("Parallel VAD processing took {} seconds".format(perf_end_time - perf_start_time))
259
+ return merged
260
+
261
+ finally:
262
+ # Return the pool to the context
263
+ if (cpu_parallel_context is not None):
264
+ cpu_parallel_context.return_pool(pool)
265
+ # Always close the context if we created it
266
+ if (created_context):
267
+ cpu_parallel_context.close()
268
+
269
+ def get_transcribe_timestamps(self, audio: str, config: ParallelTranscriptionConfig, start_time: float, duration: float):
270
+ return []
271
+
272
+ def get_merged_timestamps(self, timestamps: List[Dict[str, Any]], config: ParallelTranscriptionConfig, total_duration: float):
273
+ # Override timestamps that will be processed
274
+ if (config.override_timestamps is not None):
275
+ print("(get_merged_timestamps) Using override timestamps of size " + str(len(config.override_timestamps)))
276
+ return config.override_timestamps
277
+ return super().get_merged_timestamps(timestamps, config, total_duration)
278
+
279
+ def transcribe(self, audio: str, whisperCallable: WhisperCallback, config: ParallelTranscriptionConfig,
280
+ progressListener: ProgressListener = None):
281
+ # Override device ID the first time
282
+ if (os.environ.get("INITIALIZED", None) is None):
283
+ os.environ["INITIALIZED"] = "1"
284
+
285
+ # Note that this may be None if the user didn't specify a device. In that case, Whisper will
286
+ # just use the default GPU device.
287
+ if (config.device_id is not None):
288
+ print("Using device " + config.device_id)
289
+ os.environ["CUDA_VISIBLE_DEVICES"] = config.device_id
290
+
291
+ return super().transcribe(audio, whisperCallable, config, progressListener)
292
+
293
+ def _split(self, a, n):
294
+ """Split a list into n approximately equal parts."""
295
+ k, m = divmod(len(a), n)
296
+ return (a[i*k+min(i, m):(i+1)*k+min(i+1, m)] for i in range(n))
297
+
src/whisperContainer.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # External programs
2
+ import os
3
+ import sys
4
+ from typing import List
5
+
6
+ import whisper
7
+ from whisper import Whisper
8
+
9
+ from src.config import ModelConfig
10
+ from src.hooks.whisperProgressHook import ProgressListener, create_progress_listener_handle
11
+
12
+ from src.modelCache import GLOBAL_MODEL_CACHE, ModelCache
13
+
14
+ class WhisperContainer:
15
+ def __init__(self, model_name: str, device: str = None, download_root: str = None,
16
+ cache: ModelCache = None, models: List[ModelConfig] = []):
17
+ self.model_name = model_name
18
+ self.device = device
19
+ self.download_root = download_root
20
+ self.cache = cache
21
+
22
+ # Will be created on demand
23
+ self.model = None
24
+
25
+ # List of known models
26
+ self.models = models
27
+
28
+ def get_model(self):
29
+ if self.model is None:
30
+
31
+ if (self.cache is None):
32
+ self.model = self._create_model()
33
+ else:
34
+ model_key = "WhisperContainer." + self.model_name + ":" + (self.device if self.device else '')
35
+ self.model = self.cache.get(model_key, self._create_model)
36
+ return self.model
37
+
38
+ def ensure_downloaded(self):
39
+ """
40
+ Ensure that the model is downloaded. This is useful if you want to ensure that the model is downloaded before
41
+ passing the container to a subprocess.
42
+ """
43
+ # Warning: Using private API here
44
+ try:
45
+ root_dir = self.download_root
46
+ model_config = self.get_model_config()
47
+
48
+ if root_dir is None:
49
+ root_dir = os.path.join(os.path.expanduser("~"), ".cache", "whisper")
50
+
51
+ if self.model_name in whisper._MODELS:
52
+ whisper._download(whisper._MODELS[self.model_name], root_dir, False)
53
+ else:
54
+ # If the model is not in the official list, see if it needs to be downloaded
55
+ model_config.download_url(root_dir)
56
+ return True
57
+
58
+ except Exception as e:
59
+ # Given that the API is private, it could change at any time. We don't want to crash the program
60
+ print("Error pre-downloading model: " + str(e))
61
+ return False
62
+
63
+ def get_model_config(self) -> ModelConfig:
64
+ """
65
+ Get the model configuration for the model.
66
+ """
67
+ for model in self.models:
68
+ if model.name == self.model_name:
69
+ return model
70
+ return None
71
+
72
+ def _create_model(self):
73
+ print("Loading whisper model " + self.model_name)
74
+
75
+ model_config = self.get_model_config()
76
+ # Note that the model will not be downloaded in the case of an official Whisper model
77
+ model_path = model_config.download_url(self.download_root)
78
+
79
+ return whisper.load_model(model_path, device=self.device, download_root=self.download_root)
80
+
81
+ def create_callback(self, language: str = None, task: str = None, initial_prompt: str = None, **decodeOptions: dict):
82
+ """
83
+ Create a WhisperCallback object that can be used to transcript audio files.
84
+
85
+ Parameters
86
+ ----------
87
+ language: str
88
+ The target language of the transcription. If not specified, the language will be inferred from the audio content.
89
+ task: str
90
+ The task - either translate or transcribe.
91
+ initial_prompt: str
92
+ The initial prompt to use for the transcription.
93
+ decodeOptions: dict
94
+ Additional options to pass to the decoder. Must be pickleable.
95
+
96
+ Returns
97
+ -------
98
+ A WhisperCallback object.
99
+ """
100
+ return WhisperCallback(self, language=language, task=task, initial_prompt=initial_prompt, **decodeOptions)
101
+
102
+ # This is required for multiprocessing
103
+ def __getstate__(self):
104
+ return { "model_name": self.model_name, "device": self.device, "download_root": self.download_root, "models": self.models }
105
+
106
+ def __setstate__(self, state):
107
+ self.model_name = state["model_name"]
108
+ self.device = state["device"]
109
+ self.download_root = state["download_root"]
110
+ self.models = state["models"]
111
+ self.model = None
112
+ # Depickled objects must use the global cache
113
+ self.cache = GLOBAL_MODEL_CACHE
114
+
115
+
116
+ class WhisperCallback:
117
+ def __init__(self, model_container: WhisperContainer, language: str = None, task: str = None, initial_prompt: str = None, **decodeOptions: dict):
118
+ self.model_container = model_container
119
+ self.language = language
120
+ self.task = task
121
+ self.initial_prompt = initial_prompt
122
+ self.decodeOptions = decodeOptions
123
+
124
+ def invoke(self, audio, segment_index: int, prompt: str, detected_language: str, progress_listener: ProgressListener = None):
125
+ """
126
+ Peform the transcription of the given audio file or data.
127
+
128
+ Parameters
129
+ ----------
130
+ audio: Union[str, np.ndarray, torch.Tensor]
131
+ The audio file to transcribe, or the audio data as a numpy array or torch tensor.
132
+ segment_index: int
133
+ The target language of the transcription. If not specified, the language will be inferred from the audio content.
134
+ task: str
135
+ The task - either translate or transcribe.
136
+ prompt: str
137
+ The prompt to use for the transcription.
138
+ detected_language: str
139
+ The detected language of the audio file.
140
+
141
+ Returns
142
+ -------
143
+ The result of the Whisper call.
144
+ """
145
+ model = self.model_container.get_model()
146
+
147
+ if progress_listener is not None:
148
+ with create_progress_listener_handle(progress_listener):
149
+ return self._transcribe(model, audio, segment_index, prompt, detected_language)
150
+ else:
151
+ return self._transcribe(model, audio, segment_index, prompt, detected_language)
152
+
153
+ def _transcribe(self, model: Whisper, audio, segment_index: int, prompt: str, detected_language: str):
154
+ return model.transcribe(audio, \
155
+ language=self.language if self.language else detected_language, task=self.task, \
156
+ initial_prompt=self._concat_prompt(self.initial_prompt, prompt) if segment_index == 0 else prompt, \
157
+ **self.decodeOptions
158
+ )
159
+
160
+ def _concat_prompt(self, prompt1, prompt2):
161
+ if (prompt1 is None):
162
+ return prompt2
163
+ elif (prompt2 is None):
164
+ return prompt1
165
+ else:
166
+ return prompt1 + " " + prompt2
tests/segments_test.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import unittest
3
+
4
+ sys.path.append('../whisper-webui')
5
+
6
+ from src.segments import merge_timestamps
7
+
8
+ class TestSegments(unittest.TestCase):
9
+ def __init__(self, *args, **kwargs):
10
+ super(TestSegments, self).__init__(*args, **kwargs)
11
+
12
+ def test_merge_segments(self):
13
+ segments = [
14
+ {'start': 10.0, 'end': 20.0},
15
+ {'start': 22.0, 'end': 27.0},
16
+ {'start': 31.0, 'end': 35.0},
17
+ {'start': 45.0, 'end': 60.0},
18
+ {'start': 61.0, 'end': 65.0},
19
+ {'start': 68.0, 'end': 98.0},
20
+ {'start': 100.0, 'end': 102.0},
21
+ {'start': 110.0, 'end': 112.0}
22
+ ]
23
+
24
+ result = merge_timestamps(segments, merge_window=5, max_merge_size=30, padding_left=1, padding_right=1)
25
+
26
+ self.assertListEqual(result, [
27
+ {'start': 9.0, 'end': 36.0},
28
+ {'start': 44.0, 'end': 66.0},
29
+ {'start': 67.0, 'end': 99.0},
30
+ {'start': 99.0, 'end': 103.0},
31
+ {'start': 109.0, 'end': 113.0}
32
+ ])
33
+
34
+ def test_overlap_next(self):
35
+ segments = [
36
+ {'start': 5.0, 'end': 39.182},
37
+ {'start': 39.986, 'end': 40.814}
38
+ ]
39
+
40
+ result = merge_timestamps(segments, merge_window=5, max_merge_size=30, padding_left=1, padding_right=1)
41
+
42
+ self.assertListEqual(result, [
43
+ {'start': 4.0, 'end': 39.584},
44
+ {'start': 39.584, 'end': 41.814}
45
+ ])
46
+
47
+ if __name__ == '__main__':
48
+ unittest.main()
tests/vad_test.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pprint
2
+ import unittest
3
+ import numpy as np
4
+ import sys
5
+
6
+ sys.path.append('../whisper-webui')
7
+
8
+ from src.vad import AbstractTranscription, TranscriptionConfig, VadSileroTranscription
9
+
10
+ class TestVad(unittest.TestCase):
11
+ def __init__(self, *args, **kwargs):
12
+ super(TestVad, self).__init__(*args, **kwargs)
13
+ self.transcribe_calls = []
14
+
15
+ def test_transcript(self):
16
+ mock = MockVadTranscription()
17
+
18
+ self.transcribe_calls.clear()
19
+ result = mock.transcribe("mock", lambda segment : self.transcribe_segments(segment))
20
+
21
+ self.assertListEqual(self.transcribe_calls, [
22
+ [30, 30],
23
+ [100, 100]
24
+ ])
25
+
26
+ self.assertListEqual(result['segments'],
27
+ [{'end': 50.0, 'start': 40.0, 'text': 'Hello world '},
28
+ {'end': 120.0, 'start': 110.0, 'text': 'Hello world '}]
29
+ )
30
+
31
+ def transcribe_segments(self, segment):
32
+ self.transcribe_calls.append(segment.tolist())
33
+
34
+ # Dummy text
35
+ return {
36
+ 'text': "Hello world ",
37
+ 'segments': [
38
+ {
39
+ "start": 10.0,
40
+ "end": 20.0,
41
+ "text": "Hello world "
42
+ }
43
+ ],
44
+ 'language': ""
45
+ }
46
+
47
+ class MockVadTranscription(AbstractTranscription):
48
+ def __init__(self):
49
+ super().__init__()
50
+
51
+ def get_audio_segment(self, str, start_time: str = None, duration: str = None):
52
+ start_time_seconds = float(start_time.removesuffix("s"))
53
+ duration_seconds = float(duration.removesuffix("s"))
54
+
55
+ # For mocking, this just returns a simple numppy array
56
+ return np.array([start_time_seconds, duration_seconds], dtype=np.float64)
57
+
58
+ def get_transcribe_timestamps(self, audio: str, config: TranscriptionConfig, start_time: float, duration: float):
59
+ result = []
60
+
61
+ result.append( { 'start': 30, 'end': 60 } )
62
+ result.append( { 'start': 100, 'end': 200 } )
63
+ return result
64
+
65
+ if __name__ == '__main__':
66
+ unittest.main()