Support tool calling in Transformers
Hi all! This PR modifies the chat template to add support for tool use with our new API. In testing, results from applying the chat template seem to be identical to results from passing the same set of messages, tools, and tool calls / tool responses to the Mistral tokenizers (modulo a couple of spaces around special tokens that should get folded into the special token by our tokenizer).
Here is some sample code to try it. Please install the latest version of transformers
from main
before running it:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
checkpoint = "mistralai/Mistral-7B-Instruct-v0.3"
tokenizer = AutoTokenizer.from_pretrained(checkpoint, revision="pr/35")
model = AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=torch.bfloat16, device_map="auto")
def get_current_temperature(location: str, unit: str) -> float:
"""
Get the current temperature at a location.
Args:
location: The location to get the temperature for, in the format "City, Country"
unit: The unit to return the temperature in. (choices: ["celsius", "fahrenheit"])
Returns:
The current temperature at the specified location in the specified units, as a float.
"""
return 22. # A real function should probably actually get the temperature!
def get_current_wind_speed(location: str) -> float:
"""
Get the current wind speed in km/h at a given location.
Args:
location: The location to get the temperature for, in the format "City, Country"
Returns:
The current wind speed at the given location in km/h, as a float.
"""
return 6. # A real function should probably actually get the wind speed!
tools = [get_current_temperature, get_current_wind_speed]
messages = [
{"role": "system", "content": "You are a bot that responds to weather queries. You should reply with the unit used in the queried location."},
{"role": "user", "content": "Hey, what's the temperature in Paris right now?"}
]
# Add a tool call and tool response to the history
tool_call_id = "abcdef123" # Random ID, should be unique for each tool call
tool_call = {"name": "get_current_temperature", "arguments": {"location": "Paris, France", "unit": "celsius"}}
messages.append({"role": "assistant", "tool_calls": [{"id": tool_call_id, "type": "function", "function": tool_call}]})
messages.append({"role": "tool", "tool_call_id": tool_call_id, "name": "get_current_temperature", "content": "22.0"})
# Tokenize and generate a response
inputs = tokenizer.apply_chat_template(messages, tools=tools, add_generation_prompt=True, return_dict=True, return_tensors="pt")
inputs = {k: v.to(model.device) for k, v in inputs.items()}
out = model.generate(**inputs, max_new_tokens=128)
print(tokenizer.decode(out[0][len(inputs["input_ids"][0]):]))
Working on implementing tool use in vLLM's openAI-compatible API and this would be awesome! Digging into it now.
Copying this from the PR on Hermes 2 Pro:
Currently the tokenizer config has a bos_token
field in the tokenizer_config.json.
Would it be possible to add a similarbot_token
field (bot -> Beginning of Tool call) or maybe tool_token
to the tokenizer config that indicates when a tool call is being generated by the model as opposed to a chat completion?
For Mistral 7b instruct v0.3 it would be `"bot_token": "[TOOL_CALLS]"
Why am I asking for this?
This will provide a standardized way for open-source serving tools & frameworks to determine which token to look for when trying to infer if the response that the model has started generating is a chat response or a tool call, so that the tool/framework can provide an appropriate response to the client that is indicative of that fact.
This has been a sticking point for me when trying to implement OpenAI API-compatible tool calling into vLLM, and in particular trying to implement SSE streaming of said tool calls in an OpenAI-compatible way.
Depending on the model that you're using, you have to know which token to look for that indicates the start of a tool call (and sometimes for the end of one as well!), and right now I'm stuck hard-coding it depending on which model I'm using, since there's nowhere that I can look it up in the tokenizer or tokenizer config. The tokens are in the tokenizer, but I can't "auto-detect" which tokens are the right ones to look for because there's no field that tells me. I need a consistent way, provided in the tokenizer/tokenizer config, to look up which token indicates the start of a tool call so that I can handle sending either a chat completion response or a tool call response normally, but also more importantly for streaming since OpenAI API uses different SSE formats when streaming tool calls vs. chat completions.
I anticipate that this will be a problem for maintainers and contributors of other open-source serving frameworks that try to provide an OpenAI API-compatible interface as well, so I think it would be really great to have this!
Please let me know your thoughts on this :)
Hi @Rocketknight1 - unfortunately I am finding that this template doesn't reliably work. I think the mistral team had some notes on this in #38
Hi
@kmistele
, yes, sorry! These PRs are on-hold while we merge a batch of more critical PRs to all the Mistral models that align the behaviour of the HF tokenizers with the mistral-common
tokenizers. After that's done I'll rebase and revive the tool use template PRs - until then there will be issues with the templates because of special token handling.