import math import warnings from typing import List, Optional, Tuple, Union import torch import torch.utils.checkpoint from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from huggingface_hub import snapshot_download from transformers.modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast, TokenClassifierOutput, ) from transformers.modeling_utils import PreTrainedModel from transformers import Phi3Config, Phi3Model from transformers.cache_utils import Cache, DynamicCache, StaticCache from transformers.utils import logging logger = logging.get_logger(__name__) class Phi3Transformer(Phi3Model): """ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Phi3DecoderLayer`] We only modified the attention mask Args: config: Phi3Config """ def prefetch_layer(self, layer_idx: int, device: torch.device): "Starts prefetching the next layer cache" with torch.cuda.stream(self.prefetch_stream): # Prefetch next layer tensors to GPU for name, param in self.layers[layer_idx].named_parameters(): param.data = param.data.to(device, non_blocking=True) def evict_previous_layer(self, layer_idx: int): "Moves the previous layer cache to the CPU" prev_layer_idx = layer_idx - 1 for name, param in self.layers[prev_layer_idx].named_parameters(): param.data = param.data.to("cpu", non_blocking=True) def get_offlaod_layer(self, layer_idx: int, device: torch.device): # init stream if not hasattr(self, "prefetch_stream"): self.prefetch_stream = torch.cuda.Stream() # delete previous layer torch.cuda.current_stream().synchronize() self.evict_previous_layer(layer_idx) # make sure the current layer is ready torch.cuda.synchronize(self.prefetch_stream) # load next layer self.prefetch_layer((layer_idx + 1) % len(self.layers), device) def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, offload_model: Optional[bool] = False, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False # kept for BC (non `Cache` `past_key_values` inputs) return_legacy_cache = False if use_cache and not isinstance(past_key_values, Cache): return_legacy_cache = True if past_key_values is None: past_key_values = DynamicCache() else: past_key_values = DynamicCache.from_legacy_cache(past_key_values) logger.warning_once( "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " "(https://huggingface.co./docs/transformers/kv_cache#legacy-cache-format)" ) # if inputs_embeds is None: # inputs_embeds = self.embed_tokens(input_ids) # if cache_position is None: # past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 # cache_position = torch.arange( # past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device # ) # if position_ids is None: # position_ids = cache_position.unsqueeze(0) if attention_mask is not None and attention_mask.dim() == 3: dtype = inputs_embeds.dtype min_dtype = torch.finfo(dtype).min attention_mask = (1 - attention_mask) * min_dtype attention_mask = attention_mask.unsqueeze(1).to(inputs_embeds.dtype) else: raise # causal_mask = self._update_causal_mask( # attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions # ) hidden_states = inputs_embeds # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None next_decoder_cache = None layer_idx = -1 for decoder_layer in self.layers: layer_idx += 1 if output_hidden_states: all_hidden_states += (hidden_states,) if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, attention_mask, position_ids, past_key_values, output_attentions, use_cache, cache_position, ) else: if offload_model and not self.training: self.get_offlaod_layer(layer_idx, device=inputs_embeds.device) layer_outputs = decoder_layer( hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, ) hidden_states = layer_outputs[0] if use_cache: next_decoder_cache = layer_outputs[2 if output_attentions else 1] if output_attentions: all_self_attns += (layer_outputs[1],) hidden_states = self.norm(hidden_states) # add hidden states from the last decoder layer if output_hidden_states: print('************') all_hidden_states += (hidden_states,) next_cache = next_decoder_cache if use_cache else None if return_legacy_cache: next_cache = next_cache.to_legacy_cache() if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attns, )