|
from dataclasses import dataclass |
|
import json |
|
from typing import List, Dict, Any, Optional |
|
from openai import OpenAI |
|
""" |
|
EXAMPLE OUTPUT: |
|
|
|
**************************************** |
|
RUNNING QUERY: What's the weather for Paris, TX in fahrenheit? |
|
Step 1 |
|
---------------------------------------- |
|
|
|
Executing: get_geo_coordinates |
|
Arguments: {'city': 'Paris', 'state': 'TX'} |
|
Response: The coordinates for Paris, TX are: latitude 33.6609, longitude 95.5555 |
|
|
|
Step 2 |
|
---------------------------------------- |
|
|
|
Executing: get_current_weather |
|
Arguments: {'latitude': [33.6609], 'longitude': [95.5555], 'unit': 'fahrenheit'} |
|
Response: The weather is 85 degrees fahrenheit. It is partly cloudy, with highs in the 90's. |
|
|
|
Step 3 |
|
---------------------------------------- |
|
Conversation Complete |
|
|
|
|
|
**************************************** |
|
RUNNING QUERY: Who won the most recent PGA? |
|
Step 1 |
|
---------------------------------------- |
|
|
|
Executing: no_relevant_function |
|
Arguments: {'user_query_span': 'Who won the most recent PGA?'} |
|
Response: No relevant function for your request was found. We will stop here. |
|
|
|
Step 2 |
|
---------------------------------------- |
|
Conversation Complete |
|
""" |
|
|
|
@dataclass |
|
class WeatherConfig: |
|
"""Configuration for OpenAI and API settings""" |
|
api_key: str = "" |
|
api_base: str = "" |
|
model: Optional[str] = None |
|
max_steps: int = 5 |
|
|
|
class WeatherTools: |
|
"""Collection of available tools/functions for the weather agent""" |
|
|
|
@staticmethod |
|
def get_current_weather(latitude: List[float], longitude: List[float], unit: str) -> str: |
|
"""Get weather for given coordinates""" |
|
|
|
return f"The weather is 85 degrees {unit}. It is partly cloudy, with highs in the 90's." |
|
|
|
@staticmethod |
|
def get_geo_coordinates(city: str, state: str) -> str: |
|
"""Get coordinates for a given city""" |
|
coordinates = { |
|
"Dallas": {"TX": (32.7767, -96.7970)}, |
|
"San Francisco": {"CA": (37.7749, -122.4194)}, |
|
"Paris": {"TX": (33.6609, 95.5555)} |
|
} |
|
lat, lon = coordinates.get(city, {}).get(state, (0, 0)) |
|
|
|
return f"The coordinates for {city}, {state} are: latitude {lat}, longitude {lon}" |
|
|
|
@staticmethod |
|
def no_relevant_function(user_query_span : str) -> str: |
|
return "No relevant function for your request was found. We will stop here." |
|
|
|
class ToolRegistry: |
|
"""Registry of available tools and their schemas""" |
|
|
|
@property |
|
def available_functions(self) -> Dict[str, callable]: |
|
return { |
|
"get_current_weather": WeatherTools.get_current_weather, |
|
"get_geo_coordinates": WeatherTools.get_geo_coordinates, |
|
"no_relevant_function" : WeatherTools.no_relevant_function, |
|
} |
|
|
|
@property |
|
def tool_schemas(self) -> List[Dict[str, Any]]: |
|
return [ |
|
{ |
|
"type": "function", |
|
"function": { |
|
"name": "get_current_weather", |
|
"description": "Get the current weather in a given location. Use exact coordinates.", |
|
"parameters": { |
|
"type": "object", |
|
"properties": { |
|
"latitude": {"type": "array", "description": "The latitude for the city."}, |
|
"longitude": {"type": "array", "description": "The longitude for the city."}, |
|
"unit": { |
|
"type": "string", |
|
"description": "The unit to fetch the temperature in", |
|
"enum": ["celsius", "fahrenheit"] |
|
} |
|
}, |
|
"required": ["latitude", "longitude", "unit"] |
|
} |
|
} |
|
}, |
|
{ |
|
"type": "function", |
|
"function": { |
|
"name": "get_geo_coordinates", |
|
"description": "Get the latitude and longitude for a given city", |
|
"parameters": { |
|
"type": "object", |
|
"properties": { |
|
"city": {"type": "string", "description": "The city to find coordinates for"}, |
|
"state": {"type": "string", "description": "The two-letter state abbreviation"} |
|
}, |
|
"required": ["city", "state"] |
|
} |
|
} |
|
}, |
|
{ |
|
"type": "function", |
|
"function" : { |
|
"name": "no_relevant_function", |
|
"description": "Call this when no other provided function can be called to answer the user query.", |
|
"parameters": { |
|
"type": "object", |
|
"properties": { |
|
"user_query_span": { |
|
"type": "string", |
|
"description": "The part of the user_query that cannot be answered by any other function calls." |
|
} |
|
}, |
|
"required": ["user_query_span"] |
|
} |
|
} |
|
} |
|
] |
|
|
|
class WeatherAgent: |
|
"""Main agent class that handles the conversation and tool execution""" |
|
|
|
def __init__(self, config: WeatherConfig): |
|
self.config = config |
|
self.client = OpenAI(api_key=config.api_key, base_url=config.api_base) |
|
self.tools = ToolRegistry() |
|
self.messages = [] |
|
|
|
if not config.model: |
|
models = self.client.models.list() |
|
self.config.model = models.data[0].id |
|
|
|
def _serialize_tool_call(self, tool_call) -> Dict[str, Any]: |
|
"""Convert tool call to serializable format""" |
|
return { |
|
"id": tool_call.id, |
|
"type": tool_call.type, |
|
"function": { |
|
"name": tool_call.function.name, |
|
"arguments": tool_call.function.arguments |
|
} |
|
} |
|
|
|
def process_tool_calls(self, message) -> None: |
|
"""Process and execute tool calls from assistant""" |
|
for tool_call in message.tool_calls: |
|
function_name = tool_call.function.name |
|
function_args = json.loads(tool_call.function.arguments) |
|
|
|
print(f"\nExecuting: {function_name}") |
|
print(f"Arguments: {function_args}") |
|
|
|
function_response = self.tools.available_functions[function_name](**function_args) |
|
print(f"Response: {function_response}") |
|
|
|
self.messages.append({ |
|
"role": "tool", |
|
"content": json.dumps(function_response), |
|
"tool_call_id": tool_call.id, |
|
"name": function_name |
|
}) |
|
|
|
def run_conversation(self, initial_query: str) -> None: |
|
"""Run the main conversation loop""" |
|
self.messages = [{"role": "user", "content": initial_query}] |
|
|
|
print ("\n" * 5) |
|
print ("*" * 40) |
|
print (f"RUNNING QUERY: {initial_query}") |
|
|
|
for step in range(self.config.max_steps): |
|
print(f"\nStep {step + 1}") |
|
print("-" * 40) |
|
|
|
response = self.client.chat.completions.create( |
|
messages=self.messages, |
|
model=self.config.model, |
|
tools=self.tools.tool_schemas, |
|
temperature=0.0, |
|
) |
|
|
|
message = response.choices[0].message |
|
|
|
if not message.tool_calls: |
|
print("Conversation Complete") |
|
break |
|
|
|
self.messages.append({ |
|
"role": "assistant", |
|
"content": json.dumps(message.content), |
|
"tool_calls": [self._serialize_tool_call(tc) for tc in message.tool_calls] |
|
}) |
|
|
|
self.process_tool_calls(message) |
|
|
|
if step >= self.config.max_steps - 1: |
|
print("Maximum steps reached") |
|
|
|
def main(): |
|
|
|
config = WeatherConfig() |
|
agent = WeatherAgent(config) |
|
agent.run_conversation("What's the weather for Paris, TX in fahrenheit?") |
|
|
|
|
|
agent.run_conversation("Who won the most recent PGA?") |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|