EvaByte-SFT / multibyte_decoding_evabyte.py
linzheng's picture
Upload EvaByteForCausalLM
474addc verified
# The implementation of multibyte deocidng is largely adapted from
# Medusa decoding: https://github.com/FasterDecoding/Medusa
import torch
import torch.nn.functional as F
from transformers.generation.stopping_criteria import (
MaxLengthCriteria,
StoppingCriteriaList,
)
from typing import Union, List
from .eva_cache import EvaStaticCacheForTriton
from .eva_prep_kv_kernel import triton_eva_prep_kv_fwd
class MultibyteEosTokenCriteria:
"""
This class implements a simple stopping criteria to stop generation whenever
the "end-of-sequence" token is generated in the last `new_tokens` tokens.
Adapted from
https://github.com/huggingface/transformers/blob/main/src/transformers/generation/stopping_criteria.py#L446
By default, it uses the `model.generation_config.eos_token_id`.
Args:
eos_token_id (`Union[int, List[int]]`):
The id(s) of the *end-of-sequence* token.
"""
def __init__(self, eos_token_ids: Union[int, List[int]]):
if isinstance(eos_token_ids, int):
eos_token_ids = [eos_token_ids]
self.eos_token_ids = eos_token_ids
def __call__(self, input_ids: torch.LongTensor, new_tokens: int) -> bool:
current_input_len = input_ids.shape[-1]
new_token_ids = input_ids[:, current_input_len - new_tokens:]
for eos_token_id in self.eos_token_ids:
if torch.any(new_token_ids == eos_token_id):
return True
return False
def build_tree(spec):
nodes_at_depth = []
nodes_at_depth.append([()]) # Root at depth 1
for d in range(1, len(spec) + 1):
prev_nodes = nodes_at_depth[d - 1]
spec_list = spec[d - 1]
current_nodes = []
for node_idx, node in enumerate(prev_nodes):
if node_idx < len(spec_list):
num_children = spec_list[node_idx]
else:
num_children = 0
for child_idx in range(num_children):
new_node = node + (child_idx,)
current_nodes.append(new_node)
nodes_at_depth.append(current_nodes)
# Flatten the list of nodes, excluding the root node if desired
all_nodes = [node for depth_nodes in nodes_at_depth for node in depth_nodes if node]
return all_nodes
evabyte_7b_95 = build_tree(
[
[10],
[10, 8, 2, 2, 1, 1],
[10, 4, 2, 1, 0, 0, 0, 0, 0, 0, 2, 1, 1, 0, 0, 0, 0, 0, 1],
[8, 2, 2, 1, 0, 0, 0, 0, 0, 0, 1],
[6, 2, 1, 1],
[4, 2, 1, 1],
[4, 2, 1],
]
)
evabyte_7b_31 = build_tree(
[
[4],
[3, 2, 1, 1],
[3, 2, 1, 1],
[2, 1, 1],
[2, 1],
[2, 1],
[2, 1],
]
)
TOPK = 10 # topk for sparse tree (10 is a placeholder and it is sufficient)
def pad_path(path, length, pad_value=-2):
"""
Pad the given path list with a specific value up to a specified length.
Parameters:
- path (list): The original list that needs padding.
- length (int): The desired length of the padded list.
- pad_value (optional, default=-2): The value to use for padding.
Returns:
- list: A new list based on the original path but padded to the desired length.
Example:
>>> pad_path([1,2,3], 5)
[1, 2, 3, -2, -2]
Note:
If the given path is already longer than the specified length,
then no padding occurs, and the original path is returned.
"""
return path + [pad_value] * (length - len(path))
def reset_past_key_values(passed_key_values):
"""
Resets the current lengths in the passed key-values to zero.
This function is designed to be used during the evaluation of a baseline model.
It iterates through each layer's key-values and sets their current lengths to zero,
effectively resetting their state.
Args:
- passed_key_values (list of torch.Tensor): Contains past hidden states and past attention values for each layer.
Returns:
- passed_key_values (list of torch.Tensor): Updated past hidden states and past attention values with reset lengths.
"""
for i in range(len(passed_key_values)):
for j in range(2):
passed_key_values[i][j].current_length.fill_(0)
return passed_key_values
def get_nucleus_one_token(logit, temperature, top_p):
"""
Performs token sampling based on the nucleus (top-p) sampling method.
This function selects a token from a given logit distribution using the nucleus sampling strategy.
It allows for more controlled and diverse generation compared to traditional top-k sampling.
Args:
logit (torch.Tensor): The logits from a language model output, expected to be a 2D tensor (BxC).
temperature (float): A temperature parameter to control the randomness in sampling.
Higher values increase diversity, lower values make selections more deterministic.
top_p (float): The cumulative probability threshold for nucleus sampling.
It controls the size of the set of high-probability tokens to consider for sampling.
Returns:
torch.Tensor: A tensor containing the indices of the sampled tokens.
"""
if top_p >= 1:
return torch.multinomial(F.softmax(logit / temperature, dim=-1), 1)
logit = logit / temperature
probs = torch.softmax(logit, dim=-1)
sorted_logits, sorted_indices = torch.sort(probs, descending=True)
cum_probs = torch.cumsum(sorted_logits, dim=-1)
sorted_indices_to_remove = cum_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove)
logit[indices_to_remove] = float('-inf')
sampled_tokens = torch.multinomial(F.softmax(logit, dim=-1), 1)
return sampled_tokens
def get_typical_one_token(logit, temperature, posterior_threshold, posterior_alpha):
"""
Implements token sampling based on the typical sampling method.
This function selects a token from a given logit distribution using the typical sampling strategy,
aiming to balance between diversity and likelihood in a more nuanced way compared to traditional methods.
Args:
logit (torch.Tensor): The logits from a language model output, expected to be a 2D tensor.
temperature (float): A parameter to control the randomness in sampling.
Higher values increase diversity, lower values make selections more deterministic.
posterior_threshold (float): A threshold to decide the lower bound of probabilities to be considered for sampling.
posterior_alpha (float): A scaling factor applied to the entropy-based adaptive threshold.
Returns:
torch.Tensor: A tensor containing the indices of the sampled tokens.
"""
logit = logit / temperature
probs = torch.softmax(logit, dim=-1)
entropy = -torch.sum(
probs * torch.log(probs + 1e-5), dim=-1
)
threshold = torch.minimum(
torch.ones_like(entropy) * posterior_threshold,
torch.exp(-entropy) * posterior_alpha,
)
indices_to_remove = probs < threshold.unsqueeze(-1)
logit[indices_to_remove] = float('-inf')
sampled_tokens = torch.multinomial(F.softmax(logit, dim=-1), 1)
return sampled_tokens
def generate_medusa_buffers(medusa_choices, device="cuda"):
"""
Generate buffers for the Medusa structure based on the provided choices.
Parameters:
- medusa_choices (list): A nested list representing tree in the Medusa structure.
- device (str): Device to which the tensors should be moved. Default is "cuda".
Returns:
- dict: A dictionary containing buffers related to the Medusa structure.
"""
# Sort the medusa_choices based on their lengths and then their values
sorted_medusa_choices = sorted(medusa_choices, key=lambda x: (len(x), x))
medusa_len = len(sorted_medusa_choices) + 1
# Initialize depth_counts to keep track of how many choices have a particular depth
depth_counts = [0] * max([len(path) for path in sorted_medusa_choices])
for path in sorted_medusa_choices:
depth_counts[len(path) - 1] += 1
# Create the attention mask for Medusa
medusa_attn_mask = torch.eye(medusa_len, medusa_len)
medusa_attn_mask[:, 0] = 1
start = 0
for i in range(len(depth_counts)):
for j in range(depth_counts[i]):
cur_medusa_choice = sorted_medusa_choices[start + j]
# retrieve ancestor position
if len(cur_medusa_choice) == 1:
continue
ancestor_idx = []
for c in range(len(cur_medusa_choice) - 1):
ancestor_idx.append(sorted_medusa_choices.index(cur_medusa_choice[:c+1]) + 1)
medusa_attn_mask[j + start + 1, ancestor_idx] = 1
start += depth_counts[i]
# Generate tree indices for the Medusa structure
medusa_tree_indices = torch.zeros(medusa_len, dtype=torch.long)
medusa_tree_indices[0] = 0
start = 0
for i in range(len(depth_counts)):
for j in range(depth_counts[i]):
cur_medusa_choice = sorted_medusa_choices[start + j]
medusa_tree_indices[start + j + 1] = cur_medusa_choice[-1] + TOPK * i + 1
start += depth_counts[i]
# Generate position IDs for the Medusa structure
medusa_position_ids = torch.zeros(medusa_len, dtype=torch.long)
start = 0
for i in range(len(depth_counts)):
medusa_position_ids[start + 1: start + depth_counts[i] + 1] = i + 1
start += depth_counts[i]
# Generate retrieval indices for Medusa structure verification
retrieve_indices_nest = []
retrieve_paths = []
for i in range(len(sorted_medusa_choices)):
cur_medusa_choice = sorted_medusa_choices[-i-1]
retrieve_indice = []
if cur_medusa_choice in retrieve_paths:
continue
else:
for c in range(len(cur_medusa_choice)):
retrieve_indice.append(sorted_medusa_choices.index(cur_medusa_choice[:c+1]))
retrieve_paths.append(cur_medusa_choice[:c+1])
retrieve_indices_nest.append(retrieve_indice)
max_length = max([len(x) for x in retrieve_indices_nest])
retrieve_indices = [pad_path(path, max_length) for path in retrieve_indices_nest]
retrieve_indices = torch.tensor(retrieve_indices, dtype=torch.long)
retrieve_indices = retrieve_indices + 1
retrieve_indices = torch.cat([torch.zeros((retrieve_indices.shape[0], 1), dtype=torch.long), retrieve_indices], dim=1)
# Aggregate the generated buffers into a dictionary
medusa_buffers = {
"medusa_attn_mask": medusa_attn_mask.unsqueeze(0).unsqueeze(0),
"tree_indices": medusa_tree_indices,
"medusa_position_ids": medusa_position_ids.unsqueeze(0),
"retrieve_indices": retrieve_indices,
}
# Move the tensors in the dictionary to the specified device
medusa_buffers = {
k: v.clone().to(device)
if isinstance(v, torch.Tensor)
else torch.tensor(v, device=device)
for k, v in medusa_buffers.items()
}
return medusa_buffers
def generate_candidates(
medusa_logits,
logits,
tree_indices,
retrieve_indices,
temperature = 0,
posterior_threshold=0.3,
posterior_alpha = 0.09,
top_p=0.8,
sampling = 'typical',
fast = False
):
# Say we have 3 heads, and the top-4 for each head are:
# [10, 3, 8, 4]
# [9, 5, 1, 6]
# [7, 16, 3, 2]
# candidates_id = 10
if temperature == 0 or fast:
candidates_ids = torch.argmax(logits[:, -1]).unsqueeze(0)
else:
if sampling == 'typical':
candidates_ids = get_typical_one_token(logits[:, -1], temperature, posterior_threshold, posterior_alpha).squeeze(0)
elif sampling == 'nucleus':
candidates_ids = get_nucleus_one_token(logits[:, -1], temperature, top_p).squeeze(0)
else:
raise NotImplementedError
# this calculates the top-k medusa logits
# candidates_medusa_id = [
# [9, 5, 1, 6]
# [7, 16, 3, 2]
# ]
candidates_medusa_ids = torch.topk(medusa_logits[:, 0, -1], TOPK, dim=-1).indices
# [10, 9, 5, 1, 6, 7, 16, 3, 2]
candidate_ids = torch.cat([candidates_ids, candidates_medusa_ids.view(-1)], dim=-1)
# based on the pre-defined tree_indices, select the corresponding candidates
# if we select top-2 and top-3 for the two heads (we select top-1 for the first head):
# tree_candidates = [10, 9, 5, 7, 16, 3, 7, 16, 3]
tree_candidate_ids = candidate_ids[tree_indices]
# tree_candidate_ids = [10, 9, 5, 7, 16, 3, 7, 16, 3, 0]
# Sometimes the tree_indices are padded, so we append a zero here
# so that all padded indices select the appended zero.
tree_candidate_ids_ext = torch.cat(
[
tree_candidate_ids,
torch.zeros((1), dtype=torch.long, device=tree_candidate_ids.device)
],
dim=0
)
# [[10, 9, 7], [10, 9, 16], [10, 9, 3], [10, 5, 7], [10, 5, 16], [10, 5, 3]]
unflattened_candidate_ids = tree_candidate_ids_ext[retrieve_indices]
tree_candidate_ids = tree_candidate_ids.unsqueeze(0)
return tree_candidate_ids, unflattened_candidate_ids
def get_nucleus_posterior_mask(logits, candidates, temperature, top_p):
"""
Generates a posterior mask for token candidates using nucleus (top-p) sampling.
This function applies nucleus sampling to a set of logits, and then generates a mask indicating
which candidate tokens are selected. It adapts the sampling strategy to accommodate for
temperature scaling and cumulative probability thresholding.
Args:
logits (torch.Tensor): A tensor of logits from a language model output.
candidates (torch.Tensor): A tensor of candidate tokens to compare against sampled tokens.
temperature (float): A parameter to scale the logits, controlling randomness in sampling.
top_p (float): The cumulative probability threshold for nucleus sampling.
Returns:
torch.Tensor: A posterior mask indicating which candidate tokens match the sampled tokens.
"""
# adapted from https://github.com/huggingface/transformers/blob/18a879f47576822aa1a5c49aecb27d89bfa5fa69/examples/run_generation.py#L79
# Apply temperature
logits = logits[:, :-1] / temperature
n_samples, n_tokens = logits.shape[0], logits.shape[1]
logits = logits.view(n_samples*n_tokens, -1)
if top_p >= 1:
sampled_tokens = torch.multinomial(F.softmax(logits, dim=-1), 1)
sampled_tokens = sampled_tokens.view(n_samples, n_tokens)
posterior_mask = (candidates[:, 1:] == sampled_tokens).int()
return posterior_mask
# Convert to probabilities (softmax)
probs = F.softmax(logits, dim=-1)
# Sort the probabilities
sorted_logits, sorted_indices = torch.sort(probs, descending=True)
# Compute cumulative probabilities
cum_probs = torch.cumsum(sorted_logits, dim=-1)
# Create mask for the top-p nucleus
sorted_indices_to_remove = cum_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove)
# Remove low-probability tokens
logits[indices_to_remove] = float('-inf')
# Sample from the remaining tokens
sampled_tokens = torch.multinomial(F.softmax(logits, dim=-1), 1)
sampled_tokens = sampled_tokens.view(n_samples, n_tokens)
# Create a mask for selected tokens
posterior_mask = (candidates[:, 1:] == sampled_tokens).int()
return posterior_mask
def get_typical_posterior_mask(logits, candidates, temperature, posterior_threshold, posterior_alpha):
"""
Args:
logits (torch.Tensor): A tensor of logits from a language model output.
candidates (torch.Tensor): A tensor of candidate tokens to compare against sampled tokens.
temperature (float): A parameter to scale the logits, controlling randomness in sampling.
posterior_threshold (float): The minimum threshold for probabilities to be considered in sampling.
posterior_alpha (float): A scaling factor applied to the entropy-based adaptive threshold.
Returns:
torch.Tensor: A posterior mask indicating which candidate tokens match the sampled tokens.
"""
logits = logits[:, :-1] / temperature
n_samples, n_tokens = logits.shape[0], logits.shape[1]
logits = logits.view(n_samples*n_tokens, -1)
probs = F.softmax(logits, dim=-1)
entropy = -torch.sum(
probs * torch.log(probs + 1e-5), dim=-1
)
threshold = torch.minimum(
torch.ones_like(entropy) * posterior_threshold,
torch.exp(-entropy) * posterior_alpha,
)
indices_to_remove = probs < threshold.unsqueeze(-1)
logits[indices_to_remove] = float('-inf')
sampled_tokens = torch.multinomial(F.softmax(logits, dim=-1), 1)
sampled_tokens = sampled_tokens.view(n_samples, n_tokens)
posterior_mask = (candidates[:, 1:] == sampled_tokens).int()
return posterior_mask
def evaluate_posterior(
logits,
candidates,
temperature,
posterior_threshold=0.3,
posterior_alpha = 0.09,
top_p=0.8,
sampling = 'typical',
fast = True
):
if logits.shape[1] <= 1:
return torch.tensor(0, dtype=torch.long, device=candidates.device), 0
# Greedy decoding based on temperature value
if temperature == 0:
# Find the tokens that match the maximum logits for each position in the sequence
posterior_mask = (
candidates[:, 1:] == torch.argmax(logits[:, :-1], dim=-1)
).int()
candidates_accept_length = (torch.cumprod(posterior_mask, dim=1)).sum(dim=1)
accept_length = candidates_accept_length.max().item()
# Choose the best candidate
if accept_length == 0:
# Default to the first candidate if none are accepted
best_candidate = torch.tensor(0, dtype=torch.long, device=candidates.device)
else:
best_candidate = torch.argmax(candidates_accept_length).to(torch.long)
return best_candidate, accept_length
elif sampling == 'typical':
if fast:
posterior_prob = torch.softmax(logits[:, :-1] / temperature, dim=-1)
candidates_prob = torch.gather(
posterior_prob, dim=-1, index=candidates[:, 1:].unsqueeze(-1)
).squeeze(-1)
posterior_entropy = -torch.sum(
posterior_prob * torch.log(posterior_prob + 1e-5), dim=-1
) # torch.sum(torch.log(*)) is faster than torch.prod
threshold = torch.minimum(
torch.ones_like(posterior_entropy) * posterior_threshold,
torch.exp(-posterior_entropy) * posterior_alpha,
)
posterior_mask = candidates_prob > threshold
candidates_accept_length = (torch.cumprod(posterior_mask, dim=1)).sum(dim=1)
# Choose the best candidate based on the evaluated posterior probabilities
accept_length = candidates_accept_length.max().item()
if accept_length == 0:
# If no candidates are accepted, just choose the first one
best_candidate = torch.tensor(0, dtype=torch.long, device=candidates.device)
else:
best_candidates = torch.where(candidates_accept_length == accept_length)[0]
# Accept the best one according to likelihood
likelihood = torch.sum(
torch.log(candidates_prob[best_candidates, :accept_length]), dim=-1
)
best_candidate = best_candidates[torch.argmax(likelihood)]
return best_candidate, accept_length
# Calculate posterior probabilities and thresholds for candidate selection
posterior_mask = get_typical_posterior_mask(logits, candidates, temperature, posterior_threshold, posterior_alpha)
candidates_accept_length = (torch.cumprod(posterior_mask, dim=1)).sum(dim=1)
# Choose the best candidate based on the evaluated posterior probabilities
accept_length = candidates_accept_length.max().item()
if accept_length == 0:
# If no candidates are accepted, just choose the first one
best_candidate = torch.tensor(0, dtype=torch.long, device=candidates.device)
else:
best_candidate = torch.argmax(candidates_accept_length).to(torch.long)
# Accept the best one according to likelihood
return best_candidate, accept_length
elif sampling == 'nucleus':
assert top_p < 1.0 + 1e-6, "top_p should between 0 and 1"
posterior_mask = get_nucleus_posterior_mask(logits, candidates, temperature, top_p)
candidates_accept_length = (torch.cumprod(posterior_mask, dim=1)).sum(dim=1)
accept_length = candidates_accept_length.max().item()
# Choose the best candidate
if accept_length == 0:
# Default to the first candidate if none are accepted
best_candidate = torch.tensor(0, dtype=torch.long, device=candidates.device)
else:
best_candidate = torch.argmax(candidates_accept_length).to(torch.long)
return best_candidate, accept_length
else:
raise NotImplementedError
def update_inference_inputs(
input_ids,
medusa_logits,
logits,
candidate_ids,
best_candidate,
accept_length,
):
input_ids = torch.cat(
[
input_ids,
candidate_ids[None, best_candidate, : accept_length + 1]
],
dim=-1
)
logits = logits[
None, best_candidate, accept_length : accept_length + 1
]
medusa_logits = medusa_logits[
:, None, best_candidate, accept_length : accept_length + 1
]
# Update the new token counter
new_token = accept_length + 1
return input_ids, medusa_logits, logits, new_token
def split_logits(full_logits):
# logits has shape [b, n, heads, vocab_size]
logits = full_logits[..., 0, :]
medusa_logits = full_logits[..., 1:, :].permute(2, 0, 1, 3)
return medusa_logits, logits
class MultiByteDecodingMixin:
def multi_byte_pred_update_cache(
self,
past_key_values,
retrieve_indices,
best_candidate,
new_tokens,
):
prev_window_len = past_key_values.get_past_window_pos(0)
select_indices = (
retrieve_indices[best_candidate, : new_tokens] + prev_window_len
)
for layer_idx in range(self.config.num_hidden_layers):
past_key_values.update_past_len(new_tokens, layer_idx)
past_window_k = past_key_values.past_window_k[layer_idx]
past_window_v = past_key_values.past_window_v[layer_idx]
tgt_window_k = past_window_k[..., select_indices, :]
tgt_window_v = past_window_v[..., select_indices, :]
dst_window_k = past_window_k[..., prev_window_len : prev_window_len + new_tokens, :]
dst_window_v = past_window_v[..., prev_window_len : prev_window_len + new_tokens, :]
dst_window_k.copy_(tgt_window_k, non_blocking=True)
dst_window_v.copy_(tgt_window_v, non_blocking=True)
new_window_len = prev_window_len + new_tokens
if new_window_len >= self.config.window_size:
assert new_window_len < 2 * self.config.window_size
dump_k = past_window_k[..., :self.config.window_size, :].clone()
dump_v = past_window_v[..., :self.config.window_size, :].clone()
_window_len = new_window_len - self.config.window_size
if _window_len > 0:
new_window_k = past_window_k[..., self.config.window_size : new_window_len, :]
new_window_v = past_window_v[..., self.config.window_size : new_window_len, :]
_dst_window_k = past_window_k[..., : _window_len, :]
_dst_window_v = past_window_v[..., : _window_len, :]
_dst_window_k.copy_(new_window_k, non_blocking=True)
_dst_window_v.copy_(new_window_v, non_blocking=True)
past_key_values.past_window_pos[layer_idx] = _window_len
else:
dump_k = None
dump_v = None
past_key_values.past_window_pos[layer_idx] = new_window_len
if dump_k is not None and dump_v is not None:
rfa_k, rfa_v = triton_eva_prep_kv_fwd(
dump_k, dump_v,
self.model.layers[layer_idx].self_attn.adaptive_mu_k,
self.model.layers[layer_idx].self_attn.adaptive_phi,
None,
self.model.layers[layer_idx].self_attn.head_dim_scaling,
self.model.layers[layer_idx].self_attn.chunk_size
)
rfa_k, rfa_v = past_key_values.update_chunk_rfas(
rfa_k, rfa_v, layer_idx
)
return past_key_values
def _multi_byte_pred_update_cache_when_prefil_len_eq_window_size(
self,
past_key_values,
):
prev_window_len = past_key_values.get_past_window_pos(0)
for layer_idx in range(self.config.num_hidden_layers):
past_window_k = past_key_values.past_window_k[layer_idx]
past_window_v = past_key_values.past_window_v[layer_idx]
new_window_len = prev_window_len
if new_window_len == self.config.window_size:
dump_k = past_window_k[..., :self.config.window_size, :].clone()
dump_v = past_window_v[..., :self.config.window_size, :].clone()
past_key_values.past_window_pos[layer_idx] = 0
if dump_k is not None and dump_v is not None:
rfa_k, rfa_v = triton_eva_prep_kv_fwd(
dump_k, dump_v,
self.model.layers[layer_idx].self_attn.adaptive_mu_k,
self.model.layers[layer_idx].self_attn.adaptive_phi,
None,
self.model.layers[layer_idx].self_attn.head_dim_scaling,
self.model.layers[layer_idx].self_attn.chunk_size
)
rfa_k, rfa_v = past_key_values.update_chunk_rfas(
rfa_k, rfa_v, layer_idx
)
return past_key_values
def multi_byte_pred_update_attn_mask(
self,
last_iter_new_tokens,
tree_candidate_ids,
past_attn_mask,
medusa_attn_mask,
past_key_values,
):
batch_size, tree_candidate_len = tree_candidate_ids.shape
seen_tokens = past_key_values.get_seq_length()
# NOTE: past_key_values has been updated so now
# seen_tokens incldues new tokens from the last tree iteration
assert seen_tokens > 0
# so one iteration would not cross two windows
assert last_iter_new_tokens < self.config.window_size
if past_attn_mask is not None and seen_tokens < self.config.window_size:
past_attn_mask = torch.cat(
[
past_attn_mask,
torch.ones(
[batch_size, 1, tree_candidate_len, last_iter_new_tokens],
dtype=torch.bool,
device=self.device
)
],
dim=-1
)
else:
# we initialize attn mask each time when
# 1. the model crosses the window bounary, or
# 2. after prefilling
chunks_per_window = int(self.config.window_size // self.config.chunk_size)
window_tokens = seen_tokens % self.config.window_size
num_windows_seen_so_far = seen_tokens // self.config.window_size
attn_mask_len = num_windows_seen_so_far * chunks_per_window + window_tokens
past_attn_mask = torch.ones(
(batch_size, 1, tree_candidate_len, attn_mask_len),
dtype=torch.bool,
device=self.device
)
# note that 1 indicates the position is not masked
tree_attn_mask = torch.cat(
[
past_attn_mask,
medusa_attn_mask.to(torch.bool)
],
dim=-1
)
return tree_attn_mask, past_attn_mask
@torch.no_grad()
def multi_byte_generate(
self,
input_ids,
attention_mask=None,
temperature=0.0,
max_length=None,
max_new_tokens=None,
stopping_criteria=None,
posterior_threshold=0.09,
posterior_alpha=0.3,
top_p=0.8,
sampling='typical',
fast=True,
do_sample=False,
medusa_choices=None,
return_acc_lengths=False
):
if do_sample or temperature > 0.0:
fast = False
### Prepare `max_length` depending on other stopping criteria.
if max_new_tokens is not None:
max_length = max_new_tokens + input_ids.shape[-1]
elif max_new_tokens is None and max_length is None:
max_length = getattr(self.config, "max_position_embeddings", 32768)
### Set up stopping criteria
eos_stop_criteria = MultibyteEosTokenCriteria(self.generation_config.eos_token_id)
stop_criteria = StoppingCriteriaList()
if max_length is not None:
max_position_embeddings = getattr(self.config, "max_position_embeddings", None)
stop_criteria.append(
MaxLengthCriteria(
max_length=max_length,
max_position_embeddings=max_position_embeddings,
)
)
if stopping_criteria is not None and len(stopping_criteria) > 0:
stop_criteria.extend(stopping_criteria)
assert input_ids.shape[0] == 1, "Only support batch size 1 for now"
assert attention_mask is None, "Only support attention mask None for now"
# Avoid modifying the input_ids in-place
input_ids = input_ids.clone()
position_ids = torch.arange(0, input_ids.shape[1], device=self.device, dtype=int).reshape(1, -1)
####################################################
# 0. initialize the medusa buffers
####################################################
if medusa_choices is None:
medusa_choices = evabyte_7b_95
medusa_buffers = generate_medusa_buffers(
medusa_choices, device=self.device
)
past_key_values = EvaStaticCacheForTriton(
input_ids.shape[0],
self.config.num_attention_heads,
# we add 256 to allow tree ids
self.config.window_size + 256,
self.config.hidden_size // self.config.num_attention_heads,
self.config.num_hidden_layers,
self.lm_head.weight.dtype,
self.lm_head.weight.device,
)
# prefill to get medusa logits and logits
full_logits, past_key_values = self.forward(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
use_cache=True,
past_key_values=past_key_values,
return_all_pred_logits=True,
multibyte_decoding=False,
)
# handles an edge case where the prefill length == window_size
# we force the previous window to be dumped into RFA chunks
past_key_values = self._multi_byte_pred_update_cache_when_prefil_len_eq_window_size(
past_key_values
)
medusa_logits, logits = split_logits(full_logits)
past_attn_mask = None
last_iter_new_tokens = 0
max_iters = 32768
if return_acc_lengths:
acc_lengths = []
for _ in range(max_iters):
####################################################
# 1. generate candidate_ids with topk predictions from Medusa heads
####################################################
tree_candidate_ids, unflattened_candidate_ids = generate_candidates(
medusa_logits,
logits,
medusa_buffers["tree_indices"],
medusa_buffers["retrieve_indices"],
temperature=temperature,
posterior_alpha=posterior_alpha,
posterior_threshold=posterior_threshold,
top_p=top_p,
sampling=sampling,
fast=fast,
)
####################################################
# 2. Build the medusa attention mask and position ids
####################################################
# NOTE: 1 indicates the position is not masked
medusa_attn_mask, past_attn_mask = self.multi_byte_pred_update_attn_mask(
last_iter_new_tokens,
tree_candidate_ids,
past_attn_mask,
medusa_buffers["medusa_attn_mask"],
past_key_values,
)
medusa_position_ids = medusa_buffers["medusa_position_ids"] + input_ids.shape[1]
####################################################
# 3. tree decoding
####################################################
tree_full_logits, past_key_values = self.forward(
tree_candidate_ids,
past_key_values=past_key_values,
attention_mask=medusa_attn_mask,
position_ids=medusa_position_ids,
return_all_pred_logits=True,
multibyte_decoding=True,
)
_medusa_logits, _logits = split_logits(tree_full_logits)
medusa_logits = _medusa_logits[..., 0, medusa_buffers["retrieve_indices"], :]
logits = _logits[..., 0, medusa_buffers["retrieve_indices"], :]
####################################################
# 4. candidate selection
####################################################
# if the current iteration, with tree tokens, crosses window
# boundaries, trim the condidate_ids to be within the window
# so that those exceeded tokens (which will be inaccurate)
# will not be considered
tree_depth = unflattened_candidate_ids.shape[-1]
if tree_depth + past_key_values.get_past_window_pos(0) > self.config.window_size:
max_acc_len = self.config.window_size - past_key_values.get_past_window_pos(0)
_trimmed_unflattened_candidate_ids = unflattened_candidate_ids[:, :max_acc_len]
_trimmed_logits = logits[:, :max_acc_len]
else:
_trimmed_unflattened_candidate_ids = unflattened_candidate_ids
_trimmed_logits = logits
best_candidate, accept_length = evaluate_posterior(
_trimmed_logits,
_trimmed_unflattened_candidate_ids,
temperature,
posterior_threshold,
posterior_alpha,
top_p=top_p,
sampling=sampling,
fast=fast
)
####################################################
# 5. update model inputs and caches
####################################################
input_ids, medusa_logits, logits, last_iter_new_tokens = update_inference_inputs(
input_ids,
medusa_logits,
logits,
unflattened_candidate_ids,
best_candidate,
accept_length,
)
past_key_values = self.multi_byte_pred_update_cache(
past_key_values,
medusa_buffers["retrieve_indices"],
best_candidate,
last_iter_new_tokens,
)
if return_acc_lengths:
acc_lengths.append(last_iter_new_tokens)
if stop_criteria(input_ids, None) or eos_stop_criteria(input_ids, last_iter_new_tokens):
if return_acc_lengths:
return input_ids, acc_lengths
else:
return input_ids
if return_acc_lengths:
return input_ids, acc_lengths
else:
return input_ids