tools / routers /tool_calls.py
Germano Cavalcante
Fix tool_calls calling wiki_search without groups
25dbca2
raw
history blame
3.73 kB
import json
import logging
from fastapi import APIRouter, Body
from typing import List, Dict
from pydantic import BaseModel
try:
from tool_gpu_checker import gpu_checker_get_message
from tool_bpy_doc import bpy_doc_get_documentation
from tool_find_related import find_related
from tool_wiki_search import wiki_search
except:
from routers.tool_gpu_checker import gpu_checker_get_message
from routers.tool_bpy_doc import bpy_doc_get_documentation
from routers.tool_find_related import find_related
from routers.tool_wiki_search import wiki_search
class ToolCallFunction(BaseModel):
name: str
arguments: str
class ToolCallInput(BaseModel):
id: str
type: str
function: ToolCallFunction
router = APIRouter()
def process_tool_call(tool_call: ToolCallInput) -> Dict:
output = {"tool_call_id": tool_call.id, "output": ""}
function_name = tool_call.function.name
try:
function_args = json.loads(tool_call.function.arguments)
if function_name == "get_bpy_api_info":
output["output"] = bpy_doc_get_documentation(
function_args.get("api", ""))
elif function_name == "check_gpu":
output["output"] = gpu_checker_get_message(
function_args.get("gpu", ""))
elif function_name == "find_related":
output["output"] = find_related(
function_args["repo"], function_args["number"])
elif function_name == "wiki_search":
output["output"] = wiki_search(**function_args)
except json.JSONDecodeError as e:
error_message = f"Malformed JSON encountered at position {e.pos}: {e.msg}\n {tool_call.function.arguments}"
output["output"] = error_message
# Logging the error for further investigation
logging.error(f"JSONDecodeError in process_tool_call: {error_message}")
return output
@router.post("/function_call", response_model=List[Dict])
def function_call(tool_calls: List[ToolCallInput] = Body(..., description="List of tool calls in the request body")):
"""
Endpoint to process tool calls.
Args:
tool_calls (List[ToolCallInput]): List of tool calls.
Returns:
List[Dict]: List of tool outputs with tool_call_id and output.
"""
tool_outputs = [process_tool_call(tool_input) for tool_input in tool_calls]
return tool_outputs
if __name__ == "__main__":
tool_calls_data = [
{
"id": "call_abc123",
"type": "function",
"function": {
"name": "get_bpy_api_info",
"arguments": "{\"api\":\"bpy.context.scene.world\"}"
}
},
{
"id": "call_abc456",
"type": "function",
"function": {
"name": "check_gpu",
"arguments": "{\"gpu\":\"Mesa Intel(R) Iris(R) Plus Graphics 640 (Kaby Lake GT3e) (KBL GT3) Intel 4.6 (Core Profile) Mesa 22.2.5\"}"
}
},
{
"id": "call_abc789",
"type": "function",
"function": {
"name": "find_related",
"arguments": "{\"repo\":\"blender\",\"number\":111434}"
}
},
{
"id": "call_abc101112",
"type": "function",
"function": {
"name": "wiki_search",
"arguments": "{\"query\":\"Set Snap Base\",\"groups\":[\"manual\"]}"
}
}
]
tool_calls = [
ToolCallInput(id=tc['id'], type=tc['type'],
function=ToolCallFunction(**tc['function']))
for tc in tool_calls_data
]
test = function_call(tool_calls)
print(test)