Spaces:
Runtime error
Runtime error
Adding support for word timestamps
Browse files- app.py +28 -12
- cli.py +14 -2
- config.json5 +10 -1
- src/config.py +11 -1
- src/utils.py +117 -8
- src/vad.py +8 -0
- src/whisper/whisperContainer.py +3 -2
app.py
CHANGED
@@ -100,13 +100,17 @@ class WhisperTranscriber:
|
|
100 |
vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, vadInitialPromptMode,
|
101 |
initial_prompt: str, temperature: float, best_of: int, beam_size: int, patience: float, length_penalty: float, suppress_tokens: str,
|
102 |
condition_on_previous_text: bool, fp16: bool, temperature_increment_on_fallback: float,
|
103 |
-
compression_ratio_threshold: float, logprob_threshold: float, no_speech_threshold: float
|
|
|
|
|
|
|
104 |
|
105 |
return self.transcribe_webui_full_progress(modelName, languageName, urlData, multipleFiles, microphoneData, task,
|
106 |
vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, vadInitialPromptMode,
|
107 |
initial_prompt, temperature, best_of, beam_size, patience, length_penalty, suppress_tokens,
|
108 |
condition_on_previous_text, fp16, temperature_increment_on_fallback,
|
109 |
-
compression_ratio_threshold, logprob_threshold, no_speech_threshold
|
|
|
110 |
|
111 |
# Entry function for the full tab with progress
|
112 |
def transcribe_webui_full_progress(self, modelName, languageName, urlData, multipleFiles, microphoneData, task,
|
@@ -114,6 +118,9 @@ class WhisperTranscriber:
|
|
114 |
initial_prompt: str, temperature: float, best_of: int, beam_size: int, patience: float, length_penalty: float, suppress_tokens: str,
|
115 |
condition_on_previous_text: bool, fp16: bool, temperature_increment_on_fallback: float,
|
116 |
compression_ratio_threshold: float, logprob_threshold: float, no_speech_threshold: float,
|
|
|
|
|
|
|
117 |
progress=gr.Progress()):
|
118 |
|
119 |
# Handle temperature_increment_on_fallback
|
@@ -128,13 +135,15 @@ class WhisperTranscriber:
|
|
128 |
initial_prompt=initial_prompt, temperature=temperature, best_of=best_of, beam_size=beam_size, patience=patience, length_penalty=length_penalty, suppress_tokens=suppress_tokens,
|
129 |
condition_on_previous_text=condition_on_previous_text, fp16=fp16,
|
130 |
compression_ratio_threshold=compression_ratio_threshold, logprob_threshold=logprob_threshold, no_speech_threshold=no_speech_threshold,
|
|
|
131 |
progress=progress)
|
132 |
|
133 |
def transcribe_webui(self, modelName, languageName, urlData, multipleFiles, microphoneData, task,
|
134 |
-
vadOptions: VadOptions, progress: gr.Progress = None,
|
|
|
135 |
try:
|
136 |
sources = self.__get_source(urlData, multipleFiles, microphoneData)
|
137 |
-
|
138 |
try:
|
139 |
selectedLanguage = languageName.lower() if len(languageName) > 0 else None
|
140 |
selectedModel = modelName if modelName is not None else "base"
|
@@ -185,7 +194,7 @@ class WhisperTranscriber:
|
|
185 |
# Update progress
|
186 |
current_progress += source_audio_duration
|
187 |
|
188 |
-
source_download, source_text, source_vtt = self.write_result(result, filePrefix, outputDirectory)
|
189 |
|
190 |
if len(sources) > 1:
|
191 |
# Add new line separators
|
@@ -359,7 +368,7 @@ class WhisperTranscriber:
|
|
359 |
|
360 |
return config
|
361 |
|
362 |
-
def write_result(self, result: dict, source_name: str, output_dir: str):
|
363 |
if not os.path.exists(output_dir):
|
364 |
os.makedirs(output_dir)
|
365 |
|
@@ -368,8 +377,8 @@ class WhisperTranscriber:
|
|
368 |
languageMaxLineWidth = self.__get_max_line_width(language)
|
369 |
|
370 |
print("Max line width " + str(languageMaxLineWidth))
|
371 |
-
vtt = self.__get_subs(result["segments"], "vtt", languageMaxLineWidth)
|
372 |
-
srt = self.__get_subs(result["segments"], "srt", languageMaxLineWidth)
|
373 |
|
374 |
output_files = []
|
375 |
output_files.append(self.__create_file(srt, output_dir, source_name + "-subs.srt"));
|
@@ -394,13 +403,13 @@ class WhisperTranscriber:
|
|
394 |
# 80 latin characters should fit on a 1080p/720p screen
|
395 |
return 80
|
396 |
|
397 |
-
def __get_subs(self, segments: Iterator[dict], format: str, maxLineWidth: int) -> str:
|
398 |
segmentStream = StringIO()
|
399 |
|
400 |
if format == 'vtt':
|
401 |
-
write_vtt(segments, file=segmentStream, maxLineWidth=maxLineWidth)
|
402 |
elif format == 'srt':
|
403 |
-
write_srt(segments, file=segmentStream, maxLineWidth=maxLineWidth)
|
404 |
else:
|
405 |
raise Exception("Unknown format " + format)
|
406 |
|
@@ -501,7 +510,14 @@ def create_ui(app_config: ApplicationConfig):
|
|
501 |
gr.Number(label="Temperature increment on fallback", value=app_config.temperature_increment_on_fallback),
|
502 |
gr.Number(label="Compression ratio threshold", value=app_config.compression_ratio_threshold),
|
503 |
gr.Number(label="Logprob threshold", value=app_config.logprob_threshold),
|
504 |
-
gr.Number(label="No speech threshold", value=app_config.no_speech_threshold)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
505 |
], outputs=[
|
506 |
gr.File(label="Download"),
|
507 |
gr.Text(label="Transcription"),
|
|
|
100 |
vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, vadInitialPromptMode,
|
101 |
initial_prompt: str, temperature: float, best_of: int, beam_size: int, patience: float, length_penalty: float, suppress_tokens: str,
|
102 |
condition_on_previous_text: bool, fp16: bool, temperature_increment_on_fallback: float,
|
103 |
+
compression_ratio_threshold: float, logprob_threshold: float, no_speech_threshold: float,
|
104 |
+
# Word timestamps
|
105 |
+
word_timestamps: bool, prepend_punctuations: str,
|
106 |
+
append_punctuations: str, highlight_words: bool = False):
|
107 |
|
108 |
return self.transcribe_webui_full_progress(modelName, languageName, urlData, multipleFiles, microphoneData, task,
|
109 |
vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, vadInitialPromptMode,
|
110 |
initial_prompt, temperature, best_of, beam_size, patience, length_penalty, suppress_tokens,
|
111 |
condition_on_previous_text, fp16, temperature_increment_on_fallback,
|
112 |
+
compression_ratio_threshold, logprob_threshold, no_speech_threshold,
|
113 |
+
word_timestamps, prepend_punctuations, append_punctuations, highlight_words)
|
114 |
|
115 |
# Entry function for the full tab with progress
|
116 |
def transcribe_webui_full_progress(self, modelName, languageName, urlData, multipleFiles, microphoneData, task,
|
|
|
118 |
initial_prompt: str, temperature: float, best_of: int, beam_size: int, patience: float, length_penalty: float, suppress_tokens: str,
|
119 |
condition_on_previous_text: bool, fp16: bool, temperature_increment_on_fallback: float,
|
120 |
compression_ratio_threshold: float, logprob_threshold: float, no_speech_threshold: float,
|
121 |
+
# Word timestamps
|
122 |
+
word_timestamps: bool, prepend_punctuations: str,
|
123 |
+
append_punctuations: str, highlight_words: bool = False,
|
124 |
progress=gr.Progress()):
|
125 |
|
126 |
# Handle temperature_increment_on_fallback
|
|
|
135 |
initial_prompt=initial_prompt, temperature=temperature, best_of=best_of, beam_size=beam_size, patience=patience, length_penalty=length_penalty, suppress_tokens=suppress_tokens,
|
136 |
condition_on_previous_text=condition_on_previous_text, fp16=fp16,
|
137 |
compression_ratio_threshold=compression_ratio_threshold, logprob_threshold=logprob_threshold, no_speech_threshold=no_speech_threshold,
|
138 |
+
word_timestamps=word_timestamps, prepend_punctuations=prepend_punctuations, append_punctuations=append_punctuations, highlight_words=highlight_words,
|
139 |
progress=progress)
|
140 |
|
141 |
def transcribe_webui(self, modelName, languageName, urlData, multipleFiles, microphoneData, task,
|
142 |
+
vadOptions: VadOptions, progress: gr.Progress = None, highlight_words: bool = False,
|
143 |
+
**decodeOptions: dict):
|
144 |
try:
|
145 |
sources = self.__get_source(urlData, multipleFiles, microphoneData)
|
146 |
+
|
147 |
try:
|
148 |
selectedLanguage = languageName.lower() if len(languageName) > 0 else None
|
149 |
selectedModel = modelName if modelName is not None else "base"
|
|
|
194 |
# Update progress
|
195 |
current_progress += source_audio_duration
|
196 |
|
197 |
+
source_download, source_text, source_vtt = self.write_result(result, filePrefix, outputDirectory, highlight_words)
|
198 |
|
199 |
if len(sources) > 1:
|
200 |
# Add new line separators
|
|
|
368 |
|
369 |
return config
|
370 |
|
371 |
+
def write_result(self, result: dict, source_name: str, output_dir: str, highlight_words: bool = False):
|
372 |
if not os.path.exists(output_dir):
|
373 |
os.makedirs(output_dir)
|
374 |
|
|
|
377 |
languageMaxLineWidth = self.__get_max_line_width(language)
|
378 |
|
379 |
print("Max line width " + str(languageMaxLineWidth))
|
380 |
+
vtt = self.__get_subs(result["segments"], "vtt", languageMaxLineWidth, highlight_words=highlight_words)
|
381 |
+
srt = self.__get_subs(result["segments"], "srt", languageMaxLineWidth, highlight_words=highlight_words)
|
382 |
|
383 |
output_files = []
|
384 |
output_files.append(self.__create_file(srt, output_dir, source_name + "-subs.srt"));
|
|
|
403 |
# 80 latin characters should fit on a 1080p/720p screen
|
404 |
return 80
|
405 |
|
406 |
+
def __get_subs(self, segments: Iterator[dict], format: str, maxLineWidth: int, highlight_words: bool = False) -> str:
|
407 |
segmentStream = StringIO()
|
408 |
|
409 |
if format == 'vtt':
|
410 |
+
write_vtt(segments, file=segmentStream, maxLineWidth=maxLineWidth, highlight_words=highlight_words)
|
411 |
elif format == 'srt':
|
412 |
+
write_srt(segments, file=segmentStream, maxLineWidth=maxLineWidth, highlight_words=highlight_words)
|
413 |
else:
|
414 |
raise Exception("Unknown format " + format)
|
415 |
|
|
|
510 |
gr.Number(label="Temperature increment on fallback", value=app_config.temperature_increment_on_fallback),
|
511 |
gr.Number(label="Compression ratio threshold", value=app_config.compression_ratio_threshold),
|
512 |
gr.Number(label="Logprob threshold", value=app_config.logprob_threshold),
|
513 |
+
gr.Number(label="No speech threshold", value=app_config.no_speech_threshold),
|
514 |
+
|
515 |
+
# Word timestamps
|
516 |
+
gr.Checkbox(label="Word Timestamps", value=app_config.word_timestamps),
|
517 |
+
gr.Text(label="Word Timestamps - Prepend Punctuations", value=app_config.prepend_punctuations),
|
518 |
+
gr.Text(label="Word Timestamps - Append Punctuations", value=app_config.append_punctuations),
|
519 |
+
gr.Checkbox(label="Word Timestamps - Highlight Words", value=app_config.highlight_words),
|
520 |
+
|
521 |
], outputs=[
|
522 |
gr.File(label="Download"),
|
523 |
gr.Text(label="Transcription"),
|
cli.py
CHANGED
@@ -95,6 +95,17 @@ def cli():
|
|
95 |
parser.add_argument("--no_speech_threshold", type=optional_float, default=app_config.no_speech_threshold, \
|
96 |
help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
|
97 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
args = parser.parse_args().__dict__
|
99 |
model_name: str = args.pop("model")
|
100 |
model_dir: str = args.pop("model_dir")
|
@@ -126,6 +137,7 @@ def cli():
|
|
126 |
auto_parallel = args.pop("auto_parallel")
|
127 |
|
128 |
compute_type = args.pop("compute_type")
|
|
|
129 |
|
130 |
transcriber = WhisperTranscriber(delete_uploaded_files=False, vad_cpu_cores=vad_cpu_cores, app_config=app_config)
|
131 |
transcriber.set_parallel_devices(args.pop("vad_parallel_devices"))
|
@@ -133,7 +145,7 @@ def cli():
|
|
133 |
|
134 |
model = create_whisper_container(whisper_implementation=whisper_implementation, model_name=model_name,
|
135 |
device=device, compute_type=compute_type, download_root=model_dir, models=app_config.models)
|
136 |
-
|
137 |
if (transcriber._has_parallel_devices()):
|
138 |
print("Using parallel devices:", transcriber.parallel_device_list)
|
139 |
|
@@ -158,7 +170,7 @@ def cli():
|
|
158 |
|
159 |
result = transcriber.transcribe_file(model, source_path, temperature=temperature, vadOptions=vadOptions, **args)
|
160 |
|
161 |
-
transcriber.write_result(result, source_name, output_dir)
|
162 |
|
163 |
transcriber.close()
|
164 |
|
|
|
95 |
parser.add_argument("--no_speech_threshold", type=optional_float, default=app_config.no_speech_threshold, \
|
96 |
help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
|
97 |
|
98 |
+
parser.add_argument("--word_timestamps", type=str2bool, default=app_config.word_timestamps,
|
99 |
+
help="(experimental) extract word-level timestamps and refine the results based on them")
|
100 |
+
parser.add_argument("--prepend_punctuations", type=str, default=app_config.prepend_punctuations,
|
101 |
+
help="if word_timestamps is True, merge these punctuation symbols with the next word")
|
102 |
+
parser.add_argument("--append_punctuations", type=str, default=app_config.append_punctuations,
|
103 |
+
help="if word_timestamps is True, merge these punctuation symbols with the previous word")
|
104 |
+
parser.add_argument("--highlight_words", type=str2bool, default=app_config.highlight_words,
|
105 |
+
help="(requires --word_timestamps True) underline each word as it is spoken in srt and vtt")
|
106 |
+
parser.add_argument("--threads", type=optional_int, default=0,
|
107 |
+
help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
|
108 |
+
|
109 |
args = parser.parse_args().__dict__
|
110 |
model_name: str = args.pop("model")
|
111 |
model_dir: str = args.pop("model_dir")
|
|
|
137 |
auto_parallel = args.pop("auto_parallel")
|
138 |
|
139 |
compute_type = args.pop("compute_type")
|
140 |
+
highlight_words = args.pop("highlight_words")
|
141 |
|
142 |
transcriber = WhisperTranscriber(delete_uploaded_files=False, vad_cpu_cores=vad_cpu_cores, app_config=app_config)
|
143 |
transcriber.set_parallel_devices(args.pop("vad_parallel_devices"))
|
|
|
145 |
|
146 |
model = create_whisper_container(whisper_implementation=whisper_implementation, model_name=model_name,
|
147 |
device=device, compute_type=compute_type, download_root=model_dir, models=app_config.models)
|
148 |
+
|
149 |
if (transcriber._has_parallel_devices()):
|
150 |
print("Using parallel devices:", transcriber.parallel_device_list)
|
151 |
|
|
|
170 |
|
171 |
result = transcriber.transcribe_file(model, source_path, temperature=temperature, vadOptions=vadOptions, **args)
|
172 |
|
173 |
+
transcriber.write_result(result, source_name, output_dir, highlight_words)
|
174 |
|
175 |
transcriber.close()
|
176 |
|
config.json5
CHANGED
@@ -128,5 +128,14 @@
|
|
128 |
// If the average log probability is lower than this value, treat the decoding as failed
|
129 |
"logprob_threshold": -1.0,
|
130 |
// If the probability of the <no-speech> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence
|
131 |
-
"no_speech_threshold": 0.6
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
132 |
}
|
|
|
128 |
// If the average log probability is lower than this value, treat the decoding as failed
|
129 |
"logprob_threshold": -1.0,
|
130 |
// If the probability of the <no-speech> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence
|
131 |
+
"no_speech_threshold": 0.6,
|
132 |
+
|
133 |
+
// (experimental) extract word-level timestamps and refine the results based on them
|
134 |
+
"word_timestamps": false,
|
135 |
+
// if word_timestamps is True, merge these punctuation symbols with the next word
|
136 |
+
"prepend_punctuations": "\"\'“¿([{-",
|
137 |
+
// if word_timestamps is True, merge these punctuation symbols with the previous word
|
138 |
+
"append_punctuations": "\"\'.。,,!!??::”)]}、",
|
139 |
+
// (requires --word_timestamps True) underline each word as it is spoken in srt and vtt
|
140 |
+
"highlight_words": false,
|
141 |
}
|
src/config.py
CHANGED
@@ -58,7 +58,11 @@ class ApplicationConfig:
|
|
58 |
condition_on_previous_text: bool = True, fp16: bool = True,
|
59 |
compute_type: str = "float16",
|
60 |
temperature_increment_on_fallback: float = 0.2, compression_ratio_threshold: float = 2.4,
|
61 |
-
logprob_threshold: float = -1.0, no_speech_threshold: float = 0.6
|
|
|
|
|
|
|
|
|
62 |
|
63 |
self.models = models
|
64 |
|
@@ -104,6 +108,12 @@ class ApplicationConfig:
|
|
104 |
self.logprob_threshold = logprob_threshold
|
105 |
self.no_speech_threshold = no_speech_threshold
|
106 |
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
def get_model_names(self):
|
108 |
return [ x.name for x in self.models ]
|
109 |
|
|
|
58 |
condition_on_previous_text: bool = True, fp16: bool = True,
|
59 |
compute_type: str = "float16",
|
60 |
temperature_increment_on_fallback: float = 0.2, compression_ratio_threshold: float = 2.4,
|
61 |
+
logprob_threshold: float = -1.0, no_speech_threshold: float = 0.6,
|
62 |
+
# Word timestamp settings
|
63 |
+
word_timestamps: bool = False, prepend_punctuations: str = "\"\'“¿([{-",
|
64 |
+
append_punctuations: str = "\"\'.。,,!!??::”)]}、",
|
65 |
+
highlight_words: bool = False):
|
66 |
|
67 |
self.models = models
|
68 |
|
|
|
108 |
self.logprob_threshold = logprob_threshold
|
109 |
self.no_speech_threshold = no_speech_threshold
|
110 |
|
111 |
+
# Word timestamp settings
|
112 |
+
self.word_timestamps = word_timestamps
|
113 |
+
self.prepend_punctuations = prepend_punctuations
|
114 |
+
self.append_punctuations = append_punctuations
|
115 |
+
self.highlight_words = highlight_words
|
116 |
+
|
117 |
def get_model_names(self):
|
118 |
return [ x.name for x in self.models ]
|
119 |
|
src/utils.py
CHANGED
@@ -3,7 +3,7 @@ import unicodedata
|
|
3 |
import re
|
4 |
|
5 |
import zlib
|
6 |
-
from typing import Iterator, TextIO
|
7 |
import tqdm
|
8 |
|
9 |
import urllib3
|
@@ -56,10 +56,14 @@ def write_txt(transcript: Iterator[dict], file: TextIO):
|
|
56 |
print(segment['text'].strip(), file=file, flush=True)
|
57 |
|
58 |
|
59 |
-
def write_vtt(transcript: Iterator[dict], file: TextIO,
|
|
|
|
|
|
|
60 |
print("WEBVTT\n", file=file)
|
61 |
-
|
62 |
-
|
|
|
63 |
|
64 |
print(
|
65 |
f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n"
|
@@ -68,8 +72,8 @@ def write_vtt(transcript: Iterator[dict], file: TextIO, maxLineWidth=None):
|
|
68 |
flush=True,
|
69 |
)
|
70 |
|
71 |
-
|
72 |
-
|
73 |
"""
|
74 |
Write a transcript to a file in SRT format.
|
75 |
Example usage:
|
@@ -81,8 +85,10 @@ def write_srt(transcript: Iterator[dict], file: TextIO, maxLineWidth=None):
|
|
81 |
with open(Path(output_dir) / (audio_basename + ".srt"), "w", encoding="utf-8") as srt:
|
82 |
write_srt(result["segments"], file=srt)
|
83 |
"""
|
84 |
-
|
85 |
-
|
|
|
|
|
86 |
|
87 |
# write srt lines
|
88 |
print(
|
@@ -94,6 +100,109 @@ def write_srt(transcript: Iterator[dict], file: TextIO, maxLineWidth=None):
|
|
94 |
flush=True,
|
95 |
)
|
96 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
def process_text(text: str, maxLineWidth=None):
|
98 |
if (maxLineWidth is None or maxLineWidth < 0):
|
99 |
return text
|
|
|
3 |
import re
|
4 |
|
5 |
import zlib
|
6 |
+
from typing import Iterator, TextIO, Union
|
7 |
import tqdm
|
8 |
|
9 |
import urllib3
|
|
|
56 |
print(segment['text'].strip(), file=file, flush=True)
|
57 |
|
58 |
|
59 |
+
def write_vtt(transcript: Iterator[dict], file: TextIO,
|
60 |
+
maxLineWidth=None, highlight_words: bool = False):
|
61 |
+
iterator = __subtitle_preprocessor_iterator(transcript, maxLineWidth, highlight_words)
|
62 |
+
|
63 |
print("WEBVTT\n", file=file)
|
64 |
+
|
65 |
+
for segment in iterator:
|
66 |
+
text = segment['text'].replace('-->', '->')
|
67 |
|
68 |
print(
|
69 |
f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n"
|
|
|
72 |
flush=True,
|
73 |
)
|
74 |
|
75 |
+
def write_srt(transcript: Iterator[dict], file: TextIO,
|
76 |
+
maxLineWidth=None, highlight_words: bool = False):
|
77 |
"""
|
78 |
Write a transcript to a file in SRT format.
|
79 |
Example usage:
|
|
|
85 |
with open(Path(output_dir) / (audio_basename + ".srt"), "w", encoding="utf-8") as srt:
|
86 |
write_srt(result["segments"], file=srt)
|
87 |
"""
|
88 |
+
iterator = __subtitle_preprocessor_iterator(transcript, maxLineWidth, highlight_words)
|
89 |
+
|
90 |
+
for i, segment in enumerate(iterator, start=1):
|
91 |
+
text = segment['text'].replace('-->', '->')
|
92 |
|
93 |
# write srt lines
|
94 |
print(
|
|
|
100 |
flush=True,
|
101 |
)
|
102 |
|
103 |
+
def __subtitle_preprocessor_iterator(transcript: Iterator[dict], maxLineWidth: int = None, highlight_words: bool = False):
|
104 |
+
for segment in transcript:
|
105 |
+
words = segment.get('words', [])
|
106 |
+
|
107 |
+
if len(words) == 0:
|
108 |
+
# Yield the segment as-is
|
109 |
+
if maxLineWidth is None or maxLineWidth < 0:
|
110 |
+
yield segment
|
111 |
+
|
112 |
+
# Yield the segment with processed text
|
113 |
+
yield {
|
114 |
+
'start': segment['start'],
|
115 |
+
'end': segment['end'],
|
116 |
+
'text': process_text(segment['text'].strip(), maxLineWidth)
|
117 |
+
}
|
118 |
+
|
119 |
+
subtitle_start = segment['start']
|
120 |
+
subtitle_end = segment['end']
|
121 |
+
|
122 |
+
text_words = [ this_word["word"] for this_word in words ]
|
123 |
+
subtitle_text = __join_words(text_words, maxLineWidth)
|
124 |
+
|
125 |
+
# Iterate over the words in the segment
|
126 |
+
if highlight_words:
|
127 |
+
last = subtitle_start
|
128 |
+
|
129 |
+
for i, this_word in enumerate(words):
|
130 |
+
start = this_word['start']
|
131 |
+
end = this_word['end']
|
132 |
+
|
133 |
+
if last != start:
|
134 |
+
# Display the text up to this point
|
135 |
+
yield {
|
136 |
+
'start': last,
|
137 |
+
'end': start,
|
138 |
+
'text': subtitle_text
|
139 |
+
}
|
140 |
+
|
141 |
+
# Display the text with the current word highlighted
|
142 |
+
yield {
|
143 |
+
'start': start,
|
144 |
+
'end': end,
|
145 |
+
'text': __join_words(
|
146 |
+
[
|
147 |
+
{
|
148 |
+
"word": re.sub(r"^(\s*)(.*)$", r"\1<u>\2</u>", word)
|
149 |
+
if j == i
|
150 |
+
else word,
|
151 |
+
# The HTML tags <u> and </u> are not displayed,
|
152 |
+
# # so they should not be counted in the word length
|
153 |
+
"length": len(word)
|
154 |
+
} for j, word in enumerate(text_words)
|
155 |
+
], maxLineWidth)
|
156 |
+
}
|
157 |
+
last = end
|
158 |
+
|
159 |
+
if last != subtitle_end:
|
160 |
+
# Display the last part of the text
|
161 |
+
yield {
|
162 |
+
'start': last,
|
163 |
+
'end': subtitle_end,
|
164 |
+
'text': subtitle_text
|
165 |
+
}
|
166 |
+
|
167 |
+
# Just return the subtitle text
|
168 |
+
else:
|
169 |
+
yield {
|
170 |
+
'start': subtitle_start,
|
171 |
+
'end': subtitle_end,
|
172 |
+
'text': subtitle_text
|
173 |
+
}
|
174 |
+
|
175 |
+
def __join_words(words: Iterator[Union[str, dict]], maxLineWidth: int = None):
|
176 |
+
if maxLineWidth is None or maxLineWidth < 0:
|
177 |
+
return " ".join(words)
|
178 |
+
|
179 |
+
lines = []
|
180 |
+
current_line = ""
|
181 |
+
current_length = 0
|
182 |
+
|
183 |
+
for entry in words:
|
184 |
+
# Either accept a string or a dict with a 'word' and 'length' field
|
185 |
+
if isinstance(entry, dict):
|
186 |
+
word = entry['word']
|
187 |
+
word_length = entry['length']
|
188 |
+
else:
|
189 |
+
word = entry
|
190 |
+
word_length = len(word)
|
191 |
+
|
192 |
+
if current_length > 0 and current_length + word_length > maxLineWidth:
|
193 |
+
lines.append(current_line)
|
194 |
+
current_line = ""
|
195 |
+
current_length = 0
|
196 |
+
|
197 |
+
current_length += word_length
|
198 |
+
# The word will be prefixed with a space by Whisper, so we don't need to add one here
|
199 |
+
current_line += word
|
200 |
+
|
201 |
+
if len(current_line) > 0:
|
202 |
+
lines.append(current_line)
|
203 |
+
|
204 |
+
return "\n".join(lines)
|
205 |
+
|
206 |
def process_text(text: str, maxLineWidth=None):
|
207 |
if (maxLineWidth is None or maxLineWidth < 0):
|
208 |
return text
|
src/vad.py
CHANGED
@@ -404,6 +404,14 @@ class AbstractTranscription(ABC):
|
|
404 |
# Add to start and end
|
405 |
new_segment['start'] = segment_start + adjust_seconds
|
406 |
new_segment['end'] = segment_end + adjust_seconds
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
407 |
result.append(new_segment)
|
408 |
return result
|
409 |
|
|
|
404 |
# Add to start and end
|
405 |
new_segment['start'] = segment_start + adjust_seconds
|
406 |
new_segment['end'] = segment_end + adjust_seconds
|
407 |
+
|
408 |
+
# Handle words
|
409 |
+
if ('words' in new_segment):
|
410 |
+
for word in new_segment['words']:
|
411 |
+
# Adjust start and end
|
412 |
+
word['start'] = word['start'] + adjust_seconds
|
413 |
+
word['end'] = word['end'] + adjust_seconds
|
414 |
+
|
415 |
result.append(new_segment)
|
416 |
return result
|
417 |
|
src/whisper/whisperContainer.py
CHANGED
@@ -203,8 +203,9 @@ class WhisperCallback(AbstractWhisperCallback):
|
|
203 |
|
204 |
initial_prompt = self._get_initial_prompt(self.initial_prompt, self.initial_prompt_mode, prompt, segment_index)
|
205 |
|
206 |
-
|
207 |
language=self.language if self.language else detected_language, task=self.task, \
|
208 |
initial_prompt=initial_prompt, \
|
209 |
**decodeOptions
|
210 |
-
)
|
|
|
|
203 |
|
204 |
initial_prompt = self._get_initial_prompt(self.initial_prompt, self.initial_prompt_mode, prompt, segment_index)
|
205 |
|
206 |
+
result = model.transcribe(audio, \
|
207 |
language=self.language if self.language else detected_language, task=self.task, \
|
208 |
initial_prompt=initial_prompt, \
|
209 |
**decodeOptions
|
210 |
+
)
|
211 |
+
return result
|