erastorgueva-nv
commited on
Commit
·
6ffdd29
1
Parent(s):
c7e8b60
Initial commit
Browse files- align.py +352 -0
- app.py +314 -0
- packages.txt +3 -0
- pre-requirements.txt +2 -0
- requirements.txt +1 -0
- utils/constants.py +19 -0
- utils/data_prep.py +835 -0
- utils/make_ass_files.py +462 -0
- utils/make_ctm_files.py +114 -0
- utils/make_output_manifest.py +35 -0
- utils/viterbi_decoding.py +136 -0
align.py
ADDED
@@ -0,0 +1,352 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import copy
|
16 |
+
import math
|
17 |
+
import os
|
18 |
+
from dataclasses import dataclass, field, is_dataclass
|
19 |
+
from pathlib import Path
|
20 |
+
from typing import List, Optional
|
21 |
+
|
22 |
+
import torch
|
23 |
+
from omegaconf import OmegaConf
|
24 |
+
from utils.data_prep import (
|
25 |
+
add_t_start_end_to_utt_obj,
|
26 |
+
get_batch_starts_ends,
|
27 |
+
get_batch_variables,
|
28 |
+
get_manifest_lines_batch,
|
29 |
+
is_entry_in_all_lines,
|
30 |
+
is_entry_in_any_lines,
|
31 |
+
)
|
32 |
+
from utils.make_ass_files import make_ass_files
|
33 |
+
from utils.make_ctm_files import make_ctm_files
|
34 |
+
from utils.make_output_manifest import write_manifest_out_line
|
35 |
+
from utils.viterbi_decoding import viterbi_decoding
|
36 |
+
|
37 |
+
from nemo.collections.asr.models.ctc_models import EncDecCTCModel
|
38 |
+
from nemo.collections.asr.models.hybrid_rnnt_ctc_models import EncDecHybridRNNTCTCModel
|
39 |
+
from nemo.collections.asr.parts.utils.streaming_utils import FrameBatchASR
|
40 |
+
from nemo.collections.asr.parts.utils.transcribe_utils import setup_model
|
41 |
+
from nemo.core.config import hydra_runner
|
42 |
+
from nemo.utils import logging
|
43 |
+
|
44 |
+
"""
|
45 |
+
Align the utterances in manifest_filepath.
|
46 |
+
Results are saved in ctm files in output_dir.
|
47 |
+
|
48 |
+
Arguments:
|
49 |
+
pretrained_name: string specifying the name of a CTC NeMo ASR model which will be automatically downloaded
|
50 |
+
from NGC and used for generating the log-probs which we will use to do alignment.
|
51 |
+
Note: NFA can only use CTC models (not Transducer models) at the moment.
|
52 |
+
model_path: string specifying the local filepath to a CTC NeMo ASR model which will be used to generate the
|
53 |
+
log-probs which we will use to do alignment.
|
54 |
+
Note: NFA can only use CTC models (not Transducer models) at the moment.
|
55 |
+
Note: if a model_path is provided, it will override the pretrained_name.
|
56 |
+
manifest_filepath: filepath to the manifest of the data you want to align,
|
57 |
+
containing 'audio_filepath' and 'text' fields.
|
58 |
+
output_dir: the folder where output CTM files and new JSON manifest will be saved.
|
59 |
+
align_using_pred_text: if True, will transcribe the audio using the specified model and then use that transcription
|
60 |
+
as the reference text for the forced alignment.
|
61 |
+
transcribe_device: None, or a string specifying the device that will be used for generating log-probs (i.e. "transcribing").
|
62 |
+
The string needs to be in a format recognized by torch.device(). If None, NFA will set it to 'cuda' if it is available
|
63 |
+
(otherwise will set it to 'cpu').
|
64 |
+
viterbi_device: None, or string specifying the device that will be used for doing Viterbi decoding.
|
65 |
+
The string needs to be in a format recognized by torch.device(). If None, NFA will set it to 'cuda' if it is available
|
66 |
+
(otherwise will set it to 'cpu').
|
67 |
+
batch_size: int specifying batch size that will be used for generating log-probs and doing Viterbi decoding.
|
68 |
+
use_local_attention: boolean flag specifying whether to try to use local attention for the ASR Model (will only
|
69 |
+
work if the ASR Model is a Conformer model). If local attention is used, we will set the local attention context
|
70 |
+
size to [64,64].
|
71 |
+
additional_segment_grouping_separator: an optional string used to separate the text into smaller segments.
|
72 |
+
If this is not specified, then the whole text will be treated as a single segment.
|
73 |
+
remove_blank_tokens_from_ctm: a boolean denoting whether to remove <blank> tokens from token-level output CTMs.
|
74 |
+
audio_filepath_parts_in_utt_id: int specifying how many of the 'parts' of the audio_filepath
|
75 |
+
we will use (starting from the final part of the audio_filepath) to determine the
|
76 |
+
utt_id that will be used in the CTM files. Note also that any spaces that are present in the audio_filepath
|
77 |
+
will be replaced with dashes, so as not to change the number of space-separated elements in the
|
78 |
+
CTM files.
|
79 |
+
e.g. if audio_filepath is "/a/b/c/d/e 1.wav" and audio_filepath_parts_in_utt_id is 1 => utt_id will be "e1"
|
80 |
+
e.g. if audio_filepath is "/a/b/c/d/e 1.wav" and audio_filepath_parts_in_utt_id is 2 => utt_id will be "d_e1"
|
81 |
+
e.g. if audio_filepath is "/a/b/c/d/e 1.wav" and audio_filepath_parts_in_utt_id is 3 => utt_id will be "c_d_e1"
|
82 |
+
use_buffered_infer: False, if set True, using streaming to do get the logits for alignment
|
83 |
+
This flag is useful when aligning large audio file.
|
84 |
+
However, currently the chunk streaming inference does not support batch inference,
|
85 |
+
which means even you set batch_size > 1, it will only infer one by one instead of doing
|
86 |
+
the whole batch inference together.
|
87 |
+
chunk_len_in_secs: float chunk length in seconds
|
88 |
+
total_buffer_in_secs: float Length of buffer (chunk + left and right padding) in seconds
|
89 |
+
chunk_batch_size: int batch size for buffered chunk inference,
|
90 |
+
which will cut one audio into segments and do inference on chunk_batch_size segments at a time
|
91 |
+
|
92 |
+
simulate_cache_aware_streaming: False, if set True, using cache aware streaming to do get the logits for alignment
|
93 |
+
|
94 |
+
save_output_file_formats: List of strings specifying what type of output files to save (default: ["ctm", "ass"])
|
95 |
+
ctm_file_config: CTMFileConfig to specify the configuration of the output CTM files
|
96 |
+
ass_file_config: ASSFileConfig to specify the configuration of the output ASS files
|
97 |
+
"""
|
98 |
+
|
99 |
+
|
100 |
+
@dataclass
|
101 |
+
class CTMFileConfig:
|
102 |
+
remove_blank_tokens: bool = False
|
103 |
+
# minimum duration (in seconds) for timestamps in the CTM.If any line in the CTM has a
|
104 |
+
# duration lower than this, it will be enlarged from the middle outwards until it
|
105 |
+
# meets the minimum_timestamp_duration, or reaches the beginning or end of the audio file.
|
106 |
+
# Note that this may cause timestamps to overlap.
|
107 |
+
minimum_timestamp_duration: float = 0
|
108 |
+
|
109 |
+
|
110 |
+
@dataclass
|
111 |
+
class ASSFileConfig:
|
112 |
+
fontsize: int = 20
|
113 |
+
vertical_alignment: str = "center"
|
114 |
+
# if resegment_text_to_fill_space is True, the ASS files will use new segments
|
115 |
+
# such that each segment will not take up more than (approximately) max_lines_per_segment
|
116 |
+
# when the ASS file is applied to a video
|
117 |
+
resegment_text_to_fill_space: bool = False
|
118 |
+
max_lines_per_segment: int = 2
|
119 |
+
text_already_spoken_rgb: List[int] = field(default_factory=lambda: [49, 46, 61]) # dark gray
|
120 |
+
text_being_spoken_rgb: List[int] = field(default_factory=lambda: [57, 171, 9]) # dark green
|
121 |
+
text_not_yet_spoken_rgb: List[int] = field(default_factory=lambda: [194, 193, 199]) # light gray
|
122 |
+
|
123 |
+
|
124 |
+
@dataclass
|
125 |
+
class AlignmentConfig:
|
126 |
+
# Required configs
|
127 |
+
pretrained_name: Optional[str] = None
|
128 |
+
model_path: Optional[str] = None
|
129 |
+
manifest_filepath: Optional[str] = None
|
130 |
+
output_dir: Optional[str] = None
|
131 |
+
|
132 |
+
# General configs
|
133 |
+
align_using_pred_text: bool = False
|
134 |
+
transcribe_device: Optional[str] = None
|
135 |
+
viterbi_device: Optional[str] = None
|
136 |
+
batch_size: int = 1
|
137 |
+
use_local_attention: bool = True
|
138 |
+
additional_segment_grouping_separator: Optional[str] = None
|
139 |
+
audio_filepath_parts_in_utt_id: int = 1
|
140 |
+
|
141 |
+
# Buffered chunked streaming configs
|
142 |
+
use_buffered_chunked_streaming: bool = False
|
143 |
+
chunk_len_in_secs: float = 1.6
|
144 |
+
total_buffer_in_secs: float = 4.0
|
145 |
+
chunk_batch_size: int = 32
|
146 |
+
|
147 |
+
# Cache aware streaming configs
|
148 |
+
simulate_cache_aware_streaming: Optional[bool] = False
|
149 |
+
|
150 |
+
# Output file configs
|
151 |
+
save_output_file_formats: List[str] = field(default_factory=lambda: ["ctm", "ass"])
|
152 |
+
ctm_file_config: CTMFileConfig = CTMFileConfig()
|
153 |
+
ass_file_config: ASSFileConfig = ASSFileConfig()
|
154 |
+
|
155 |
+
|
156 |
+
@hydra_runner(config_name="AlignmentConfig", schema=AlignmentConfig)
|
157 |
+
def main(cfg: AlignmentConfig):
|
158 |
+
|
159 |
+
logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}')
|
160 |
+
|
161 |
+
if is_dataclass(cfg):
|
162 |
+
cfg = OmegaConf.structured(cfg)
|
163 |
+
|
164 |
+
# Validate config
|
165 |
+
if cfg.model_path is None and cfg.pretrained_name is None:
|
166 |
+
raise ValueError("Both cfg.model_path and cfg.pretrained_name cannot be None")
|
167 |
+
|
168 |
+
if cfg.model_path is not None and cfg.pretrained_name is not None:
|
169 |
+
raise ValueError("One of cfg.model_path and cfg.pretrained_name must be None")
|
170 |
+
|
171 |
+
if cfg.manifest_filepath is None:
|
172 |
+
raise ValueError("cfg.manifest_filepath must be specified")
|
173 |
+
|
174 |
+
if cfg.output_dir is None:
|
175 |
+
raise ValueError("cfg.output_dir must be specified")
|
176 |
+
|
177 |
+
if cfg.batch_size < 1:
|
178 |
+
raise ValueError("cfg.batch_size cannot be zero or a negative number")
|
179 |
+
|
180 |
+
if cfg.additional_segment_grouping_separator == "" or cfg.additional_segment_grouping_separator == " ":
|
181 |
+
raise ValueError("cfg.additional_grouping_separator cannot be empty string or space character")
|
182 |
+
|
183 |
+
if cfg.ctm_file_config.minimum_timestamp_duration < 0:
|
184 |
+
raise ValueError("cfg.minimum_timestamp_duration cannot be a negative number")
|
185 |
+
|
186 |
+
if cfg.ass_file_config.vertical_alignment not in ["top", "center", "bottom"]:
|
187 |
+
raise ValueError("cfg.ass_file_config.vertical_alignment must be one of 'top', 'center' or 'bottom'")
|
188 |
+
|
189 |
+
for rgb_list in [
|
190 |
+
cfg.ass_file_config.text_already_spoken_rgb,
|
191 |
+
cfg.ass_file_config.text_already_spoken_rgb,
|
192 |
+
cfg.ass_file_config.text_already_spoken_rgb,
|
193 |
+
]:
|
194 |
+
if len(rgb_list) != 3:
|
195 |
+
raise ValueError(
|
196 |
+
"cfg.ass_file_config.text_already_spoken_rgb,"
|
197 |
+
" cfg.ass_file_config.text_being_spoken_rgb,"
|
198 |
+
" and cfg.ass_file_config.text_already_spoken_rgb all need to contain"
|
199 |
+
" exactly 3 elements."
|
200 |
+
)
|
201 |
+
|
202 |
+
# Validate manifest contents
|
203 |
+
if not is_entry_in_all_lines(cfg.manifest_filepath, "audio_filepath"):
|
204 |
+
raise RuntimeError(
|
205 |
+
"At least one line in cfg.manifest_filepath does not contain an 'audio_filepath' entry. "
|
206 |
+
"All lines must contain an 'audio_filepath' entry."
|
207 |
+
)
|
208 |
+
|
209 |
+
if cfg.align_using_pred_text:
|
210 |
+
if is_entry_in_any_lines(cfg.manifest_filepath, "pred_text"):
|
211 |
+
raise RuntimeError(
|
212 |
+
"Cannot specify cfg.align_using_pred_text=True when the manifest at cfg.manifest_filepath "
|
213 |
+
"contains 'pred_text' entries. This is because the audio will be transcribed and may produce "
|
214 |
+
"a different 'pred_text'. This may cause confusion."
|
215 |
+
)
|
216 |
+
else:
|
217 |
+
if not is_entry_in_all_lines(cfg.manifest_filepath, "text"):
|
218 |
+
raise RuntimeError(
|
219 |
+
"At least one line in cfg.manifest_filepath does not contain a 'text' entry. "
|
220 |
+
"NFA requires all lines to contain a 'text' entry when cfg.align_using_pred_text=False."
|
221 |
+
)
|
222 |
+
|
223 |
+
# init devices
|
224 |
+
if cfg.transcribe_device is None:
|
225 |
+
transcribe_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
226 |
+
else:
|
227 |
+
transcribe_device = torch.device(cfg.transcribe_device)
|
228 |
+
logging.info(f"Device to be used for transcription step (`transcribe_device`) is {transcribe_device}")
|
229 |
+
|
230 |
+
if cfg.viterbi_device is None:
|
231 |
+
viterbi_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
232 |
+
else:
|
233 |
+
viterbi_device = torch.device(cfg.viterbi_device)
|
234 |
+
logging.info(f"Device to be used for viterbi step (`viterbi_device`) is {viterbi_device}")
|
235 |
+
|
236 |
+
if transcribe_device.type == 'cuda' or viterbi_device.type == 'cuda':
|
237 |
+
logging.warning(
|
238 |
+
'One or both of transcribe_device and viterbi_device are GPUs. If you run into OOM errors '
|
239 |
+
'it may help to change both devices to be the CPU.'
|
240 |
+
)
|
241 |
+
|
242 |
+
# load model
|
243 |
+
model, _ = setup_model(cfg, transcribe_device)
|
244 |
+
model.eval()
|
245 |
+
|
246 |
+
if isinstance(model, EncDecHybridRNNTCTCModel):
|
247 |
+
model.change_decoding_strategy(decoder_type="ctc")
|
248 |
+
|
249 |
+
if cfg.use_local_attention:
|
250 |
+
logging.info(
|
251 |
+
"Flag use_local_attention is set to True => will try to use local attention for model if it allows it"
|
252 |
+
)
|
253 |
+
model.change_attention_model(self_attention_model="rel_pos_local_attn", att_context_size=[64, 64])
|
254 |
+
|
255 |
+
if not (isinstance(model, EncDecCTCModel) or isinstance(model, EncDecHybridRNNTCTCModel)):
|
256 |
+
raise NotImplementedError(
|
257 |
+
f"Model is not an instance of NeMo EncDecCTCModel or ENCDecHybridRNNTCTCModel."
|
258 |
+
" Currently only instances of these models are supported"
|
259 |
+
)
|
260 |
+
|
261 |
+
if cfg.ctm_file_config.minimum_timestamp_duration > 0:
|
262 |
+
logging.warning(
|
263 |
+
f"cfg.ctm_file_config.minimum_timestamp_duration has been set to {cfg.ctm_file_config.minimum_timestamp_duration} seconds. "
|
264 |
+
"This may cause the alignments for some tokens/words/additional segments to be overlapping."
|
265 |
+
)
|
266 |
+
|
267 |
+
buffered_chunk_params = {}
|
268 |
+
if cfg.use_buffered_chunked_streaming:
|
269 |
+
model_cfg = copy.deepcopy(model._cfg)
|
270 |
+
|
271 |
+
OmegaConf.set_struct(model_cfg.preprocessor, False)
|
272 |
+
# some changes for streaming scenario
|
273 |
+
model_cfg.preprocessor.dither = 0.0
|
274 |
+
model_cfg.preprocessor.pad_to = 0
|
275 |
+
|
276 |
+
if model_cfg.preprocessor.normalize != "per_feature":
|
277 |
+
logging.error(
|
278 |
+
"Only EncDecCTCModelBPE models trained with per_feature normalization are supported currently"
|
279 |
+
)
|
280 |
+
# Disable config overwriting
|
281 |
+
OmegaConf.set_struct(model_cfg.preprocessor, True)
|
282 |
+
|
283 |
+
feature_stride = model_cfg.preprocessor['window_stride']
|
284 |
+
model_stride_in_secs = feature_stride * cfg.model_downsample_factor
|
285 |
+
total_buffer = cfg.total_buffer_in_secs
|
286 |
+
chunk_len = float(cfg.chunk_len_in_secs)
|
287 |
+
tokens_per_chunk = math.ceil(chunk_len / model_stride_in_secs)
|
288 |
+
mid_delay = math.ceil((chunk_len + (total_buffer - chunk_len) / 2) / model_stride_in_secs)
|
289 |
+
logging.info(f"tokens_per_chunk is {tokens_per_chunk}, mid_delay is {mid_delay}")
|
290 |
+
|
291 |
+
model = FrameBatchASR(
|
292 |
+
asr_model=model,
|
293 |
+
frame_len=chunk_len,
|
294 |
+
total_buffer=cfg.total_buffer_in_secs,
|
295 |
+
batch_size=cfg.chunk_batch_size,
|
296 |
+
)
|
297 |
+
buffered_chunk_params = {
|
298 |
+
"delay": mid_delay,
|
299 |
+
"model_stride_in_secs": model_stride_in_secs,
|
300 |
+
"tokens_per_chunk": tokens_per_chunk,
|
301 |
+
}
|
302 |
+
# get start and end line IDs of batches
|
303 |
+
starts, ends = get_batch_starts_ends(cfg.manifest_filepath, cfg.batch_size)
|
304 |
+
|
305 |
+
# init output_timestep_duration = None and we will calculate and update it during the first batch
|
306 |
+
output_timestep_duration = None
|
307 |
+
|
308 |
+
# init f_manifest_out
|
309 |
+
os.makedirs(cfg.output_dir, exist_ok=True)
|
310 |
+
tgt_manifest_name = str(Path(cfg.manifest_filepath).stem) + "_with_output_file_paths.json"
|
311 |
+
tgt_manifest_filepath = str(Path(cfg.output_dir) / tgt_manifest_name)
|
312 |
+
f_manifest_out = open(tgt_manifest_filepath, 'w')
|
313 |
+
|
314 |
+
# get alignment and save in CTM batch-by-batch
|
315 |
+
for start, end in zip(starts, ends):
|
316 |
+
manifest_lines_batch = get_manifest_lines_batch(cfg.manifest_filepath, start, end)
|
317 |
+
|
318 |
+
(log_probs_batch, y_batch, T_batch, U_batch, utt_obj_batch, output_timestep_duration,) = get_batch_variables(
|
319 |
+
manifest_lines_batch,
|
320 |
+
model,
|
321 |
+
cfg.additional_segment_grouping_separator,
|
322 |
+
cfg.align_using_pred_text,
|
323 |
+
cfg.audio_filepath_parts_in_utt_id,
|
324 |
+
output_timestep_duration,
|
325 |
+
cfg.simulate_cache_aware_streaming,
|
326 |
+
cfg.use_buffered_chunked_streaming,
|
327 |
+
buffered_chunk_params,
|
328 |
+
)
|
329 |
+
|
330 |
+
alignments_batch = viterbi_decoding(log_probs_batch, y_batch, T_batch, U_batch, viterbi_device)
|
331 |
+
|
332 |
+
for utt_obj, alignment_utt in zip(utt_obj_batch, alignments_batch):
|
333 |
+
|
334 |
+
utt_obj = add_t_start_end_to_utt_obj(utt_obj, alignment_utt, output_timestep_duration)
|
335 |
+
|
336 |
+
if "ctm" in cfg.save_output_file_formats:
|
337 |
+
utt_obj = make_ctm_files(utt_obj, cfg.output_dir, cfg.ctm_file_config,)
|
338 |
+
|
339 |
+
if "ass" in cfg.save_output_file_formats:
|
340 |
+
utt_obj = make_ass_files(utt_obj, cfg.output_dir, cfg.ass_file_config)
|
341 |
+
|
342 |
+
write_manifest_out_line(
|
343 |
+
f_manifest_out, utt_obj,
|
344 |
+
)
|
345 |
+
|
346 |
+
f_manifest_out.close()
|
347 |
+
|
348 |
+
return None
|
349 |
+
|
350 |
+
|
351 |
+
if __name__ == "__main__":
|
352 |
+
main()
|
app.py
ADDED
@@ -0,0 +1,314 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import librosa
|
3 |
+
import soundfile
|
4 |
+
import tempfile
|
5 |
+
import os
|
6 |
+
import uuid
|
7 |
+
import json
|
8 |
+
|
9 |
+
import jieba
|
10 |
+
|
11 |
+
import nemo.collections.asr as nemo_asr
|
12 |
+
from nemo.collections.asr.models import ASRModel
|
13 |
+
from nemo.utils import logging
|
14 |
+
|
15 |
+
from align import main, AlignmentConfig, ASSFileConfig
|
16 |
+
|
17 |
+
|
18 |
+
SAMPLE_RATE = 16000
|
19 |
+
|
20 |
+
# Pre-download and cache the model in disk space
|
21 |
+
logging.setLevel(logging.ERROR)
|
22 |
+
for tmp_model_name in [
|
23 |
+
"stt_en_fastconformer_hybrid_large_pc",
|
24 |
+
"stt_de_fastconformer_hybrid_large_pc",
|
25 |
+
"stt_es_fastconformer_hybrid_large_pc",
|
26 |
+
"stt_fr_conformer_ctc_large",
|
27 |
+
"stt_zh_citrinet_1024_gamma_0_25",
|
28 |
+
]:
|
29 |
+
tmp_model = ASRModel.from_pretrained(tmp_model_name, map_location='cpu')
|
30 |
+
del tmp_model
|
31 |
+
logging.setLevel(logging.INFO)
|
32 |
+
|
33 |
+
|
34 |
+
def get_audio_data_and_duration(file):
|
35 |
+
data, sr = librosa.load(file)
|
36 |
+
|
37 |
+
if sr != SAMPLE_RATE:
|
38 |
+
data = librosa.resample(data, orig_sr=sr, target_sr=SAMPLE_RATE)
|
39 |
+
|
40 |
+
# monochannel
|
41 |
+
data = librosa.to_mono(data)
|
42 |
+
|
43 |
+
duration = librosa.get_duration(y=data, sr=SAMPLE_RATE)
|
44 |
+
return data, duration
|
45 |
+
|
46 |
+
|
47 |
+
def get_char_tokens(text, model):
|
48 |
+
tokens = []
|
49 |
+
for character in text:
|
50 |
+
if character in model.decoder.vocabulary:
|
51 |
+
tokens.append(model.decoder.vocabulary.index(character))
|
52 |
+
else:
|
53 |
+
tokens.append(len(model.decoder.vocabulary)) # return unk token (same as blank token)
|
54 |
+
|
55 |
+
return tokens
|
56 |
+
|
57 |
+
|
58 |
+
def get_S_prime_and_T(text, model_name, model, audio_duration):
|
59 |
+
|
60 |
+
# estimate T
|
61 |
+
if "citrinet" in model_name or "_fastconformer_" in model_name:
|
62 |
+
output_timestep_duration = 0.08
|
63 |
+
elif "_conformer_" in model_name:
|
64 |
+
output_timestep_duration = 0.04
|
65 |
+
elif "quartznet" in model_name:
|
66 |
+
output_timestep_duration = 0.02
|
67 |
+
else:
|
68 |
+
raise RuntimeError("unexpected model name")
|
69 |
+
|
70 |
+
T = int(audio_duration / output_timestep_duration) + 1
|
71 |
+
|
72 |
+
# calculate S_prime = num tokens + num repetitions
|
73 |
+
if hasattr(model, 'tokenizer'):
|
74 |
+
all_tokens = model.tokenizer.text_to_ids(text)
|
75 |
+
elif hasattr(model.decoder, "vocabulary"): # i.e. tokenization is simply character-based
|
76 |
+
all_tokens = get_char_tokens(text, model)
|
77 |
+
else:
|
78 |
+
raise RuntimeError("cannot obtain tokens from this model")
|
79 |
+
|
80 |
+
n_token_repetitions = 0
|
81 |
+
for i_tok in range(1, len(all_tokens)):
|
82 |
+
if all_tokens[i_tok] == all_tokens[i_tok - 1]:
|
83 |
+
n_token_repetitions += 1
|
84 |
+
|
85 |
+
S_prime = len(all_tokens) + n_token_repetitions
|
86 |
+
|
87 |
+
print('all_tokens', all_tokens)
|
88 |
+
print(len(all_tokens))
|
89 |
+
print(n_token_repetitions)
|
90 |
+
|
91 |
+
return S_prime, T
|
92 |
+
|
93 |
+
|
94 |
+
def hex_to_rgb_list(hex_string):
|
95 |
+
hex_string = hex_string.lstrip("#")
|
96 |
+
r = int(hex_string[:2], 16)
|
97 |
+
g = int(hex_string[2:4], 16)
|
98 |
+
b = int(hex_string[4:], 16)
|
99 |
+
return [r, g, b]
|
100 |
+
|
101 |
+
def delete_mp4s_except_given_filepath(filepath):
|
102 |
+
files_in_dir = os.listdir()
|
103 |
+
mp4_files_in_dir = [x for x in files_in_dir if x.endswith(".mp4")]
|
104 |
+
for mp4_file in mp4_files_in_dir:
|
105 |
+
if mp4_file != filepath:
|
106 |
+
print('deleting', mp4_file)
|
107 |
+
os.remove(mp4_file)
|
108 |
+
|
109 |
+
|
110 |
+
|
111 |
+
|
112 |
+
def align(lang, Microphone, File_Upload, text, col1, col2, col3, progress=gr.Progress()):
|
113 |
+
# Create utt_id, specify output_video_filepath and delete any MP4s
|
114 |
+
# that are not that filepath. These stray MP4s can be created
|
115 |
+
# if a user refreshes or exits the page while this 'align' function is executing.
|
116 |
+
# This deletion will not delete any other users' video as long as this 'align' function
|
117 |
+
# is run one at a time.
|
118 |
+
utt_id = uuid.uuid4()
|
119 |
+
output_video_filepath = f"{utt_id}.mp4"
|
120 |
+
delete_mp4s_except_given_filepath(output_video_filepath)
|
121 |
+
|
122 |
+
output_info = ""
|
123 |
+
|
124 |
+
progress(0, desc="Validating input")
|
125 |
+
|
126 |
+
# choose model
|
127 |
+
if lang in ["en", "de", "es"]:
|
128 |
+
model_name = f"stt_{lang}_fastconformer_hybrid_large_pc"
|
129 |
+
elif lang in ["fr"]:
|
130 |
+
model_name = f"stt_{lang}_conformer_ctc_large"
|
131 |
+
elif lang in ["zh"]:
|
132 |
+
model_name = f"stt_{lang}_citrinet_1024_gamma_0_25"
|
133 |
+
|
134 |
+
# decide which of Mic / File_Upload is used as input & do error handling
|
135 |
+
if (Microphone is not None) and (File_Upload is not None):
|
136 |
+
raise gr.Error("Please use either the microphone or file upload input - not both")
|
137 |
+
|
138 |
+
elif (Microphone is None) and (File_Upload is None):
|
139 |
+
raise gr.Error("You have to either use the microphone or upload an audio file")
|
140 |
+
|
141 |
+
elif Microphone is not None:
|
142 |
+
file = Microphone
|
143 |
+
else:
|
144 |
+
file = File_Upload
|
145 |
+
|
146 |
+
# check audio is not too long
|
147 |
+
audio_data, duration = get_audio_data_and_duration(file)
|
148 |
+
|
149 |
+
if duration > 4 * 60:
|
150 |
+
raise gr.Error(
|
151 |
+
f"Detected that uploaded audio has duration {duration/60:.1f} mins - please only upload audio of less than 4 mins duration"
|
152 |
+
)
|
153 |
+
|
154 |
+
# loading model
|
155 |
+
progress(0.1, desc="Loading speech recognition model")
|
156 |
+
model = ASRModel.from_pretrained(model_name)
|
157 |
+
|
158 |
+
if text: # check input text is not too long compared to audio
|
159 |
+
S_prime, T = get_S_prime_and_T(text, model_name, model, duration)
|
160 |
+
|
161 |
+
if S_prime > T:
|
162 |
+
raise gr.Error(
|
163 |
+
f"The number of tokens in the input text is too long compared to the duration of the audio."
|
164 |
+
f" This model can handle {T} tokens + token repetitions at most. You have provided {S_prime} tokens + token repetitions. "
|
165 |
+
f" (Adjacent tokens that are not in the model's vocabulary are also counted as a token repetition.)"
|
166 |
+
)
|
167 |
+
|
168 |
+
with tempfile.TemporaryDirectory() as tmpdir:
|
169 |
+
audio_path = os.path.join(tmpdir, f'{utt_id}.wav')
|
170 |
+
soundfile.write(audio_path, audio_data, SAMPLE_RATE)
|
171 |
+
|
172 |
+
# getting the text if it hasn't been provided
|
173 |
+
if not text:
|
174 |
+
progress(0.2, desc="Transcribing audio")
|
175 |
+
text = model.transcribe([audio_path])[0]
|
176 |
+
if 'hybrid' in model_name:
|
177 |
+
text = text[0]
|
178 |
+
print('transcribed text:', text)
|
179 |
+
|
180 |
+
if text == "":
|
181 |
+
raise gr.Error(
|
182 |
+
"ERROR: the ASR model did not detect any speech in the input audio. Please upload audio with speech."
|
183 |
+
)
|
184 |
+
|
185 |
+
output_info += (
|
186 |
+
"You did not enter any input text, so the ASR model's transcription will be used:\n"
|
187 |
+
"--------------------------\n"
|
188 |
+
f"{text}\n"
|
189 |
+
"--------------------------\n"
|
190 |
+
f"You could try pasting the transcription into the text input box, correcting any"
|
191 |
+
" transcription errors, and clicking 'Submit' again."
|
192 |
+
)
|
193 |
+
|
194 |
+
if lang == "zh" and " " not in text:
|
195 |
+
# use jieba to add spaces between zh characters
|
196 |
+
text = " ".join(jieba.cut(text))
|
197 |
+
|
198 |
+
data = {
|
199 |
+
"audio_filepath": audio_path,
|
200 |
+
"text": text,
|
201 |
+
}
|
202 |
+
manifest_path = os.path.join(tmpdir, f"{utt_id}_manifest.json")
|
203 |
+
with open(manifest_path, 'w') as fout:
|
204 |
+
fout.write(f"{json.dumps(data)}\n")
|
205 |
+
|
206 |
+
# run alignment
|
207 |
+
if "|" in text:
|
208 |
+
resegment_text_to_fill_space = False
|
209 |
+
else:
|
210 |
+
resegment_text_to_fill_space = True
|
211 |
+
|
212 |
+
alignment_config = AlignmentConfig(
|
213 |
+
pretrained_name=model_name,
|
214 |
+
manifest_filepath=manifest_path,
|
215 |
+
output_dir=f"{tmpdir}/nfa_output/",
|
216 |
+
audio_filepath_parts_in_utt_id=1,
|
217 |
+
batch_size=1,
|
218 |
+
use_local_attention=True,
|
219 |
+
additional_segment_grouping_separator="|",
|
220 |
+
# transcribe_device='cpu',
|
221 |
+
# viterbi_device='cpu',
|
222 |
+
save_output_file_formats=["ass"],
|
223 |
+
ass_file_config=ASSFileConfig(
|
224 |
+
fontsize=45,
|
225 |
+
resegment_text_to_fill_space=resegment_text_to_fill_space,
|
226 |
+
max_lines_per_segment=4,
|
227 |
+
text_already_spoken_rgb=hex_to_rgb_list(col1),
|
228 |
+
text_being_spoken_rgb=hex_to_rgb_list(col2),
|
229 |
+
text_not_yet_spoken_rgb=hex_to_rgb_list(col3),
|
230 |
+
),
|
231 |
+
)
|
232 |
+
|
233 |
+
progress(0.5, desc="Aligning audio")
|
234 |
+
|
235 |
+
main(alignment_config)
|
236 |
+
|
237 |
+
progress(0.95, desc="Saving generated alignments")
|
238 |
+
|
239 |
+
|
240 |
+
if lang=="zh":
|
241 |
+
# make video file from the token-level ASS file
|
242 |
+
ass_file_for_video = f"{tmpdir}/nfa_output/ass/tokens/{utt_id}.ass"
|
243 |
+
else:
|
244 |
+
# make video file from the word-level ASS file
|
245 |
+
ass_file_for_video = f"{tmpdir}/nfa_output/ass/words/{utt_id}.ass"
|
246 |
+
|
247 |
+
ffmpeg_command = (
|
248 |
+
f"ffmpeg -y -i {audio_path} "
|
249 |
+
"-f lavfi -i color=c=white:s=1280x720:r=50 "
|
250 |
+
"-crf 1 -shortest -vcodec libx264 -pix_fmt yuv420p "
|
251 |
+
f"-vf 'ass={ass_file_for_video}' "
|
252 |
+
f"{output_video_filepath}"
|
253 |
+
)
|
254 |
+
|
255 |
+
os.system(ffmpeg_command)
|
256 |
+
|
257 |
+
return output_video_filepath, gr.update(value=output_info, visible=True), output_video_filepath
|
258 |
+
|
259 |
+
|
260 |
+
def delete_non_tmp_video(video_path):
|
261 |
+
if video_path:
|
262 |
+
if os.path.exists(video_path):
|
263 |
+
os.remove(video_path)
|
264 |
+
return None
|
265 |
+
|
266 |
+
|
267 |
+
with gr.Blocks(title="NeMo Forced Aligner", theme="huggingface") as demo:
|
268 |
+
non_tmp_output_video_filepath = gr.State([])
|
269 |
+
|
270 |
+
with gr.Row():
|
271 |
+
with gr.Column():
|
272 |
+
gr.Markdown("# NeMo Forced Aligner")
|
273 |
+
gr.Markdown(
|
274 |
+
"Demo for [NeMo Forced Aligner](https://github.com/NVIDIA/NeMo/tree/main/tools/nemo_forced_aligner) (NFA). "
|
275 |
+
"Upload audio and (optionally) the text spoken in the audio to generate a video where each part of the text will be highlighted as it is spoken. ",
|
276 |
+
)
|
277 |
+
|
278 |
+
with gr.Row():
|
279 |
+
|
280 |
+
with gr.Column(scale=1):
|
281 |
+
gr.Markdown("## Input")
|
282 |
+
lang_drop = gr.Dropdown(choices=["de", "en", "es", "fr", "zh"], value="en", label="Audio language",)
|
283 |
+
|
284 |
+
mic_in = gr.Audio(source="microphone", type='filepath', label="Microphone input (max 4 mins)")
|
285 |
+
audio_file_in = gr.Audio(source="upload", type='filepath', label="File upload (max 4 mins)")
|
286 |
+
ref_text = gr.Textbox(
|
287 |
+
label="[Optional] The reference text. Use '|' separators to specify which text will appear together. "
|
288 |
+
"Leave this field blank to use an ASR model's transcription as the reference text instead."
|
289 |
+
)
|
290 |
+
|
291 |
+
gr.Markdown("[Optional] For fun - adjust the colors of the text in the output video")
|
292 |
+
with gr.Row():
|
293 |
+
col1 = gr.ColorPicker(label="text already spoken", value="#fcba03")
|
294 |
+
col2 = gr.ColorPicker(label="text being spoken", value="#bf45bf")
|
295 |
+
col3 = gr.ColorPicker(label="text to be spoken", value="#3e1af0")
|
296 |
+
|
297 |
+
submit_button = gr.Button("Submit")
|
298 |
+
|
299 |
+
with gr.Column(scale=1):
|
300 |
+
gr.Markdown("## Output")
|
301 |
+
video_out = gr.Video(label="output video")
|
302 |
+
text_out = gr.Textbox(label="output info", visible=False)
|
303 |
+
|
304 |
+
submit_button.click(
|
305 |
+
fn=align,
|
306 |
+
inputs=[lang_drop, mic_in, audio_file_in, ref_text, col1, col2, col3,],
|
307 |
+
outputs=[video_out, text_out, non_tmp_output_video_filepath],
|
308 |
+
).then(
|
309 |
+
fn=delete_non_tmp_video, inputs=[non_tmp_output_video_filepath], outputs=None,
|
310 |
+
)
|
311 |
+
|
312 |
+
demo.queue()
|
313 |
+
demo.launch()
|
314 |
+
|
packages.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
ffmpeg
|
2 |
+
libsndfile1
|
3 |
+
build-essential
|
pre-requirements.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
Cython
|
2 |
+
torch
|
requirements.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
nemo_toolkit[all]
|
utils/constants.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
BLANK_TOKEN = "<b>"
|
16 |
+
|
17 |
+
SPACE_TOKEN = "<space>"
|
18 |
+
|
19 |
+
V_NEGATIVE_NUM = -3.4e38 # this is just above the most negative number in torch.float32
|
utils/data_prep.py
ADDED
@@ -0,0 +1,835 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import json
|
16 |
+
from dataclasses import dataclass, field
|
17 |
+
from pathlib import Path
|
18 |
+
from typing import List, Union
|
19 |
+
|
20 |
+
import soundfile as sf
|
21 |
+
import torch
|
22 |
+
from tqdm.auto import tqdm
|
23 |
+
from utils.constants import BLANK_TOKEN, SPACE_TOKEN, V_NEGATIVE_NUM
|
24 |
+
|
25 |
+
from nemo.utils import logging
|
26 |
+
|
27 |
+
|
28 |
+
def _get_utt_id(audio_filepath, audio_filepath_parts_in_utt_id):
|
29 |
+
fp_parts = Path(audio_filepath).parts[-audio_filepath_parts_in_utt_id:]
|
30 |
+
utt_id = Path("_".join(fp_parts)).stem
|
31 |
+
utt_id = utt_id.replace(" ", "-") # replace any spaces in the filepath with dashes
|
32 |
+
return utt_id
|
33 |
+
|
34 |
+
|
35 |
+
def get_batch_starts_ends(manifest_filepath, batch_size):
|
36 |
+
"""
|
37 |
+
Get the start and end ids of the lines we will use for each 'batch'.
|
38 |
+
"""
|
39 |
+
|
40 |
+
with open(manifest_filepath, 'r') as f:
|
41 |
+
num_lines_in_manifest = sum(1 for _ in f)
|
42 |
+
|
43 |
+
starts = [x for x in range(0, num_lines_in_manifest, batch_size)]
|
44 |
+
ends = [x - 1 for x in starts]
|
45 |
+
ends.pop(0)
|
46 |
+
ends.append(num_lines_in_manifest)
|
47 |
+
|
48 |
+
return starts, ends
|
49 |
+
|
50 |
+
|
51 |
+
def is_entry_in_any_lines(manifest_filepath, entry):
|
52 |
+
"""
|
53 |
+
Returns True if entry is a key in any of the JSON lines in manifest_filepath
|
54 |
+
"""
|
55 |
+
|
56 |
+
entry_in_manifest = False
|
57 |
+
|
58 |
+
with open(manifest_filepath, 'r') as f:
|
59 |
+
for line in f:
|
60 |
+
data = json.loads(line)
|
61 |
+
|
62 |
+
if entry in data:
|
63 |
+
entry_in_manifest = True
|
64 |
+
|
65 |
+
return entry_in_manifest
|
66 |
+
|
67 |
+
|
68 |
+
def is_entry_in_all_lines(manifest_filepath, entry):
|
69 |
+
"""
|
70 |
+
Returns True is entry is a key in all of the JSON lines in manifest_filepath.
|
71 |
+
"""
|
72 |
+
with open(manifest_filepath, 'r') as f:
|
73 |
+
for line in f:
|
74 |
+
data = json.loads(line)
|
75 |
+
|
76 |
+
if entry not in data:
|
77 |
+
return False
|
78 |
+
|
79 |
+
return True
|
80 |
+
|
81 |
+
|
82 |
+
def get_manifest_lines_batch(manifest_filepath, start, end):
|
83 |
+
manifest_lines_batch = []
|
84 |
+
with open(manifest_filepath, "r", encoding="utf-8-sig") as f:
|
85 |
+
for line_i, line in enumerate(f):
|
86 |
+
if line_i >= start and line_i <= end:
|
87 |
+
data = json.loads(line)
|
88 |
+
if "text" in data:
|
89 |
+
# remove any BOM, any duplicated spaces, convert any
|
90 |
+
# newline chars to spaces
|
91 |
+
data["text"] = data["text"].replace("\ufeff", "")
|
92 |
+
data["text"] = " ".join(data["text"].split())
|
93 |
+
manifest_lines_batch.append(data)
|
94 |
+
|
95 |
+
if line_i == end:
|
96 |
+
break
|
97 |
+
return manifest_lines_batch
|
98 |
+
|
99 |
+
|
100 |
+
def get_char_tokens(text, model):
|
101 |
+
tokens = []
|
102 |
+
for character in text:
|
103 |
+
if character in model.decoder.vocabulary:
|
104 |
+
tokens.append(model.decoder.vocabulary.index(character))
|
105 |
+
else:
|
106 |
+
tokens.append(len(model.decoder.vocabulary)) # return unk token (same as blank token)
|
107 |
+
|
108 |
+
return tokens
|
109 |
+
|
110 |
+
|
111 |
+
def is_sub_or_superscript_pair(ref_text, text):
|
112 |
+
"""returns True if ref_text is a subscript or superscript version of text"""
|
113 |
+
sub_or_superscript_to_num = {
|
114 |
+
"⁰": "0",
|
115 |
+
"¹": "1",
|
116 |
+
"²": "2",
|
117 |
+
"³": "3",
|
118 |
+
"⁴": "4",
|
119 |
+
"⁵": "5",
|
120 |
+
"⁶": "6",
|
121 |
+
"⁷": "7",
|
122 |
+
"⁸": "8",
|
123 |
+
"⁹": "9",
|
124 |
+
"₀": "0",
|
125 |
+
"₁": "1",
|
126 |
+
"₂": "2",
|
127 |
+
"₃": "3",
|
128 |
+
"₄": "4",
|
129 |
+
"₅": "5",
|
130 |
+
"₆": "6",
|
131 |
+
"₇": "7",
|
132 |
+
"₈": "8",
|
133 |
+
"₉": "9",
|
134 |
+
}
|
135 |
+
|
136 |
+
if text in sub_or_superscript_to_num:
|
137 |
+
if sub_or_superscript_to_num[text] == ref_text:
|
138 |
+
return True
|
139 |
+
return False
|
140 |
+
|
141 |
+
|
142 |
+
def restore_token_case(word, word_tokens):
|
143 |
+
|
144 |
+
# remove repeated "▁" and "_" from word as that is what the tokenizer will do
|
145 |
+
while "▁▁" in word:
|
146 |
+
word = word.replace("▁▁", "▁")
|
147 |
+
|
148 |
+
while "__" in word:
|
149 |
+
word = word.repalce("__", "_")
|
150 |
+
|
151 |
+
word_tokens_cased = []
|
152 |
+
word_char_pointer = 0
|
153 |
+
|
154 |
+
for token in word_tokens:
|
155 |
+
token_cased = ""
|
156 |
+
|
157 |
+
for token_char in token:
|
158 |
+
if token_char == word[word_char_pointer]:
|
159 |
+
token_cased += token_char
|
160 |
+
word_char_pointer += 1
|
161 |
+
|
162 |
+
else:
|
163 |
+
if token_char.upper() == word[word_char_pointer] or is_sub_or_superscript_pair(
|
164 |
+
token_char, word[word_char_pointer]
|
165 |
+
):
|
166 |
+
token_cased += token_char.upper()
|
167 |
+
word_char_pointer += 1
|
168 |
+
else:
|
169 |
+
if token_char == "▁" or token_char == "_":
|
170 |
+
if word[word_char_pointer] == "▁" or word[word_char_pointer] == "_":
|
171 |
+
token_cased += token_char
|
172 |
+
word_char_pointer += 1
|
173 |
+
elif word_char_pointer == 0:
|
174 |
+
token_cased += token_char
|
175 |
+
|
176 |
+
else:
|
177 |
+
raise RuntimeError(
|
178 |
+
f"Unexpected error - failed to recover capitalization of tokens for word {word}"
|
179 |
+
)
|
180 |
+
|
181 |
+
word_tokens_cased.append(token_cased)
|
182 |
+
|
183 |
+
return word_tokens_cased
|
184 |
+
|
185 |
+
|
186 |
+
@dataclass
|
187 |
+
class Token:
|
188 |
+
text: str = None
|
189 |
+
text_cased: str = None
|
190 |
+
s_start: int = None
|
191 |
+
s_end: int = None
|
192 |
+
t_start: float = None
|
193 |
+
t_end: float = None
|
194 |
+
|
195 |
+
|
196 |
+
@dataclass
|
197 |
+
class Word:
|
198 |
+
text: str = None
|
199 |
+
s_start: int = None
|
200 |
+
s_end: int = None
|
201 |
+
t_start: float = None
|
202 |
+
t_end: float = None
|
203 |
+
tokens: List[Token] = field(default_factory=list)
|
204 |
+
|
205 |
+
|
206 |
+
@dataclass
|
207 |
+
class Segment:
|
208 |
+
text: str = None
|
209 |
+
s_start: int = None
|
210 |
+
s_end: int = None
|
211 |
+
t_start: float = None
|
212 |
+
t_end: float = None
|
213 |
+
words_and_tokens: List[Union[Word, Token]] = field(default_factory=list)
|
214 |
+
|
215 |
+
|
216 |
+
@dataclass
|
217 |
+
class Utterance:
|
218 |
+
token_ids_with_blanks: List[int] = field(default_factory=list)
|
219 |
+
segments_and_tokens: List[Union[Segment, Token]] = field(default_factory=list)
|
220 |
+
text: str = None
|
221 |
+
pred_text: str = None
|
222 |
+
audio_filepath: str = None
|
223 |
+
utt_id: str = None
|
224 |
+
saved_output_files: dict = field(default_factory=dict)
|
225 |
+
|
226 |
+
|
227 |
+
def get_utt_obj(
|
228 |
+
text, model, separator, T, audio_filepath, utt_id,
|
229 |
+
):
|
230 |
+
"""
|
231 |
+
Function to create an Utterance object and add all necessary information to it except
|
232 |
+
for timings of the segments / words / tokens according to the alignment - that will
|
233 |
+
be done later in a different function, after the alignment is done.
|
234 |
+
|
235 |
+
The Utterance object has a list segments_and_tokens which contains Segment objects and
|
236 |
+
Token objects (for blank tokens in between segments).
|
237 |
+
Within the Segment objects, there is a list words_and_tokens which contains Word objects and
|
238 |
+
Token objects (for blank tokens in between words).
|
239 |
+
Within the Word objects, there is a list tokens tokens which contains Token objects for
|
240 |
+
blank and non-blank tokens.
|
241 |
+
We will be building up these lists in this function. This data structure will then be useful for
|
242 |
+
generating the various output files that we wish to save.
|
243 |
+
"""
|
244 |
+
|
245 |
+
if not separator: # if separator is not defined - treat the whole text as one segment
|
246 |
+
segments = [text]
|
247 |
+
else:
|
248 |
+
segments = text.split(separator)
|
249 |
+
|
250 |
+
# remove any spaces at start and end of segments
|
251 |
+
segments = [seg.strip() for seg in segments]
|
252 |
+
# remove any empty segments
|
253 |
+
segments = [seg for seg in segments if len(seg) > 0]
|
254 |
+
|
255 |
+
utt = Utterance(text=text, audio_filepath=audio_filepath, utt_id=utt_id,)
|
256 |
+
|
257 |
+
# build up lists: token_ids_with_blanks, segments_and_tokens.
|
258 |
+
# The code for these is different depending on whether we use char-based tokens or not
|
259 |
+
if hasattr(model, 'tokenizer'):
|
260 |
+
if hasattr(model, 'blank_id'):
|
261 |
+
BLANK_ID = model.blank_id
|
262 |
+
else:
|
263 |
+
BLANK_ID = len(model.tokenizer.vocab) # TODO: check
|
264 |
+
|
265 |
+
utt.token_ids_with_blanks = [BLANK_ID]
|
266 |
+
|
267 |
+
# check for text being 0 length
|
268 |
+
if len(text) == 0:
|
269 |
+
return utt
|
270 |
+
|
271 |
+
# check for # tokens + token repetitions being > T
|
272 |
+
all_tokens = model.tokenizer.text_to_ids(text)
|
273 |
+
n_token_repetitions = 0
|
274 |
+
for i_tok in range(1, len(all_tokens)):
|
275 |
+
if all_tokens[i_tok] == all_tokens[i_tok - 1]:
|
276 |
+
n_token_repetitions += 1
|
277 |
+
|
278 |
+
if len(all_tokens) + n_token_repetitions > T:
|
279 |
+
logging.info(
|
280 |
+
f"Utterance {utt_id} has too many tokens compared to the audio file duration."
|
281 |
+
" Will not generate output alignment files for this utterance."
|
282 |
+
)
|
283 |
+
return utt
|
284 |
+
|
285 |
+
# build up data structures containing segments/words/tokens
|
286 |
+
utt.segments_and_tokens.append(Token(text=BLANK_TOKEN, text_cased=BLANK_TOKEN, s_start=0, s_end=0,))
|
287 |
+
|
288 |
+
segment_s_pointer = 1 # first segment will start at s=1 because s=0 is a blank
|
289 |
+
word_s_pointer = 1 # first word will start at s=1 because s=0 is a blank
|
290 |
+
|
291 |
+
for segment in segments:
|
292 |
+
# add the segment to segment_info and increment the segment_s_pointer
|
293 |
+
segment_tokens = model.tokenizer.text_to_tokens(segment)
|
294 |
+
utt.segments_and_tokens.append(
|
295 |
+
Segment(
|
296 |
+
text=segment,
|
297 |
+
s_start=segment_s_pointer,
|
298 |
+
# segment_tokens do not contain blanks => need to muliply by 2
|
299 |
+
# s_end needs to be the index of the final token (including blanks) of the current segment:
|
300 |
+
# segment_s_pointer + len(segment_tokens) * 2 is the index of the first token of the next segment =>
|
301 |
+
# => need to subtract 2
|
302 |
+
s_end=segment_s_pointer + len(segment_tokens) * 2 - 2,
|
303 |
+
)
|
304 |
+
)
|
305 |
+
segment_s_pointer += (
|
306 |
+
len(segment_tokens) * 2
|
307 |
+
) # multiply by 2 to account for blanks (which are not present in segment_tokens)
|
308 |
+
|
309 |
+
words = segment.split(" ") # we define words to be space-separated sub-strings
|
310 |
+
for word_i, word in enumerate(words):
|
311 |
+
|
312 |
+
word_tokens = model.tokenizer.text_to_tokens(word)
|
313 |
+
word_token_ids = model.tokenizer.text_to_ids(word)
|
314 |
+
word_tokens_cased = restore_token_case(word, word_tokens)
|
315 |
+
|
316 |
+
# add the word to word_info and increment the word_s_pointer
|
317 |
+
utt.segments_and_tokens[-1].words_and_tokens.append(
|
318 |
+
# word_tokens do not contain blanks => need to muliply by 2
|
319 |
+
# s_end needs to be the index of the final token (including blanks) of the current word:
|
320 |
+
# word_s_pointer + len(word_tokens) * 2 is the index of the first token of the next word =>
|
321 |
+
# => need to subtract 2
|
322 |
+
Word(text=word, s_start=word_s_pointer, s_end=word_s_pointer + len(word_tokens) * 2 - 2)
|
323 |
+
)
|
324 |
+
word_s_pointer += (
|
325 |
+
len(word_tokens) * 2
|
326 |
+
) # multiply by 2 to account for blanks (which are not present in word_tokens)
|
327 |
+
|
328 |
+
for token_i, (token, token_id, token_cased) in enumerate(
|
329 |
+
zip(word_tokens, word_token_ids, word_tokens_cased)
|
330 |
+
):
|
331 |
+
# add the text tokens and the blanks in between them
|
332 |
+
# to our token-based variables
|
333 |
+
utt.token_ids_with_blanks.extend([token_id, BLANK_ID])
|
334 |
+
# adding Token object for non-blank token
|
335 |
+
utt.segments_and_tokens[-1].words_and_tokens[-1].tokens.append(
|
336 |
+
Token(
|
337 |
+
text=token,
|
338 |
+
text_cased=token_cased,
|
339 |
+
# utt.token_ids_with_blanks has the form [...., <this non-blank token>, <blank>] =>
|
340 |
+
# => if do len(utt.token_ids_with_blanks) - 1 you get the index of the final <blank>
|
341 |
+
# => we want to do len(utt.token_ids_with_blanks) - 2 to get the index of <this non-blank token>
|
342 |
+
s_start=len(utt.token_ids_with_blanks) - 2,
|
343 |
+
# s_end is same as s_start since the token only occupies one element in the list
|
344 |
+
s_end=len(utt.token_ids_with_blanks) - 2,
|
345 |
+
)
|
346 |
+
)
|
347 |
+
|
348 |
+
# adding Token object for blank tokens in between the tokens of the word
|
349 |
+
# (ie do not add another blank if you have reached the end)
|
350 |
+
if token_i < len(word_tokens) - 1:
|
351 |
+
utt.segments_and_tokens[-1].words_and_tokens[-1].tokens.append(
|
352 |
+
Token(
|
353 |
+
text=BLANK_TOKEN,
|
354 |
+
text_cased=BLANK_TOKEN,
|
355 |
+
# utt.token_ids_with_blanks has the form [...., <this blank token>] =>
|
356 |
+
# => if do len(utt.token_ids_with_blanks) -1 you get the index of this <blank>
|
357 |
+
s_start=len(utt.token_ids_with_blanks) - 1,
|
358 |
+
# s_end is same as s_start since the token only occupies one element in the list
|
359 |
+
s_end=len(utt.token_ids_with_blanks) - 1,
|
360 |
+
)
|
361 |
+
)
|
362 |
+
|
363 |
+
# add a Token object for blanks in between words in this segment
|
364 |
+
# (but only *in between* - do not add the token if it is after the final word)
|
365 |
+
if word_i < len(words) - 1:
|
366 |
+
utt.segments_and_tokens[-1].words_and_tokens.append(
|
367 |
+
Token(
|
368 |
+
text=BLANK_TOKEN,
|
369 |
+
text_cased=BLANK_TOKEN,
|
370 |
+
# utt.token_ids_with_blanks has the form [...., <this blank token>] =>
|
371 |
+
# => if do len(utt.token_ids_with_blanks) -1 you get the index of this <blank>
|
372 |
+
s_start=len(utt.token_ids_with_blanks) - 1,
|
373 |
+
# s_end is same as s_start since the token only occupies one element in the list
|
374 |
+
s_end=len(utt.token_ids_with_blanks) - 1,
|
375 |
+
)
|
376 |
+
)
|
377 |
+
|
378 |
+
# add the blank token in between segments/after the final segment
|
379 |
+
utt.segments_and_tokens.append(
|
380 |
+
Token(
|
381 |
+
text=BLANK_TOKEN,
|
382 |
+
text_cased=BLANK_TOKEN,
|
383 |
+
# utt.token_ids_with_blanks has the form [...., <this blank token>] =>
|
384 |
+
# => if do len(utt.token_ids_with_blanks) -1 you get the index of this <blank>
|
385 |
+
s_start=len(utt.token_ids_with_blanks) - 1,
|
386 |
+
# s_end is same as s_start since the token only occupies one element in the list
|
387 |
+
s_end=len(utt.token_ids_with_blanks) - 1,
|
388 |
+
)
|
389 |
+
)
|
390 |
+
|
391 |
+
return utt
|
392 |
+
|
393 |
+
elif hasattr(model.decoder, "vocabulary"): # i.e. tokenization is simply character-based
|
394 |
+
|
395 |
+
BLANK_ID = len(model.decoder.vocabulary) # TODO: check this is correct
|
396 |
+
SPACE_ID = model.decoder.vocabulary.index(" ")
|
397 |
+
|
398 |
+
utt.token_ids_with_blanks = [BLANK_ID]
|
399 |
+
|
400 |
+
# check for text being 0 length
|
401 |
+
if len(text) == 0:
|
402 |
+
return utt
|
403 |
+
|
404 |
+
# check for # tokens + token repetitions being > T
|
405 |
+
all_tokens = get_char_tokens(text, model)
|
406 |
+
n_token_repetitions = 0
|
407 |
+
for i_tok in range(1, len(all_tokens)):
|
408 |
+
if all_tokens[i_tok] == all_tokens[i_tok - 1]:
|
409 |
+
n_token_repetitions += 1
|
410 |
+
|
411 |
+
if len(all_tokens) + n_token_repetitions > T:
|
412 |
+
logging.info(
|
413 |
+
f"Utterance {utt_id} has too many tokens compared to the audio file duration."
|
414 |
+
" Will not generate output alignment files for this utterance."
|
415 |
+
)
|
416 |
+
return utt
|
417 |
+
|
418 |
+
# build up data structures containing segments/words/tokens
|
419 |
+
utt.segments_and_tokens.append(Token(text=BLANK_TOKEN, text_cased=BLANK_TOKEN, s_start=0, s_end=0,))
|
420 |
+
|
421 |
+
segment_s_pointer = 1 # first segment will start at s=1 because s=0 is a blank
|
422 |
+
word_s_pointer = 1 # first word will start at s=1 because s=0 is a blank
|
423 |
+
|
424 |
+
for i_segment, segment in enumerate(segments):
|
425 |
+
# add the segment to segment_info and increment the segment_s_pointer
|
426 |
+
segment_tokens = get_char_tokens(segment, model)
|
427 |
+
utt.segments_and_tokens.append(
|
428 |
+
Segment(
|
429 |
+
text=segment,
|
430 |
+
s_start=segment_s_pointer,
|
431 |
+
# segment_tokens do not contain blanks => need to muliply by 2
|
432 |
+
# s_end needs to be the index of the final token (including blanks) of the current segment:
|
433 |
+
# segment_s_pointer + len(segment_tokens) * 2 is the index of the first token of the next segment =>
|
434 |
+
# => need to subtract 2
|
435 |
+
s_end=segment_s_pointer + len(segment_tokens) * 2 - 2,
|
436 |
+
)
|
437 |
+
)
|
438 |
+
|
439 |
+
# for correct calculation: multiply len(segment_tokens) by 2 to account for blanks (which are not present in segment_tokens)
|
440 |
+
# and + 2 to account for [<token for space in between segments>, <blank token after that space token>]
|
441 |
+
segment_s_pointer += len(segment_tokens) * 2 + 2
|
442 |
+
|
443 |
+
words = segment.split(" ") # we define words to be space-separated substrings
|
444 |
+
for i_word, word in enumerate(words):
|
445 |
+
|
446 |
+
# convert string to list of characters
|
447 |
+
word_tokens = list(word)
|
448 |
+
# convert list of characters to list of their ids in the vocabulary
|
449 |
+
word_token_ids = get_char_tokens(word, model)
|
450 |
+
|
451 |
+
# add the word to word_info and increment the word_s_pointer
|
452 |
+
utt.segments_and_tokens[-1].words_and_tokens.append(
|
453 |
+
# note for s_end:
|
454 |
+
# word_tokens do not contain blanks => need to muliply by 2
|
455 |
+
# s_end needs to be the index of the final token (including blanks) of the current word:
|
456 |
+
# word_s_pointer + len(word_tokens) * 2 is the index of the first token of the next word =>
|
457 |
+
# => need to subtract 2
|
458 |
+
Word(text=word, s_start=word_s_pointer, s_end=word_s_pointer + len(word_tokens) * 2 - 2)
|
459 |
+
)
|
460 |
+
|
461 |
+
# for correct calculation: multiply len(word_tokens) by 2 to account for blanks (which are not present in word_tokens)
|
462 |
+
# and + 2 to account for [<token for space in between words>, <blank token after that space token>]
|
463 |
+
word_s_pointer += len(word_tokens) * 2 + 2
|
464 |
+
|
465 |
+
for token_i, (token, token_id) in enumerate(zip(word_tokens, word_token_ids)):
|
466 |
+
# add the text tokens and the blanks in between them
|
467 |
+
# to our token-based variables
|
468 |
+
utt.token_ids_with_blanks.extend([token_id])
|
469 |
+
utt.segments_and_tokens[-1].words_and_tokens[-1].tokens.append(
|
470 |
+
Token(
|
471 |
+
text=token,
|
472 |
+
text_cased=token,
|
473 |
+
# utt.token_ids_with_blanks has the form [..., <this non-blank token>]
|
474 |
+
# => do len(utt.token_ids_with_blanks) - 1 to get the index of this non-blank token
|
475 |
+
s_start=len(utt.token_ids_with_blanks) - 1,
|
476 |
+
# s_end is same as s_start since the token only occupies one element in the list
|
477 |
+
s_end=len(utt.token_ids_with_blanks) - 1,
|
478 |
+
)
|
479 |
+
)
|
480 |
+
|
481 |
+
if token_i < len(word_tokens) - 1: # only add blank tokens that are in the middle of words
|
482 |
+
utt.token_ids_with_blanks.extend([BLANK_ID])
|
483 |
+
utt.segments_and_tokens[-1].words_and_tokens[-1].tokens.append(
|
484 |
+
Token(
|
485 |
+
text=BLANK_TOKEN,
|
486 |
+
text_cased=BLANK_TOKEN,
|
487 |
+
# utt.token_ids_with_blanks has the form [..., <this blank token>]
|
488 |
+
# => do len(utt.token_ids_with_blanks) - 1 to get the index of this blank token
|
489 |
+
s_start=len(utt.token_ids_with_blanks) - 1,
|
490 |
+
# s_end is same as s_start since the token only occupies one element in the list
|
491 |
+
s_end=len(utt.token_ids_with_blanks) - 1,
|
492 |
+
)
|
493 |
+
)
|
494 |
+
|
495 |
+
# add space token (and the blanks around it) unless this is the final word in a segment
|
496 |
+
if i_word < len(words) - 1:
|
497 |
+
utt.token_ids_with_blanks.extend([BLANK_ID, SPACE_ID, BLANK_ID])
|
498 |
+
utt.segments_and_tokens[-1].words_and_tokens.append(
|
499 |
+
Token(
|
500 |
+
text=BLANK_TOKEN,
|
501 |
+
text_cased=BLANK_TOKEN,
|
502 |
+
# utt.token_ids_with_blanks has the form
|
503 |
+
# [..., <final token of previous word>, <blank token>, <space token>, <blank token>]
|
504 |
+
# => do len(utt.token_ids_with_blanks) - 3 to get the index of the blank token before the space token
|
505 |
+
s_start=len(utt.token_ids_with_blanks) - 3,
|
506 |
+
# s_end is same as s_start since the token only occupies one element in the list
|
507 |
+
s_end=len(utt.token_ids_with_blanks) - 3,
|
508 |
+
)
|
509 |
+
)
|
510 |
+
utt.segments_and_tokens[-1].words_and_tokens.append(
|
511 |
+
Token(
|
512 |
+
text=SPACE_TOKEN,
|
513 |
+
text_cased=SPACE_TOKEN,
|
514 |
+
# utt.token_ids_with_blanks has the form
|
515 |
+
# [..., <final token of previous word>, <blank token>, <space token>, <blank token>]
|
516 |
+
# => do len(utt.token_ids_with_blanks) - 2 to get the index of the space token
|
517 |
+
s_start=len(utt.token_ids_with_blanks) - 2,
|
518 |
+
# s_end is same as s_start since the token only occupies one element in the list
|
519 |
+
s_end=len(utt.token_ids_with_blanks) - 2,
|
520 |
+
)
|
521 |
+
)
|
522 |
+
utt.segments_and_tokens[-1].words_and_tokens.append(
|
523 |
+
Token(
|
524 |
+
text=BLANK_TOKEN,
|
525 |
+
text_cased=BLANK_TOKEN,
|
526 |
+
# utt.token_ids_with_blanks has the form
|
527 |
+
# [..., <final token of previous word>, <blank token>, <space token>, <blank token>]
|
528 |
+
# => do len(utt.token_ids_with_blanks) - 1 to get the index of the blank token after the space token
|
529 |
+
s_start=len(utt.token_ids_with_blanks) - 1,
|
530 |
+
# s_end is same as s_start since the token only occupies one element in the list
|
531 |
+
s_end=len(utt.token_ids_with_blanks) - 1,
|
532 |
+
)
|
533 |
+
)
|
534 |
+
|
535 |
+
# add a blank to the segment, and add a space after if this is not the final segment
|
536 |
+
utt.token_ids_with_blanks.extend([BLANK_ID])
|
537 |
+
utt.segments_and_tokens.append(
|
538 |
+
Token(
|
539 |
+
text=BLANK_TOKEN,
|
540 |
+
text_cased=BLANK_TOKEN,
|
541 |
+
# utt.token_ids_with_blanks has the form [..., <this blank token>]
|
542 |
+
# => do len(utt.token_ids_with_blanks) - 1 to get the index of this blank token
|
543 |
+
s_start=len(utt.token_ids_with_blanks) - 1,
|
544 |
+
# s_end is same as s_start since the token only occupies one element in the list
|
545 |
+
s_end=len(utt.token_ids_with_blanks) - 1,
|
546 |
+
)
|
547 |
+
)
|
548 |
+
|
549 |
+
if i_segment < len(segments) - 1:
|
550 |
+
utt.token_ids_with_blanks.extend([SPACE_ID, BLANK_ID])
|
551 |
+
utt.segments_and_tokens.append(
|
552 |
+
Token(
|
553 |
+
text=SPACE_TOKEN,
|
554 |
+
text_cased=SPACE_TOKEN,
|
555 |
+
# utt.token_ids_with_blanks has the form
|
556 |
+
# [..., <space token>, <blank token>]
|
557 |
+
# => do len(utt.token_ids_with_blanks) - 2 to get the index of the space token
|
558 |
+
s_start=len(utt.token_ids_with_blanks) - 2,
|
559 |
+
# s_end is same as s_start since the token only occupies one element in the list
|
560 |
+
s_end=len(utt.token_ids_with_blanks) - 2,
|
561 |
+
)
|
562 |
+
)
|
563 |
+
utt.segments_and_tokens.append(
|
564 |
+
Token(
|
565 |
+
text=BLANK_TOKEN,
|
566 |
+
text_cased=BLANK_TOKEN,
|
567 |
+
# utt.token_ids_with_blanks has the form
|
568 |
+
# [..., <space token>, <blank token>]
|
569 |
+
# => do len(utt.token_ids_with_blanks) - 1 to get the index of the blank token
|
570 |
+
s_start=len(utt.token_ids_with_blanks) - 1,
|
571 |
+
# s_end is same as s_start since the token only occupies one element in the list
|
572 |
+
s_end=len(utt.token_ids_with_blanks) - 1,
|
573 |
+
)
|
574 |
+
)
|
575 |
+
|
576 |
+
return utt
|
577 |
+
|
578 |
+
else:
|
579 |
+
raise RuntimeError("Cannot get tokens of this model.")
|
580 |
+
|
581 |
+
|
582 |
+
def add_t_start_end_to_utt_obj(utt_obj, alignment_utt, output_timestep_duration):
|
583 |
+
"""
|
584 |
+
Function to add t_start and t_end (representing time in seconds) to the Utterance object utt_obj.
|
585 |
+
Args:
|
586 |
+
utt_obj: Utterance object to which we will add t_start and t_end for its
|
587 |
+
constituent segments/words/tokens.
|
588 |
+
alignment_utt: a list of ints indicating which token does the alignment pass through at each
|
589 |
+
timestep (will take the form [0, 0, 1, 1, ..., <num of tokens including blanks in uterance>]).
|
590 |
+
output_timestep_duration: a float indicating the duration of a single output timestep from
|
591 |
+
the ASR Model.
|
592 |
+
|
593 |
+
Returns:
|
594 |
+
utt_obj: updated Utterance object.
|
595 |
+
"""
|
596 |
+
|
597 |
+
# General idea for the algorithm of how we add t_start and t_end
|
598 |
+
# the timestep where a token s starts is the location of the first appearance of s_start in alignment_utt
|
599 |
+
# the timestep where a token s ends is the location of the final appearance of s_end in alignment_utt
|
600 |
+
# We will make dictionaries num_to_first_alignment_appearance and
|
601 |
+
# num_to_last_appearance and use that to update all of
|
602 |
+
# the t_start and t_end values in utt_obj.
|
603 |
+
# We will put t_start = t_end = -1 for tokens that are skipped (should only be blanks)
|
604 |
+
|
605 |
+
num_to_first_alignment_appearance = dict()
|
606 |
+
num_to_last_alignment_appearance = dict()
|
607 |
+
|
608 |
+
prev_s = -1 # use prev_s to keep track of when the s changes
|
609 |
+
for t, s in enumerate(alignment_utt):
|
610 |
+
if s > prev_s:
|
611 |
+
num_to_first_alignment_appearance[s] = t
|
612 |
+
|
613 |
+
if prev_s >= 0: # dont record prev_s = -1
|
614 |
+
num_to_last_alignment_appearance[prev_s] = t - 1
|
615 |
+
prev_s = s
|
616 |
+
# add last appearance of the final s
|
617 |
+
num_to_last_alignment_appearance[prev_s] = len(alignment_utt) - 1
|
618 |
+
|
619 |
+
# update all the t_start and t_end in utt_obj
|
620 |
+
for segment_or_token in utt_obj.segments_and_tokens:
|
621 |
+
if type(segment_or_token) is Segment:
|
622 |
+
segment = segment_or_token
|
623 |
+
segment.t_start = num_to_first_alignment_appearance[segment.s_start] * output_timestep_duration
|
624 |
+
segment.t_end = (num_to_last_alignment_appearance[segment.s_end] + 1) * output_timestep_duration
|
625 |
+
|
626 |
+
for word_or_token in segment.words_and_tokens:
|
627 |
+
if type(word_or_token) is Word:
|
628 |
+
word = word_or_token
|
629 |
+
word.t_start = num_to_first_alignment_appearance[word.s_start] * output_timestep_duration
|
630 |
+
word.t_end = (num_to_last_alignment_appearance[word.s_end] + 1) * output_timestep_duration
|
631 |
+
|
632 |
+
for token in word.tokens:
|
633 |
+
if token.s_start in num_to_first_alignment_appearance:
|
634 |
+
token.t_start = num_to_first_alignment_appearance[token.s_start] * output_timestep_duration
|
635 |
+
else:
|
636 |
+
token.t_start = -1
|
637 |
+
|
638 |
+
if token.s_end in num_to_last_alignment_appearance:
|
639 |
+
token.t_end = (
|
640 |
+
num_to_last_alignment_appearance[token.s_end] + 1
|
641 |
+
) * output_timestep_duration
|
642 |
+
else:
|
643 |
+
token.t_end = -1
|
644 |
+
else:
|
645 |
+
token = word_or_token
|
646 |
+
if token.s_start in num_to_first_alignment_appearance:
|
647 |
+
token.t_start = num_to_first_alignment_appearance[token.s_start] * output_timestep_duration
|
648 |
+
else:
|
649 |
+
token.t_start = -1
|
650 |
+
|
651 |
+
if token.s_end in num_to_last_alignment_appearance:
|
652 |
+
token.t_end = (num_to_last_alignment_appearance[token.s_end] + 1) * output_timestep_duration
|
653 |
+
else:
|
654 |
+
token.t_end = -1
|
655 |
+
|
656 |
+
else:
|
657 |
+
token = segment_or_token
|
658 |
+
if token.s_start in num_to_first_alignment_appearance:
|
659 |
+
token.t_start = num_to_first_alignment_appearance[token.s_start] * output_timestep_duration
|
660 |
+
else:
|
661 |
+
token.t_start = -1
|
662 |
+
|
663 |
+
if token.s_end in num_to_last_alignment_appearance:
|
664 |
+
token.t_end = (num_to_last_alignment_appearance[token.s_end] + 1) * output_timestep_duration
|
665 |
+
else:
|
666 |
+
token.t_end = -1
|
667 |
+
|
668 |
+
return utt_obj
|
669 |
+
|
670 |
+
|
671 |
+
def get_batch_variables(
|
672 |
+
manifest_lines_batch,
|
673 |
+
model,
|
674 |
+
separator,
|
675 |
+
align_using_pred_text,
|
676 |
+
audio_filepath_parts_in_utt_id,
|
677 |
+
output_timestep_duration,
|
678 |
+
simulate_cache_aware_streaming=False,
|
679 |
+
use_buffered_chunked_streaming=False,
|
680 |
+
buffered_chunk_params={},
|
681 |
+
):
|
682 |
+
"""
|
683 |
+
Returns:
|
684 |
+
log_probs, y, T, U (y and U are s.t. every other token is a blank) - these are the tensors we will need
|
685 |
+
during Viterbi decoding.
|
686 |
+
utt_obj_batch: a list of Utterance objects for every utterance in the batch.
|
687 |
+
output_timestep_duration: a float indicating the duration of a single output timestep from
|
688 |
+
the ASR Model.
|
689 |
+
"""
|
690 |
+
|
691 |
+
# get hypotheses by calling 'transcribe'
|
692 |
+
# we will use the output log_probs, the duration of the log_probs,
|
693 |
+
# and (optionally) the predicted ASR text from the hypotheses
|
694 |
+
audio_filepaths_batch = [line["audio_filepath"] for line in manifest_lines_batch]
|
695 |
+
B = len(audio_filepaths_batch)
|
696 |
+
log_probs_list_batch = []
|
697 |
+
T_list_batch = []
|
698 |
+
pred_text_batch = []
|
699 |
+
|
700 |
+
if not use_buffered_chunked_streaming:
|
701 |
+
if not simulate_cache_aware_streaming:
|
702 |
+
with torch.no_grad():
|
703 |
+
hypotheses = model.transcribe(audio_filepaths_batch, return_hypotheses=True, batch_size=B)
|
704 |
+
else:
|
705 |
+
with torch.no_grad():
|
706 |
+
hypotheses = model.transcribe_simulate_cache_aware_streaming(
|
707 |
+
audio_filepaths_batch, return_hypotheses=True, batch_size=B
|
708 |
+
)
|
709 |
+
|
710 |
+
# if hypotheses form a tuple (from Hybrid model), extract just "best" hypothesis
|
711 |
+
if type(hypotheses) == tuple and len(hypotheses) == 2:
|
712 |
+
hypotheses = hypotheses[0]
|
713 |
+
|
714 |
+
for hypothesis in hypotheses:
|
715 |
+
log_probs_list_batch.append(hypothesis.y_sequence)
|
716 |
+
T_list_batch.append(hypothesis.y_sequence.shape[0])
|
717 |
+
pred_text_batch.append(hypothesis.text)
|
718 |
+
else:
|
719 |
+
delay = buffered_chunk_params["delay"]
|
720 |
+
model_stride_in_secs = buffered_chunk_params["model_stride_in_secs"]
|
721 |
+
tokens_per_chunk = buffered_chunk_params["tokens_per_chunk"]
|
722 |
+
for l in tqdm(audio_filepaths_batch, desc="Sample:"):
|
723 |
+
model.reset()
|
724 |
+
model.read_audio_file(l, delay, model_stride_in_secs)
|
725 |
+
hyp, logits = model.transcribe(tokens_per_chunk, delay, keep_logits=True)
|
726 |
+
log_probs_list_batch.append(logits)
|
727 |
+
T_list_batch.append(logits.shape[0])
|
728 |
+
pred_text_batch.append(hyp)
|
729 |
+
|
730 |
+
# we loop over every line in the manifest that is in our current batch,
|
731 |
+
# and record the y (list of tokens, including blanks), U (list of lengths of y) and
|
732 |
+
# token_info_batch, word_info_batch, segment_info_batch
|
733 |
+
y_list_batch = []
|
734 |
+
U_list_batch = []
|
735 |
+
utt_obj_batch = []
|
736 |
+
|
737 |
+
for i_line, line in enumerate(manifest_lines_batch):
|
738 |
+
if align_using_pred_text:
|
739 |
+
gt_text_for_alignment = " ".join(pred_text_batch[i_line].split())
|
740 |
+
else:
|
741 |
+
gt_text_for_alignment = line["text"]
|
742 |
+
utt_obj = get_utt_obj(
|
743 |
+
gt_text_for_alignment,
|
744 |
+
model,
|
745 |
+
separator,
|
746 |
+
T_list_batch[i_line],
|
747 |
+
audio_filepaths_batch[i_line],
|
748 |
+
_get_utt_id(audio_filepaths_batch[i_line], audio_filepath_parts_in_utt_id),
|
749 |
+
)
|
750 |
+
|
751 |
+
# update utt_obj.pred_text or utt_obj.text
|
752 |
+
if align_using_pred_text:
|
753 |
+
utt_obj.pred_text = pred_text_batch[i_line]
|
754 |
+
if len(utt_obj.pred_text) == 0:
|
755 |
+
logging.info(
|
756 |
+
f"'pred_text' of utterance {utt_obj.utt_id} is empty - we will not generate"
|
757 |
+
" any output alignment files for this utterance"
|
758 |
+
)
|
759 |
+
if "text" in line:
|
760 |
+
utt_obj.text = line["text"] # keep the text as we will save it in the output manifest
|
761 |
+
else:
|
762 |
+
utt_obj.text = line["text"]
|
763 |
+
if len(utt_obj.text) == 0:
|
764 |
+
logging.info(
|
765 |
+
f"'text' of utterance {utt_obj.utt_id} is empty - we will not generate"
|
766 |
+
" any output alignment files for this utterance"
|
767 |
+
)
|
768 |
+
|
769 |
+
y_list_batch.append(utt_obj.token_ids_with_blanks)
|
770 |
+
U_list_batch.append(len(utt_obj.token_ids_with_blanks))
|
771 |
+
utt_obj_batch.append(utt_obj)
|
772 |
+
|
773 |
+
# turn log_probs, y, T, U into dense tensors for fast computation during Viterbi decoding
|
774 |
+
T_max = max(T_list_batch)
|
775 |
+
U_max = max(U_list_batch)
|
776 |
+
# V = the number of tokens in the vocabulary + 1 for the blank token.
|
777 |
+
if hasattr(model, 'tokenizer'):
|
778 |
+
V = len(model.tokenizer.vocab) + 1
|
779 |
+
else:
|
780 |
+
V = len(model.decoder.vocabulary) + 1
|
781 |
+
T_batch = torch.tensor(T_list_batch)
|
782 |
+
U_batch = torch.tensor(U_list_batch)
|
783 |
+
|
784 |
+
# make log_probs_batch tensor of shape (B x T_max x V)
|
785 |
+
log_probs_batch = V_NEGATIVE_NUM * torch.ones((B, T_max, V))
|
786 |
+
for b, log_probs_utt in enumerate(log_probs_list_batch):
|
787 |
+
t = log_probs_utt.shape[0]
|
788 |
+
log_probs_batch[b, :t, :] = log_probs_utt
|
789 |
+
|
790 |
+
# make y tensor of shape (B x U_max)
|
791 |
+
# populate it initially with all 'V' numbers so that the 'V's will remain in the areas that
|
792 |
+
# are 'padding'. This will be useful for when we make 'log_probs_reorderd' during Viterbi decoding
|
793 |
+
# in a different function.
|
794 |
+
y_batch = V * torch.ones((B, U_max), dtype=torch.int64)
|
795 |
+
for b, y_utt in enumerate(y_list_batch):
|
796 |
+
U_utt = U_batch[b]
|
797 |
+
y_batch[b, :U_utt] = torch.tensor(y_utt)
|
798 |
+
|
799 |
+
# calculate output_timestep_duration if it is None
|
800 |
+
if output_timestep_duration is None:
|
801 |
+
if not 'window_stride' in model.cfg.preprocessor:
|
802 |
+
raise ValueError(
|
803 |
+
"Don't have attribute 'window_stride' in 'model.cfg.preprocessor' => cannot calculate "
|
804 |
+
" model_downsample_factor => stopping process"
|
805 |
+
)
|
806 |
+
|
807 |
+
if not 'sample_rate' in model.cfg.preprocessor:
|
808 |
+
raise ValueError(
|
809 |
+
"Don't have attribute 'sample_rate' in 'model.cfg.preprocessor' => cannot calculate start "
|
810 |
+
" and end time of segments => stopping process"
|
811 |
+
)
|
812 |
+
|
813 |
+
with sf.SoundFile(audio_filepaths_batch[0]) as f:
|
814 |
+
audio_dur = f.frames / f.samplerate
|
815 |
+
n_input_frames = audio_dur / model.cfg.preprocessor.window_stride
|
816 |
+
model_downsample_factor = round(n_input_frames / int(T_batch[0]))
|
817 |
+
|
818 |
+
output_timestep_duration = (
|
819 |
+
model.preprocessor.featurizer.hop_length * model_downsample_factor / model.cfg.preprocessor.sample_rate
|
820 |
+
)
|
821 |
+
|
822 |
+
logging.info(
|
823 |
+
f"Calculated that the model downsample factor is {model_downsample_factor}"
|
824 |
+
f" and therefore the ASR model output timestep duration is {output_timestep_duration}"
|
825 |
+
" -- will use this for all batches"
|
826 |
+
)
|
827 |
+
|
828 |
+
return (
|
829 |
+
log_probs_batch,
|
830 |
+
y_batch,
|
831 |
+
T_batch,
|
832 |
+
U_batch,
|
833 |
+
utt_obj_batch,
|
834 |
+
output_timestep_duration,
|
835 |
+
)
|
utils/make_ass_files.py
ADDED
@@ -0,0 +1,462 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""
|
16 |
+
This file contains functions for make ASS-format subtitle files based on the generated alignment.
|
17 |
+
ASS files can be generated highlighting token-level alignments or word-level alignments.
|
18 |
+
In both cases, 'segment' boundaries will be used to determine which parts of the text will appear
|
19 |
+
at the same time.
|
20 |
+
For the token-level ASS files, the text will be highlighted token-by-token, with the timings determined
|
21 |
+
by the NFA alignments.
|
22 |
+
For the word-level ASS files, the text will be highlighted word-by-word, with the timings determined
|
23 |
+
by the NFA alignemtns.
|
24 |
+
"""
|
25 |
+
|
26 |
+
import os
|
27 |
+
|
28 |
+
from utils.constants import BLANK_TOKEN, SPACE_TOKEN
|
29 |
+
from utils.data_prep import Segment, Token, Word
|
30 |
+
|
31 |
+
PLAYERRESX = 384
|
32 |
+
PLAYERRESY = 288
|
33 |
+
MARGINL = 10
|
34 |
+
MARGINR = 10
|
35 |
+
MARGINV = 20
|
36 |
+
|
37 |
+
|
38 |
+
def seconds_to_ass_format(seconds_float):
|
39 |
+
seconds_float = float(seconds_float)
|
40 |
+
mm, ss_decimals = divmod(seconds_float, 60)
|
41 |
+
hh, mm = divmod(mm, 60)
|
42 |
+
|
43 |
+
hh = str(round(hh))
|
44 |
+
if len(hh) == 1:
|
45 |
+
hh = '0' + hh
|
46 |
+
|
47 |
+
mm = str(round(mm))
|
48 |
+
if len(mm) == 1:
|
49 |
+
mm = '0' + mm
|
50 |
+
|
51 |
+
ss_decimals = f"{ss_decimals:.2f}"
|
52 |
+
if len(ss_decimals.split(".")[0]) == 1:
|
53 |
+
ss_decimals = "0" + ss_decimals
|
54 |
+
|
55 |
+
srt_format_time = f"{hh}:{mm}:{ss_decimals}"
|
56 |
+
|
57 |
+
return srt_format_time
|
58 |
+
|
59 |
+
|
60 |
+
def rgb_list_to_hex_bgr(rgb_list):
|
61 |
+
r, g, b = rgb_list
|
62 |
+
return f"{b:x}{g:x}{r:x}"
|
63 |
+
|
64 |
+
|
65 |
+
def make_ass_files(
|
66 |
+
utt_obj, output_dir_root, ass_file_config,
|
67 |
+
):
|
68 |
+
|
69 |
+
# don't try to make files if utt_obj.segments_and_tokens is empty, which will happen
|
70 |
+
# in the case of the ground truth text being empty or the number of tokens being too large vs audio duration
|
71 |
+
if not utt_obj.segments_and_tokens:
|
72 |
+
return utt_obj
|
73 |
+
|
74 |
+
if ass_file_config.resegment_text_to_fill_space:
|
75 |
+
utt_obj = resegment_utt_obj(utt_obj, ass_file_config)
|
76 |
+
|
77 |
+
utt_obj = make_word_level_ass_file(utt_obj, output_dir_root, ass_file_config,)
|
78 |
+
utt_obj = make_token_level_ass_file(utt_obj, output_dir_root, ass_file_config,)
|
79 |
+
|
80 |
+
return utt_obj
|
81 |
+
|
82 |
+
|
83 |
+
def _get_word_n_chars(word):
|
84 |
+
n_chars = 0
|
85 |
+
for token in word.tokens:
|
86 |
+
if token.text != BLANK_TOKEN:
|
87 |
+
n_chars += len(token.text)
|
88 |
+
return n_chars
|
89 |
+
|
90 |
+
|
91 |
+
def _get_segment_n_chars(segment):
|
92 |
+
n_chars = 0
|
93 |
+
for word_or_token in segment.words_and_tokens:
|
94 |
+
if word_or_token.text == SPACE_TOKEN:
|
95 |
+
n_chars += 1
|
96 |
+
elif word_or_token.text != BLANK_TOKEN:
|
97 |
+
n_chars += len(word_or_token.text)
|
98 |
+
return n_chars
|
99 |
+
|
100 |
+
|
101 |
+
def resegment_utt_obj(utt_obj, ass_file_config):
|
102 |
+
|
103 |
+
# get list of just all words and tokens
|
104 |
+
all_words_and_tokens = []
|
105 |
+
for segment_or_token in utt_obj.segments_and_tokens:
|
106 |
+
if type(segment_or_token) is Segment:
|
107 |
+
all_words_and_tokens.extend(segment_or_token.words_and_tokens)
|
108 |
+
else:
|
109 |
+
all_words_and_tokens.append(segment_or_token)
|
110 |
+
|
111 |
+
# figure out how many chars will fit into one 'slide' and thus should be the max
|
112 |
+
# size of a segment
|
113 |
+
approx_chars_per_line = (PLAYERRESX - MARGINL - MARGINR) / (
|
114 |
+
ass_file_config.fontsize * 0.6
|
115 |
+
) # assume chars 0.6 as wide as they are tall
|
116 |
+
approx_lines_per_segment = (PLAYERRESY - MARGINV) / (
|
117 |
+
ass_file_config.fontsize * 1.15
|
118 |
+
) # assume line spacing is 1.15
|
119 |
+
if approx_lines_per_segment > ass_file_config.max_lines_per_segment:
|
120 |
+
approx_lines_per_segment = ass_file_config.max_lines_per_segment
|
121 |
+
|
122 |
+
max_chars_per_segment = int(approx_chars_per_line * approx_lines_per_segment)
|
123 |
+
|
124 |
+
new_segments_and_tokens = []
|
125 |
+
all_words_and_tokens_pointer = 0
|
126 |
+
for word_or_token in all_words_and_tokens:
|
127 |
+
if type(word_or_token) is Token:
|
128 |
+
new_segments_and_tokens.append(word_or_token)
|
129 |
+
all_words_and_tokens_pointer += 1
|
130 |
+
else:
|
131 |
+
break
|
132 |
+
|
133 |
+
new_segments_and_tokens.append(Segment())
|
134 |
+
|
135 |
+
while all_words_and_tokens_pointer < len(all_words_and_tokens):
|
136 |
+
word_or_token = all_words_and_tokens[all_words_and_tokens_pointer]
|
137 |
+
if type(word_or_token) is Word:
|
138 |
+
|
139 |
+
# if this is going to be the first word in the segment, we definitely want
|
140 |
+
# to add it to the segment
|
141 |
+
if not new_segments_and_tokens[-1].words_and_tokens:
|
142 |
+
new_segments_and_tokens[-1].words_and_tokens.append(word_or_token)
|
143 |
+
|
144 |
+
else:
|
145 |
+
# if not the first word, check what the new length of the segment will be
|
146 |
+
# if short enough - add this word to this segment;
|
147 |
+
# if too long - add to a new segment
|
148 |
+
this_word_n_chars = _get_word_n_chars(word_or_token)
|
149 |
+
segment_so_far_n_chars = _get_segment_n_chars(new_segments_and_tokens[-1])
|
150 |
+
if this_word_n_chars + segment_so_far_n_chars < max_chars_per_segment:
|
151 |
+
new_segments_and_tokens[-1].words_and_tokens.append(word_or_token)
|
152 |
+
else:
|
153 |
+
new_segments_and_tokens.append(Segment())
|
154 |
+
new_segments_and_tokens[-1].words_and_tokens.append(word_or_token)
|
155 |
+
|
156 |
+
else: # i.e. word_or_token is a token
|
157 |
+
# currently this breaks the convention of tokens at the end/beginning
|
158 |
+
# of segments being listed as separate tokens in segment.word_and_tokens
|
159 |
+
# TODO: change code so we follow this convention
|
160 |
+
new_segments_and_tokens[-1].words_and_tokens.append(word_or_token)
|
161 |
+
|
162 |
+
all_words_and_tokens_pointer += 1
|
163 |
+
|
164 |
+
utt_obj.segments_and_tokens = new_segments_and_tokens
|
165 |
+
|
166 |
+
return utt_obj
|
167 |
+
|
168 |
+
|
169 |
+
def make_word_level_ass_file(
|
170 |
+
utt_obj, output_dir_root, ass_file_config,
|
171 |
+
):
|
172 |
+
|
173 |
+
default_style_dict = {
|
174 |
+
"Name": "Default",
|
175 |
+
"Fontname": "Arial",
|
176 |
+
"Fontsize": str(ass_file_config.fontsize),
|
177 |
+
"PrimaryColour": "&Hffffff",
|
178 |
+
"SecondaryColour": "&Hffffff",
|
179 |
+
"OutlineColour": "&H0",
|
180 |
+
"BackColour": "&H0",
|
181 |
+
"Bold": "0",
|
182 |
+
"Italic": "0",
|
183 |
+
"Underline": "0",
|
184 |
+
"StrikeOut": "0",
|
185 |
+
"ScaleX": "100",
|
186 |
+
"ScaleY": "100",
|
187 |
+
"Spacing": "0",
|
188 |
+
"Angle": "0",
|
189 |
+
"BorderStyle": "1",
|
190 |
+
"Outline": "1",
|
191 |
+
"Shadow": "0",
|
192 |
+
"Alignment": None, # will specify below
|
193 |
+
"MarginL": str(MARGINL),
|
194 |
+
"MarginR": str(MARGINR),
|
195 |
+
"MarginV": str(MARGINV),
|
196 |
+
"Encoding": "0",
|
197 |
+
}
|
198 |
+
|
199 |
+
if ass_file_config.vertical_alignment == "top":
|
200 |
+
default_style_dict["Alignment"] = "8" # text will be 'center-justified' and in the top of the screen
|
201 |
+
elif ass_file_config.vertical_alignment == "center":
|
202 |
+
default_style_dict["Alignment"] = "5" # text will be 'center-justified' and in the middle of the screen
|
203 |
+
elif ass_file_config.vertical_alignment == "bottom":
|
204 |
+
default_style_dict["Alignment"] = "2" # text will be 'center-justified' and in the bottom of the screen
|
205 |
+
else:
|
206 |
+
raise ValueError(f"got an unexpected value for ass_file_config.vertical_alignment")
|
207 |
+
|
208 |
+
output_dir = os.path.join(output_dir_root, "ass", "words")
|
209 |
+
os.makedirs(output_dir, exist_ok=True)
|
210 |
+
output_file = os.path.join(output_dir, f"{utt_obj.utt_id}.ass")
|
211 |
+
|
212 |
+
already_spoken_color_code = r"{\c&H" + rgb_list_to_hex_bgr(ass_file_config.text_already_spoken_rgb) + r"&}"
|
213 |
+
being_spoken_color_code = r"{\c&H" + rgb_list_to_hex_bgr(ass_file_config.text_being_spoken_rgb) + r"&}"
|
214 |
+
not_yet_spoken_color_code = r"{\c&H" + rgb_list_to_hex_bgr(ass_file_config.text_not_yet_spoken_rgb) + r"&}"
|
215 |
+
|
216 |
+
with open(output_file, 'w') as f:
|
217 |
+
default_style_top_line = "Format: " + ", ".join(default_style_dict.keys())
|
218 |
+
default_style_bottom_line = "Style: " + ",".join(default_style_dict.values())
|
219 |
+
|
220 |
+
f.write(
|
221 |
+
(
|
222 |
+
"[Script Info]\n"
|
223 |
+
"ScriptType: v4.00+\n"
|
224 |
+
f"PlayResX: {PLAYERRESX}\n"
|
225 |
+
f"PlayResY: {PLAYERRESY}\n"
|
226 |
+
"\n"
|
227 |
+
"[V4+ Styles]\n"
|
228 |
+
f"{default_style_top_line}\n"
|
229 |
+
f"{default_style_bottom_line}\n"
|
230 |
+
"\n"
|
231 |
+
"[Events]\n"
|
232 |
+
"Format: Layer, Start, End, Style, Name, MarginL, MarginR, MarginV, Effect, Text\n\n"
|
233 |
+
)
|
234 |
+
)
|
235 |
+
|
236 |
+
# write first set of subtitles for text before speech starts to be spoken
|
237 |
+
words_in_first_segment = []
|
238 |
+
for segment_or_token in utt_obj.segments_and_tokens:
|
239 |
+
if type(segment_or_token) is Segment:
|
240 |
+
first_segment = segment_or_token
|
241 |
+
|
242 |
+
for word_or_token in first_segment.words_and_tokens:
|
243 |
+
if type(word_or_token) is Word:
|
244 |
+
words_in_first_segment.append(word_or_token)
|
245 |
+
break
|
246 |
+
|
247 |
+
text_before_speech = not_yet_spoken_color_code + " ".join([x.text for x in words_in_first_segment]) + r"{\r}"
|
248 |
+
subtitle_text = (
|
249 |
+
f"Dialogue: 0,{seconds_to_ass_format(0)},{seconds_to_ass_format(words_in_first_segment[0].t_start)},Default,,0,0,0,,"
|
250 |
+
+ text_before_speech.rstrip()
|
251 |
+
)
|
252 |
+
|
253 |
+
f.write(subtitle_text + '\n')
|
254 |
+
|
255 |
+
for segment_or_token in utt_obj.segments_and_tokens:
|
256 |
+
if type(segment_or_token) is Segment:
|
257 |
+
segment = segment_or_token
|
258 |
+
|
259 |
+
words_in_segment = []
|
260 |
+
for word_or_token in segment.words_and_tokens:
|
261 |
+
if type(word_or_token) is Word:
|
262 |
+
words_in_segment.append(word_or_token)
|
263 |
+
|
264 |
+
for word_i, word in enumerate(words_in_segment):
|
265 |
+
|
266 |
+
text_before = " ".join([x.text for x in words_in_segment[:word_i]])
|
267 |
+
if text_before != "":
|
268 |
+
text_before += " "
|
269 |
+
text_before = already_spoken_color_code + text_before + r"{\r}"
|
270 |
+
|
271 |
+
if word_i < len(words_in_segment) - 1:
|
272 |
+
text_after = " " + " ".join([x.text for x in words_in_segment[word_i + 1 :]])
|
273 |
+
else:
|
274 |
+
text_after = ""
|
275 |
+
text_after = not_yet_spoken_color_code + text_after + r"{\r}"
|
276 |
+
|
277 |
+
aligned_text = being_spoken_color_code + word.text + r"{\r}"
|
278 |
+
aligned_text_off = already_spoken_color_code + word.text + r"{\r}"
|
279 |
+
|
280 |
+
subtitle_text = (
|
281 |
+
f"Dialogue: 0,{seconds_to_ass_format(word.t_start)},{seconds_to_ass_format(word.t_end)},Default,,0,0,0,,"
|
282 |
+
+ text_before
|
283 |
+
+ aligned_text
|
284 |
+
+ text_after.rstrip()
|
285 |
+
)
|
286 |
+
f.write(subtitle_text + '\n')
|
287 |
+
|
288 |
+
# add subtitles without word-highlighting for when words are not being spoken
|
289 |
+
if word_i < len(words_in_segment) - 1:
|
290 |
+
last_word_end = float(words_in_segment[word_i].t_end)
|
291 |
+
next_word_start = float(words_in_segment[word_i + 1].t_start)
|
292 |
+
if next_word_start - last_word_end > 0.001:
|
293 |
+
subtitle_text = (
|
294 |
+
f"Dialogue: 0,{seconds_to_ass_format(last_word_end)},{seconds_to_ass_format(next_word_start)},Default,,0,0,0,,"
|
295 |
+
+ text_before
|
296 |
+
+ aligned_text_off
|
297 |
+
+ text_after.rstrip()
|
298 |
+
)
|
299 |
+
f.write(subtitle_text + '\n')
|
300 |
+
|
301 |
+
utt_obj.saved_output_files[f"words_level_ass_filepath"] = output_file
|
302 |
+
|
303 |
+
return utt_obj
|
304 |
+
|
305 |
+
|
306 |
+
def make_token_level_ass_file(
|
307 |
+
utt_obj, output_dir_root, ass_file_config,
|
308 |
+
):
|
309 |
+
|
310 |
+
default_style_dict = {
|
311 |
+
"Name": "Default",
|
312 |
+
"Fontname": "Arial",
|
313 |
+
"Fontsize": str(ass_file_config.fontsize),
|
314 |
+
"PrimaryColour": "&Hffffff",
|
315 |
+
"SecondaryColour": "&Hffffff",
|
316 |
+
"OutlineColour": "&H0",
|
317 |
+
"BackColour": "&H0",
|
318 |
+
"Bold": "0",
|
319 |
+
"Italic": "0",
|
320 |
+
"Underline": "0",
|
321 |
+
"StrikeOut": "0",
|
322 |
+
"ScaleX": "100",
|
323 |
+
"ScaleY": "100",
|
324 |
+
"Spacing": "0",
|
325 |
+
"Angle": "0",
|
326 |
+
"BorderStyle": "1",
|
327 |
+
"Outline": "1",
|
328 |
+
"Shadow": "0",
|
329 |
+
"Alignment": None, # will specify below
|
330 |
+
"MarginL": str(MARGINL),
|
331 |
+
"MarginR": str(MARGINR),
|
332 |
+
"MarginV": str(MARGINV),
|
333 |
+
"Encoding": "0",
|
334 |
+
}
|
335 |
+
|
336 |
+
if ass_file_config.vertical_alignment == "top":
|
337 |
+
default_style_dict["Alignment"] = "8" # text will be 'center-justified' and in the top of the screen
|
338 |
+
elif ass_file_config.vertical_alignment == "center":
|
339 |
+
default_style_dict["Alignment"] = "5" # text will be 'center-justified' and in the middle of the screen
|
340 |
+
elif ass_file_config.vertical_alignment == "bottom":
|
341 |
+
default_style_dict["Alignment"] = "2" # text will be 'center-justified' and in the bottom of the screen
|
342 |
+
else:
|
343 |
+
raise ValueError(f"got an unexpected value for ass_file_config.vertical_alignment")
|
344 |
+
|
345 |
+
output_dir = os.path.join(output_dir_root, "ass", "tokens")
|
346 |
+
os.makedirs(output_dir, exist_ok=True)
|
347 |
+
output_file = os.path.join(output_dir, f"{utt_obj.utt_id}.ass")
|
348 |
+
|
349 |
+
already_spoken_color_code = r"{\c&H" + rgb_list_to_hex_bgr(ass_file_config.text_already_spoken_rgb) + r"&}"
|
350 |
+
being_spoken_color_code = r"{\c&H" + rgb_list_to_hex_bgr(ass_file_config.text_being_spoken_rgb) + r"&}"
|
351 |
+
not_yet_spoken_color_code = r"{\c&H" + rgb_list_to_hex_bgr(ass_file_config.text_not_yet_spoken_rgb) + r"&}"
|
352 |
+
|
353 |
+
with open(output_file, 'w') as f:
|
354 |
+
default_style_top_line = "Format: " + ", ".join(default_style_dict.keys())
|
355 |
+
default_style_bottom_line = "Style: " + ",".join(default_style_dict.values())
|
356 |
+
|
357 |
+
f.write(
|
358 |
+
(
|
359 |
+
"[Script Info]\n"
|
360 |
+
"ScriptType: v4.00+\n"
|
361 |
+
f"PlayResX: {PLAYERRESX}\n"
|
362 |
+
f"PlayResY: {PLAYERRESY}\n"
|
363 |
+
"ScaledBorderAndShadow: yes\n"
|
364 |
+
"\n"
|
365 |
+
"[V4+ Styles]\n"
|
366 |
+
f"{default_style_top_line}\n"
|
367 |
+
f"{default_style_bottom_line}\n"
|
368 |
+
"\n"
|
369 |
+
"[Events]\n"
|
370 |
+
"Format: Layer, Start, End, Style, Name, MarginL, MarginR, MarginV, Effect, Text\n\n"
|
371 |
+
)
|
372 |
+
)
|
373 |
+
|
374 |
+
# write first set of subtitles for text before speech starts to be spoken
|
375 |
+
tokens_in_first_segment = []
|
376 |
+
for segment_or_token in utt_obj.segments_and_tokens:
|
377 |
+
if type(segment_or_token) is Segment:
|
378 |
+
for word_or_token in segment_or_token.words_and_tokens:
|
379 |
+
if type(word_or_token) is Token:
|
380 |
+
if word_or_token.text != BLANK_TOKEN:
|
381 |
+
tokens_in_first_segment.append(word_or_token)
|
382 |
+
else:
|
383 |
+
for token in word_or_token.tokens:
|
384 |
+
if token.text != BLANK_TOKEN:
|
385 |
+
tokens_in_first_segment.append(token)
|
386 |
+
|
387 |
+
break
|
388 |
+
|
389 |
+
for token in tokens_in_first_segment:
|
390 |
+
token.text_cased = token.text_cased.replace(
|
391 |
+
"▁", " "
|
392 |
+
) # replace underscores used in subword tokens with spaces
|
393 |
+
token.text_cased = token.text_cased.replace(SPACE_TOKEN, " ") # space token with actual space
|
394 |
+
|
395 |
+
text_before_speech = (
|
396 |
+
not_yet_spoken_color_code + "".join([x.text_cased for x in tokens_in_first_segment]) + r"{\r}"
|
397 |
+
)
|
398 |
+
subtitle_text = (
|
399 |
+
f"Dialogue: 0,{seconds_to_ass_format(0)},{seconds_to_ass_format(tokens_in_first_segment[0].t_start)},Default,,0,0,0,,"
|
400 |
+
+ text_before_speech.rstrip()
|
401 |
+
)
|
402 |
+
|
403 |
+
f.write(subtitle_text + '\n')
|
404 |
+
|
405 |
+
for segment_or_token in utt_obj.segments_and_tokens:
|
406 |
+
if type(segment_or_token) is Segment:
|
407 |
+
segment = segment_or_token
|
408 |
+
|
409 |
+
tokens_in_segment = [] # make list of (non-blank) tokens
|
410 |
+
for word_or_token in segment.words_and_tokens:
|
411 |
+
if type(word_or_token) is Token:
|
412 |
+
if word_or_token.text != BLANK_TOKEN:
|
413 |
+
tokens_in_segment.append(word_or_token)
|
414 |
+
else:
|
415 |
+
for token in word_or_token.tokens:
|
416 |
+
if token.text != BLANK_TOKEN:
|
417 |
+
tokens_in_segment.append(token)
|
418 |
+
|
419 |
+
for token in tokens_in_segment:
|
420 |
+
token.text_cased = token.text_cased.replace(
|
421 |
+
"▁", " "
|
422 |
+
) # replace underscores used in subword tokens with spaces
|
423 |
+
token.text_cased = token.text_cased.replace(SPACE_TOKEN, " ") # space token with actual space
|
424 |
+
|
425 |
+
for token_i, token in enumerate(tokens_in_segment):
|
426 |
+
|
427 |
+
text_before = "".join([x.text_cased for x in tokens_in_segment[:token_i]])
|
428 |
+
text_before = already_spoken_color_code + text_before + r"{\r}"
|
429 |
+
|
430 |
+
if token_i < len(tokens_in_segment) - 1:
|
431 |
+
text_after = "".join([x.text_cased for x in tokens_in_segment[token_i + 1 :]])
|
432 |
+
else:
|
433 |
+
text_after = ""
|
434 |
+
text_after = not_yet_spoken_color_code + text_after + r"{\r}"
|
435 |
+
|
436 |
+
aligned_text = being_spoken_color_code + token.text_cased + r"{\r}"
|
437 |
+
aligned_text_off = already_spoken_color_code + token.text_cased + r"{\r}"
|
438 |
+
|
439 |
+
subtitle_text = (
|
440 |
+
f"Dialogue: 0,{seconds_to_ass_format(token.t_start)},{seconds_to_ass_format(token.t_end)},Default,,0,0,0,,"
|
441 |
+
+ text_before
|
442 |
+
+ aligned_text
|
443 |
+
+ text_after.rstrip()
|
444 |
+
)
|
445 |
+
f.write(subtitle_text + '\n')
|
446 |
+
|
447 |
+
# add subtitles without word-highlighting for when words are not being spoken
|
448 |
+
if token_i < len(tokens_in_segment) - 1:
|
449 |
+
last_token_end = float(tokens_in_segment[token_i].t_end)
|
450 |
+
next_token_start = float(tokens_in_segment[token_i + 1].t_start)
|
451 |
+
if next_token_start - last_token_end > 0.001:
|
452 |
+
subtitle_text = (
|
453 |
+
f"Dialogue: 0,{seconds_to_ass_format(last_token_end)},{seconds_to_ass_format(next_token_start)},Default,,0,0,0,,"
|
454 |
+
+ text_before
|
455 |
+
+ aligned_text_off
|
456 |
+
+ text_after.rstrip()
|
457 |
+
)
|
458 |
+
f.write(subtitle_text + '\n')
|
459 |
+
|
460 |
+
utt_obj.saved_output_files[f"tokens_level_ass_filepath"] = output_file
|
461 |
+
|
462 |
+
return utt_obj
|
utils/make_ctm_files.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import os
|
16 |
+
|
17 |
+
import soundfile as sf
|
18 |
+
from utils.constants import BLANK_TOKEN, SPACE_TOKEN
|
19 |
+
from utils.data_prep import Segment, Word
|
20 |
+
|
21 |
+
|
22 |
+
def make_ctm_files(
|
23 |
+
utt_obj, output_dir_root, ctm_file_config,
|
24 |
+
):
|
25 |
+
"""
|
26 |
+
Function to save CTM files for all the utterances in the incoming batch.
|
27 |
+
"""
|
28 |
+
|
29 |
+
# don't try to make files if utt_obj.segments_and_tokens is empty, which will happen
|
30 |
+
# in the case of the ground truth text being empty or the number of tokens being too large vs audio duration
|
31 |
+
if not utt_obj.segments_and_tokens:
|
32 |
+
return utt_obj
|
33 |
+
|
34 |
+
# get audio file duration if we will need it later
|
35 |
+
if ctm_file_config.minimum_timestamp_duration > 0:
|
36 |
+
with sf.SoundFile(utt_obj.audio_filepath) as f:
|
37 |
+
audio_file_duration = f.frames / f.samplerate
|
38 |
+
else:
|
39 |
+
audio_file_duration = None
|
40 |
+
|
41 |
+
utt_obj = make_ctm("tokens", utt_obj, output_dir_root, audio_file_duration, ctm_file_config,)
|
42 |
+
utt_obj = make_ctm("words", utt_obj, output_dir_root, audio_file_duration, ctm_file_config,)
|
43 |
+
utt_obj = make_ctm("segments", utt_obj, output_dir_root, audio_file_duration, ctm_file_config,)
|
44 |
+
|
45 |
+
return utt_obj
|
46 |
+
|
47 |
+
|
48 |
+
def make_ctm(
|
49 |
+
alignment_level, utt_obj, output_dir_root, audio_file_duration, ctm_file_config,
|
50 |
+
):
|
51 |
+
output_dir = os.path.join(output_dir_root, "ctm", alignment_level)
|
52 |
+
os.makedirs(output_dir, exist_ok=True)
|
53 |
+
|
54 |
+
boundary_info_utt = []
|
55 |
+
for segment_or_token in utt_obj.segments_and_tokens:
|
56 |
+
if type(segment_or_token) is Segment:
|
57 |
+
segment = segment_or_token
|
58 |
+
if alignment_level == "segments":
|
59 |
+
boundary_info_utt.append(segment)
|
60 |
+
|
61 |
+
for word_or_token in segment.words_and_tokens:
|
62 |
+
if type(word_or_token) is Word:
|
63 |
+
word = word_or_token
|
64 |
+
if alignment_level == "words":
|
65 |
+
boundary_info_utt.append(word)
|
66 |
+
|
67 |
+
for token in word.tokens:
|
68 |
+
if alignment_level == "tokens":
|
69 |
+
boundary_info_utt.append(token)
|
70 |
+
|
71 |
+
else:
|
72 |
+
token = word_or_token
|
73 |
+
if alignment_level == "tokens":
|
74 |
+
boundary_info_utt.append(token)
|
75 |
+
|
76 |
+
else:
|
77 |
+
token = segment_or_token
|
78 |
+
if alignment_level == "tokens":
|
79 |
+
boundary_info_utt.append(token)
|
80 |
+
|
81 |
+
with open(os.path.join(output_dir, f"{utt_obj.utt_id}.ctm"), "w") as f_ctm:
|
82 |
+
for boundary_info_ in boundary_info_utt: # loop over every token/word/segment
|
83 |
+
|
84 |
+
# skip if t_start = t_end = negative number because we used it as a marker to skip some blank tokens
|
85 |
+
if not (boundary_info_.t_start < 0 or boundary_info_.t_end < 0):
|
86 |
+
text = boundary_info_.text
|
87 |
+
start_time = boundary_info_.t_start
|
88 |
+
end_time = boundary_info_.t_end
|
89 |
+
|
90 |
+
if (
|
91 |
+
ctm_file_config.minimum_timestamp_duration > 0
|
92 |
+
and ctm_file_config.minimum_timestamp_duration > end_time - start_time
|
93 |
+
):
|
94 |
+
# make the predicted duration of the token/word/segment longer, growing it outwards equal
|
95 |
+
# amounts from the predicted center of the token/word/segment
|
96 |
+
token_mid_point = (start_time + end_time) / 2
|
97 |
+
start_time = max(token_mid_point - ctm_file_config.minimum_timestamp_duration / 2, 0)
|
98 |
+
end_time = min(
|
99 |
+
token_mid_point + ctm_file_config.minimum_timestamp_duration / 2, audio_file_duration
|
100 |
+
)
|
101 |
+
|
102 |
+
if not (
|
103 |
+
text == BLANK_TOKEN and ctm_file_config.remove_blank_tokens
|
104 |
+
): # don't save blanks if we don't want to
|
105 |
+
# replace any spaces with <space> so we dont introduce extra space characters to our CTM files
|
106 |
+
text = text.replace(" ", SPACE_TOKEN)
|
107 |
+
|
108 |
+
f_ctm.write(f"{utt_obj.utt_id} 1 {start_time:.2f} {end_time - start_time:.2f} {text}\n")
|
109 |
+
|
110 |
+
utt_obj.saved_output_files[f"{alignment_level}_level_ctm_filepath"] = os.path.join(
|
111 |
+
output_dir, f"{utt_obj.utt_id}.ctm"
|
112 |
+
)
|
113 |
+
|
114 |
+
return utt_obj
|
utils/make_output_manifest.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import json
|
16 |
+
|
17 |
+
|
18 |
+
def write_manifest_out_line(
|
19 |
+
f_manifest_out, utt_obj,
|
20 |
+
):
|
21 |
+
|
22 |
+
data = {"audio_filepath": utt_obj.audio_filepath}
|
23 |
+
if not utt_obj.text is None:
|
24 |
+
data["text"] = utt_obj.text
|
25 |
+
|
26 |
+
if not utt_obj.pred_text is None:
|
27 |
+
data["pred_text"] = utt_obj.pred_text
|
28 |
+
|
29 |
+
for key, val in utt_obj.saved_output_files.items():
|
30 |
+
data[key] = val
|
31 |
+
|
32 |
+
new_line = json.dumps(data)
|
33 |
+
f_manifest_out.write(f"{new_line}\n")
|
34 |
+
|
35 |
+
return None
|
utils/viterbi_decoding.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import torch
|
16 |
+
from utils.constants import V_NEGATIVE_NUM
|
17 |
+
|
18 |
+
|
19 |
+
def viterbi_decoding(log_probs_batch, y_batch, T_batch, U_batch, viterbi_device):
|
20 |
+
"""
|
21 |
+
Do Viterbi decoding with an efficient algorithm (the only for-loop in the 'forward pass' is over the time dimension).
|
22 |
+
Args:
|
23 |
+
log_probs_batch: tensor of shape (B, T_max, V). The parts of log_probs_batch which are 'padding' are filled
|
24 |
+
with 'V_NEGATIVE_NUM' - a large negative number which represents a very low probability.
|
25 |
+
y_batch: tensor of shape (B, U_max) - contains token IDs including blanks in every other position. The parts of
|
26 |
+
y_batch which are padding are filled with the number 'V'. V = the number of tokens in the vocabulary + 1 for
|
27 |
+
the blank token.
|
28 |
+
T_batch: tensor of shape (B, 1) - contains the durations of the log_probs_batch (so we can ignore the
|
29 |
+
parts of log_probs_batch which are padding)
|
30 |
+
U_batch: tensor of shape (B, 1) - contains the lengths of y_batch (so we can ignore the parts of y_batch
|
31 |
+
which are padding).
|
32 |
+
viterbi_device: the torch device on which Viterbi decoding will be done.
|
33 |
+
|
34 |
+
Returns:
|
35 |
+
alignments_batch: list of lists containing locations for the tokens we align to at each timestep.
|
36 |
+
Looks like: [[0, 0, 1, 2, 2, 3, 3, ..., ], ..., [0, 1, 2, 2, 2, 3, 4, ....]].
|
37 |
+
Each list inside alignments_batch is of length T_batch[location of utt in batch].
|
38 |
+
"""
|
39 |
+
|
40 |
+
B, T_max, _ = log_probs_batch.shape
|
41 |
+
U_max = y_batch.shape[1]
|
42 |
+
|
43 |
+
# transfer all tensors to viterbi_device
|
44 |
+
log_probs_batch = log_probs_batch.to(viterbi_device)
|
45 |
+
y_batch = y_batch.to(viterbi_device)
|
46 |
+
T_batch = T_batch.to(viterbi_device)
|
47 |
+
U_batch = U_batch.to(viterbi_device)
|
48 |
+
|
49 |
+
# make tensor that we will put at timesteps beyond the duration of the audio
|
50 |
+
padding_for_log_probs = V_NEGATIVE_NUM * torch.ones((B, T_max, 1), device=viterbi_device)
|
51 |
+
# make log_probs_padded tensor of shape (B, T_max, V +1 ) where all of
|
52 |
+
# log_probs_padded[:,:,-1] is the 'V_NEGATIVE_NUM'
|
53 |
+
log_probs_padded = torch.cat((log_probs_batch, padding_for_log_probs), dim=2)
|
54 |
+
|
55 |
+
# initialize v_prev - tensor of previous timestep's viterbi probabilies, of shape (B, U_max)
|
56 |
+
v_prev = V_NEGATIVE_NUM * torch.ones((B, U_max), device=viterbi_device)
|
57 |
+
v_prev[:, :2] = torch.gather(input=log_probs_padded[:, 0, :], dim=1, index=y_batch[:, :2])
|
58 |
+
|
59 |
+
# initialize backpointers_rel - which contains values like 0 to indicate the backpointer is to the same u index,
|
60 |
+
# 1 to indicate the backpointer pointing to the u-1 index and 2 to indicate the backpointer is pointing to the u-2 index
|
61 |
+
backpointers_rel = -99 * torch.ones((B, T_max, U_max), dtype=torch.int8, device=viterbi_device)
|
62 |
+
|
63 |
+
# Make a letter_repetition_mask the same shape as y_batch
|
64 |
+
# the letter_repetition_mask will have 'True' where the token (including blanks) is the same
|
65 |
+
# as the token two places before it in the ground truth (and 'False everywhere else).
|
66 |
+
# We will use letter_repetition_mask to determine whether the Viterbi algorithm needs to look two tokens back or
|
67 |
+
# three tokens back
|
68 |
+
y_shifted_left = torch.roll(y_batch, shifts=2, dims=1)
|
69 |
+
letter_repetition_mask = y_batch - y_shifted_left
|
70 |
+
letter_repetition_mask[:, :2] = 1 # make sure dont apply mask to first 2 tokens
|
71 |
+
letter_repetition_mask = letter_repetition_mask == 0
|
72 |
+
|
73 |
+
for t in range(1, T_max):
|
74 |
+
|
75 |
+
# e_current is a tensor of shape (B, U_max) of the log probs of every possible token at the current timestep
|
76 |
+
e_current = torch.gather(input=log_probs_padded[:, t, :], dim=1, index=y_batch)
|
77 |
+
|
78 |
+
# apply a mask to e_current to cope with the fact that we do not keep the whole v_matrix and continue
|
79 |
+
# calculating viterbi probabilities during some 'padding' timesteps
|
80 |
+
t_exceeded_T_batch = t >= T_batch
|
81 |
+
|
82 |
+
U_can_be_final = torch.logical_or(
|
83 |
+
torch.arange(0, U_max, device=viterbi_device).unsqueeze(0) == (U_batch.unsqueeze(1) - 0),
|
84 |
+
torch.arange(0, U_max, device=viterbi_device).unsqueeze(0) == (U_batch.unsqueeze(1) - 1),
|
85 |
+
)
|
86 |
+
|
87 |
+
mask = torch.logical_not(torch.logical_and(t_exceeded_T_batch.unsqueeze(1), U_can_be_final,)).long()
|
88 |
+
|
89 |
+
e_current = e_current * mask
|
90 |
+
|
91 |
+
# v_prev_shifted is a tensor of shape (B, U_max) of the viterbi probabilities 1 timestep back and 1 token position back
|
92 |
+
v_prev_shifted = torch.roll(v_prev, shifts=1, dims=1)
|
93 |
+
# by doing a roll shift of size 1, we have brought the viterbi probability in the final token position to the
|
94 |
+
# first token position - let's overcome this by 'zeroing out' the probabilities in the firest token position
|
95 |
+
v_prev_shifted[:, 0] = V_NEGATIVE_NUM
|
96 |
+
|
97 |
+
# v_prev_shifted2 is a tensor of shape (B, U_max) of the viterbi probabilities 1 timestep back and 2 token position back
|
98 |
+
v_prev_shifted2 = torch.roll(v_prev, shifts=2, dims=1)
|
99 |
+
v_prev_shifted2[:, :2] = V_NEGATIVE_NUM # zero out as we did for v_prev_shifted
|
100 |
+
# use our letter_repetition_mask to remove the connections between 2 blanks (so we don't skip over a letter)
|
101 |
+
# and to remove the connections between 2 consective letters (so we don't skip over a blank)
|
102 |
+
v_prev_shifted2.masked_fill_(letter_repetition_mask, V_NEGATIVE_NUM)
|
103 |
+
|
104 |
+
# we need this v_prev_dup tensor so we can calculated the viterbi probability of every possible
|
105 |
+
# token position simultaneously
|
106 |
+
v_prev_dup = torch.cat(
|
107 |
+
(v_prev.unsqueeze(2), v_prev_shifted.unsqueeze(2), v_prev_shifted2.unsqueeze(2),), dim=2,
|
108 |
+
)
|
109 |
+
|
110 |
+
# candidates_v_current are our candidate viterbi probabilities for every token position, from which
|
111 |
+
# we will pick the max and record the argmax
|
112 |
+
candidates_v_current = v_prev_dup + e_current.unsqueeze(2)
|
113 |
+
# we straight away save results in v_prev instead of v_current, so that the variable v_prev will be ready for the
|
114 |
+
# next iteration of the for-loop
|
115 |
+
v_prev, bp_relative = torch.max(candidates_v_current, dim=2)
|
116 |
+
|
117 |
+
backpointers_rel[:, t, :] = bp_relative
|
118 |
+
|
119 |
+
# trace backpointers
|
120 |
+
alignments_batch = []
|
121 |
+
for b in range(B):
|
122 |
+
T_b = int(T_batch[b])
|
123 |
+
U_b = int(U_batch[b])
|
124 |
+
|
125 |
+
if U_b == 1: # i.e. we put only a blank token in the reference text because the reference text is empty
|
126 |
+
current_u = 0 # set initial u to 0 and let the rest of the code block run as usual
|
127 |
+
else:
|
128 |
+
current_u = int(torch.argmax(v_prev[b, U_b - 2 : U_b])) + U_b - 2
|
129 |
+
alignment_b = [current_u]
|
130 |
+
for t in range(T_max - 1, 0, -1):
|
131 |
+
current_u = current_u - int(backpointers_rel[b, t, current_u])
|
132 |
+
alignment_b.insert(0, current_u)
|
133 |
+
alignment_b = alignment_b[:T_b]
|
134 |
+
alignments_batch.append(alignment_b)
|
135 |
+
|
136 |
+
return alignments_batch
|