import json import torch import transformers from transformers.cache_utils import * from transformers.models.llama.modeling_llama import * from .modules.inf_llm import InfLLMGenerator, inf_llm_forward from .modules.minference_forward import ( gather_last_q_vertical_slash_topk_v4, gather_last_q_vertical_slash_topk_vllm, init_minference_parameters, minference_forward, minference_kv_cache_cpu_forward, minference_vllm_forward, minference_with_snapkv_forward, search_pattern, sum_all_diagonal_matrix, ) from .ops.streaming_kernel import stream_llm_forward class RotaryEmbeddingESM(torch.nn.Module): """ Rotary position embeddings based on those in [RoFormer](https://huggingface.co./docs/transformers/model_doc/roformer). Query and keys are transformed by rotation matrices which depend on their relative positions. """ def __init__( self, dim: int, base: Union[int, float] = 10000, distance_scale: Union[int, float] = 1, ): super().__init__() self.base = base self.distance_scale = distance_scale # Generate and save the inverse frequency buffer (non trainable) inv_freq = 1.0 / ( base ** (torch.arange(0, dim, 2, device="cuda", dtype=torch.float32) / dim) ) self.register_buffer("inv_freq", inv_freq, persistent=False) self._seq_len_cached = -1 self._cos_cached = None self._sin_cached = None def rotate_half(self, x): x1, x2 = x.chunk(2, dim=-1) return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(self, x, length, right, cos, sin): dtype = x.dtype if cos.dim() == 2: cos = cos[right - length : right, :] sin = sin[right - length : right, :] elif cos.dim() == 3: cos = cos[:, right - length : right, :] sin = sin[:, right - length : right, :] elif cos.dim() == 4: cos = cos[:, :, right - length : right, :] sin = sin[:, :, right - length : right, :] return ((x.float() * cos) + (self.rotate_half(x).float() * sin)).to(dtype) def _update_cos_sin_tables(self, x, seq_dim): seq_len = x.size(seq_dim) if seq_len > self._seq_len_cached: self._seq_len_cached = seq_len t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq) freqs = torch.outer(t * self.distance_scale, self.inv_freq) emb = torch.cat((freqs, freqs), dim=-1) if x.dim() == 2: self._cos_cached = emb.cos() self._sin_cached = emb.sin() elif x.dim() == 3: self._cos_cached = emb.cos()[None, :, :] self._sin_cached = emb.sin()[None, :, :] elif x.dim() == 4: self._cos_cached = emb.cos()[None, None, :, :] self._sin_cached = emb.sin()[None, None, :, :] return self._cos_cached, self._sin_cached def _update_cos_sin_tables_len(self, seq_len, device, dim=None): if seq_len > self._seq_len_cached: if dim is None: assert self._cos_cached is not None dim = self._cos_cached.dim() self._seq_len_cached = seq_len t = torch.arange(seq_len, device=device).type_as(self.inv_freq) freqs = torch.outer(t * self.distance_scale, self.inv_freq) emb = torch.cat((freqs, freqs), dim=-1) if dim == 2: self._cos_cached = emb.cos() self._sin_cached = emb.sin() elif dim == 3: self._cos_cached = emb.cos()[None, :, :] self._sin_cached = emb.sin()[None, :, :] elif dim == 4: self._cos_cached = emb.cos()[None, None, :, :] self._sin_cached = emb.sin()[None, None, :, :] return self._cos_cached, self._sin_cached def apply_rotary_pos_emb_one_angle(self, x: torch.Tensor, index): dtype = x.dtype cos, sin = self._update_cos_sin_tables_len(index, x.device) if cos.dim() == 2: cos = cos[index - 1 : index, :] sin = sin[index - 1 : index, :] elif cos.dim() == 3: cos = cos[:, index - 1 : index, :] sin = sin[:, index - 1 : index, :] elif cos.dim() == 4: cos = cos[:, :, index - 1 : index, :] sin = sin[:, :, index - 1 : index, :] return ((x.float() * cos) + (self.rotate_half(x).float() * sin)).to(dtype) def forward( self, q: torch.Tensor, k: torch.Tensor, seq_dim=-2 ) -> Tuple[torch.Tensor, torch.Tensor]: self._cos_cached, self._sin_cached = self._update_cos_sin_tables( k, seq_dim=seq_dim ) return ( self.apply_rotary_pos_emb( q, q.size(seq_dim), k.size(seq_dim), self._cos_cached, self._sin_cached ), self.apply_rotary_pos_emb( k, k.size(seq_dim), k.size(seq_dim), self._cos_cached, self._sin_cached ), ) ATTN_FORWRAD = { "streaming": stream_llm_forward, "minference": minference_forward, "inf_llm": inf_llm_forward, } def huggingface_forward(forward): def hf_forward( self, hidden_states: torch.Tensor, attention_mask=None, position_ids=None, past_key_value=None, output_attentions: bool = False, use_cache: bool = False, **kwargs, ): assert not output_attentions ret = forward( self, hidden_states, hidden_states, position_ids, use_cache, past_key_value, self.q_proj, self.k_proj, self.v_proj, self.o_proj, self.head_dim, self.num_heads, self.num_key_value_heads, ) if use_cache: o, pkv = ret else: o = ret pkv = None return o, None, pkv return hf_forward def hf_437_prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs, ): if past_key_values is not None: if isinstance(past_key_values, transformers.cache_utils.Cache): cache_length = past_key_values.get_seq_length() past_length = past_key_values.seen_tokens max_cache_length = past_key_values.get_max_length() else: cache_length = past_length = past_key_values[0][0].shape[2] max_cache_length = None # Keep only the unprocessed tokens: # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as # input) if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard # input_ids based on the past_length. elif past_length < input_ids.shape[1]: input_ids = input_ids[:, past_length:] # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. if ( max_cache_length is not None and attention_mask is not None and cache_length + input_ids.shape[1] > max_cache_length ): attention_mask = attention_mask[:, -max_cache_length:] position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: position_ids = position_ids[:, -input_ids.shape[1] :] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} else: model_inputs = {"input_ids": input_ids} model_inputs.update( { "position_ids": position_ids, "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache"), "attention_mask": attention_mask, } ) return model_inputs def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, **kwargs, ): # With static cache, the `past_key_values` is None # TODO joao: standardize interface for the different Cache classes and remove of this if has_static_cache = False if past_key_values is None: past_key_values = getattr( getattr(self.model.layers[0], "self_attn", {}), "past_key_value", None ) has_static_cache = past_key_values is not None past_length = 0 if past_key_values is not None: if isinstance(past_key_values, transformers.cache_utils.Cache): past_length = ( cache_position[0] if cache_position is not None else past_key_values.get_seq_length() ) max_cache_length = ( torch.tensor(past_key_values.get_max_length(), device=input_ids.device) if past_key_values.get_max_length() is not None else None ) cache_length = ( past_length if max_cache_length is None else torch.min(max_cache_length, past_length) ) # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects else: # cache_length = past_length = past_key_values[0][0].shape[2] cache_length = past_length = cache_position[0] max_cache_length = None # Keep only the unprocessed tokens: # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as # input) if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard # input_ids based on the past_length. elif past_length < input_ids.shape[1]: input_ids = input_ids[:, past_length:] # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. if ( max_cache_length is not None and attention_mask is not None and cache_length + input_ids.shape[1] > max_cache_length ): attention_mask = attention_mask[:, -max_cache_length:] position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: position_ids = position_ids[:, -input_ids.shape[1] :] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} else: # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114 # TODO: use `next_tokens` directly instead. model_inputs = {"input_ids": input_ids.contiguous()} input_length = ( position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] ) if cache_position is None: cache_position = torch.arange( past_length, past_length + input_length, device=input_ids.device ) else: cache_position = cache_position[-input_length:] if has_static_cache: past_key_values = None model_inputs.update( { "position_ids": position_ids, "cache_position": cache_position, "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache"), "attention_mask": attention_mask, } ) return model_inputs def prepare_inputs_for_generation_snapkv( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs, ): if past_key_values is None: # [SnapKV] for layer in self.model.layers: layer.self_attn.kv_seq_len = 0 if past_key_values is not None: if isinstance(past_key_values, Cache): cache_length = past_key_values.get_seq_length() past_length = past_key_values.seen_tokens max_cache_length = past_key_values.get_max_length() else: # cache_length = past_length = past_key_values[0][0].shape[2] # max_cache_length = None cache_length = past_length = self.model.layers[0].self_attn.kv_seq_len max_cache_length = None # Keep only the unprocessed tokens: # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as # input) if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard # input_ids based on the past_length. elif past_length < input_ids.shape[1]: input_ids = input_ids[:, past_length:] # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. if ( max_cache_length is not None and attention_mask is not None and cache_length + input_ids.shape[1] > max_cache_length ): attention_mask = attention_mask[:, -max_cache_length:] position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: position_ids = position_ids[:, -input_ids.shape[1] :] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} else: model_inputs = {"input_ids": input_ids} model_inputs.update( { "position_ids": position_ids, "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache"), "attention_mask": attention_mask, } ) return model_inputs def _prepare_decoder_attention_mask_inference( self, attention_mask, input_shape, inputs_embeds, past_key_values_length ): # [bsz, seq_len] if past_key_values_length > 0 and attention_mask is not None: attention_mask = torch.cat( ( torch.full( (input_shape[0], past_key_values_length), True, dtype=attention_mask.dtype, device=attention_mask.device, ), attention_mask, ), dim=-1, ) if attention_mask is not None and torch.all(attention_mask): return None # This uses the faster call when training with full samples return attention_mask def forward_llama_decoder_layer( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, padding_mask: Optional[torch.LongTensor] = None, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states """ residual = hidden_states.clone() batch, seq_len, embed_dim = hidden_states.shape for start_idx in range(0, seq_len, 32000): end_idx = min(seq_len, start_idx + 32000) hidden_states[:, start_idx:end_idx, :] = self.input_layernorm( hidden_states[:, start_idx:end_idx, :] ) # Self Attention hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, padding_mask=padding_mask, ) hidden_states = residual + hidden_states # Fully Connected for start_idx in range(0, seq_len, 32000): end_idx = min(seq_len, start_idx + 32000) part_hidden_states = hidden_states[:, start_idx:end_idx, :].clone() part_hidden_states = self.post_attention_layernorm(part_hidden_states) part_hidden_states = self.mlp(part_hidden_states) hidden_states[:, start_idx:end_idx, :] += part_hidden_states outputs = (hidden_states,) if output_attentions: outputs += (self_attn_weights,) if use_cache: outputs += (present_key_value,) return outputs def forward_llama_model( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) # retrieve input_ids and inputs_embeds if input_ids is not None and inputs_embeds is not None: raise ValueError( "You cannot specify both input_ids and inputs_embeds at the same time" ) elif input_ids is not None: batch_size, seq_length = input_ids.shape[:2] elif inputs_embeds is not None: batch_size, seq_length = inputs_embeds.shape[:2] else: raise ValueError("You have to specify either input_ids or inputs_embeds") if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False seq_length_with_past = seq_length past_key_values_length = 0 if use_cache: use_legacy_cache = not isinstance(past_key_values, Cache) if use_legacy_cache: past_key_values = DynamicCache.from_legacy_cache(past_key_values) past_key_values_length = past_key_values.get_usable_length(seq_length) seq_length_with_past = seq_length_with_past + past_key_values_length if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = torch.arange( past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device, ) position_ids = position_ids.unsqueeze(0) if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) if attention_mask is None: attention_mask = torch.ones( (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device, ) padding_mask = None else: if 0 in attention_mask: padding_mask = attention_mask else: padding_mask = None attention_mask = self._prepare_decoder_attention_mask( attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length ) # embed positions hidden_states = inputs_embeds # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None next_decoder_cache = None for decoder_layer in self.layers: if output_hidden_states: all_hidden_states += (hidden_states,) if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, attention_mask, position_ids, past_key_values, output_attentions, use_cache, ) else: layer_outputs = decoder_layer( hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, ) hidden_states = layer_outputs[0] if use_cache: next_decoder_cache = layer_outputs[2 if output_attentions else 1] if output_attentions: all_self_attns += (layer_outputs[1],) batch, seq_len, embed_dim = hidden_states.shape for start_idx in range(0, seq_len, 32000): end_idx = min(seq_len, start_idx + 32000) hidden_states[:, start_idx:end_idx, :] = self.norm( hidden_states[:, start_idx:end_idx, :] ) # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) next_cache = None if use_cache: next_cache = ( next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache ) if not return_dict: return tuple( v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None ) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attns, ) def forward_llama_for_causal_lm( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: # assert labels is not None output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) torch.cuda.empty_cache() hidden_states = outputs[0] if labels is not None: loss_fct = CrossEntropyLoss(reduction="sum") valid_seq_len = input_ids.shape[-1] - 1 valid_seq_len_slide_win = torch.sum(labels[:, 1:] >= 0).item() # print("valid_seq_len_slide_win", valid_seq_len) loss = 0.0 for start_idx in range(0, valid_seq_len, 32000): end_idx = min(start_idx + 32000, valid_seq_len) shift_logits = self.lm_head( hidden_states[..., start_idx:end_idx, :] ).float() shift_labels = labels[..., start_idx + 1 : end_idx + 1].contiguous() # Flatten the tokens shift_logits = shift_logits.view(-1, self.config.vocab_size) shift_labels = shift_labels.view(-1) # Enable model parallelism shift_labels = shift_labels.to(shift_logits.device) loss += loss_fct(shift_logits, shift_labels) loss /= valid_seq_len_slide_win logits = None else: if self.config.to_dict().get("is_ppl", False): logits = self.lm_head(hidden_states) else: logits = self.lm_head(hidden_states[:, -1:]).float() loss = None return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, ) def minference_patch(model, config): from transformers import LlamaForCausalLM if config.kv_cache_cpu: return minference_patch_kv_cache_cpu(model) if config.use_snapkv: return minference_patch_with_snapkv(model) Attention = model.model.layers[0].self_attn.__class__ Model = model.model.__class__ DecoderLayer = model.model.layers[0].__class__ forward = minference_forward() def update_module(m): if isinstance(m, Attention): m.init_minference_parameters = init_minference_parameters.__get__( m, Attention ) m.gather_last_q_vertical_slash_topk_v4 = ( gather_last_q_vertical_slash_topk_v4.__get__(m, Attention) ) m.forward = forward.__get__(m, Attention) if isinstance(m, DecoderLayer): m.forward = forward_llama_decoder_layer.__get__(m, DecoderLayer) model.apply(update_module) model.prepare_inputs_for_generation = hf_437_prepare_inputs_for_generation.__get__( model, model.__class__ ) model.model._use_sdpa = False model.model._prepare_decoder_attention_mask = ( _prepare_decoder_attention_mask_inference.__get__( model.model, model.model.__class__ ) ) model.model.forward = forward_llama_model.__get__( model.model, model.model.__class__ ) model.forward = forward_llama_for_causal_lm.__get__(model, model.__class__) model.has_patch = True print("Patched model for minference..") return model def minference_patch_kv_cache_cpu(model): from transformers import LlamaForCausalLM transformers.cache_utils.DynamicCache.update = cpu_cache_update transformers.cache_utils.DynamicCache.get = cpu_cache_get Attention = model.model.layers[0].self_attn.__class__ Model = model.model.__class__ DecoderLayer = model.model.layers[0].__class__ forward = minference_kv_cache_cpu_forward() def update_module(m): if isinstance(m, Attention): m.init_minference_parameters = init_minference_parameters.__get__( m, Attention ) m.gather_last_q_vertical_slash_topk_v4 = ( gather_last_q_vertical_slash_topk_v4.__get__(m, Attention) ) m.forward = forward.__get__(m, Attention) if isinstance(m, DecoderLayer): m.forward = forward_llama_decoder_layer.__get__(m, DecoderLayer) model.apply(update_module) model.prepare_inputs_for_generation = hf_437_prepare_inputs_for_generation.__get__( model, model.__class__ ) model.model._use_sdpa = False model.model._prepare_decoder_attention_mask = ( _prepare_decoder_attention_mask_inference.__get__( model.model, model.model.__class__ ) ) model.model.forward = forward_llama_model.__get__( model.model, model.model.__class__ ) model.forward = forward_llama_for_causal_lm.__get__(model, model.__class__) print("Patched model for MInference load KV Cache to CPU.") return model def minference_patch_with_snapkv(model): from transformers import LlamaForCausalLM Attention = model.model.layers[0].self_attn.__class__ Model = model.model.__class__ DecoderLayer = model.model.layers[0].__class__ forward = minference_with_snapkv_forward() def update_module(m): if isinstance(m, Attention): m.init_minference_parameters = init_minference_parameters.__get__( m, Attention ) m.gather_last_q_vertical_slash_topk_v4 = ( gather_last_q_vertical_slash_topk_v4.__get__(m, Attention) ) m.forward = forward.__get__(m, Attention) if isinstance(m, DecoderLayer): m.forward = forward_llama_decoder_layer.__get__(m, DecoderLayer) model.apply(update_module) model.prepare_inputs_for_generation = prepare_inputs_for_generation_snapkv.__get__( model, model.__class__ ) model.model._use_sdpa = False model.model._prepare_decoder_attention_mask = ( _prepare_decoder_attention_mask_inference.__get__( model.model, model.model.__class__ ) ) model.model.forward = forward_llama_model.__get__( model.model, model.model.__class__ ) model.forward = forward_llama_for_causal_lm.__get__(model, model.__class__) print("Patched model for minference with SanpKV..") return model def llama_model_forward_vllm( self, input_ids: Optional[torch.Tensor], positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: if inputs_embeds is not None: hidden_states = inputs_embeds else: hidden_states = self.get_input_embeddings(input_ids) residual = None for i in range(len(self.layers)): layer = self.layers[i] hidden_states, residual = layer( positions, hidden_states, kv_caches[i], attn_metadata, residual, layer_idx=i, ) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states def llama_layer_forward_vllm( self, positions: torch.Tensor, hidden_states: torch.Tensor, kv_cache: torch.Tensor, attn_metadata, residual: Optional[torch.Tensor], layer_idx: int, ) -> Tuple[torch.Tensor, torch.Tensor]: # Self Attention if residual is None: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, kv_cache=kv_cache, attn_metadata=attn_metadata, layer_idx=layer_idx, ) # Fully Connected hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states = self.mlp(hidden_states) return hidden_states, residual def llama_attn_forward_vllm( self, positions: torch.Tensor, hidden_states: torch.Tensor, kv_cache: torch.Tensor, attn_metadata, layer_idx: int, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) attn_output = self.attn(q, k, v, kv_cache, attn_metadata, self.kv_scale, layer_idx) output, _ = self.o_proj(attn_output) return output def vllm_attn_forward( self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, kv_cache: Optional[torch.Tensor], attn_metadata, kv_scale: float = 1.0, layer_idx: int = 0, ) -> torch.Tensor: return self.impl.forward( query, key, value, kv_cache, attn_metadata, kv_scale, layer_idx ) def minference_patch_vllm( llm, config_file, ): from vllm.attention import Attention from vllm.model_executor.models.llama import ( LlamaAttention, LlamaDecoderLayer, LlamaForCausalLM, LlamaModel, ) config = json.load(open(config_file)) attn_forward = minference_vllm_forward(config) def update_module(m): if isinstance(m, Attention): m.forward = vllm_attn_forward.__get__(m, Attention) m = m.impl m_cls = m.__class__ m.gather_last_q_vertical_slash_topk_vllm = ( gather_last_q_vertical_slash_topk_vllm.__get__(m, m_cls) ) m.forward = attn_forward.__get__(m, m_cls) if isinstance(m, LlamaDecoderLayer): m.forward = llama_layer_forward_vllm.__get__(m, LlamaDecoderLayer) if isinstance(m, LlamaModel): m.forward = llama_model_forward_vllm.__get__(m, LlamaModel) if isinstance(m, LlamaAttention): m.forward = llama_attn_forward_vllm.__get__(m, LlamaAttention) llm.llm_engine.model_executor.driver_worker.model_runner.model.apply(update_module) print("Patched model for minference with VLLM..") return llm def patch_hf( model, attn_type: str = "inf_llm", attn_kwargs: dict = {}, base=None, distance_scale=None, **kwargs, ): attn_kwargs.update(kwargs) # This approach lacks scalability and will be refactored. from transformers import LlamaForCausalLM, MistralForCausalLM, Qwen2ForCausalLM from transformers.models.llama.modeling_llama import ( BaseModelOutputWithPast, LlamaAttention, LlamaModel, ) from transformers.models.mistral.modeling_mistral import ( MistralAttention, MistralModel, ) from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention, Qwen2Model def model_forward( self, input_ids: torch.LongTensor = None, attention_mask=None, position_ids=None, past_key_values=None, inputs_embeds=None, use_cache=None, output_attentions=None, output_hidden_states=None, return_dict=None, *args, **kwargs, ): output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) # retrieve input_ids and inputs_embeds if input_ids is not None and inputs_embeds is not None: raise ValueError( "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time" ) elif input_ids is not None: batch_size, seq_length = input_ids.shape elif inputs_embeds is not None: batch_size, seq_length, _ = inputs_embeds.shape else: raise ValueError( "You have to specify either decoder_input_ids or decoder_inputs_embeds" ) if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) if hasattr(self, "config") and hasattr(self.config, "scale_emb"): inputs_embeds = inputs_embeds * self.config.scale_emb if use_cache: pkv = tuple() else: pkv = None hidden_states = inputs_embeds # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None for i, decoder_layer in enumerate(self.layers): if output_hidden_states: all_hidden_states += (hidden_states,) layer_outputs = decoder_layer( hidden_states, attention_mask=attention_mask, position_ids=self.position_bias, past_key_value=( past_key_values[i] if past_key_values is not None else None ), output_attentions=output_attentions, use_cache=use_cache, ) hidden_states = layer_outputs[0] if use_cache: _cache = layer_outputs[2 if output_attentions else 1] pkv = pkv + (_cache,) if output_attentions: all_self_attns += (layer_outputs[1],) # hidden_states = self.norm(hidden_states) for start_idx in range(0, hidden_states.size(1), 32000): end_idx = min(hidden_states.size(1), start_idx + 32000) hidden_states[:, start_idx:end_idx, :] = self.norm( hidden_states[:, start_idx:end_idx, :] ) # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) if not return_dict: return tuple( v for v in [hidden_states, pkv, all_hidden_states, all_self_attns] if v is not None ) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=pkv, hidden_states=all_hidden_states, attentions=all_self_attns, ) forward = huggingface_forward(ATTN_FORWRAD[attn_type](**attn_kwargs)) if isinstance(model, LlamaForCausalLM): Attention = model.model.layers[0].self_attn.__class__ Model = model.model.__class__ elif isinstance(model, MistralForCausalLM): Attention = model.model.layers[0].self_attn.__class__ Model = model.model.__class__ elif isinstance(model, Qwen2ForCausalLM): Attention = model.model.layers[0].self_attn.__class__ Model = model.model.__class__ elif model.__class__.__name__ == "MiniCPMForCausalLM": Attention = model.model.layers[0].self_attn.__class__ Model = model.model.__class__ elif model.__class__.__name__ == "Phi3ForCausalLM": Attention = model.model.layers[0].self_attn.__class__ Model = model.model.__class__ else: raise ValueError("Only supports llama, mistral and qwen2 models.") hf_rope = model.model.layers[0].self_attn.rotary_emb base = base if base is not None else hf_rope.base distance_scale = distance_scale if distance_scale is not None else 1.0 rope = RotaryEmbeddingESM(hf_rope.dim, base, distance_scale) model.model.position_bias = rope model.model.hf_position_bias = hf_rope def set_forward(m): if isinstance(m, Attention): m._old_forward = m.forward m.forward = forward.__get__(m, Attention) model.apply(set_forward) model._old_prepare_inputs_for_generation = model.prepare_inputs_for_generation model.prepare_inputs_for_generation = prepare_inputs_for_generation.__get__( model, model.__class__ ) model.model._old_forward = model.model.forward model.model.forward = model_forward.__get__(model.model, Model) if attn_type == "inf_llm": tokenizer = transformers.AutoTokenizer.from_pretrained( model.config._name_or_path ) model = InfLLMGenerator(model, tokenizer) print("Patched model ...") return model def fp8_cache_update( self, key_states: torch.Tensor, value_states: torch.Tensor, layer_idx: int, cache_kwargs: Optional[Dict[str, Any]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. Parameters: key_states (`torch.Tensor`): The new key states to cache. value_states (`torch.Tensor`): The new value states to cache. layer_idx (`int`): The index of the layer to cache the states for. cache_kwargs (`Dict[str, Any]`, `optional`): Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. Return: A tuple containing the updated key and value states. """ # Update the number of seen tokens if layer_idx == 0: self.seen_tokens += key_states.shape[-2] # Update the cache if len(self.key_cache) <= layer_idx: self.key_cache.append(key_states.to(torch.float8_e5m2)) self.value_cache.append(value_states.to(torch.float8_e5m2)) else: self.key_cache[layer_idx] = torch.cat( [self.key_cache[layer_idx], key_states.to(torch.float8_e5m2)], dim=-2 ) self.value_cache[layer_idx] = torch.cat( [self.value_cache[layer_idx], value_states.to(torch.float8_e5m2)], dim=-2 ) return self.key_cache[layer_idx].to(key_states.dtype), self.value_cache[ layer_idx ].to(key_states.dtype) def cpu_cache_update( self, key_states: torch.Tensor, value_states: torch.Tensor, layer_idx: int, cache_kwargs: Optional[Dict[str, Any]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: if layer_idx == 0: if "_seen_tokens" in self.__dict__: self._seen_tokens += key_states.shape[-2] else: self.seen_tokens += key_states.shape[-2] # Update the cache if len(self.key_cache) <= layer_idx: self.key_cache.append(key_states.cpu()) self.value_cache.append(value_states.cpu()) else: self.key_cache[layer_idx] = torch.cat( [self.key_cache[layer_idx], key_states.cpu()], dim=-2 ) self.value_cache[layer_idx] = torch.cat( [self.value_cache[layer_idx], value_states.cpu()], dim=-2 ) def cpu_cache_get( self, key_states: torch.Tensor, value_states: torch.Tensor, layer_idx: int, head_idx: int, cache_kwargs: Optional[Dict[str, Any]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: if layer_idx == 0: if "_seen_tokens" in self.__dict__: self._seen_tokens += key_states.shape[-2] else: self.seen_tokens += key_states.shape[-2] # Update the cache if len(self.key_cache) <= layer_idx: return key_states, value_states else: key_states = torch.cat( [self.key_cache[layer_idx][:, head_idx : head_idx + 1].cuda(), key_states], dim=-2, ) value_states = torch.cat( [ self.value_cache[layer_idx][:, head_idx : head_idx + 1].cuda(), value_states, ], dim=-2, ) return key_states, value_states