|
from copy import deepcopy |
|
import json |
|
from typing import Any, Dict, List, Literal, Optional, Union |
|
|
|
import jsonref |
|
from pydantic import BaseModel, Field, model_validator |
|
from typing_extensions import Self |
|
|
|
from transformers.tokenization_utils_base import BatchEncoding |
|
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast |
|
from transformers.utils import TensorType, logging |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
SYSTEM_PROMPT = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. The assistant calls functions with appropriate input when necessary""" |
|
CODE_INTERPRETER_SYSTEM_PROMPT = """When you send a message containing Python code to python, it will be executed in a stateful Jupyter notebook environment. python will respond with the output of the execution or time out after 60.0 seconds. The drive at '/mnt/data' can be used to save and persist user files.""" |
|
|
|
class Function(BaseModel): |
|
name: str |
|
description: Optional[str] = Field(default="") |
|
parameters: Optional[dict] = None |
|
|
|
|
|
class Tool(BaseModel): |
|
type: Literal["function", "code_interpreter"] |
|
function: Optional[Function] = None |
|
|
|
@model_validator(mode="after") |
|
def check_type_function_matches(self) -> Self: |
|
if self.type == "function": |
|
assert self.function is not None, '"function" must contain function description when `"type": "function"`' |
|
else: |
|
assert self.function is None, '"function" must not be provided when `"type": "code_interpreter"`' |
|
return self |
|
|
|
|
|
def convert_data_type(param_type: str) -> str: |
|
"""convert data_type to typescript data type |
|
|
|
Args: |
|
param_type (str): param_type |
|
|
|
Returns: |
|
str: param type in typescript |
|
""" |
|
if param_type == "integer" or param_type == "float": |
|
return "number" |
|
return param_type |
|
|
|
|
|
def get_param_type(param: Dict) -> str: |
|
"""get param_type of parameter |
|
|
|
Args: |
|
param (Dict): param dict in properties |
|
|
|
Returns: |
|
str: _description_ |
|
""" |
|
param_type = "any" |
|
if "type" in param: |
|
raw_param_type = param["type"] |
|
if type(raw_param_type) is list: |
|
param_type = " | ".join(raw_param_type) |
|
else: |
|
param_type = raw_param_type |
|
|
|
else: |
|
if "oneOf" in param: |
|
one_of_types = [] |
|
for item in param["oneOf"]: |
|
if "type" in item: |
|
one_of_types.append(convert_data_type(item["type"])) |
|
one_of_types = list(set(one_of_types)) |
|
param_type = " | ".join(one_of_types) |
|
return convert_data_type(param_type) |
|
|
|
|
|
def get_format_param(param: Dict) -> Optional[str]: |
|
"""Get "format" from param. There are cases where format is not directly in param but in oneOf |
|
|
|
Args: |
|
param (Dict): _description_ |
|
|
|
Returns: |
|
Optional[str]: _description_ |
|
""" |
|
if "format" in param: |
|
return param["format"] |
|
if "oneOf" in param: |
|
formats = [] |
|
for item in param["oneOf"]: |
|
if "format" in item: |
|
formats.append(item["format"]) |
|
if len(formats) > 0: |
|
return " or ".join(formats) |
|
return None |
|
|
|
|
|
def get_param_info(param: Dict) -> Optional[str]: |
|
"""get additional information about parameter such as: format, default value, min, max, ... |
|
|
|
Args: |
|
param (Dict): _description_ |
|
|
|
Returns: |
|
Optional[str]: _description_ |
|
""" |
|
param_type = param.get("type", "any") |
|
info_list = [] |
|
if "description" in param: |
|
desc = param["description"] |
|
if not desc.endswith("."): |
|
desc += "." |
|
info_list.append(desc) |
|
|
|
if "default" in param: |
|
default_value = param["default"] |
|
if param_type == "string": |
|
default_value = f'"{default_value}"' |
|
info_list.append(f"Default={default_value}.") |
|
|
|
format_param = get_format_param(param) |
|
if format_param is not None: |
|
info_list.append("Format=" + format_param) |
|
|
|
for field, field_name in [ |
|
("maximum", "Maximum"), |
|
("minimum", "Minimum"), |
|
("maxLength", "Maximum length"), |
|
("minLength", "Minimum length"), |
|
]: |
|
if field in param: |
|
info_list.append(f"{field_name}=" + str(param[field])) |
|
|
|
if len(info_list) > 0: |
|
result = "// " + " ".join(info_list) |
|
result = result.replace("\n", " ") |
|
return result |
|
return None |
|
|
|
|
|
def append_new_param_info( |
|
info_list: List[str], |
|
param_declaration: str, |
|
comment_info: Optional[str], |
|
examples_info: List, |
|
depth: int, |
|
): |
|
"""Append a new parameter with comment to the info_list |
|
|
|
Args: |
|
info_lines (List[str]): current info_list |
|
param_declaration (str): param: type |
|
comment_info (Optional[str]): information of comment |
|
examples_info (List): information of examples given |
|
depth (int): level of nested param |
|
""" |
|
offset = "" |
|
if depth >= 1: |
|
offset = "".join([" " for _ in range(depth)]) |
|
if comment_info is not None: |
|
|
|
info_list.append(f"{offset}{comment_info}") |
|
if len(examples_info) > 0: |
|
for example in examples_info: |
|
info_list.append(f"{offset}{example}") |
|
info_list.append(f"{offset}{param_declaration}") |
|
|
|
|
|
else: |
|
info_list.append(f"{offset}{param_declaration}") |
|
|
|
|
|
def get_examples_info(param_name: str, examples: List) -> List: |
|
"""get information about examples provided |
|
|
|
Args: |
|
param_name (str): _description_ |
|
examples (List): _description_ |
|
|
|
Returns: |
|
List: _description_ |
|
""" |
|
examples_list = [f"// Example {param_name}:"] |
|
for example in examples: |
|
if isinstance(example, dict) or isinstance(example, list): |
|
example_str = json.dumps(example, ensure_ascii=False).replace('\n', '\\n') |
|
else: |
|
example_str = str(example).replace('\n', '\\n') |
|
examples_list.append(f"// {example_str}") |
|
|
|
return examples_list |
|
|
|
|
|
def get_enum_option_str(enum_options: List) -> str: |
|
"""get enum option separated by: "|" |
|
|
|
Args: |
|
enum_options (List): list of options |
|
|
|
Returns: |
|
_type_: concatenation of options separated by "|" |
|
""" |
|
|
|
return " | ".join([f'"{v}"' if type(v) is str else str(v) for v in enum_options]) |
|
|
|
|
|
def get_array_typescript( |
|
param_name: Optional[str], param_dic: dict, depth: int = 0 |
|
) -> str: |
|
"""recursive implementation for generating type script of array |
|
|
|
Args: |
|
param_name (Optional[str]): name of param, optional |
|
param_dic (dict): param_dic |
|
depth (int, optional): nested level. Defaults to 0. |
|
|
|
Returns: |
|
_type_: typescript of array |
|
""" |
|
offset = "" |
|
if depth >= 1: |
|
offset = "".join([" " for _ in range(depth)]) |
|
items_info = param_dic.get("items", {}) |
|
|
|
if len(items_info) == 0: |
|
if param_name is not None: |
|
return f"{offset}{param_name}: []" |
|
else: |
|
return "[]" |
|
array_type = get_param_type(items_info) |
|
if array_type == "object": |
|
info_lines = [] |
|
child_lines = get_parameter_typescript( |
|
items_info.get("properties", {}), items_info.get("required", []), depth + 1 |
|
) |
|
|
|
|
|
if param_name is not None: |
|
info_lines.append(f"{offset}{param_name}" + ": {") |
|
else: |
|
info_lines.append(f"{offset}" + "{") |
|
info_lines.extend(child_lines) |
|
info_lines.append(f"{offset}" + "}[]") |
|
return "\n".join(info_lines) |
|
|
|
elif array_type == "array": |
|
item_info = get_array_typescript(None, items_info, depth + 1) |
|
if param_name is None: |
|
return f"{item_info}[]" |
|
return f"{offset}{param_name}: {item_info.strip()}[]" |
|
|
|
else: |
|
if "enum" in items_info: |
|
item_type = get_enum_option_str(items_info["enum"]) |
|
if param_name is None: |
|
return f"({item_type})[]" |
|
else: |
|
return f"{offset}{param_name}: ({item_type})[]" |
|
else: |
|
if param_name is None: |
|
return f"{array_type}[]" |
|
else: |
|
return f"{offset}{param_name}: {array_type}[]," |
|
|
|
|
|
def get_parameter_typescript(properties, required_params, depth=0) -> List[str]: |
|
"""Recursion, returning the information about parameters including data type, description and other information |
|
These kinds of information will be put into the prompt |
|
|
|
Args: |
|
properties (_type_): properties in parameters |
|
required_params (_type_): List of required parameters |
|
depth (int, optional): the depth of params (nested level). Defaults to 0. |
|
|
|
Returns: |
|
_type_: list of lines containing information about all parameters |
|
""" |
|
tp_lines = [] |
|
for param_name, param in properties.items(): |
|
|
|
|
|
if not isinstance(param, dict): |
|
continue |
|
|
|
comment_info = get_param_info(param) |
|
|
|
examples_info = [] |
|
if "examples" in param: |
|
examples_info = get_examples_info(param_name, param["examples"]) |
|
|
|
param_declaration = f"{param_name}" |
|
if isinstance(required_params, list): |
|
if param_name not in required_params: |
|
param_declaration += "?" |
|
param_type = get_param_type(param) |
|
|
|
offset = "" |
|
if depth >= 1: |
|
offset = "".join([" " for _ in range(depth)]) |
|
|
|
if param_type == "object": |
|
child_lines = get_parameter_typescript( |
|
param.get("properties", {}), param.get("required", []), depth + 1 |
|
) |
|
if comment_info is not None: |
|
tp_lines.append(f"{offset}{comment_info}") |
|
if len(examples_info) > 0: |
|
for example in examples_info: |
|
tp_lines.append(f"{offset}{example}") |
|
|
|
param_declaration += ": {" |
|
tp_lines.append(f"{offset}{param_declaration}") |
|
tp_lines.extend(child_lines) |
|
tp_lines.append(f"{offset}" + "},") |
|
|
|
elif param_type == "array": |
|
item_info = param.get("items", {}) |
|
if "type" not in item_info: |
|
param_declaration += ": []," |
|
append_new_param_info( |
|
tp_lines, param_declaration, comment_info, examples_info, depth |
|
) |
|
else: |
|
array_declaration = get_array_typescript( |
|
param_declaration, param, depth |
|
) |
|
if not array_declaration.endswith(","): |
|
array_declaration += "," |
|
if comment_info is not None: |
|
tp_lines.append(f"{offset}{comment_info}") |
|
if len(examples_info) > 0: |
|
for example in examples_info: |
|
tp_lines.append(f"{offset}{example}") |
|
tp_lines.append(array_declaration) |
|
else: |
|
if "enum" in param: |
|
param_type = get_enum_option_str(param["enum"]) |
|
|
|
if "nullable" in param and param["nullable"] is True: |
|
param_type += " | null" |
|
param_declaration += f": {param_type}," |
|
append_new_param_info( |
|
tp_lines, param_declaration, comment_info, examples_info, depth |
|
) |
|
|
|
return tp_lines |
|
|
|
def generate_schema_from_functions( |
|
functions: List[Function], namespace="functions" |
|
) -> str: |
|
""" |
|
Convert functions schema to a schema that language models can understand. |
|
""" |
|
|
|
schema = "// Supported function definitions that should be called when necessary.\n" |
|
schema += f"namespace {namespace} {{\n\n" |
|
|
|
for function in functions: |
|
|
|
if not isinstance(function, dict): |
|
function = function.model_dump() |
|
function_name = function.get("name", None) |
|
if function_name is None: |
|
continue |
|
|
|
description = function.get("description", "") |
|
schema += f"// {description}\n" |
|
schema += f"type {function_name}" |
|
|
|
parameters = function.get("parameters", None) |
|
if parameters is not None and parameters.get("properties") is not None: |
|
parameters = deepcopy(jsonref.JsonRef.replace_refs(parameters)) |
|
schema += " = (_: {\n" |
|
required_params = parameters.get("required", []) |
|
tp_lines = get_parameter_typescript( |
|
parameters.get("properties"), |
|
required_params, |
|
0, |
|
) |
|
schema += "\n".join(tp_lines) |
|
schema += "\n}) => any;\n\n" |
|
else: |
|
|
|
schema += " = () => any;\n\n" |
|
|
|
schema += f"}} // namespace {namespace}" |
|
|
|
return schema |
|
|
|
class FunctionaryTokenizer(PreTrainedTokenizerFast): |
|
def apply_chat_template( |
|
self, |
|
conversation: Union[List[Dict[str, str]], List[List[Dict[str, str]]], str], |
|
tools: Optional[List[Dict[str, Any]]], |
|
chat_template: Optional[str] = None, |
|
add_generation_prompt: bool = False, |
|
tokenize: bool = True, |
|
padding: bool = False, |
|
truncation: bool = False, |
|
max_length: Optional[int] = None, |
|
return_tensors: Optional[Union[str, TensorType]] = None, |
|
return_dict: bool = False, |
|
tokenizer_kwargs: Optional[Dict[str, Any]] = None, |
|
**kwargs, |
|
) -> Union[str, List[int], List[str], List[List[int]], BatchEncoding]: |
|
|
|
if return_dict and not tokenize: |
|
raise ValueError( |
|
"`return_dict=True` is incompatible with `tokenize=False`, because there is no dict " |
|
"of tokenizer outputs to return." |
|
) |
|
|
|
if tokenizer_kwargs is None: |
|
tokenizer_kwargs = {} |
|
|
|
using_default_template = False |
|
|
|
|
|
if isinstance(self.chat_template, dict) or ( |
|
self.chat_template is None and isinstance(self.default_chat_template, dict) |
|
): |
|
if self.chat_template is not None: |
|
template_dict = self.chat_template |
|
using_default_dict = False |
|
else: |
|
template_dict = self.default_chat_template |
|
using_default_dict = True |
|
if chat_template is not None and chat_template in template_dict: |
|
|
|
chat_template = template_dict[chat_template] |
|
if using_default_dict: |
|
using_default_template = True |
|
elif chat_template is None and "default" in template_dict: |
|
chat_template = template_dict["default"] |
|
if using_default_dict: |
|
using_default_template = True |
|
elif chat_template is None: |
|
raise ValueError( |
|
"This model has multiple chat templates with no default specified! Please either pass a chat " |
|
"template or the name of the template you wish to use to the `chat_template` argument. Available " |
|
f"template names are {sorted(template_dict.keys())}." |
|
) |
|
elif chat_template is None: |
|
|
|
|
|
if self.chat_template is not None: |
|
chat_template = self.chat_template |
|
else: |
|
chat_template = self.default_chat_template |
|
using_default_template = True |
|
|
|
if using_default_template: |
|
logger.warning_once( |
|
"No chat template is set for this tokenizer, falling back to a default class-level template. This is " |
|
"very error-prone, because models are often trained with templates different from the class default! " |
|
"Default chat templates are a legacy feature and will be removed in Transformers v4.43, at which " |
|
"point any code depending on them will stop working. We recommend setting a valid chat template before " |
|
"then to ensure that this model continues working without issues." |
|
) |
|
|
|
|
|
functions_pydantic_to_render = [] |
|
has_code_interpreter = False |
|
for i in range(len(tools)): |
|
tool_pydantic = Tool.model_validate(tools[i]) |
|
if tool_pydantic.type == "function": |
|
functions_pydantic_to_render.append(tool_pydantic.function) |
|
else: |
|
has_code_interpreter = True |
|
conversation.insert(0, {"role": "system", "content": generate_schema_from_functions(functions_pydantic_to_render)}) |
|
|
|
system_prompt_to_use = SYSTEM_PROMPT if not has_code_interpreter else CODE_INTERPRETER_SYSTEM_PROMPT |
|
conversation.insert(1, {"role": "system", "content": system_prompt_to_use}) |
|
|
|
|
|
compiled_template = self._compile_jinja_template(chat_template) |
|
|
|
if isinstance(conversation, (list, tuple)) and ( |
|
isinstance(conversation[0], (list, tuple)) or hasattr(conversation[0], "messages") |
|
): |
|
conversations = conversation |
|
is_batched = True |
|
else: |
|
conversations = [conversation] |
|
is_batched = False |
|
|
|
rendered = [] |
|
template_kwargs = {**self.special_tokens_map, **kwargs} |
|
for chat in conversations: |
|
if hasattr(chat, "messages"): |
|
|
|
chat = chat.messages |
|
rendered_chat = compiled_template.render( |
|
messages=chat, add_generation_prompt=add_generation_prompt, **template_kwargs |
|
) |
|
rendered.append(rendered_chat) |
|
|
|
if not is_batched: |
|
rendered = rendered[0] |
|
|
|
if tokenize: |
|
out = self( |
|
rendered, |
|
padding=padding, |
|
truncation=truncation, |
|
max_length=max_length, |
|
add_special_tokens=False, |
|
return_tensors=return_tensors, |
|
**tokenizer_kwargs, |
|
) |
|
if return_dict: |
|
return out |
|
else: |
|
return out["input_ids"] |
|
else: |
|
return rendered |