In modeling_minimax_text_01.py attention mask is not passed correctly to MiniMaxText01FlashAttention2::forward() method

#13
by sszymczyk - opened

In modeling_minimax_text_01.py file MiniMaxText01DecoderLayer::forward() method we have:

        hidden_states, self_attn_weights, present_key_value = self.self_attn(
            hidden_states=hidden_states,
            position_ids=position_ids,
            attn_mask=attention_mask,
            past_key_value=past_key_value,
            output_attentions=output_attentions,
            use_cache=use_cache,
            slope_rate=slope_rate,
        )

where self_attn can be object of class MiniMaxText01LightningAttention or MiniMaxText01FlashAttention2 depending on the layer number.
Note that attention mask is always passed in named attn_mask argument.

However, in MiniMaxText01FlashAttention2::forward() we have:

    def forward(
            self,
            hidden_states: torch.Tensor,
            attention_mask: Optional[torch.Tensor] = None,
            position_ids: Optional[torch.LongTensor] = None,
            past_key_value: Optional[Union[Cache, Tuple[torch.Tensor]]] = None,
            output_attentions: bool = False,
            use_cache: bool = False,
            **kwargs,
    ):

so argument name for attention mask is attention_mask here, not attn_mask as passed in MiniMaxText01DecoderLayer::forward().
Since it has default value of None, attention mask will always be None here.

Is this intentional or an error? Did you use this code to train the model?

MiniMax org

Yes, within MiniMaxText01FlashAttention2, the setting self.is_causal = True ensures that only the causal mask is used for each request. This code only implements the inference execution adapted to the HuggingFace format and has not been run for the training.

Sign up or log in to comment