Spaces:
Runtime error
Runtime error
""" | |
Runtime that accepts a sql statement and runs it on sql server. | |
Returns the results of sql execution. | |
""" | |
import traceback | |
import sqlite3 | |
# MODIFY THE PATH BELOW FOR YOUR SYSTEM | |
my_db = r"../data/elections.db" | |
class SQLRuntime(object): | |
def __init__(self, dbname=None): | |
if dbname is None: | |
dbname = my_db | |
conn = sqlite3.connect(dbname) # creating a connection | |
self.cursor = conn.cursor() # we need the cursor to execute statement | |
return | |
def list_tables(self): | |
result = self.cursor.execute("SELECT name FROM sqlite_master WHERE type='table';").fetchall() | |
table_names = sorted(list(zip(*result))[0]) | |
return table_names | |
def get_schema_for_table(self, table_name): | |
result = self.cursor.execute("PRAGMA table_info('%s')" % table_name).fetchall() | |
column_names = list(zip(*result))[1] | |
return column_names | |
def get_schemas(self): | |
schemas = {} | |
table_names = self.list_tables() | |
for name in table_names: | |
fields = self.get_schema_for_table(name) # fields of the table name | |
schemas[name] = fields | |
return schemas | |
def execute(self, statement): | |
code = 0 | |
msg = { | |
"text": "SUCCESS", | |
"reason": None, | |
"traceback": None, | |
} | |
data = None | |
try: | |
self.cursor.execute(statement) | |
except sqlite3.OperationalError: | |
code = -1 | |
msg = { | |
"text": "ERROR: SQL execution error", | |
"reason": "possibly due to incorrect table/fields names", | |
"traceback": traceback.format_exc(), | |
} | |
if code == 0: | |
data = self.cursor.fetchall() | |
msg["input"] = statement | |
result = { | |
"code": code, | |
"msg": msg, | |
"data": data | |
} | |
return result | |
def execute_batch(self, queries): | |
results = [] | |
for query in queries: | |
result = self.execute(query) | |
results.append(result) | |
return results | |
def post_process(self, data): | |
""" | |
post process the data so that we can identify any harmful code and remove them. | |
Also, llm output may need an output parser. | |
:param data: | |
:return: | |
""" | |
# IMPLEMENT YOUR CODE HERE FOR POST-PROCESSING and VALIDATION | |
return data | |
def sql_runtime(statement): | |
""" | |
Instantiates a sql runtime and executes the given sql statement | |
:param statement: sql statement | |
""" | |
SQL = SQLRuntime() | |
data = SQL.execute(statement) | |
return data | |
if __name__ == '__main__': | |
# stmt = """ | |
# SELECT * FROM elections_2019; | |
# """ | |
# stmt = input("Enter stmt: ") | |
sql = SQLRuntime() | |
tables = sql.list_tables() | |
print(tables) | |
schemas = {} | |
for table in tables: | |
schemas[table] = sql.get_schema_for_table(table) | |
print(f"Table: {table}, Schema: {schemas[table]}\n") | |
# data1 = sql.execute(stmt) | |
# dat = data1["data"] | |
# if dat is not None and len(dat) > 0: | |
# for record in dat: | |
# print(record) | |
# print("-" * 100) | |
# sample question: find out the votes polled by NOTA for each instance of Akkalkuwa in the parliamentary elections 2019. | |
stmt = """ | |
SELECT party_name, SUM(nota_votes) | |
FROM elections_2019 | |
WHERE constituency='Akkalkuwa' | |
GROUP BY party_name; | |
""" | |
data1 = sql.execute(stmt) | |
# print(data1) | |
dat = data1["data"] | |
if dat is not None and len(dat) > 0: | |
for record in dat: | |
print(record) | |
print("-" * 100) |