tdoehmen commited on
Commit
a3ab7c7
1 Parent(s): dcec0ff

added azure endpoint

Browse files
Files changed (2) hide show
  1. MODEL_README.md +0 -156
  2. app.py +40 -3
MODEL_README.md DELETED
@@ -1,156 +0,0 @@
1
- ---
2
- license: llama2
3
- inference:
4
- parameters:
5
- do_sample: false
6
- max_length: 200
7
- widget:
8
- - text: "CREATE TABLE stadium (\n stadium_id number,\n location text,\n name text,\n capacity number,\n)\n\n-- Using valid SQLite, answer the following questions for the tables provided above.\n\n-- how many stadiums in total?\n\nSELECT"
9
- example_title: "Number stadiums"
10
- - text: "CREATE TABLE work_orders ( ID NUMBER, CREATED_AT TEXT, COST FLOAT, INVOICE_AMOUNT FLOAT, IS_DUE BOOLEAN, IS_OPEN BOOLEAN, IS_OVERDUE BOOLEAN, COUNTRY_NAME TEXT, )\n\n-- Using valid SQLite, answer the following questions for the tables provided above.\n\n-- how many work orders are open?\n\nSELECT"
11
- example_title: "Open work orders"
12
- - text: "CREATE TABLE stadium ( stadium_id number, location text, name text, capacity number, highest number, lowest number, average number )\n\nCREATE TABLE singer ( singer_id number, name text, country text, song_name text, song_release_year text, age number, is_male others )\n\nCREATE TABLE concert ( concert_id number, concert_name text, theme text, stadium_id text, year text )\n\nCREATE TABLE singer_in_concert ( concert_id number, singer_id text )\n\n-- Using valid SQLite, answer the following questions for the tables provided above.\n\n-- What is the maximum, the average, and the minimum capacity of stadiums ?\n\nSELECT"
13
- example_title: "Stadium capacity"
14
- ---
15
-
16
- # DucKDB-NSQL-7B
17
-
18
- ## Model Description
19
-
20
- NSQL is a family of autoregressive open-source large foundation models (FMs) designed specifically for SQL generation tasks.
21
-
22
- In this repository we are introducing a new member of NSQL, DuckDB-NSQL. It's based on Meta's original [Llama-2 7B model](https://huggingface.co/meta-llama/Llama-2-7b) and further pre-trained on a dataset of general SQL queries and then fine-tuned on a dataset composed of DuckDB text-to-SQL pairs.
23
-
24
- ## Training Data
25
-
26
- The general SQL queries are the SQL subset from [The Stack](https://huggingface.co/datasets/bigcode/the-stack), containing 1M training samples. The samples we transpiled to DuckDB SQL, using [sqlglot](https://github.com/tobymao/sqlglot). The labeled text-to-SQL pairs come [NSText2SQL](https://huggingface.co/datasets/NumbersStation/NSText2SQL) that were also transpiled to DuckDB SQL, and 200k synthetically generated DuckDB SQL queries, based on the DuckDB v.0.9.2 documentation.
27
-
28
- ## Evaluation Data
29
-
30
- We evaluate our models on a DuckDB-specific benchmark that contains 75 text-to-SQL pairs. The benchmark is available [here](https://github.com/NumbersStationAI/DuckDB-NSQL/).
31
-
32
- ## Training Procedure
33
-
34
- DuckDB-NSQL was trained using cross-entropy loss to maximize the likelihood of sequential inputs. For finetuning on text-to-SQL pairs, we only compute the loss over the SQL portion of the pair. The model is trained using 80GB A100s, leveraging data and model parallelism. We pre-trained for 3 epochs and fine-tuned for 10 epochs.
35
-
36
- ## Intended Use and Limitations
37
-
38
- The model was designed for text-to-SQL generation tasks from given table schema and natural language prompts. The model works best with the prompt format defined below and outputs.
39
- In contrast to existing text-to-SQL models, the SQL generation is not contrained to `SELECT` statements, but can generate any valid DuckDB SQL statement, including statements for official DuckDB extensions.
40
-
41
- ## How to Use
42
-
43
- Example 1:
44
-
45
- ```python
46
- import torch
47
- from transformers import AutoTokenizer, AutoModelForCausalLM
48
- tokenizer = AutoTokenizer.from_pretrained("motherduckdb/nsql-duckdb-7B")
49
- model = AutoModelForCausalLM.from_pretrained("motherduckdb/nsql-duckdb-7B", torch_dtype=torch.bfloat16)
50
-
51
- text = """CREATE TABLE stadium (
52
- stadium_id number,
53
- location text,
54
- name text,
55
- capacity number,
56
- highest number,
57
- lowest number,
58
- average number
59
- )
60
-
61
- CREATE TABLE singer (
62
- singer_id number,
63
- name text,
64
- country text,
65
- song_name text,
66
- song_release_year text,
67
- age number,
68
- is_male others
69
- )
70
-
71
- CREATE TABLE concert (
72
- concert_id number,
73
- concert_name text,
74
- theme text,
75
- stadium_id text,
76
- year text
77
- )
78
-
79
- CREATE TABLE singer_in_concert (
80
- concert_id number,
81
- singer_id text
82
- )
83
-
84
- -- Using valid DuckDB SQL, answer the following questions for the tables provided above.
85
-
86
- -- What is the maximum, the average, and the minimum capacity of stadiums ?
87
-
88
- SELECT"""
89
-
90
- input_ids = tokenizer(text, return_tensors="pt").input_ids
91
-
92
- generated_ids = model.generate(input_ids, max_length=500)
93
- print(tokenizer.decode(generated_ids[0], skip_special_tokens=True))
94
- ```
95
-
96
- Example 2:
97
-
98
- ```python
99
- import torch
100
- from transformers import AutoTokenizer, AutoModelForCausalLM
101
- tokenizer = AutoTokenizer.from_pretrained("motherduckdb/nsql-duckdb-7B")
102
- model = AutoModelForCausalLM.from_pretrained("motherduckdb/nsql-duckdb-7B", torch_dtype=torch.bfloat16)
103
-
104
- text = """CREATE TABLE stadium (
105
- stadium_id number,
106
- location text,
107
- name text,
108
- capacity number,
109
- )
110
-
111
- -- Using valid DuckDB SQL, answer the following questions for the tables provided above.
112
-
113
- -- how many stadiums in total?
114
-
115
- SELECT"""
116
-
117
- input_ids = tokenizer(text, return_tensors="pt").input_ids
118
-
119
- generated_ids = model.generate(input_ids, max_length=500)
120
- print(tokenizer.decode(generated_ids[0], skip_special_tokens=True))
121
- ```
122
-
123
- Example 3:
124
-
125
- ```python
126
- import torch
127
- from transformers import AutoTokenizer, AutoModelForCausalLM
128
- tokenizer = AutoTokenizer.from_pretrained("motherduckdb/nsql-duckdb-7B")
129
- model = AutoModelForCausalLM.from_pretrained("motherduckdb/nsql-duckdb-7B", torch_dtype=torch.bfloat16)
130
-
131
- text = """CREATE TABLE work_orders (
132
- ID NUMBER,
133
- CREATED_AT TEXT,
134
- COST FLOAT,
135
- INVOICE_AMOUNT FLOAT,
136
- IS_DUE BOOLEAN,
137
- IS_OPEN BOOLEAN,
138
- IS_OVERDUE BOOLEAN,
139
- COUNTRY_NAME TEXT,
140
- )
141
-
142
- -- Using valid DuckDB SQL, answer the following questions for the tables provided above.
143
-
144
- -- how many work orders are open?
145
-
146
- SELECT"""
147
-
148
- input_ids = tokenizer(text, return_tensors="pt").input_ids
149
-
150
- generated_ids = model.generate(input_ids, max_length=500)
151
- print(tokenizer.decode(generated_ids[0], skip_special_tokens=True))
152
- ```
153
-
154
-
155
-
156
- For more information (e.g., run with your local database), please find examples in [this repository](https://github.com/NumbersStationAI/DuckDB-NSQL).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py CHANGED
@@ -3,12 +3,24 @@ import requests
3
  import subprocess
4
  import re
5
  import sys
 
 
 
 
 
6
 
7
  PROMPT_TEMPLATE = """### Instruction:\n{instruction}\n\n### Input:\n{input}\n### Question:\n{question}\n\n### Response (use duckdb shorthand if possible):\n"""
8
  INSTRUCTION_TEMPLATE = """Your task is to generate valid duckdb SQL to answer the following question{has_schema}""" # noqa: E501
9
  ERROR_MESSAGE = ":red[ Quack! Much to our regret, SQL generation has gone a tad duck-side-down.\nThe model is currently not able to craft a correct SQL query for this request. \nSorry my duck friend. ]\n\n:red[If the question is about your own database, make sure to set the correct schema. Otherwise, try to rephrase your request. ]\n\n```sql\n{sql_query}\n```\n\n```sql\n{error_msg}\n```"
10
  STOP_TOKENS = ["###", ";", "--", "```"]
11
 
 
 
 
 
 
 
 
12
 
13
  def generate_prompt(question, schema):
14
  input = ""
@@ -34,10 +46,35 @@ def generate_prompt(question, schema):
34
  )
35
  return prompt
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  def generate_sql(question, schema):
39
  prompt = generate_prompt(question, schema)
40
-
41
  s = requests.Session()
42
  api_base = "https://text-motherduck-sql-fp16-4vycuix6qcp2.octoai.run"
43
  url = f"{api_base}/v1/completions"
@@ -52,7 +89,7 @@ def generate_sql(question, schema):
52
  headers = {"Authorization": f"Bearer {st.secrets['octoml_token']}"}
53
  with s.post(url, json=body, headers=headers) as resp:
54
  sql_query = resp.json()["choices"][0]["text"]
55
-
56
  return sql_query
57
 
58
 
@@ -192,7 +229,7 @@ text_prompt = st.text_input(
192
  )
193
 
194
  if text_prompt:
195
- sql_query = generate_sql(text_prompt, schema)
196
  valid, msg = validate_sql(sql_query, schema)
197
  if not valid:
198
  st.markdown(ERROR_MESSAGE.format(sql_query=sql_query, error_msg=msg))
 
3
  import subprocess
4
  import re
5
  import sys
6
+ import urllib.request
7
+ import json
8
+ import os
9
+ import ssl
10
+ import time
11
 
12
  PROMPT_TEMPLATE = """### Instruction:\n{instruction}\n\n### Input:\n{input}\n### Question:\n{question}\n\n### Response (use duckdb shorthand if possible):\n"""
13
  INSTRUCTION_TEMPLATE = """Your task is to generate valid duckdb SQL to answer the following question{has_schema}""" # noqa: E501
14
  ERROR_MESSAGE = ":red[ Quack! Much to our regret, SQL generation has gone a tad duck-side-down.\nThe model is currently not able to craft a correct SQL query for this request. \nSorry my duck friend. ]\n\n:red[If the question is about your own database, make sure to set the correct schema. Otherwise, try to rephrase your request. ]\n\n```sql\n{sql_query}\n```\n\n```sql\n{error_msg}\n```"
15
  STOP_TOKENS = ["###", ";", "--", "```"]
16
 
17
+ def allowSelfSignedHttps(allowed):
18
+ # bypass the server certificate verification on client side
19
+ if allowed and not os.environ.get('PYTHONHTTPSVERIFY', '') and getattr(ssl, '_create_unverified_context', None):
20
+ ssl._create_default_https_context = ssl._create_unverified_context
21
+
22
+ allowSelfSignedHttps(True) # this line is needed if you use self-signed certificate in your scoring service.
23
+
24
 
25
  def generate_prompt(question, schema):
26
  input = ""
 
46
  )
47
  return prompt
48
 
49
+ def generate_sql_azure(question, schema):
50
+ prompt = generate_prompt(question, schema)
51
+ start = time.time()
52
+
53
+ data={
54
+ "input_data": {
55
+ "input_string": [prompt],
56
+ "parameters":{
57
+ "top_p": 0.9,
58
+ "temperature": 0.1,
59
+ "max_new_tokens": 200,
60
+ "do_sample": True
61
+ }
62
+ }
63
+ }
64
+ body = str.encode(json.dumps(data))
65
+
66
+ url = 'https://motherduck-eu-west2-xbdfd.westeurope.inference.ml.azure.com/score'
67
+ headers = {'Content-Type':'application/json', 'Authorization':('Bearer '+ st.secrets['azure_ai_token']), 'azureml-model-deployment': 'motherduckdb-duckdb-nsql-7b-v-1' }
68
+ req = urllib.request.Request(url, body, headers)
69
+ raw_resp = urllib.request.urlopen(req)
70
+ resp = json.loads(raw_resp.read().decode("utf-8"))[0]["0"]
71
+ sql_query = resp[len(prompt):]
72
+ print(time.time()-start)
73
+ return sql_query
74
 
75
  def generate_sql(question, schema):
76
  prompt = generate_prompt(question, schema)
77
+ start = time.time()
78
  s = requests.Session()
79
  api_base = "https://text-motherduck-sql-fp16-4vycuix6qcp2.octoai.run"
80
  url = f"{api_base}/v1/completions"
 
89
  headers = {"Authorization": f"Bearer {st.secrets['octoml_token']}"}
90
  with s.post(url, json=body, headers=headers) as resp:
91
  sql_query = resp.json()["choices"][0]["text"]
92
+ print(time.time()-start)
93
  return sql_query
94
 
95
 
 
229
  )
230
 
231
  if text_prompt:
232
+ sql_query = generate_sql_azure(text_prompt, schema)
233
  valid, msg = validate_sql(sql_query, schema)
234
  if not valid:
235
  st.markdown(ERROR_MESSAGE.format(sql_query=sql_query, error_msg=msg))