Update modeling_reguler.py
Browse files- modeling_reguler.py +22 -2
modeling_reguler.py
CHANGED
@@ -385,14 +385,25 @@ class JointCTCAttentionEncoderDecoder(SpeechEncoderDecoderModel):
|
|
385 |
|
386 |
def _get_logits_processor(
|
387 |
self,
|
388 |
-
generation_config:
|
389 |
input_ids_seq_length: int,
|
390 |
encoder_input_ids: torch.LongTensor,
|
391 |
prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]],
|
392 |
logits_processor: Optional[LogitsProcessorList],
|
|
|
|
|
|
|
393 |
) -> LogitsProcessorList:
|
|
|
394 |
processors = super()._get_logits_processor(
|
395 |
-
generation_config,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
396 |
)
|
397 |
if hasattr(generation_config, "ctc_weight") and generation_config.ctc_weight > 0:
|
398 |
if generation_config.num_beams <= 1:
|
@@ -405,8 +416,17 @@ class JointCTCAttentionEncoderDecoder(SpeechEncoderDecoderModel):
|
|
405 |
self.generation_config.ctc_margin,
|
406 |
self.generation_config.ctc_weight,
|
407 |
self.generation_config.num_beams,
|
|
|
|
|
|
|
408 |
)
|
409 |
processors.append(self.ctc_rescorer)
|
|
|
|
|
|
|
|
|
|
|
|
|
410 |
return processors
|
411 |
|
412 |
def _prepare_encoder_decoder_kwargs_for_generation(
|
|
|
385 |
|
386 |
def _get_logits_processor(
|
387 |
self,
|
388 |
+
generation_config: GenerationConfig,
|
389 |
input_ids_seq_length: int,
|
390 |
encoder_input_ids: torch.LongTensor,
|
391 |
prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]],
|
392 |
logits_processor: Optional[LogitsProcessorList],
|
393 |
+
model_kwargs: Optional[Dict[str, Any]] = None,
|
394 |
+
negative_prompt_ids: Optional[torch.Tensor] = None,
|
395 |
+
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
396 |
) -> LogitsProcessorList:
|
397 |
+
# pylint: disable=no-member
|
398 |
processors = super()._get_logits_processor(
|
399 |
+
generation_config,
|
400 |
+
input_ids_seq_length,
|
401 |
+
encoder_input_ids,
|
402 |
+
prefix_allowed_tokens_fn,
|
403 |
+
logits_processor,
|
404 |
+
model_kwargs,
|
405 |
+
negative_prompt_ids,
|
406 |
+
negative_prompt_attention_mask,
|
407 |
)
|
408 |
if hasattr(generation_config, "ctc_weight") and generation_config.ctc_weight > 0:
|
409 |
if generation_config.num_beams <= 1:
|
|
|
416 |
self.generation_config.ctc_margin,
|
417 |
self.generation_config.ctc_weight,
|
418 |
self.generation_config.num_beams,
|
419 |
+
self.generation_config.space_token_id,
|
420 |
+
self.generation_config.apply_eos_space_trick,
|
421 |
+
self.generation_config.eos_space_trick_weight,
|
422 |
)
|
423 |
processors.append(self.ctc_rescorer)
|
424 |
+
if hasattr(generation_config, "lm_weight") and generation_config.lm_weight > 0:
|
425 |
+
if not hasattr(generation_config, "lm_model"):
|
426 |
+
raise ValueError("If `lm_weight` is specified, make sure that `lm_model` is defined.")
|
427 |
+
processors.append(
|
428 |
+
LMRescorerLogitsProcessor(generation_config.lm_weight, generation_config.lm_model, device=self.device)
|
429 |
+
)
|
430 |
return processors
|
431 |
|
432 |
def _prepare_encoder_decoder_kwargs_for_generation(
|