|
import time |
|
|
|
import tiktoken |
|
import torch |
|
import torch.nn.functional as F |
|
import yaml |
|
from extensions.openai.defaults import clamp, default, get_default_req_params |
|
from extensions.openai.errors import InvalidRequestError |
|
from extensions.openai.utils import debug_msg, end_line |
|
from modules import shared |
|
from modules.text_generation import decode, encode, generate_reply |
|
from transformers import LogitsProcessor, LogitsProcessorList |
|
|
|
|
|
|
|
class LogitsBiasProcessor(LogitsProcessor): |
|
def __init__(self, logit_bias={}): |
|
self.logit_bias = logit_bias |
|
if self.logit_bias: |
|
self.keys = list([int(key) for key in self.logit_bias.keys()]) |
|
values = [self.logit_bias[str(key)] for key in self.keys] |
|
self.values = torch.tensor(values, dtype=torch.float, device=shared.model.device) |
|
debug_msg(f"{self})") |
|
|
|
def __call__(self, input_ids: torch.LongTensor, logits: torch.FloatTensor) -> torch.FloatTensor: |
|
if self.logit_bias: |
|
debug_msg(logits[0, self.keys], " + ", self.values) |
|
logits[0, self.keys] += self.values |
|
debug_msg(" --> ", logits[0, self.keys]) |
|
debug_msg(" max/min ", float(torch.max(logits[0])), float(torch.min(logits[0]))) |
|
return logits |
|
|
|
def __repr__(self): |
|
return f"<{self.__class__.__name__}(logit_bias={self.logit_bias})>" |
|
|
|
|
|
class LogprobProcessor(LogitsProcessor): |
|
def __init__(self, logprobs=None): |
|
self.logprobs = logprobs |
|
self.token_alternatives = {} |
|
|
|
def __call__(self, input_ids: torch.LongTensor, logits: torch.FloatTensor) -> torch.FloatTensor: |
|
if self.logprobs is not None: |
|
log_e_probabilities = F.log_softmax(logits, dim=1) |
|
top_values, top_indices = torch.topk(log_e_probabilities, k=self.logprobs + 1) |
|
top_tokens = [decode(tok) for tok in top_indices[0]] |
|
top_probs = [float(x) for x in top_values[0]] |
|
self.token_alternatives = dict(zip(top_tokens, top_probs)) |
|
debug_msg(repr(self)) |
|
return logits |
|
|
|
def __repr__(self): |
|
return f"<{self.__class__.__name__}(logprobs={self.logprobs}, token_alternatives={self.token_alternatives})>" |
|
|
|
|
|
def convert_logprobs_to_tiktoken(model, logprobs): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return logprobs |
|
|
|
|
|
def marshal_common_params(body): |
|
|
|
|
|
|
|
req_params = get_default_req_params() |
|
|
|
|
|
req_params['truncation_length'] = shared.settings['truncation_length'] |
|
req_params['add_bos_token'] = shared.settings.get('add_bos_token', req_params['add_bos_token']) |
|
req_params['seed'] = shared.settings.get('seed', req_params['seed']) |
|
req_params['custom_stopping_strings'] = shared.settings['custom_stopping_strings'] |
|
|
|
|
|
|
|
req_params['requested_model'] = body.get('model', shared.model_name) |
|
|
|
req_params['suffix'] = default(body, 'suffix', req_params['suffix']) |
|
req_params['temperature'] = clamp(default(body, 'temperature', req_params['temperature']), 0.01, 1.99) |
|
req_params['top_p'] = clamp(default(body, 'top_p', req_params['top_p']), 0.01, 1.0) |
|
n = default(body, 'n', 1) |
|
if n != 1: |
|
raise InvalidRequestError(message="Only n = 1 is supported.", param='n') |
|
|
|
if 'stop' in body: |
|
if isinstance(body['stop'], str): |
|
req_params['stopping_strings'] = [body['stop']] |
|
elif isinstance(body['stop'], list): |
|
req_params['stopping_strings'] = body['stop'] |
|
|
|
|
|
|
|
|
|
|
|
req_params['repetition_penalty'] = default(body, 'repetition_penalty', req_params['repetition_penalty']) |
|
req_params['encoder_repetition_penalty'] = default(body, 'encoder_repetition_penalty', req_params['encoder_repetition_penalty']) |
|
|
|
|
|
|
|
logits_processor = [] |
|
logit_bias = body.get('logit_bias', None) |
|
if logit_bias: |
|
|
|
|
|
try: |
|
encoder = tiktoken.encoding_for_model(req_params['requested_model']) |
|
new_logit_bias = {} |
|
for logit, bias in logit_bias.items(): |
|
for x in encode(encoder.decode([int(logit)]), add_special_tokens=False)[0]: |
|
if int(x) in [0, 1, 2, 29871]: |
|
continue |
|
new_logit_bias[str(int(x))] = bias |
|
debug_msg('logit_bias_map', logit_bias, '->', new_logit_bias) |
|
logit_bias = new_logit_bias |
|
except KeyError: |
|
pass |
|
|
|
logits_processor = [LogitsBiasProcessor(logit_bias)] |
|
|
|
logprobs = None |
|
if 'logprobs' in body: |
|
logprobs = default(body, 'logprobs', 0) |
|
req_params['logprob_proc'] = LogprobProcessor(logprobs) |
|
logits_processor.extend([req_params['logprob_proc']]) |
|
else: |
|
logprobs = None |
|
|
|
if logits_processor: |
|
req_params['logits_processor'] = LogitsProcessorList(logits_processor) |
|
|
|
return req_params |
|
|
|
|
|
def messages_to_prompt(body: dict, req_params: dict, max_tokens): |
|
|
|
if body.get('functions', []): |
|
raise InvalidRequestError(message="functions is not supported.", param='functions') |
|
if body.get('function_call', ''): |
|
raise InvalidRequestError(message="function_call is not supported.", param='function_call') |
|
|
|
if 'messages' not in body: |
|
raise InvalidRequestError(message="messages is required", param='messages') |
|
|
|
messages = body['messages'] |
|
|
|
role_formats = { |
|
'user': 'User: {message}\n', |
|
'assistant': 'Assistant: {message}\n', |
|
'system': '{message}', |
|
'context': 'You are a helpful assistant. Answer as concisely as possible.\nUser: I want your assistance.\nAssistant: Sure! What can I do for you?', |
|
'prompt': 'Assistant:', |
|
} |
|
|
|
if 'stopping_strings' not in req_params: |
|
req_params['stopping_strings'] = [] |
|
|
|
|
|
if shared.settings['instruction_template']: |
|
try: |
|
instruct = yaml.safe_load(open(f"instruction-templates/{shared.settings['instruction_template']}.yaml", 'r')) |
|
|
|
template = instruct['turn_template'] |
|
system_message_template = "{message}" |
|
system_message_default = instruct.get('context', '') |
|
bot_start = template.find('<|bot|>') |
|
user_message_template = template[:bot_start].replace('<|user-message|>', '{message}').replace('<|user|>', instruct.get('user', '')) |
|
bot_message_template = template[bot_start:].replace('<|bot-message|>', '{message}').replace('<|bot|>', instruct.get('bot', '')) |
|
bot_prompt = bot_message_template[:bot_message_template.find('{message}')].rstrip(' ') |
|
|
|
role_formats = { |
|
'user': user_message_template, |
|
'assistant': bot_message_template, |
|
'system': system_message_template, |
|
'context': system_message_default, |
|
'prompt': bot_prompt, |
|
} |
|
|
|
if 'Alpaca' in shared.settings['instruction_template']: |
|
req_params['stopping_strings'].extend(['\n###']) |
|
elif instruct['user']: |
|
req_params['stopping_strings'].extend(['\n' + instruct['user'], instruct['user']]) |
|
|
|
debug_msg(f"Loaded instruction role format: {shared.settings['instruction_template']}") |
|
|
|
except Exception as e: |
|
req_params['stopping_strings'].extend(['\nUser:', 'User:']) |
|
|
|
print(f"Exception: When loading instruction-templates/{shared.settings['instruction_template']}.yaml: {repr(e)}") |
|
print("Warning: Loaded default instruction-following template for model.") |
|
|
|
else: |
|
req_params['stopping_strings'].extend(['\nUser:', 'User:']) |
|
print("Warning: Loaded default instruction-following template for model.") |
|
|
|
system_msgs = [] |
|
chat_msgs = [] |
|
|
|
|
|
context_msg = role_formats['system'].format(message=role_formats['context']) if role_formats['context'] else '' |
|
context_msg = end_line(context_msg) |
|
|
|
|
|
if 'prompt' in body: |
|
context_msg = end_line(role_formats['system'].format(message=body['prompt'])) + context_msg |
|
|
|
for m in messages: |
|
if 'role' not in m: |
|
raise InvalidRequestError(message="messages: missing role", param='messages') |
|
if 'content' not in m: |
|
raise InvalidRequestError(message="messages: missing content", param='messages') |
|
|
|
role = m['role'] |
|
content = m['content'] |
|
|
|
|
|
msg = role_formats[role].format(message=content) |
|
if role == 'system': |
|
system_msgs.extend([msg]) |
|
elif role == 'function': |
|
raise InvalidRequestError(message="role: function is not supported.", param='messages') |
|
else: |
|
chat_msgs.extend([msg]) |
|
|
|
system_msg = '\n'.join(system_msgs) |
|
system_msg = end_line(system_msg) |
|
|
|
prompt = system_msg + context_msg + ''.join(chat_msgs) + role_formats['prompt'] |
|
|
|
token_count = len(encode(prompt)[0]) |
|
|
|
if token_count >= req_params['truncation_length']: |
|
err_msg = f"This model maximum context length is {req_params['truncation_length']} tokens. However, your messages resulted in over {token_count} tokens." |
|
raise InvalidRequestError(message=err_msg, param='messages') |
|
|
|
if max_tokens > 0 and token_count + max_tokens > req_params['truncation_length']: |
|
err_msg = f"This model maximum context length is {req_params['truncation_length']} tokens. However, your messages resulted in over {token_count} tokens and max_tokens is {max_tokens}." |
|
print(f"Warning: ${err_msg}") |
|
|
|
|
|
return prompt, token_count |
|
|
|
|
|
def chat_completions(body: dict, is_legacy: bool = False) -> dict: |
|
|
|
object_type = 'chat.completions' |
|
created_time = int(time.time()) |
|
cmpl_id = "chatcmpl-%d" % (int(time.time() * 1000000000)) |
|
resp_list = 'data' if is_legacy else 'choices' |
|
|
|
|
|
req_params = marshal_common_params(body) |
|
req_params['stream'] = False |
|
requested_model = req_params.pop('requested_model') |
|
logprob_proc = req_params.pop('logprob_proc', None) |
|
req_params['top_k'] = 20 |
|
|
|
|
|
max_tokens = 0 |
|
max_tokens_str = 'length' if is_legacy else 'max_tokens' |
|
if max_tokens_str in body: |
|
max_tokens = default(body, max_tokens_str, req_params['truncation_length']) |
|
req_params['max_new_tokens'] = max_tokens |
|
else: |
|
req_params['max_new_tokens'] = req_params['truncation_length'] |
|
|
|
|
|
prompt, token_count = messages_to_prompt(body, req_params, max_tokens) |
|
|
|
|
|
if req_params['max_new_tokens'] + token_count >= req_params['truncation_length']: |
|
req_params['max_new_tokens'] = req_params['truncation_length'] - token_count |
|
|
|
stopping_strings = req_params.pop('stopping_strings', []) |
|
|
|
|
|
debug_msg({'prompt': prompt, 'req_params': req_params}) |
|
generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False) |
|
|
|
answer = '' |
|
for a in generator: |
|
answer = a |
|
|
|
|
|
if answer and answer[0] == ' ': |
|
answer = answer[1:] |
|
|
|
completion_token_count = len(encode(answer)[0]) |
|
stop_reason = "stop" |
|
if token_count + completion_token_count >= req_params['truncation_length'] or completion_token_count >= req_params['max_new_tokens']: |
|
stop_reason = "length" |
|
|
|
resp = { |
|
"id": cmpl_id, |
|
"object": object_type, |
|
"created": created_time, |
|
"model": shared.model_name, |
|
resp_list: [{ |
|
"index": 0, |
|
"finish_reason": stop_reason, |
|
"message": {"role": "assistant", "content": answer} |
|
}], |
|
"usage": { |
|
"prompt_tokens": token_count, |
|
"completion_tokens": completion_token_count, |
|
"total_tokens": token_count + completion_token_count |
|
} |
|
} |
|
if logprob_proc: |
|
top_logprobs = convert_logprobs_to_tiktoken(model=requested_model, logprobs=logprob_proc.token_alternatives) |
|
resp[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]} |
|
|
|
|
|
|
|
return resp |
|
|
|
|
|
|
|
def stream_chat_completions(body: dict, is_legacy: bool = False): |
|
|
|
|
|
stream_object_type = 'chat.completions.chunk' |
|
created_time = int(time.time()) |
|
cmpl_id = "chatcmpl-%d" % (int(time.time() * 1000000000)) |
|
resp_list = 'data' if is_legacy else 'choices' |
|
|
|
|
|
req_params = marshal_common_params(body) |
|
req_params['stream'] = True |
|
requested_model = req_params.pop('requested_model') |
|
logprob_proc = req_params.pop('logprob_proc', None) |
|
req_params['top_k'] = 20 |
|
|
|
|
|
max_tokens = 0 |
|
max_tokens_str = 'length' if is_legacy else 'max_tokens' |
|
if max_tokens_str in body: |
|
max_tokens = default(body, max_tokens_str, req_params['truncation_length']) |
|
req_params['max_new_tokens'] = max_tokens |
|
else: |
|
req_params['max_new_tokens'] = req_params['truncation_length'] |
|
|
|
|
|
prompt, token_count = messages_to_prompt(body, req_params, max_tokens) |
|
|
|
|
|
if req_params['max_new_tokens'] + token_count >= req_params['truncation_length']: |
|
req_params['max_new_tokens'] = req_params['truncation_length'] - token_count |
|
|
|
def chat_streaming_chunk(content): |
|
|
|
chunk = { |
|
"id": cmpl_id, |
|
"object": stream_object_type, |
|
"created": created_time, |
|
"model": shared.model_name, |
|
resp_list: [{ |
|
"index": 0, |
|
"finish_reason": None, |
|
|
|
"message": {'role': 'assistant', 'content': content}, |
|
"delta": {'role': 'assistant', 'content': content}, |
|
}], |
|
} |
|
|
|
if logprob_proc: |
|
top_logprobs = convert_logprobs_to_tiktoken(model=requested_model, logprobs=logprob_proc.token_alternatives) |
|
chunk[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]} |
|
|
|
|
|
return chunk |
|
|
|
yield chat_streaming_chunk('') |
|
|
|
|
|
debug_msg({'prompt': prompt, 'req_params': req_params}) |
|
|
|
stopping_strings = req_params.pop('stopping_strings', []) |
|
|
|
generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False) |
|
|
|
answer = '' |
|
seen_content = '' |
|
completion_token_count = 0 |
|
|
|
for a in generator: |
|
answer = a |
|
|
|
len_seen = len(seen_content) |
|
new_content = answer[len_seen:] |
|
|
|
if not new_content or chr(0xfffd) in new_content: |
|
continue |
|
|
|
seen_content = answer |
|
|
|
|
|
if len_seen == 0 and new_content[0] == ' ': |
|
new_content = new_content[1:] |
|
|
|
chunk = chat_streaming_chunk(new_content) |
|
|
|
yield chunk |
|
|
|
|
|
if answer and answer[0] == ' ': |
|
answer = answer[1:] |
|
|
|
completion_token_count = len(encode(answer)[0]) |
|
stop_reason = "stop" |
|
if token_count + completion_token_count >= req_params['truncation_length'] or completion_token_count >= req_params['max_new_tokens']: |
|
stop_reason = "length" |
|
|
|
chunk = chat_streaming_chunk('') |
|
chunk[resp_list][0]['finish_reason'] = stop_reason |
|
chunk['usage'] = { |
|
"prompt_tokens": token_count, |
|
"completion_tokens": completion_token_count, |
|
"total_tokens": token_count + completion_token_count |
|
} |
|
|
|
yield chunk |
|
|
|
|
|
def completions(body: dict, is_legacy: bool = False): |
|
|
|
|
|
object_type = 'text_completion' |
|
created_time = int(time.time()) |
|
cmpl_id = "conv-%d" % (int(time.time() * 1000000000)) |
|
resp_list = 'data' if is_legacy else 'choices' |
|
|
|
|
|
prompt_str = 'context' if is_legacy else 'prompt' |
|
if prompt_str not in body: |
|
raise InvalidRequestError("Missing required input", param=prompt_str) |
|
|
|
prompt_arg = body[prompt_str] |
|
if isinstance(prompt_arg, str) or (isinstance(prompt_arg, list) and isinstance(prompt_arg[0], int)): |
|
prompt_arg = [prompt_arg] |
|
|
|
|
|
req_params = marshal_common_params(body) |
|
req_params['stream'] = False |
|
max_tokens_str = 'length' if is_legacy else 'max_tokens' |
|
max_tokens = default(body, max_tokens_str, req_params['max_new_tokens']) |
|
req_params['max_new_tokens'] = max_tokens |
|
requested_model = req_params.pop('requested_model') |
|
logprob_proc = req_params.pop('logprob_proc', None) |
|
stopping_strings = req_params.pop('stopping_strings', []) |
|
|
|
req_params['echo'] = default(body, 'echo', req_params['echo']) |
|
req_params['top_k'] = default(body, 'best_of', req_params['top_k']) |
|
|
|
resp_list_data = [] |
|
total_completion_token_count = 0 |
|
total_prompt_token_count = 0 |
|
|
|
for idx, prompt in enumerate(prompt_arg, start=0): |
|
if isinstance(prompt[0], int): |
|
|
|
if requested_model == shared.model_name: |
|
prompt = decode(prompt)[0] |
|
else: |
|
try: |
|
encoder = tiktoken.encoding_for_model(requested_model) |
|
prompt = encoder.decode(prompt) |
|
except KeyError: |
|
prompt = decode(prompt)[0] |
|
|
|
token_count = len(encode(prompt)[0]) |
|
total_prompt_token_count += token_count |
|
|
|
if token_count + max_tokens > req_params['truncation_length']: |
|
err_msg = f"The token count of your prompt ({token_count}) plus max_tokens ({max_tokens}) cannot exceed the model's context length ({req_params['truncation_length']})." |
|
|
|
raise InvalidRequestError(message=err_msg, param=max_tokens_str) |
|
|
|
|
|
debug_msg({'prompt': prompt, 'req_params': req_params}) |
|
generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False) |
|
answer = '' |
|
|
|
for a in generator: |
|
answer = a |
|
|
|
|
|
if answer and answer[0] == ' ': |
|
answer = answer[1:] |
|
|
|
completion_token_count = len(encode(answer)[0]) |
|
total_completion_token_count += completion_token_count |
|
stop_reason = "stop" |
|
if token_count + completion_token_count >= req_params['truncation_length'] or completion_token_count >= max_tokens: |
|
stop_reason = "length" |
|
|
|
respi = { |
|
"index": idx, |
|
"finish_reason": stop_reason, |
|
"text": answer, |
|
"logprobs": {'top_logprobs': [logprob_proc.token_alternatives]} if logprob_proc else None, |
|
} |
|
|
|
resp_list_data.extend([respi]) |
|
|
|
resp = { |
|
"id": cmpl_id, |
|
"object": object_type, |
|
"created": created_time, |
|
"model": shared.model_name, |
|
resp_list: resp_list_data, |
|
"usage": { |
|
"prompt_tokens": total_prompt_token_count, |
|
"completion_tokens": total_completion_token_count, |
|
"total_tokens": total_prompt_token_count + total_completion_token_count |
|
} |
|
} |
|
|
|
return resp |
|
|
|
|
|
|
|
def stream_completions(body: dict, is_legacy: bool = False): |
|
|
|
|
|
|
|
stream_object_type = 'text_completion.chunk' |
|
created_time = int(time.time()) |
|
cmpl_id = "conv-%d" % (int(time.time() * 1000000000)) |
|
resp_list = 'data' if is_legacy else 'choices' |
|
|
|
|
|
prompt_str = 'context' if is_legacy else 'prompt' |
|
if prompt_str not in body: |
|
raise InvalidRequestError("Missing required input", param=prompt_str) |
|
|
|
prompt = body[prompt_str] |
|
req_params = marshal_common_params(body) |
|
requested_model = req_params.pop('requested_model') |
|
if isinstance(prompt, list): |
|
if prompt and isinstance(prompt[0], int): |
|
try: |
|
encoder = tiktoken.encoding_for_model(requested_model) |
|
prompt = encoder.decode(prompt) |
|
except KeyError: |
|
prompt = decode(prompt)[0] |
|
else: |
|
raise InvalidRequestError(message="API Batched generation not yet supported.", param=prompt_str) |
|
|
|
|
|
req_params['stream'] = True |
|
max_tokens_str = 'length' if is_legacy else 'max_tokens' |
|
max_tokens = default(body, max_tokens_str, req_params['max_new_tokens']) |
|
req_params['max_new_tokens'] = max_tokens |
|
logprob_proc = req_params.pop('logprob_proc', None) |
|
stopping_strings = req_params.pop('stopping_strings', []) |
|
|
|
req_params['echo'] = default(body, 'echo', req_params['echo']) |
|
req_params['top_k'] = default(body, 'best_of', req_params['top_k']) |
|
|
|
token_count = len(encode(prompt)[0]) |
|
|
|
if token_count + max_tokens > req_params['truncation_length']: |
|
err_msg = f"The token count of your prompt ({token_count}) plus max_tokens ({max_tokens}) cannot exceed the model's context length ({req_params['truncation_length']})." |
|
|
|
raise InvalidRequestError(message=err_msg, param=max_tokens_str) |
|
|
|
def text_streaming_chunk(content): |
|
|
|
chunk = { |
|
"id": cmpl_id, |
|
"object": stream_object_type, |
|
"created": created_time, |
|
"model": shared.model_name, |
|
resp_list: [{ |
|
"index": 0, |
|
"finish_reason": None, |
|
"text": content, |
|
"logprobs": {'top_logprobs': [logprob_proc.token_alternatives]} if logprob_proc else None, |
|
}], |
|
} |
|
|
|
return chunk |
|
|
|
yield text_streaming_chunk('') |
|
|
|
|
|
debug_msg({'prompt': prompt, 'req_params': req_params}) |
|
generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False) |
|
|
|
answer = '' |
|
seen_content = '' |
|
completion_token_count = 0 |
|
|
|
for a in generator: |
|
answer = a |
|
|
|
len_seen = len(seen_content) |
|
new_content = answer[len_seen:] |
|
|
|
if not new_content or chr(0xfffd) in new_content: |
|
continue |
|
|
|
seen_content = answer |
|
|
|
|
|
if len_seen == 0 and new_content[0] == ' ': |
|
new_content = new_content[1:] |
|
|
|
chunk = text_streaming_chunk(new_content) |
|
|
|
yield chunk |
|
|
|
|
|
if answer and answer[0] == ' ': |
|
answer = answer[1:] |
|
|
|
completion_token_count = len(encode(answer)[0]) |
|
stop_reason = "stop" |
|
if token_count + completion_token_count >= req_params['truncation_length'] or completion_token_count >= max_tokens: |
|
stop_reason = "length" |
|
|
|
chunk = text_streaming_chunk('') |
|
chunk[resp_list][0]["finish_reason"] = stop_reason |
|
chunk["usage"] = { |
|
"prompt_tokens": token_count, |
|
"completion_tokens": completion_token_count, |
|
"total_tokens": token_count + completion_token_count |
|
} |
|
|
|
yield chunk |
|
|