from transformers.models.whisper.configuration_whisper import WhisperConfig from typing import List, Optional, Dict, Any, Callable """Custom config to support modification of the Whisper encoder.""" class CustomWhisperConfig(WhisperConfig): def __init__( self, use_first_embeddings: bool = False, embedding_stride: int = 1, conv_preprocessing_layers: Optional[List[Dict[str, Any]]] = None, slide_feature_dim: Optional[int] = None, conv_dropout: float = 0.0, conv_bias: bool = True, conv_activation: str = "gelu", skip_connections: bool = False, **kwargs ): super().__init__(**kwargs) self.use_first_embeddings = use_first_embeddings self.embedding_stride = embedding_stride self.slide_feature_dim = slide_feature_dim if conv_preprocessing_layers is None: conv_preprocessing_layers = [ { "in_channels": self.num_mel_bins, "out_channels": self.d_model, "kernel_size": 3, "stride": 1, "padding": 1, "activation": "gelu", "bias": True }, { "in_channels": self.d_model, "out_channels": self.d_model, "kernel_size": 3, "stride": 2, "padding": 1, "activation": "gelu", "bias": True } ] self.conv_preprocessing_layers = conv_preprocessing_layers self.conv_dropout = conv_dropout self.conv_bias = conv_bias self.conv_activation = conv_activation self.skip_connections = skip_connections """PyTorch Whisper model with customization.""" import math from typing import Optional, Tuple, Union import warnings import numpy as np import torch import torch.utils.checkpoint from torch import nn from torch.nn import CrossEntropyLoss import torch.nn.functional as F from transformers.generation.logits_process import ( LogitsProcessorList, SuppressTokensLogitsProcessor ) from transformers.activations import ACT2FN from transformers.cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache from transformers.generation import GenerationMixin from transformers.modeling_attn_mask_utils import AttentionMaskConverter from transformers.modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions, Seq2SeqLMOutput, Seq2SeqModelOutput, SequenceClassifierOutput, ) from transformers.modeling_utils import PreTrainedModel from transformers.utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, ) from transformers.models.whisper.generation_whisper import WhisperGenerationMixin if is_flash_attn_2_available(): from transformers.modeling_flash_attention_utils import _flash_attention_forward logger = logging.get_logger(__name__) _HIDDEN_STATES_START_POSITION = 1 _CONFIG_FOR_DOC = "WhisperConfig" _CHECKPOINT_FOR_DOC = "openai/whisper-tiny" def sinusoids(length: int, channels: int, max_timescale: float = 10000) -> torch.Tensor: """Returns sinusoids for positional embedding""" if channels % 2 != 0: raise ValueError( f"Number of channels has to be divisible by 2 for sinusoidal positional embeddings, got {channels} channels." ) log_timescale_increment = math.log(max_timescale) / (channels // 2 - 1) inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2)) scaled_time = torch.arange(length).view(-1, 1) * inv_timescales.view(1, -1) return torch.cat([scaled_time.sin(), scaled_time.cos()], dim=1) # Copied from transformers.models.bart.modeling_bart.shift_tokens_right def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): """ Shift input ids one token to the right. """ shifted_input_ids = input_ids.new_zeros(input_ids.shape) shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() shifted_input_ids[:, 0] = decoder_start_token_id if pad_token_id is None: raise ValueError("self.model.config.pad_token_id has to be defined.") # replace possible -100 values in labels by `pad_token_id` shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) return shifted_input_ids # Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices def _compute_mask_indices( shape: Tuple[int, int], mask_prob: float, mask_length: int, attention_mask: Optional[torch.LongTensor] = None, min_masks: int = 0, ) -> np.ndarray: """ Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on CPU as part of the preprocessing during training. Args: shape: The shape for which to compute masks. This should be of a tuple of size 2 where the first element is the batch size and the second element is the length of the axis to span. mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of independently generated mask spans of length `mask_length` is computed by `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the actual percentage will be smaller. mask_length: size of the mask min_masks: minimum number of masked spans attention_mask: A (right-padded) attention mask which independently shortens the feature axis of each batch dimension. """ batch_size, sequence_length = shape if mask_length < 1: raise ValueError("`mask_length` has to be bigger than 0.") if mask_length > sequence_length: raise ValueError( f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}" f" and `sequence_length`: {sequence_length}`" ) # epsilon is used for probabilistic rounding epsilon = np.random.rand(1).item() def compute_num_masked_span(input_length): """Given input length, compute how many spans should be masked""" num_masked_span = int(mask_prob * input_length / mask_length + epsilon) num_masked_span = max(num_masked_span, min_masks) # make sure num masked span <= sequence_length if num_masked_span * mask_length > sequence_length: num_masked_span = sequence_length // mask_length # make sure num_masked span is also <= input_length - (mask_length - 1) if input_length - (mask_length - 1) < num_masked_span: num_masked_span = max(input_length - (mask_length - 1), 0) return num_masked_span # compute number of masked spans in batch input_lengths = ( attention_mask.sum(-1).detach().tolist() if attention_mask is not None else [sequence_length for _ in range(batch_size)] ) # SpecAugment mask to fill spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool) spec_aug_mask_idxs = [] max_num_masked_span = compute_num_masked_span(sequence_length) if max_num_masked_span == 0: return spec_aug_mask for input_length in input_lengths: # compute num of masked spans for this input num_masked_span = compute_num_masked_span(input_length) # get random indices to mask spec_aug_mask_idx = np.random.choice( np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False ) # pick first sampled index that will serve as a dummy index to pad vector # to ensure same dimension for all batches due to probabilistic rounding # Picking first sample just pads those vectors twice. if len(spec_aug_mask_idx) == 0: # this case can only happen if `input_length` is strictly smaller then # `sequence_length` in which case the last token has to be a padding # token which we can use as a dummy mask id dummy_mask_idx = sequence_length - 1 else: dummy_mask_idx = spec_aug_mask_idx[0] spec_aug_mask_idx = np.concatenate( [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx] ) spec_aug_mask_idxs.append(spec_aug_mask_idx) spec_aug_mask_idxs = np.array(spec_aug_mask_idxs) # expand masked indices to masked spans spec_aug_mask_idxs = np.broadcast_to( spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length) ) spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length) # add offset to the starting indexes so that indexes now create a span offsets = np.arange(mask_length)[None, None, :] offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape( batch_size, max_num_masked_span * mask_length ) spec_aug_mask_idxs = spec_aug_mask_idxs + offsets # ensure that we cannot have indices larger than sequence_length if spec_aug_mask_idxs.max() > sequence_length - 1: spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1 # scatter indices to mask np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1) return spec_aug_mask class WhisperPositionalEmbedding(nn.Embedding): def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None): super().__init__(num_positions, embedding_dim) def forward(self, input_ids, past_key_values_length=0, position_ids=None): if position_ids is None: return self.weight[past_key_values_length : past_key_values_length + input_ids.shape[1]] else: return self.weight[position_ids] class WhisperAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__( self, embed_dim: int, num_heads: int, dropout: float = 0.0, is_decoder: bool = False, bias: bool = True, is_causal: bool = False, layer_idx: Optional[int] = None, config: Optional[CustomWhisperConfig] = None, ): super().__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.dropout = dropout self.head_dim = embed_dim // num_heads self.config = config if (self.head_dim * num_heads) != self.embed_dim: raise ValueError( f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" f" and `num_heads`: {num_heads})." ) self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder self.is_causal = is_causal if layer_idx is None and is_decoder: logger.warning_once( f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " "will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " "when creating this class." ) self.layer_idx = layer_idx self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) # Copied from transformers.models.bart.modeling_bart.BartAttention._shape with BART->whisper def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, past_key_value: Optional[EncoderDecoderCache] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder is_cross_attention = key_value_states is not None bsz, tgt_len, _ = hidden_states.size() # get query proj query_states = self._shape(self.q_proj(hidden_states) * self.scaling, tgt_len, bsz) if past_key_value is not None: is_updated = past_key_value.is_updated.get(self.layer_idx) if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_states from cache past_key_value.is_updated[self.layer_idx] = True past_key_value = past_key_value.cross_attention_cache else: past_key_value = past_key_value.self_attention_cache # use key_value_states if cross attention current_states = key_value_states if key_value_states is not None else hidden_states if is_cross_attention and past_key_value and is_updated: # reuse k,v, cross_attentions key_states = past_key_value.key_cache[self.layer_idx] value_states = past_key_value.value_cache[self.layer_idx] else: key_states = self._shape(self.k_proj(current_states), -1, bsz) value_states = self._shape(self.v_proj(current_states), -1, bsz) if past_key_value is not None: # save all key/value_states to cache to be re-used for fast auto-regressive generation cache_position = cache_position if not is_cross_attention else None key_states, value_states = past_key_value.update( key_states, value_states, self.layer_idx, {"cache_position": cache_position} ) attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) if attention_mask is not None: # no matter the length, we just slice it causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1) if layer_head_mask is not None: if layer_head_mask.size() != (self.num_heads,): raise ValueError( f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" f" {layer_head_mask.size()}" ) attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) attn_output = torch.matmul(attn_probs, value_states) if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): raise ValueError( f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" f" {attn_output.size()}" ) attn_output = attn_output.transpose(1, 2) # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be # partitioned across GPUs when using tensor-parallelism. attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) attn_output = self.out_proj(attn_output) return attn_output, attn_weights, past_key_value class WhisperFlashAttention2(WhisperAttention): """ Whisper flash attention module. This module inherits from `WhisperAttention` as the weights of the module stays untouched. The only required change would be on the forward pass where it needs to correctly call the public API of flash attention and deal with padding tokens in case the input contains any of them. """ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, past_key_value: Optional[EncoderDecoderCache] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if isinstance(past_key_value, StaticCache): raise ValueError( "The `static` cache implementation is not compatible with `attn_implementation='flash_attention_2'`. " "Use `attn_implementation='sdpa'` in the meantime, and open an issue at https://github.com/huggingface/transformers" ) # WhisperFlashAttention2 attention does not support output_attentions if output_attentions: raise ValueError("WhisperFlashAttention2 attention does not support output_attentions") # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder is_cross_attention = key_value_states is not None bsz, tgt_len, _ = hidden_states.size() # get query proj query_states = torch.reshape(self.q_proj(hidden_states), (bsz, tgt_len, self.num_heads, self.head_dim)) if past_key_value is not None: is_updated = past_key_value.is_updated.get(self.layer_idx) if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_states from cache past_key_value.is_updated[self.layer_idx] = True past_key_value = past_key_value.cross_attention_cache else: past_key_value = past_key_value.self_attention_cache # use key_value_states if cross attention current_states = key_value_states if key_value_states is not None else hidden_states if is_cross_attention and past_key_value and is_updated: # reuse k,v, cross_attentions key_states = past_key_value.key_cache[self.layer_idx] value_states = past_key_value.value_cache[self.layer_idx] else: key_states = self._shape(self.k_proj(current_states), -1, bsz) value_states = self._shape(self.v_proj(current_states), -1, bsz) if past_key_value is not None: # save all key/value_states to cache to be re-used for fast auto-regressive generation cache_position = cache_position if not is_cross_attention else None key_states, value_states = past_key_value.update( key_states, value_states, self.layer_idx, {"cache_position": cache_position} ) # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim] # We would need to refactor the KV cache to be able to avoid many of these transpose/reshape/view. key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) causal_mask = attention_mask if attention_mask is not None: # no matter the length, we just slice it causal_mask = attention_mask[:, : key_states.shape[-2]] # In PEFT, usually we cast the layer norms in float32 for training stability reasons # therefore the input hidden states gets silently casted in float32. Hence, we need # cast them back in the correct dtype just to be sure everything works as expected. # This might slowdown training & inference so it is recommended to not cast the LayerNorms # in fp32. (LlamaRMSNorm handles it correctly) input_dtype = query_states.dtype if input_dtype == torch.float32: if torch.is_autocast_enabled(): target_dtype = torch.get_autocast_gpu_dtype() # Handle the case where the model is quantized elif hasattr(self.config, "_pre_quantization_dtype"): target_dtype = self.config._pre_quantization_dtype else: target_dtype = self.q_proj.weight.dtype logger.warning_once( f"The input hidden states seems to be silently casted in float32, this might be related to" f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" f" {target_dtype}." ) query_states = query_states.to(target_dtype) key_states = key_states.to(target_dtype) value_states = value_states.to(target_dtype) attn_output = _flash_attention_forward( query_states, key_states, value_states, causal_mask, tgt_len, dropout=self.dropout if self.training else 0.0, is_causal=self.is_causal, use_top_left_mask=self._flash_attn_uses_top_left_mask, ) attn_output = attn_output.reshape(bsz, tgt_len, -1) attn_output = self.out_proj(attn_output) if not output_attentions: attn_weights = None return attn_output, attn_weights, past_key_value class WhisperSdpaAttention(WhisperAttention): def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, past_key_value: Optional[EncoderDecoderCache] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" if output_attentions or layer_head_mask is not None: # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented. logger.warning_once( "WhisperModel is using WhisperSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention" ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' ) return super().forward( hidden_states, key_value_states=key_value_states, past_key_value=past_key_value, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, cache_position=cache_position, ) # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder is_cross_attention = key_value_states is not None bsz, tgt_len, _ = hidden_states.size() # get query proj query_states = self._shape(self.q_proj(hidden_states), tgt_len, bsz) if past_key_value is not None: is_updated = past_key_value.is_updated.get(self.layer_idx) if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_states from cache past_key_value.is_updated[self.layer_idx] = True past_key_value = past_key_value.cross_attention_cache else: past_key_value = past_key_value.self_attention_cache # use key_value_states if cross attention current_states = key_value_states if key_value_states is not None else hidden_states if is_cross_attention and past_key_value and is_updated: # reuse k,v, cross_attentions key_states = past_key_value.key_cache[self.layer_idx] value_states = past_key_value.value_cache[self.layer_idx] else: key_states = self._shape(self.k_proj(current_states), -1, bsz) value_states = self._shape(self.v_proj(current_states), -1, bsz) if past_key_value is not None: # save all key/value_states to cache to be re-used for fast auto-regressive generation cache_position = cache_position if not is_cross_attention else None key_states, value_states = past_key_value.update( key_states, value_states, self.layer_idx, {"cache_position": cache_position} ) causal_mask = attention_mask if attention_mask is not None: # no matter the length, we just slice it causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. is_causal = True if self.is_causal and causal_mask is None and tgt_len > 1 else False # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=causal_mask, dropout_p=self.dropout if self.training else 0.0, is_causal=is_causal, ) if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): raise ValueError( f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" f" {attn_output.size()}" ) attn_output = attn_output.transpose(1, 2) # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be # partitioned across GPUs when using tensor-parallelism. attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) attn_output = self.out_proj(attn_output) return attn_output, None, past_key_value WHISPER_ATTENTION_CLASSES = { "eager": WhisperAttention, "flash_attention_2": WhisperFlashAttention2, "sdpa": WhisperSdpaAttention, } # Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Whisper, MBART->WHISPER class WhisperEncoderLayer(nn.Module): def __init__(self, config: CustomWhisperConfig): super().__init__() self.embed_dim = config.d_model self.self_attn = WHISPER_ATTENTION_CLASSES[config._attn_implementation]( embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, config=config, ) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] self.activation_dropout = config.activation_dropout self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) self.final_layer_norm = nn.LayerNorm(self.embed_dim) def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, layer_head_mask: torch.Tensor, output_attentions: bool = False, ) -> torch.Tensor: """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` attention_mask (`torch.FloatTensor`): attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size `(encoder_attention_heads,)`. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. """ residual = hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) hidden_states, attn_weights, _ = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.activation_fn(self.fc1(hidden_states)) hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) hidden_states = self.fc2(hidden_states) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states if hidden_states.dtype == torch.float16 and ( torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() ): clamp_value = torch.finfo(hidden_states.dtype).max - 1000 hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) outputs = (hidden_states,) if output_attentions: outputs += (attn_weights,) return outputs class WhisperDecoderLayer(nn.Module): def __init__(self, config: CustomWhisperConfig, layer_idx: int = None): super().__init__() self.embed_dim = config.d_model self.self_attn = WHISPER_ATTENTION_CLASSES[config._attn_implementation]( embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, is_decoder=True, is_causal=True, layer_idx=layer_idx, config=config, ) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] self.activation_dropout = config.activation_dropout self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.encoder_attn = WHISPER_ATTENTION_CLASSES[config._attn_implementation]( self.embed_dim, config.decoder_attention_heads, dropout=config.attention_dropout, is_decoder=True, layer_idx=layer_idx, config=config, ) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) self.final_layer_norm = nn.LayerNorm(self.embed_dim) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None, past_key_value: Optional[EncoderDecoderCache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, cache_position: Optional[torch.LongTensor] = None, ) -> torch.Tensor: """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` attention_mask (`torch.FloatTensor`): attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. encoder_hidden_states (`torch.FloatTensor`): cross attention input to the layer of shape `(batch, seq_len, embed_dim)` encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size `(encoder_attention_heads,)`. cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of size `(decoder_attention_heads,)`. past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. """ residual = hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) # Self Attention hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, past_key_value=past_key_value, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, cache_position=cache_position, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states # Cross-Attention Block cross_attn_weights = None if encoder_hidden_states is not None: residual = hidden_states hidden_states = self.encoder_attn_layer_norm(hidden_states) hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, layer_head_mask=cross_attn_layer_head_mask, past_key_value=past_key_value, output_attentions=output_attentions, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states # add cross-attn to positions 1 of present_key_value tuple present_key_value = (present_key_value, cross_attn_present_key_value) # Fully Connected residual = hidden_states hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.activation_fn(self.fc1(hidden_states)) hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) hidden_states = self.fc2(hidden_states) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states outputs = (hidden_states,) if output_attentions: outputs += (self_attn_weights, cross_attn_weights) if use_cache: outputs += (present_key_value,) return outputs class WhisperPreTrainedModel(PreTrainedModel): config_class = CustomWhisperConfig base_model_prefix = "model" main_input_name = "input_features" supports_gradient_checkpointing = True _no_split_modules = ["WhisperEncoderLayer", "WhisperDecoderLayer"] _supports_flash_attn_2 = True _supports_sdpa = True _supports_cache_class = True _supports_static_cache = True def _init_weights(self, module): std = self.config.init_std if isinstance(module, (nn.Linear, nn.Conv1d)): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() elif isinstance(module, WhisperEncoder): with torch.no_grad(): embed_positions = module.embed_positions.weight embed_positions.copy_(sinusoids(*embed_positions.shape)) def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): """ Computes the output length of the convolutional layers """ input_lengths = (input_lengths - 1) // 2 + 1 return input_lengths WHISPER_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.) This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior. Parameters: config ([`WhisperConfig`]): Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. """ WHISPER_INPUTS_DOCSTRING = r""" Args: input_features (`torch.FloatTensor` of shape `(batch_size, feature_size, sequence_length)`): Float values mel features extracted from the raw speech waveform. Raw speech waveform can be obtained by loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding and conversion into a tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`] attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing *SpecAugment* data augmentation on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): Indices of decoder input sequence tokens in the vocabulary. Indices can be obtained using [`WhisperTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are decoder input IDs?](../glossary#decoder-input-ids) Whisper uses the `decoder_start_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`). decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also be used by default. If you want to change padding behavior, you should read [`modeling_whisper._prepare_decoder_attention_mask`] and modify to your needs. See diagram 1 in [the BART paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy. head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`: - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. decoder_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`: - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*): Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. past_key_values (`EncoderDecoderCache` or `tuple(tuple(torch.FloatTensor))`, *optional*): Pre-computed hidden-states that can be used to speed up auto-regressive (sequential) decoding. There are four sets of pre-computed hidden-states: key and values states in the self-attention blocks (2) and in the cross-attention blocks (2). The `past_key_values` are returned when `use_cache=True` is passed or when `config.use_cache=True` Two formats are allowed: - An [`~cache_utils.EncoderDecoderCache`] instance; - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `decoder_input_ids` of shape `(batch_size, sequence_length)`. decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be input (see `past_key_values`). This is useful if you want more control over how to convert `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix. use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): Indices depicting the position of the input sequence tokens in the sequence. It is used to update the cache in the correct position and to infer the complete sequence length. """ WHISPER_ENCODER_INPUTS_DOCSTRING = r""" Args: input_features (`torch.FloatTensor` of shape `(batch_size, feature_size, sequence_length)`): Float values mel features extracted from the raw speech waveform. Raw speech waveform can be obtained by loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding and conversion into a tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`] head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`: - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*): Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of hidden-states at the output of the last layer of the encoder. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ class WhisperEncoder(WhisperPreTrainedModel): """ Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a [`WhisperEncoderLayer`]. Args: config: WhisperConfig """ def __init__(self, config: CustomWhisperConfig): super().__init__(config) self.dropout = config.dropout self.layerdrop = config.encoder_layerdrop embed_dim = config.d_model self.num_mel_bins = config.num_mel_bins self.padding_idx = config.pad_token_id self.max_source_positions = config.max_source_positions self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 # CUSTOM # Create conv layers dynamically based on config self.conv_layers = nn.ModuleList() for layer_config in config.conv_preprocessing_layers: # Create sequential module for each conv+activation pair conv_sequence = nn.Sequential( nn.Conv1d( layer_config["in_channels"], layer_config["out_channels"], kernel_size=layer_config["kernel_size"], stride=layer_config["stride"], padding=layer_config["padding"], bias=True ), nn.GELU() if layer_config["activation"] == "gelu" else nn.ReLU() ) self.conv_layers.append(conv_sequence) self.embed_positions = nn.Embedding(self.max_source_positions, embed_dim) self.embed_positions.requires_grad_(False) self.layers = nn.ModuleList([WhisperEncoderLayer(config) for _ in range(config.encoder_layers)]) self.layer_norm = nn.LayerNorm(config.d_model) self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() # CUSTOM def get_conv_stride(self): total_stride = 1 for layer in self.conv_layers: total_stride *= layer[0].stride[0] return total_stride def _freeze_parameters(self): for param in self.parameters(): param.requires_grad = False self._requires_grad = False def get_input_embeddings(self) -> nn.Module: return self.conv1 def set_input_embeddings(self, value: nn.Module): self.conv1 = value # CUSTOM def get_position_embeddings(self, sequence_length: int): """Get position embeddings with optional zero padding""" embed_pos = self.embed_positions.weight max_pos_len = embed_pos.shape[0] if sequence_length <= max_pos_len: if self.config.use_first_embeddings: return embed_pos[:sequence_length] else: return embed_pos[-sequence_length:] # Handle case where sequence_length > max_pos_len if self.config.allow_position_padding: # Create zero-padded embeddings device = embed_pos.device dtype = embed_pos.dtype padded_embeddings = torch.zeros((sequence_length, embed_pos.shape[1]), device=device, dtype=dtype) if self.config.use_first_embeddings: padded_embeddings[:max_pos_len] = embed_pos else: padded_embeddings[-max_pos_len:] = embed_pos return padded_embeddings else: # Use available embeddings (truncate to max_pos_len) if self.config.use_first_embeddings: return embed_pos[:max_pos_len] else: return embed_pos[-max_pos_len:] def forward( self, input_features, attention_mask=None, head_mask=None, output_attentions=None, output_hidden_states=None, return_dict=None, ): r""" Args: input_features (`torch.LongTensor` of shape `(batch_size, feature_size, sequence_length)`): Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be obtained by loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding and conversion into a tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`] attention_mask (`torch.Tensor`)`, *optional*): Whisper does not support masking of the `input_features`, this argument is preserved for compatibility, but it is not used. By default the silence in the input log mel spectrogram are ignored. head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ # CUSTOM #expected_seq_length = self.config.max_source_positions * self.get_conv_stride() # CUSTOM # Must be deactivated for our purpose, theoretically Whisper supports any sequence length for the encoder # if input_features.shape[-1] != expected_seq_length: # raise ValueError( # f"Whisper expects the mel input features to be of length {expected_seq_length}, but found {input_features.shape[-1]}. Make sure to pad the input mel features to {expected_seq_length}." # ) output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict # CUSTOM # Process through conv layers inputs_embeds = input_features for i, conv_layer in enumerate(self.conv_layers): residual = None if self.config.skip_connections and i > 0 and inputs_embeds.shape == conv_layer(inputs_embeds).shape: residual = inputs_embeds inputs_embeds = conv_layer(inputs_embeds) # Apply activation if self.config.conv_activation.lower() == "gelu": inputs_embeds = nn.functional.gelu(inputs_embeds) elif self.config.conv_activation.lower() == "relu": inputs_embeds = nn.functional.relu(inputs_embeds) # Apply skip connection if enabled if self.config.skip_connections and i > 0 and residual and residual.shape == inputs_embeds.shape: inputs_embeds = inputs_embeds + residual # Apply dropout if self.config.conv_dropout > 0: inputs_embeds = nn.functional.dropout( inputs_embeds, p=self.config.conv_dropout, training=self.training ) inputs_embeds = inputs_embeds.permute(0, 2, 1) sequence_length = inputs_embeds.shape[1] # CUSTOM position_embeddings = self.get_position_embeddings(sequence_length) print(f"Position_embeddings shape: {position_embeddings.shape}") hidden_states = inputs_embeds + position_embeddings hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None # check if head_mask has a correct number of layers specified if desired if head_mask is not None: assert head_mask.size()[0] == ( len(self.layers) ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." for idx, encoder_layer in enumerate(self.layers): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) to_drop = False if self.training: dropout_probability = torch.rand([]) if dropout_probability < self.layerdrop: # skip the layer to_drop = True if to_drop: layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( encoder_layer.__call__, hidden_states, None, (head_mask[idx] if head_mask is not None else None), output_attentions, ) else: layer_outputs = encoder_layer( hidden_states, None, layer_head_mask=(head_mask[idx] if head_mask is not None else None), output_attentions=output_attentions, ) hidden_states = layer_outputs[0] if output_attentions: all_attentions = all_attentions + (layer_outputs[1],) hidden_states = self.layer_norm(hidden_states) if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if not return_dict: return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) return BaseModelOutput( last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions ) class WhisperDecoder(WhisperPreTrainedModel): """ Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`WhisperDecoderLayer`] Args: config: WhisperConfig """ main_input_name = "input_ids" def __init__(self, config: CustomWhisperConfig): super().__init__(config) self.dropout = config.dropout self.layerdrop = config.decoder_layerdrop self.padding_idx = config.pad_token_id self.max_target_positions = config.max_target_positions self.max_source_positions = config.max_source_positions self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx) self.embed_positions = WhisperPositionalEmbedding(self.max_target_positions, config.d_model) self.layers = nn.ModuleList( [WhisperDecoderLayer(config, layer_idx) for layer_idx in range(config.decoder_layers)] ) self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" self._use_sdpa = config._attn_implementation == "sdpa" self.layer_norm = nn.LayerNorm(config.d_model) self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.embed_tokens def set_input_embeddings(self, value): self.embed_tokens = value def _freeze_parameters(self): for param in self.parameters(): param.requires_grad = False self._requires_grad = False def forward( self, input_ids=None, attention_mask=None, encoder_hidden_states=None, head_mask=None, cross_attn_head_mask=None, past_key_values=None, inputs_embeds=None, position_ids=None, use_cache=None, output_attentions=None, output_hidden_states=None, return_dict=None, cache_position=None, ): r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide it. Indices can be obtained using [`WhisperTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention on hidden heads. Mask values selected in `[0, 1]`: - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. past_key_values (`EncoderDecoderCache` or `tuple(tuple(torch.FloatTensor))`, *optional*): Pre-computed hidden-states that can be used to speed up auto-regressive (sequential) decoding. There are four sets of pre-computed hidden-states: key and values states in the self-attention blocks (2) and in the cross-attention blocks (2). The `past_key_values` are returned when `use_cache=True` is passed or when `config.use_cache=True` Two formats are allowed: - An [`~cache_utils.EncoderDecoderCache`] instance; - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): Indices depicting the position of the input sequence tokens in the sequence. It is used to update the cache in the correct position and to infer the complete sequence length. """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict # retrieve input_ids and inputs_embeds if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") elif input_ids is not None: input_shape = input_ids.size() input_ids = input_ids.view(-1, input_shape[-1]) elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] else: raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) return_legacy_cache = False return_self_attention_cache = False if use_cache or past_key_values is not None: if isinstance(past_key_values, Cache) and not isinstance(past_key_values, EncoderDecoderCache): return_self_attention_cache = True past_key_values = EncoderDecoderCache(past_key_values, DynamicCache()) elif not isinstance(past_key_values, EncoderDecoderCache): return_legacy_cache = True logger.warning_once( "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.43.0. " "You should pass an instance of `EncoderDecoderCache` instead, e.g. " "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." ) past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) past_key_values_length = 0 if cache_position is not None: past_key_values_length = cache_position[0] elif past_key_values is not None: past_key_values_length = past_key_values.get_seq_length() if cache_position is None: cache_position = torch.arange( past_key_values_length, past_key_values_length + input_shape[1], device=inputs_embeds.device ) if position_ids is None: position_ids = cache_position.unsqueeze(0) # embed positions if input_ids is not None: positions = self.embed_positions( input_ids, past_key_values_length=past_key_values_length, position_ids=position_ids ) else: positions = self.embed_positions( inputs_embeds, past_key_values_length=past_key_values_length, position_ids=position_ids ) hidden_states = inputs_embeds + positions.to(inputs_embeds.device) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, past_key_values.self_attention_cache if past_key_values is not None else None, output_attentions, ) if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( "`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`..." ) use_cache = False # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): if attn_mask is not None: assert attn_mask.size()[0] == (len(self.layers)), ( f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" f" {head_mask.size()[0]}." ) for idx, decoder_layer in enumerate(self.layers): # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) if output_hidden_states: all_hidden_states += (hidden_states,) if self.training: dropout_probability = torch.rand([]) if dropout_probability < self.layerdrop: continue if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, causal_mask, encoder_hidden_states, None, # encoder attention mask head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, # past_key_value output_attentions, use_cache, cache_position, ) else: layer_outputs = decoder_layer( hidden_states, attention_mask=causal_mask, encoder_hidden_states=encoder_hidden_states, layer_head_mask=(head_mask[idx] if head_mask is not None else None), cross_attn_layer_head_mask=( cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None ), past_key_value=past_key_values if use_cache else None, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, ) hidden_states = layer_outputs[0] if output_attentions: all_self_attns += (layer_outputs[1],) if encoder_hidden_states is not None: all_cross_attentions += (layer_outputs[2],) hidden_states = self.layer_norm(hidden_states) # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) next_cache = past_key_values if use_cache else None if return_self_attention_cache: next_cache = past_key_values.self_attention_cache if return_legacy_cache: next_cache = past_key_values.to_legacy_cache() if not return_dict: return tuple( v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] if v is not None ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attns, cross_attentions=all_cross_attentions, ) # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask def _update_causal_mask( self, attention_mask: torch.Tensor, input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, output_attentions: bool, ): if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens, is_training=self.training, ): return None dtype, device = input_tensor.dtype, input_tensor.device sequence_length = input_tensor.shape[1] if using_static_cache: target_length = past_key_values.get_max_cache_shape() else: target_length = ( attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + sequence_length + 1 ) # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( attention_mask, sequence_length=sequence_length, target_length=target_length, dtype=dtype, device=device, cache_position=cache_position, batch_size=input_tensor.shape[0], ) if ( self.config._attn_implementation == "sdpa" and attention_mask is not None and attention_mask.device.type == "cuda" and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # Details: https://github.com/pytorch/pytorch/issues/110213 min_dtype = torch.finfo(dtype).min causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) return causal_mask @staticmethod # Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, target_length: int, dtype: torch.dtype, device: torch.device, cache_position: torch.Tensor, batch_size: int, **kwargs, ): """ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. Args: attention_mask (`torch.Tensor`): A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. sequence_length (`int`): The sequence length being processed. target_length (`int`): The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. device (`torch.device`): The device to plcae the 4D attention mask on. cache_position (`torch.Tensor`): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): Batch size. """ if attention_mask is not None and attention_mask.dim() == 4: # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. causal_mask = attention_mask else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device ) if sequence_length != 1: causal_mask = torch.triu(causal_mask, diagonal=1) causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit mask_length = attention_mask.shape[-1] padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] padding_mask = padding_mask == 0 causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( padding_mask, min_dtype ) return causal_mask @add_start_docstrings( "The bare Whisper Model outputting raw hidden-states without any specific head on top.", WHISPER_START_DOCSTRING, ) class WhisperModel(WhisperPreTrainedModel): def __init__(self, config: CustomWhisperConfig): super().__init__(config) self.encoder = WhisperEncoder(config) self.decoder = WhisperDecoder(config) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.decoder.embed_tokens def set_input_embeddings(self, value): self.decoder.embed_tokens = value def get_encoder(self): return self.encoder def get_decoder(self): return self.decoder def freeze_encoder(self): """ Calling this function will disable the gradient computation for the Whisper encoder so that its parameters will not be updated during training. """ self.encoder._freeze_parameters() def _mask_input_features( self, input_features: torch.FloatTensor, attention_mask: Optional[torch.LongTensor] = None, ): """ Masks extracted features along time axis and/or along feature axis according to [SpecAugment](https://arxiv.org/abs/1904.08779). """ # `config.apply_spec_augment` can set masking to False if not getattr(self.config, "apply_spec_augment", True): return input_features # generate indices & apply SpecAugment along time axis batch_size, hidden_size, sequence_length = input_features.size() if self.config.mask_time_prob > 0 and self.training: # generate indices & apply SpecAugment along time axis mask_time_indices = _compute_mask_indices( (batch_size, sequence_length), mask_prob=self.config.mask_time_prob, mask_length=self.config.mask_time_length, attention_mask=attention_mask, min_masks=self.config.mask_time_min_masks, ) mask_time_indices = torch.tensor(mask_time_indices, device=input_features.device, dtype=torch.bool) mask_time_indices = mask_time_indices[:, None].expand(-1, hidden_size, -1) input_features[mask_time_indices] = 0 if self.config.mask_feature_prob > 0 and self.training: # generate indices & apply SpecAugment along feature axis mask_feature_indices = _compute_mask_indices( (batch_size, hidden_size), mask_prob=self.config.mask_feature_prob, mask_length=self.config.mask_feature_length, min_masks=self.config.mask_feature_min_masks, ) mask_feature_indices = torch.tensor(mask_feature_indices, device=input_features.device, dtype=torch.bool) input_features[mask_feature_indices] = 0 return input_features @add_start_docstrings_to_model_forward(WHISPER_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC) def forward( self, input_features: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.LongTensor] = None, decoder_input_ids: Optional[torch.LongTensor] = None, decoder_attention_mask: Optional[torch.LongTensor] = None, head_mask: Optional[torch.Tensor] = None, decoder_head_mask: Optional[torch.Tensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, past_key_values: Optional[Union[EncoderDecoderCache, Tuple[torch.FloatTensor]]] = None, decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None, decoder_position_ids: Optional[Tuple[torch.LongTensor]] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple[torch.Tensor], Seq2SeqModelOutput]: r""" Returns: Example: ```python >>> import torch >>> from transformers import AutoFeatureExtractor, WhisperModel >>> from datasets import load_dataset >>> model = WhisperModel.from_pretrained("openai/whisper-base") >>> feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-base") >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") >>> inputs = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt") >>> input_features = inputs.input_features >>> decoder_input_ids = torch.tensor([[1, 1]]) * model.config.decoder_start_token_id >>> last_hidden_state = model(input_features, decoder_input_ids=decoder_input_ids).last_hidden_state >>> list(last_hidden_state.shape) [1, 2, 512] ```""" output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict if encoder_outputs is None: input_features = self._mask_input_features(input_features, attention_mask=attention_mask) encoder_outputs = self.encoder( input_features, head_mask=head_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): encoder_outputs = BaseModelOutput( last_hidden_state=encoder_outputs[0], hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, ) # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) decoder_outputs = self.decoder( input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, encoder_hidden_states=encoder_outputs[0], head_mask=decoder_head_mask, cross_attn_head_mask=cross_attn_head_mask, past_key_values=past_key_values, inputs_embeds=decoder_inputs_embeds, position_ids=decoder_position_ids, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, ) if not return_dict: return decoder_outputs + encoder_outputs return Seq2SeqModelOutput( last_hidden_state=decoder_outputs.last_hidden_state, past_key_values=decoder_outputs.past_key_values, decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, cross_attentions=decoder_outputs.cross_attentions, encoder_last_hidden_state=encoder_outputs.last_hidden_state, encoder_hidden_states=encoder_outputs.hidden_states, encoder_attentions=encoder_outputs.attentions, ) def _pad_to_max_length( current_segments, pad_token_id, device, padding_side="right", padding="longest", bos_token_tensor=None, cut_off_length=None, ): max_total_length = 0 sequences = [] if padding_side not in ["right", "left"]: raise ValueError(f"`padding_side` must be either 'right' or 'left', not {padding_side}") if padding not in ["longest", "max_length"]: raise ValueError(f"`padding` must be either 'longest' or 'max_length', not {padding}") elif padding == "max_length" and cut_off_length is None: raise ValueError("`cut_off_length` must be specified when `padding='max_length'`") for current_segment_list in current_segments: if current_segment_list is not None and len([d["tokens"] for d in current_segment_list]) > 0: sequence = torch.cat([d["tokens"] for d in current_segment_list], dim=-1) if cut_off_length is not None: sequence = sequence[-cut_off_length:] if bos_token_tensor is not None: sequence = torch.cat([bos_token_tensor, sequence]) sequences.append(sequence) max_total_length = max(max_total_length, len(sequences[-1])) elif bos_token_tensor is not None: sequences.append(bos_token_tensor) else: sequences.append(torch.tensor([], device=device)) max_total_length = cut_off_length + 1 if padding == "max_length" else max_total_length for i in range(len(current_segments)): pad_length = max_total_length - len(sequences[i]) pad = (0, pad_length) if padding_side == "right" else (pad_length, 0) sequences[i] = F.pad(sequences[i], pad=pad, value=pad_token_id) sequences = torch.stack(sequences, dim=0) return sequences def _get_attr_from_logit_processors(logits_processor, logit_processor_class, attribute_name): if logits_processor is not None: logit_processor = next((cls for cls in logits_processor if isinstance(cls, logit_processor_class)), None) if logit_processor: return getattr(logit_processor, attribute_name, None) return None # CUSTOM (patch the generation method) class CustomWhisperGenerationMixin(WhisperGenerationMixin): def generate( self, input_features: Optional[torch.Tensor] = None, generation_config: Optional[Any] = None, logits_processor: Optional[LogitsProcessorList] = None, stopping_criteria: Optional[Any] = None, prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, synced_gpus: bool = False, return_timestamps: Optional[bool] = None, task: Optional[str] = None, language: Optional[Union[str, List[str]]] = None, is_multilingual: Optional[bool] = None, prompt_ids: Optional[torch.Tensor] = None, prompt_condition_type: Optional[str] = None, # first-segment, all-segments condition_on_prev_tokens: Optional[bool] = None, temperature: Optional[Union[float, Tuple[float, ...]]] = None, compression_ratio_threshold: Optional[float] = None, logprob_threshold: Optional[float] = None, no_speech_threshold: Optional[float] = None, num_segment_frames: Optional[int] = None, attention_mask: Optional[torch.Tensor] = None, time_precision: float = 0.02, return_token_timestamps: Optional[bool] = None, return_segments: bool = False, return_dict_in_generate: Optional[bool] = None, **kwargs, ): # 0. deprecate old inputs if "inputs" in kwargs: input_features = kwargs.pop("inputs") warnings.warn( "The input name `inputs` is deprecated. Please make sure to use `input_features` instead.", FutureWarning, ) # 1. prepare generation config generation_config, kwargs = self._prepare_generation_config(generation_config, **kwargs) # 2. set global generate variables input_stride = self.model.encoder.get_conv_stride() #self.model.encoder.conv1.stride[0] * self.model.encoder.conv2.stride[0] num_segment_frames = input_stride * self.config.max_source_positions batch_size, total_input_frames = self._retrieve_total_input_frames( input_features=input_features, input_stride=input_stride, kwargs=kwargs ) is_shortform = total_input_frames <= num_segment_frames # 3. Make sure generation config is correctly set # Make sure the generation config is correctly set depending on whether timestamps are to be returned or not return_dict_in_generate = self._set_return_outputs( return_dict_in_generate=return_dict_in_generate, return_token_timestamps=return_token_timestamps, logprob_threshold=logprob_threshold, generation_config=generation_config, ) timestamp_begin = self._set_return_timestamps( return_timestamps=return_timestamps, is_shortform=is_shortform, generation_config=generation_config ) self._set_language_and_task( language=language, task=task, is_multilingual=is_multilingual, generation_config=generation_config ) self._set_num_frames( return_token_timestamps=return_token_timestamps, generation_config=generation_config, kwargs=kwargs ) self._set_thresholds_and_condition( generation_config=generation_config, logprob_threshold=logprob_threshold, compression_ratio_threshold=compression_ratio_threshold, no_speech_threshold=no_speech_threshold, condition_on_prev_tokens=condition_on_prev_tokens, ) self._set_prompt_condition_type( generation_config=generation_config, prompt_condition_type=prompt_condition_type, ) # pass self.config for backward compatibility init_tokens = self._retrieve_init_tokens( input_features, batch_size=batch_size, generation_config=generation_config, config=self.config, num_segment_frames=num_segment_frames, kwargs=kwargs, ) # passing `decoder_input_ids` is deprecated - the only exception is for assisted generation # where the input ids are handled explicitly by the generate method self._check_decoder_input_ids(kwargs=kwargs) # 3. Retrieve logits processors device = kwargs["encoder_outputs"][0].device if "encoder_outputs" in kwargs else input_features.device begin_index = init_tokens.shape[1] logits_processor = self._retrieve_logit_processors( generation_config=generation_config, logits_processor=logits_processor, begin_index=begin_index, # begin index is index of first generated decoder token num_beams=kwargs.get("num_beams", 1), device=device, ) # 4 Set and retrieve global generation variables self._set_condition_on_prev_tokens( condition_on_prev_tokens=condition_on_prev_tokens, generation_config=generation_config ) temperatures = [temperature] if not isinstance(temperature, (list, tuple)) else temperature temperature = temperatures[0] max_frames, seek = self._retrieve_max_frames_and_seek( batch_size=batch_size, attention_mask=attention_mask, total_input_frames=total_input_frames, is_shortform=is_shortform, ) # 5 Prepare running variables, list for generation num_return_sequences = generation_config.num_return_sequences ( batch_idx_map, cur_bsz, input_features, seek, max_frames, init_tokens, do_condition_on_prev_tokens, ) = self._expand_variables_for_generation( input_features=input_features, seek=seek, max_frames=max_frames, init_tokens=init_tokens, batch_size=batch_size, condition_on_prev_tokens=condition_on_prev_tokens, generation_config=generation_config, ) current_segments = self._prepare_segments( prompt_ids=prompt_ids, batch_size=cur_bsz, generation_config=generation_config, ) # 6 Transcribe audio until we reach the end of all input audios while (seek < max_frames).any(): # 6.1 NOTE: When in longform transcription mode and batch size > 1 we need to dynamically reduce the batch size during the loop # in case one audio finished earlier than another one. Thus, we need to keep a table of "previous-index-2-current-index" in order # to know which original audio is being decoded # Set updated index map, duration of previously decoded chunks and number of max frames of current decoding chunk input_features, cur_bsz, batch_idx_map = self._maybe_reduce_batch( input_features=input_features, seek=seek, max_frames=max_frames, cur_bsz=cur_bsz, batch_idx_map=batch_idx_map, ) time_offset = seek * time_precision / input_stride seek_num_frames = (max_frames - seek).clamp(max=num_segment_frames) # 6.2 cut out next 30s segment from input features segment_input = self._get_input_segment( input_features=input_features, seek=seek, seek_num_frames=seek_num_frames, num_segment_frames=num_segment_frames, cur_bsz=cur_bsz, batch_idx_map=batch_idx_map, ) # 6.3 prepare decoder input ids suppress_tokens = _get_attr_from_logit_processors( logits_processor, SuppressTokensLogitsProcessor, "suppress_tokens" ) decoder_input_ids, kwargs = self._prepare_decoder_input_ids( cur_bsz=cur_bsz, init_tokens=init_tokens, current_segments=current_segments, batch_idx_map=batch_idx_map, do_condition_on_prev_tokens=do_condition_on_prev_tokens, prompt_ids=prompt_ids, generation_config=generation_config, config=self.config, device=init_tokens.device, suppress_tokens=suppress_tokens, kwargs=kwargs, ) # 6.4 set max new tokens or max length self._set_max_new_tokens_and_length( config=self.config, decoder_input_ids=decoder_input_ids, generation_config=generation_config, ) # 6.5 Set current `begin_index` for all logit processors if logits_processor is not None: for proc in logits_processor: if hasattr(proc, "set_begin_index"): proc.set_begin_index(decoder_input_ids.shape[-1]) # 6.6 Run generate with fallback ( seek_sequences, seek_outputs, should_skip, do_condition_on_prev_tokens, model_output_type, ) = self.generate_with_fallback( segment_input=segment_input, decoder_input_ids=decoder_input_ids, cur_bsz=cur_bsz, batch_idx_map=batch_idx_map, seek=seek, num_segment_frames=num_segment_frames, max_frames=max_frames, temperatures=temperatures, generation_config=generation_config, logits_processor=logits_processor, stopping_criteria=stopping_criteria, prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, synced_gpus=synced_gpus, return_token_timestamps=return_token_timestamps, do_condition_on_prev_tokens=do_condition_on_prev_tokens, is_shortform=is_shortform, batch_size=batch_size, attention_mask=attention_mask, kwargs=kwargs, ) # 6.7 In every generated sequence, split by timestamp tokens and extract segments for i, seek_sequence in enumerate(seek_sequences): prev_i = batch_idx_map[i] if should_skip[i]: seek[prev_i] += seek_num_frames[prev_i] continue segments, segment_offset = self._retrieve_segment( seek_sequence=seek_sequence, seek_outputs=seek_outputs, time_offset=time_offset, timestamp_begin=timestamp_begin, seek_num_frames=seek_num_frames, time_precision=time_precision, input_stride=input_stride, prev_idx=prev_i, idx=i, return_token_timestamps=return_token_timestamps, ) current_segments[prev_i] += segments if is_shortform: seek[prev_i] += max_frames[i] else: seek[prev_i] += segment_offset # 7. Once all segments are added to the list of all segments, called `current_segments`, we extract the predicted # output tokens from the list of dicts. If we use batch size > 1, we make sure to pad the output final_segments = ( [x[1:] for x in current_segments] if (prompt_ids is not None and generation_config.prompt_condition_type == "first-segment") else current_segments ) sequences = _pad_to_max_length( final_segments, generation_config.pad_token_id, device=self.device, padding_side="right" ) # 8. If we return all segments, the predicted output sequences are put under `"sequences"`. if return_segments: return {"sequences": sequences, "segments": final_segments} if is_shortform: # add eos token: if generation_config.max_new_tokens is None and generation_config.max_length is None: eos_tokens = torch.full((sequences.shape[0], 1), generation_config.eos_token_id) sequences = torch.cat([sequences, eos_tokens], dim=-1) if return_token_timestamps: outputs = {} outputs["sequences"] = sequences outputs["token_timestamps"] = torch.stack([d["token_timestamps"] for d in seek_outputs], dim=0) else: outputs = sequences if return_dict_in_generate and generation_config.return_dict_in_generate: dict_outputs = self._stack_split_outputs(seek_outputs, model_output_type, sequences.device, kwargs) if num_return_sequences > 1: if hasattr(dict_outputs, "encoder_attentions") and dict_outputs.encoder_attentions is not None: dict_outputs.encoder_attentions = tuple( dict_outputs.encoder_attentions[i][::num_return_sequences] for i in range(len(dict_outputs.encoder_attentions)) ) if ( hasattr(dict_outputs, "encoder_hidden_states") and dict_outputs.encoder_hidden_states is not None ): dict_outputs.encoder_hidden_states = tuple( dict_outputs.encoder_hidden_states[i][::num_return_sequences] for i in range(len(dict_outputs.encoder_hidden_states)) ) if return_token_timestamps: dict_outputs["token_timestamps"] = outputs["token_timestamps"] return dict_outputs return outputs return sequences @add_start_docstrings( "The Whisper Model with a language modeling head. Can be used for automatic speech recognition.", WHISPER_START_DOCSTRING, ) class CustomWhisperForConditionalGeneration(CustomWhisperGenerationMixin, WhisperPreTrainedModel): base_model_prefix = "model" _tied_weights_keys = ["proj_out.weight"] def __init__(self, config: CustomWhisperConfig): super().__init__(config) self.model = WhisperModel(config) self.proj_out = nn.Linear(config.d_model, config.vocab_size, bias=False) self.max_target_positions = config.max_target_positions # Initialize weights and apply final processing self.post_init() def get_encoder(self): return self.model.get_encoder() def get_decoder(self): return self.model.get_decoder() def get_output_embeddings(self): return self.proj_out def set_output_embeddings(self, new_embeddings): self.proj_out = new_embeddings def get_input_embeddings(self) -> nn.Module: return self.model.get_input_embeddings() def freeze_encoder(self): """ Calling this function will disable the gradient computation for the Whisper encoder so that its parameters will not be updated during training. """ self.model.encoder._freeze_parameters() ## CUSTOM def freeze_decoder(self): """ Calling this function will disable the gradient computation for the Whisper decoder so that its parameters will not be updated during training. """ self.model.decoder._freeze_parameters() @add_start_docstrings_to_model_forward(WHISPER_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) def forward( self, input_features: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.LongTensor] = None, decoder_input_ids: Optional[torch.LongTensor] = None, decoder_attention_mask: Optional[torch.LongTensor] = None, head_mask: Optional[torch.Tensor] = None, decoder_head_mask: Optional[torch.Tensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, past_key_values: Optional[Union[EncoderDecoderCache, Tuple[torch.FloatTensor]]] = None, decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None, decoder_position_ids: Optional[Tuple[torch.LongTensor]] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. `sequence_length` should be smaller than or equal to `config.max_target_positions`. Returns: Example: ```python >>> import torch >>> from transformers import AutoProcessor, WhisperForConditionalGeneration >>> from datasets import load_dataset >>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en") >>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en") >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") >>> inputs = processor(ds[0]["audio"]["array"], return_tensors="pt") >>> input_features = inputs.input_features >>> generated_ids = model.generate(inputs=input_features) >>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] >>> transcription ' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.' ```""" return_dict = return_dict if return_dict is not None else self.config.use_return_dict if labels is not None: if labels.shape[1] > self.max_target_positions: raise ValueError( f"Labels' sequence length {labels.shape[1]} cannot exceed the maximum allowed length of {self.max_target_positions} tokens." ) if decoder_input_ids is None and decoder_inputs_embeds is None: decoder_input_ids = shift_tokens_right( labels, self.config.pad_token_id, self.config.decoder_start_token_id ) outputs = self.model( input_features, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, encoder_outputs=encoder_outputs, decoder_attention_mask=decoder_attention_mask, head_mask=head_mask, decoder_head_mask=decoder_head_mask, cross_attn_head_mask=cross_attn_head_mask, past_key_values=past_key_values, decoder_inputs_embeds=decoder_inputs_embeds, decoder_position_ids=decoder_position_ids, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, ) lm_logits = self.proj_out(outputs[0]) loss = None if labels is not None: loss_fct = CrossEntropyLoss() # move labels to correct device to enable PP labels = labels.to(lm_logits.device) loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.reshape(-1)) if not return_dict: output = (lm_logits,) + outputs[1:] return ((loss,) + output) if loss is not None else output return Seq2SeqLMOutput( loss=loss, logits=lm_logits, past_key_values=outputs.past_key_values, decoder_hidden_states=outputs.decoder_hidden_states, decoder_attentions=outputs.decoder_attentions, cross_attentions=outputs.cross_attentions, encoder_last_hidden_state=outputs.encoder_last_hidden_state, encoder_hidden_states=outputs.encoder_hidden_states, encoder_attentions=outputs.encoder_attentions, ) def prepare_inputs_for_generation( self, decoder_input_ids, past_key_values=None, use_cache=None, encoder_outputs=None, attention_mask=None, decoder_attention_mask=None, cache_position=None, **kwargs, ): # Overwritten -- encoder-decoder whisper has custom logic, but it's close to the general function. Next time # this function needs to be touched, let's try to sort out the commonalities between the two and remove the # overwrite. decoder_position_ids = None if decoder_attention_mask is not None: decoder_position_ids = (decoder_attention_mask.cumsum(-1) - 1).clamp(min=0) past_length = 0 if past_key_values is not None: if isinstance(past_key_values, EncoderDecoderCache): past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() else: past_length = past_key_values[0][0].shape[2] # Some generation methods already pass only the last input ID if decoder_input_ids.shape[1] > past_length: remove_prefix_length = past_length else: # Default to old behavior: keep only final ID remove_prefix_length = decoder_input_ids.shape[1] - 1 decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] if decoder_position_ids is not None: decoder_position_ids = decoder_position_ids[:, remove_prefix_length:] # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture. decoder_position_ids = decoder_position_ids.clone(memory_format=torch.contiguous_format) if cache_position is None: cache_position = torch.arange( past_length, past_length + decoder_input_ids.shape[1], device=decoder_input_ids.device ) elif use_cache: cache_position = cache_position[-decoder_input_ids.shape[1] :] # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114 decoder_input_ids = decoder_input_ids.contiguous() if ( isinstance(past_key_values, EncoderDecoderCache) and ( isinstance(past_key_values.self_attention_cache, StaticCache) or isinstance(past_key_values.cross_attention_cache, StaticCache) ) and decoder_attention_mask is not None and decoder_attention_mask.ndim == 2 ): batch_size, sequence_length = decoder_input_ids.shape decoder_attention_mask = self.get_decoder()._prepare_4d_causal_attention_mask_with_cache_position( decoder_attention_mask, sequence_length=sequence_length, target_length=past_key_values.self_attention_cache.get_max_cache_shape(), dtype=self.proj_out.weight.dtype, device=decoder_input_ids.device, cache_position=cache_position, batch_size=batch_size, ) return { "encoder_outputs": encoder_outputs, "past_key_values": past_key_values, "decoder_input_ids": decoder_input_ids, "use_cache": use_cache, "decoder_attention_mask": decoder_attention_mask, "decoder_position_ids": decoder_position_ids, "cache_position": cache_position, } class WhisperDecoderWrapper(WhisperPreTrainedModel): """ This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is used in combination with the [`EncoderDecoderModel`] framework. """ def __init__(self, config): super().__init__(config) config.is_encoder_decoder = False self.decoder = WhisperDecoder(config) def get_input_embeddings(self): return self.decoder.embed_tokens def set_input_embeddings(self, value): self.decoder.embed_tokens = value def forward(self, *args, **kwargs): return self.decoder(*args, **kwargs) @add_start_docstrings( """ Whisper decoder with a language modeling head on top (linear layer with weights tied to the input embeddings). """, WHISPER_START_DOCSTRING, ) class WhisperForCausalLM(WhisperPreTrainedModel, GenerationMixin): _tied_weights_keys = ["proj_out.weight"] main_input_name = "input_ids" def __init__(self, config): super().__init__(config) config.is_encoder_decoder = False self.model = WhisperDecoderWrapper(config) self.proj_out = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing self.post_init() def get_output_embeddings(self): return self.proj_out def set_output_embeddings(self, new_embeddings): self.proj_out = new_embeddings def get_input_embeddings(self) -> nn.Module: return self.model.get_input_embeddings() def set_input_embeddings(self, value): self.model.set_input_embeddings(value) def set_decoder(self, decoder): self.model.decoder = decoder def get_decoder(self): return self.model.decoder @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None, head_mask: Optional[torch.Tensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide it. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) encoder_outputs (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if the model is configured as a decoder. head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional tensors are only required when the model is used as a decoder in a Sequence to Sequence model. Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): Indices depicting the position of the input sequence tokens in the sequence. It is used to update the cache in the correct position and to infer the complete sequence length. Returns: Example: ```python >>> from transformers import WhisperForCausalLM, WhisperForConditionalGeneration, WhisperProcessor >>> import torch >>> from datasets import load_dataset >>> processor = WhisperProcessor.from_pretrained("openai/whisper-large-v2") >>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large-v2") >>> assistant_model = WhisperForCausalLM.from_pretrained("distil-whisper/distil-large-v2") >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") >>> sample = ds[0]["audio"] >>> input_features = processor( ... sample["array"], sampling_rate=sample["sampling_rate"], return_tensors="pt" ... ).input_features >>> predicted_ids = model.generate(input_features, assistant_model=assistant_model) >>> # decode token ids to text >>> transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] >>> transcription ' Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.' ```""" output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict # If the user passed a tuple or `BaseModelOutput` for encoder_outputs, we extract only the hidden states if isinstance(encoder_outputs, (BaseModelOutput, tuple, list)): encoder_outputs = encoder_outputs[0] # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model.decoder( input_ids=input_ids, attention_mask=attention_mask, encoder_hidden_states=encoder_outputs, head_mask=head_mask, cross_attn_head_mask=cross_attn_head_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, ) logits = self.proj_out(outputs[0]) loss = None if labels is not None: labels = labels.to(logits.device) loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output return CausalLMOutputWithCrossAttentions( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, cross_attentions=outputs.cross_attentions, ) @staticmethod def _reorder_cache(past_key_values, beam_idx): reordered_past = () for layer_past in past_key_values: reordered_past += ( tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), ) return reordered_past @add_start_docstrings( """ Whisper Encoder Model with a sequence classification head on top (a linear layer over the pooled output) for tasks like SUPERB Keyword Spotting. """, WHISPER_ENCODER_INPUTS_DOCSTRING, ) class WhisperForAudioClassification(WhisperPreTrainedModel): def __init__(self, config): super().__init__(config) self.encoder = WhisperEncoder(config) num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings if config.use_weighted_layer_sum: self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers) self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size) self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels) # Initialize weights and apply final processing self.post_init() def freeze_encoder(self): """ Calling this function will disable the gradient computation for the Whisper encoder so that its parameters will not be updated during training. Only the projection layers and classification head will be updated. """ self.encoder._freeze_parameters() def get_input_embeddings(self) -> nn.Module: return self.encoder.get_input_embeddings() def set_input_embeddings(self, value: nn.Module): self.encoder.set_input_embeddings(value) @add_start_docstrings_to_model_forward(WHISPER_ENCODER_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC) def forward( self, input_features: Optional[torch.LongTensor] = None, head_mask: Optional[torch.Tensor] = None, encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, labels: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). Returns: Example: ```python >>> import torch >>> from transformers import AutoFeatureExtractor, WhisperForAudioClassification >>> from datasets import load_dataset >>> feature_extractor = AutoFeatureExtractor.from_pretrained("sanchit-gandhi/whisper-medium-fleurs-lang-id") >>> model = WhisperForAudioClassification.from_pretrained("sanchit-gandhi/whisper-medium-fleurs-lang-id") >>> ds = load_dataset("google/fleurs", "all", split="validation", streaming=True) >>> sample = next(iter(ds)) >>> inputs = feature_extractor( ... sample["audio"]["array"], sampling_rate=sample["audio"]["sampling_rate"], return_tensors="pt" ... ) >>> input_features = inputs.input_features >>> with torch.no_grad(): ... logits = model(input_features).logits >>> predicted_class_ids = torch.argmax(logits).item() >>> predicted_label = model.config.id2label[predicted_class_ids] >>> predicted_label 'Afrikaans' ```""" output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) if self.config.use_weighted_layer_sum: output_hidden_states = True elif output_hidden_states is None: output_hidden_states = self.config.output_hidden_states return_dict = return_dict if return_dict is not None else self.config.use_return_dict if encoder_outputs is None: encoder_outputs = self.encoder( input_features, head_mask=head_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) if self.config.use_weighted_layer_sum: hidden_states = encoder_outputs[_HIDDEN_STATES_START_POSITION] hidden_states = torch.stack(hidden_states, dim=1) norm_weights = nn.functional.softmax(self.layer_weights, dim=-1) hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1) else: hidden_states = encoder_outputs[0] hidden_states = self.projector(hidden_states) pooled_output = hidden_states.mean(dim=1) logits = self.classifier(pooled_output) loss = None if labels is not None: loss_fct = CrossEntropyLoss() # move labels to correct device to enable PP labels = labels.to(logits.device) loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) if not return_dict: output = (logits,) + encoder_outputs[1:] return ((loss,) + output) if loss is not None else output return SequenceClassifierOutput( loss=loss, logits=logits, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, )