Spaces:
Running
Running
import json | |
import logging | |
from fastapi import APIRouter, Body | |
from typing import List, Dict | |
from pydantic import BaseModel | |
try: | |
from .tool_gpu_checker import gpu_checker_get_message | |
from .tool_bpy_doc import bpy_doc_get_documentation | |
from .tool_find_related import find_related | |
from .tool_wiki_search import wiki_search | |
except: | |
from tool_gpu_checker import gpu_checker_get_message | |
from tool_bpy_doc import bpy_doc_get_documentation | |
from tool_find_related import find_related | |
from tool_wiki_search import wiki_search | |
class ToolCallFunction(BaseModel): | |
name: str | |
arguments: str | |
class ToolCallInput(BaseModel): | |
id: str | |
type: str | |
function: ToolCallFunction | |
router = APIRouter() | |
def process_tool_call(tool_call: ToolCallInput) -> Dict: | |
output = {"tool_call_id": tool_call.id, "output": ""} | |
function_name = tool_call.function.name | |
try: | |
function_args = json.loads(tool_call.function.arguments) | |
if function_name == "get_bpy_api_info": | |
output["output"] = bpy_doc_get_documentation( | |
function_args.get("api", "")) | |
elif function_name == "check_gpu": | |
output["output"] = gpu_checker_get_message( | |
function_args.get("gpu", "")) | |
elif function_name == "find_related": | |
output["output"] = find_related( | |
function_args["repo"], function_args["number"]) | |
elif function_name == "wiki_search": | |
output["output"] = wiki_search(function_args["query"]) | |
except json.JSONDecodeError as e: | |
error_message = f"Malformed JSON encountered at position {e.pos}: {e.msg}\n {tool_call.function.arguments}" | |
output["output"] = error_message | |
# Logging the error for further investigation | |
logging.error(f"JSONDecodeError in process_tool_call: {error_message}") | |
return output | |
def function_call(tool_calls: List[ToolCallInput] = Body(..., description="List of tool calls in the request body")): | |
""" | |
Endpoint to process tool calls. | |
Args: | |
tool_calls (List[ToolCallInput]): List of tool calls. | |
Returns: | |
List[Dict]: List of tool outputs with tool_call_id and output. | |
""" | |
tool_outputs = [process_tool_call(tool_input) for tool_input in tool_calls] | |
return tool_outputs | |
if __name__ == "__main__": | |
tool_calls_data = [ | |
{ | |
"id": "call_abc123", | |
"type": "function", | |
"function": { | |
"name": "get_bpy_api_info", | |
"arguments": "{\"api\":\"bpy.context.scene.world\"}" | |
} | |
}, | |
{ | |
"id": "call_abc456", | |
"type": "function", | |
"function": { | |
"name": "check_gpu", | |
"arguments": "{\"gpu\":\"Mesa Intel(R) Iris(R) Plus Graphics 640 (Kaby Lake GT3e) (KBL GT3) Intel 4.6 (Core Profile) Mesa 22.2.5\"}" | |
} | |
}, | |
{ | |
"id": "call_abc789", | |
"type": "function", | |
"function": { | |
"name": "find_related", | |
"arguments": "{\"repo\":\"blender\",\"number\":111434}" | |
} | |
} | |
] | |
tool_calls = [ | |
ToolCallInput(id=tc['id'], type=tc['type'], | |
function=ToolCallFunction(**tc['function'])) | |
for tc in tool_calls_data | |
] | |
test = function_call(tool_calls) | |
print(test) | |