lawyer_assitance / examples /test /test_assistant /test_assistant_function_call.py
qgyd2021's picture
[update]add function calling test
44e1a5b
raw
history blame
6.39 kB
#!/usr/bin/python3
# -*- coding: utf-8 -*-
"""
https://platform.openai.com/docs/assistants/tools/function-calling
"""
import argparse
import json
import time
from openai import OpenAI
from openai.pagination import SyncCursorPage
from openai.types.beta.threads import ThreadMessage
from openai.types.beta.assistant import Assistant
from project_settings import environment, project_path
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--openai_api_key",
default=environment.get("openai_api_key", default=None, dtype=str),
type=str
)
args = parser.parse_args()
return args
def get_current_weather(location, unit="fahrenheit"):
if "tokyo" in location.lower():
return json.dumps({"location": location, "temperature": "10", "unit": "celsius"})
elif "san francisco" in location.lower():
return json.dumps({"location": location, "temperature": "72", "unit": "fahrenheit"})
else:
return json.dumps({"location": location, "temperature": "22", "unit": "celsius"})
available_functions = {
"get_current_weather": get_current_weather,
}
def main():
"""
assistant.id: asst_9iUOSeG3dUgzBxYqfygvtKLi
thread.id: thread_9C4dDj5i4jDCtkMCujyBleOc
"""
args = get_args()
client = OpenAI(
api_key=args.openai_api_key
)
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
},
}
]
tools_ = json.dumps(tools, ensure_ascii=False)
print(tools_.replace("\"", "\\\""))
exit(0)
assistant = client.beta.assistants.create(
instructions="You are a weather bot. Use the provided functions to answer questions.",
model="gpt-4-1106-preview",
tools=tools
)
print(f"assistant.id: {assistant.id}")
thread = client.beta.threads.create()
print(f"thread.id: {thread.id}")
message = client.beta.threads.messages.create(
thread_id=thread.id,
role="user",
content="what's the whether San Francisco"
)
run = client.beta.threads.runs.create(
thread_id=thread.id,
assistant_id=assistant.id,
instructions="Please address the user as Jane Doe. The user has a premium account."
)
delta_time = 0.3
no_update_count = 0
max_no_update_count = 10
while True:
if no_update_count > max_no_update_count:
break
time.sleep(delta_time)
run = client.beta.threads.runs.retrieve(
thread_id=thread.id,
run_id=run.id
)
print("run.required_action: {}".format(run.required_action))
if run.required_action is None:
no_update_count += 1
continue
else:
if run.required_action.type != "submit_tool_outputs":
raise AssertionError
tool_outputs = list()
for tool_call in run.required_action.submit_tool_outputs.tool_calls:
function_name = tool_call.function.name
function_to_call = available_functions[function_name]
function_args = json.loads(tool_call.function.arguments)
function_response = function_to_call(
location=function_args.get("location"),
unit=function_args.get("unit"),
)
tool_outputs.append({
"tool_call_id": tool_call.id,
"output": function_response,
})
run = client.beta.threads.runs.submit_tool_outputs(
thread_id=thread.id,
run_id=run.id,
tool_outputs=tool_outputs
)
no_update_count = 0
# wait complete
while True:
time.sleep(delta_time)
run = client.beta.threads.runs.retrieve(
thread_id=thread.id,
run_id=run.id
)
print(run.created_at)
print(run.started_at)
print(run.completed_at)
print(run.failed_at)
print(run.expires_at)
print(run.cancelled_at)
if run.completed_at is not None:
break
if run.failed_at is not None:
break
if run.expires_at is not None:
break
if run.cancelled_at is not None:
break
# messages
messages = client.beta.threads.messages.list(
thread_id=thread.id
)
messages = messages.model_dump(mode="json")
messages = json.dumps(messages, indent=4, ensure_ascii=False)
print(messages)
return
def main2():
"""
assistant.id: asst_OrPcAueQLrLYxtksFaPVVeJo
thread.id: thread_2oJCtoSCYgguOhdssafJM7ab
run: run_cA8DtX8EnoVGhmvu4VrvF63O
run.required_action: None
run id: 2090622954288
run: run_cA8DtX8EnoVGhmvu4VrvF63O
run.required_action: None
run id: 2090623149056
run.required_action: RequiredAction(submit_tool_outputs=RequiredActionSubmitToolOutputs(tool_calls=[RequiredActionFunctionToolCall(id='call_jalze5uKemfrnkPiJPRehVt0', function=Function(arguments='{"location":"San Francisco, CA"}', name='getCurrentWeather'), type='function')]), type='submit_tool_outputs')
"""
args = get_args()
client = OpenAI(
api_key=args.openai_api_key
)
thread_id = "thread_2oJCtoSCYgguOhdssafJM7ab"
run_id = "run_cA8DtX8EnoVGhmvu4VrvF63O"
run = client.beta.threads.runs.retrieve(
thread_id=thread_id,
run_id=run_id
)
print("run: {}".format(run.id))
print("run.required_action: {}".format(run.required_action))
print("run id: {}".format(id(run)))
messages = client.beta.threads.messages.list(
thread_id=thread_id
)
print(messages)
return
if __name__ == '__main__':
main()