File size: 7,199 Bytes
e67043b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
#!/usr/bin/env python
# coding=utf-8

import json
import os

from ..tool import Tool
from swarms.tools.database.utils.db_parser import get_conf
from swarms.tools.database.utils.database import DBArgs, Database

import openai

from typing import Optional, List, Mapping, Any
import requests, json


def build_database_tool(config) -> Tool:
    tool = Tool(
        "Data in a database",
        "Look up user data",
        name_for_model="Database",
        description_for_model="Plugin for querying the data in a database",
        logo_url="https://commons.wikimedia.org/wiki/File:Postgresql_elephant.svg",
        contact_email="[email protected]",
        legal_info_url="[email protected]",
    )

    URL_REWRITE = "http://8.131.229.55:5114/rewrite"

    # load db settings
    script_path = os.path.abspath(__file__)
    script_dir = os.path.dirname(script_path)
    config = get_conf(script_dir + "/my_config.ini", "postgresql")
    dbargs = DBArgs("postgresql", config=config)  # todo assign database name

    # send request to database
    db = Database(dbargs, timeout=-1)
    schema = ""
    query = ""

    @tool.get("/get_database_schema")
    # def get_database_schema(query : str='select * from customer limit 2;', db_name : str='tpch10x'):
    def get_database_schema(db_name: str = "tpch10x"):
        global schema

        # todo simplify the schema based on the query
        print("=========== database name:", db_name)
        schema = db.compute_table_schema()

        print("========schema:", schema)

        text_output = f"The database schema is:\n" + "".join(str(schema))

        return text_output

    @tool.get("/translate_nlp_to_sql")
    def translate_nlp_to_sql(description: str):
        global schema, query
        """translate_nlp_to_sql(description: str) translates the input nlp string into sql query based on the database schema, and the sql query is the input of rewrite_sql and select_database_data API.
            description is a string that represents the description of the result data.
            schema is a string that represents the database schema.
            Final answer should be complete.

            This is an example:
            Thoughts: Now that I have the database schema, I will use the \\\'translate_nlp_to_sql\\\' command to generate the SQL query based on the given description and schema, and take the SQL query as the input of the \\\'rewrite_sql\\\' and  \\\'select_database_data\\\' commands.
            Reasoning: I need to generate the SQL query accurately based on the given description. I will use the \\\'translate_nlp_to_sql\\\' command to obtain the SQL query based on the given description and schema, and take the SQL query as the input of the \\\'select_database_data\\\' command.
            Plan: - Use the \\\'translate_nlp_to_sql\\\' command to generate the SQL query. \\\\n- Use the \\\'finish\\\' command to signal that I have completed all my objectives.
            Command: {"name": "translate_nlp_to_sql", "args": {"description": "Retrieve the comments of suppliers . The results should be sorted in descending order based on the comments of the suppliers."}}
            Result: Command translate_nlp_to_sql returned: "SELECT s_comment FROM supplier BY s_comment DESC"
        """

        openai.api_key = os.environ["OPENAI_API_KEY"]
        # schema = db.compute_table_schema()

        prompt = """Translate the natural language description into an semantic equivalent SQL query.
        The table and column names used in the sql must exactly appear in the schema. Any other table and column names are unacceptable.
        The schema is:\n
        {}
        
        The description is:\n
        {}

        The SQL query is:
        """.format(
            schema, description
        )

        # Set up the OpenAI GPT-3 model
        model_engine = "gpt-3.5-turbo"

        prompt_response = openai.ChatCompletion.create(
            engine=model_engine,
            messages=[
                {
                    "role": "assistant",
                    "content": "The table schema is as follows: " + schema,
                },
                {"role": "user", "content": prompt},
            ],
        )
        output_text = prompt_response["choices"][0]["message"]["content"]

        query = output_text

        return output_text

    @tool.get("/select_database_data")
    def select_database_data(query: str):
        """select_database_data(query : str) Read the data stored in database based on the SQL query from the translate_nlp_to_sql API.
        query : str is a string that represents the SQL query outputted by the translate_nlp_to_sql API.
        Final answer should be complete.

        This is an example:
        Thoughts: Now that I have the database schema and SQL query, I will use the \\\'select_database_data\\\' command to retrieve the data from the database based on the SQL query
        Reasoning: I will use the \\\'select_database_data\\\' command to retrieve the data from the database based on the SQL query
        Plan: - Use the \\\'select_database_data\\\' command to retrieve the data from the database based on the SQL query.\\\\n- Use the \\\'finish\\\' command to signal that I have completed all my objectives.
        Command: {"name": "select_database_data", "args": {query: "SELECT s_comment FROM supplier BY s_comment DESC"}}
        Result: Command select_database_data returned: "The number of result rows is: 394"
        """

        if query == "":
            raise RuntimeError("SQL query is empty")

        print("=========== database query:", query)
        res_completion = db.pgsql_results(query)  # list format

        if res_completion == "<fail>":
            raise RuntimeError("Database query failed")

        # data = json.loads(str(res_completion).strip())
        if isinstance(res_completion, list):
            text_output = f"The number of result rows is: " + "".join(
                str(len(res_completion))
            )
        else:
            text_output = f"The number of result rows is: " + "".join(
                str(res_completion)
            )

        return text_output

    @tool.get("/rewrite_sql")
    def rewrite_sql(
        sql: str = "select distinct l_orderkey, sum(l_extendedprice + 3 + (1 - l_discount)) as revenue, o_orderkey, o_shippriority from customer, orders, lineitem where c_mktsegment = 'BUILDING' and c_custkey = o_custkey and l_orderkey = o_orderkey and o_orderdate < date '1995-03-15' and l_shipdate > date '1995-03-15' group by l_orderkey, o_orderkey, o_shippriority order by revenue desc, o_orderkey;",
    ):
        """Rewrite the input sql query"""

        param = {"sql": sql}
        print("Rewriter param:", param)
        headers = {"Content-Type": "application/json"}
        res_completion = requests.post(
            URL_REWRITE, data=json.dumps(param), headers=headers
        )

        # print("============ res_completion", res_completion.text)

        data = json.loads(res_completion.text.strip())
        data = data.get("data")
        text_output = f"Rewritten sql is:\n" + data.get("rewritten_sql")

        return text_output

    return tool