MahaNeta / utils /sql_runtime.py
ankush-003's picture
init
10757ec
"""
Runtime that accepts a sql statement and runs it on sql server.
Returns the results of sql execution.
"""
import traceback
import sqlite3
# MODIFY THE PATH BELOW FOR YOUR SYSTEM
my_db = r"../data/elections.db"
class SQLRuntime(object):
def __init__(self, dbname=None):
if dbname is None:
dbname = my_db
conn = sqlite3.connect(dbname) # creating a connection
self.cursor = conn.cursor() # we need the cursor to execute statement
return
def list_tables(self):
result = self.cursor.execute("SELECT name FROM sqlite_master WHERE type='table';").fetchall()
table_names = sorted(list(zip(*result))[0])
return table_names
def get_schema_for_table(self, table_name):
result = self.cursor.execute("PRAGMA table_info('%s')" % table_name).fetchall()
column_names = list(zip(*result))[1]
return column_names
def get_schemas(self):
schemas = {}
table_names = self.list_tables()
for name in table_names:
fields = self.get_schema_for_table(name) # fields of the table name
schemas[name] = fields
return schemas
def execute(self, statement):
code = 0
msg = {
"text": "SUCCESS",
"reason": None,
"traceback": None,
}
data = None
try:
self.cursor.execute(statement)
except sqlite3.OperationalError:
code = -1
msg = {
"text": "ERROR: SQL execution error",
"reason": "possibly due to incorrect table/fields names",
"traceback": traceback.format_exc(),
}
if code == 0:
data = self.cursor.fetchall()
msg["input"] = statement
result = {
"code": code,
"msg": msg,
"data": data
}
return result
def execute_batch(self, queries):
results = []
for query in queries:
result = self.execute(query)
results.append(result)
return results
def post_process(self, data):
"""
post process the data so that we can identify any harmful code and remove them.
Also, llm output may need an output parser.
:param data:
:return:
"""
# IMPLEMENT YOUR CODE HERE FOR POST-PROCESSING and VALIDATION
return data
def sql_runtime(statement):
"""
Instantiates a sql runtime and executes the given sql statement
:param statement: sql statement
"""
SQL = SQLRuntime()
data = SQL.execute(statement)
return data
if __name__ == '__main__':
# stmt = """
# SELECT * FROM elections_2019;
# """
# stmt = input("Enter stmt: ")
sql = SQLRuntime()
tables = sql.list_tables()
print(tables)
schemas = {}
for table in tables:
schemas[table] = sql.get_schema_for_table(table)
print(f"Table: {table}, Schema: {schemas[table]}\n")
# data1 = sql.execute(stmt)
# dat = data1["data"]
# if dat is not None and len(dat) > 0:
# for record in dat:
# print(record)
# print("-" * 100)
# sample question: find out the votes polled by NOTA for each instance of Akkalkuwa in the parliamentary elections 2019.
stmt = """
SELECT party_name, SUM(nota_votes)
FROM elections_2019
WHERE constituency='Akkalkuwa'
GROUP BY party_name;
"""
data1 = sql.execute(stmt)
# print(data1)
dat = data1["data"]
if dat is not None and len(dat) > 0:
for record in dat:
print(record)
print("-" * 100)