Spaces:
Running
Running
import torch | |
import gc | |
import os | |
import time | |
import random | |
from typing import Dict, Optional, Sequence, List, Tuple | |
from transformers.cache_utils import Cache, DynamicCache | |
from transformers import ( | |
LlamaModel, | |
LlamaForCausalLM, | |
GenerationConfig, | |
StoppingCriteria, | |
StoppingCriteriaList, | |
TextIteratorStreamer, | |
) | |
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask | |
import torch.nn.functional as F | |
def get_jacobian_trajectory( | |
model, tokenizer, input_ids, attention_mask, max_new_tokens | |
): | |
bsz = input_ids.shape[0] | |
prompt_len = [torch.sum(t) for t in attention_mask] | |
max_prompt_len = max(prompt_len) | |
total_len = max_prompt_len + max_new_tokens | |
# initialize the first point of jacobian trajectory | |
tokens = torch.full( | |
(bsz, total_len), tokenizer.pad_token_id, dtype=torch.long, device=model.device | |
) | |
for i in range(bsz): | |
tokens[i, :] = torch.tensor( | |
random.choices(input_ids[i][attention_mask[i] == 1], k=total_len), | |
dtype=torch.long, | |
device=model.device, | |
) | |
tokens[i, : prompt_len[i]] = input_ids[i][: prompt_len[i]].to( | |
dtype=torch.long, device=model.device | |
) | |
itr = 0 | |
next_generation = tokens | |
generate_attention_mask = torch.full_like(next_generation, 1).to(model.device) | |
accurate_lengths = torch.tensor([prompt_len[i].item()] * bsz, device=model.device) | |
prev_len = 0 | |
while True: | |
current_generation = next_generation | |
with torch.no_grad(): | |
logits = model(current_generation, generate_attention_mask).logits | |
next_generation = torch.argmax( | |
torch.nn.functional.softmax(logits, dim=-1) / 0.001, dim=-1 | |
) | |
# hold prompt unchanged and update generated tokens | |
for i in range(bsz): | |
next_generation[i, :] = torch.cat( | |
( | |
tokens[i, : prompt_len[i]], | |
next_generation[i, prompt_len[i] - 1 : total_len - 1], | |
), | |
dim=0, | |
) | |
if ( | |
torch.all(torch.eq(next_generation, current_generation)).item() | |
and itr == max_new_tokens | |
or len( | |
torch.where( | |
current_generation[0, : accurate_lengths[0]] | |
== tokenizer.eos_token_id | |
)[0] | |
) | |
> 0 | |
): | |
# forced exit due to max_new_tokens constraint or eos reached | |
return next_generation, itr | |
# skip the first itr, current_generation has not been updated yet | |
if itr != 0: | |
if torch.all(torch.eq(next_generation, current_generation)).item(): | |
matched_position = total_len | |
else: | |
matched_position = ( | |
torch.eq(current_generation, next_generation).squeeze(0) == False | |
).nonzero(as_tuple=True)[0][0] | |
fast_forward_cnt = matched_position - accurate_lengths[0] | |
for i in range(bsz): | |
accurate_lengths[i] = matched_position.item() | |
# flush and print the first sequence | |
generated_str = tokenizer.decode( | |
next_generation[0, prompt_len[0] : accurate_lengths[0]], | |
skip_special_tokens=True, | |
spaces_between_special_tokens=False, | |
clean_up_tokenization_spaces=True, | |
) | |
print(generated_str[prev_len:], flush=True, end="") | |
prev_len = len(generated_str) | |
if torch.all(torch.eq(next_generation, current_generation)).item(): | |
# early termination: itr < max_new_tokens | |
return next_generation, itr | |
itr += 1 | |
def generate_stream_cllm( | |
model, | |
tokenizer, | |
params, | |
device, | |
context_len, | |
stream_interval=2, | |
judge_sent_end=False, | |
): | |
# converge_step = [] | |
prompt = params["prompt"] | |
inputs = tokenizer(prompt, return_tensors="pt").to(device) | |
max_new_tokens = int(params.get("n_token_seq_length", 32)) | |
max_new_seq_len = int(params.get("max_new_tokens", 1024)) | |
prompt_len = torch.sum(inputs["attention_mask"], dim=-1) | |
generation = inputs["input_ids"] | |
input_echo_len = len(generation) | |
### generation phase | |
itr = 0 | |
eos_reached = False | |
while True: | |
if itr == 0: | |
input_ids = inputs["input_ids"] | |
input_masks = inputs["attention_mask"] | |
else: | |
input_masks = torch.ones_like(input_ids).to(device) | |
for j in range(bsz): | |
input_masks[j][ | |
torch.sum(inputs["attention_mask"], dim=-1)[j] | |
+ itr * max_new_tokens : | |
] = 0 | |
bsz = input_ids.shape[0] | |
eos_reached = torch.tensor([False] * bsz, device=device) | |
generation, iter_steps = get_jacobian_trajectory( | |
model=model, | |
tokenizer=tokenizer, | |
input_ids=input_ids, | |
attention_mask=input_masks, | |
max_new_tokens=max_new_tokens, | |
) | |
### inspect <eos> | |
for j in range(bsz): | |
prompt_len = torch.sum(input_masks, dim=-1) | |
eos_positions = torch.where(generation[j] == tokenizer.eos_token_id)[0] | |
if len(eos_positions) == 0: | |
# no EOS, continue to the next item in the batch | |
generation[j][prompt_len[j] + max_new_tokens :] = tokenizer.pad_token_id | |
continue | |
# otherwise, set tokens coming after EOS as pad | |
else: | |
if len(eos_positions) != 0: | |
eos_reached[j] = True | |
generation[j, int(eos_positions[0]) + 1 :] = tokenizer.pad_token_id | |
itr += 1 | |
if all(eos_reached) or itr * max_new_tokens >= max_new_seq_len: | |
break | |
input_ids = generation[ | |
torch.where(eos_reached == False)[0].tolist(), ... | |
] # delete samples with <eos> generated | |
if all(eos_reached): | |
finish_reason = "eos" | |
elif itr * max_new_tokens > max_new_seq_len: | |
finish_reason = "length" | |
else: | |
finish_reason = "stop" | |
output = tokenizer.decode(input_ids[0], skip_special_tokens=False) | |
yield { | |
"text": "", | |
"usage": { | |
"prompt_tokens": input_echo_len, | |
"completion_tokens": itr * max_new_tokens, | |
"total_tokens": input_echo_len + itr * max_new_tokens, | |
}, | |
"finish_reason": finish_reason, | |
} | |
# clean | |
gc.collect() | |
torch.cuda.empty_cache() | |
if device == "xpu": | |
torch.xpu.empty_cache() | |
if device == "npu": | |
torch.npu.empty_cache() | |