Athene-V2-Agent / example /vllm_v2_weather_agent.py
venkat-srinivasan-nexusflow's picture
Update example/vllm_v2_weather_agent.py
5e35ce5 verified
from dataclasses import dataclass
import json
from typing import List, Dict, Any, Optional
from openai import OpenAI
"""
EXAMPLE OUTPUT:
****************************************
RUNNING QUERY: What's the weather for Paris, TX in fahrenheit?
Step 1
----------------------------------------
Executing: get_geo_coordinates
Arguments: {'city': 'Paris', 'state': 'TX'}
Response: The coordinates for Paris, TX are: latitude 33.6609, longitude 95.5555
Step 2
----------------------------------------
Executing: get_current_weather
Arguments: {'latitude': [33.6609], 'longitude': [95.5555], 'unit': 'fahrenheit'}
Response: The weather is 85 degrees fahrenheit. It is partly cloudy, with highs in the 90's.
Step 3
----------------------------------------
Conversation Complete
****************************************
RUNNING QUERY: Who won the most recent PGA?
Step 1
----------------------------------------
Executing: no_relevant_function
Arguments: {'user_query_span': 'Who won the most recent PGA?'}
Response: No relevant function for your request was found. We will stop here.
Step 2
----------------------------------------
Conversation Complete
"""
@dataclass
class WeatherConfig:
"""Configuration for OpenAI and API settings"""
api_key: str = "" # FILL IN WITH YOUR VLLM_ENDPOINT_KEY
api_base: str = "" # FILL IN WITH YOUR VLLM_ENDPOINT
model: Optional[str] = None
max_steps: int = 5
class WeatherTools:
"""Collection of available tools/functions for the weather agent"""
@staticmethod
def get_current_weather(latitude: List[float], longitude: List[float], unit: str) -> str:
"""Get weather for given coordinates"""
# We are mocking the weather here, but in the real world, you will submit a request here.
return f"The weather is 85 degrees {unit}. It is partly cloudy, with highs in the 90's."
@staticmethod
def get_geo_coordinates(city: str, state: str) -> str:
"""Get coordinates for a given city"""
coordinates = {
"Dallas": {"TX": (32.7767, -96.7970)},
"San Francisco": {"CA": (37.7749, -122.4194)},
"Paris": {"TX": (33.6609, 95.5555)}
}
lat, lon = coordinates.get(city, {}).get(state, (0, 0))
# We are mocking the weather here, but in the real world, you will submit a request here.
return f"The coordinates for {city}, {state} are: latitude {lat}, longitude {lon}"
@staticmethod
def no_relevant_function(user_query_span : str) -> str:
return "No relevant function for your request was found. We will stop here."
class ToolRegistry:
"""Registry of available tools and their schemas"""
@property
def available_functions(self) -> Dict[str, callable]:
return {
"get_current_weather": WeatherTools.get_current_weather,
"get_geo_coordinates": WeatherTools.get_geo_coordinates,
"no_relevant_function" : WeatherTools.no_relevant_function,
}
@property
def tool_schemas(self) -> List[Dict[str, Any]]:
return [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location. Use exact coordinates.",
"parameters": {
"type": "object",
"properties": {
"latitude": {"type": "array", "description": "The latitude for the city."},
"longitude": {"type": "array", "description": "The longitude for the city."},
"unit": {
"type": "string",
"description": "The unit to fetch the temperature in",
"enum": ["celsius", "fahrenheit"]
}
},
"required": ["latitude", "longitude", "unit"]
}
}
},
{
"type": "function",
"function": {
"name": "get_geo_coordinates",
"description": "Get the latitude and longitude for a given city",
"parameters": {
"type": "object",
"properties": {
"city": {"type": "string", "description": "The city to find coordinates for"},
"state": {"type": "string", "description": "The two-letter state abbreviation"}
},
"required": ["city", "state"]
}
}
},
{
"type": "function",
"function" : {
"name": "no_relevant_function",
"description": "Call this when no other provided function can be called to answer the user query.",
"parameters": {
"type": "object",
"properties": {
"user_query_span": {
"type": "string",
"description": "The part of the user_query that cannot be answered by any other function calls."
}
},
"required": ["user_query_span"]
}
}
}
]
class WeatherAgent:
"""Main agent class that handles the conversation and tool execution"""
def __init__(self, config: WeatherConfig):
self.config = config
self.client = OpenAI(api_key=config.api_key, base_url=config.api_base)
self.tools = ToolRegistry()
self.messages = []
if not config.model:
models = self.client.models.list()
self.config.model = models.data[0].id
def _serialize_tool_call(self, tool_call) -> Dict[str, Any]:
"""Convert tool call to serializable format"""
return {
"id": tool_call.id,
"type": tool_call.type,
"function": {
"name": tool_call.function.name,
"arguments": tool_call.function.arguments
}
}
def process_tool_calls(self, message) -> None:
"""Process and execute tool calls from assistant"""
for tool_call in message.tool_calls:
function_name = tool_call.function.name
function_args = json.loads(tool_call.function.arguments)
print(f"\nExecuting: {function_name}")
print(f"Arguments: {function_args}")
function_response = self.tools.available_functions[function_name](**function_args)
print(f"Response: {function_response}")
self.messages.append({
"role": "tool",
"content": json.dumps(function_response),
"tool_call_id": tool_call.id,
"name": function_name
})
def run_conversation(self, initial_query: str) -> None:
"""Run the main conversation loop"""
self.messages = [{"role": "user", "content": initial_query}]
print ("\n" * 5)
print ("*" * 40)
print (f"RUNNING QUERY: {initial_query}")
for step in range(self.config.max_steps):
print(f"\nStep {step + 1}")
print("-" * 40)
response = self.client.chat.completions.create(
messages=self.messages,
model=self.config.model,
tools=self.tools.tool_schemas,
temperature=0.0,
)
message = response.choices[0].message
if not message.tool_calls:
print("Conversation Complete")
break
self.messages.append({
"role": "assistant",
"content": json.dumps(message.content),
"tool_calls": [self._serialize_tool_call(tc) for tc in message.tool_calls]
})
self.process_tool_calls(message)
if step >= self.config.max_steps - 1:
print("Maximum steps reached")
def main():
# Example usage
config = WeatherConfig()
agent = WeatherAgent(config)
agent.run_conversation("What's the weather for Paris, TX in fahrenheit?")
# Example OOD usage
agent.run_conversation("Who won the most recent PGA?")
if __name__ == "__main__":
main()