|
import os |
|
import torch |
|
import time |
|
import numpy as np |
|
import torch.distributed as dist |
|
from copy import deepcopy |
|
from transformers.utils import logging |
|
from transformers import AutoTokenizer |
|
from itertools import cycle |
|
from typing import List |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
class Memory(torch.nn.Module): |
|
def __init__( |
|
self, |
|
model_config, |
|
k_seq_dim:int=2, |
|
v_seq_dim:int=2, |
|
): |
|
"""Setup necessary attributes.""" |
|
super().__init__() |
|
|
|
self.config = model_config |
|
|
|
|
|
self.k_seq_dim = k_seq_dim |
|
self.v_seq_dim = v_seq_dim |
|
self.rng = np.random.default_rng(42) |
|
|
|
self._post_validation() |
|
self.reset() |
|
|
|
@property |
|
def beacon_token(self): |
|
return self.config.vocab_size |
|
|
|
def _post_validation(self, verbose=True): |
|
assert self.config.beacon_window >= self.config.beacon_stride, f"Make sure the beacon_window {self.config.beacon_window} >= beacon_stride {self.config.beacon_stride}!" |
|
for ratio in self.config.beacon_ratio: |
|
assert ratio >= 0, f"Make sure all beacon ratios are greater than or equal to 0, found {self.config.beacon_ratio}!" |
|
assert self.config.beacon_attn in ["segmentation", "step-expansion", "full-coverage"], f"beacon_attn {self.config.beacon_attn} not implemented!" |
|
assert self.config.beacon_ratio_mix in ["instance-random", "step-random", "sequence"] or "adapt-" in self.config.beacon_ratio_mix, f"beacon_ratio_mix {self.config.beacon_ratio_mix} not implemented!" |
|
|
|
if self.config.beacon_pos == "interleave": |
|
assert self.config.beacon_window == self.config.beacon_stride, f"Make sure the beacon_window equals to beacon_stride when using interleaving mode." |
|
if self.config.beacon_parallel_window > 1: |
|
assert self.config._attn_implementation != "flash_attention_2", f"Currently parallel window does not support flash_attention_2!" |
|
|
|
self._cpu = torch.device("cpu") |
|
|
|
if verbose: |
|
info = f"applying activation beacon on {self.config.beacon_param} (the beacon embedding is initialized from {'bos' if self.config.beacon_embed_init == 'bos' else 'eos'} embedding, the beacon tokens are positioned with '{self.config.beacon_pos}' method), with window size {self.config.beacon_window}, stride {self.config.beacon_stride}, {self.config.beacon_attn} attention{' (attending to previous beacons)' if self.config.beacon_attend_prev else ' (no attending to previous beacons)'}, sink size {self.config.beacon_sink_size}, compression ratio {self.config.beacon_ratio} (mixed by {self.config.beacon_ratio_mix})..." |
|
logger.info(info) |
|
|
|
def set(self, verbose=True, **kwargs): |
|
""" |
|
Set attributes out of the constructor. |
|
""" |
|
for k, v in kwargs.items(): |
|
setattr(self.config, k, v) |
|
self._post_validation(verbose=verbose) |
|
|
|
def reset(self, **kwargs): |
|
"""Initialize attributes for a new sequence.""" |
|
|
|
self.start_idx = 0 |
|
|
|
self.end_idx = 0 |
|
|
|
self.all_beacon_sizes = [] |
|
|
|
self.batch_loss = None |
|
|
|
self.valid_token_num = None |
|
|
|
self.step_idx = 0 |
|
|
|
self.compression_ratio = None |
|
|
|
self.is_full_window = True |
|
|
|
self.raw_size_to_cache = 0 |
|
|
|
|
|
self.interleave_remainder = 0 |
|
|
|
self.interleave_compression_ratio = None |
|
self.beacon_indices = None |
|
|
|
self.all_input_ids = None |
|
self.all_attention_mask = None |
|
self.all_labels = None |
|
|
|
|
|
self.beacon_skip_first = None |
|
self.beacon_skip_last = None |
|
|
|
|
|
self.sink_activations = [(None, None) for _ in range(self.config.num_hidden_layers)] |
|
|
|
self.beacon_activations = [(None, None) for _ in range(self.config.num_hidden_layers)] |
|
|
|
self.raw_activations = [(None, None) for _ in range(self.config.num_hidden_layers)] |
|
|
|
|
|
for k, v in kwargs.items(): |
|
|
|
setattr(self, deepcopy(k), deepcopy(v)) |
|
|
|
def export(self): |
|
"""Export all necessary attributes of the memory module.""" |
|
return { |
|
"start_idx": self.start_idx, |
|
"end_idx": self.end_idx, |
|
"all_beacon_sizes": self.all_beacon_sizes, |
|
"batch_loss": self.batch_loss, |
|
"valid_token_num": self.valid_token_num, |
|
"step_idx": self.step_idx, |
|
"compression_ratio": self.compression_ratio, |
|
"is_full_window": self.is_full_window, |
|
"raw_size_to_cache": self.raw_size_to_cache, |
|
"interleave_remainder": self.interleave_remainder, |
|
"interleave_compression_ratio": self.interleave_compression_ratio, |
|
"beacon_indices": self.beacon_indices, |
|
"all_input_ids": self.all_input_ids, |
|
"all_attention_mask": self.all_attention_mask, |
|
"all_labels": self.all_labels, |
|
"beacon_skip_first": self.beacon_skip_first, |
|
"beacon_skip_last": self.beacon_skip_last, |
|
|
|
"sink_activations": deepcopy(self.sink_activations), |
|
"beacon_activations": deepcopy(self.beacon_activations), |
|
"raw_activations": deepcopy(self.raw_activations), |
|
} |
|
|
|
@property |
|
def all_sequence_length(self): |
|
if self.all_input_ids is None: |
|
return 0 |
|
else: |
|
return self.all_input_ids.shape[1] |
|
|
|
@property |
|
def batch_size(self): |
|
if self.all_input_ids is None: |
|
return 0 |
|
else: |
|
return self.all_input_ids.shape[0] |
|
|
|
@property |
|
def finish(self): |
|
is_finish = self.end_idx == self.all_sequence_length |
|
return is_finish |
|
|
|
@property |
|
def dtype(self): |
|
return self.config.torch_dtype |
|
|
|
@property |
|
def min_value(self): |
|
return torch.finfo(self.dtype).min |
|
|
|
@property |
|
def max_position_embeddings(self): |
|
max_position_embeddings = self.config.max_position_embeddings |
|
if getattr(self.config, "rope_scaling", None) is not None: |
|
scaling_factor = self.config.rope_scaling["factor"] |
|
max_position_embeddings = max_position_embeddings * scaling_factor |
|
return max_position_embeddings |
|
|
|
@property |
|
def beacon_window(self): |
|
if ( |
|
self.beacon_skip_last is not None |
|
and self.start_idx < self.beacon_skip_last |
|
and self.start_idx + self.config.beacon_window > self.beacon_skip_last |
|
): |
|
return self.beacon_skip_last - self.start_idx |
|
else: |
|
return self.config.beacon_window |
|
|
|
@property |
|
def beacon_stride(self): |
|
if ( |
|
self.beacon_skip_last is not None |
|
and self.start_idx < self.beacon_skip_last |
|
and self.start_idx + self.config.beacon_window > self.beacon_skip_last |
|
): |
|
return self.beacon_skip_last - self.start_idx |
|
else: |
|
return self.config.beacon_stride |
|
|
|
def get_memory_size(self): |
|
""" |
|
Sink memory size, beacon memory size and raw memory size. |
|
""" |
|
sink_memory_size = 0 |
|
beacon_memory_size = 0 |
|
raw_memory_size = 0 |
|
if self.sink_activations[0][0] is not None: |
|
sink_memory_size += self.sink_activations[0][0].shape[self.k_seq_dim] |
|
if self.beacon_activations[0][0] is not None: |
|
beacon_memory_size += self.beacon_activations[0][0].shape[self.k_seq_dim] |
|
if self.raw_activations[0][0] is not None: |
|
raw_memory_size += self.raw_activations[0][0].shape[self.k_seq_dim] |
|
return sink_memory_size, beacon_memory_size, raw_memory_size |
|
|
|
def prepare(self, input_ids, attention_mask, labels, skip_first=None, skip_last=None): |
|
""" |
|
Prepare inputs for the model. These inputs belong to the same sequence. |
|
""" |
|
|
|
|
|
|
|
self._device = input_ids.device |
|
|
|
|
|
if self.all_input_ids is None: |
|
self.all_input_ids = input_ids.cpu() |
|
else: |
|
self.all_input_ids = torch.cat([self.all_input_ids, input_ids.cpu()], dim=1) |
|
|
|
|
|
if attention_mask is None: |
|
attention_mask = torch.ones_like(input_ids, device=torch.device("cpu")) |
|
if self.all_attention_mask is None: |
|
self.all_attention_mask = attention_mask.cpu() |
|
else: |
|
self.all_attention_mask = torch.cat([self.all_attention_mask, attention_mask.cpu()], dim=1) |
|
|
|
|
|
if labels is not None: |
|
|
|
labels = torch.cat([labels[:, 1:].cpu(), torch.tensor([-100]).expand(labels.shape[0], 1)], dim=1) |
|
if self.all_labels is None: |
|
self.all_labels = labels.cpu() |
|
else: |
|
self.all_labels = torch.cat([self.all_labels, labels], dim=1) |
|
assert self.all_input_ids.shape[1] == self.all_labels.shape[1], f"Found inconsistent all_input_ids {self.all_input_ids.shape} and all_labels {self.all_labels.shape}!" |
|
|
|
|
|
if skip_first is not None: |
|
assert self.config.beacon_parallel_window == 1, f"Make sure the parallel window is set to 1 when using beacon_skip!" |
|
assert self.config.beacon_window == self.config.beacon_stride, f"Make sure the beacon_window equals to beacon_stride when using beacon_skip." |
|
assert self.config.beacon_sink_size == 0, f"Make sure the beacon_sink_size is set to 0 when using beacon_skip!" |
|
|
|
if skip_last is not None: |
|
skip_first = skip_first if skip_first is not None else 0 |
|
|
|
assert self.config.beacon_sink_size == 0, "Make sure the beacon_sink_size is zero when using skip_last!" |
|
self.beacon_skip_first = skip_first |
|
self.beacon_skip_last = skip_last |
|
|
|
def set_compression_ratio(self, start_idx, end_idx): |
|
"""Choose a condensing ratio from self.config.beacon_ratio""" |
|
def filter_ratio(ratios, stride): |
|
valid_ratios = [] |
|
for ratio in ratios: |
|
|
|
if stride < ratio: |
|
continue |
|
|
|
if ratio > 0 and (stride % ratio) != 0: |
|
continue |
|
|
|
if ratio == 0 and self.training: |
|
previous_has_zero = -1 in self.all_beacon_sizes |
|
following_has_nonzero = (start_idx + stride + self.beacon_window) <= self.all_sequence_length |
|
if previous_has_zero or (not following_has_nonzero): |
|
continue |
|
valid_ratios.append(ratio) |
|
assert len(valid_ratios), f"Cannot find valid condensing ratio (among {ratios}) for stride {stride}!" |
|
return valid_ratios |
|
|
|
def get_max_length(ratios): |
|
max_lengths = [] |
|
for compression_ratio in ratios: |
|
if compression_ratio > 0: |
|
|
|
max_lengths.append((self.max_position_embeddings - self.beacon_window) * compression_ratio + self.beacon_window) |
|
else: |
|
max_lengths.append(self.max_position_embeddings) |
|
return max_lengths |
|
|
|
if len(self.config.beacon_ratio) == 1: |
|
return self.config.beacon_ratio[0] |
|
|
|
ratio_mix = self.config.beacon_ratio_mix |
|
|
|
beacon_ratio = filter_ratio(self.config.beacon_ratio, self.beacon_stride) |
|
|
|
if ratio_mix == "instance-random": |
|
if self.compression_ratio is None: |
|
beacon_ratio = self.rng.choice(beacon_ratio).tolist() |
|
self.compression_ratio = beacon_ratio |
|
else: |
|
beacon_ratio = self.compression_ratio |
|
|
|
elif ratio_mix == "step-random": |
|
beacon_ratio = self.rng.choice(beacon_ratio).tolist() |
|
|
|
elif ratio_mix == "sequence": |
|
if self.compression_ratio is None: |
|
self.compression_ratio = cycle(beacon_ratio) |
|
beacon_ratio = next(self.compression_ratio) |
|
|
|
elif "adapt" in ratio_mix: |
|
if self.compression_ratio is None: |
|
future_length = int(ratio_mix.split("-")[1]) |
|
sequence_length = self.all_input_ids.shape[1] + future_length |
|
max_lengths = get_max_length(beacon_ratio) |
|
|
|
valid_max_lengths_and_indices = [x for x in enumerate(max_lengths) if x[1] >= sequence_length] |
|
if len(valid_max_lengths_and_indices): |
|
minimum_length_index = min(valid_max_lengths_and_indices, key=lambda x: x[1])[0] |
|
|
|
beacon_ratio = beacon_ratio[minimum_length_index] |
|
else: |
|
beacon_ratio = max(beacon_ratio) |
|
|
|
self.compression_ratio = beacon_ratio |
|
else: |
|
beacon_ratio = self.compression_ratio |
|
|
|
return beacon_ratio |
|
|
|
def step(self): |
|
|
|
|
|
|
|
if ( |
|
self.config.beacon_parallel_window > 1 |
|
and self.config.beacon_stride == self.config.beacon_window |
|
and 0 not in self.config.beacon_ratio |
|
and self.all_input_ids[:, self.end_idx:].shape[1] >= self.config.beacon_parallel_window * self.config.beacon_window |
|
): |
|
input_ids_list = [] |
|
attention_mask_list = [] |
|
position_ids_list = [] |
|
labels_list = [] |
|
|
|
beacon_size_list = [] |
|
beacon_indices_list = [] |
|
|
|
for i in range(self.config.beacon_parallel_window): |
|
if i == 0: |
|
_input_ids, _attention_mask, _position_ids, _past_key_values, _labels = self._step() |
|
else: |
|
_input_ids, _attention_mask, _position_ids, _past_key_values, _labels = self._step(ignore_memory=True) |
|
|
|
input_ids_list.append(_input_ids) |
|
attention_mask_list.append(_attention_mask) |
|
position_ids_list.append(_position_ids) |
|
labels_list.append(_labels) |
|
beacon_size_list.append(_past_key_values[0][2]) |
|
beacon_indices_list.append(_past_key_values[0][3]) |
|
|
|
if i == 0: |
|
past_key_values = _past_key_values |
|
if past_key_values[0][0] is None: |
|
mem_size = 0 |
|
else: |
|
mem_size = past_key_values[0][0].shape[self.k_seq_dim] |
|
|
|
else: |
|
|
|
assert _past_key_values[0][0] is None |
|
|
|
batch_size = self.all_input_ids.shape[0] |
|
|
|
seq_len = sum(x.shape[1] for x in input_ids_list) + sum(beacon_size_list) - beacon_size_list[-1] |
|
|
|
input_ids = _input_ids.new_zeros((batch_size, seq_len)) + self.beacon_token |
|
|
|
attention_mask = _attention_mask.new_zeros((batch_size, 1, seq_len, mem_size + seq_len)) + self.min_value |
|
position_ids = torch.arange(mem_size + seq_len, device=self._device).expand(batch_size, mem_size + seq_len) |
|
|
|
beacon_indices = beacon_indices_list[0].new_zeros(seq_len) + 2 |
|
if _labels is not None: |
|
|
|
labels = _labels.new_zeros((batch_size, seq_len)) - 100 |
|
else: |
|
labels = None |
|
|
|
start_idx = 0 |
|
position_offset = mem_size |
|
for i in range(self.config.beacon_parallel_window): |
|
beacon_size = beacon_size_list[i] |
|
|
|
|
|
_input_ids = input_ids_list[i] |
|
cur_seq_len = _input_ids.shape[1] |
|
input_ids[:, start_idx: start_idx + cur_seq_len] = _input_ids |
|
|
|
|
|
_attention_mask = attention_mask_list[i] |
|
_position_ids = position_ids_list[i] |
|
|
|
if i == 0: |
|
_attention_mask = _attention_mask[:, :, :, mem_size:] |
|
_position_ids = _position_ids[:, mem_size:] - mem_size |
|
|
|
attention_mask[:, :, start_idx: start_idx + cur_seq_len, mem_size + start_idx: mem_size + start_idx + cur_seq_len] = _attention_mask |
|
position_ids[:, mem_size + start_idx: mem_size + start_idx + cur_seq_len] = _position_ids + position_offset |
|
|
|
|
|
_beacon_indices = beacon_indices_list[i] |
|
beacon_indices[start_idx: start_idx + cur_seq_len] = _beacon_indices |
|
|
|
|
|
if labels is not None: |
|
|
|
_labels = labels_list[i] |
|
labels[:, start_idx: start_idx + cur_seq_len] = _labels |
|
|
|
|
|
if i == 0 and self.config.beacon_sink_size > 0 and self.sink_activations[0][0] is None: |
|
position_offset += 1 |
|
|
|
|
|
if i != self.config.beacon_parallel_window - 1: |
|
replicate_beacon_row_start = start_idx + cur_seq_len |
|
replicate_beacon_col_start = mem_size + start_idx + cur_seq_len |
|
|
|
attention_mask[:, :, replicate_beacon_row_start: replicate_beacon_row_start + beacon_size, replicate_beacon_col_start: replicate_beacon_col_start + beacon_size] = _attention_mask.new_full((beacon_size, beacon_size), self.min_value).triu(1) |
|
|
|
attention_mask[:, :, replicate_beacon_row_start + beacon_size:, replicate_beacon_col_start: replicate_beacon_col_start + beacon_size] = 0 |
|
|
|
position_ids[:, mem_size + start_idx + cur_seq_len: mem_size + start_idx + cur_seq_len + beacon_size] = torch.arange(position_offset, position_offset + beacon_size, device=_input_ids.device)[None:] |
|
|
|
start_idx += cur_seq_len + beacon_size |
|
position_offset += beacon_size |
|
|
|
|
|
attention_mask[:, :, :, :max(mem_size, self.config.beacon_sink_size)] = 0 |
|
|
|
|
|
for i, (key, value, _, _) in enumerate(past_key_values): |
|
past_key_values[i] = (key, value, sum(beacon_size_list), beacon_indices) |
|
|
|
|
|
self.beacon_indices = beacon_indices |
|
|
|
return input_ids, attention_mask, position_ids, past_key_values, labels |
|
|
|
else: |
|
return self._step() |
|
|
|
def _step(self, ignore_memory=False): |
|
""" |
|
Yield inputs for the current sliding window, including the input_ids, attention_mask, position_ids, and past_key_values. |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
start_idx = self.start_idx |
|
|
|
end_idx = start_idx + self.beacon_window |
|
|
|
|
|
if end_idx > self.all_sequence_length: |
|
|
|
end_idx = self.all_sequence_length |
|
is_full_window = False |
|
else: |
|
is_full_window = True |
|
|
|
|
|
|
|
if self.training and end_idx == self.all_sequence_length: |
|
next_start_idx = start_idx |
|
is_full_window = False |
|
raw_size_to_cache = -1 |
|
beacon_size = 0 |
|
compression_ratio = -1 |
|
|
|
|
|
elif self.step_idx == 0 and self.beacon_skip_first is not None: |
|
end_idx = start_idx + self.beacon_skip_first |
|
assert end_idx <= self.all_sequence_length |
|
next_start_idx = end_idx |
|
is_full_window = True |
|
raw_size_to_cache = -1 |
|
beacon_size = 0 |
|
compression_ratio = -1 |
|
|
|
|
|
elif self.beacon_skip_last is not None and start_idx >= self.beacon_skip_last: |
|
end_idx = min(start_idx + self.beacon_window, self.all_sequence_length) |
|
next_start_idx = end_idx |
|
is_full_window = False |
|
raw_size_to_cache = -1 |
|
beacon_size = 0 |
|
compression_ratio = -1 |
|
|
|
else: |
|
|
|
|
|
|
|
if self.config.beacon_pos == "append": |
|
if is_full_window: |
|
|
|
beacon_stride = self.beacon_stride |
|
compression_ratio = self.set_compression_ratio(start_idx=start_idx, end_idx=end_idx) |
|
|
|
if compression_ratio > 0: |
|
|
|
beacon_size = beacon_stride // compression_ratio |
|
else: |
|
|
|
beacon_size = -1 |
|
|
|
|
|
next_start_idx = start_idx + beacon_stride |
|
|
|
raw_size_to_cache = end_idx - next_start_idx |
|
else: |
|
|
|
next_start_idx = start_idx |
|
|
|
raw_size_to_cache = -1 |
|
beacon_size = 0 |
|
compression_ratio = 0 |
|
|
|
elif self.config.beacon_pos == "interleave": |
|
|
|
input_size = end_idx - self.end_idx |
|
|
|
if self.is_full_window: |
|
compression_ratio = self.set_compression_ratio(start_idx=start_idx, end_idx=end_idx) |
|
self.interleave_compression_ratio = compression_ratio |
|
else: |
|
compression_ratio = self.interleave_compression_ratio |
|
|
|
|
|
if compression_ratio > 0: |
|
|
|
beacon_size = (input_size + self.interleave_remainder) // compression_ratio |
|
else: |
|
|
|
beacon_size = -1 |
|
|
|
if is_full_window: |
|
|
|
next_start_idx = start_idx + self.beacon_stride |
|
|
|
raw_size_to_cache = 0 |
|
else: |
|
|
|
next_start_idx = start_idx |
|
|
|
raw_size_to_cache = -1 |
|
|
|
|
|
|
|
|
|
input_ids = self.all_input_ids[:, self.end_idx: end_idx].to(self._device) |
|
attention_mask = self.all_attention_mask[:, self.end_idx: end_idx].to(self._device) |
|
if self.all_labels is not None: |
|
labels = self.all_labels[:, self.end_idx: end_idx].to(self._device) |
|
else: |
|
labels = None |
|
batch_size = input_ids.shape[0] |
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.config.beacon_pos == "append": |
|
|
|
if is_full_window and beacon_size > 0: |
|
input_ids = torch.cat([input_ids, input_ids.new_full((batch_size, beacon_size), self.beacon_token)], dim=1) |
|
|
|
attention_mask = torch.cat([attention_mask, attention_mask.new_ones(batch_size, beacon_size)], dim=1) |
|
if labels is not None: |
|
labels = torch.cat([labels, labels.new_zeros(batch_size, beacon_size) - 100], dim=1) |
|
|
|
elif self.config.beacon_pos == "interleave": |
|
input_len = input_ids.shape[1] |
|
if beacon_size > 0: |
|
|
|
input_ids_with_beacons = input_ids.new_full((input_ids.shape[0], input_len + beacon_size), self.beacon_token) |
|
raw_token_indices = torch.arange(input_ids_with_beacons.shape[1], device=input_ids.device) |
|
interleave_start_idx = compression_ratio - self.interleave_remainder |
|
raw_token_indices = raw_token_indices[raw_token_indices % (compression_ratio + 1) != interleave_start_idx].unsqueeze(0).expand_as(input_ids) |
|
input_ids_with_beacons = input_ids_with_beacons.scatter(dim=1, index=raw_token_indices, src=input_ids) |
|
input_ids = input_ids_with_beacons |
|
|
|
attention_mask_with_beacons = attention_mask.new_full((attention_mask.shape[0], attention_mask.shape[1] + beacon_size), 1) |
|
attention_mask_with_beacons = attention_mask_with_beacons.scatter(dim=1, index=raw_token_indices, src=attention_mask) |
|
attention_mask = attention_mask_with_beacons |
|
|
|
if labels is not None: |
|
labels_with_beacons = labels.new_full((labels.shape[0], labels.shape[1] + beacon_size), -100) |
|
labels_with_beacons = labels_with_beacons.scatter(dim=1, index=raw_token_indices, src=labels) |
|
labels = labels_with_beacons |
|
|
|
if compression_ratio > 0: |
|
|
|
self.interleave_remainder = (input_len + self.interleave_remainder) % compression_ratio |
|
|
|
|
|
if self.training and self.step_idx == 0 and not (self.config.beacon_pos == 'interleave' and self.config.beacon_attn == 'full-coverage'): |
|
labels[:] = -100 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
beacon_indices = (input_ids[0] == self.beacon_token).long() |
|
if self.is_full_window: |
|
self.beacon_indices = torch.tensor([], dtype=torch.long, device=input_ids.device) |
|
|
|
beacon_indices = torch.cat([self.beacon_indices, beacon_indices]) |
|
|
|
self.beacon_indices = beacon_indices |
|
if is_full_window and beacon_size == -1: |
|
|
|
|
|
beacon_indices[:self.beacon_stride] = -1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
past_key_values = [] |
|
for layer_idx in range(self.config.num_hidden_layers): |
|
if ignore_memory: |
|
key, value = None, None |
|
else: |
|
sink_key, sink_value = self.sink_activations[layer_idx] |
|
beacon_key, beacon_value = self.beacon_activations[layer_idx] |
|
raw_key, raw_value = self.raw_activations[layer_idx] |
|
|
|
key = cat_tensor([ |
|
sink_key, beacon_key, raw_key, |
|
], dim=self.k_seq_dim) |
|
value = cat_tensor([ |
|
sink_value, beacon_value, raw_value, |
|
], dim=self.v_seq_dim) |
|
|
|
layer_past_key_values = (key, value, beacon_size, beacon_indices) |
|
past_key_values.append(layer_past_key_values) |
|
|
|
|
|
|
|
|
|
|
|
|
|
first_key = past_key_values[0][0] |
|
mem_size = first_key.shape[self.k_seq_dim] if first_key is not None else 0 |
|
if mem_size > 0: |
|
attention_mask = torch.cat([attention_mask.new_ones(batch_size, mem_size), attention_mask], dim=1) |
|
|
|
input_length = input_ids.shape[1] |
|
position_ids = torch.arange(attention_mask.shape[-1], dtype=torch.long, device=self._device).repeat(batch_size, 1) |
|
|
|
if self.config._attn_implementation == "flash_attention_2": |
|
assert self.config.beacon_attn == "full-coverage", f"Make sure to set beacon_attn='full-coverage' when using flash attention! Found {self.config.beacon_attn}." |
|
if 0 in attention_mask: |
|
pass |
|
else: |
|
attention_mask = None |
|
elif self.config._attn_implementation == "sdpa" and self.config.beacon_pos == "append" and beacon_size <= 0 and (input_length == 1 or mem_size == 0): |
|
attention_mask = None |
|
else: |
|
attention_mask, position_ids = self._make_4d_attention_mask_and_position_ids( |
|
attention_mask, |
|
position_ids, |
|
mem_size, |
|
beacon_size, |
|
compression_ratio, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.is_full_window = is_full_window |
|
|
|
self.raw_size_to_cache = raw_size_to_cache |
|
|
|
self.all_beacon_sizes.append(beacon_size) |
|
|
|
|
|
|
|
self.start_idx = next_start_idx |
|
self.end_idx = end_idx |
|
self.step_idx += 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return input_ids, attention_mask, position_ids, past_key_values, labels |
|
|
|
def update_memory(self, past_key_values): |
|
""" |
|
Accumulate beacon activations and raw activations. |
|
""" |
|
for layer_idx, (key, value, beacon_size, beacon_indices) in enumerate(past_key_values): |
|
|
|
previous_raw_key, previous_raw_value = self.raw_activations[layer_idx] |
|
|
|
if self.beacon_skip_first is not None and self.sink_activations[layer_idx][0] is None: |
|
assert key.shape[self.k_seq_dim] == self.beacon_skip_first |
|
assert value.shape[self.k_seq_dim] == self.beacon_skip_first |
|
self.sink_activations[layer_idx] = [ |
|
key, |
|
value, |
|
] |
|
|
|
continue |
|
|
|
if self.beacon_activations[layer_idx][0] is None and self.config.beacon_sink_size > 0: |
|
|
|
|
|
self.sink_activations[layer_idx] = [ |
|
slice_tensor(key, end=self.config.beacon_sink_size, dim=self.k_seq_dim), |
|
slice_tensor(value, end=self.config.beacon_sink_size, dim=self.v_seq_dim), |
|
] |
|
|
|
if not self.is_full_window: |
|
|
|
|
|
assert self.raw_size_to_cache == -1 |
|
raw_key = cat_tensor([ |
|
previous_raw_key, |
|
key |
|
], dim=self.k_seq_dim) |
|
raw_value = cat_tensor([ |
|
previous_raw_value, |
|
value |
|
], dim=self.v_seq_dim) |
|
self.raw_activations[layer_idx] = (raw_key, raw_value) |
|
|
|
else: |
|
|
|
previous_beacon_key, previous_beacon_value = self.beacon_activations[layer_idx] |
|
|
|
beacon_key, beacon_value, raw_key, raw_value = self._extract_beacon_and_raw_memory( |
|
key, |
|
value, |
|
previous_beacon_key, |
|
previous_beacon_value, |
|
previous_raw_key, |
|
previous_raw_value, |
|
beacon_indices, |
|
) |
|
|
|
self.beacon_activations[layer_idx] = (beacon_key, beacon_value) |
|
self.raw_activations[layer_idx] = (raw_key, raw_value) |
|
|
|
def update_loss(self, batch_loss, valid_token_num): |
|
""" |
|
Accumulate loss for later perplexity computation and backward pass. |
|
""" |
|
if self.batch_loss is None: |
|
|
|
self.batch_loss = batch_loss * valid_token_num |
|
self.valid_token_num = valid_token_num |
|
else: |
|
|
|
self.batch_loss = self.batch_loss + batch_loss * valid_token_num |
|
self.valid_token_num = self.valid_token_num + valid_token_num |
|
|
|
def output(self, model_outputs): |
|
""" |
|
Override loss with accumulated loss. Update the next-token logits. |
|
""" |
|
|
|
if self.batch_loss is not None: |
|
|
|
loss = self.batch_loss.sum() / self.valid_token_num.sum() |
|
|
|
|
|
batch_loss = self.batch_loss / self.valid_token_num |
|
if (self.valid_token_num == 0).any(): |
|
batch_loss = batch_loss.masked_fill(self.valid_token_num == 0, 0.) |
|
|
|
|
|
model_outputs["loss"] = loss |
|
model_outputs["batch_loss"] = batch_loss |
|
|
|
|
|
beacon_size = self.all_beacon_sizes[-1] |
|
|
|
if beacon_size > 0: |
|
logits = model_outputs["logits"] |
|
beacon_indices = self.beacon_indices[-logits.shape[1]:] |
|
model_outputs["logits"] = logits[:, beacon_indices == 0] |
|
|
|
return model_outputs |
|
|
|
def _make_4d_attention_mask_and_position_ids( |
|
self, |
|
attention_mask, |
|
position_ids, |
|
mem_size, |
|
beacon_size, |
|
compression_ratio, |
|
): |
|
""" |
|
Convert attention_mask into causal 4D attention_mask (batch_size, head_num, query_len, key_len). |
|
""" |
|
tgt_size = attention_mask.size(-1) - mem_size |
|
dtype = self.dtype |
|
min_value = self.min_value |
|
device = self._device |
|
batch_size, src_size = attention_mask.size() |
|
|
|
|
|
causal_mask = torch.full((tgt_size, tgt_size), min_value, device=device, dtype=dtype) |
|
mask_cond = torch.arange(causal_mask.size(-1), device=device) |
|
causal_mask.masked_fill_(mask_cond < (mask_cond + 1).view(causal_mask.size(-1), -1), 0) |
|
causal_mask = torch.cat([torch.zeros(tgt_size, mem_size, dtype=dtype, device=device), causal_mask], dim=-1) |
|
causal_mask = causal_mask[None, None, ...].expand(batch_size, 1, tgt_size, src_size) |
|
|
|
expand_mask = attention_mask[:, None, None, :].expand(batch_size, 1, tgt_size, src_size) |
|
invert_mask = 1.0 - expand_mask |
|
invert_mask.masked_fill_(invert_mask.bool(), min_value) |
|
|
|
attention_mask = causal_mask.masked_fill(invert_mask.bool(), min_value) |
|
|
|
if self.config.beacon_attn == "step-expansion": |
|
|
|
|
|
if self.config.beacon_pos == "append" and beacon_size > 0: |
|
window_size = self.beacon_window |
|
window_size_with_beacon = window_size + beacon_size |
|
beacon_start_idx = -beacon_size |
|
|
|
reference_attention_mask = attention_mask[..., -beacon_size - 1, -window_size_with_beacon: -beacon_size] |
|
|
|
|
|
beacon_arange = torch.arange(1, beacon_size + 1, device=device) * compression_ratio |
|
|
|
ordinal_arange = torch.arange(window_size, device=device) |
|
|
|
valid_pos = ordinal_arange.expand(beacon_size, window_size) < beacon_arange.unsqueeze(-1) |
|
|
|
ordinal_attention_mask = torch.where(valid_pos, 0, min_value) |
|
|
|
ordinal_attention_mask = ordinal_attention_mask[None, None, ...] + reference_attention_mask.unsqueeze(-2) |
|
|
|
if self.config.beacon_attend_prev: |
|
beacon_attention_mask = attention_mask.new_full((beacon_size, beacon_size), min_value).triu(1) |
|
|
|
ordinal_position_ids = position_ids[:, -window_size_with_beacon: -beacon_size] |
|
beacon_position_ids = ordinal_position_ids[:, compression_ratio - 1::compression_ratio] + torch.arange(1, beacon_size + 1, device=device)[None] |
|
position_ids[:, beacon_start_idx:] = beacon_position_ids |
|
else: |
|
beacon_attention_mask = attention_mask.new_full((beacon_size, beacon_size), min_value).fill_diagonal_(0) |
|
|
|
ordinal_position_ids = position_ids[:, -window_size_with_beacon: -beacon_size] |
|
beacon_position_ids = ordinal_position_ids[:, compression_ratio - 1::compression_ratio] + 1 |
|
position_ids[:, beacon_start_idx:] = beacon_position_ids |
|
|
|
attention_mask[..., beacon_start_idx:, -window_size_with_beacon: -beacon_size] = ordinal_attention_mask |
|
attention_mask[..., beacon_start_idx:, beacon_start_idx:] = beacon_attention_mask |
|
|
|
|
|
elif self.config.beacon_pos == "interleave" and (self.beacon_indices == 1).any(): |
|
assert self.config.beacon_attend_prev == False, f"Make sure beacon_attend_prev is False if using 'interleave' beacon pos!" |
|
|
|
beacon_indices = self.beacon_indices |
|
|
|
cur_position_ids = position_ids[:, -len(beacon_indices):] |
|
base_position = cur_position_ids[:, 0] - 1 |
|
|
|
position_template = cur_position_ids.new_ones(cur_position_ids.shape) |
|
position_template[:, compression_ratio + 1::compression_ratio + 1] = 0 |
|
cur_position_ids = base_position + position_template.cumsum(-1) |
|
position_ids[:, -len(beacon_indices):] = cur_position_ids |
|
|
|
cur_input_length = len(beacon_indices) |
|
cur_attention_mask = attention_mask[..., -cur_input_length:, -cur_input_length:] |
|
|
|
cur_attention_mask[..., beacon_indices] = min_value |
|
|
|
input_ids_attention_mask = cur_attention_mask[..., -tgt_size:, -tgt_size:] |
|
input_ids_attention_mask[..., range(tgt_size), range(tgt_size)] = 0 |
|
|
|
elif self.config.beacon_attn == "segmentation": |
|
|
|
|
|
if self.config.beacon_pos == "append" and beacon_size > 0: |
|
window_size = self.beacon_window |
|
window_size_with_beacon = window_size + beacon_size |
|
beacon_start_idx = -beacon_size |
|
|
|
reference_attention_mask = attention_mask[..., -beacon_size - 1, -window_size_with_beacon: -beacon_size] |
|
|
|
|
|
indices = torch.arange(compression_ratio * beacon_size, device=device).view(beacon_size, -1) |
|
|
|
ordinal_attention_mask = attention_mask.new_full((beacon_size, window_size), min_value) |
|
ordinal_attention_mask.scatter_(dim=-1, index=indices, value=0) |
|
|
|
|
|
ordinal_attention_mask = ordinal_attention_mask[None, None, ...] + reference_attention_mask.unsqueeze(-2) |
|
|
|
if self.config.beacon_attend_prev: |
|
beacon_attention_mask = attention_mask.new_full((beacon_size, beacon_size), min_value).triu(1) |
|
|
|
beacon_position_ids = position_ids.new_full(beacon_size, fill_value=compression_ratio + mem_size) |
|
beacon_position_ids = beacon_position_ids + torch.arange(beacon_size) |
|
position_ids[:, beacon_start_idx:] = beacon_position_ids |
|
else: |
|
beacon_attention_mask = attention_mask.new_full((beacon_size, beacon_size), min_value).fill_diagonal_(0) |
|
|
|
beacon_position_ids = position_ids.new_full(beacon_size, fill_value=compression_ratio + mem_size) |
|
position_ids[:, beacon_start_idx:] = beacon_position_ids |
|
|
|
attention_mask[..., beacon_start_idx:, -window_size_with_beacon: -beacon_size] = ordinal_attention_mask |
|
attention_mask[..., beacon_start_idx:, beacon_start_idx:] = beacon_attention_mask |
|
|
|
attention_mask[..., beacon_start_idx:, -beacon_size: beacon_start_idx] = min_value |
|
|
|
elif self.config.beacon_pos == "interleave": |
|
raise NotImplementedError |
|
|
|
elif self.config.beacon_attn == "full-coverage": |
|
pass |
|
|
|
return attention_mask, position_ids |
|
|
|
def _extract_beacon_and_raw_memory( |
|
self, |
|
key, |
|
value, |
|
previous_beacon_key, |
|
previous_beacon_value, |
|
previous_raw_key, |
|
previous_raw_value, |
|
beacon_indices, |
|
): |
|
"""Extract beacon and raw memory from the returned key and value when the window is full.""" |
|
key = cat_tensor([ |
|
previous_raw_key, |
|
key |
|
], dim=self.k_seq_dim) |
|
value = cat_tensor([ |
|
previous_raw_value, |
|
value |
|
], dim=self.v_seq_dim) |
|
|
|
|
|
beacon_key = slice_tensor(key, index=torch.logical_or(beacon_indices == 1, beacon_indices == -1), dim=self.k_seq_dim) |
|
beacon_value = slice_tensor(value, index=torch.logical_or(beacon_indices == 1, beacon_indices == -1), dim=self.v_seq_dim) |
|
|
|
if self.config.beacon_accum: |
|
beacon_key = cat_tensor([previous_beacon_key, beacon_key], dim=self.k_seq_dim) |
|
beacon_value = cat_tensor([previous_beacon_value, beacon_value], dim=self.v_seq_dim) |
|
|
|
if self.raw_size_to_cache > 0: |
|
raw_key = slice_tensor(key, index=beacon_indices == 0, dim=self.k_seq_dim) |
|
raw_key = slice_tensor(raw_key, start=-raw_size_to_cache, dim=self.k_seq_dim) |
|
|
|
raw_value = slice_tensor(value, index=beacon_indices == 0, dim=self.v_seq_dim) |
|
raw_value = slice_tensor(raw_value, start=-raw_size_to_cache, dim=self.v_seq_dim) |
|
|
|
else: |
|
raw_key = None |
|
raw_value = None |
|
|
|
return beacon_key, beacon_value, raw_key, raw_value |
|
|
|
|
|
def slice_tensor(x, start=None, end=None, step=None, index=None, dim=2): |
|
if x is None: |
|
return None |
|
if end == 0: |
|
return None |
|
if start == x.shape[dim]: |
|
return None |
|
if start is not None and start == end: |
|
return None |
|
if dim == 2: |
|
if index is not None: |
|
return x[:, :, index] |
|
elif start is None and end is not None: |
|
if step is None: |
|
return x[:, :, :end, ...] |
|
else: |
|
return x[:, :, :end:step, ...] |
|
elif start is not None and end is None: |
|
if step is None: |
|
return x[:, :, start:, ...] |
|
else: |
|
return x[:, :, start::step, ...] |
|
elif start is not None and end is not None: |
|
if step is None: |
|
return x[:, :, start:end, ...] |
|
else: |
|
return x[:, :, start:end:step, ...] |
|
elif dim == 1: |
|
if index is not None: |
|
return x[:, :, index] |
|
elif start is None and end is not None: |
|
if step is None: |
|
return x[:, :end, ...] |
|
else: |
|
return x[:, :end:step, ...] |
|
elif start is not None and end is None: |
|
if step is None: |
|
return x[:, start:, ...] |
|
else: |
|
return x[:, start::step, ...] |
|
elif start is not None and end is not None: |
|
if step is None: |
|
return x[:, start:end, ...] |
|
else: |
|
return x[:, start:end:step, ...] |
|
else: |
|
raise NotImplementedError |
|
|
|
def cat_tensor(list_of_tensors, dim=-1): |
|
list_of_tensors = [t for t in list_of_tensors if t is not None] |
|
if len(list_of_tensors) > 1: |
|
result = torch.cat(list_of_tensors, dim=dim) |
|
elif len(list_of_tensors) == 1: |
|
result = list_of_tensors[0] |
|
else: |
|
result = None |
|
return result |
|
|
|
def slice_activations(activations, start=None, end=None, k_seq_dim=2, v_seq_dim=2): |
|
new_activations = [] |
|
for key, value in activations: |
|
new_key = slice_tensor(key, start=start, end=end, dim=k_seq_dim) |
|
new_value = slice_tensor(value, start=start, end=end, dim=v_seq_dim) |
|
new_activations.append([new_key, new_value]) |
|
return new_activations |
|
|
|
def cat_activations(list_of_activations, k_seq_dim=2, v_seq_dim=2): |
|
assert all(len(x) == len(list_of_activations[0]) for x in list_of_activations), f"Make sure all activations have the same number of layers! Found {[len(x) for x in list_of_activations]}." |
|
|
|
new_activations = [] |
|
for layer_idx in range(len(list_of_activations[0])): |
|
keys = [x[layer_idx][0] for x in list_of_activations] |
|
values = [x[layer_idx][1] for x in list_of_activations] |
|
|
|
new_key = cat_tensor(keys, dim=k_seq_dim) |
|
new_value = cat_tensor(values, dim=v_seq_dim) |
|
new_activations.append([new_key, new_value]) |
|
return new_activations |
|
|
|
def interleave_activations(main_activations, augment_activations, main_spans, augment_spans, k_seq_dim=2, v_seq_dim=2, device=torch.device("cuda")): |
|
""" Interleave main_activations and augment_activations according to main_span and augment_span. |
|
|
|
Args: |
|
main_span: a list of tuples (start_idx, end_idx). when start_idx and end_idx is None, the augment_activations will be plugged in. |
|
augment_span: a list of tuples (start_idx, end_idx) |
|
""" |
|
assert len(main_activations) == len(augment_activations) , f"Make sure main and augment activations have the same number of layers! Found {len(main_activations)} and {len(augment_activations)}!" |
|
assert sum(x[0] is None and x[1] is None for x in main_spans) == len(augment_spans), f"Make sure the number of slots for augmentation (start_idx=None and end_idx=None in main_spans) matches the number of augmentations. Found {sum(x for x in main_spans if x[0] is None and x[1] is None)} slots but {len(augment_spans)} augmentations!" |
|
|
|
new_activations = [] |
|
for layer_idx in range(len(main_activations)): |
|
main_key, main_value = main_activations[layer_idx] |
|
augment_key, augment_value = augment_activations[layer_idx] |
|
|
|
sliced_keys = [] |
|
sliced_values = [] |
|
|
|
augment_idx = 0 |
|
for start, end in main_spans: |
|
if start is None and end is None: |
|
|
|
augment_start, augment_end = augment_spans[augment_idx] |
|
sliced_key = slice_tensor( |
|
augment_key, |
|
start=augment_start, |
|
end=augment_end, |
|
dim=k_seq_dim |
|
).to(device) |
|
sliced_value = slice_tensor( |
|
augment_value, |
|
start=augment_start, |
|
end=augment_end, |
|
dim=v_seq_dim |
|
).to(device) |
|
|
|
else: |
|
sliced_key = slice_tensor( |
|
main_key, |
|
start=start, |
|
end=end, |
|
dim=k_seq_dim |
|
) |
|
sliced_value = slice_tensor( |
|
main_value, |
|
start=start, |
|
end=end, |
|
dim=v_seq_dim |
|
) |
|
|
|
sliced_keys.append(sliced_key) |
|
sliced_values.append(sliced_value) |
|
|
|
new_key = cat_tensor(sliced_keys, dim=k_seq_dim) |
|
new_value = cat_tensor(sliced_values, dim=v_seq_dim) |
|
new_activations.append([new_key, new_value]) |
|
|
|
return new_activations |
|
|
|
def softmax(x:np.ndarray, axis=-1, temperature=1): |
|
if isinstance(x, list): |
|
x = np.array(x) |
|
x = x / temperature |
|
x = x - x.max(axis=axis, keepdims=True) |
|
y = np.exp(x) |
|
return y / y.sum(axis=axis, keepdims=True) |
|
|
|
def l1_norm(x): |
|
sum_x = sum(x) |
|
x = [y/sum_x for y in x] |
|
return x |