Spaces:
Runtime error
Runtime error
File size: 7,199 Bytes
e67043b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
#!/usr/bin/env python
# coding=utf-8
import json
import os
from ..tool import Tool
from swarms.tools.database.utils.db_parser import get_conf
from swarms.tools.database.utils.database import DBArgs, Database
import openai
from typing import Optional, List, Mapping, Any
import requests, json
def build_database_tool(config) -> Tool:
tool = Tool(
"Data in a database",
"Look up user data",
name_for_model="Database",
description_for_model="Plugin for querying the data in a database",
logo_url="https://commons.wikimedia.org/wiki/File:Postgresql_elephant.svg",
contact_email="[email protected]",
legal_info_url="[email protected]",
)
URL_REWRITE = "http://8.131.229.55:5114/rewrite"
# load db settings
script_path = os.path.abspath(__file__)
script_dir = os.path.dirname(script_path)
config = get_conf(script_dir + "/my_config.ini", "postgresql")
dbargs = DBArgs("postgresql", config=config) # todo assign database name
# send request to database
db = Database(dbargs, timeout=-1)
schema = ""
query = ""
@tool.get("/get_database_schema")
# def get_database_schema(query : str='select * from customer limit 2;', db_name : str='tpch10x'):
def get_database_schema(db_name: str = "tpch10x"):
global schema
# todo simplify the schema based on the query
print("=========== database name:", db_name)
schema = db.compute_table_schema()
print("========schema:", schema)
text_output = f"The database schema is:\n" + "".join(str(schema))
return text_output
@tool.get("/translate_nlp_to_sql")
def translate_nlp_to_sql(description: str):
global schema, query
"""translate_nlp_to_sql(description: str) translates the input nlp string into sql query based on the database schema, and the sql query is the input of rewrite_sql and select_database_data API.
description is a string that represents the description of the result data.
schema is a string that represents the database schema.
Final answer should be complete.
This is an example:
Thoughts: Now that I have the database schema, I will use the \\\'translate_nlp_to_sql\\\' command to generate the SQL query based on the given description and schema, and take the SQL query as the input of the \\\'rewrite_sql\\\' and \\\'select_database_data\\\' commands.
Reasoning: I need to generate the SQL query accurately based on the given description. I will use the \\\'translate_nlp_to_sql\\\' command to obtain the SQL query based on the given description and schema, and take the SQL query as the input of the \\\'select_database_data\\\' command.
Plan: - Use the \\\'translate_nlp_to_sql\\\' command to generate the SQL query. \\\\n- Use the \\\'finish\\\' command to signal that I have completed all my objectives.
Command: {"name": "translate_nlp_to_sql", "args": {"description": "Retrieve the comments of suppliers . The results should be sorted in descending order based on the comments of the suppliers."}}
Result: Command translate_nlp_to_sql returned: "SELECT s_comment FROM supplier BY s_comment DESC"
"""
openai.api_key = os.environ["OPENAI_API_KEY"]
# schema = db.compute_table_schema()
prompt = """Translate the natural language description into an semantic equivalent SQL query.
The table and column names used in the sql must exactly appear in the schema. Any other table and column names are unacceptable.
The schema is:\n
{}
The description is:\n
{}
The SQL query is:
""".format(
schema, description
)
# Set up the OpenAI GPT-3 model
model_engine = "gpt-3.5-turbo"
prompt_response = openai.ChatCompletion.create(
engine=model_engine,
messages=[
{
"role": "assistant",
"content": "The table schema is as follows: " + schema,
},
{"role": "user", "content": prompt},
],
)
output_text = prompt_response["choices"][0]["message"]["content"]
query = output_text
return output_text
@tool.get("/select_database_data")
def select_database_data(query: str):
"""select_database_data(query : str) Read the data stored in database based on the SQL query from the translate_nlp_to_sql API.
query : str is a string that represents the SQL query outputted by the translate_nlp_to_sql API.
Final answer should be complete.
This is an example:
Thoughts: Now that I have the database schema and SQL query, I will use the \\\'select_database_data\\\' command to retrieve the data from the database based on the SQL query
Reasoning: I will use the \\\'select_database_data\\\' command to retrieve the data from the database based on the SQL query
Plan: - Use the \\\'select_database_data\\\' command to retrieve the data from the database based on the SQL query.\\\\n- Use the \\\'finish\\\' command to signal that I have completed all my objectives.
Command: {"name": "select_database_data", "args": {query: "SELECT s_comment FROM supplier BY s_comment DESC"}}
Result: Command select_database_data returned: "The number of result rows is: 394"
"""
if query == "":
raise RuntimeError("SQL query is empty")
print("=========== database query:", query)
res_completion = db.pgsql_results(query) # list format
if res_completion == "<fail>":
raise RuntimeError("Database query failed")
# data = json.loads(str(res_completion).strip())
if isinstance(res_completion, list):
text_output = f"The number of result rows is: " + "".join(
str(len(res_completion))
)
else:
text_output = f"The number of result rows is: " + "".join(
str(res_completion)
)
return text_output
@tool.get("/rewrite_sql")
def rewrite_sql(
sql: str = "select distinct l_orderkey, sum(l_extendedprice + 3 + (1 - l_discount)) as revenue, o_orderkey, o_shippriority from customer, orders, lineitem where c_mktsegment = 'BUILDING' and c_custkey = o_custkey and l_orderkey = o_orderkey and o_orderdate < date '1995-03-15' and l_shipdate > date '1995-03-15' group by l_orderkey, o_orderkey, o_shippriority order by revenue desc, o_orderkey;",
):
"""Rewrite the input sql query"""
param = {"sql": sql}
print("Rewriter param:", param)
headers = {"Content-Type": "application/json"}
res_completion = requests.post(
URL_REWRITE, data=json.dumps(param), headers=headers
)
# print("============ res_completion", res_completion.text)
data = json.loads(res_completion.text.strip())
data = data.get("data")
text_output = f"Rewritten sql is:\n" + data.get("rewritten_sql")
return text_output
return tool
|