File size: 3,822 Bytes
10757ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
"""

Runtime that accepts a sql statement and runs it on sql server.

Returns the results of sql execution.

"""
import traceback
import sqlite3

# MODIFY THE PATH BELOW FOR YOUR SYSTEM
my_db = r"../data/elections.db"

class SQLRuntime(object):
    def __init__(self, dbname=None):
        if dbname is None:
            dbname = my_db
        conn = sqlite3.connect(dbname)  # creating a connection
        self.cursor = conn.cursor()  # we need the cursor to execute statement
        return

    def list_tables(self):
        result = self.cursor.execute("SELECT name FROM sqlite_master WHERE type='table';").fetchall()
        table_names = sorted(list(zip(*result))[0])
        return table_names

    def get_schema_for_table(self, table_name):
        result = self.cursor.execute("PRAGMA table_info('%s')" % table_name).fetchall()
        column_names = list(zip(*result))[1]
        return column_names

    def get_schemas(self):
        schemas = {}
        table_names = self.list_tables()
        for name in table_names:
            fields = self.get_schema_for_table(name)  # fields of the table name
            schemas[name] = fields
        return schemas

    def execute(self, statement):
        code = 0
        msg = {
            "text": "SUCCESS",
            "reason": None,
            "traceback": None,
        }
        data = None

        try:
            self.cursor.execute(statement)
        except sqlite3.OperationalError:
            code = -1
            msg = {
                "text": "ERROR: SQL execution error",
                "reason": "possibly due to incorrect table/fields names",
                "traceback": traceback.format_exc(),
            }

        if code == 0:
            data = self.cursor.fetchall()

        msg["input"] = statement

        result = {
            "code": code,
            "msg": msg,
            "data": data
        }

        return result

    def execute_batch(self, queries):
        results = []
        for query in queries:
            result = self.execute(query)
            results.append(result)
        return results

    def post_process(self, data):
        """

        post process the data so that we can identify any harmful code and remove them.

        Also, llm output may need an output parser.

        :param data:

        :return:

        """
        # IMPLEMENT YOUR CODE HERE FOR POST-PROCESSING and VALIDATION
        return data


def sql_runtime(statement):
    """

    Instantiates a sql runtime and executes the given sql statement

    :param statement: sql statement

    """
    SQL = SQLRuntime()
    data = SQL.execute(statement)
    return data


if __name__ == '__main__':
    # stmt = """
    # SELECT * FROM elections_2019;
    # """
    # stmt = input("Enter stmt: ")
    sql = SQLRuntime()

    tables = sql.list_tables()
    print(tables)

    schemas = {}
    for table in tables:
        schemas[table] = sql.get_schema_for_table(table)
        print(f"Table: {table}, Schema: {schemas[table]}\n")

    # data1 = sql.execute(stmt)

    # dat = data1["data"]
    # if dat is not None and len(dat) > 0:
    #     for record in dat:
    #         print(record)
    #         print("-" * 100)

    # sample question:  find out the votes polled by NOTA for each instance of Akkalkuwa in the parliamentary elections 2019.
    stmt = """

    SELECT party_name, SUM(nota_votes) 

    FROM elections_2019 

    WHERE constituency='Akkalkuwa' 

    GROUP BY party_name;

    """

    data1 = sql.execute(stmt)

    # print(data1)

    dat = data1["data"]
    if dat is not None and len(dat) > 0:
        for record in dat:
            print(record)
            print("-" * 100)