Spaces:
Running
Running
""" | |
Inference code for ChatGLM. | |
Adapted from https://huggingface.co./THUDM/chatglm-6b/blob/main/modeling_chatglm.py. | |
""" | |
import re | |
import torch | |
from transformers.generation.logits_process import LogitsProcessor | |
class InvalidScoreLogitsProcessor(LogitsProcessor): | |
def __call__( | |
self, input_ids: torch.LongTensor, scores: torch.FloatTensor | |
) -> torch.FloatTensor: | |
if torch.isnan(scores).any() or torch.isinf(scores).any(): | |
scores.zero_() | |
scores[..., 5] = 5e4 | |
return scores | |
invalid_score_processor = InvalidScoreLogitsProcessor() | |
def process_response(response): | |
response = response.strip() | |
response = response.replace("[[训练时间]]", "2023年") | |
punkts = [ | |
[",", ","], | |
["!", "!"], | |
[":", ":"], | |
[";", ";"], | |
["\?", "?"], | |
] | |
for item in punkts: | |
response = re.sub(r"([\u4e00-\u9fff])%s" % item[0], r"\1%s" % item[1], response) | |
response = re.sub(r"%s([\u4e00-\u9fff])" % item[0], r"%s\1" % item[1], response) | |
return response | |
def recover_message_list(prompt): | |
role_token_pattern = "|".join( | |
[re.escape(r) for r in ["<|system|>", "<|user|>", "<|assistant|>"]] | |
) | |
role = None | |
last_end_idx = -1 | |
message_list = [] | |
for match in re.finditer(role_token_pattern, prompt): | |
if role: | |
messge = {} | |
if role == "<|system|>": | |
messge["role"] = "system" | |
elif role == "<|user|>": | |
messge["role"] = "user" | |
else: | |
messge["role"] = "assistant" | |
messge["content"] = prompt[last_end_idx + 1 : match.start()] | |
message_list.append(messge) | |
role = prompt[match.start() : match.end()] | |
last_end_idx = match.end() | |
return message_list | |
def generate_stream_chatglm( | |
model, | |
tokenizer, | |
params, | |
device, | |
context_len=2048, | |
stream_interval=2, | |
judge_sent_end=False, | |
): | |
prompt = params["prompt"] | |
temperature = float(params.get("temperature", 1.0)) | |
repetition_penalty = float(params.get("repetition_penalty", 1.0)) | |
top_p = float(params.get("top_p", 1.0)) | |
max_new_tokens = int(params.get("max_new_tokens", 256)) | |
echo = params.get("echo", True) | |
model_type = str(type(model)).lower() | |
if "peft" in model_type: | |
model_type = str(type(model.base_model.model)).lower() | |
if "chatglm3" in model_type: | |
message_list = recover_message_list(prompt) | |
inputs = tokenizer.build_chat_input( | |
query=message_list[-1]["content"], history=message_list[:-1], role="user" | |
).to(model.device) | |
else: | |
inputs = tokenizer([prompt], return_tensors="pt").to(model.device) | |
input_echo_len = len(inputs["input_ids"][0]) | |
gen_kwargs = { | |
"max_length": max_new_tokens + input_echo_len, | |
"do_sample": True if temperature > 1e-5 else False, | |
"top_p": top_p, | |
"repetition_penalty": repetition_penalty, | |
"logits_processor": [invalid_score_processor], | |
} | |
if temperature > 1e-5: | |
gen_kwargs["temperature"] = temperature | |
total_len = 0 | |
for total_ids in model.stream_generate(**inputs, **gen_kwargs): | |
total_ids = total_ids.tolist()[0] | |
total_len = len(total_ids) | |
if echo: | |
output_ids = total_ids | |
else: | |
output_ids = total_ids[input_echo_len:] | |
response = tokenizer.decode(output_ids) | |
response = process_response(response) | |
yield { | |
"text": response, | |
"usage": { | |
"prompt_tokens": input_echo_len, | |
"completion_tokens": total_len - input_echo_len, | |
"total_tokens": total_len, | |
}, | |
"finish_reason": None, | |
} | |
# TODO: ChatGLM stop when it reach max length | |
# Only last stream result contains finish_reason, we set finish_reason as stop | |
ret = { | |
"text": response, | |
"usage": { | |
"prompt_tokens": input_echo_len, | |
"completion_tokens": total_len - input_echo_len, | |
"total_tokens": total_len, | |
}, | |
"finish_reason": "stop", | |
} | |
yield ret | |