augray commited on
Commit
57f2e80
1 Parent(s): e602593

Working; needs refinements

Browse files
Files changed (1) hide show
  1. app.py +69 -32
app.py CHANGED
@@ -14,6 +14,20 @@ logger = logging.getLogger(__name__)
14
  example = HuggingfaceHubSearch().example_value()
15
 
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  def get_iframe(hub_repo_id, sql_query=None):
18
  if not hub_repo_id:
19
  raise ValueError("Hub repo id is required")
@@ -75,9 +89,8 @@ def get_table_name(
75
  return table_name.lower()
76
 
77
 
78
- def get_prompt_messages(
79
  card_data: dict[str, Any],
80
- natural_language_query: str,
81
  config: str | None,
82
  split: str | None,
83
  ):
@@ -86,23 +99,10 @@ def get_prompt_messages(
86
 
87
  table_name = get_table_name(config, split, config_choices, split_choices)
88
  features = card_data[config]["features"]
89
- messages = [
90
- {
91
- "role": "system",
92
- "content": "You are a SQL query expert assistant that returns a DuckDB SQL queries based on the user's natural language query and dataset features. You might need to use DuckDB functions for lists and aggregations, given the features. Only return the SQL query, no other text.",
93
- },
94
- {
95
- "role": "user",
96
- "content": f"""table {table_name}
97
- # Features
98
- {features}
99
-
100
- # Query
101
- {natural_language_query}
102
- """,
103
- },
104
- ]
105
- return messages
106
 
107
 
108
  def get_config_choices(card_data: dict[str, Any]) -> list[str]:
@@ -117,9 +117,32 @@ def get_split_choices(card_data: dict[str, Any]) -> list[str]:
117
  return list(splits)
118
 
119
 
120
- def query_dataset(hub_repo_id, card_data, query, config, split):
 
 
121
  card_data = json.loads(card_data)
122
- messages = get_prompt_messages(card_data, query, config, split)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  api_key = os.environ["API_KEY_TOGETHER_AI"].strip()
124
  response = requests.post(
125
  "https://api.together.xyz/v1/chat/completions",
@@ -142,7 +165,8 @@ def query_dataset(hub_repo_id, card_data, query, config, split):
142
  response_dict = response.json()
143
  duck_query = response_dict["choices"][0]["message"]["content"]
144
  duck_query = _sanitize_duck_query(duck_query)
145
- return duck_query, get_iframe(hub_repo_id, duck_query)
 
146
 
147
 
148
  def _sanitize_duck_query(duck_query: str) -> str:
@@ -176,10 +200,8 @@ with gr.Blocks() as demo:
176
  sumbit_on_select=True,
177
  )
178
  with gr.Row():
179
- query = gr.Textbox(
180
- label="Natural Language Query",
181
- placeholder="Enter a natural language query to generate SQL",
182
- )
183
  sql_out = gr.Code(
184
  label="DuckDB SQL Query",
185
  interactive=True,
@@ -212,15 +234,22 @@ with gr.Blocks() as demo:
212
  label="Split Name", choices=split_choices, value=initial_split
213
  )
214
 
 
 
 
 
 
 
 
215
  with gr.Row():
216
  with gr.Column():
217
- btn = gr.Button("Show Dataset")
218
  with gr.Column():
219
- btn2 = gr.Button("Query Dataset")
220
  with gr.Row():
221
  search_out = gr.HTML(label="Search Results")
222
  gr.on(
223
- [btn.click, search_in.submit],
224
  fn=get_iframe,
225
  inputs=[search_in],
226
  outputs=[search_out],
@@ -230,11 +259,19 @@ with gr.Blocks() as demo:
230
  outputs=[card_data],
231
  )
232
  gr.on(
233
- [btn2.click, query.submit],
234
  fn=query_dataset,
235
- inputs=[search_in, card_data, query, config_selection, split_selection],
236
- outputs=[sql_out, search_out],
 
 
 
 
 
 
 
237
  )
 
238
 
239
 
240
  if __name__ == "__main__":
 
14
  example = HuggingfaceHubSearch().example_value()
15
 
16
 
17
+ SYSTEM_PROMPT_TEMPLATE = (
18
+ "You are a SQL query expert assistant that returns a DuckDB SQL queries "
19
+ "based on the user's natural language query and dataset features. "
20
+ "You might need to use DuckDB functions for lists and aggregations, "
21
+ "given the features. Only return the SQL query, no other text. The "
22
+ "user may ask you to make various adjustments to the query. Every "
23
+ "time your response should only include the refined SQL query and "
24
+ "nothing else.\n\n"
25
+ "The table being queried is named: {table_name}.\n\n"
26
+ "# Features\n"
27
+ "{features}"
28
+ )
29
+
30
+
31
  def get_iframe(hub_repo_id, sql_query=None):
32
  if not hub_repo_id:
33
  raise ValueError("Hub repo id is required")
 
89
  return table_name.lower()
90
 
91
 
92
+ def get_system_prompt(
93
  card_data: dict[str, Any],
 
94
  config: str | None,
95
  split: str | None,
96
  ):
 
99
 
100
  table_name = get_table_name(config, split, config_choices, split_choices)
101
  features = card_data[config]["features"]
102
+ return SYSTEM_PROMPT_TEMPLATE.format(
103
+ table_name=table_name,
104
+ features=features,
105
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
 
108
  def get_config_choices(card_data: dict[str, Any]) -> list[str]:
 
117
  return list(splits)
118
 
119
 
120
+ def query_dataset(hub_repo_id, card_data, query, config, split, history):
121
+ if card_data is None or len(card_data) == 0:
122
+ return "", get_iframe(hub_repo_id), []
123
  card_data = json.loads(card_data)
124
+ system_prompt = get_system_prompt(card_data, config, split)
125
+ messages = [{"role": "system", "content": system_prompt}]
126
+ for turn in history:
127
+ user, assistant = turn
128
+ messages.append(
129
+ {
130
+ "role": "user",
131
+ "content": user,
132
+ }
133
+ )
134
+ messages.append(
135
+ {
136
+ "role": "assistant",
137
+ "content": assistant,
138
+ }
139
+ )
140
+ messages.append(
141
+ {
142
+ "role": "user",
143
+ "content": query,
144
+ }
145
+ )
146
  api_key = os.environ["API_KEY_TOGETHER_AI"].strip()
147
  response = requests.post(
148
  "https://api.together.xyz/v1/chat/completions",
 
165
  response_dict = response.json()
166
  duck_query = response_dict["choices"][0]["message"]["content"]
167
  duck_query = _sanitize_duck_query(duck_query)
168
+ history.append((query, duck_query))
169
+ return duck_query, get_iframe(hub_repo_id, duck_query), history
170
 
171
 
172
  def _sanitize_duck_query(duck_query: str) -> str:
 
200
  sumbit_on_select=True,
201
  )
202
  with gr.Row():
203
+ show_btn = gr.Button("Show Dataset")
204
+ with gr.Row():
 
 
205
  sql_out = gr.Code(
206
  label="DuckDB SQL Query",
207
  interactive=True,
 
234
  label="Split Name", choices=split_choices, value=initial_split
235
  )
236
 
237
+ with gr.Accordion("Query Suggestion History.", open=False) as accordion:
238
+ chatbot = gr.Chatbot(height=200, layout="bubble")
239
+ with gr.Row():
240
+ query = gr.Textbox(
241
+ label="Query Description",
242
+ placeholder="Enter a natural language query to generate SQL",
243
+ )
244
  with gr.Row():
245
  with gr.Column():
246
+ query_btn = gr.Button("Get Suggested Query")
247
  with gr.Column():
248
+ clear = gr.ClearButton([query, chatbot], value="Reset Query History")
249
  with gr.Row():
250
  search_out = gr.HTML(label="Search Results")
251
  gr.on(
252
+ [show_btn.click, search_in.submit],
253
  fn=get_iframe,
254
  inputs=[search_in],
255
  outputs=[search_out],
 
259
  outputs=[card_data],
260
  )
261
  gr.on(
262
+ [query_btn.click, query.submit],
263
  fn=query_dataset,
264
+ inputs=[
265
+ search_in,
266
+ card_data,
267
+ query,
268
+ config_selection,
269
+ split_selection,
270
+ chatbot,
271
+ ],
272
+ outputs=[sql_out, search_out, chatbot],
273
  )
274
+ gr.on([query_btn.click], fn=lambda: gr.update(open=True), outputs=[accordion])
275
 
276
 
277
  if __name__ == "__main__":