erastorgueva-nv commited on
Commit
6ffdd29
·
1 Parent(s): c7e8b60

Initial commit

Browse files
align.py ADDED
@@ -0,0 +1,352 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import copy
16
+ import math
17
+ import os
18
+ from dataclasses import dataclass, field, is_dataclass
19
+ from pathlib import Path
20
+ from typing import List, Optional
21
+
22
+ import torch
23
+ from omegaconf import OmegaConf
24
+ from utils.data_prep import (
25
+ add_t_start_end_to_utt_obj,
26
+ get_batch_starts_ends,
27
+ get_batch_variables,
28
+ get_manifest_lines_batch,
29
+ is_entry_in_all_lines,
30
+ is_entry_in_any_lines,
31
+ )
32
+ from utils.make_ass_files import make_ass_files
33
+ from utils.make_ctm_files import make_ctm_files
34
+ from utils.make_output_manifest import write_manifest_out_line
35
+ from utils.viterbi_decoding import viterbi_decoding
36
+
37
+ from nemo.collections.asr.models.ctc_models import EncDecCTCModel
38
+ from nemo.collections.asr.models.hybrid_rnnt_ctc_models import EncDecHybridRNNTCTCModel
39
+ from nemo.collections.asr.parts.utils.streaming_utils import FrameBatchASR
40
+ from nemo.collections.asr.parts.utils.transcribe_utils import setup_model
41
+ from nemo.core.config import hydra_runner
42
+ from nemo.utils import logging
43
+
44
+ """
45
+ Align the utterances in manifest_filepath.
46
+ Results are saved in ctm files in output_dir.
47
+
48
+ Arguments:
49
+ pretrained_name: string specifying the name of a CTC NeMo ASR model which will be automatically downloaded
50
+ from NGC and used for generating the log-probs which we will use to do alignment.
51
+ Note: NFA can only use CTC models (not Transducer models) at the moment.
52
+ model_path: string specifying the local filepath to a CTC NeMo ASR model which will be used to generate the
53
+ log-probs which we will use to do alignment.
54
+ Note: NFA can only use CTC models (not Transducer models) at the moment.
55
+ Note: if a model_path is provided, it will override the pretrained_name.
56
+ manifest_filepath: filepath to the manifest of the data you want to align,
57
+ containing 'audio_filepath' and 'text' fields.
58
+ output_dir: the folder where output CTM files and new JSON manifest will be saved.
59
+ align_using_pred_text: if True, will transcribe the audio using the specified model and then use that transcription
60
+ as the reference text for the forced alignment.
61
+ transcribe_device: None, or a string specifying the device that will be used for generating log-probs (i.e. "transcribing").
62
+ The string needs to be in a format recognized by torch.device(). If None, NFA will set it to 'cuda' if it is available
63
+ (otherwise will set it to 'cpu').
64
+ viterbi_device: None, or string specifying the device that will be used for doing Viterbi decoding.
65
+ The string needs to be in a format recognized by torch.device(). If None, NFA will set it to 'cuda' if it is available
66
+ (otherwise will set it to 'cpu').
67
+ batch_size: int specifying batch size that will be used for generating log-probs and doing Viterbi decoding.
68
+ use_local_attention: boolean flag specifying whether to try to use local attention for the ASR Model (will only
69
+ work if the ASR Model is a Conformer model). If local attention is used, we will set the local attention context
70
+ size to [64,64].
71
+ additional_segment_grouping_separator: an optional string used to separate the text into smaller segments.
72
+ If this is not specified, then the whole text will be treated as a single segment.
73
+ remove_blank_tokens_from_ctm: a boolean denoting whether to remove <blank> tokens from token-level output CTMs.
74
+ audio_filepath_parts_in_utt_id: int specifying how many of the 'parts' of the audio_filepath
75
+ we will use (starting from the final part of the audio_filepath) to determine the
76
+ utt_id that will be used in the CTM files. Note also that any spaces that are present in the audio_filepath
77
+ will be replaced with dashes, so as not to change the number of space-separated elements in the
78
+ CTM files.
79
+ e.g. if audio_filepath is "/a/b/c/d/e 1.wav" and audio_filepath_parts_in_utt_id is 1 => utt_id will be "e1"
80
+ e.g. if audio_filepath is "/a/b/c/d/e 1.wav" and audio_filepath_parts_in_utt_id is 2 => utt_id will be "d_e1"
81
+ e.g. if audio_filepath is "/a/b/c/d/e 1.wav" and audio_filepath_parts_in_utt_id is 3 => utt_id will be "c_d_e1"
82
+ use_buffered_infer: False, if set True, using streaming to do get the logits for alignment
83
+ This flag is useful when aligning large audio file.
84
+ However, currently the chunk streaming inference does not support batch inference,
85
+ which means even you set batch_size > 1, it will only infer one by one instead of doing
86
+ the whole batch inference together.
87
+ chunk_len_in_secs: float chunk length in seconds
88
+ total_buffer_in_secs: float Length of buffer (chunk + left and right padding) in seconds
89
+ chunk_batch_size: int batch size for buffered chunk inference,
90
+ which will cut one audio into segments and do inference on chunk_batch_size segments at a time
91
+
92
+ simulate_cache_aware_streaming: False, if set True, using cache aware streaming to do get the logits for alignment
93
+
94
+ save_output_file_formats: List of strings specifying what type of output files to save (default: ["ctm", "ass"])
95
+ ctm_file_config: CTMFileConfig to specify the configuration of the output CTM files
96
+ ass_file_config: ASSFileConfig to specify the configuration of the output ASS files
97
+ """
98
+
99
+
100
+ @dataclass
101
+ class CTMFileConfig:
102
+ remove_blank_tokens: bool = False
103
+ # minimum duration (in seconds) for timestamps in the CTM.If any line in the CTM has a
104
+ # duration lower than this, it will be enlarged from the middle outwards until it
105
+ # meets the minimum_timestamp_duration, or reaches the beginning or end of the audio file.
106
+ # Note that this may cause timestamps to overlap.
107
+ minimum_timestamp_duration: float = 0
108
+
109
+
110
+ @dataclass
111
+ class ASSFileConfig:
112
+ fontsize: int = 20
113
+ vertical_alignment: str = "center"
114
+ # if resegment_text_to_fill_space is True, the ASS files will use new segments
115
+ # such that each segment will not take up more than (approximately) max_lines_per_segment
116
+ # when the ASS file is applied to a video
117
+ resegment_text_to_fill_space: bool = False
118
+ max_lines_per_segment: int = 2
119
+ text_already_spoken_rgb: List[int] = field(default_factory=lambda: [49, 46, 61]) # dark gray
120
+ text_being_spoken_rgb: List[int] = field(default_factory=lambda: [57, 171, 9]) # dark green
121
+ text_not_yet_spoken_rgb: List[int] = field(default_factory=lambda: [194, 193, 199]) # light gray
122
+
123
+
124
+ @dataclass
125
+ class AlignmentConfig:
126
+ # Required configs
127
+ pretrained_name: Optional[str] = None
128
+ model_path: Optional[str] = None
129
+ manifest_filepath: Optional[str] = None
130
+ output_dir: Optional[str] = None
131
+
132
+ # General configs
133
+ align_using_pred_text: bool = False
134
+ transcribe_device: Optional[str] = None
135
+ viterbi_device: Optional[str] = None
136
+ batch_size: int = 1
137
+ use_local_attention: bool = True
138
+ additional_segment_grouping_separator: Optional[str] = None
139
+ audio_filepath_parts_in_utt_id: int = 1
140
+
141
+ # Buffered chunked streaming configs
142
+ use_buffered_chunked_streaming: bool = False
143
+ chunk_len_in_secs: float = 1.6
144
+ total_buffer_in_secs: float = 4.0
145
+ chunk_batch_size: int = 32
146
+
147
+ # Cache aware streaming configs
148
+ simulate_cache_aware_streaming: Optional[bool] = False
149
+
150
+ # Output file configs
151
+ save_output_file_formats: List[str] = field(default_factory=lambda: ["ctm", "ass"])
152
+ ctm_file_config: CTMFileConfig = CTMFileConfig()
153
+ ass_file_config: ASSFileConfig = ASSFileConfig()
154
+
155
+
156
+ @hydra_runner(config_name="AlignmentConfig", schema=AlignmentConfig)
157
+ def main(cfg: AlignmentConfig):
158
+
159
+ logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}')
160
+
161
+ if is_dataclass(cfg):
162
+ cfg = OmegaConf.structured(cfg)
163
+
164
+ # Validate config
165
+ if cfg.model_path is None and cfg.pretrained_name is None:
166
+ raise ValueError("Both cfg.model_path and cfg.pretrained_name cannot be None")
167
+
168
+ if cfg.model_path is not None and cfg.pretrained_name is not None:
169
+ raise ValueError("One of cfg.model_path and cfg.pretrained_name must be None")
170
+
171
+ if cfg.manifest_filepath is None:
172
+ raise ValueError("cfg.manifest_filepath must be specified")
173
+
174
+ if cfg.output_dir is None:
175
+ raise ValueError("cfg.output_dir must be specified")
176
+
177
+ if cfg.batch_size < 1:
178
+ raise ValueError("cfg.batch_size cannot be zero or a negative number")
179
+
180
+ if cfg.additional_segment_grouping_separator == "" or cfg.additional_segment_grouping_separator == " ":
181
+ raise ValueError("cfg.additional_grouping_separator cannot be empty string or space character")
182
+
183
+ if cfg.ctm_file_config.minimum_timestamp_duration < 0:
184
+ raise ValueError("cfg.minimum_timestamp_duration cannot be a negative number")
185
+
186
+ if cfg.ass_file_config.vertical_alignment not in ["top", "center", "bottom"]:
187
+ raise ValueError("cfg.ass_file_config.vertical_alignment must be one of 'top', 'center' or 'bottom'")
188
+
189
+ for rgb_list in [
190
+ cfg.ass_file_config.text_already_spoken_rgb,
191
+ cfg.ass_file_config.text_already_spoken_rgb,
192
+ cfg.ass_file_config.text_already_spoken_rgb,
193
+ ]:
194
+ if len(rgb_list) != 3:
195
+ raise ValueError(
196
+ "cfg.ass_file_config.text_already_spoken_rgb,"
197
+ " cfg.ass_file_config.text_being_spoken_rgb,"
198
+ " and cfg.ass_file_config.text_already_spoken_rgb all need to contain"
199
+ " exactly 3 elements."
200
+ )
201
+
202
+ # Validate manifest contents
203
+ if not is_entry_in_all_lines(cfg.manifest_filepath, "audio_filepath"):
204
+ raise RuntimeError(
205
+ "At least one line in cfg.manifest_filepath does not contain an 'audio_filepath' entry. "
206
+ "All lines must contain an 'audio_filepath' entry."
207
+ )
208
+
209
+ if cfg.align_using_pred_text:
210
+ if is_entry_in_any_lines(cfg.manifest_filepath, "pred_text"):
211
+ raise RuntimeError(
212
+ "Cannot specify cfg.align_using_pred_text=True when the manifest at cfg.manifest_filepath "
213
+ "contains 'pred_text' entries. This is because the audio will be transcribed and may produce "
214
+ "a different 'pred_text'. This may cause confusion."
215
+ )
216
+ else:
217
+ if not is_entry_in_all_lines(cfg.manifest_filepath, "text"):
218
+ raise RuntimeError(
219
+ "At least one line in cfg.manifest_filepath does not contain a 'text' entry. "
220
+ "NFA requires all lines to contain a 'text' entry when cfg.align_using_pred_text=False."
221
+ )
222
+
223
+ # init devices
224
+ if cfg.transcribe_device is None:
225
+ transcribe_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
226
+ else:
227
+ transcribe_device = torch.device(cfg.transcribe_device)
228
+ logging.info(f"Device to be used for transcription step (`transcribe_device`) is {transcribe_device}")
229
+
230
+ if cfg.viterbi_device is None:
231
+ viterbi_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
232
+ else:
233
+ viterbi_device = torch.device(cfg.viterbi_device)
234
+ logging.info(f"Device to be used for viterbi step (`viterbi_device`) is {viterbi_device}")
235
+
236
+ if transcribe_device.type == 'cuda' or viterbi_device.type == 'cuda':
237
+ logging.warning(
238
+ 'One or both of transcribe_device and viterbi_device are GPUs. If you run into OOM errors '
239
+ 'it may help to change both devices to be the CPU.'
240
+ )
241
+
242
+ # load model
243
+ model, _ = setup_model(cfg, transcribe_device)
244
+ model.eval()
245
+
246
+ if isinstance(model, EncDecHybridRNNTCTCModel):
247
+ model.change_decoding_strategy(decoder_type="ctc")
248
+
249
+ if cfg.use_local_attention:
250
+ logging.info(
251
+ "Flag use_local_attention is set to True => will try to use local attention for model if it allows it"
252
+ )
253
+ model.change_attention_model(self_attention_model="rel_pos_local_attn", att_context_size=[64, 64])
254
+
255
+ if not (isinstance(model, EncDecCTCModel) or isinstance(model, EncDecHybridRNNTCTCModel)):
256
+ raise NotImplementedError(
257
+ f"Model is not an instance of NeMo EncDecCTCModel or ENCDecHybridRNNTCTCModel."
258
+ " Currently only instances of these models are supported"
259
+ )
260
+
261
+ if cfg.ctm_file_config.minimum_timestamp_duration > 0:
262
+ logging.warning(
263
+ f"cfg.ctm_file_config.minimum_timestamp_duration has been set to {cfg.ctm_file_config.minimum_timestamp_duration} seconds. "
264
+ "This may cause the alignments for some tokens/words/additional segments to be overlapping."
265
+ )
266
+
267
+ buffered_chunk_params = {}
268
+ if cfg.use_buffered_chunked_streaming:
269
+ model_cfg = copy.deepcopy(model._cfg)
270
+
271
+ OmegaConf.set_struct(model_cfg.preprocessor, False)
272
+ # some changes for streaming scenario
273
+ model_cfg.preprocessor.dither = 0.0
274
+ model_cfg.preprocessor.pad_to = 0
275
+
276
+ if model_cfg.preprocessor.normalize != "per_feature":
277
+ logging.error(
278
+ "Only EncDecCTCModelBPE models trained with per_feature normalization are supported currently"
279
+ )
280
+ # Disable config overwriting
281
+ OmegaConf.set_struct(model_cfg.preprocessor, True)
282
+
283
+ feature_stride = model_cfg.preprocessor['window_stride']
284
+ model_stride_in_secs = feature_stride * cfg.model_downsample_factor
285
+ total_buffer = cfg.total_buffer_in_secs
286
+ chunk_len = float(cfg.chunk_len_in_secs)
287
+ tokens_per_chunk = math.ceil(chunk_len / model_stride_in_secs)
288
+ mid_delay = math.ceil((chunk_len + (total_buffer - chunk_len) / 2) / model_stride_in_secs)
289
+ logging.info(f"tokens_per_chunk is {tokens_per_chunk}, mid_delay is {mid_delay}")
290
+
291
+ model = FrameBatchASR(
292
+ asr_model=model,
293
+ frame_len=chunk_len,
294
+ total_buffer=cfg.total_buffer_in_secs,
295
+ batch_size=cfg.chunk_batch_size,
296
+ )
297
+ buffered_chunk_params = {
298
+ "delay": mid_delay,
299
+ "model_stride_in_secs": model_stride_in_secs,
300
+ "tokens_per_chunk": tokens_per_chunk,
301
+ }
302
+ # get start and end line IDs of batches
303
+ starts, ends = get_batch_starts_ends(cfg.manifest_filepath, cfg.batch_size)
304
+
305
+ # init output_timestep_duration = None and we will calculate and update it during the first batch
306
+ output_timestep_duration = None
307
+
308
+ # init f_manifest_out
309
+ os.makedirs(cfg.output_dir, exist_ok=True)
310
+ tgt_manifest_name = str(Path(cfg.manifest_filepath).stem) + "_with_output_file_paths.json"
311
+ tgt_manifest_filepath = str(Path(cfg.output_dir) / tgt_manifest_name)
312
+ f_manifest_out = open(tgt_manifest_filepath, 'w')
313
+
314
+ # get alignment and save in CTM batch-by-batch
315
+ for start, end in zip(starts, ends):
316
+ manifest_lines_batch = get_manifest_lines_batch(cfg.manifest_filepath, start, end)
317
+
318
+ (log_probs_batch, y_batch, T_batch, U_batch, utt_obj_batch, output_timestep_duration,) = get_batch_variables(
319
+ manifest_lines_batch,
320
+ model,
321
+ cfg.additional_segment_grouping_separator,
322
+ cfg.align_using_pred_text,
323
+ cfg.audio_filepath_parts_in_utt_id,
324
+ output_timestep_duration,
325
+ cfg.simulate_cache_aware_streaming,
326
+ cfg.use_buffered_chunked_streaming,
327
+ buffered_chunk_params,
328
+ )
329
+
330
+ alignments_batch = viterbi_decoding(log_probs_batch, y_batch, T_batch, U_batch, viterbi_device)
331
+
332
+ for utt_obj, alignment_utt in zip(utt_obj_batch, alignments_batch):
333
+
334
+ utt_obj = add_t_start_end_to_utt_obj(utt_obj, alignment_utt, output_timestep_duration)
335
+
336
+ if "ctm" in cfg.save_output_file_formats:
337
+ utt_obj = make_ctm_files(utt_obj, cfg.output_dir, cfg.ctm_file_config,)
338
+
339
+ if "ass" in cfg.save_output_file_formats:
340
+ utt_obj = make_ass_files(utt_obj, cfg.output_dir, cfg.ass_file_config)
341
+
342
+ write_manifest_out_line(
343
+ f_manifest_out, utt_obj,
344
+ )
345
+
346
+ f_manifest_out.close()
347
+
348
+ return None
349
+
350
+
351
+ if __name__ == "__main__":
352
+ main()
app.py ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import librosa
3
+ import soundfile
4
+ import tempfile
5
+ import os
6
+ import uuid
7
+ import json
8
+
9
+ import jieba
10
+
11
+ import nemo.collections.asr as nemo_asr
12
+ from nemo.collections.asr.models import ASRModel
13
+ from nemo.utils import logging
14
+
15
+ from align import main, AlignmentConfig, ASSFileConfig
16
+
17
+
18
+ SAMPLE_RATE = 16000
19
+
20
+ # Pre-download and cache the model in disk space
21
+ logging.setLevel(logging.ERROR)
22
+ for tmp_model_name in [
23
+ "stt_en_fastconformer_hybrid_large_pc",
24
+ "stt_de_fastconformer_hybrid_large_pc",
25
+ "stt_es_fastconformer_hybrid_large_pc",
26
+ "stt_fr_conformer_ctc_large",
27
+ "stt_zh_citrinet_1024_gamma_0_25",
28
+ ]:
29
+ tmp_model = ASRModel.from_pretrained(tmp_model_name, map_location='cpu')
30
+ del tmp_model
31
+ logging.setLevel(logging.INFO)
32
+
33
+
34
+ def get_audio_data_and_duration(file):
35
+ data, sr = librosa.load(file)
36
+
37
+ if sr != SAMPLE_RATE:
38
+ data = librosa.resample(data, orig_sr=sr, target_sr=SAMPLE_RATE)
39
+
40
+ # monochannel
41
+ data = librosa.to_mono(data)
42
+
43
+ duration = librosa.get_duration(y=data, sr=SAMPLE_RATE)
44
+ return data, duration
45
+
46
+
47
+ def get_char_tokens(text, model):
48
+ tokens = []
49
+ for character in text:
50
+ if character in model.decoder.vocabulary:
51
+ tokens.append(model.decoder.vocabulary.index(character))
52
+ else:
53
+ tokens.append(len(model.decoder.vocabulary)) # return unk token (same as blank token)
54
+
55
+ return tokens
56
+
57
+
58
+ def get_S_prime_and_T(text, model_name, model, audio_duration):
59
+
60
+ # estimate T
61
+ if "citrinet" in model_name or "_fastconformer_" in model_name:
62
+ output_timestep_duration = 0.08
63
+ elif "_conformer_" in model_name:
64
+ output_timestep_duration = 0.04
65
+ elif "quartznet" in model_name:
66
+ output_timestep_duration = 0.02
67
+ else:
68
+ raise RuntimeError("unexpected model name")
69
+
70
+ T = int(audio_duration / output_timestep_duration) + 1
71
+
72
+ # calculate S_prime = num tokens + num repetitions
73
+ if hasattr(model, 'tokenizer'):
74
+ all_tokens = model.tokenizer.text_to_ids(text)
75
+ elif hasattr(model.decoder, "vocabulary"): # i.e. tokenization is simply character-based
76
+ all_tokens = get_char_tokens(text, model)
77
+ else:
78
+ raise RuntimeError("cannot obtain tokens from this model")
79
+
80
+ n_token_repetitions = 0
81
+ for i_tok in range(1, len(all_tokens)):
82
+ if all_tokens[i_tok] == all_tokens[i_tok - 1]:
83
+ n_token_repetitions += 1
84
+
85
+ S_prime = len(all_tokens) + n_token_repetitions
86
+
87
+ print('all_tokens', all_tokens)
88
+ print(len(all_tokens))
89
+ print(n_token_repetitions)
90
+
91
+ return S_prime, T
92
+
93
+
94
+ def hex_to_rgb_list(hex_string):
95
+ hex_string = hex_string.lstrip("#")
96
+ r = int(hex_string[:2], 16)
97
+ g = int(hex_string[2:4], 16)
98
+ b = int(hex_string[4:], 16)
99
+ return [r, g, b]
100
+
101
+ def delete_mp4s_except_given_filepath(filepath):
102
+ files_in_dir = os.listdir()
103
+ mp4_files_in_dir = [x for x in files_in_dir if x.endswith(".mp4")]
104
+ for mp4_file in mp4_files_in_dir:
105
+ if mp4_file != filepath:
106
+ print('deleting', mp4_file)
107
+ os.remove(mp4_file)
108
+
109
+
110
+
111
+
112
+ def align(lang, Microphone, File_Upload, text, col1, col2, col3, progress=gr.Progress()):
113
+ # Create utt_id, specify output_video_filepath and delete any MP4s
114
+ # that are not that filepath. These stray MP4s can be created
115
+ # if a user refreshes or exits the page while this 'align' function is executing.
116
+ # This deletion will not delete any other users' video as long as this 'align' function
117
+ # is run one at a time.
118
+ utt_id = uuid.uuid4()
119
+ output_video_filepath = f"{utt_id}.mp4"
120
+ delete_mp4s_except_given_filepath(output_video_filepath)
121
+
122
+ output_info = ""
123
+
124
+ progress(0, desc="Validating input")
125
+
126
+ # choose model
127
+ if lang in ["en", "de", "es"]:
128
+ model_name = f"stt_{lang}_fastconformer_hybrid_large_pc"
129
+ elif lang in ["fr"]:
130
+ model_name = f"stt_{lang}_conformer_ctc_large"
131
+ elif lang in ["zh"]:
132
+ model_name = f"stt_{lang}_citrinet_1024_gamma_0_25"
133
+
134
+ # decide which of Mic / File_Upload is used as input & do error handling
135
+ if (Microphone is not None) and (File_Upload is not None):
136
+ raise gr.Error("Please use either the microphone or file upload input - not both")
137
+
138
+ elif (Microphone is None) and (File_Upload is None):
139
+ raise gr.Error("You have to either use the microphone or upload an audio file")
140
+
141
+ elif Microphone is not None:
142
+ file = Microphone
143
+ else:
144
+ file = File_Upload
145
+
146
+ # check audio is not too long
147
+ audio_data, duration = get_audio_data_and_duration(file)
148
+
149
+ if duration > 4 * 60:
150
+ raise gr.Error(
151
+ f"Detected that uploaded audio has duration {duration/60:.1f} mins - please only upload audio of less than 4 mins duration"
152
+ )
153
+
154
+ # loading model
155
+ progress(0.1, desc="Loading speech recognition model")
156
+ model = ASRModel.from_pretrained(model_name)
157
+
158
+ if text: # check input text is not too long compared to audio
159
+ S_prime, T = get_S_prime_and_T(text, model_name, model, duration)
160
+
161
+ if S_prime > T:
162
+ raise gr.Error(
163
+ f"The number of tokens in the input text is too long compared to the duration of the audio."
164
+ f" This model can handle {T} tokens + token repetitions at most. You have provided {S_prime} tokens + token repetitions. "
165
+ f" (Adjacent tokens that are not in the model's vocabulary are also counted as a token repetition.)"
166
+ )
167
+
168
+ with tempfile.TemporaryDirectory() as tmpdir:
169
+ audio_path = os.path.join(tmpdir, f'{utt_id}.wav')
170
+ soundfile.write(audio_path, audio_data, SAMPLE_RATE)
171
+
172
+ # getting the text if it hasn't been provided
173
+ if not text:
174
+ progress(0.2, desc="Transcribing audio")
175
+ text = model.transcribe([audio_path])[0]
176
+ if 'hybrid' in model_name:
177
+ text = text[0]
178
+ print('transcribed text:', text)
179
+
180
+ if text == "":
181
+ raise gr.Error(
182
+ "ERROR: the ASR model did not detect any speech in the input audio. Please upload audio with speech."
183
+ )
184
+
185
+ output_info += (
186
+ "You did not enter any input text, so the ASR model's transcription will be used:\n"
187
+ "--------------------------\n"
188
+ f"{text}\n"
189
+ "--------------------------\n"
190
+ f"You could try pasting the transcription into the text input box, correcting any"
191
+ " transcription errors, and clicking 'Submit' again."
192
+ )
193
+
194
+ if lang == "zh" and " " not in text:
195
+ # use jieba to add spaces between zh characters
196
+ text = " ".join(jieba.cut(text))
197
+
198
+ data = {
199
+ "audio_filepath": audio_path,
200
+ "text": text,
201
+ }
202
+ manifest_path = os.path.join(tmpdir, f"{utt_id}_manifest.json")
203
+ with open(manifest_path, 'w') as fout:
204
+ fout.write(f"{json.dumps(data)}\n")
205
+
206
+ # run alignment
207
+ if "|" in text:
208
+ resegment_text_to_fill_space = False
209
+ else:
210
+ resegment_text_to_fill_space = True
211
+
212
+ alignment_config = AlignmentConfig(
213
+ pretrained_name=model_name,
214
+ manifest_filepath=manifest_path,
215
+ output_dir=f"{tmpdir}/nfa_output/",
216
+ audio_filepath_parts_in_utt_id=1,
217
+ batch_size=1,
218
+ use_local_attention=True,
219
+ additional_segment_grouping_separator="|",
220
+ # transcribe_device='cpu',
221
+ # viterbi_device='cpu',
222
+ save_output_file_formats=["ass"],
223
+ ass_file_config=ASSFileConfig(
224
+ fontsize=45,
225
+ resegment_text_to_fill_space=resegment_text_to_fill_space,
226
+ max_lines_per_segment=4,
227
+ text_already_spoken_rgb=hex_to_rgb_list(col1),
228
+ text_being_spoken_rgb=hex_to_rgb_list(col2),
229
+ text_not_yet_spoken_rgb=hex_to_rgb_list(col3),
230
+ ),
231
+ )
232
+
233
+ progress(0.5, desc="Aligning audio")
234
+
235
+ main(alignment_config)
236
+
237
+ progress(0.95, desc="Saving generated alignments")
238
+
239
+
240
+ if lang=="zh":
241
+ # make video file from the token-level ASS file
242
+ ass_file_for_video = f"{tmpdir}/nfa_output/ass/tokens/{utt_id}.ass"
243
+ else:
244
+ # make video file from the word-level ASS file
245
+ ass_file_for_video = f"{tmpdir}/nfa_output/ass/words/{utt_id}.ass"
246
+
247
+ ffmpeg_command = (
248
+ f"ffmpeg -y -i {audio_path} "
249
+ "-f lavfi -i color=c=white:s=1280x720:r=50 "
250
+ "-crf 1 -shortest -vcodec libx264 -pix_fmt yuv420p "
251
+ f"-vf 'ass={ass_file_for_video}' "
252
+ f"{output_video_filepath}"
253
+ )
254
+
255
+ os.system(ffmpeg_command)
256
+
257
+ return output_video_filepath, gr.update(value=output_info, visible=True), output_video_filepath
258
+
259
+
260
+ def delete_non_tmp_video(video_path):
261
+ if video_path:
262
+ if os.path.exists(video_path):
263
+ os.remove(video_path)
264
+ return None
265
+
266
+
267
+ with gr.Blocks(title="NeMo Forced Aligner", theme="huggingface") as demo:
268
+ non_tmp_output_video_filepath = gr.State([])
269
+
270
+ with gr.Row():
271
+ with gr.Column():
272
+ gr.Markdown("# NeMo Forced Aligner")
273
+ gr.Markdown(
274
+ "Demo for [NeMo Forced Aligner](https://github.com/NVIDIA/NeMo/tree/main/tools/nemo_forced_aligner) (NFA). "
275
+ "Upload audio and (optionally) the text spoken in the audio to generate a video where each part of the text will be highlighted as it is spoken. ",
276
+ )
277
+
278
+ with gr.Row():
279
+
280
+ with gr.Column(scale=1):
281
+ gr.Markdown("## Input")
282
+ lang_drop = gr.Dropdown(choices=["de", "en", "es", "fr", "zh"], value="en", label="Audio language",)
283
+
284
+ mic_in = gr.Audio(source="microphone", type='filepath', label="Microphone input (max 4 mins)")
285
+ audio_file_in = gr.Audio(source="upload", type='filepath', label="File upload (max 4 mins)")
286
+ ref_text = gr.Textbox(
287
+ label="[Optional] The reference text. Use '|' separators to specify which text will appear together. "
288
+ "Leave this field blank to use an ASR model's transcription as the reference text instead."
289
+ )
290
+
291
+ gr.Markdown("[Optional] For fun - adjust the colors of the text in the output video")
292
+ with gr.Row():
293
+ col1 = gr.ColorPicker(label="text already spoken", value="#fcba03")
294
+ col2 = gr.ColorPicker(label="text being spoken", value="#bf45bf")
295
+ col3 = gr.ColorPicker(label="text to be spoken", value="#3e1af0")
296
+
297
+ submit_button = gr.Button("Submit")
298
+
299
+ with gr.Column(scale=1):
300
+ gr.Markdown("## Output")
301
+ video_out = gr.Video(label="output video")
302
+ text_out = gr.Textbox(label="output info", visible=False)
303
+
304
+ submit_button.click(
305
+ fn=align,
306
+ inputs=[lang_drop, mic_in, audio_file_in, ref_text, col1, col2, col3,],
307
+ outputs=[video_out, text_out, non_tmp_output_video_filepath],
308
+ ).then(
309
+ fn=delete_non_tmp_video, inputs=[non_tmp_output_video_filepath], outputs=None,
310
+ )
311
+
312
+ demo.queue()
313
+ demo.launch()
314
+
packages.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ ffmpeg
2
+ libsndfile1
3
+ build-essential
pre-requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ Cython
2
+ torch
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ nemo_toolkit[all]
utils/constants.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ BLANK_TOKEN = "<b>"
16
+
17
+ SPACE_TOKEN = "<space>"
18
+
19
+ V_NEGATIVE_NUM = -3.4e38 # this is just above the most negative number in torch.float32
utils/data_prep.py ADDED
@@ -0,0 +1,835 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import json
16
+ from dataclasses import dataclass, field
17
+ from pathlib import Path
18
+ from typing import List, Union
19
+
20
+ import soundfile as sf
21
+ import torch
22
+ from tqdm.auto import tqdm
23
+ from utils.constants import BLANK_TOKEN, SPACE_TOKEN, V_NEGATIVE_NUM
24
+
25
+ from nemo.utils import logging
26
+
27
+
28
+ def _get_utt_id(audio_filepath, audio_filepath_parts_in_utt_id):
29
+ fp_parts = Path(audio_filepath).parts[-audio_filepath_parts_in_utt_id:]
30
+ utt_id = Path("_".join(fp_parts)).stem
31
+ utt_id = utt_id.replace(" ", "-") # replace any spaces in the filepath with dashes
32
+ return utt_id
33
+
34
+
35
+ def get_batch_starts_ends(manifest_filepath, batch_size):
36
+ """
37
+ Get the start and end ids of the lines we will use for each 'batch'.
38
+ """
39
+
40
+ with open(manifest_filepath, 'r') as f:
41
+ num_lines_in_manifest = sum(1 for _ in f)
42
+
43
+ starts = [x for x in range(0, num_lines_in_manifest, batch_size)]
44
+ ends = [x - 1 for x in starts]
45
+ ends.pop(0)
46
+ ends.append(num_lines_in_manifest)
47
+
48
+ return starts, ends
49
+
50
+
51
+ def is_entry_in_any_lines(manifest_filepath, entry):
52
+ """
53
+ Returns True if entry is a key in any of the JSON lines in manifest_filepath
54
+ """
55
+
56
+ entry_in_manifest = False
57
+
58
+ with open(manifest_filepath, 'r') as f:
59
+ for line in f:
60
+ data = json.loads(line)
61
+
62
+ if entry in data:
63
+ entry_in_manifest = True
64
+
65
+ return entry_in_manifest
66
+
67
+
68
+ def is_entry_in_all_lines(manifest_filepath, entry):
69
+ """
70
+ Returns True is entry is a key in all of the JSON lines in manifest_filepath.
71
+ """
72
+ with open(manifest_filepath, 'r') as f:
73
+ for line in f:
74
+ data = json.loads(line)
75
+
76
+ if entry not in data:
77
+ return False
78
+
79
+ return True
80
+
81
+
82
+ def get_manifest_lines_batch(manifest_filepath, start, end):
83
+ manifest_lines_batch = []
84
+ with open(manifest_filepath, "r", encoding="utf-8-sig") as f:
85
+ for line_i, line in enumerate(f):
86
+ if line_i >= start and line_i <= end:
87
+ data = json.loads(line)
88
+ if "text" in data:
89
+ # remove any BOM, any duplicated spaces, convert any
90
+ # newline chars to spaces
91
+ data["text"] = data["text"].replace("\ufeff", "")
92
+ data["text"] = " ".join(data["text"].split())
93
+ manifest_lines_batch.append(data)
94
+
95
+ if line_i == end:
96
+ break
97
+ return manifest_lines_batch
98
+
99
+
100
+ def get_char_tokens(text, model):
101
+ tokens = []
102
+ for character in text:
103
+ if character in model.decoder.vocabulary:
104
+ tokens.append(model.decoder.vocabulary.index(character))
105
+ else:
106
+ tokens.append(len(model.decoder.vocabulary)) # return unk token (same as blank token)
107
+
108
+ return tokens
109
+
110
+
111
+ def is_sub_or_superscript_pair(ref_text, text):
112
+ """returns True if ref_text is a subscript or superscript version of text"""
113
+ sub_or_superscript_to_num = {
114
+ "⁰": "0",
115
+ "¹": "1",
116
+ "²": "2",
117
+ "³": "3",
118
+ "⁴": "4",
119
+ "⁵": "5",
120
+ "⁶": "6",
121
+ "⁷": "7",
122
+ "⁸": "8",
123
+ "⁹": "9",
124
+ "₀": "0",
125
+ "₁": "1",
126
+ "₂": "2",
127
+ "₃": "3",
128
+ "₄": "4",
129
+ "₅": "5",
130
+ "₆": "6",
131
+ "₇": "7",
132
+ "₈": "8",
133
+ "₉": "9",
134
+ }
135
+
136
+ if text in sub_or_superscript_to_num:
137
+ if sub_or_superscript_to_num[text] == ref_text:
138
+ return True
139
+ return False
140
+
141
+
142
+ def restore_token_case(word, word_tokens):
143
+
144
+ # remove repeated "▁" and "_" from word as that is what the tokenizer will do
145
+ while "▁▁" in word:
146
+ word = word.replace("▁▁", "▁")
147
+
148
+ while "__" in word:
149
+ word = word.repalce("__", "_")
150
+
151
+ word_tokens_cased = []
152
+ word_char_pointer = 0
153
+
154
+ for token in word_tokens:
155
+ token_cased = ""
156
+
157
+ for token_char in token:
158
+ if token_char == word[word_char_pointer]:
159
+ token_cased += token_char
160
+ word_char_pointer += 1
161
+
162
+ else:
163
+ if token_char.upper() == word[word_char_pointer] or is_sub_or_superscript_pair(
164
+ token_char, word[word_char_pointer]
165
+ ):
166
+ token_cased += token_char.upper()
167
+ word_char_pointer += 1
168
+ else:
169
+ if token_char == "▁" or token_char == "_":
170
+ if word[word_char_pointer] == "▁" or word[word_char_pointer] == "_":
171
+ token_cased += token_char
172
+ word_char_pointer += 1
173
+ elif word_char_pointer == 0:
174
+ token_cased += token_char
175
+
176
+ else:
177
+ raise RuntimeError(
178
+ f"Unexpected error - failed to recover capitalization of tokens for word {word}"
179
+ )
180
+
181
+ word_tokens_cased.append(token_cased)
182
+
183
+ return word_tokens_cased
184
+
185
+
186
+ @dataclass
187
+ class Token:
188
+ text: str = None
189
+ text_cased: str = None
190
+ s_start: int = None
191
+ s_end: int = None
192
+ t_start: float = None
193
+ t_end: float = None
194
+
195
+
196
+ @dataclass
197
+ class Word:
198
+ text: str = None
199
+ s_start: int = None
200
+ s_end: int = None
201
+ t_start: float = None
202
+ t_end: float = None
203
+ tokens: List[Token] = field(default_factory=list)
204
+
205
+
206
+ @dataclass
207
+ class Segment:
208
+ text: str = None
209
+ s_start: int = None
210
+ s_end: int = None
211
+ t_start: float = None
212
+ t_end: float = None
213
+ words_and_tokens: List[Union[Word, Token]] = field(default_factory=list)
214
+
215
+
216
+ @dataclass
217
+ class Utterance:
218
+ token_ids_with_blanks: List[int] = field(default_factory=list)
219
+ segments_and_tokens: List[Union[Segment, Token]] = field(default_factory=list)
220
+ text: str = None
221
+ pred_text: str = None
222
+ audio_filepath: str = None
223
+ utt_id: str = None
224
+ saved_output_files: dict = field(default_factory=dict)
225
+
226
+
227
+ def get_utt_obj(
228
+ text, model, separator, T, audio_filepath, utt_id,
229
+ ):
230
+ """
231
+ Function to create an Utterance object and add all necessary information to it except
232
+ for timings of the segments / words / tokens according to the alignment - that will
233
+ be done later in a different function, after the alignment is done.
234
+
235
+ The Utterance object has a list segments_and_tokens which contains Segment objects and
236
+ Token objects (for blank tokens in between segments).
237
+ Within the Segment objects, there is a list words_and_tokens which contains Word objects and
238
+ Token objects (for blank tokens in between words).
239
+ Within the Word objects, there is a list tokens tokens which contains Token objects for
240
+ blank and non-blank tokens.
241
+ We will be building up these lists in this function. This data structure will then be useful for
242
+ generating the various output files that we wish to save.
243
+ """
244
+
245
+ if not separator: # if separator is not defined - treat the whole text as one segment
246
+ segments = [text]
247
+ else:
248
+ segments = text.split(separator)
249
+
250
+ # remove any spaces at start and end of segments
251
+ segments = [seg.strip() for seg in segments]
252
+ # remove any empty segments
253
+ segments = [seg for seg in segments if len(seg) > 0]
254
+
255
+ utt = Utterance(text=text, audio_filepath=audio_filepath, utt_id=utt_id,)
256
+
257
+ # build up lists: token_ids_with_blanks, segments_and_tokens.
258
+ # The code for these is different depending on whether we use char-based tokens or not
259
+ if hasattr(model, 'tokenizer'):
260
+ if hasattr(model, 'blank_id'):
261
+ BLANK_ID = model.blank_id
262
+ else:
263
+ BLANK_ID = len(model.tokenizer.vocab) # TODO: check
264
+
265
+ utt.token_ids_with_blanks = [BLANK_ID]
266
+
267
+ # check for text being 0 length
268
+ if len(text) == 0:
269
+ return utt
270
+
271
+ # check for # tokens + token repetitions being > T
272
+ all_tokens = model.tokenizer.text_to_ids(text)
273
+ n_token_repetitions = 0
274
+ for i_tok in range(1, len(all_tokens)):
275
+ if all_tokens[i_tok] == all_tokens[i_tok - 1]:
276
+ n_token_repetitions += 1
277
+
278
+ if len(all_tokens) + n_token_repetitions > T:
279
+ logging.info(
280
+ f"Utterance {utt_id} has too many tokens compared to the audio file duration."
281
+ " Will not generate output alignment files for this utterance."
282
+ )
283
+ return utt
284
+
285
+ # build up data structures containing segments/words/tokens
286
+ utt.segments_and_tokens.append(Token(text=BLANK_TOKEN, text_cased=BLANK_TOKEN, s_start=0, s_end=0,))
287
+
288
+ segment_s_pointer = 1 # first segment will start at s=1 because s=0 is a blank
289
+ word_s_pointer = 1 # first word will start at s=1 because s=0 is a blank
290
+
291
+ for segment in segments:
292
+ # add the segment to segment_info and increment the segment_s_pointer
293
+ segment_tokens = model.tokenizer.text_to_tokens(segment)
294
+ utt.segments_and_tokens.append(
295
+ Segment(
296
+ text=segment,
297
+ s_start=segment_s_pointer,
298
+ # segment_tokens do not contain blanks => need to muliply by 2
299
+ # s_end needs to be the index of the final token (including blanks) of the current segment:
300
+ # segment_s_pointer + len(segment_tokens) * 2 is the index of the first token of the next segment =>
301
+ # => need to subtract 2
302
+ s_end=segment_s_pointer + len(segment_tokens) * 2 - 2,
303
+ )
304
+ )
305
+ segment_s_pointer += (
306
+ len(segment_tokens) * 2
307
+ ) # multiply by 2 to account for blanks (which are not present in segment_tokens)
308
+
309
+ words = segment.split(" ") # we define words to be space-separated sub-strings
310
+ for word_i, word in enumerate(words):
311
+
312
+ word_tokens = model.tokenizer.text_to_tokens(word)
313
+ word_token_ids = model.tokenizer.text_to_ids(word)
314
+ word_tokens_cased = restore_token_case(word, word_tokens)
315
+
316
+ # add the word to word_info and increment the word_s_pointer
317
+ utt.segments_and_tokens[-1].words_and_tokens.append(
318
+ # word_tokens do not contain blanks => need to muliply by 2
319
+ # s_end needs to be the index of the final token (including blanks) of the current word:
320
+ # word_s_pointer + len(word_tokens) * 2 is the index of the first token of the next word =>
321
+ # => need to subtract 2
322
+ Word(text=word, s_start=word_s_pointer, s_end=word_s_pointer + len(word_tokens) * 2 - 2)
323
+ )
324
+ word_s_pointer += (
325
+ len(word_tokens) * 2
326
+ ) # multiply by 2 to account for blanks (which are not present in word_tokens)
327
+
328
+ for token_i, (token, token_id, token_cased) in enumerate(
329
+ zip(word_tokens, word_token_ids, word_tokens_cased)
330
+ ):
331
+ # add the text tokens and the blanks in between them
332
+ # to our token-based variables
333
+ utt.token_ids_with_blanks.extend([token_id, BLANK_ID])
334
+ # adding Token object for non-blank token
335
+ utt.segments_and_tokens[-1].words_and_tokens[-1].tokens.append(
336
+ Token(
337
+ text=token,
338
+ text_cased=token_cased,
339
+ # utt.token_ids_with_blanks has the form [...., <this non-blank token>, <blank>] =>
340
+ # => if do len(utt.token_ids_with_blanks) - 1 you get the index of the final <blank>
341
+ # => we want to do len(utt.token_ids_with_blanks) - 2 to get the index of <this non-blank token>
342
+ s_start=len(utt.token_ids_with_blanks) - 2,
343
+ # s_end is same as s_start since the token only occupies one element in the list
344
+ s_end=len(utt.token_ids_with_blanks) - 2,
345
+ )
346
+ )
347
+
348
+ # adding Token object for blank tokens in between the tokens of the word
349
+ # (ie do not add another blank if you have reached the end)
350
+ if token_i < len(word_tokens) - 1:
351
+ utt.segments_and_tokens[-1].words_and_tokens[-1].tokens.append(
352
+ Token(
353
+ text=BLANK_TOKEN,
354
+ text_cased=BLANK_TOKEN,
355
+ # utt.token_ids_with_blanks has the form [...., <this blank token>] =>
356
+ # => if do len(utt.token_ids_with_blanks) -1 you get the index of this <blank>
357
+ s_start=len(utt.token_ids_with_blanks) - 1,
358
+ # s_end is same as s_start since the token only occupies one element in the list
359
+ s_end=len(utt.token_ids_with_blanks) - 1,
360
+ )
361
+ )
362
+
363
+ # add a Token object for blanks in between words in this segment
364
+ # (but only *in between* - do not add the token if it is after the final word)
365
+ if word_i < len(words) - 1:
366
+ utt.segments_and_tokens[-1].words_and_tokens.append(
367
+ Token(
368
+ text=BLANK_TOKEN,
369
+ text_cased=BLANK_TOKEN,
370
+ # utt.token_ids_with_blanks has the form [...., <this blank token>] =>
371
+ # => if do len(utt.token_ids_with_blanks) -1 you get the index of this <blank>
372
+ s_start=len(utt.token_ids_with_blanks) - 1,
373
+ # s_end is same as s_start since the token only occupies one element in the list
374
+ s_end=len(utt.token_ids_with_blanks) - 1,
375
+ )
376
+ )
377
+
378
+ # add the blank token in between segments/after the final segment
379
+ utt.segments_and_tokens.append(
380
+ Token(
381
+ text=BLANK_TOKEN,
382
+ text_cased=BLANK_TOKEN,
383
+ # utt.token_ids_with_blanks has the form [...., <this blank token>] =>
384
+ # => if do len(utt.token_ids_with_blanks) -1 you get the index of this <blank>
385
+ s_start=len(utt.token_ids_with_blanks) - 1,
386
+ # s_end is same as s_start since the token only occupies one element in the list
387
+ s_end=len(utt.token_ids_with_blanks) - 1,
388
+ )
389
+ )
390
+
391
+ return utt
392
+
393
+ elif hasattr(model.decoder, "vocabulary"): # i.e. tokenization is simply character-based
394
+
395
+ BLANK_ID = len(model.decoder.vocabulary) # TODO: check this is correct
396
+ SPACE_ID = model.decoder.vocabulary.index(" ")
397
+
398
+ utt.token_ids_with_blanks = [BLANK_ID]
399
+
400
+ # check for text being 0 length
401
+ if len(text) == 0:
402
+ return utt
403
+
404
+ # check for # tokens + token repetitions being > T
405
+ all_tokens = get_char_tokens(text, model)
406
+ n_token_repetitions = 0
407
+ for i_tok in range(1, len(all_tokens)):
408
+ if all_tokens[i_tok] == all_tokens[i_tok - 1]:
409
+ n_token_repetitions += 1
410
+
411
+ if len(all_tokens) + n_token_repetitions > T:
412
+ logging.info(
413
+ f"Utterance {utt_id} has too many tokens compared to the audio file duration."
414
+ " Will not generate output alignment files for this utterance."
415
+ )
416
+ return utt
417
+
418
+ # build up data structures containing segments/words/tokens
419
+ utt.segments_and_tokens.append(Token(text=BLANK_TOKEN, text_cased=BLANK_TOKEN, s_start=0, s_end=0,))
420
+
421
+ segment_s_pointer = 1 # first segment will start at s=1 because s=0 is a blank
422
+ word_s_pointer = 1 # first word will start at s=1 because s=0 is a blank
423
+
424
+ for i_segment, segment in enumerate(segments):
425
+ # add the segment to segment_info and increment the segment_s_pointer
426
+ segment_tokens = get_char_tokens(segment, model)
427
+ utt.segments_and_tokens.append(
428
+ Segment(
429
+ text=segment,
430
+ s_start=segment_s_pointer,
431
+ # segment_tokens do not contain blanks => need to muliply by 2
432
+ # s_end needs to be the index of the final token (including blanks) of the current segment:
433
+ # segment_s_pointer + len(segment_tokens) * 2 is the index of the first token of the next segment =>
434
+ # => need to subtract 2
435
+ s_end=segment_s_pointer + len(segment_tokens) * 2 - 2,
436
+ )
437
+ )
438
+
439
+ # for correct calculation: multiply len(segment_tokens) by 2 to account for blanks (which are not present in segment_tokens)
440
+ # and + 2 to account for [<token for space in between segments>, <blank token after that space token>]
441
+ segment_s_pointer += len(segment_tokens) * 2 + 2
442
+
443
+ words = segment.split(" ") # we define words to be space-separated substrings
444
+ for i_word, word in enumerate(words):
445
+
446
+ # convert string to list of characters
447
+ word_tokens = list(word)
448
+ # convert list of characters to list of their ids in the vocabulary
449
+ word_token_ids = get_char_tokens(word, model)
450
+
451
+ # add the word to word_info and increment the word_s_pointer
452
+ utt.segments_and_tokens[-1].words_and_tokens.append(
453
+ # note for s_end:
454
+ # word_tokens do not contain blanks => need to muliply by 2
455
+ # s_end needs to be the index of the final token (including blanks) of the current word:
456
+ # word_s_pointer + len(word_tokens) * 2 is the index of the first token of the next word =>
457
+ # => need to subtract 2
458
+ Word(text=word, s_start=word_s_pointer, s_end=word_s_pointer + len(word_tokens) * 2 - 2)
459
+ )
460
+
461
+ # for correct calculation: multiply len(word_tokens) by 2 to account for blanks (which are not present in word_tokens)
462
+ # and + 2 to account for [<token for space in between words>, <blank token after that space token>]
463
+ word_s_pointer += len(word_tokens) * 2 + 2
464
+
465
+ for token_i, (token, token_id) in enumerate(zip(word_tokens, word_token_ids)):
466
+ # add the text tokens and the blanks in between them
467
+ # to our token-based variables
468
+ utt.token_ids_with_blanks.extend([token_id])
469
+ utt.segments_and_tokens[-1].words_and_tokens[-1].tokens.append(
470
+ Token(
471
+ text=token,
472
+ text_cased=token,
473
+ # utt.token_ids_with_blanks has the form [..., <this non-blank token>]
474
+ # => do len(utt.token_ids_with_blanks) - 1 to get the index of this non-blank token
475
+ s_start=len(utt.token_ids_with_blanks) - 1,
476
+ # s_end is same as s_start since the token only occupies one element in the list
477
+ s_end=len(utt.token_ids_with_blanks) - 1,
478
+ )
479
+ )
480
+
481
+ if token_i < len(word_tokens) - 1: # only add blank tokens that are in the middle of words
482
+ utt.token_ids_with_blanks.extend([BLANK_ID])
483
+ utt.segments_and_tokens[-1].words_and_tokens[-1].tokens.append(
484
+ Token(
485
+ text=BLANK_TOKEN,
486
+ text_cased=BLANK_TOKEN,
487
+ # utt.token_ids_with_blanks has the form [..., <this blank token>]
488
+ # => do len(utt.token_ids_with_blanks) - 1 to get the index of this blank token
489
+ s_start=len(utt.token_ids_with_blanks) - 1,
490
+ # s_end is same as s_start since the token only occupies one element in the list
491
+ s_end=len(utt.token_ids_with_blanks) - 1,
492
+ )
493
+ )
494
+
495
+ # add space token (and the blanks around it) unless this is the final word in a segment
496
+ if i_word < len(words) - 1:
497
+ utt.token_ids_with_blanks.extend([BLANK_ID, SPACE_ID, BLANK_ID])
498
+ utt.segments_and_tokens[-1].words_and_tokens.append(
499
+ Token(
500
+ text=BLANK_TOKEN,
501
+ text_cased=BLANK_TOKEN,
502
+ # utt.token_ids_with_blanks has the form
503
+ # [..., <final token of previous word>, <blank token>, <space token>, <blank token>]
504
+ # => do len(utt.token_ids_with_blanks) - 3 to get the index of the blank token before the space token
505
+ s_start=len(utt.token_ids_with_blanks) - 3,
506
+ # s_end is same as s_start since the token only occupies one element in the list
507
+ s_end=len(utt.token_ids_with_blanks) - 3,
508
+ )
509
+ )
510
+ utt.segments_and_tokens[-1].words_and_tokens.append(
511
+ Token(
512
+ text=SPACE_TOKEN,
513
+ text_cased=SPACE_TOKEN,
514
+ # utt.token_ids_with_blanks has the form
515
+ # [..., <final token of previous word>, <blank token>, <space token>, <blank token>]
516
+ # => do len(utt.token_ids_with_blanks) - 2 to get the index of the space token
517
+ s_start=len(utt.token_ids_with_blanks) - 2,
518
+ # s_end is same as s_start since the token only occupies one element in the list
519
+ s_end=len(utt.token_ids_with_blanks) - 2,
520
+ )
521
+ )
522
+ utt.segments_and_tokens[-1].words_and_tokens.append(
523
+ Token(
524
+ text=BLANK_TOKEN,
525
+ text_cased=BLANK_TOKEN,
526
+ # utt.token_ids_with_blanks has the form
527
+ # [..., <final token of previous word>, <blank token>, <space token>, <blank token>]
528
+ # => do len(utt.token_ids_with_blanks) - 1 to get the index of the blank token after the space token
529
+ s_start=len(utt.token_ids_with_blanks) - 1,
530
+ # s_end is same as s_start since the token only occupies one element in the list
531
+ s_end=len(utt.token_ids_with_blanks) - 1,
532
+ )
533
+ )
534
+
535
+ # add a blank to the segment, and add a space after if this is not the final segment
536
+ utt.token_ids_with_blanks.extend([BLANK_ID])
537
+ utt.segments_and_tokens.append(
538
+ Token(
539
+ text=BLANK_TOKEN,
540
+ text_cased=BLANK_TOKEN,
541
+ # utt.token_ids_with_blanks has the form [..., <this blank token>]
542
+ # => do len(utt.token_ids_with_blanks) - 1 to get the index of this blank token
543
+ s_start=len(utt.token_ids_with_blanks) - 1,
544
+ # s_end is same as s_start since the token only occupies one element in the list
545
+ s_end=len(utt.token_ids_with_blanks) - 1,
546
+ )
547
+ )
548
+
549
+ if i_segment < len(segments) - 1:
550
+ utt.token_ids_with_blanks.extend([SPACE_ID, BLANK_ID])
551
+ utt.segments_and_tokens.append(
552
+ Token(
553
+ text=SPACE_TOKEN,
554
+ text_cased=SPACE_TOKEN,
555
+ # utt.token_ids_with_blanks has the form
556
+ # [..., <space token>, <blank token>]
557
+ # => do len(utt.token_ids_with_blanks) - 2 to get the index of the space token
558
+ s_start=len(utt.token_ids_with_blanks) - 2,
559
+ # s_end is same as s_start since the token only occupies one element in the list
560
+ s_end=len(utt.token_ids_with_blanks) - 2,
561
+ )
562
+ )
563
+ utt.segments_and_tokens.append(
564
+ Token(
565
+ text=BLANK_TOKEN,
566
+ text_cased=BLANK_TOKEN,
567
+ # utt.token_ids_with_blanks has the form
568
+ # [..., <space token>, <blank token>]
569
+ # => do len(utt.token_ids_with_blanks) - 1 to get the index of the blank token
570
+ s_start=len(utt.token_ids_with_blanks) - 1,
571
+ # s_end is same as s_start since the token only occupies one element in the list
572
+ s_end=len(utt.token_ids_with_blanks) - 1,
573
+ )
574
+ )
575
+
576
+ return utt
577
+
578
+ else:
579
+ raise RuntimeError("Cannot get tokens of this model.")
580
+
581
+
582
+ def add_t_start_end_to_utt_obj(utt_obj, alignment_utt, output_timestep_duration):
583
+ """
584
+ Function to add t_start and t_end (representing time in seconds) to the Utterance object utt_obj.
585
+ Args:
586
+ utt_obj: Utterance object to which we will add t_start and t_end for its
587
+ constituent segments/words/tokens.
588
+ alignment_utt: a list of ints indicating which token does the alignment pass through at each
589
+ timestep (will take the form [0, 0, 1, 1, ..., <num of tokens including blanks in uterance>]).
590
+ output_timestep_duration: a float indicating the duration of a single output timestep from
591
+ the ASR Model.
592
+
593
+ Returns:
594
+ utt_obj: updated Utterance object.
595
+ """
596
+
597
+ # General idea for the algorithm of how we add t_start and t_end
598
+ # the timestep where a token s starts is the location of the first appearance of s_start in alignment_utt
599
+ # the timestep where a token s ends is the location of the final appearance of s_end in alignment_utt
600
+ # We will make dictionaries num_to_first_alignment_appearance and
601
+ # num_to_last_appearance and use that to update all of
602
+ # the t_start and t_end values in utt_obj.
603
+ # We will put t_start = t_end = -1 for tokens that are skipped (should only be blanks)
604
+
605
+ num_to_first_alignment_appearance = dict()
606
+ num_to_last_alignment_appearance = dict()
607
+
608
+ prev_s = -1 # use prev_s to keep track of when the s changes
609
+ for t, s in enumerate(alignment_utt):
610
+ if s > prev_s:
611
+ num_to_first_alignment_appearance[s] = t
612
+
613
+ if prev_s >= 0: # dont record prev_s = -1
614
+ num_to_last_alignment_appearance[prev_s] = t - 1
615
+ prev_s = s
616
+ # add last appearance of the final s
617
+ num_to_last_alignment_appearance[prev_s] = len(alignment_utt) - 1
618
+
619
+ # update all the t_start and t_end in utt_obj
620
+ for segment_or_token in utt_obj.segments_and_tokens:
621
+ if type(segment_or_token) is Segment:
622
+ segment = segment_or_token
623
+ segment.t_start = num_to_first_alignment_appearance[segment.s_start] * output_timestep_duration
624
+ segment.t_end = (num_to_last_alignment_appearance[segment.s_end] + 1) * output_timestep_duration
625
+
626
+ for word_or_token in segment.words_and_tokens:
627
+ if type(word_or_token) is Word:
628
+ word = word_or_token
629
+ word.t_start = num_to_first_alignment_appearance[word.s_start] * output_timestep_duration
630
+ word.t_end = (num_to_last_alignment_appearance[word.s_end] + 1) * output_timestep_duration
631
+
632
+ for token in word.tokens:
633
+ if token.s_start in num_to_first_alignment_appearance:
634
+ token.t_start = num_to_first_alignment_appearance[token.s_start] * output_timestep_duration
635
+ else:
636
+ token.t_start = -1
637
+
638
+ if token.s_end in num_to_last_alignment_appearance:
639
+ token.t_end = (
640
+ num_to_last_alignment_appearance[token.s_end] + 1
641
+ ) * output_timestep_duration
642
+ else:
643
+ token.t_end = -1
644
+ else:
645
+ token = word_or_token
646
+ if token.s_start in num_to_first_alignment_appearance:
647
+ token.t_start = num_to_first_alignment_appearance[token.s_start] * output_timestep_duration
648
+ else:
649
+ token.t_start = -1
650
+
651
+ if token.s_end in num_to_last_alignment_appearance:
652
+ token.t_end = (num_to_last_alignment_appearance[token.s_end] + 1) * output_timestep_duration
653
+ else:
654
+ token.t_end = -1
655
+
656
+ else:
657
+ token = segment_or_token
658
+ if token.s_start in num_to_first_alignment_appearance:
659
+ token.t_start = num_to_first_alignment_appearance[token.s_start] * output_timestep_duration
660
+ else:
661
+ token.t_start = -1
662
+
663
+ if token.s_end in num_to_last_alignment_appearance:
664
+ token.t_end = (num_to_last_alignment_appearance[token.s_end] + 1) * output_timestep_duration
665
+ else:
666
+ token.t_end = -1
667
+
668
+ return utt_obj
669
+
670
+
671
+ def get_batch_variables(
672
+ manifest_lines_batch,
673
+ model,
674
+ separator,
675
+ align_using_pred_text,
676
+ audio_filepath_parts_in_utt_id,
677
+ output_timestep_duration,
678
+ simulate_cache_aware_streaming=False,
679
+ use_buffered_chunked_streaming=False,
680
+ buffered_chunk_params={},
681
+ ):
682
+ """
683
+ Returns:
684
+ log_probs, y, T, U (y and U are s.t. every other token is a blank) - these are the tensors we will need
685
+ during Viterbi decoding.
686
+ utt_obj_batch: a list of Utterance objects for every utterance in the batch.
687
+ output_timestep_duration: a float indicating the duration of a single output timestep from
688
+ the ASR Model.
689
+ """
690
+
691
+ # get hypotheses by calling 'transcribe'
692
+ # we will use the output log_probs, the duration of the log_probs,
693
+ # and (optionally) the predicted ASR text from the hypotheses
694
+ audio_filepaths_batch = [line["audio_filepath"] for line in manifest_lines_batch]
695
+ B = len(audio_filepaths_batch)
696
+ log_probs_list_batch = []
697
+ T_list_batch = []
698
+ pred_text_batch = []
699
+
700
+ if not use_buffered_chunked_streaming:
701
+ if not simulate_cache_aware_streaming:
702
+ with torch.no_grad():
703
+ hypotheses = model.transcribe(audio_filepaths_batch, return_hypotheses=True, batch_size=B)
704
+ else:
705
+ with torch.no_grad():
706
+ hypotheses = model.transcribe_simulate_cache_aware_streaming(
707
+ audio_filepaths_batch, return_hypotheses=True, batch_size=B
708
+ )
709
+
710
+ # if hypotheses form a tuple (from Hybrid model), extract just "best" hypothesis
711
+ if type(hypotheses) == tuple and len(hypotheses) == 2:
712
+ hypotheses = hypotheses[0]
713
+
714
+ for hypothesis in hypotheses:
715
+ log_probs_list_batch.append(hypothesis.y_sequence)
716
+ T_list_batch.append(hypothesis.y_sequence.shape[0])
717
+ pred_text_batch.append(hypothesis.text)
718
+ else:
719
+ delay = buffered_chunk_params["delay"]
720
+ model_stride_in_secs = buffered_chunk_params["model_stride_in_secs"]
721
+ tokens_per_chunk = buffered_chunk_params["tokens_per_chunk"]
722
+ for l in tqdm(audio_filepaths_batch, desc="Sample:"):
723
+ model.reset()
724
+ model.read_audio_file(l, delay, model_stride_in_secs)
725
+ hyp, logits = model.transcribe(tokens_per_chunk, delay, keep_logits=True)
726
+ log_probs_list_batch.append(logits)
727
+ T_list_batch.append(logits.shape[0])
728
+ pred_text_batch.append(hyp)
729
+
730
+ # we loop over every line in the manifest that is in our current batch,
731
+ # and record the y (list of tokens, including blanks), U (list of lengths of y) and
732
+ # token_info_batch, word_info_batch, segment_info_batch
733
+ y_list_batch = []
734
+ U_list_batch = []
735
+ utt_obj_batch = []
736
+
737
+ for i_line, line in enumerate(manifest_lines_batch):
738
+ if align_using_pred_text:
739
+ gt_text_for_alignment = " ".join(pred_text_batch[i_line].split())
740
+ else:
741
+ gt_text_for_alignment = line["text"]
742
+ utt_obj = get_utt_obj(
743
+ gt_text_for_alignment,
744
+ model,
745
+ separator,
746
+ T_list_batch[i_line],
747
+ audio_filepaths_batch[i_line],
748
+ _get_utt_id(audio_filepaths_batch[i_line], audio_filepath_parts_in_utt_id),
749
+ )
750
+
751
+ # update utt_obj.pred_text or utt_obj.text
752
+ if align_using_pred_text:
753
+ utt_obj.pred_text = pred_text_batch[i_line]
754
+ if len(utt_obj.pred_text) == 0:
755
+ logging.info(
756
+ f"'pred_text' of utterance {utt_obj.utt_id} is empty - we will not generate"
757
+ " any output alignment files for this utterance"
758
+ )
759
+ if "text" in line:
760
+ utt_obj.text = line["text"] # keep the text as we will save it in the output manifest
761
+ else:
762
+ utt_obj.text = line["text"]
763
+ if len(utt_obj.text) == 0:
764
+ logging.info(
765
+ f"'text' of utterance {utt_obj.utt_id} is empty - we will not generate"
766
+ " any output alignment files for this utterance"
767
+ )
768
+
769
+ y_list_batch.append(utt_obj.token_ids_with_blanks)
770
+ U_list_batch.append(len(utt_obj.token_ids_with_blanks))
771
+ utt_obj_batch.append(utt_obj)
772
+
773
+ # turn log_probs, y, T, U into dense tensors for fast computation during Viterbi decoding
774
+ T_max = max(T_list_batch)
775
+ U_max = max(U_list_batch)
776
+ # V = the number of tokens in the vocabulary + 1 for the blank token.
777
+ if hasattr(model, 'tokenizer'):
778
+ V = len(model.tokenizer.vocab) + 1
779
+ else:
780
+ V = len(model.decoder.vocabulary) + 1
781
+ T_batch = torch.tensor(T_list_batch)
782
+ U_batch = torch.tensor(U_list_batch)
783
+
784
+ # make log_probs_batch tensor of shape (B x T_max x V)
785
+ log_probs_batch = V_NEGATIVE_NUM * torch.ones((B, T_max, V))
786
+ for b, log_probs_utt in enumerate(log_probs_list_batch):
787
+ t = log_probs_utt.shape[0]
788
+ log_probs_batch[b, :t, :] = log_probs_utt
789
+
790
+ # make y tensor of shape (B x U_max)
791
+ # populate it initially with all 'V' numbers so that the 'V's will remain in the areas that
792
+ # are 'padding'. This will be useful for when we make 'log_probs_reorderd' during Viterbi decoding
793
+ # in a different function.
794
+ y_batch = V * torch.ones((B, U_max), dtype=torch.int64)
795
+ for b, y_utt in enumerate(y_list_batch):
796
+ U_utt = U_batch[b]
797
+ y_batch[b, :U_utt] = torch.tensor(y_utt)
798
+
799
+ # calculate output_timestep_duration if it is None
800
+ if output_timestep_duration is None:
801
+ if not 'window_stride' in model.cfg.preprocessor:
802
+ raise ValueError(
803
+ "Don't have attribute 'window_stride' in 'model.cfg.preprocessor' => cannot calculate "
804
+ " model_downsample_factor => stopping process"
805
+ )
806
+
807
+ if not 'sample_rate' in model.cfg.preprocessor:
808
+ raise ValueError(
809
+ "Don't have attribute 'sample_rate' in 'model.cfg.preprocessor' => cannot calculate start "
810
+ " and end time of segments => stopping process"
811
+ )
812
+
813
+ with sf.SoundFile(audio_filepaths_batch[0]) as f:
814
+ audio_dur = f.frames / f.samplerate
815
+ n_input_frames = audio_dur / model.cfg.preprocessor.window_stride
816
+ model_downsample_factor = round(n_input_frames / int(T_batch[0]))
817
+
818
+ output_timestep_duration = (
819
+ model.preprocessor.featurizer.hop_length * model_downsample_factor / model.cfg.preprocessor.sample_rate
820
+ )
821
+
822
+ logging.info(
823
+ f"Calculated that the model downsample factor is {model_downsample_factor}"
824
+ f" and therefore the ASR model output timestep duration is {output_timestep_duration}"
825
+ " -- will use this for all batches"
826
+ )
827
+
828
+ return (
829
+ log_probs_batch,
830
+ y_batch,
831
+ T_batch,
832
+ U_batch,
833
+ utt_obj_batch,
834
+ output_timestep_duration,
835
+ )
utils/make_ass_files.py ADDED
@@ -0,0 +1,462 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ This file contains functions for make ASS-format subtitle files based on the generated alignment.
17
+ ASS files can be generated highlighting token-level alignments or word-level alignments.
18
+ In both cases, 'segment' boundaries will be used to determine which parts of the text will appear
19
+ at the same time.
20
+ For the token-level ASS files, the text will be highlighted token-by-token, with the timings determined
21
+ by the NFA alignments.
22
+ For the word-level ASS files, the text will be highlighted word-by-word, with the timings determined
23
+ by the NFA alignemtns.
24
+ """
25
+
26
+ import os
27
+
28
+ from utils.constants import BLANK_TOKEN, SPACE_TOKEN
29
+ from utils.data_prep import Segment, Token, Word
30
+
31
+ PLAYERRESX = 384
32
+ PLAYERRESY = 288
33
+ MARGINL = 10
34
+ MARGINR = 10
35
+ MARGINV = 20
36
+
37
+
38
+ def seconds_to_ass_format(seconds_float):
39
+ seconds_float = float(seconds_float)
40
+ mm, ss_decimals = divmod(seconds_float, 60)
41
+ hh, mm = divmod(mm, 60)
42
+
43
+ hh = str(round(hh))
44
+ if len(hh) == 1:
45
+ hh = '0' + hh
46
+
47
+ mm = str(round(mm))
48
+ if len(mm) == 1:
49
+ mm = '0' + mm
50
+
51
+ ss_decimals = f"{ss_decimals:.2f}"
52
+ if len(ss_decimals.split(".")[0]) == 1:
53
+ ss_decimals = "0" + ss_decimals
54
+
55
+ srt_format_time = f"{hh}:{mm}:{ss_decimals}"
56
+
57
+ return srt_format_time
58
+
59
+
60
+ def rgb_list_to_hex_bgr(rgb_list):
61
+ r, g, b = rgb_list
62
+ return f"{b:x}{g:x}{r:x}"
63
+
64
+
65
+ def make_ass_files(
66
+ utt_obj, output_dir_root, ass_file_config,
67
+ ):
68
+
69
+ # don't try to make files if utt_obj.segments_and_tokens is empty, which will happen
70
+ # in the case of the ground truth text being empty or the number of tokens being too large vs audio duration
71
+ if not utt_obj.segments_and_tokens:
72
+ return utt_obj
73
+
74
+ if ass_file_config.resegment_text_to_fill_space:
75
+ utt_obj = resegment_utt_obj(utt_obj, ass_file_config)
76
+
77
+ utt_obj = make_word_level_ass_file(utt_obj, output_dir_root, ass_file_config,)
78
+ utt_obj = make_token_level_ass_file(utt_obj, output_dir_root, ass_file_config,)
79
+
80
+ return utt_obj
81
+
82
+
83
+ def _get_word_n_chars(word):
84
+ n_chars = 0
85
+ for token in word.tokens:
86
+ if token.text != BLANK_TOKEN:
87
+ n_chars += len(token.text)
88
+ return n_chars
89
+
90
+
91
+ def _get_segment_n_chars(segment):
92
+ n_chars = 0
93
+ for word_or_token in segment.words_and_tokens:
94
+ if word_or_token.text == SPACE_TOKEN:
95
+ n_chars += 1
96
+ elif word_or_token.text != BLANK_TOKEN:
97
+ n_chars += len(word_or_token.text)
98
+ return n_chars
99
+
100
+
101
+ def resegment_utt_obj(utt_obj, ass_file_config):
102
+
103
+ # get list of just all words and tokens
104
+ all_words_and_tokens = []
105
+ for segment_or_token in utt_obj.segments_and_tokens:
106
+ if type(segment_or_token) is Segment:
107
+ all_words_and_tokens.extend(segment_or_token.words_and_tokens)
108
+ else:
109
+ all_words_and_tokens.append(segment_or_token)
110
+
111
+ # figure out how many chars will fit into one 'slide' and thus should be the max
112
+ # size of a segment
113
+ approx_chars_per_line = (PLAYERRESX - MARGINL - MARGINR) / (
114
+ ass_file_config.fontsize * 0.6
115
+ ) # assume chars 0.6 as wide as they are tall
116
+ approx_lines_per_segment = (PLAYERRESY - MARGINV) / (
117
+ ass_file_config.fontsize * 1.15
118
+ ) # assume line spacing is 1.15
119
+ if approx_lines_per_segment > ass_file_config.max_lines_per_segment:
120
+ approx_lines_per_segment = ass_file_config.max_lines_per_segment
121
+
122
+ max_chars_per_segment = int(approx_chars_per_line * approx_lines_per_segment)
123
+
124
+ new_segments_and_tokens = []
125
+ all_words_and_tokens_pointer = 0
126
+ for word_or_token in all_words_and_tokens:
127
+ if type(word_or_token) is Token:
128
+ new_segments_and_tokens.append(word_or_token)
129
+ all_words_and_tokens_pointer += 1
130
+ else:
131
+ break
132
+
133
+ new_segments_and_tokens.append(Segment())
134
+
135
+ while all_words_and_tokens_pointer < len(all_words_and_tokens):
136
+ word_or_token = all_words_and_tokens[all_words_and_tokens_pointer]
137
+ if type(word_or_token) is Word:
138
+
139
+ # if this is going to be the first word in the segment, we definitely want
140
+ # to add it to the segment
141
+ if not new_segments_and_tokens[-1].words_and_tokens:
142
+ new_segments_and_tokens[-1].words_and_tokens.append(word_or_token)
143
+
144
+ else:
145
+ # if not the first word, check what the new length of the segment will be
146
+ # if short enough - add this word to this segment;
147
+ # if too long - add to a new segment
148
+ this_word_n_chars = _get_word_n_chars(word_or_token)
149
+ segment_so_far_n_chars = _get_segment_n_chars(new_segments_and_tokens[-1])
150
+ if this_word_n_chars + segment_so_far_n_chars < max_chars_per_segment:
151
+ new_segments_and_tokens[-1].words_and_tokens.append(word_or_token)
152
+ else:
153
+ new_segments_and_tokens.append(Segment())
154
+ new_segments_and_tokens[-1].words_and_tokens.append(word_or_token)
155
+
156
+ else: # i.e. word_or_token is a token
157
+ # currently this breaks the convention of tokens at the end/beginning
158
+ # of segments being listed as separate tokens in segment.word_and_tokens
159
+ # TODO: change code so we follow this convention
160
+ new_segments_and_tokens[-1].words_and_tokens.append(word_or_token)
161
+
162
+ all_words_and_tokens_pointer += 1
163
+
164
+ utt_obj.segments_and_tokens = new_segments_and_tokens
165
+
166
+ return utt_obj
167
+
168
+
169
+ def make_word_level_ass_file(
170
+ utt_obj, output_dir_root, ass_file_config,
171
+ ):
172
+
173
+ default_style_dict = {
174
+ "Name": "Default",
175
+ "Fontname": "Arial",
176
+ "Fontsize": str(ass_file_config.fontsize),
177
+ "PrimaryColour": "&Hffffff",
178
+ "SecondaryColour": "&Hffffff",
179
+ "OutlineColour": "&H0",
180
+ "BackColour": "&H0",
181
+ "Bold": "0",
182
+ "Italic": "0",
183
+ "Underline": "0",
184
+ "StrikeOut": "0",
185
+ "ScaleX": "100",
186
+ "ScaleY": "100",
187
+ "Spacing": "0",
188
+ "Angle": "0",
189
+ "BorderStyle": "1",
190
+ "Outline": "1",
191
+ "Shadow": "0",
192
+ "Alignment": None, # will specify below
193
+ "MarginL": str(MARGINL),
194
+ "MarginR": str(MARGINR),
195
+ "MarginV": str(MARGINV),
196
+ "Encoding": "0",
197
+ }
198
+
199
+ if ass_file_config.vertical_alignment == "top":
200
+ default_style_dict["Alignment"] = "8" # text will be 'center-justified' and in the top of the screen
201
+ elif ass_file_config.vertical_alignment == "center":
202
+ default_style_dict["Alignment"] = "5" # text will be 'center-justified' and in the middle of the screen
203
+ elif ass_file_config.vertical_alignment == "bottom":
204
+ default_style_dict["Alignment"] = "2" # text will be 'center-justified' and in the bottom of the screen
205
+ else:
206
+ raise ValueError(f"got an unexpected value for ass_file_config.vertical_alignment")
207
+
208
+ output_dir = os.path.join(output_dir_root, "ass", "words")
209
+ os.makedirs(output_dir, exist_ok=True)
210
+ output_file = os.path.join(output_dir, f"{utt_obj.utt_id}.ass")
211
+
212
+ already_spoken_color_code = r"{\c&H" + rgb_list_to_hex_bgr(ass_file_config.text_already_spoken_rgb) + r"&}"
213
+ being_spoken_color_code = r"{\c&H" + rgb_list_to_hex_bgr(ass_file_config.text_being_spoken_rgb) + r"&}"
214
+ not_yet_spoken_color_code = r"{\c&H" + rgb_list_to_hex_bgr(ass_file_config.text_not_yet_spoken_rgb) + r"&}"
215
+
216
+ with open(output_file, 'w') as f:
217
+ default_style_top_line = "Format: " + ", ".join(default_style_dict.keys())
218
+ default_style_bottom_line = "Style: " + ",".join(default_style_dict.values())
219
+
220
+ f.write(
221
+ (
222
+ "[Script Info]\n"
223
+ "ScriptType: v4.00+\n"
224
+ f"PlayResX: {PLAYERRESX}\n"
225
+ f"PlayResY: {PLAYERRESY}\n"
226
+ "\n"
227
+ "[V4+ Styles]\n"
228
+ f"{default_style_top_line}\n"
229
+ f"{default_style_bottom_line}\n"
230
+ "\n"
231
+ "[Events]\n"
232
+ "Format: Layer, Start, End, Style, Name, MarginL, MarginR, MarginV, Effect, Text\n\n"
233
+ )
234
+ )
235
+
236
+ # write first set of subtitles for text before speech starts to be spoken
237
+ words_in_first_segment = []
238
+ for segment_or_token in utt_obj.segments_and_tokens:
239
+ if type(segment_or_token) is Segment:
240
+ first_segment = segment_or_token
241
+
242
+ for word_or_token in first_segment.words_and_tokens:
243
+ if type(word_or_token) is Word:
244
+ words_in_first_segment.append(word_or_token)
245
+ break
246
+
247
+ text_before_speech = not_yet_spoken_color_code + " ".join([x.text for x in words_in_first_segment]) + r"{\r}"
248
+ subtitle_text = (
249
+ f"Dialogue: 0,{seconds_to_ass_format(0)},{seconds_to_ass_format(words_in_first_segment[0].t_start)},Default,,0,0,0,,"
250
+ + text_before_speech.rstrip()
251
+ )
252
+
253
+ f.write(subtitle_text + '\n')
254
+
255
+ for segment_or_token in utt_obj.segments_and_tokens:
256
+ if type(segment_or_token) is Segment:
257
+ segment = segment_or_token
258
+
259
+ words_in_segment = []
260
+ for word_or_token in segment.words_and_tokens:
261
+ if type(word_or_token) is Word:
262
+ words_in_segment.append(word_or_token)
263
+
264
+ for word_i, word in enumerate(words_in_segment):
265
+
266
+ text_before = " ".join([x.text for x in words_in_segment[:word_i]])
267
+ if text_before != "":
268
+ text_before += " "
269
+ text_before = already_spoken_color_code + text_before + r"{\r}"
270
+
271
+ if word_i < len(words_in_segment) - 1:
272
+ text_after = " " + " ".join([x.text for x in words_in_segment[word_i + 1 :]])
273
+ else:
274
+ text_after = ""
275
+ text_after = not_yet_spoken_color_code + text_after + r"{\r}"
276
+
277
+ aligned_text = being_spoken_color_code + word.text + r"{\r}"
278
+ aligned_text_off = already_spoken_color_code + word.text + r"{\r}"
279
+
280
+ subtitle_text = (
281
+ f"Dialogue: 0,{seconds_to_ass_format(word.t_start)},{seconds_to_ass_format(word.t_end)},Default,,0,0,0,,"
282
+ + text_before
283
+ + aligned_text
284
+ + text_after.rstrip()
285
+ )
286
+ f.write(subtitle_text + '\n')
287
+
288
+ # add subtitles without word-highlighting for when words are not being spoken
289
+ if word_i < len(words_in_segment) - 1:
290
+ last_word_end = float(words_in_segment[word_i].t_end)
291
+ next_word_start = float(words_in_segment[word_i + 1].t_start)
292
+ if next_word_start - last_word_end > 0.001:
293
+ subtitle_text = (
294
+ f"Dialogue: 0,{seconds_to_ass_format(last_word_end)},{seconds_to_ass_format(next_word_start)},Default,,0,0,0,,"
295
+ + text_before
296
+ + aligned_text_off
297
+ + text_after.rstrip()
298
+ )
299
+ f.write(subtitle_text + '\n')
300
+
301
+ utt_obj.saved_output_files[f"words_level_ass_filepath"] = output_file
302
+
303
+ return utt_obj
304
+
305
+
306
+ def make_token_level_ass_file(
307
+ utt_obj, output_dir_root, ass_file_config,
308
+ ):
309
+
310
+ default_style_dict = {
311
+ "Name": "Default",
312
+ "Fontname": "Arial",
313
+ "Fontsize": str(ass_file_config.fontsize),
314
+ "PrimaryColour": "&Hffffff",
315
+ "SecondaryColour": "&Hffffff",
316
+ "OutlineColour": "&H0",
317
+ "BackColour": "&H0",
318
+ "Bold": "0",
319
+ "Italic": "0",
320
+ "Underline": "0",
321
+ "StrikeOut": "0",
322
+ "ScaleX": "100",
323
+ "ScaleY": "100",
324
+ "Spacing": "0",
325
+ "Angle": "0",
326
+ "BorderStyle": "1",
327
+ "Outline": "1",
328
+ "Shadow": "0",
329
+ "Alignment": None, # will specify below
330
+ "MarginL": str(MARGINL),
331
+ "MarginR": str(MARGINR),
332
+ "MarginV": str(MARGINV),
333
+ "Encoding": "0",
334
+ }
335
+
336
+ if ass_file_config.vertical_alignment == "top":
337
+ default_style_dict["Alignment"] = "8" # text will be 'center-justified' and in the top of the screen
338
+ elif ass_file_config.vertical_alignment == "center":
339
+ default_style_dict["Alignment"] = "5" # text will be 'center-justified' and in the middle of the screen
340
+ elif ass_file_config.vertical_alignment == "bottom":
341
+ default_style_dict["Alignment"] = "2" # text will be 'center-justified' and in the bottom of the screen
342
+ else:
343
+ raise ValueError(f"got an unexpected value for ass_file_config.vertical_alignment")
344
+
345
+ output_dir = os.path.join(output_dir_root, "ass", "tokens")
346
+ os.makedirs(output_dir, exist_ok=True)
347
+ output_file = os.path.join(output_dir, f"{utt_obj.utt_id}.ass")
348
+
349
+ already_spoken_color_code = r"{\c&H" + rgb_list_to_hex_bgr(ass_file_config.text_already_spoken_rgb) + r"&}"
350
+ being_spoken_color_code = r"{\c&H" + rgb_list_to_hex_bgr(ass_file_config.text_being_spoken_rgb) + r"&}"
351
+ not_yet_spoken_color_code = r"{\c&H" + rgb_list_to_hex_bgr(ass_file_config.text_not_yet_spoken_rgb) + r"&}"
352
+
353
+ with open(output_file, 'w') as f:
354
+ default_style_top_line = "Format: " + ", ".join(default_style_dict.keys())
355
+ default_style_bottom_line = "Style: " + ",".join(default_style_dict.values())
356
+
357
+ f.write(
358
+ (
359
+ "[Script Info]\n"
360
+ "ScriptType: v4.00+\n"
361
+ f"PlayResX: {PLAYERRESX}\n"
362
+ f"PlayResY: {PLAYERRESY}\n"
363
+ "ScaledBorderAndShadow: yes\n"
364
+ "\n"
365
+ "[V4+ Styles]\n"
366
+ f"{default_style_top_line}\n"
367
+ f"{default_style_bottom_line}\n"
368
+ "\n"
369
+ "[Events]\n"
370
+ "Format: Layer, Start, End, Style, Name, MarginL, MarginR, MarginV, Effect, Text\n\n"
371
+ )
372
+ )
373
+
374
+ # write first set of subtitles for text before speech starts to be spoken
375
+ tokens_in_first_segment = []
376
+ for segment_or_token in utt_obj.segments_and_tokens:
377
+ if type(segment_or_token) is Segment:
378
+ for word_or_token in segment_or_token.words_and_tokens:
379
+ if type(word_or_token) is Token:
380
+ if word_or_token.text != BLANK_TOKEN:
381
+ tokens_in_first_segment.append(word_or_token)
382
+ else:
383
+ for token in word_or_token.tokens:
384
+ if token.text != BLANK_TOKEN:
385
+ tokens_in_first_segment.append(token)
386
+
387
+ break
388
+
389
+ for token in tokens_in_first_segment:
390
+ token.text_cased = token.text_cased.replace(
391
+ "▁", " "
392
+ ) # replace underscores used in subword tokens with spaces
393
+ token.text_cased = token.text_cased.replace(SPACE_TOKEN, " ") # space token with actual space
394
+
395
+ text_before_speech = (
396
+ not_yet_spoken_color_code + "".join([x.text_cased for x in tokens_in_first_segment]) + r"{\r}"
397
+ )
398
+ subtitle_text = (
399
+ f"Dialogue: 0,{seconds_to_ass_format(0)},{seconds_to_ass_format(tokens_in_first_segment[0].t_start)},Default,,0,0,0,,"
400
+ + text_before_speech.rstrip()
401
+ )
402
+
403
+ f.write(subtitle_text + '\n')
404
+
405
+ for segment_or_token in utt_obj.segments_and_tokens:
406
+ if type(segment_or_token) is Segment:
407
+ segment = segment_or_token
408
+
409
+ tokens_in_segment = [] # make list of (non-blank) tokens
410
+ for word_or_token in segment.words_and_tokens:
411
+ if type(word_or_token) is Token:
412
+ if word_or_token.text != BLANK_TOKEN:
413
+ tokens_in_segment.append(word_or_token)
414
+ else:
415
+ for token in word_or_token.tokens:
416
+ if token.text != BLANK_TOKEN:
417
+ tokens_in_segment.append(token)
418
+
419
+ for token in tokens_in_segment:
420
+ token.text_cased = token.text_cased.replace(
421
+ "▁", " "
422
+ ) # replace underscores used in subword tokens with spaces
423
+ token.text_cased = token.text_cased.replace(SPACE_TOKEN, " ") # space token with actual space
424
+
425
+ for token_i, token in enumerate(tokens_in_segment):
426
+
427
+ text_before = "".join([x.text_cased for x in tokens_in_segment[:token_i]])
428
+ text_before = already_spoken_color_code + text_before + r"{\r}"
429
+
430
+ if token_i < len(tokens_in_segment) - 1:
431
+ text_after = "".join([x.text_cased for x in tokens_in_segment[token_i + 1 :]])
432
+ else:
433
+ text_after = ""
434
+ text_after = not_yet_spoken_color_code + text_after + r"{\r}"
435
+
436
+ aligned_text = being_spoken_color_code + token.text_cased + r"{\r}"
437
+ aligned_text_off = already_spoken_color_code + token.text_cased + r"{\r}"
438
+
439
+ subtitle_text = (
440
+ f"Dialogue: 0,{seconds_to_ass_format(token.t_start)},{seconds_to_ass_format(token.t_end)},Default,,0,0,0,,"
441
+ + text_before
442
+ + aligned_text
443
+ + text_after.rstrip()
444
+ )
445
+ f.write(subtitle_text + '\n')
446
+
447
+ # add subtitles without word-highlighting for when words are not being spoken
448
+ if token_i < len(tokens_in_segment) - 1:
449
+ last_token_end = float(tokens_in_segment[token_i].t_end)
450
+ next_token_start = float(tokens_in_segment[token_i + 1].t_start)
451
+ if next_token_start - last_token_end > 0.001:
452
+ subtitle_text = (
453
+ f"Dialogue: 0,{seconds_to_ass_format(last_token_end)},{seconds_to_ass_format(next_token_start)},Default,,0,0,0,,"
454
+ + text_before
455
+ + aligned_text_off
456
+ + text_after.rstrip()
457
+ )
458
+ f.write(subtitle_text + '\n')
459
+
460
+ utt_obj.saved_output_files[f"tokens_level_ass_filepath"] = output_file
461
+
462
+ return utt_obj
utils/make_ctm_files.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+
17
+ import soundfile as sf
18
+ from utils.constants import BLANK_TOKEN, SPACE_TOKEN
19
+ from utils.data_prep import Segment, Word
20
+
21
+
22
+ def make_ctm_files(
23
+ utt_obj, output_dir_root, ctm_file_config,
24
+ ):
25
+ """
26
+ Function to save CTM files for all the utterances in the incoming batch.
27
+ """
28
+
29
+ # don't try to make files if utt_obj.segments_and_tokens is empty, which will happen
30
+ # in the case of the ground truth text being empty or the number of tokens being too large vs audio duration
31
+ if not utt_obj.segments_and_tokens:
32
+ return utt_obj
33
+
34
+ # get audio file duration if we will need it later
35
+ if ctm_file_config.minimum_timestamp_duration > 0:
36
+ with sf.SoundFile(utt_obj.audio_filepath) as f:
37
+ audio_file_duration = f.frames / f.samplerate
38
+ else:
39
+ audio_file_duration = None
40
+
41
+ utt_obj = make_ctm("tokens", utt_obj, output_dir_root, audio_file_duration, ctm_file_config,)
42
+ utt_obj = make_ctm("words", utt_obj, output_dir_root, audio_file_duration, ctm_file_config,)
43
+ utt_obj = make_ctm("segments", utt_obj, output_dir_root, audio_file_duration, ctm_file_config,)
44
+
45
+ return utt_obj
46
+
47
+
48
+ def make_ctm(
49
+ alignment_level, utt_obj, output_dir_root, audio_file_duration, ctm_file_config,
50
+ ):
51
+ output_dir = os.path.join(output_dir_root, "ctm", alignment_level)
52
+ os.makedirs(output_dir, exist_ok=True)
53
+
54
+ boundary_info_utt = []
55
+ for segment_or_token in utt_obj.segments_and_tokens:
56
+ if type(segment_or_token) is Segment:
57
+ segment = segment_or_token
58
+ if alignment_level == "segments":
59
+ boundary_info_utt.append(segment)
60
+
61
+ for word_or_token in segment.words_and_tokens:
62
+ if type(word_or_token) is Word:
63
+ word = word_or_token
64
+ if alignment_level == "words":
65
+ boundary_info_utt.append(word)
66
+
67
+ for token in word.tokens:
68
+ if alignment_level == "tokens":
69
+ boundary_info_utt.append(token)
70
+
71
+ else:
72
+ token = word_or_token
73
+ if alignment_level == "tokens":
74
+ boundary_info_utt.append(token)
75
+
76
+ else:
77
+ token = segment_or_token
78
+ if alignment_level == "tokens":
79
+ boundary_info_utt.append(token)
80
+
81
+ with open(os.path.join(output_dir, f"{utt_obj.utt_id}.ctm"), "w") as f_ctm:
82
+ for boundary_info_ in boundary_info_utt: # loop over every token/word/segment
83
+
84
+ # skip if t_start = t_end = negative number because we used it as a marker to skip some blank tokens
85
+ if not (boundary_info_.t_start < 0 or boundary_info_.t_end < 0):
86
+ text = boundary_info_.text
87
+ start_time = boundary_info_.t_start
88
+ end_time = boundary_info_.t_end
89
+
90
+ if (
91
+ ctm_file_config.minimum_timestamp_duration > 0
92
+ and ctm_file_config.minimum_timestamp_duration > end_time - start_time
93
+ ):
94
+ # make the predicted duration of the token/word/segment longer, growing it outwards equal
95
+ # amounts from the predicted center of the token/word/segment
96
+ token_mid_point = (start_time + end_time) / 2
97
+ start_time = max(token_mid_point - ctm_file_config.minimum_timestamp_duration / 2, 0)
98
+ end_time = min(
99
+ token_mid_point + ctm_file_config.minimum_timestamp_duration / 2, audio_file_duration
100
+ )
101
+
102
+ if not (
103
+ text == BLANK_TOKEN and ctm_file_config.remove_blank_tokens
104
+ ): # don't save blanks if we don't want to
105
+ # replace any spaces with <space> so we dont introduce extra space characters to our CTM files
106
+ text = text.replace(" ", SPACE_TOKEN)
107
+
108
+ f_ctm.write(f"{utt_obj.utt_id} 1 {start_time:.2f} {end_time - start_time:.2f} {text}\n")
109
+
110
+ utt_obj.saved_output_files[f"{alignment_level}_level_ctm_filepath"] = os.path.join(
111
+ output_dir, f"{utt_obj.utt_id}.ctm"
112
+ )
113
+
114
+ return utt_obj
utils/make_output_manifest.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import json
16
+
17
+
18
+ def write_manifest_out_line(
19
+ f_manifest_out, utt_obj,
20
+ ):
21
+
22
+ data = {"audio_filepath": utt_obj.audio_filepath}
23
+ if not utt_obj.text is None:
24
+ data["text"] = utt_obj.text
25
+
26
+ if not utt_obj.pred_text is None:
27
+ data["pred_text"] = utt_obj.pred_text
28
+
29
+ for key, val in utt_obj.saved_output_files.items():
30
+ data[key] = val
31
+
32
+ new_line = json.dumps(data)
33
+ f_manifest_out.write(f"{new_line}\n")
34
+
35
+ return None
utils/viterbi_decoding.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ from utils.constants import V_NEGATIVE_NUM
17
+
18
+
19
+ def viterbi_decoding(log_probs_batch, y_batch, T_batch, U_batch, viterbi_device):
20
+ """
21
+ Do Viterbi decoding with an efficient algorithm (the only for-loop in the 'forward pass' is over the time dimension).
22
+ Args:
23
+ log_probs_batch: tensor of shape (B, T_max, V). The parts of log_probs_batch which are 'padding' are filled
24
+ with 'V_NEGATIVE_NUM' - a large negative number which represents a very low probability.
25
+ y_batch: tensor of shape (B, U_max) - contains token IDs including blanks in every other position. The parts of
26
+ y_batch which are padding are filled with the number 'V'. V = the number of tokens in the vocabulary + 1 for
27
+ the blank token.
28
+ T_batch: tensor of shape (B, 1) - contains the durations of the log_probs_batch (so we can ignore the
29
+ parts of log_probs_batch which are padding)
30
+ U_batch: tensor of shape (B, 1) - contains the lengths of y_batch (so we can ignore the parts of y_batch
31
+ which are padding).
32
+ viterbi_device: the torch device on which Viterbi decoding will be done.
33
+
34
+ Returns:
35
+ alignments_batch: list of lists containing locations for the tokens we align to at each timestep.
36
+ Looks like: [[0, 0, 1, 2, 2, 3, 3, ..., ], ..., [0, 1, 2, 2, 2, 3, 4, ....]].
37
+ Each list inside alignments_batch is of length T_batch[location of utt in batch].
38
+ """
39
+
40
+ B, T_max, _ = log_probs_batch.shape
41
+ U_max = y_batch.shape[1]
42
+
43
+ # transfer all tensors to viterbi_device
44
+ log_probs_batch = log_probs_batch.to(viterbi_device)
45
+ y_batch = y_batch.to(viterbi_device)
46
+ T_batch = T_batch.to(viterbi_device)
47
+ U_batch = U_batch.to(viterbi_device)
48
+
49
+ # make tensor that we will put at timesteps beyond the duration of the audio
50
+ padding_for_log_probs = V_NEGATIVE_NUM * torch.ones((B, T_max, 1), device=viterbi_device)
51
+ # make log_probs_padded tensor of shape (B, T_max, V +1 ) where all of
52
+ # log_probs_padded[:,:,-1] is the 'V_NEGATIVE_NUM'
53
+ log_probs_padded = torch.cat((log_probs_batch, padding_for_log_probs), dim=2)
54
+
55
+ # initialize v_prev - tensor of previous timestep's viterbi probabilies, of shape (B, U_max)
56
+ v_prev = V_NEGATIVE_NUM * torch.ones((B, U_max), device=viterbi_device)
57
+ v_prev[:, :2] = torch.gather(input=log_probs_padded[:, 0, :], dim=1, index=y_batch[:, :2])
58
+
59
+ # initialize backpointers_rel - which contains values like 0 to indicate the backpointer is to the same u index,
60
+ # 1 to indicate the backpointer pointing to the u-1 index and 2 to indicate the backpointer is pointing to the u-2 index
61
+ backpointers_rel = -99 * torch.ones((B, T_max, U_max), dtype=torch.int8, device=viterbi_device)
62
+
63
+ # Make a letter_repetition_mask the same shape as y_batch
64
+ # the letter_repetition_mask will have 'True' where the token (including blanks) is the same
65
+ # as the token two places before it in the ground truth (and 'False everywhere else).
66
+ # We will use letter_repetition_mask to determine whether the Viterbi algorithm needs to look two tokens back or
67
+ # three tokens back
68
+ y_shifted_left = torch.roll(y_batch, shifts=2, dims=1)
69
+ letter_repetition_mask = y_batch - y_shifted_left
70
+ letter_repetition_mask[:, :2] = 1 # make sure dont apply mask to first 2 tokens
71
+ letter_repetition_mask = letter_repetition_mask == 0
72
+
73
+ for t in range(1, T_max):
74
+
75
+ # e_current is a tensor of shape (B, U_max) of the log probs of every possible token at the current timestep
76
+ e_current = torch.gather(input=log_probs_padded[:, t, :], dim=1, index=y_batch)
77
+
78
+ # apply a mask to e_current to cope with the fact that we do not keep the whole v_matrix and continue
79
+ # calculating viterbi probabilities during some 'padding' timesteps
80
+ t_exceeded_T_batch = t >= T_batch
81
+
82
+ U_can_be_final = torch.logical_or(
83
+ torch.arange(0, U_max, device=viterbi_device).unsqueeze(0) == (U_batch.unsqueeze(1) - 0),
84
+ torch.arange(0, U_max, device=viterbi_device).unsqueeze(0) == (U_batch.unsqueeze(1) - 1),
85
+ )
86
+
87
+ mask = torch.logical_not(torch.logical_and(t_exceeded_T_batch.unsqueeze(1), U_can_be_final,)).long()
88
+
89
+ e_current = e_current * mask
90
+
91
+ # v_prev_shifted is a tensor of shape (B, U_max) of the viterbi probabilities 1 timestep back and 1 token position back
92
+ v_prev_shifted = torch.roll(v_prev, shifts=1, dims=1)
93
+ # by doing a roll shift of size 1, we have brought the viterbi probability in the final token position to the
94
+ # first token position - let's overcome this by 'zeroing out' the probabilities in the firest token position
95
+ v_prev_shifted[:, 0] = V_NEGATIVE_NUM
96
+
97
+ # v_prev_shifted2 is a tensor of shape (B, U_max) of the viterbi probabilities 1 timestep back and 2 token position back
98
+ v_prev_shifted2 = torch.roll(v_prev, shifts=2, dims=1)
99
+ v_prev_shifted2[:, :2] = V_NEGATIVE_NUM # zero out as we did for v_prev_shifted
100
+ # use our letter_repetition_mask to remove the connections between 2 blanks (so we don't skip over a letter)
101
+ # and to remove the connections between 2 consective letters (so we don't skip over a blank)
102
+ v_prev_shifted2.masked_fill_(letter_repetition_mask, V_NEGATIVE_NUM)
103
+
104
+ # we need this v_prev_dup tensor so we can calculated the viterbi probability of every possible
105
+ # token position simultaneously
106
+ v_prev_dup = torch.cat(
107
+ (v_prev.unsqueeze(2), v_prev_shifted.unsqueeze(2), v_prev_shifted2.unsqueeze(2),), dim=2,
108
+ )
109
+
110
+ # candidates_v_current are our candidate viterbi probabilities for every token position, from which
111
+ # we will pick the max and record the argmax
112
+ candidates_v_current = v_prev_dup + e_current.unsqueeze(2)
113
+ # we straight away save results in v_prev instead of v_current, so that the variable v_prev will be ready for the
114
+ # next iteration of the for-loop
115
+ v_prev, bp_relative = torch.max(candidates_v_current, dim=2)
116
+
117
+ backpointers_rel[:, t, :] = bp_relative
118
+
119
+ # trace backpointers
120
+ alignments_batch = []
121
+ for b in range(B):
122
+ T_b = int(T_batch[b])
123
+ U_b = int(U_batch[b])
124
+
125
+ if U_b == 1: # i.e. we put only a blank token in the reference text because the reference text is empty
126
+ current_u = 0 # set initial u to 0 and let the rest of the code block run as usual
127
+ else:
128
+ current_u = int(torch.argmax(v_prev[b, U_b - 2 : U_b])) + U_b - 2
129
+ alignment_b = [current_u]
130
+ for t in range(T_max - 1, 0, -1):
131
+ current_u = current_u - int(backpointers_rel[b, t, current_u])
132
+ alignment_b.insert(0, current_u)
133
+ alignment_b = alignment_b[:T_b]
134
+ alignments_batch.append(alignment_b)
135
+
136
+ return alignments_batch