Spaces:
Running
Running
File size: 6,417 Bytes
57816fc 8a8e4f3 1d0f3f8 8a8e4f3 57816fc 1d0f3f8 a148b10 8a8e4f3 1d0f3f8 a148b10 1d0f3f8 a148b10 1d0f3f8 a148b10 1d0f3f8 a148b10 57816fc 1d0f3f8 57816fc 1d0f3f8 57816fc a148b10 57816fc 1d0f3f8 57816fc a148b10 1d0f3f8 57816fc b3eb06a 8a8e4f3 1d0f3f8 b3eb06a 57816fc 1d0f3f8 57816fc 1d0f3f8 57816fc 1d0f3f8 57816fc 1d0f3f8 b3eb06a 1d0f3f8 b3eb06a 1d0f3f8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 |
import streamlit as st
import requests
import subprocess
import re
import sys
PROMPT_TEMPLATE = """### Instruction:\n{instruction}\n\n### Input:\n{input}\n### Question:\n{question}\n\n### Response (use duckdb shorthand if possible):\n"""
INSTRUCTION_TEMPLATE = """Your task is to generate valid duckdb SQL to answer the following question{has_schema}""" # noqa: E501
ERROR_MESSAGE = ":red[ Quack! Much to our regret, SQL generation has gone a tad duck-side-down.\nThe model is currently not able to craft a correct SQL query for this request. \nSorry my duck friend. ]\n\n:red[ Try rephrasing the question/instruction. And if the question is about your own database, make sure to set the correct schema. ]\n\n```sql\n{sql_query}\n```\n\n```sql\n{error_msg}\n```"
STOP_TOKENS = ["###", ";", "--", "```"]
def generate_prompt(question, schema):
input = ""
if schema:
# Lowercase types inside each CREATE TABLE (...) statement
for create_table in re.findall(
r"CREATE TABLE [^(]+\((.*?)\);", schema, flags=re.DOTALL | re.MULTILINE
):
for create_col in re.findall(r"(\w+) (\w+)", create_table):
schema = schema.replace(
f"{create_col[0]} {create_col[1]}",
f"{create_col[0]} {create_col[1].lower()}",
)
input = """Here is the database schema that the SQL query will run on:\n{schema}\n""".format( # noqa: E501
schema=schema
)
prompt = PROMPT_TEMPLATE.format(
instruction=INSTRUCTION_TEMPLATE.format(
has_schema="." if schema == "" else ", given a duckdb database schema."
),
input=input,
question=question,
)
return prompt
def generate_sql(question, schema):
prompt = generate_prompt(question, schema)
s = requests.Session()
api_base = "https://text-motherduck-sql-fp16-4vycuix6qcp2.octoai.run"
url = f"{api_base}/v1/completions"
body = {
"model": "motherduck-sql-fp16",
"prompt": prompt,
"temperature": 0.1,
"max_tokens": 200,
"stop": "<s>",
"n": 1,
}
headers = {"Authorization": f"Bearer {st.secrets['octoml_token']}"}
with s.post(url, json=body, headers=headers) as resp:
sql_query = resp.json()["choices"][0]["text"]
return sql_query
def validate_sql(query, schema):
try:
# Define subprocess
process = subprocess.Popen(
[sys.executable, './validate_sql.py', query, schema],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE
)
# Get output and potential parser, and binder error message
stdout, stderr = process.communicate(timeout=0.5)
if stderr:
return False, stderr.decode('utf8')
return True, ""
except subprocess.TimeoutExpired:
process.kill()
# timeout reached, so parsing and binding was very likely successful
return True, ""
st.title("DuckDB-NSQL-7B Demo")
expander = st.expander("Customize Schema (Optional)")
expander.markdown(
"If you DuckDB database is `database.duckdb`, execute this query in your terminal to get your current schema:"
)
expander.markdown(
"""```bash\necho ".schema" | duckdb database.duckdb | sed 's/(/(\\n /g' | sed 's/, /,\\n /g' | sed 's/);/\\n);\\n/g'\n```""",
)
# Input field for text prompt
default_schema = """CREATE TABLE rideshare(
hvfhs_license_num VARCHAR,
dispatching_base_num VARCHAR,
originating_base_num VARCHAR,
request_datetime TIMESTAMP,
on_scene_datetime TIMESTAMP,
pickup_datetime TIMESTAMP,
dropoff_datetime TIMESTAMP,
PULocationID BIGINT,
DOLocationID BIGINT,
trip_miles DOUBLE,
trip_time BIGINT,
base_passenger_fare DOUBLE,
tolls DOUBLE,
bcf DOUBLE,
sales_tax DOUBLE,
congestion_surcharge DOUBLE,
airport_fee DOUBLE,
tips DOUBLE,
driver_pay DOUBLE,
shared_request_flag VARCHAR,
shared_match_flag VARCHAR,
access_a_ride_flag VARCHAR,
wav_request_flag VARCHAR,
wav_match_flag VARCHAR
);
CREATE TABLE service_requests(
unique_key BIGINT,
created_date TIMESTAMP,
closed_date TIMESTAMP,
agency VARCHAR,
agency_name VARCHAR,
complaint_type VARCHAR,
descriptor VARCHAR,
location_type VARCHAR,
incident_zip VARCHAR,
incident_address VARCHAR,
street_name VARCHAR,
cross_street_1 VARCHAR,
cross_street_2 VARCHAR,
intersection_street_1 VARCHAR,
intersection_street_2 VARCHAR,
address_type VARCHAR,
city VARCHAR,
landmark VARCHAR,
facility_type VARCHAR,
status VARCHAR,
due_date TIMESTAMP,
resolution_description VARCHAR,
resolution_action_updated_date TIMESTAMP,
community_board VARCHAR,
bbl VARCHAR,
borough VARCHAR,
x_coordinate_state_plane VARCHAR,
y_coordinate_state_plane VARCHAR,
open_data_channel_type VARCHAR,
park_facility_name VARCHAR,
park_borough VARCHAR,
vehicle_type VARCHAR,
taxi_company_borough VARCHAR,
taxi_pick_up_location VARCHAR,
bridge_highway_name VARCHAR,
bridge_highway_direction VARCHAR,
road_ramp VARCHAR,
bridge_highway_segment VARCHAR,
latitude DOUBLE,
longitude DOUBLE
);
CREATE TABLE taxi(
VendorID BIGINT,
tpep_pickup_datetime TIMESTAMP,
tpep_dropoff_datetime TIMESTAMP,
passenger_count DOUBLE,
trip_distance DOUBLE,
RatecodeID DOUBLE,
store_and_fwd_flag VARCHAR,
PULocationID BIGINT,
DOLocationID BIGINT,
payment_type BIGINT,
fare_amount DOUBLE,
extra DOUBLE,
mta_tax DOUBLE,
tip_amount DOUBLE,
tolls_amount DOUBLE,
improvement_surcharge DOUBLE,
total_amount DOUBLE,
congestion_surcharge DOUBLE,
airport_fee DOUBLE,
drivers VARCHAR[],
speeding_tickets STRUCT(date TIMESTAMP, speed VARCHAR)[],
other_violations JSON
);"""
schema = expander.text_area("Current schema:", value=default_schema, height=500)
# Input field for text prompt
text_prompt = st.text_input(
"What DuckDB SQL query can I write for you?", value="Read a CSV file from test.csv"
)
if text_prompt:
sql_query = generate_sql(text_prompt, schema)
valid, msg = validate_sql(sql_query, schema)
if not valid:
st.markdown(ERROR_MESSAGE.format(sql_query=sql_query, error_msg=msg))
else:
st.markdown(f"""```sql\n{sql_query}\n```""")
|