Lakoc commited on
Commit
4020d91
1 Parent(s): 5c310a0

Update modeling_reguler.py

Browse files
Files changed (1) hide show
  1. modeling_reguler.py +22 -2
modeling_reguler.py CHANGED
@@ -385,14 +385,25 @@ class JointCTCAttentionEncoderDecoder(SpeechEncoderDecoderModel):
385
 
386
  def _get_logits_processor(
387
  self,
388
- generation_config: GenerationConfigWithCTC,
389
  input_ids_seq_length: int,
390
  encoder_input_ids: torch.LongTensor,
391
  prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]],
392
  logits_processor: Optional[LogitsProcessorList],
 
 
 
393
  ) -> LogitsProcessorList:
 
394
  processors = super()._get_logits_processor(
395
- generation_config, input_ids_seq_length, encoder_input_ids, prefix_allowed_tokens_fn, logits_processor
 
 
 
 
 
 
 
396
  )
397
  if hasattr(generation_config, "ctc_weight") and generation_config.ctc_weight > 0:
398
  if generation_config.num_beams <= 1:
@@ -405,8 +416,17 @@ class JointCTCAttentionEncoderDecoder(SpeechEncoderDecoderModel):
405
  self.generation_config.ctc_margin,
406
  self.generation_config.ctc_weight,
407
  self.generation_config.num_beams,
 
 
 
408
  )
409
  processors.append(self.ctc_rescorer)
 
 
 
 
 
 
410
  return processors
411
 
412
  def _prepare_encoder_decoder_kwargs_for_generation(
 
385
 
386
  def _get_logits_processor(
387
  self,
388
+ generation_config: GenerationConfig,
389
  input_ids_seq_length: int,
390
  encoder_input_ids: torch.LongTensor,
391
  prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]],
392
  logits_processor: Optional[LogitsProcessorList],
393
+ model_kwargs: Optional[Dict[str, Any]] = None,
394
+ negative_prompt_ids: Optional[torch.Tensor] = None,
395
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
396
  ) -> LogitsProcessorList:
397
+ # pylint: disable=no-member
398
  processors = super()._get_logits_processor(
399
+ generation_config,
400
+ input_ids_seq_length,
401
+ encoder_input_ids,
402
+ prefix_allowed_tokens_fn,
403
+ logits_processor,
404
+ model_kwargs,
405
+ negative_prompt_ids,
406
+ negative_prompt_attention_mask,
407
  )
408
  if hasattr(generation_config, "ctc_weight") and generation_config.ctc_weight > 0:
409
  if generation_config.num_beams <= 1:
 
416
  self.generation_config.ctc_margin,
417
  self.generation_config.ctc_weight,
418
  self.generation_config.num_beams,
419
+ self.generation_config.space_token_id,
420
+ self.generation_config.apply_eos_space_trick,
421
+ self.generation_config.eos_space_trick_weight,
422
  )
423
  processors.append(self.ctc_rescorer)
424
+ if hasattr(generation_config, "lm_weight") and generation_config.lm_weight > 0:
425
+ if not hasattr(generation_config, "lm_model"):
426
+ raise ValueError("If `lm_weight` is specified, make sure that `lm_model` is defined.")
427
+ processors.append(
428
+ LMRescorerLogitsProcessor(generation_config.lm_weight, generation_config.lm_model, device=self.device)
429
+ )
430
  return processors
431
 
432
  def _prepare_encoder_decoder_kwargs_for_generation(