|
from typing import Dict, Optional, Tuple, List, Any, Union |
|
import torch |
|
from transformers.cache_utils import Cache |
|
|
|
class EvaCache(Cache): |
|
""" |
|
A cache that grows dynamically as more tokens are generated. This is the default for generative models. |
|
|
|
It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is |
|
`[batch_size, num_heads, seq_len, head_dim]`. |
|
""" |
|
|
|
def __init__(self) -> None: |
|
self.w_k: List[torch.Tensor] = [] |
|
self.w_v: List[torch.Tensor] = [] |
|
|
|
self.rf_q: List[torch.Tensor] = [] |
|
self.rf_k: List[torch.Tensor] = [] |
|
self.rf_v: List[torch.Tensor] = [] |
|
|
|
self.softmax_phi_k_v: List[torch.Tensor] = [] |
|
self.log_sum_phi_k: List[torch.Tensor] = [] |
|
self.rf_k_bar: List[torch.Tensor] = [] |
|
self._seen_tokens = 0 |
|
|
|
|
|
self.rf_mask: List[Optional[torch.Tensor]] = [] |
|
self.s_mask: List[torch.Tensor] = [] |
|
self.chunk_mask: List[torch.Tensor] = [] |
|
|
|
def __len__(self): |
|
""" |
|
Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds |
|
to the number of layers in the model. |
|
""" |
|
return len(self.w_k) |
|
|
|
def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int: |
|
"""Given the sequence length of the new inputs, returns the usable length of the cache.""" |
|
|
|
|
|
|
|
max_length = self.get_max_length() |
|
previous_seq_length = self.get_seq_length(layer_idx) |
|
if max_length is not None and previous_seq_length + new_seq_length > max_length: |
|
return max_length - new_seq_length |
|
return previous_seq_length |
|
|
|
def reorder_cache(self, beam_idx: torch.LongTensor): |
|
"""Reorders the cache for beam search, given the selected beam indices.""" |
|
for layer_idx in range(len(self.w_k)): |
|
device = self.w_k[layer_idx].device |
|
self.w_k[layer_idx] = self.w_k[layer_idx].index_select(0, beam_idx.to(device)) |
|
|
|
device = self.w_v[layer_idx].device |
|
self.w_v[layer_idx] = self.w_v[layer_idx].index_select(0, beam_idx.to(device)) |
|
|
|
device = self.rf_q[layer_idx].device |
|
self.rf_q[layer_idx] = self.rf_q[layer_idx].index_select(0, beam_idx.to(device)) |
|
|
|
device = self.rf_k[layer_idx].device |
|
self.rf_k[layer_idx] = self.rf_k[layer_idx].index_select(0, beam_idx.to(device)) |
|
|
|
device = self.rf_v[layer_idx].device |
|
self.rf_v[layer_idx] = self.rf_v[layer_idx].index_select(0, beam_idx.to(device)) |
|
|
|
device = self.softmax_phi_k_v[layer_idx].device |
|
self.softmax_phi_k_v[layer_idx] = self.softmax_phi_k_v[layer_idx].index_select(0, beam_idx.to(device)) |
|
|
|
device = self.log_sum_phi_k[layer_idx].device |
|
self.log_sum_phi_k[layer_idx] = self.log_sum_phi_k[layer_idx].index_select(0, beam_idx.to(device)) |
|
|
|
device = self.rf_k_bar[layer_idx].device |
|
self.rf_k_bar[layer_idx] = self.rf_k_bar[layer_idx].index_select(0, beam_idx.to(device)) |
|
|
|
device = self.rf_mask[layer_idx].device |
|
self.rf_mask[layer_idx] = self.rf_mask[layer_idx].index_select(0, beam_idx.to(device)) |
|
|
|
device = self.s_mask[layer_idx].device |
|
self.s_mask[layer_idx] = self.s_mask[layer_idx].index_select(0, beam_idx.to(device)) |
|
|
|
device = self.chunk_mask[layer_idx].device |
|
self.chunk_mask[layer_idx] = self.chunk_mask[layer_idx].index_select(0, beam_idx.to(device)) |
|
@property |
|
def seen_tokens(self): |
|
if hasattr(self, "_seen_tokens"): |
|
return self._seen_tokens |
|
else: |
|
return None |
|
|
|
def update_past_len( |
|
self, |
|
cur_q_len: int, |
|
layer_idx: int |
|
): |
|
|
|
if layer_idx == 0: |
|
self._seen_tokens += cur_q_len |
|
return self._seen_tokens |
|
|
|
def update_mask( |
|
self, |
|
prev_s_mask, |
|
cur_s_mask, |
|
chunk_mask, |
|
rf_mask, |
|
layer_idx, |
|
window_size, |
|
chunk_size, |
|
): |
|
|
|
|
|
|
|
q_len = None |
|
if len(self.s_mask) <= layer_idx: |
|
q_len = chunk_mask.shape[-2] |
|
|
|
|
|
if q_len < window_size: |
|
assert prev_s_mask is None |
|
|
|
|
|
|
|
self.s_mask.append(cur_s_mask[..., -1:, :] if cur_s_mask is not None else prev_s_mask[..., -1, -1:, :]) |
|
else: |
|
|
|
prev_s_mask = None |
|
|
|
cached_s_mask = self.s_mask[layer_idx] |
|
assert cached_s_mask is not None |
|
if cached_s_mask.shape[-1] == window_size: |
|
cur_s_mask = cur_s_mask |
|
else: |
|
cur_s_mask = torch.cat([cached_s_mask, cur_s_mask], dim=-1) |
|
|
|
|
|
self.s_mask[layer_idx] = cur_s_mask |
|
|
|
|
|
|
|
|
|
dump_rf_mask = None |
|
if len(self.rf_mask) <= layer_idx: |
|
|
|
|
|
if q_len < chunk_size: |
|
cur_rf_mask = rf_mask |
|
else: |
|
if q_len % chunk_size == 0: |
|
dump_rf_mask = rf_mask |
|
cur_rf_mask = None |
|
else: |
|
remainder_tokens = q_len % chunk_size |
|
if rf_mask is not None: |
|
dump_rf_mask, cur_rf_mask = torch.split(rf_mask, [q_len - remainder_tokens, remainder_tokens], dim=-2) |
|
else: |
|
dump_rf_mask = None |
|
cur_rf_mask = None |
|
self.rf_mask.append(cur_rf_mask) |
|
else: |
|
past_rf_mask = self.rf_mask[layer_idx] |
|
if past_rf_mask is not None: |
|
|
|
|
|
cur_rf_mask = torch.cat([past_rf_mask, rf_mask], dim=-2) |
|
else: |
|
|
|
cur_rf_mask = None |
|
|
|
|
|
|
|
|
|
if cur_rf_mask is not None and cur_rf_mask.shape[-2] == chunk_size: |
|
dump_rf_mask = cur_rf_mask |
|
cur_rf_mask = None |
|
|
|
self.rf_mask[layer_idx] = cur_rf_mask |
|
|
|
|
|
|
|
|
|
if len(self.chunk_mask) <= layer_idx: |
|
|
|
|
|
if q_len < window_size: |
|
cur_chunk_mask = chunk_mask |
|
prev_chunk_mask = None |
|
else: |
|
if q_len % window_size == 0: |
|
cur_chunk_mask = None |
|
prev_chunk_mask = chunk_mask |
|
else: |
|
remainder_tokens = q_len % window_size |
|
|
|
prev_chunk_mask, cur_chunk_mask = torch.split(chunk_mask, [q_len - remainder_tokens, remainder_tokens], dim=-2) |
|
bsz, num_heads, _, head_dim = prev_chunk_mask.shape |
|
prev_chunk_mask = prev_chunk_mask.reshape(bsz, num_heads, -1, window_size, head_dim) |
|
|
|
assert prev_s_mask is not None |
|
if prev_s_mask.shape[-3] == 1 and prev_chunk_mask.shape[-3] > 1: |
|
|
|
prev_s_mask = prev_s_mask.expand(-1, -1, prev_chunk_mask.shape[-3], -1, -1) |
|
|
|
|
|
self.chunk_mask.append(cur_chunk_mask[..., -1:, :] if cur_chunk_mask is not None else prev_chunk_mask[..., -1, -1:, :]) |
|
else: |
|
|
|
prev_chunk_mask = None |
|
cur_chunk_mask = self.chunk_mask[layer_idx] |
|
|
|
|
|
|
|
seen_seq_len = self.get_seq_length(layer_idx) |
|
if seen_seq_len > 0 and seen_seq_len % chunk_size == 0: |
|
past_chunk_mask = self.chunk_mask[layer_idx] |
|
if past_chunk_mask is not None: |
|
|
|
|
|
cur_chunk_mask = torch.cat([past_chunk_mask, chunk_mask], dim=-1) |
|
else: |
|
cur_chunk_mask = chunk_mask |
|
self.chunk_mask[layer_idx] = cur_chunk_mask |
|
|
|
|
|
|
|
if seen_seq_len > 0 and seen_seq_len % window_size == 1: |
|
cur_chunk_mask = self.chunk_mask[layer_idx] |
|
|
|
num_chunks_per_window = window_size // chunk_size |
|
cur_chunk_mask[..., -num_chunks_per_window:] = False |
|
self.chunk_mask[layer_idx] = cur_chunk_mask |
|
|
|
return (prev_s_mask, cur_s_mask, prev_chunk_mask, cur_chunk_mask, dump_rf_mask) |
|
|
|
def update_singletons( |
|
self, |
|
q, |
|
k, |
|
v, |
|
layer_idx, |
|
window_size, |
|
): |
|
if len(self.w_k) <= layer_idx: |
|
|
|
|
|
q_len = q.shape[-2] |
|
if q_len < window_size: |
|
w_q = q |
|
w_k = k |
|
w_v = v |
|
past_w_q = past_w_k = past_w_v = None |
|
else: |
|
if q_len % window_size == 0: |
|
w_q = None |
|
w_k = None |
|
w_v = None |
|
past_w_q = q |
|
past_w_k = k |
|
past_w_v = v |
|
else: |
|
remainder_tokens = q_len % window_size |
|
|
|
past_w_q, w_q = torch.split(q, [q_len - remainder_tokens, remainder_tokens], dim=-2) |
|
past_w_k, w_k = torch.split(k, [q_len - remainder_tokens, remainder_tokens], dim=-2) |
|
past_w_v, w_v = torch.split(v, [q_len - remainder_tokens, remainder_tokens], dim=-2) |
|
bsz, num_heads, _, head_dim = past_w_q.shape |
|
past_w_q = past_w_q.reshape(bsz, num_heads, -1, window_size, head_dim) |
|
past_w_k = past_w_k.reshape(bsz, num_heads, -1, window_size, head_dim) |
|
past_w_v = past_w_v.reshape(bsz, num_heads, -1, window_size, head_dim) |
|
|
|
|
|
|
|
|
|
|
|
|
|
self.w_k.append(w_k if w_k is not None else past_w_k[..., -1, :, :]) |
|
self.w_v.append(w_v if w_v is not None else past_w_v[..., -1, :, :]) |
|
else: |
|
|
|
past_w_q = past_w_k = past_w_v = None |
|
|
|
w_q = q |
|
w_k = k |
|
w_v = v |
|
|
|
cached_w_k = self.w_k[layer_idx] |
|
assert cached_w_k is not None |
|
if cached_w_k.shape[-2] == window_size: |
|
w_k = w_k |
|
else: |
|
w_k = torch.cat([cached_w_k, w_k], dim=-2) |
|
|
|
cached_w_v = self.w_v[layer_idx] |
|
assert cached_w_v is not None |
|
if cached_w_v.shape[-2] == window_size: |
|
w_v = w_v |
|
else: |
|
w_v = torch.cat([cached_w_v, w_v], dim=-2) |
|
|
|
|
|
self.w_k[layer_idx] = w_k |
|
self.w_v[layer_idx] = w_v |
|
return (past_w_q, past_w_k, past_w_v), (w_q, w_k, w_v) |
|
|
|
def update_chunks( |
|
self, |
|
q, |
|
k, |
|
v, |
|
layer_idx, |
|
chunk_size |
|
): |
|
q_len = q.shape[-2] |
|
dump_q = None |
|
dump_k = None |
|
dump_v = None |
|
if len(self.rf_q) <= layer_idx: |
|
|
|
|
|
if q_len < chunk_size: |
|
rf_q = q |
|
rf_k = k |
|
rf_v = v |
|
else: |
|
if q_len % chunk_size == 0: |
|
rf_q = None |
|
rf_k = None |
|
rf_v = None |
|
dump_q = q |
|
dump_k = k |
|
dump_v = v |
|
else: |
|
remainder_tokens = q_len % chunk_size |
|
|
|
dump_q, rf_q = torch.split(q, [q_len - remainder_tokens, remainder_tokens], dim=-2) |
|
dump_k, rf_k = torch.split(k, [q_len - remainder_tokens, remainder_tokens], dim=-2) |
|
dump_v, rf_v = torch.split(v, [q_len - remainder_tokens, remainder_tokens], dim=-2) |
|
self.rf_q.append(rf_q) |
|
self.rf_k.append(rf_k) |
|
self.rf_v.append(rf_v) |
|
else: |
|
|
|
|
|
past_rf_q = self.rf_q[layer_idx] |
|
if past_rf_q is not None: |
|
rf_q = torch.cat([past_rf_q, q], dim=-2) |
|
else: |
|
rf_q = q |
|
|
|
past_rf_k = self.rf_k[layer_idx] |
|
if past_rf_k is not None: |
|
rf_k = torch.cat([past_rf_k, k], dim=-2) |
|
else: |
|
rf_k = k |
|
|
|
past_rf_v = self.rf_v[layer_idx] |
|
if past_rf_v is not None: |
|
rf_v = torch.cat([past_rf_v, v], dim=-2) |
|
else: |
|
rf_v = v |
|
|
|
|
|
|
|
|
|
|
|
if rf_q.shape[-2] == chunk_size: |
|
dump_q = rf_q |
|
dump_k = rf_k |
|
dump_v = rf_v |
|
|
|
rf_q = None |
|
rf_k = None |
|
rf_v = None |
|
|
|
self.rf_q[layer_idx] = rf_q |
|
self.rf_k[layer_idx] = rf_k |
|
self.rf_v[layer_idx] = rf_v |
|
|
|
return dump_q, dump_k, dump_v |
|
|
|
def update_chunk_rfas( |
|
self, |
|
softmax_phi_k_v, |
|
log_sum_phi_k, |
|
rf_k_bar, |
|
layer_idx, |
|
random_feature_dim |
|
): |
|
if len(self.softmax_phi_k_v) <= layer_idx: |
|
|
|
self.softmax_phi_k_v.append(softmax_phi_k_v) |
|
self.log_sum_phi_k.append(log_sum_phi_k) |
|
self.rf_k_bar.append(rf_k_bar) |
|
else: |
|
|
|
past_softmax_phi_k_v = self.softmax_phi_k_v[layer_idx] |
|
past_log_sum_phi_k = self.log_sum_phi_k[layer_idx] |
|
past_rf_k_bar = self.rf_k_bar[layer_idx] |
|
|
|
if past_softmax_phi_k_v is not None: |
|
if random_feature_dim == 1: |
|
dim = -2 |
|
else: |
|
dim = -3 |
|
softmax_phi_k_v = torch.cat([past_softmax_phi_k_v, softmax_phi_k_v], dim=dim) |
|
|
|
if past_log_sum_phi_k is not None: |
|
if random_feature_dim == 1: |
|
dim = -2 |
|
else: |
|
dim = -3 |
|
log_sum_phi_k = torch.cat([past_log_sum_phi_k, log_sum_phi_k], dim=dim) |
|
|
|
if past_rf_k_bar is not None: |
|
rf_k_bar = torch.cat([past_rf_k_bar, rf_k_bar], dim=-2) |
|
|
|
self.softmax_phi_k_v[layer_idx] = softmax_phi_k_v |
|
self.log_sum_phi_k[layer_idx] = log_sum_phi_k |
|
self.rf_k_bar[layer_idx] = rf_k_bar |
|
|
|
return softmax_phi_k_v, log_sum_phi_k, rf_k_bar |
|
|
|
def get_chunk_rfas(self, layer_idx): |
|
if len(self.softmax_phi_k_v) <= layer_idx: |
|
return ( |
|
None, |
|
None, |
|
None |
|
) |
|
else: |
|
return ( |
|
self.softmax_phi_k_v[layer_idx], |
|
self.log_sum_phi_k[layer_idx], |
|
self.rf_k_bar[layer_idx] |
|
) |
|
|
|
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: |
|
"""Returns the sequence length of the cached states. A layer index can be optionally passed.""" |
|
if len(self.w_k) <= layer_idx: |
|
return 0 |
|
return self._seen_tokens |
|
|
|
def get_max_length(self) -> Optional[int]: |
|
"""Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length.""" |
|
return None |
|
|
|
def update( |
|
self, |
|
layer_idx: int, |
|
cache_kwargs: Optional[Dict[str, Any]] = None, |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
raise NotImplementedError("`update` is not used in Eva Cache.") |
|
|
|
class EvaStaticCacheForTriton(Cache): |
|
""" |
|
A variant of EvaCache for eva's triton kernels |
|
""" |
|
|
|
def __init__( |
|
self, |
|
batch_size, |
|
num_key_value_heads, |
|
window_size, |
|
head_dim, |
|
num_layers, |
|
dtype, |
|
device |
|
) -> None: |
|
self.past_window_k: List[torch.Tensor] = [] |
|
self.past_window_v: List[torch.Tensor] = [] |
|
|
|
cache_shape = (batch_size, num_key_value_heads, window_size, head_dim) |
|
for idx in range(num_layers): |
|
new_window_k = torch.zeros(cache_shape, dtype=dtype, device=device) |
|
new_window_v = torch.zeros(cache_shape, dtype=dtype, device=device) |
|
self.past_window_k.append(new_window_k) |
|
self.past_window_v.append(new_window_v) |
|
|
|
self.past_window_pos: List[int] = [] |
|
|
|
self.rfa_k: List[torch.Tensor] = [] |
|
self.rfa_v: List[torch.Tensor] = [] |
|
|
|
|
|
self._seen_tokens = 0 |
|
|
|
|
|
self.rf_mask: List[Optional[torch.Tensor]] = [] |
|
self.s_mask: List[torch.Tensor] = [] |
|
|
|
def __len__(self): |
|
""" |
|
Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds |
|
to the number of layers in the model. |
|
""" |
|
return len(self.past_window_pos) |
|
|
|
def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int: |
|
"""Given the sequence length of the new inputs, returns the usable length of the cache.""" |
|
|
|
|
|
|
|
max_length = self.get_max_length() |
|
previous_seq_length = self.get_seq_length(layer_idx) |
|
if max_length is not None and previous_seq_length + new_seq_length > max_length: |
|
return max_length - new_seq_length |
|
return previous_seq_length |
|
|
|
def reorder_cache(self, beam_idx: torch.LongTensor): |
|
"""Reorders the cache for beam search, given the selected beam indices.""" |
|
for layer_idx in range(len(self.past_window_k)): |
|
device = self.past_window_k[layer_idx].device |
|
self.past_window_k[layer_idx] = self.past_window_k[layer_idx].index_select(0, beam_idx.to(device)) |
|
|
|
device = self.past_window_v[layer_idx].device |
|
self.past_window_v[layer_idx] = self.past_window_v[layer_idx].index_select(0, beam_idx.to(device)) |
|
|
|
device = self.rfa_k[layer_idx].device |
|
self.rfa_k[layer_idx] = self.rfa_k[layer_idx].index_select(0, beam_idx.to(device)) |
|
|
|
device = self.rfa_v[layer_idx].device |
|
self.rfa_v[layer_idx] = self.rfa_v[layer_idx].index_select(0, beam_idx.to(device)) |
|
|
|
|
|
|
|
|
|
device = self.rf_mask[layer_idx].device |
|
self.rf_mask[layer_idx] = self.rf_mask[layer_idx].index_select(0, beam_idx.to(device)) |
|
|
|
device = self.s_mask[layer_idx].device |
|
self.s_mask[layer_idx] = self.s_mask[layer_idx].index_select(0, beam_idx.to(device)) |
|
|
|
@property |
|
def seen_tokens(self): |
|
if hasattr(self, "_seen_tokens"): |
|
return self._seen_tokens |
|
else: |
|
return None |
|
|
|
def update_past_len( |
|
self, |
|
cur_q_len: int, |
|
layer_idx: int |
|
): |
|
|
|
if layer_idx == 0: |
|
self._seen_tokens += cur_q_len |
|
return self._seen_tokens |
|
|
|
def update_mask( |
|
self, |
|
s_mask, |
|
rf_mask, |
|
layer_idx, |
|
window_size, |
|
): |
|
|
|
|
|
|
|
if len(self.s_mask) <= layer_idx: |
|
|
|
|
|
|
|
|
|
if s_mask is None: |
|
cur_s_mask = None |
|
else: |
|
q_len = s_mask.shape[-2] |
|
|
|
|
|
|
|
|
|
remainder_tokens = q_len % window_size |
|
if remainder_tokens == 0: |
|
cur_s_mask = None |
|
else: |
|
cur_s_mask = s_mask[..., -1:, :remainder_tokens] |
|
self.s_mask.append(cur_s_mask) |
|
|
|
dump_s_mask = s_mask |
|
else: |
|
|
|
past_s_mask = self.s_mask[layer_idx] |
|
if past_s_mask is None: |
|
assert s_mask is None |
|
cur_s_mask = None |
|
else: |
|
assert s_mask is not None |
|
cur_s_mask = torch.cat([past_s_mask, s_mask], dim=-1) |
|
|
|
dump_s_mask = cur_s_mask |
|
if cur_s_mask is not None and cur_s_mask.shape[-1] == window_size: |
|
cur_s_mask = None |
|
|
|
self.s_mask[layer_idx] = cur_s_mask |
|
|
|
|
|
|
|
|
|
dump_rf_mask = None |
|
if len(self.rf_mask) <= layer_idx: |
|
|
|
|
|
if rf_mask is None: |
|
cur_rf_mask = None |
|
else: |
|
q_len = rf_mask.shape[-2] |
|
if q_len < window_size: |
|
dump_rf_mask = None |
|
cur_rf_mask = rf_mask |
|
else: |
|
if q_len % window_size == 0: |
|
dump_rf_mask = rf_mask |
|
cur_rf_mask = None |
|
else: |
|
remainder_tokens = q_len % window_size |
|
dump_rf_mask, cur_rf_mask = torch.split(rf_mask, [q_len - remainder_tokens, remainder_tokens], dim=-2) |
|
self.rf_mask.append(cur_rf_mask) |
|
else: |
|
past_rf_mask = self.rf_mask[layer_idx] |
|
if past_rf_mask is not None: |
|
|
|
|
|
cur_rf_mask = torch.cat([past_rf_mask, rf_mask], dim=-2) |
|
else: |
|
cur_rf_mask = None |
|
|
|
if cur_rf_mask is not None and cur_rf_mask.shape[-2] == window_size: |
|
dump_rf_mask = cur_rf_mask |
|
cur_rf_mask = None |
|
|
|
self.rf_mask[layer_idx] = cur_rf_mask |
|
|
|
return dump_s_mask, dump_rf_mask |
|
|
|
def update_singletons_and_chunks( |
|
self, |
|
k, |
|
v, |
|
layer_idx, |
|
window_size, |
|
): |
|
if len(self.past_window_pos) <= layer_idx: |
|
|
|
s_k = k |
|
s_v = v |
|
input_len = k.shape[-2] |
|
window_pos = 0 |
|
if input_len <= window_size: |
|
new_window_pos = window_pos + input_len |
|
|
|
cached_window_k = k |
|
cached_window_v = v |
|
dump_k = None |
|
dump_v = None |
|
else: |
|
remainder_tokens = input_len % window_size |
|
if remainder_tokens == 0: |
|
remainder_tokens = window_size |
|
new_window_pos = window_pos + remainder_tokens |
|
|
|
|
|
cached_window_k = k[..., -remainder_tokens:, :] |
|
cached_window_v = v[..., -remainder_tokens:, :] |
|
dump_k = k[..., :-remainder_tokens, :] |
|
dump_v = v[..., :-remainder_tokens, :] |
|
|
|
self.past_window_k[layer_idx][:, :, window_pos : new_window_pos, :] = cached_window_k |
|
self.past_window_v[layer_idx][:, :, window_pos : new_window_pos, :] = cached_window_v |
|
self.past_window_pos.append(new_window_pos) |
|
else: |
|
|
|
|
|
|
|
if self.past_window_pos[layer_idx] == window_size: |
|
self.past_window_pos[layer_idx] = 0 |
|
dump_k = self.past_window_k[layer_idx].clone() |
|
dump_v = self.past_window_v[layer_idx].clone() |
|
else: |
|
dump_k = None |
|
dump_v = None |
|
|
|
input_len = k.shape[-2] |
|
window_pos = self.past_window_pos[layer_idx] |
|
new_window_pos = window_pos + input_len |
|
|
|
self.past_window_k[layer_idx][:, :, window_pos : new_window_pos, :] = k |
|
self.past_window_v[layer_idx][:, :, window_pos : new_window_pos, :] = v |
|
|
|
s_k = self.past_window_k[layer_idx][:, :, : new_window_pos, :] |
|
s_v = self.past_window_v[layer_idx][:, :, : new_window_pos, :] |
|
|
|
self.past_window_pos[layer_idx] = new_window_pos |
|
|
|
return s_k, s_v, dump_k, dump_v |
|
|
|
def update_chunk_rfas( |
|
self, |
|
rfa_k, |
|
rfa_v, |
|
layer_idx, |
|
): |
|
if len(self.rfa_k) <= layer_idx: |
|
|
|
self.rfa_k.append(rfa_k) |
|
self.rfa_v.append(rfa_v) |
|
else: |
|
|
|
past_rfa_k = self.rfa_k[layer_idx] |
|
past_rfa_v = self.rfa_v[layer_idx] |
|
|
|
if past_rfa_k is not None: |
|
rfa_k = torch.cat([past_rfa_k, rfa_k], dim=-2) |
|
|
|
if past_rfa_v is not None: |
|
rfa_v = torch.cat([past_rfa_v, rfa_v], dim=-2) |
|
|
|
self.rfa_k[layer_idx] = rfa_k |
|
self.rfa_v[layer_idx] = rfa_v |
|
|
|
return rfa_k, rfa_v |
|
|
|
def get_past_window_pos(self, layer_idx): |
|
if len(self.past_window_pos) <= layer_idx: |
|
return None |
|
else: |
|
return self.past_window_pos[layer_idx] |
|
|
|
def get_past_window_kv(self, layer_idx): |
|
if len(self.past_window_pos) <= layer_idx: |
|
return None, None |
|
else: |
|
return ( |
|
self.past_window_k[layer_idx][:, :, : self.past_window_pos[layer_idx], :], |
|
self.past_window_v[layer_idx][:, :, : self.past_window_pos[layer_idx], :] |
|
) |
|
|
|
def get_chunk_rfas(self, layer_idx): |
|
if len(self.rfa_k) <= layer_idx: |
|
return None, None |
|
else: |
|
return self.rfa_k[layer_idx], self.rfa_v[layer_idx] |
|
|
|
def get_seq_length(self, layer_idx = 0) -> int: |
|
"""Returns the sequence length of the cached states. A layer index can be optionally passed.""" |
|
|
|
|
|
if len(self.past_window_pos) <= layer_idx: |
|
return 0 |
|
return self._seen_tokens |
|
|
|
def get_max_length(self) -> Optional[int]: |
|
"""Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length.""" |
|
return None |
|
|
|
def update( |
|
self, |
|
layer_idx: int, |
|
cache_kwargs: Optional[Dict[str, Any]] = None, |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
raise NotImplementedError("`update` is not used in Eva Cache.") |
|
|