|
|
|
|
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from transformers.generation.stopping_criteria import ( |
|
MaxLengthCriteria, |
|
StoppingCriteriaList, |
|
) |
|
from typing import Union, List |
|
from .eva_cache import EvaStaticCacheForTriton |
|
from .eva_prep_kv_kernel import triton_eva_prep_kv_fwd |
|
|
|
class MultibyteEosTokenCriteria: |
|
""" |
|
This class implements a simple stopping criteria to stop generation whenever |
|
the "end-of-sequence" token is generated in the last `new_tokens` tokens. |
|
|
|
Adapted from |
|
https://github.com/huggingface/transformers/blob/main/src/transformers/generation/stopping_criteria.py#L446 |
|
By default, it uses the `model.generation_config.eos_token_id`. |
|
|
|
Args: |
|
eos_token_id (`Union[int, List[int]]`): |
|
The id(s) of the *end-of-sequence* token. |
|
""" |
|
|
|
def __init__(self, eos_token_ids: Union[int, List[int]]): |
|
if isinstance(eos_token_ids, int): |
|
eos_token_ids = [eos_token_ids] |
|
self.eos_token_ids = eos_token_ids |
|
|
|
def __call__(self, input_ids: torch.LongTensor, new_tokens: int) -> bool: |
|
current_input_len = input_ids.shape[-1] |
|
new_token_ids = input_ids[:, current_input_len - new_tokens:] |
|
for eos_token_id in self.eos_token_ids: |
|
if torch.any(new_token_ids == eos_token_id): |
|
return True |
|
return False |
|
|
|
def build_tree(spec): |
|
nodes_at_depth = [] |
|
nodes_at_depth.append([()]) |
|
|
|
for d in range(1, len(spec) + 1): |
|
prev_nodes = nodes_at_depth[d - 1] |
|
spec_list = spec[d - 1] |
|
current_nodes = [] |
|
for node_idx, node in enumerate(prev_nodes): |
|
if node_idx < len(spec_list): |
|
num_children = spec_list[node_idx] |
|
else: |
|
num_children = 0 |
|
for child_idx in range(num_children): |
|
new_node = node + (child_idx,) |
|
current_nodes.append(new_node) |
|
nodes_at_depth.append(current_nodes) |
|
|
|
|
|
all_nodes = [node for depth_nodes in nodes_at_depth for node in depth_nodes if node] |
|
return all_nodes |
|
|
|
evabyte_7b_95 = build_tree( |
|
[ |
|
[10], |
|
[10, 8, 2, 2, 1, 1], |
|
[10, 4, 2, 1, 0, 0, 0, 0, 0, 0, 2, 1, 1, 0, 0, 0, 0, 0, 1], |
|
[8, 2, 2, 1, 0, 0, 0, 0, 0, 0, 1], |
|
[6, 2, 1, 1], |
|
[4, 2, 1, 1], |
|
[4, 2, 1], |
|
] |
|
) |
|
evabyte_7b_31 = build_tree( |
|
[ |
|
[4], |
|
[3, 2, 1, 1], |
|
[3, 2, 1, 1], |
|
[2, 1, 1], |
|
[2, 1], |
|
[2, 1], |
|
[2, 1], |
|
] |
|
) |
|
TOPK = 10 |
|
|
|
def pad_path(path, length, pad_value=-2): |
|
""" |
|
Pad the given path list with a specific value up to a specified length. |
|
|
|
Parameters: |
|
- path (list): The original list that needs padding. |
|
- length (int): The desired length of the padded list. |
|
- pad_value (optional, default=-2): The value to use for padding. |
|
|
|
Returns: |
|
- list: A new list based on the original path but padded to the desired length. |
|
|
|
Example: |
|
>>> pad_path([1,2,3], 5) |
|
[1, 2, 3, -2, -2] |
|
|
|
Note: |
|
If the given path is already longer than the specified length, |
|
then no padding occurs, and the original path is returned. |
|
""" |
|
return path + [pad_value] * (length - len(path)) |
|
|
|
def reset_past_key_values(passed_key_values): |
|
""" |
|
Resets the current lengths in the passed key-values to zero. |
|
|
|
This function is designed to be used during the evaluation of a baseline model. |
|
It iterates through each layer's key-values and sets their current lengths to zero, |
|
effectively resetting their state. |
|
|
|
Args: |
|
- passed_key_values (list of torch.Tensor): Contains past hidden states and past attention values for each layer. |
|
|
|
Returns: |
|
- passed_key_values (list of torch.Tensor): Updated past hidden states and past attention values with reset lengths. |
|
""" |
|
for i in range(len(passed_key_values)): |
|
for j in range(2): |
|
passed_key_values[i][j].current_length.fill_(0) |
|
return passed_key_values |
|
|
|
def get_nucleus_one_token(logit, temperature, top_p): |
|
""" |
|
Performs token sampling based on the nucleus (top-p) sampling method. |
|
|
|
This function selects a token from a given logit distribution using the nucleus sampling strategy. |
|
It allows for more controlled and diverse generation compared to traditional top-k sampling. |
|
|
|
Args: |
|
logit (torch.Tensor): The logits from a language model output, expected to be a 2D tensor (BxC). |
|
temperature (float): A temperature parameter to control the randomness in sampling. |
|
Higher values increase diversity, lower values make selections more deterministic. |
|
top_p (float): The cumulative probability threshold for nucleus sampling. |
|
It controls the size of the set of high-probability tokens to consider for sampling. |
|
|
|
Returns: |
|
torch.Tensor: A tensor containing the indices of the sampled tokens. |
|
""" |
|
if top_p >= 1: |
|
return torch.multinomial(F.softmax(logit / temperature, dim=-1), 1) |
|
logit = logit / temperature |
|
probs = torch.softmax(logit, dim=-1) |
|
sorted_logits, sorted_indices = torch.sort(probs, descending=True) |
|
cum_probs = torch.cumsum(sorted_logits, dim=-1) |
|
sorted_indices_to_remove = cum_probs > top_p |
|
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
|
sorted_indices_to_remove[..., 0] = 0 |
|
indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove) |
|
logit[indices_to_remove] = float('-inf') |
|
sampled_tokens = torch.multinomial(F.softmax(logit, dim=-1), 1) |
|
return sampled_tokens |
|
|
|
def get_typical_one_token(logit, temperature, posterior_threshold, posterior_alpha): |
|
""" |
|
Implements token sampling based on the typical sampling method. |
|
|
|
This function selects a token from a given logit distribution using the typical sampling strategy, |
|
aiming to balance between diversity and likelihood in a more nuanced way compared to traditional methods. |
|
|
|
Args: |
|
logit (torch.Tensor): The logits from a language model output, expected to be a 2D tensor. |
|
temperature (float): A parameter to control the randomness in sampling. |
|
Higher values increase diversity, lower values make selections more deterministic. |
|
posterior_threshold (float): A threshold to decide the lower bound of probabilities to be considered for sampling. |
|
posterior_alpha (float): A scaling factor applied to the entropy-based adaptive threshold. |
|
|
|
Returns: |
|
torch.Tensor: A tensor containing the indices of the sampled tokens. |
|
""" |
|
logit = logit / temperature |
|
probs = torch.softmax(logit, dim=-1) |
|
entropy = -torch.sum( |
|
probs * torch.log(probs + 1e-5), dim=-1 |
|
) |
|
threshold = torch.minimum( |
|
torch.ones_like(entropy) * posterior_threshold, |
|
torch.exp(-entropy) * posterior_alpha, |
|
) |
|
indices_to_remove = probs < threshold.unsqueeze(-1) |
|
logit[indices_to_remove] = float('-inf') |
|
sampled_tokens = torch.multinomial(F.softmax(logit, dim=-1), 1) |
|
return sampled_tokens |
|
|
|
|
|
|
|
def generate_medusa_buffers(medusa_choices, device="cuda"): |
|
""" |
|
Generate buffers for the Medusa structure based on the provided choices. |
|
|
|
Parameters: |
|
- medusa_choices (list): A nested list representing tree in the Medusa structure. |
|
- device (str): Device to which the tensors should be moved. Default is "cuda". |
|
|
|
Returns: |
|
- dict: A dictionary containing buffers related to the Medusa structure. |
|
""" |
|
|
|
|
|
sorted_medusa_choices = sorted(medusa_choices, key=lambda x: (len(x), x)) |
|
medusa_len = len(sorted_medusa_choices) + 1 |
|
|
|
|
|
depth_counts = [0] * max([len(path) for path in sorted_medusa_choices]) |
|
for path in sorted_medusa_choices: |
|
depth_counts[len(path) - 1] += 1 |
|
|
|
|
|
medusa_attn_mask = torch.eye(medusa_len, medusa_len) |
|
medusa_attn_mask[:, 0] = 1 |
|
start = 0 |
|
for i in range(len(depth_counts)): |
|
for j in range(depth_counts[i]): |
|
cur_medusa_choice = sorted_medusa_choices[start + j] |
|
|
|
if len(cur_medusa_choice) == 1: |
|
continue |
|
ancestor_idx = [] |
|
for c in range(len(cur_medusa_choice) - 1): |
|
ancestor_idx.append(sorted_medusa_choices.index(cur_medusa_choice[:c+1]) + 1) |
|
medusa_attn_mask[j + start + 1, ancestor_idx] = 1 |
|
start += depth_counts[i] |
|
|
|
|
|
medusa_tree_indices = torch.zeros(medusa_len, dtype=torch.long) |
|
medusa_tree_indices[0] = 0 |
|
start = 0 |
|
for i in range(len(depth_counts)): |
|
for j in range(depth_counts[i]): |
|
cur_medusa_choice = sorted_medusa_choices[start + j] |
|
medusa_tree_indices[start + j + 1] = cur_medusa_choice[-1] + TOPK * i + 1 |
|
start += depth_counts[i] |
|
|
|
|
|
medusa_position_ids = torch.zeros(medusa_len, dtype=torch.long) |
|
start = 0 |
|
for i in range(len(depth_counts)): |
|
medusa_position_ids[start + 1: start + depth_counts[i] + 1] = i + 1 |
|
start += depth_counts[i] |
|
|
|
|
|
retrieve_indices_nest = [] |
|
retrieve_paths = [] |
|
for i in range(len(sorted_medusa_choices)): |
|
cur_medusa_choice = sorted_medusa_choices[-i-1] |
|
retrieve_indice = [] |
|
if cur_medusa_choice in retrieve_paths: |
|
continue |
|
else: |
|
for c in range(len(cur_medusa_choice)): |
|
retrieve_indice.append(sorted_medusa_choices.index(cur_medusa_choice[:c+1])) |
|
retrieve_paths.append(cur_medusa_choice[:c+1]) |
|
retrieve_indices_nest.append(retrieve_indice) |
|
max_length = max([len(x) for x in retrieve_indices_nest]) |
|
retrieve_indices = [pad_path(path, max_length) for path in retrieve_indices_nest] |
|
retrieve_indices = torch.tensor(retrieve_indices, dtype=torch.long) |
|
retrieve_indices = retrieve_indices + 1 |
|
retrieve_indices = torch.cat([torch.zeros((retrieve_indices.shape[0], 1), dtype=torch.long), retrieve_indices], dim=1) |
|
|
|
|
|
medusa_buffers = { |
|
"medusa_attn_mask": medusa_attn_mask.unsqueeze(0).unsqueeze(0), |
|
"tree_indices": medusa_tree_indices, |
|
"medusa_position_ids": medusa_position_ids.unsqueeze(0), |
|
"retrieve_indices": retrieve_indices, |
|
} |
|
|
|
|
|
medusa_buffers = { |
|
k: v.clone().to(device) |
|
if isinstance(v, torch.Tensor) |
|
else torch.tensor(v, device=device) |
|
for k, v in medusa_buffers.items() |
|
} |
|
return medusa_buffers |
|
|
|
def generate_candidates( |
|
medusa_logits, |
|
logits, |
|
tree_indices, |
|
retrieve_indices, |
|
temperature = 0, |
|
posterior_threshold=0.3, |
|
posterior_alpha = 0.09, |
|
top_p=0.8, |
|
sampling = 'typical', |
|
fast = False |
|
): |
|
|
|
|
|
|
|
|
|
|
|
|
|
if temperature == 0 or fast: |
|
candidates_ids = torch.argmax(logits[:, -1]).unsqueeze(0) |
|
else: |
|
if sampling == 'typical': |
|
candidates_ids = get_typical_one_token(logits[:, -1], temperature, posterior_threshold, posterior_alpha).squeeze(0) |
|
elif sampling == 'nucleus': |
|
candidates_ids = get_nucleus_one_token(logits[:, -1], temperature, top_p).squeeze(0) |
|
else: |
|
raise NotImplementedError |
|
|
|
|
|
|
|
|
|
|
|
|
|
candidates_medusa_ids = torch.topk(medusa_logits[:, 0, -1], TOPK, dim=-1).indices |
|
|
|
|
|
candidate_ids = torch.cat([candidates_ids, candidates_medusa_ids.view(-1)], dim=-1) |
|
|
|
|
|
|
|
|
|
tree_candidate_ids = candidate_ids[tree_indices] |
|
|
|
|
|
|
|
|
|
tree_candidate_ids_ext = torch.cat( |
|
[ |
|
tree_candidate_ids, |
|
torch.zeros((1), dtype=torch.long, device=tree_candidate_ids.device) |
|
], |
|
dim=0 |
|
) |
|
|
|
unflattened_candidate_ids = tree_candidate_ids_ext[retrieve_indices] |
|
|
|
tree_candidate_ids = tree_candidate_ids.unsqueeze(0) |
|
|
|
return tree_candidate_ids, unflattened_candidate_ids |
|
|
|
def get_nucleus_posterior_mask(logits, candidates, temperature, top_p): |
|
""" |
|
Generates a posterior mask for token candidates using nucleus (top-p) sampling. |
|
|
|
This function applies nucleus sampling to a set of logits, and then generates a mask indicating |
|
which candidate tokens are selected. It adapts the sampling strategy to accommodate for |
|
temperature scaling and cumulative probability thresholding. |
|
|
|
Args: |
|
logits (torch.Tensor): A tensor of logits from a language model output. |
|
candidates (torch.Tensor): A tensor of candidate tokens to compare against sampled tokens. |
|
temperature (float): A parameter to scale the logits, controlling randomness in sampling. |
|
top_p (float): The cumulative probability threshold for nucleus sampling. |
|
|
|
Returns: |
|
torch.Tensor: A posterior mask indicating which candidate tokens match the sampled tokens. |
|
""" |
|
|
|
|
|
|
|
logits = logits[:, :-1] / temperature |
|
n_samples, n_tokens = logits.shape[0], logits.shape[1] |
|
logits = logits.view(n_samples*n_tokens, -1) |
|
if top_p >= 1: |
|
sampled_tokens = torch.multinomial(F.softmax(logits, dim=-1), 1) |
|
sampled_tokens = sampled_tokens.view(n_samples, n_tokens) |
|
posterior_mask = (candidates[:, 1:] == sampled_tokens).int() |
|
return posterior_mask |
|
|
|
probs = F.softmax(logits, dim=-1) |
|
|
|
sorted_logits, sorted_indices = torch.sort(probs, descending=True) |
|
|
|
|
|
cum_probs = torch.cumsum(sorted_logits, dim=-1) |
|
|
|
|
|
sorted_indices_to_remove = cum_probs > top_p |
|
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
|
sorted_indices_to_remove[..., 0] = 0 |
|
|
|
indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove) |
|
|
|
|
|
|
|
logits[indices_to_remove] = float('-inf') |
|
|
|
sampled_tokens = torch.multinomial(F.softmax(logits, dim=-1), 1) |
|
sampled_tokens = sampled_tokens.view(n_samples, n_tokens) |
|
|
|
posterior_mask = (candidates[:, 1:] == sampled_tokens).int() |
|
|
|
return posterior_mask |
|
|
|
def get_typical_posterior_mask(logits, candidates, temperature, posterior_threshold, posterior_alpha): |
|
""" |
|
Args: |
|
logits (torch.Tensor): A tensor of logits from a language model output. |
|
candidates (torch.Tensor): A tensor of candidate tokens to compare against sampled tokens. |
|
temperature (float): A parameter to scale the logits, controlling randomness in sampling. |
|
posterior_threshold (float): The minimum threshold for probabilities to be considered in sampling. |
|
posterior_alpha (float): A scaling factor applied to the entropy-based adaptive threshold. |
|
|
|
Returns: |
|
torch.Tensor: A posterior mask indicating which candidate tokens match the sampled tokens. |
|
""" |
|
logits = logits[:, :-1] / temperature |
|
n_samples, n_tokens = logits.shape[0], logits.shape[1] |
|
logits = logits.view(n_samples*n_tokens, -1) |
|
probs = F.softmax(logits, dim=-1) |
|
entropy = -torch.sum( |
|
probs * torch.log(probs + 1e-5), dim=-1 |
|
) |
|
threshold = torch.minimum( |
|
torch.ones_like(entropy) * posterior_threshold, |
|
torch.exp(-entropy) * posterior_alpha, |
|
) |
|
indices_to_remove = probs < threshold.unsqueeze(-1) |
|
logits[indices_to_remove] = float('-inf') |
|
sampled_tokens = torch.multinomial(F.softmax(logits, dim=-1), 1) |
|
sampled_tokens = sampled_tokens.view(n_samples, n_tokens) |
|
posterior_mask = (candidates[:, 1:] == sampled_tokens).int() |
|
return posterior_mask |
|
|
|
|
|
|
|
def evaluate_posterior( |
|
logits, |
|
candidates, |
|
temperature, |
|
posterior_threshold=0.3, |
|
posterior_alpha = 0.09, |
|
top_p=0.8, |
|
sampling = 'typical', |
|
fast = True |
|
): |
|
if logits.shape[1] <= 1: |
|
return torch.tensor(0, dtype=torch.long, device=candidates.device), 0 |
|
|
|
if temperature == 0: |
|
|
|
posterior_mask = ( |
|
candidates[:, 1:] == torch.argmax(logits[:, :-1], dim=-1) |
|
).int() |
|
candidates_accept_length = (torch.cumprod(posterior_mask, dim=1)).sum(dim=1) |
|
accept_length = candidates_accept_length.max().item() |
|
|
|
if accept_length == 0: |
|
|
|
best_candidate = torch.tensor(0, dtype=torch.long, device=candidates.device) |
|
else: |
|
best_candidate = torch.argmax(candidates_accept_length).to(torch.long) |
|
return best_candidate, accept_length |
|
elif sampling == 'typical': |
|
if fast: |
|
posterior_prob = torch.softmax(logits[:, :-1] / temperature, dim=-1) |
|
candidates_prob = torch.gather( |
|
posterior_prob, dim=-1, index=candidates[:, 1:].unsqueeze(-1) |
|
).squeeze(-1) |
|
posterior_entropy = -torch.sum( |
|
posterior_prob * torch.log(posterior_prob + 1e-5), dim=-1 |
|
) |
|
threshold = torch.minimum( |
|
torch.ones_like(posterior_entropy) * posterior_threshold, |
|
torch.exp(-posterior_entropy) * posterior_alpha, |
|
) |
|
posterior_mask = candidates_prob > threshold |
|
candidates_accept_length = (torch.cumprod(posterior_mask, dim=1)).sum(dim=1) |
|
|
|
|
|
accept_length = candidates_accept_length.max().item() |
|
if accept_length == 0: |
|
|
|
best_candidate = torch.tensor(0, dtype=torch.long, device=candidates.device) |
|
else: |
|
best_candidates = torch.where(candidates_accept_length == accept_length)[0] |
|
|
|
likelihood = torch.sum( |
|
torch.log(candidates_prob[best_candidates, :accept_length]), dim=-1 |
|
) |
|
best_candidate = best_candidates[torch.argmax(likelihood)] |
|
return best_candidate, accept_length |
|
|
|
posterior_mask = get_typical_posterior_mask(logits, candidates, temperature, posterior_threshold, posterior_alpha) |
|
candidates_accept_length = (torch.cumprod(posterior_mask, dim=1)).sum(dim=1) |
|
|
|
accept_length = candidates_accept_length.max().item() |
|
|
|
if accept_length == 0: |
|
|
|
best_candidate = torch.tensor(0, dtype=torch.long, device=candidates.device) |
|
else: |
|
best_candidate = torch.argmax(candidates_accept_length).to(torch.long) |
|
|
|
return best_candidate, accept_length |
|
elif sampling == 'nucleus': |
|
assert top_p < 1.0 + 1e-6, "top_p should between 0 and 1" |
|
posterior_mask = get_nucleus_posterior_mask(logits, candidates, temperature, top_p) |
|
candidates_accept_length = (torch.cumprod(posterior_mask, dim=1)).sum(dim=1) |
|
accept_length = candidates_accept_length.max().item() |
|
|
|
if accept_length == 0: |
|
|
|
best_candidate = torch.tensor(0, dtype=torch.long, device=candidates.device) |
|
else: |
|
best_candidate = torch.argmax(candidates_accept_length).to(torch.long) |
|
return best_candidate, accept_length |
|
else: |
|
raise NotImplementedError |
|
|
|
def update_inference_inputs( |
|
input_ids, |
|
medusa_logits, |
|
logits, |
|
candidate_ids, |
|
best_candidate, |
|
accept_length, |
|
): |
|
input_ids = torch.cat( |
|
[ |
|
input_ids, |
|
candidate_ids[None, best_candidate, : accept_length + 1] |
|
], |
|
dim=-1 |
|
) |
|
logits = logits[ |
|
None, best_candidate, accept_length : accept_length + 1 |
|
] |
|
medusa_logits = medusa_logits[ |
|
:, None, best_candidate, accept_length : accept_length + 1 |
|
] |
|
|
|
new_token = accept_length + 1 |
|
return input_ids, medusa_logits, logits, new_token |
|
|
|
def split_logits(full_logits): |
|
|
|
logits = full_logits[..., 0, :] |
|
medusa_logits = full_logits[..., 1:, :].permute(2, 0, 1, 3) |
|
return medusa_logits, logits |
|
|
|
class MultiByteDecodingMixin: |
|
def multi_byte_pred_update_cache( |
|
self, |
|
past_key_values, |
|
retrieve_indices, |
|
best_candidate, |
|
new_tokens, |
|
): |
|
prev_window_len = past_key_values.get_past_window_pos(0) |
|
select_indices = ( |
|
retrieve_indices[best_candidate, : new_tokens] + prev_window_len |
|
) |
|
for layer_idx in range(self.config.num_hidden_layers): |
|
|
|
past_key_values.update_past_len(new_tokens, layer_idx) |
|
|
|
past_window_k = past_key_values.past_window_k[layer_idx] |
|
past_window_v = past_key_values.past_window_v[layer_idx] |
|
|
|
tgt_window_k = past_window_k[..., select_indices, :] |
|
tgt_window_v = past_window_v[..., select_indices, :] |
|
|
|
dst_window_k = past_window_k[..., prev_window_len : prev_window_len + new_tokens, :] |
|
dst_window_v = past_window_v[..., prev_window_len : prev_window_len + new_tokens, :] |
|
|
|
dst_window_k.copy_(tgt_window_k, non_blocking=True) |
|
dst_window_v.copy_(tgt_window_v, non_blocking=True) |
|
|
|
new_window_len = prev_window_len + new_tokens |
|
if new_window_len >= self.config.window_size: |
|
assert new_window_len < 2 * self.config.window_size |
|
|
|
dump_k = past_window_k[..., :self.config.window_size, :].clone() |
|
dump_v = past_window_v[..., :self.config.window_size, :].clone() |
|
|
|
_window_len = new_window_len - self.config.window_size |
|
|
|
if _window_len > 0: |
|
new_window_k = past_window_k[..., self.config.window_size : new_window_len, :] |
|
new_window_v = past_window_v[..., self.config.window_size : new_window_len, :] |
|
|
|
_dst_window_k = past_window_k[..., : _window_len, :] |
|
_dst_window_v = past_window_v[..., : _window_len, :] |
|
|
|
_dst_window_k.copy_(new_window_k, non_blocking=True) |
|
_dst_window_v.copy_(new_window_v, non_blocking=True) |
|
|
|
past_key_values.past_window_pos[layer_idx] = _window_len |
|
else: |
|
dump_k = None |
|
dump_v = None |
|
past_key_values.past_window_pos[layer_idx] = new_window_len |
|
|
|
if dump_k is not None and dump_v is not None: |
|
rfa_k, rfa_v = triton_eva_prep_kv_fwd( |
|
dump_k, dump_v, |
|
self.model.layers[layer_idx].self_attn.adaptive_mu_k, |
|
self.model.layers[layer_idx].self_attn.adaptive_phi, |
|
None, |
|
self.model.layers[layer_idx].self_attn.head_dim_scaling, |
|
self.model.layers[layer_idx].self_attn.chunk_size |
|
) |
|
rfa_k, rfa_v = past_key_values.update_chunk_rfas( |
|
rfa_k, rfa_v, layer_idx |
|
) |
|
return past_key_values |
|
|
|
def _multi_byte_pred_update_cache_when_prefil_len_eq_window_size( |
|
self, |
|
past_key_values, |
|
): |
|
prev_window_len = past_key_values.get_past_window_pos(0) |
|
for layer_idx in range(self.config.num_hidden_layers): |
|
|
|
past_window_k = past_key_values.past_window_k[layer_idx] |
|
past_window_v = past_key_values.past_window_v[layer_idx] |
|
|
|
new_window_len = prev_window_len |
|
if new_window_len == self.config.window_size: |
|
dump_k = past_window_k[..., :self.config.window_size, :].clone() |
|
dump_v = past_window_v[..., :self.config.window_size, :].clone() |
|
past_key_values.past_window_pos[layer_idx] = 0 |
|
|
|
if dump_k is not None and dump_v is not None: |
|
rfa_k, rfa_v = triton_eva_prep_kv_fwd( |
|
dump_k, dump_v, |
|
self.model.layers[layer_idx].self_attn.adaptive_mu_k, |
|
self.model.layers[layer_idx].self_attn.adaptive_phi, |
|
None, |
|
self.model.layers[layer_idx].self_attn.head_dim_scaling, |
|
self.model.layers[layer_idx].self_attn.chunk_size |
|
) |
|
rfa_k, rfa_v = past_key_values.update_chunk_rfas( |
|
rfa_k, rfa_v, layer_idx |
|
) |
|
return past_key_values |
|
|
|
def multi_byte_pred_update_attn_mask( |
|
self, |
|
last_iter_new_tokens, |
|
tree_candidate_ids, |
|
past_attn_mask, |
|
medusa_attn_mask, |
|
past_key_values, |
|
): |
|
batch_size, tree_candidate_len = tree_candidate_ids.shape |
|
seen_tokens = past_key_values.get_seq_length() |
|
|
|
|
|
assert seen_tokens > 0 |
|
|
|
assert last_iter_new_tokens < self.config.window_size |
|
|
|
if past_attn_mask is not None and seen_tokens < self.config.window_size: |
|
past_attn_mask = torch.cat( |
|
[ |
|
past_attn_mask, |
|
torch.ones( |
|
[batch_size, 1, tree_candidate_len, last_iter_new_tokens], |
|
dtype=torch.bool, |
|
device=self.device |
|
) |
|
], |
|
dim=-1 |
|
) |
|
else: |
|
|
|
|
|
|
|
chunks_per_window = int(self.config.window_size // self.config.chunk_size) |
|
|
|
window_tokens = seen_tokens % self.config.window_size |
|
num_windows_seen_so_far = seen_tokens // self.config.window_size |
|
attn_mask_len = num_windows_seen_so_far * chunks_per_window + window_tokens |
|
past_attn_mask = torch.ones( |
|
(batch_size, 1, tree_candidate_len, attn_mask_len), |
|
dtype=torch.bool, |
|
device=self.device |
|
) |
|
|
|
|
|
tree_attn_mask = torch.cat( |
|
[ |
|
past_attn_mask, |
|
medusa_attn_mask.to(torch.bool) |
|
], |
|
dim=-1 |
|
) |
|
return tree_attn_mask, past_attn_mask |
|
|
|
@torch.no_grad() |
|
def multi_byte_generate( |
|
self, |
|
input_ids, |
|
attention_mask=None, |
|
temperature=0.0, |
|
max_length=None, |
|
max_new_tokens=None, |
|
stopping_criteria=None, |
|
posterior_threshold=0.09, |
|
posterior_alpha=0.3, |
|
top_p=0.8, |
|
sampling='typical', |
|
fast=True, |
|
do_sample=False, |
|
medusa_choices=None, |
|
return_acc_lengths=False |
|
): |
|
if do_sample or temperature > 0.0: |
|
fast = False |
|
|
|
|
|
if max_new_tokens is not None: |
|
max_length = max_new_tokens + input_ids.shape[-1] |
|
elif max_new_tokens is None and max_length is None: |
|
max_length = getattr(self.config, "max_position_embeddings", 32768) |
|
|
|
|
|
eos_stop_criteria = MultibyteEosTokenCriteria(self.generation_config.eos_token_id) |
|
stop_criteria = StoppingCriteriaList() |
|
if max_length is not None: |
|
max_position_embeddings = getattr(self.config, "max_position_embeddings", None) |
|
stop_criteria.append( |
|
MaxLengthCriteria( |
|
max_length=max_length, |
|
max_position_embeddings=max_position_embeddings, |
|
) |
|
) |
|
if stopping_criteria is not None and len(stopping_criteria) > 0: |
|
stop_criteria.extend(stopping_criteria) |
|
|
|
assert input_ids.shape[0] == 1, "Only support batch size 1 for now" |
|
assert attention_mask is None, "Only support attention mask None for now" |
|
|
|
input_ids = input_ids.clone() |
|
position_ids = torch.arange(0, input_ids.shape[1], device=self.device, dtype=int).reshape(1, -1) |
|
|
|
|
|
|
|
|
|
if medusa_choices is None: |
|
medusa_choices = evabyte_7b_95 |
|
medusa_buffers = generate_medusa_buffers( |
|
medusa_choices, device=self.device |
|
) |
|
|
|
past_key_values = EvaStaticCacheForTriton( |
|
input_ids.shape[0], |
|
self.config.num_attention_heads, |
|
|
|
self.config.window_size + 256, |
|
self.config.hidden_size // self.config.num_attention_heads, |
|
self.config.num_hidden_layers, |
|
self.lm_head.weight.dtype, |
|
self.lm_head.weight.device, |
|
) |
|
|
|
full_logits, past_key_values = self.forward( |
|
input_ids, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
use_cache=True, |
|
past_key_values=past_key_values, |
|
return_all_pred_logits=True, |
|
multibyte_decoding=False, |
|
) |
|
|
|
|
|
past_key_values = self._multi_byte_pred_update_cache_when_prefil_len_eq_window_size( |
|
past_key_values |
|
) |
|
medusa_logits, logits = split_logits(full_logits) |
|
|
|
past_attn_mask = None |
|
last_iter_new_tokens = 0 |
|
max_iters = 32768 |
|
if return_acc_lengths: |
|
acc_lengths = [] |
|
for _ in range(max_iters): |
|
|
|
|
|
|
|
tree_candidate_ids, unflattened_candidate_ids = generate_candidates( |
|
medusa_logits, |
|
logits, |
|
medusa_buffers["tree_indices"], |
|
medusa_buffers["retrieve_indices"], |
|
temperature=temperature, |
|
posterior_alpha=posterior_alpha, |
|
posterior_threshold=posterior_threshold, |
|
top_p=top_p, |
|
sampling=sampling, |
|
fast=fast, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
medusa_attn_mask, past_attn_mask = self.multi_byte_pred_update_attn_mask( |
|
last_iter_new_tokens, |
|
tree_candidate_ids, |
|
past_attn_mask, |
|
medusa_buffers["medusa_attn_mask"], |
|
past_key_values, |
|
) |
|
medusa_position_ids = medusa_buffers["medusa_position_ids"] + input_ids.shape[1] |
|
|
|
|
|
|
|
|
|
tree_full_logits, past_key_values = self.forward( |
|
tree_candidate_ids, |
|
past_key_values=past_key_values, |
|
attention_mask=medusa_attn_mask, |
|
position_ids=medusa_position_ids, |
|
return_all_pred_logits=True, |
|
multibyte_decoding=True, |
|
) |
|
_medusa_logits, _logits = split_logits(tree_full_logits) |
|
medusa_logits = _medusa_logits[..., 0, medusa_buffers["retrieve_indices"], :] |
|
logits = _logits[..., 0, medusa_buffers["retrieve_indices"], :] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tree_depth = unflattened_candidate_ids.shape[-1] |
|
if tree_depth + past_key_values.get_past_window_pos(0) > self.config.window_size: |
|
max_acc_len = self.config.window_size - past_key_values.get_past_window_pos(0) |
|
_trimmed_unflattened_candidate_ids = unflattened_candidate_ids[:, :max_acc_len] |
|
_trimmed_logits = logits[:, :max_acc_len] |
|
else: |
|
_trimmed_unflattened_candidate_ids = unflattened_candidate_ids |
|
_trimmed_logits = logits |
|
best_candidate, accept_length = evaluate_posterior( |
|
_trimmed_logits, |
|
_trimmed_unflattened_candidate_ids, |
|
temperature, |
|
posterior_threshold, |
|
posterior_alpha, |
|
top_p=top_p, |
|
sampling=sampling, |
|
fast=fast |
|
) |
|
|
|
|
|
|
|
|
|
input_ids, medusa_logits, logits, last_iter_new_tokens = update_inference_inputs( |
|
input_ids, |
|
medusa_logits, |
|
logits, |
|
unflattened_candidate_ids, |
|
best_candidate, |
|
accept_length, |
|
) |
|
|
|
past_key_values = self.multi_byte_pred_update_cache( |
|
past_key_values, |
|
medusa_buffers["retrieve_indices"], |
|
best_candidate, |
|
last_iter_new_tokens, |
|
) |
|
|
|
if return_acc_lengths: |
|
acc_lengths.append(last_iter_new_tokens) |
|
if stop_criteria(input_ids, None) or eos_stop_criteria(input_ids, last_iter_new_tokens): |
|
if return_acc_lengths: |
|
return input_ids, acc_lengths |
|
else: |
|
return input_ids |
|
if return_acc_lengths: |
|
return input_ids, acc_lengths |
|
else: |
|
return input_ids |
|
|