DuckDB-NSQL-7B / app.py
tdoehmen's picture
Update app.py
4375112 verified
raw
history blame
8.06 kB
import streamlit as st
import requests
import subprocess
import re
import sys
import urllib.request
import json
import os
import ssl
import time
PROMPT_TEMPLATE = """### Instruction:\n{instruction}\n\n### Input:\n{input}\n### Question:\n{question}\n\n### Response (use duckdb shorthand if possible):\n"""
INSTRUCTION_TEMPLATE = """Your task is to generate valid duckdb SQL to answer the following question{has_schema}""" # noqa: E501
ERROR_MESSAGE = ":red[ Quack! Much to our regret, SQL generation has gone a tad duck-side-down.\nThe model is currently not able to craft a correct SQL query for this request. \nSorry my duck friend. ]\n\n:red[If the question is about your own database, make sure to set the correct schema. Otherwise, try to rephrase your request. ]\n\n```sql\n{sql_query}\n```\n\n```sql\n{error_msg}\n```"
STOP_TOKENS = ["###", ";", "--", "```"]
def allowSelfSignedHttps(allowed):
# bypass the server certificate verification on client side
if allowed and not os.environ.get('PYTHONHTTPSVERIFY', '') and getattr(ssl, '_create_unverified_context', None):
ssl._create_default_https_context = ssl._create_unverified_context
allowSelfSignedHttps(True) # this line is needed if you use self-signed certificate in your scoring service.
def generate_prompt(question, schema):
input = ""
if schema:
# Lowercase types inside each CREATE TABLE (...) statement
for create_table in re.findall(
r"CREATE TABLE [^(]+\((.*?)\);", schema, flags=re.DOTALL | re.MULTILINE
):
for create_col in re.findall(r"(\w+) (\w+)", create_table):
schema = schema.replace(
f"{create_col[0]} {create_col[1]}",
f"{create_col[0]} {create_col[1].lower()}",
)
input = """Here is the database schema that the SQL query will run on:\n{schema}\n""".format( # noqa: E501
schema=schema
)
prompt = PROMPT_TEMPLATE.format(
instruction=INSTRUCTION_TEMPLATE.format(
has_schema="." if schema == "" else ", given a duckdb database schema."
),
input=input,
question=question,
)
return prompt
def generate_sql_azure(question, schema):
prompt = generate_prompt(question, schema)
start = time.time()
data={
"input_data": {
"input_string": [prompt],
"parameters":{
"top_p": 0.9,
"temperature": 0.1,
"max_new_tokens": 200,
"do_sample": True
}
}
}
body = str.encode(json.dumps(data))
url = 'https://motherduck-eu-west2-xbdfd.westeurope.inference.ml.azure.com/score'
headers = {'Content-Type':'application/json', 'Authorization':('Bearer '+ st.secrets['azure_ai_token']), 'azureml-model-deployment': 'motherduckdb-duckdb-nsql-7b-v-1' }
req = urllib.request.Request(url, body, headers)
raw_resp = urllib.request.urlopen(req)
resp = json.loads(raw_resp.read().decode("utf-8"))[0]["0"]
sql_query = resp[len(prompt):]
print(time.time()-start)
return sql_query
def generate_sql(question, schema):
print(question)
prompt = generate_prompt(question, schema)
start = time.time()
s = requests.Session()
api_base = "https://text-motherduck-sql-fp16-4vycuix6qcp2.octoai.run"
url = f"{api_base}/v1/completions"
body = {
"model": "motherduck-sql-fp16",
"prompt": prompt,
"temperature": 0.1,
"max_tokens": 200,
"stop": "<s>",
"n": 1,
}
headers = {"Authorization": f"Bearer {st.secrets['octoml_token']}"}
with s.post(url, json=body, headers=headers) as resp:
sql_query = resp.json()["choices"][0]["text"]
print(time.time()-start)
return sql_query
def validate_sql(query, schema):
try:
# Define subprocess
process = subprocess.Popen(
[sys.executable, './validate_sql.py', query, schema],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE
)
# Get output and potential parser, and binder error message
stdout, stderr = process.communicate(timeout=0.5)
if stderr:
error_message = stderr.decode('utf8').split("\n")
# skip traceback
if len(error_message) > 3:
error_message = "\n".join(error_message[3:])
return False, error_message
return True, ""
except subprocess.TimeoutExpired:
process.kill()
# timeout reached, so parsing and binding was very likely successful
return True, ""
st.title("DuckDB-NSQL-7B Demo")
expander = st.expander("Customize Schema (Optional)")
expander.markdown(
"If you DuckDB database is `database.duckdb`, execute this query in your terminal to get your current schema:"
)
expander.markdown(
"""```bash\necho ".schema" | duckdb database.duckdb | sed 's/(/(\\n /g' | sed 's/, /,\\n /g' | sed 's/);/\\n);\\n/g'\n```""",
)
# Input field for text prompt
default_schema = """CREATE TABLE rideshare(
hvfhs_license_num VARCHAR,
dispatching_base_num VARCHAR,
originating_base_num VARCHAR,
request_datetime TIMESTAMP,
on_scene_datetime TIMESTAMP,
pickup_datetime TIMESTAMP,
dropoff_datetime TIMESTAMP,
PULocationID BIGINT,
DOLocationID BIGINT,
trip_miles DOUBLE,
trip_time BIGINT,
base_passenger_fare DOUBLE,
tolls DOUBLE,
bcf DOUBLE,
sales_tax DOUBLE,
congestion_surcharge DOUBLE,
airport_fee DOUBLE,
tips DOUBLE,
driver_pay DOUBLE,
shared_request_flag VARCHAR,
shared_match_flag VARCHAR,
access_a_ride_flag VARCHAR,
wav_request_flag VARCHAR,
wav_match_flag VARCHAR
);
CREATE TABLE service_requests(
unique_key BIGINT,
created_date TIMESTAMP,
closed_date TIMESTAMP,
agency VARCHAR,
agency_name VARCHAR,
complaint_type VARCHAR,
descriptor VARCHAR,
location_type VARCHAR,
incident_zip VARCHAR,
incident_address VARCHAR,
street_name VARCHAR,
cross_street_1 VARCHAR,
cross_street_2 VARCHAR,
intersection_street_1 VARCHAR,
intersection_street_2 VARCHAR,
address_type VARCHAR,
city VARCHAR,
landmark VARCHAR,
facility_type VARCHAR,
status VARCHAR,
due_date TIMESTAMP,
resolution_description VARCHAR,
resolution_action_updated_date TIMESTAMP,
community_board VARCHAR,
bbl VARCHAR,
borough VARCHAR,
x_coordinate_state_plane VARCHAR,
y_coordinate_state_plane VARCHAR,
open_data_channel_type VARCHAR,
park_facility_name VARCHAR,
park_borough VARCHAR,
vehicle_type VARCHAR,
taxi_company_borough VARCHAR,
taxi_pick_up_location VARCHAR,
bridge_highway_name VARCHAR,
bridge_highway_direction VARCHAR,
road_ramp VARCHAR,
bridge_highway_segment VARCHAR,
latitude DOUBLE,
longitude DOUBLE
);
CREATE TABLE taxi(
VendorID BIGINT,
tpep_pickup_datetime TIMESTAMP,
tpep_dropoff_datetime TIMESTAMP,
passenger_count DOUBLE,
trip_distance DOUBLE,
RatecodeID DOUBLE,
store_and_fwd_flag VARCHAR,
PULocationID BIGINT,
DOLocationID BIGINT,
payment_type BIGINT,
fare_amount DOUBLE,
extra DOUBLE,
mta_tax DOUBLE,
tip_amount DOUBLE,
tolls_amount DOUBLE,
improvement_surcharge DOUBLE,
total_amount DOUBLE,
congestion_surcharge DOUBLE,
airport_fee DOUBLE,
drivers VARCHAR[],
speeding_tickets STRUCT(date TIMESTAMP, speed VARCHAR)[],
other_violations JSON
);"""
schema = expander.text_area("Current schema:", value=default_schema, height=500)
# Input field for text prompt
text_prompt = st.text_input(
"What DuckDB SQL query can I write for you?", value="Read a CSV file from test.csv"
)
if text_prompt:
sql_query = generate_sql(text_prompt, schema)
valid, msg = validate_sql(sql_query, schema)
if not valid:
st.markdown(ERROR_MESSAGE.format(sql_query=sql_query, error_msg=msg))
else:
st.markdown(f"""```sql\n{sql_query}\n```""")