MahaNeta / utils /load_maha_election_dataset.py
ankush-003's picture
init
10757ec
"""
Loads Maharashtra assembly 2019 dataset
"""
import json
import sqlite3
import pandas as pd
import csv
def load_data_from_csv(name, end=58925):
data = []
keys = None
with open(name, "r", encoding="utf-8") as f:
csv_data = csv.reader(f)
for i, line in enumerate(csv_data):
found = False
if i == 0:
keys = line
continue
for field in line:
if field.strip() == "TURNOUT":
found = True
break
if found:
# print("TURNOUT found, skipping")
continue
item = {}
# print(line)
for key, val in zip(keys, line):
item[key] = val
data.append(item)
return data
def clean_dataframe(df):
# Strip leading and trailing spaces from column names (without changing them)
df.columns = df.columns.str.strip()
# Strip spaces and convert text columns to lowercase
for col in df.select_dtypes(include='object').columns:
df[col] = df[col].str.strip()
# Fill null values with 0
df.fillna(0, inplace=True)
return df
def load_data_from_csv_to_db(name, conn):
# read the dataset from csv file and create a pandas dataframe
df = pd.read_csv(open(name, "r", encoding="utf-8"))
# clean the dataframe
df = clean_dataframe(df)
df.columns = [
'state', 'constituency_number', 'constituency', 'candidate_name', 'sex', 'age',
'category', 'party_name', 'party_symbol', 'evm_votes', 'postal_votes', 'total_votes',
'vote_percentage', 'total_electors'
]
# save the dataframe as a database table, name of table is: elections_2019
result = df.to_sql("maha_2019", conn, if_exists='replace', index=False)
return result
def query_sql(conn, query):
cursor = conn.cursor()
cursor.execute(query)
result = cursor.fetchall()
field_names = [r[0] for r in cursor.description]
print(field_names)
return result
if __name__ == '__main__':
# create a connection to sql db called elections.db
conn = sqlite3.connect('../data/elections.db')
filename = r"../data/maha_results_2019.csv"
data = load_data_from_csv(filename, end=5)
# print(data)
res = load_data_from_csv_to_db(filename, conn)
# print(res)
query = "SELECT * FROM maha_2019 LIMIT 5;"
results = query_sql(conn, query)
print(results)
# keys = data.keys()
# for i, item in enumerate(data):
# print(data[item])
# jdata = json.loads(data.to_json())
# print(jdata)