feat/tools-in-chat-template

#21
by lcahill - opened

Tool tokens (and related fine-tuning) are a major asset of this model. Being able to add available tools, tool invocations and tool results as messages in a chat template will enhance the ease of use of tools via the transformers library.

Tested this updated chat template using the following code. Please excuse the length of the below. It includes testing with hardcoded prompts, followed by a chat loop to allow for testing in multi-turn conversations.

from transformers import LlamaTokenizerFast, MistralForCausalLM
import torch
import json
import multiprocessing
import traceback


def restricted_exec(code, return_dict):
    try:
        # Define a restricted set of built-ins
        allowed_builtins = {
            'print': print,
            'range': range,
            'len': len,
            'int': int,
            'float': float,
            'str': str,
            'bool': bool,
            'list': list,
            'dict': dict,
            'set': set,
            'tuple': tuple,
            'abs': abs,
            'sum': sum,
            'min': min,
            'max': max,
            'sorted': sorted,
            'zip': zip,
            'enumerate': enumerate,
            'map': map,
            'filter': filter,
            'all': all,
            'any': any,
        }

        # Create a restricted global and local namespace
        restricted_globals = {
            '__builtins__': allowed_builtins,
        }
        restricted_locals = {}

        execution_code = f'output = {code}'
        # Execute the code in the restricted environment
        exec(execution_code, restricted_globals, restricted_locals)
        return_dict['output'] = restricted_locals.get('output', '')
    except Exception as e:
        return_dict['error'] = traceback.format_exc()


def python_eval_function(python_code_string):
    """
    Executes the given Python code string in a restricted environment using multiprocessing.

    :param python_code_string: The Python code to execute.
    :return: The output of the executed code.
    """
    manager = multiprocessing.Manager()
    return_dict = manager.dict()

    process = multiprocessing.Process(target=restricted_exec, args=(python_code_string, return_dict))
    process.start()
    process.join(timeout=30)  # Set a timeout to prevent long-running code

    if process.is_alive():
        process.terminate()
        raise RuntimeError("Code execution timed out")

    if 'error' in return_dict:
        raise RuntimeError(f"Error executing code: {return_dict['error']}")

    return return_dict.get('output', '')


def dummy_weather_function(location, format):
    print(f"location passed to dummy weather function: {location}")
    print(f"format passed to dummy weather function: {format}")
    return f"Fine, with a chance of showers."


tools_dict = {
    "python_eval_function": python_eval_function,
    'get_current_weather': dummy_weather_function,
}

if __name__ == '__main__':
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    model_id = 'mistralai/Mistral-7B-Instruct-v0.3'

    tokenizer = LlamaTokenizerFast.from_pretrained(model_id)
    model = MistralForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16,
                                               device_map=device)

    model.generation_config.max_new_tokens = 3000
    model.generation_config.pad_token_id = tokenizer.eos_token_id

    available_tools = [
        {
            "type": "function",
            "function": {
                "name": "get_current_weather",
                "description": "Get the current weather",
                "parameters": {
                    "type": "object",
                    "properties": {
                        "location": {
                            "type": "string",
                            "description": "The city and state, e.g. San Francisco, CA"
                        },
                        "format": {
                            "type": "string",
                            "enum": ["celsius", "fahrenheit"],
                            "description": "The temperature unit to use. Infer this from the users location."
                        }
                    },
                    "required": ["location", "format"]
                }
            }
        },
        {
            "type": "function",
            "function": {
                "name": "python_eval_function",
                "description": "Execute a single line of arbitrary python code. The result of this execution will be returned.",
                "parameters": {
                    "type": "object",
                    "properties": {
                        "python_code_string": {
                            "type": "string",
                            "description": "Single line of python code to execute."
                        }
                    },
                    "required": ["python_code_string"]
                }
            }
        }
    ]

    available_tools_json = json.dumps(available_tools)

    messages_presented = [
        {
            'role': 'user',
            'content': "What's 34529 times 12049581?",
        },
    ]

    messages_raw = [
        {
            'role': 'available_tools',
            'content': available_tools_json,
        },
        *messages_presented
    ]

    prompt_string = tokenizer.apply_chat_template(messages_raw, tokenize=False)

    prompt_tokens = tokenizer.encode(prompt_string)

    prompt_tensor = torch.tensor([prompt_tokens]).to(device)

    result_tokens = model.generate(prompt_tensor)[0]

    new_tokens = result_tokens[len(prompt_tokens):]

    result_string = tokenizer.decode(new_tokens)

    TOOL_CALL_TOKEN = 5
    EOS_TOKEN = 2

    if new_tokens[-1] != 2:
        raise Exception(f"ERROR: Incomplete response")

    if new_tokens[0] == TOOL_CALL_TOKEN:  # if this is a tool call
        tool_call_content_json = new_tokens[1:-1]  # remove the tool call token and the eos token
        tool_json = tokenizer.decode(tool_call_content_json)
        # print(f"INFO: {tool_json=}")
        tool_call_object = json.loads(tool_json)[0]  # TODO: Handle invalid json.
        if tool_call_object.get('name') in tools_dict:
            tool_function = tools_dict[tool_call_object.get('name')]
            tool_call_result = tool_function(
                **tool_call_object.get('arguments'))  # TODO: Handle invalid arguments or nulls, or no arguments key.
            print(f"Tool call result: {tool_call_result}")
            messages_raw += [
                {
                    'role': 'tool_call',
                    'content': tool_json,
                },
                {
                    'role': 'tool_results',
                    'content': str(tool_call_result),
                }
            ]
        else:
            raise Exception(f"TODO: Handling of invalid tools.")
    else:
        raise Exception(f"TODO: Populate with logic to just get new user response.")

    post_tool_prompt_string = tokenizer.apply_chat_template(messages_raw, tokenize=False)

    post_tool_prompt_tokens = tokenizer.encode(post_tool_prompt_string)

    post_tool_prompt_tensor = torch.tensor([post_tool_prompt_tokens]).to(device)

    post_tool_result_tokens = model.generate(post_tool_prompt_tensor)[0]

    post_tool_new_tokens = post_tool_result_tokens[len(post_tool_prompt_tokens):]

    post_tool_result_string = tokenizer.decode(post_tool_new_tokens)

    print(f"Final Output after using tool: '{post_tool_result_string}'")

    print(f"INFO: Tool tested successfully. Proceeding with chat loop for continued testing over longer conversations.")

    messages_presented = []
    skip_user_input = False  # if the previous loop just invoked a tool, skip the user input and allow the model to generate using the result.

    while True:

        if skip_user_input:
            skip_user_input = False  # if skipping this time, reset the flag so the next is not skipped.

            messages_raw = [
                {
                    'role': 'available_tools',
                    'content': available_tools_json,
                },
                *messages_presented
            ]
        else:
            new_user_message = input(f"User: ")
            messages_presented.append(
                {
                    'role': 'user',
                    'content': new_user_message,
                }
            )
            # if taking a new user message, available tools should be before the last user message.
            messages_raw = [
                *messages_presented[:-1],
                {
                    'role': 'available_tools',
                    'content': available_tools_json,
                },
                messages_presented[-1]
            ]

        prompt_string = tokenizer.apply_chat_template(messages_raw, tokenize=False)

        prompt_tokens = tokenizer.encode(prompt_string)

        prompt_tensor = torch.tensor([prompt_tokens]).to(device)

        TOOL_CALL_TOKEN = 5
        EOS_TOKEN = 2
        CLOSE_SQUARE_BRACKET_TOKEN = 29561 # used to determine the end of a tool calls list.

        def determine_if_tool_invoked(prompt_tensor, model):
            previous_max_new_tokens = model.generation_config.max_new_tokens
            model.generation_config.max_new_tokens = 1
            first_token = model.generate(prompt_tensor)[0][len(prompt_tokens)].item() # get the first generated token.
            model.generation_config.max_new_tokens = previous_max_new_tokens # reset generation config.
            return first_token, first_token == TOOL_CALL_TOKEN

        first_token, tool_was_invoked = determine_if_tool_invoked(prompt_tensor, model)
        prompt_tokens.append(TOOL_CALL_TOKEN) # Since we already generated the first token, add this to our prompt so it is not re-computed.
        prompt_tensor = torch.tensor([prompt_tokens]).to(device)

        if tool_was_invoked:
            previous_eos_token = model.generation_config.eos_token_id
            model.generation_config.eos_token_id = CLOSE_SQUARE_BRACKET_TOKEN # If a tool was invoked, stop generation at end of tool invocation list.
            result_tokens = model.generate(prompt_tensor)[0]
            model.generation_config.eos_token_id = previous_eos_token
            #TODO: restrict generation length to something reasonable for tools, and raise/handle exception if reached.
        else:
            # if no tool was invoked, simply generate a response
            result_tokens = model.generate(prompt_tensor)[0]

        new_tokens = result_tokens[len(prompt_tokens):]

        result_string = tokenizer.decode(new_tokens)

        if tool_was_invoked:  # if this is a tool call
            tool_call_content_json = new_tokens  # remove the eos token
            tool_json = tokenizer.decode(tool_call_content_json)
            tool_call_object = json.loads(tool_json)[0]  # TODO: Handle invalid json.
            if tool_call_object.get('name') in tools_dict:
                tool_function = tools_dict[tool_call_object.get('name')]
                tool_call_result = tool_function(
                    **tool_call_object.get(
                        'arguments'))  # TODO: Handle invalid arguments or nulls, or no arguments key.
                print(f"Tool call result: {tool_call_result}")
                messages_presented += [
                    {
                        'role': 'tool_call',
                        'content': tool_json,
                    },
                    {
                        'role': 'tool_results',
                        'content': str(tool_call_result),
                    },
                ]
                skip_user_input = True  # Skip the next user input if tool invoked to allow assistant to create a response using tool output.
            else:
                raise Exception(f"TODO: Handling of invalid tools.")
        else:
            assistant_response_cleaned = tokenizer.decode(
                new_tokens[:-1])  # exclude eos token. This will be added back by chat template on next loop.
            messages_presented.append({
                'role': 'assistant',
                'content': assistant_response_cleaned,
            })
            print(f"Assistant: {assistant_response_cleaned}")

Note that I have tested the template using the chat loop with multi-turn conversations, using both tools multiple times, with successful use of these tools by the model. The output of this multi-turn chat was:

Connected to pydev debugger (build 233.14475.56)
Loading checkpoint shards: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 3/3 [00:05<00:00,  1.92s/it]
Tool call result: 416059982349
Final Output after using tool: 'The result of 34529 times 12049581 is 416059982349.</s>'
INFO: Tool tested successfully. Proceeding with chat loop for continued testing over longer conversations.
User: >? Hey, what's the weather like in auckland right now? 
location passed to dummy weather function: Auckland, NZ
format passed to dummy weather function: celsius
Tool call result: Fine, with a chance of showers.
Assistant: It seems that it's currently fine with a chance of showers in Auckland, New Zealand. Enjoy your day! If you need any other information, feel free to ask.
User: >? Thanks! And what's 58273957 times 293847?
Tool call result: 17123627442579
Assistant: The result of 58273957 times 293847 is 17123627442579.
User: >? Thanks! what about 91827464 / 2817236 ?
Tool call result: 32.59487810037924
Assistant: The result of 91827464 divided by 2817236 is approximately 32.59487810037924.
User: >? fun stuff. How about the weather in melbourne?
location passed to dummy weather function: Melbourne, AU
format passed to dummy weather function: celsius
Tool call result: Fine, with a chance of showers.
Assistant: It seems that it's currently fine with a chance of showers in Melbourne, Australia. Have a great day! If you need any other information, feel free to ask.
lcahill changed pull request status to open

@lcahill
How did you manage to load tokenizer?
LlamaTokenizerFast.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3") gives an exception:
Exception: data did not match any variant of untagged enum PyPreTokenizerTypeWrapper at line 6952 column 3

@lcahill
Hey I am trying to figure out the prompt structure of the model and I haven't found any information. It seems like you have a good grip on this. If you can provide some resources on that it would be much appreciated.

Here are some additional questions:
How do you know that LlamaTokenizerFast is compatible with mistralai/Mistral-7B-Instruct-v0.3?
How do you know that the below message type is supported:

                {
                    'role': 'available_tools',
                    'content': available_tools_json,
                },

@lcahill
How did you manage to load tokenizer?
LlamaTokenizerFast.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3") gives an exception:
Exception: data did not match any variant of untagged enum PyPreTokenizerTypeWrapper at line 6952 column 3

Hi @Alordan ,

Strange, I am not getting the same exception. Running on Windows 11 with transformers version 4.41.1. Maybe you need to update your transformers version?

@lcahill
Hey I am trying to figure out the prompt structure of the model and I haven't found any information. It seems like you have a good grip on this. If you can provide some resources on that it would be much appreciated.

Hi @SebastianS , I was able to figure out the chat template by using mistral's mistral_inference and mistral_common python libraries. They have some instructions in the readme of this model.

How do you know that LlamaTokenizerFast is compatible with mistralai/Mistral-7B-Instruct-v0.3?

I got this by loading the tokenizer, then checking its type. The AutoTokenizer will figure out the correct tokenizer based on the config.

import transformers
from transformers import AutoTokenizer, LlamaTokenizerFast

print(f"{transformers.__version__=}")

model_id = 'mistralai/Mistral-7B-Instruct-v0.3'

# Load both tokenizers and ensure they are the same type

tokenizer_from_pretrained = AutoTokenizer.from_pretrained(model_id)

print(f"{type(tokenizer_from_pretrained)=}")

tokenizer_from_llama_tokenizer_fast = LlamaTokenizerFast.from_pretrained(model_id)

print(f"{type(tokenizer_from_llama_tokenizer_fast)=}")

assert type(tokenizer_from_pretrained) == type(tokenizer_from_llama_tokenizer_fast)

This outputs:

transformers.__version__='4.41.1'
type(tokenizer_from_pretrained)=<class 'transformers.models.llama.tokenization_llama_fast.LlamaTokenizerFast'>
type(tokenizer_from_llama_tokenizer_fast)=<class 'transformers.models.llama.tokenization_llama_fast.LlamaTokenizerFast'>

How do you know that the below message type is supported:

                {
                    'role': 'available_tools',
                    'content': available_tools_json,
                },

This is only compatible after changing the chat template, as I am attempting to do via this PR. Note you can also change this at runtime by running something like the below.

# apply custom chat template to model as an example.
from transformers import LlamaTokenizerFast, AutoModelForCausalLM
import json
import torch

new_chat_template = """{{ bos_token }}{% for message in messages %}{% if loop.index0 == 0 and message['role'] not in ['available_tools', 'user'] %}{{ raise_exception('The first message must be either available_tools or user role!') }}{% endif %}{% if message['role'] not in ['user', 'assistant', 'available_tools', 'tool_call', 'tool_results'] %}{{ raise_exception('Only user, assistant, available_tools, tool_call, and tool_results roles are supported!') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + '[/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token }}{% elif message['role'] == 'available_tools' %}{{ '[AVAILABLE_TOOLS] ' + message['content'] + '[/AVAILABLE_TOOLS]' }}{% elif message['role'] == 'tool_call' %}{{ '[TOOL_CALLS]' + message['content'] + eos_token }}{% elif message['role'] == 'tool_results' %}{{ '[TOOL_RESULTS]' + message['content'] + '[/TOOL_RESULTS]' }}{% endif %}{% endfor %}"""

available_tools = [
    {
        "type": "function",
        "function": {
            "name": "get_current_weather",
            "description": "Get the current weather",
            "parameters": {
                "type": "object",
                "properties": {
                    "location": {
                        "type": "string",
                        "description": "The city and state, e.g. San Francisco, CA"
                    },
                    "format": {
                        "type": "string",
                        "enum": ["celsius", "fahrenheit"],
                        "description": "The temperature unit to use. Infer this from the users location."
                    }
                },
                "required": ["location", "format"]
            }
        }
    },
    {
        "type": "function",
        "function": {
            "name": "python_eval_function",
            "description": "Execute a single line of arbitrary python code. The result of this execution will be returned.",
            "parameters": {
                "type": "object",
                "properties": {
                    "python_code_string": {
                        "type": "string",
                        "description": "Single line of python code to execute."
                    }
                },
                "required": ["python_code_string"]
            }
        }
    }
]

available_tools_json = json.dumps(available_tools)


messages = [
    {
        'role': 'user',
        'content': "Hey! What's the weather like in auckland?"
    },
    {
        'role': 'tool_call',
        'content': '[{"name": "get_current_weather", "arguments": {"location": "Auckland, NZ", "format": "celsius"}}]'
    },
    {
        'role': 'tool_results',
        'content': 'Fine, with a chance of showers.'
    },
    {
        'role': 'assistant',
        'content': "It looks like it's going to be a bit rainy in Auckland today. Be sure to take an umbrella with you!"
    },
    {
        'role': 'user',
        'content': 'Thanks! And whats 9817249382934 times 116263212356?'
    },
    {
        'role': 'tool_call',
        'content': '[{"name": "python_eval_function", "arguments": {"python_code_string": "9817249382934 * 116263212356"}}]'
    },
    {
        'role': 'tool_results',
        'content': '1141384949759865604332504'
    },
    {
        'role': 'assistant',
        'content': "That's a very large number! It's 1,141,384,949,975,986,560,433,250,040 when written in standard format."
    },
    {
        'role': 'available_tools',
        'content': available_tools_json,
    },
    {
        'role': 'user',
        'content': 'interesting that you failed to correctly provide the tool result. I guess there are some limitations to small LLMs! What is the actual result?'
    },
]

model_id = 'mistralai/Mistral-7B-Instruct-v0.3'

device = 'cuda' if torch.cuda.is_available() else 'cpu'

tokenizer = LlamaTokenizerFast.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, device_map=device, torch_dtype=torch.bfloat16)

model.generation_config.max_new_tokens = 3000
model.generation_config.pad_token_id = tokenizer.eos_token_id

input_string = tokenizer .apply_chat_template(
    messages,
    chat_template=new_chat_template,
    tokenize=False,
)
print(f"{input_string=}")

input_tokens = tokenizer.encode(input_string, return_tensors='pt').to(device)

result_tokens = model.generate(input_tokens)[0]

new_tokens = result_tokens[len(input_tokens[0]):]

result_string = tokenizer.decode(new_tokens)

print(f"{result_string=}")

This outputs:

Loading checkpoint shards: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 3/3 [00:05<00:00,  1.92s/it]
input_string='<s>[INST] Hey! What\'s the weather like in auckland?[/INST][TOOL_CALLS][{"name": "get_current_weather", "arguments": {"location": "Auckland, NZ", "format": "celsius"}}]</s>[TOOL_RESULTS]Fine, with a chance of showers.[/TOOL_RESULTS]It looks like it\'s going to be a bit rainy in Auckland today. Be sure to take an umbrella with you!</s>[INST] Thanks! And whats 9817249382934 times 116263212356?[/INST][TOOL_CALLS][{"name": "python_eval_function", "arguments": {"python_code_string": "9817249382934 * 116263212356"}}]</s>[TOOL_RESULTS]1141384949759865604332504[/TOOL_RESULTS]That\'s a very large number! It\'s 1,141,384,949,975,986,560,433,250,040 when written in standard format.</s>[AVAILABLE_TOOLS] [{"type": "function", "function": {"name": "get_current_weather", "description": "Get the current weather", "parameters": {"type": "object", "properties": {"location": {"type": "string", "description": "The city and state, e.g. San Francisco, CA"}, "format": {"type": "string", "enum": ["celsius", "fahrenheit"], "description": "The temperature unit to use. Infer this from the users location."}}, "required": ["location", "format"]}}}, {"type": "function", "function": {"name": "python_eval_function", "description": "Execute a single line of arbitrary python code. The result of this execution will be returned.", "parameters": {"type": "object", "properties": {"python_code_string": {"type": "string", "description": "Single line of python code to execute."}}, "required": ["python_code_string"]}}}][/AVAILABLE_TOOLS][INST] interesting that you failed to correctly provide the tool result. I guess there are some limitations to small LLMs! What is the actual result?[/INST]'
result_string='Apologies for the confusion. The result of the multiplication operation is 1,141,384,949,975,986,560,433,250,040.</s>'

Process finished with exit code 0

Note the model is still getting this wrong. It could be inherent limitations of a 7b model, or it could be due to my ordering of the messages. I am putting the available tools before the last user message because this comment mentioned that is how the mistral libraries do it.

https://huggingface.co./mistralai/Mistral-7B-Instruct-v0.3/discussions/17#66519aade6ddaa4f89902342

Regardless of how the messages need to be ordered, I think adding tools to the chat template makes them easier to implement.

There's a chance I may have issues with my implementation of the chat_template so I am keen for people to try it out and provide feedback. I have noticed the model can be quite sensitive to layout of tokens so even things like missing/additional spaces in the chat template can impact model output quality.

Hi, Matt from Hugging Face here. We've actually added a new tool use API that will come out in the next version of Transformers. You can see the API docs here. We've also made a PR to Mistral to enable the new API, and you can see it here

Cannot merge
This branch has merge conflicts in the following files:
  • tokenizer_config.json

Sign up or log in to comment