File size: 4,768 Bytes
8d2f9d4
 
 
 
 
 
 
 
 
f169c98
 
8d2f9d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import os
import httpx
from dotenv import load_dotenv
from typing import Dict, Any, Optional, List, Iterable
from datetime import datetime
import logging
import asyncio
import json
import google.generativeai as genai
import PIL.Image

# Import custom modules
from app.utils.load_env import ACCESS_TOKEN, WHATSAPP_API_URL, GEMNI_API, OPENAI_API
from app.utils.system_prompt import system_prompt, agentic_prompt
from google.generativeai.types import content_types
from testtool import ToolCallParser, FunctionExecutor
from app.services.search_engine import google_search, set_light_values

# Load environment variables
load_dotenv()

# Set up logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)

def tool_config_from_mode(mode: str, fns: Iterable[str] = ()):
    """
    Create a tool config with the specified function calling mode.
    """
    return content_types.to_tool_config(
        {"function_calling_config": {"mode": mode, "allowed_function_names": fns}}
    )

def transform_result_to_response(results: List[Dict[str, Any]]) -> Dict[str, Any]:
    """
    Transform a list of result objects into a structured response dictionary.
    """
    response = {}
    for res in results:
        if res.get("status") == "success":
            function_name = res.get("function")
            function_result = res.get("result")
            response[function_name] = function_result
        else:
            # Handle individual failures if necessary
            response[res.get("function", "unknown_function")] = {
                "error": "Function execution failed."
            }
    return response

async def process_tool_calls(input_string: str) -> List[Dict[str, Any]]:
    """
    Processes all tool calls extracted from the input string and executes them.
    """
    tool_calls = ToolCallParser.extract_tool_calls(input_string)
    logger.info(f"Extracted tool_calls: {tool_calls}")
    results = []
    for tool_call in tool_calls:
        result = await FunctionExecutor.call_function(tool_call)
        results.append(result)
    return results

async def main():
    # Define available functions and tool configuration
    available_functions = ["google_search", "set_light_values"]
    config = tool_config_from_mode("any", fns=available_functions)

    # Define chat history
    history = [{"role": "user", "parts": "This is the chat history so far"}]

    # Configure the Gemini API
    genai.configure(api_key=GEMNI_API)
    model = genai.GenerativeModel(
        "gemini-1.5-pro-002",
        system_instruction=agentic_prompt,
        tools=[google_search, set_light_values]
    )

    # Start chat with history
    chat = model.start_chat(history=history)

    # Send the user's message and await the response
    try:
        response = chat.send_message(
            "find the cheapest flight price from Medan to Jakarta on 1st January 2025",
            tool_config=config
        )
    except Exception as e:
        logger.error(f"Error sending message: {e}")
        return

    # Ensure that response.parts exists and is iterable
    if not hasattr(response, 'parts') or not isinstance(response.parts, Iterable):
        logger.error("Invalid response format: 'parts' attribute is missing or not iterable.")
        return

    # Convert response parts to a single input string
    input_string = "\n".join(str(part) for part in response.parts)
    logger.info(f"Input string for tool processing: {input_string}")

    # Process tool calls
    try:
        results = await process_tool_calls(input_string)
    except Exception as e:
        logger.error(f"Error processing tool calls: {e}")
        return

    # Log and print the results
    logger.info("Results from tool calls:")
    for result in results:
        logger.info(json.dumps(result, indent=4))
        print(json.dumps(result, indent=4))

    # Transform the results into the desired response format
    responses = transform_result_to_response(results)

    # Build the response parts for the chat
    try:
        response_parts = [
            genai.protos.Part(
                function_response=genai.protos.FunctionResponse(
                    name=fn,
                    response={"result": val}
                )
            )
            for fn, val in responses.items()
        ]
    except Exception as e:
        logger.error(f"Error building response parts: {e}")
        return

    # Send the function responses back to the chat
    try:
        final_response = chat.send_message(response_parts)
        print(final_response.text)
    except Exception as e:
        logger.error(f"Error sending final response: {e}")

if __name__ == "__main__":
    asyncio.run(main())