Update ctc_scorer.py
Browse files- ctc_scorer.py +66 -12
ctc_scorer.py
CHANGED
@@ -1,14 +1,7 @@
|
|
1 |
# pylint: skip-file
|
2 |
# Copied from: https://github.com/espnet/espnet/blob/master/espnet/nets/ctc_prefix_score.py
|
3 |
import torch
|
4 |
-
from transformers import
|
5 |
-
|
6 |
-
|
7 |
-
class GenerationConfigWithCTC(GenerationConfig):
|
8 |
-
def __init__(self, ctc_weight=0.0, ctc_margin=0, **kwargs):
|
9 |
-
super().__init__(**kwargs)
|
10 |
-
self.ctc_weight = ctc_weight
|
11 |
-
self.ctc_margin = ctc_margin
|
12 |
|
13 |
|
14 |
class CTCPrefixScoreTH(object):
|
@@ -93,7 +86,7 @@ class CTCPrefixScoreTH(object):
|
|
93 |
else:
|
94 |
r_prev, s_prev, f_min_prev, f_max_prev = state
|
95 |
|
96 |
-
# select input dimensions for
|
97 |
if self.scoring_num > 0:
|
98 |
scoring_idmap = torch.full((n_bh, self.odim), -1, dtype=torch.long, device=self.device)
|
99 |
snum = self.scoring_num
|
@@ -173,8 +166,8 @@ class CTCPrefixScoreTH(object):
|
|
173 |
dim=0,
|
174 |
)
|
175 |
|
176 |
-
for si in range(n_bh):
|
177 |
-
|
178 |
|
179 |
# exclude blank probs
|
180 |
log_psi[:, self.blank] = self.logzero
|
@@ -273,8 +266,14 @@ class CTCRescorerLogitsProcessor(LogitsProcessor):
|
|
273 |
ctc_margin: int,
|
274 |
ctc_weight: float,
|
275 |
num_beams: int,
|
|
|
|
|
|
|
|
|
276 |
):
|
277 |
super().__init__()
|
|
|
|
|
278 |
self.pad_token_id = pad_token_id
|
279 |
self.ctc_prefix_scorer = CTCPrefixScoreTH(
|
280 |
torch.nn.functional.log_softmax(encoder_logits, dim=-1),
|
@@ -286,6 +285,41 @@ class CTCRescorerLogitsProcessor(LogitsProcessor):
|
|
286 |
self.ctc_weight = ctc_weight
|
287 |
self.ctc_states = None
|
288 |
self.num_beams = num_beams
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
289 |
|
290 |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
291 |
scores[:, self.pad_token_id] = self.ctc_prefix_scorer.logzero
|
@@ -296,7 +330,27 @@ class CTCRescorerLogitsProcessor(LogitsProcessor):
|
|
296 |
ctc_scores, ctc_states = self.ctc_prefix_scorer(input_ids, self.ctc_states)
|
297 |
self.ctc_states = ctc_states
|
298 |
next_token_scores = (1 - self.ctc_weight) * scores + self.ctc_weight * ctc_scores
|
299 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
300 |
return next_token_scores
|
301 |
|
302 |
|
|
|
1 |
# pylint: skip-file
|
2 |
# Copied from: https://github.com/espnet/espnet/blob/master/espnet/nets/ctc_prefix_score.py
|
3 |
import torch
|
4 |
+
from transformers import LogitsProcessor
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
|
7 |
class CTCPrefixScoreTH(object):
|
|
|
86 |
else:
|
87 |
r_prev, s_prev, f_min_prev, f_max_prev = state
|
88 |
|
89 |
+
# select input dimensions for decred_scoring
|
90 |
if self.scoring_num > 0:
|
91 |
scoring_idmap = torch.full((n_bh, self.odim), -1, dtype=torch.long, device=self.device)
|
92 |
snum = self.scoring_num
|
|
|
166 |
dim=0,
|
167 |
)
|
168 |
|
169 |
+
# for si in range(n_bh):
|
170 |
+
# log_psi[si, self.eos] = r_sum[self.end_frames[si // n_hyps], si]
|
171 |
|
172 |
# exclude blank probs
|
173 |
log_psi[:, self.blank] = self.logzero
|
|
|
266 |
ctc_margin: int,
|
267 |
ctc_weight: float,
|
268 |
num_beams: int,
|
269 |
+
space_token_id: int,
|
270 |
+
apply_eos_space_trick: bool,
|
271 |
+
eos_space_trick_weight: float,
|
272 |
+
debug: bool = False,
|
273 |
):
|
274 |
super().__init__()
|
275 |
+
# reduce_lens_by = (encoder_logits.argmax(dim=-1) == eos_token_id).sum(dim=-1)
|
276 |
+
# encoder_output_lens = encoder_output_lens - reduce_lens_by
|
277 |
self.pad_token_id = pad_token_id
|
278 |
self.ctc_prefix_scorer = CTCPrefixScoreTH(
|
279 |
torch.nn.functional.log_softmax(encoder_logits, dim=-1),
|
|
|
285 |
self.ctc_weight = ctc_weight
|
286 |
self.ctc_states = None
|
287 |
self.num_beams = num_beams
|
288 |
+
self.eos_token_id = eos_token_id
|
289 |
+
self.apply_eos_space_trick = apply_eos_space_trick
|
290 |
+
self.space_token_id = space_token_id
|
291 |
+
self.eos_space_trick_weight = eos_space_trick_weight
|
292 |
+
self.debug = debug
|
293 |
+
|
294 |
+
@staticmethod
|
295 |
+
def analyze_predictions(
|
296 |
+
scores, ctc_scores, next_token_scores, input_ids, k=10, tokenizer="Lakoc/english_corpus_uni5000_normalized"
|
297 |
+
):
|
298 |
+
from transformers import AutoTokenizer
|
299 |
+
|
300 |
+
tokenizer = AutoTokenizer.from_pretrained(tokenizer)
|
301 |
+
best_att_ids = scores.topk(k=k, dim=1)
|
302 |
+
best_ctc_ids = ctc_scores.topk(k=k, dim=1)
|
303 |
+
best_ids = next_token_scores.topk(k=k, dim=1)
|
304 |
+
|
305 |
+
def print_prediction(best_ids, name):
|
306 |
+
new_tensor = torch.zeros((best_ids.indices.shape[0], best_ids.indices.shape[1] * 2), dtype=torch.long)
|
307 |
+
new_tensor[:, 0::2] = best_ids.indices
|
308 |
+
new_tensor[:, 1::2] = 4976
|
309 |
+
print(f"{name}:")
|
310 |
+
for index, (next_ids, scores) in enumerate(zip(tokenizer.batch_decode(new_tensor), best_ids.values)):
|
311 |
+
print(f"HYP {index}:\n{next_ids} {scores}")
|
312 |
+
|
313 |
+
print(f"PREFIX:")
|
314 |
+
for index, prefix in enumerate(tokenizer.batch_decode(input_ids)):
|
315 |
+
print(f"HYP {index}:\n{prefix}")
|
316 |
+
print_prediction(best_att_ids, "ATT_SCORES")
|
317 |
+
print()
|
318 |
+
print_prediction(best_ctc_ids, "CTC_SCORES")
|
319 |
+
print()
|
320 |
+
print(f"CTC_EOS: {ctc_scores[:, 1]}")
|
321 |
+
print_prediction(best_ids, "NEXT_TOKEN_SCORES")
|
322 |
+
print()
|
323 |
|
324 |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
325 |
scores[:, self.pad_token_id] = self.ctc_prefix_scorer.logzero
|
|
|
330 |
ctc_scores, ctc_states = self.ctc_prefix_scorer(input_ids, self.ctc_states)
|
331 |
self.ctc_states = ctc_states
|
332 |
next_token_scores = (1 - self.ctc_weight) * scores + self.ctc_weight * ctc_scores
|
333 |
+
if self.apply_eos_space_trick:
|
334 |
+
space_eos_conflict = torch.logical_and(
|
335 |
+
scores.argmax(dim=1) == self.eos_token_id, ctc_scores.argmax(dim=1) == self.space_token_id
|
336 |
+
)
|
337 |
+
if space_eos_conflict.any():
|
338 |
+
apply_trick_on = torch.logical_and(
|
339 |
+
torch.logical_and(
|
340 |
+
space_eos_conflict,
|
341 |
+
next_token_scores[:, self.eos_token_id] < next_token_scores[:, self.space_token_id],
|
342 |
+
),
|
343 |
+
self.eos_space_trick_weight * next_token_scores[:, self.eos_token_id]
|
344 |
+
> next_token_scores[:, self.space_token_id],
|
345 |
+
)
|
346 |
+
if apply_trick_on.any():
|
347 |
+
next_token_scores[apply_trick_on, self.eos_token_id] = (
|
348 |
+
next_token_scores[apply_trick_on, self.eos_token_id] * self.eos_space_trick_weight
|
349 |
+
)
|
350 |
+
|
351 |
+
if self.debug:
|
352 |
+
self.analyze_predictions(scores, ctc_scores, next_token_scores, input_ids)
|
353 |
+
|
354 |
return next_token_scores
|
355 |
|
356 |
|