# coding=utf-8 from typing import List, Optional, Tuple, Union import torch import torch.utils.checkpoint from torch import nn import transformers from transformers import PreTrainedModel, LlamaPreTrainedModel, LlamaForCausalLM, AutoModelForCausalLM from transformers.generation import GenerationMixin from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.utils import logging from transformers.models.mllama.modeling_mllama import _prepare_cross_attention_mask from .configuration_llama3 import Llama3Config from .mllama_audio_model import Llama3Embedding logger = logging.get_logger(__name__) class Llama3PreTrainedModel(PreTrainedModel): config_class = Llama3Config base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["LlamaDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True _supports_sdpa = True _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True def _init_weights(self, module): std = self.config.get_text_config().initializer_range if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() class Llama3ForCausalLM(Llama3PreTrainedModel, GenerationMixin): config_class = Llama3Config base_model_prefix = "model" _supports_quantized_cache = False # quant cache not supported in encoder-decoder setting def __init__(self, config: Llama3Config): super().__init__(config) self.vocab_size = config.text_config.vocab_size self.hidden_size = config.text_config.hidden_size self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 self.language_model = LlamaForCausalLM._from_config(config.text_config) self.embed_tokens = Llama3Embedding(config) self.post_init() def get_input_embeddings(self): return self.language_model.get_input_embeddings() def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) def get_output_embeddings(self): return self.language_model.get_output_embeddings() def set_output_embeddings(self, new_embeddings): self.language_model.set_output_embeddings(new_embeddings) def set_decoder(self, decoder): self.language_model.set_decoder(decoder) def get_decoder(self): return self.language_model.get_decoder() def tie_weights(self): return self.language_model.tie_weights() def forward( self, input_ids: Optional[torch.LongTensor] = None, audio_features: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.Tensor] = None, cross_attention_mask: Optional[torch.Tensor] = None, cross_attention_states: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, num_logits_to_keep: int = 0, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. num_logits_to_keep (`int`, *optional*): Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. Returns: Example: ```python >>> from PIL import Image >>> import requests >>> from transformers import AutoProcessor, MllamaForConditionalGeneration >>> checkpoint = "meta-llama/Llama-3.2-11B-Vision" >>> model = MllamaForConditionalGeneration.from_pretrained(checkpoint) >>> processor = AutoProcessor.from_pretrained(checkpoint) >>> prompt = "<|image|>If I had to write a haiku for this one" >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" >>> image = Image.open(requests.get(url, stream=True).raw) >>> inputs = processor(text=prompt, images=image, return_tensors="pt") >>> # Generate >>> output = model.generate(**inputs, max_new_tokens=15) >>> prompt_len = inputs.input_ids.shape[-1] >>> generated_ids = output[:, prompt_len:] >>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) >>> print(generated_text) [', it would be:.\\nA stop sign in Chinatown.\\n'] ``` """ output_attentions = output_attentions if output_attentions is not None else self.config.text_config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.text_config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.text_config.use_return_dict if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") if inputs_embeds is None: input_embeddings = self.get_input_embeddings()(input_ids.clamp_min(0).detach()) inputs_embeds = self.embed_tokens(input_ids=input_ids, input_embeddings=input_embeddings, audio_features=audio_features) outputs = self.language_model( input_ids=None, attention_mask=attention_mask, position_ids=position_ids, cross_attention_states=cross_attention_states, cross_attention_mask=cross_attention_mask, full_text_row_masked_out_mask=None, past_key_values=past_key_values, use_cache=use_cache, inputs_embeds=inputs_embeds, labels=labels, output_hidden_states=output_hidden_states, output_attentions=output_attentions, return_dict=return_dict, cache_position=cache_position, num_logits_to_keep=num_logits_to_keep, ) return outputs def prepare_inputs_for_generation( self, input_ids=None, audio_features=None, inputs_embeds=None, attention_mask=None, position_ids=None, aspect_ratio_ids=None, aspect_ratio_mask=None, cross_attention_mask=None, past_key_values=None, use_cache=False, cache_position=None, num_logits_to_keep=None, **kwargs, ): # Overwritten -- in specific circumstances we don't want to forward image inputs to the model # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens # Exception 1: when passing input_embeds, input_ids may be missing entries # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here if past_key_values is not None: if inputs_embeds is not None: # Exception 1 input_ids = input_ids[:, -cache_position.shape[0] :] elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) input_ids = input_ids[:, cache_position] # TODO: we have no attention_mask so this won't work, check if we really won't need attention mask and find another way if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: position_ids = position_ids[:, -input_ids.shape[1] :] # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture. position_ids = position_ids.clone(memory_format=torch.contiguous_format) # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and cache_position[0] == 0: model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} else: # The clone here is for the same reason as for `position_ids`. model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} if num_logits_to_keep is not None: model_inputs["num_logits_to_keep"] = num_logits_to_keep model_inputs.update( { "audio_features":audio_features, "position_ids": position_ids, "cache_position": cache_position, "past_key_values": past_key_values, "use_cache": use_cache, "attention_mask": attention_mask, "cross_attention_mask": cross_attention_mask, } ) # If we're in pre-fill or cacheless decoding step, then we need pixel_values and aspect ratios # to compute image hidden states, otherwise they are cached within each cross attn layer if cache_position[0] == 0: model_inputs["aspect_ratio_ids"] = aspect_ratio_ids model_inputs["aspect_ratio_mask"] = aspect_ratio_mask return model_inputs def _update_model_kwargs_for_generation(self, outputs, model_kwargs, is_encoder_decoder, **kwargs): cross_attention_mask_prev = model_kwargs.get("cross_attention_mask", None) model_kwargs = super()._update_model_kwargs_for_generation( outputs=outputs, model_kwargs=model_kwargs, is_encoder_decoder=is_encoder_decoder, **kwargs, ) # add cross-attn mask for new token if cross_attention_mask_prev is not None: model_kwargs["cross_attention_mask"] = torch.cat( [cross_attention_mask_prev, cross_attention_mask_prev[:, -1:, ...]], dim=1 ) return model_kwargs AutoModelForCausalLM.register(Llama3Config, Llama3ForCausalLM) transformers.Llama3ForCausalLM = Llama3ForCausalLM