from typing import Dict, Optional, Tuple, List, Any, Union import torch from torch import nn import torch.nn.functional as F from .eva_agg_kernel import triton_eva_agg_fwd from .eva_prep_kv_kernel import triton_eva_prep_kv_fwd try: import triton USE_TRITON_IMPL = True except ImportError: USE_TRITON_IMPL = False raise ImportError("Triton is not installed. Please install it by running `pip install triton`.") def rotate_half(x: torch.Tensor) -> torch.Tensor: """ Rotates half the hidden dims (last dim) of the input. Args: x: Rotary embedded tensor Return: Tensor with half of last dim negated and rotated to the front. """ x1, x2 = x.split(x.shape[-1] // 2, dim=-1) return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, position_ids: torch.Tensor) -> torch.Tensor: """ Apply rotary embedding (cos, sin) to the query and key tensor on the sequence dimension. The legends for dimensions are defined as: num_heads: number of attention heads current_seq_len: the current batch's sequence length, should be either 1 or max_seq_len max_seq_len: the static sequence length, different from current_seq_len in cached inference case where it is always maximum lenghth, e.g. the length of static sequence length of KV cache Args: q: Query tensor, of size (batch_size, num_heads, current_seq_len, head_dim) k: Key tensor, of size (batch_size, num_key_value_heads, current_seq_len, head_dim) cos: Cosine base of rotary embedding, of size (max_seq_len, head_dim) sin: Sine base of rotary embedding, of size (max_seq_len, head_dim) position_ids: The position indices of the tokens corresponding to the query and key tensors. It has a size of (batch_size, current_seq_len). Returns: Embedded query and key tensor of same size as input. """ bs, nheads, cur_seq_len, head_dim = q.shape assert len( k.shape) == 4, f"k should be of shape (batch_size, num_heads, current_seq_len, head_dim), got {k.shape} instead" assert k.shape[0] == bs, f"k has a different batch_size {k.shape[0]} compared to q {bs}" assert list(k.shape[2:]) == [cur_seq_len, head_dim], f"k has different current_seq_len and/or head_dim compared to q" assert cos.shape[3] == head_dim, f"cos should have dim of head dim {head_dim}, got {cos.shape[3]} instead" assert list(position_ids.shape) in [[bs, cur_seq_len], [1, cur_seq_len]],\ f"position_ids should be of shape {[bs, cur_seq_len]} or {[1, cur_seq_len]}, got {position_ids.shape} instead" q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed class EvaAttention(nn.Module): """ Causal EVA for language modeling. """ def __init__(self, config, layer_idx: Optional[int] = None): super().__init__() self.config = config self.layer_idx = layer_idx self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.hidden_size // self.num_heads self.head_dim_scaling = self.head_dim ** -0.5 self.max_position_embeddings = config.max_position_embeddings if (self.head_dim * self.num_heads) != self.hidden_size: raise ValueError( f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" f" and `num_heads`: {self.num_heads})." ) self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) self.window_size = config.window_size self.num_chunks = config.num_chunks self.chunk_size = config.chunk_size if self.chunk_size is not None: assert self.window_size >= self.chunk_size and self.window_size % self.chunk_size == 0 # chunk_size overrides the number of landmarks self.num_chunks = None self.chunks_per_window = int(self.window_size // self.chunk_size) self.adaptive_phi = nn.Parameter( torch.randn( 1, self.num_heads, 1, 1, self.head_dim ).clamp(-1., 1.) * self.head_dim_scaling ) self.adaptive_mu_k = nn.Parameter( torch.randn( 1, self.num_heads, 1, 1, self.head_dim ).clamp(-1., 1.) * self.head_dim_scaling ) def _triton_forward( self, hidden_states: torch.Tensor, attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, cos: Optional[torch.Tensor] = None, sin: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: assert not output_attentions bsz, q_len, _ = hidden_states.size() if use_cache and past_key_value is None: raise ValueError assert isinstance(attention_mask, tuple) # infer the model's running mode is_prefilling = use_cache and past_key_value.get_seq_length(self.layer_idx) == 0 is_decoding = use_cache and past_key_value.get_seq_length(self.layer_idx) > 0 if is_prefilling: assert len(attention_mask) == 2 window_mask, intra_chunk_mask = attention_mask chunk_dummpy_mask = None elif is_decoding: assert len(attention_mask) == 3 window_mask, intra_chunk_mask, chunk_dummpy_mask = attention_mask else: window_mask, intra_chunk_mask = attention_mask chunk_dummpy_mask = None ############################################ # compute q, k, v from hidden states ############################################ # [b, h, q_len, d] q = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) # [b, h, kv_len, d] k = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) # [b, h, kv_len, d] v = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) if use_cache: past_key_value.update_past_len(q.shape[-2], self.layer_idx) ############################################ # apply rotary positional embeddings to q, k ############################################ q, k = apply_rotary_pos_emb(q, k, cos, sin, position_ids) ############################################ # update and get cached singleton tokens # update and cache k and v for calculating chunk-level RFAs ############################################ if use_cache: s_k, s_v, dump_k, dump_v = past_key_value.update_singletons_and_chunks( k, v, self.layer_idx, self.window_size, ) else: s_k, s_v = k, v dump_k, dump_v = k, v if use_cache: singleton_mask, dump_rf_mask = past_key_value.update_mask( s_mask=window_mask, rf_mask=intra_chunk_mask, layer_idx=self.layer_idx, window_size=self.window_size, ) else: singleton_mask = window_mask dump_rf_mask = intra_chunk_mask if dump_k is not None and dump_v is not None: # 1. in prefilling, the input shape is # dump_k/dump_v: [b, h, n, d] # rfa_k/rfa_v: [b, h, n // c, d] # 2. in decoding, the input shape is # k/v: [b, h, w, d] # rfa_k/rfa_v: [b, h, w//c, d] # 3. in forward inference; the seq_len is already divisible rfa_k, rfa_v = triton_eva_prep_kv_fwd( dump_k, dump_v, self.adaptive_mu_k, self.adaptive_phi, dump_rf_mask, self.head_dim_scaling, self.chunk_size ) # rfa_mask = get_rfa_chunk_mask(dump_rf_mask) if use_cache: rfa_k, rfa_v = past_key_value.update_chunk_rfas( rfa_k, rfa_v, self.layer_idx ) elif use_cache: # if there are not enough elements within the last chunk, # we will only use the cached chunk-level RFAs rfa_k, rfa_v = past_key_value.get_chunk_rfas(self.layer_idx) else: rfa_k, rfa_v = None, None ############################################ # compute the full attention output ############################################ if is_prefilling: # prefilling # 1. in prefilling, the input shape is # q: [b, h, n, d] # k/v: [b, h, n, d] # rfa_k/rfa_v: [b, h, n // c, d] attn_output = triton_eva_agg_fwd( q, s_k, s_v, rfa_k, rfa_v, singleton_mask, self.head_dim_scaling, self.window_size, self.chunks_per_window ) elif is_decoding: # 2. in decoding, the input shape is # q: [b, h, 1, d] or [b, h, z, d] (for multi-byte prediction) # k/v: [b, h, 1 + s, d] # rfa_k/rfa_v: [b, h, n // c, d] if rfa_k is not None and rfa_v is not None: # we only take the chunk-level RFAs not in the current window seen_seq_len = past_key_value.get_seq_length(self.layer_idx) if seen_seq_len <= self.window_size: agg_k = s_k agg_v = s_v attn_mask = singleton_mask else: # NOTE: we already updated the cache so the length now # includes the current token # we subtract 1 from seen_seq_len because we want # if seen_seq_len = 2048 -> num_windows_seen_so_far = 0 # if seen_seq_len = 4096 -> num_windows_seen_so_far = 1 # if seen_seq_len = 4097 -> num_windows_seen_so_far = 2 # NOTE the cat order should be taken care of; # should align with the order based on which # the attention mask is constructed num_windows_seen_so_far = (seen_seq_len - 1) // self.window_size agg_k = torch.cat([s_k, rfa_k[..., :num_windows_seen_so_far * self.chunks_per_window, :]], dim=-2) agg_v = torch.cat([s_v, rfa_v[..., :num_windows_seen_so_far * self.chunks_per_window, :]], dim=-2) if singleton_mask is not None: assert chunk_dummpy_mask is not None attn_mask = torch.cat([singleton_mask, chunk_dummpy_mask], dim=-1) else: attn_mask = singleton_mask else: agg_k = s_k agg_v = s_v attn_mask = singleton_mask attn_output = F.scaled_dot_product_attention( q, agg_k, agg_v, attn_mask=attn_mask, is_causal=False, dropout_p=0.0, scale=self.head_dim_scaling ) else: # 3. in single-forward inference attn_output = triton_eva_agg_fwd( q, s_k, s_v, rfa_k, rfa_v, singleton_mask, self.head_dim_scaling, self.window_size, self.chunks_per_window ) if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): raise ValueError( f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" f" {attn_output.size()}" ) attn_output = attn_output.transpose(1, 2).reshape(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) attn_weights = None return attn_output, attn_weights, past_key_value def _multibyte_decoding_forward( self, hidden_states: torch.Tensor, attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, cos: Optional[torch.Tensor] = None, sin: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: # during multi-byte forwarding, we only read caches and do not update them assert not output_attentions bsz, q_len, _ = hidden_states.size() if use_cache and past_key_value is None: raise ValueError assert USE_TRITON_IMPL assert isinstance(attention_mask, torch.Tensor) and attention_mask.dtype == torch.bool assert use_cache and past_key_value.get_seq_length(self.layer_idx) > 0 ############################################ # compute q, k, v from hidden states ############################################ # [b, h, q_len, d] q = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) # [b, h, kv_len, d] k = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) # [b, h, kv_len, d] v = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) ############################################ # apply rotary positional embeddings to q, k ############################################ q, k = apply_rotary_pos_emb(q, k, cos, sin, position_ids) ############################################ # update and get cached singleton tokens ############################################ input_len = k.shape[-2] window_pos = past_key_value.past_window_pos[self.layer_idx] new_window_pos = window_pos + input_len past_key_value.past_window_k[self.layer_idx][:, :, window_pos : new_window_pos, :] = k past_key_value.past_window_v[self.layer_idx][:, :, window_pos : new_window_pos, :] = v s_k = past_key_value.past_window_k[self.layer_idx][:, :, : new_window_pos, :] s_v = past_key_value.past_window_v[self.layer_idx][:, :, : new_window_pos, :] rfa_k, rfa_v = past_key_value.get_chunk_rfas(self.layer_idx) ############################################ # compute the full attention output ############################################ # 2. in decoding, the input shape is # q: [b, h, 1, d] or [b, h, z, d] (for multi-byte prediction) # k/v: [b, h, 1 + s, d] # rfa_k/rfa_v: [b, h, n // c, d] if rfa_k is not None and rfa_v is not None: # NOTE the cat order should be taken care of; # should align with the order based on which # the attention mask is constructed # agg_k = torch.cat([s_k, rfa_k], dim=-2) # agg_v = torch.cat([s_v, rfa_v], dim=-2) agg_k = torch.cat([rfa_k, s_k], dim=-2) agg_v = torch.cat([rfa_v, s_v], dim=-2) else: agg_k = s_k agg_v = s_v attn_output = F.scaled_dot_product_attention( q, agg_k, agg_v, attn_mask=attention_mask, is_causal=False, dropout_p=0.0, scale=self.head_dim_scaling ) if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): raise ValueError( f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" f" {attn_output.size()}" ) attn_output = attn_output.transpose(1, 2).reshape(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) attn_weights = None return attn_output, attn_weights, past_key_value def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, cos: Optional[torch.Tensor] = None, sin: Optional[torch.Tensor] = None, multibyte_decoding: Optional[bool] = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: assert not output_attentions if use_cache and past_key_value is None: raise ValueError assert USE_TRITON_IMPL if use_cache and multibyte_decoding: return self._multibyte_decoding_forward( hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cos=cos, sin=sin, ) else: return self._triton_forward( hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cos=cos, sin=sin, )