Spaces:
Sleeping
Sleeping
augray
commited on
Commit
•
57f2e80
1
Parent(s):
e602593
Working; needs refinements
Browse files
app.py
CHANGED
@@ -14,6 +14,20 @@ logger = logging.getLogger(__name__)
|
|
14 |
example = HuggingfaceHubSearch().example_value()
|
15 |
|
16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
def get_iframe(hub_repo_id, sql_query=None):
|
18 |
if not hub_repo_id:
|
19 |
raise ValueError("Hub repo id is required")
|
@@ -75,9 +89,8 @@ def get_table_name(
|
|
75 |
return table_name.lower()
|
76 |
|
77 |
|
78 |
-
def
|
79 |
card_data: dict[str, Any],
|
80 |
-
natural_language_query: str,
|
81 |
config: str | None,
|
82 |
split: str | None,
|
83 |
):
|
@@ -86,23 +99,10 @@ def get_prompt_messages(
|
|
86 |
|
87 |
table_name = get_table_name(config, split, config_choices, split_choices)
|
88 |
features = card_data[config]["features"]
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
},
|
94 |
-
{
|
95 |
-
"role": "user",
|
96 |
-
"content": f"""table {table_name}
|
97 |
-
# Features
|
98 |
-
{features}
|
99 |
-
|
100 |
-
# Query
|
101 |
-
{natural_language_query}
|
102 |
-
""",
|
103 |
-
},
|
104 |
-
]
|
105 |
-
return messages
|
106 |
|
107 |
|
108 |
def get_config_choices(card_data: dict[str, Any]) -> list[str]:
|
@@ -117,9 +117,32 @@ def get_split_choices(card_data: dict[str, Any]) -> list[str]:
|
|
117 |
return list(splits)
|
118 |
|
119 |
|
120 |
-
def query_dataset(hub_repo_id, card_data, query, config, split):
|
|
|
|
|
121 |
card_data = json.loads(card_data)
|
122 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
123 |
api_key = os.environ["API_KEY_TOGETHER_AI"].strip()
|
124 |
response = requests.post(
|
125 |
"https://api.together.xyz/v1/chat/completions",
|
@@ -142,7 +165,8 @@ def query_dataset(hub_repo_id, card_data, query, config, split):
|
|
142 |
response_dict = response.json()
|
143 |
duck_query = response_dict["choices"][0]["message"]["content"]
|
144 |
duck_query = _sanitize_duck_query(duck_query)
|
145 |
-
|
|
|
146 |
|
147 |
|
148 |
def _sanitize_duck_query(duck_query: str) -> str:
|
@@ -176,10 +200,8 @@ with gr.Blocks() as demo:
|
|
176 |
sumbit_on_select=True,
|
177 |
)
|
178 |
with gr.Row():
|
179 |
-
|
180 |
-
|
181 |
-
placeholder="Enter a natural language query to generate SQL",
|
182 |
-
)
|
183 |
sql_out = gr.Code(
|
184 |
label="DuckDB SQL Query",
|
185 |
interactive=True,
|
@@ -212,15 +234,22 @@ with gr.Blocks() as demo:
|
|
212 |
label="Split Name", choices=split_choices, value=initial_split
|
213 |
)
|
214 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
215 |
with gr.Row():
|
216 |
with gr.Column():
|
217 |
-
|
218 |
with gr.Column():
|
219 |
-
|
220 |
with gr.Row():
|
221 |
search_out = gr.HTML(label="Search Results")
|
222 |
gr.on(
|
223 |
-
[
|
224 |
fn=get_iframe,
|
225 |
inputs=[search_in],
|
226 |
outputs=[search_out],
|
@@ -230,11 +259,19 @@ with gr.Blocks() as demo:
|
|
230 |
outputs=[card_data],
|
231 |
)
|
232 |
gr.on(
|
233 |
-
[
|
234 |
fn=query_dataset,
|
235 |
-
inputs=[
|
236 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
237 |
)
|
|
|
238 |
|
239 |
|
240 |
if __name__ == "__main__":
|
|
|
14 |
example = HuggingfaceHubSearch().example_value()
|
15 |
|
16 |
|
17 |
+
SYSTEM_PROMPT_TEMPLATE = (
|
18 |
+
"You are a SQL query expert assistant that returns a DuckDB SQL queries "
|
19 |
+
"based on the user's natural language query and dataset features. "
|
20 |
+
"You might need to use DuckDB functions for lists and aggregations, "
|
21 |
+
"given the features. Only return the SQL query, no other text. The "
|
22 |
+
"user may ask you to make various adjustments to the query. Every "
|
23 |
+
"time your response should only include the refined SQL query and "
|
24 |
+
"nothing else.\n\n"
|
25 |
+
"The table being queried is named: {table_name}.\n\n"
|
26 |
+
"# Features\n"
|
27 |
+
"{features}"
|
28 |
+
)
|
29 |
+
|
30 |
+
|
31 |
def get_iframe(hub_repo_id, sql_query=None):
|
32 |
if not hub_repo_id:
|
33 |
raise ValueError("Hub repo id is required")
|
|
|
89 |
return table_name.lower()
|
90 |
|
91 |
|
92 |
+
def get_system_prompt(
|
93 |
card_data: dict[str, Any],
|
|
|
94 |
config: str | None,
|
95 |
split: str | None,
|
96 |
):
|
|
|
99 |
|
100 |
table_name = get_table_name(config, split, config_choices, split_choices)
|
101 |
features = card_data[config]["features"]
|
102 |
+
return SYSTEM_PROMPT_TEMPLATE.format(
|
103 |
+
table_name=table_name,
|
104 |
+
features=features,
|
105 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
|
107 |
|
108 |
def get_config_choices(card_data: dict[str, Any]) -> list[str]:
|
|
|
117 |
return list(splits)
|
118 |
|
119 |
|
120 |
+
def query_dataset(hub_repo_id, card_data, query, config, split, history):
|
121 |
+
if card_data is None or len(card_data) == 0:
|
122 |
+
return "", get_iframe(hub_repo_id), []
|
123 |
card_data = json.loads(card_data)
|
124 |
+
system_prompt = get_system_prompt(card_data, config, split)
|
125 |
+
messages = [{"role": "system", "content": system_prompt}]
|
126 |
+
for turn in history:
|
127 |
+
user, assistant = turn
|
128 |
+
messages.append(
|
129 |
+
{
|
130 |
+
"role": "user",
|
131 |
+
"content": user,
|
132 |
+
}
|
133 |
+
)
|
134 |
+
messages.append(
|
135 |
+
{
|
136 |
+
"role": "assistant",
|
137 |
+
"content": assistant,
|
138 |
+
}
|
139 |
+
)
|
140 |
+
messages.append(
|
141 |
+
{
|
142 |
+
"role": "user",
|
143 |
+
"content": query,
|
144 |
+
}
|
145 |
+
)
|
146 |
api_key = os.environ["API_KEY_TOGETHER_AI"].strip()
|
147 |
response = requests.post(
|
148 |
"https://api.together.xyz/v1/chat/completions",
|
|
|
165 |
response_dict = response.json()
|
166 |
duck_query = response_dict["choices"][0]["message"]["content"]
|
167 |
duck_query = _sanitize_duck_query(duck_query)
|
168 |
+
history.append((query, duck_query))
|
169 |
+
return duck_query, get_iframe(hub_repo_id, duck_query), history
|
170 |
|
171 |
|
172 |
def _sanitize_duck_query(duck_query: str) -> str:
|
|
|
200 |
sumbit_on_select=True,
|
201 |
)
|
202 |
with gr.Row():
|
203 |
+
show_btn = gr.Button("Show Dataset")
|
204 |
+
with gr.Row():
|
|
|
|
|
205 |
sql_out = gr.Code(
|
206 |
label="DuckDB SQL Query",
|
207 |
interactive=True,
|
|
|
234 |
label="Split Name", choices=split_choices, value=initial_split
|
235 |
)
|
236 |
|
237 |
+
with gr.Accordion("Query Suggestion History.", open=False) as accordion:
|
238 |
+
chatbot = gr.Chatbot(height=200, layout="bubble")
|
239 |
+
with gr.Row():
|
240 |
+
query = gr.Textbox(
|
241 |
+
label="Query Description",
|
242 |
+
placeholder="Enter a natural language query to generate SQL",
|
243 |
+
)
|
244 |
with gr.Row():
|
245 |
with gr.Column():
|
246 |
+
query_btn = gr.Button("Get Suggested Query")
|
247 |
with gr.Column():
|
248 |
+
clear = gr.ClearButton([query, chatbot], value="Reset Query History")
|
249 |
with gr.Row():
|
250 |
search_out = gr.HTML(label="Search Results")
|
251 |
gr.on(
|
252 |
+
[show_btn.click, search_in.submit],
|
253 |
fn=get_iframe,
|
254 |
inputs=[search_in],
|
255 |
outputs=[search_out],
|
|
|
259 |
outputs=[card_data],
|
260 |
)
|
261 |
gr.on(
|
262 |
+
[query_btn.click, query.submit],
|
263 |
fn=query_dataset,
|
264 |
+
inputs=[
|
265 |
+
search_in,
|
266 |
+
card_data,
|
267 |
+
query,
|
268 |
+
config_selection,
|
269 |
+
split_selection,
|
270 |
+
chatbot,
|
271 |
+
],
|
272 |
+
outputs=[sql_out, search_out, chatbot],
|
273 |
)
|
274 |
+
gr.on([query_btn.click], fn=lambda: gr.update(open=True), outputs=[accordion])
|
275 |
|
276 |
|
277 |
if __name__ == "__main__":
|