from typing import List, Dict, Any, Optional, Type from langchain_core.tools import BaseTool from pydantic import BaseModel, Field import pandas as pd from .sql_runtime import SQLRuntime from langchain_core.output_parsers import StrOutputParser from langchain_core.prompts import ChatPromptTemplate from .load_llm import load_llm from langchain_core.messages import SystemMessage from langchain_core.prompts import HumanMessagePromptTemplate from langchain.agents import AgentExecutor, create_react_agent from dotenv import load_dotenv from react import run_agent_executor from prompts import react_prompt # definig the input schema class QueryInput(BaseModel): query: str = Field(..., description="The SQL query to execute, make sure to use semicolon at the end of the query, do not execute harmful queries") class TableNameInput(BaseModel): table_name: str = Field(..., description="The name of the table to analyze") class ColumnSearchInput(BaseModel): table_name: str = Field(..., description="The name of the table to search") column_name: str = Field(..., description="The name of the column to search") limit: int = Field(default=10, description="Maximum number of distinct values to return") class SQLQueryTool(BaseTool): name: str = "sql_query" description: str = """ Execute a SQL query and return the results. Use this when you need to run a specific SQL query on the elections database. The query should be a valid SQL statement and should end with a semicolon. There should be no harmful queries executed. There are three tables in the database: elections_2019, elections_2024, maha_2019 """ args_schema: Type[BaseModel] = QueryInput # def __init__(self, db_path: Optional[str] = None): # super().__init__() # self. def _run(self, query: str) -> str: sql_runtime = SQLRuntime('../data/elections.db') try: result = sql_runtime.execute(query) if result["code"] != 0: return f"Error executing query: {result['msg']['reason']}" # Convert to DataFrame for nice string representation df = pd.DataFrame(result["data"]) if not df.empty: return df.to_string() return "Query returned no results" except Exception as e: return f"Error: {str(e)}" class TableInfoTool(BaseTool): name: str = "get_table_info" description: str = """ Get information about a specific table including its schema and basic statistics. Use this when you need to understand the structure of a table or get basic statistics about it. """ args_schema: Type[BaseModel] = TableNameInput # def __init__(self, db_path: Optional[str] = None): # super().__init__() def _run(self, table_name: str) -> str: sql_runtime = SQLRuntime('../data/elections.db') try: # Get schema schema = sql_runtime.get_schema_for_table(table_name) # Get row count count_query = f"SELECT COUNT(*) FROM {table_name}" count_result = sql_runtime.execute(count_query) row_count = count_result["data"][0][0] if count_result["code"] == 0 else "Error" # Get sample data sample_query = f"SELECT * FROM {table_name} LIMIT 3" sample_result = sql_runtime.execute(sample_query) info = f""" Table: {table_name} Columns: {', '.join(schema)} Row Count: {row_count} Sample Data: {pd.DataFrame(sample_result['data'], columns=schema).to_string() if sample_result['code'] == 0 else 'Error getting sample data'} """ return info except Exception as e: return f"Error getting table info: {str(e)}" class ColumnValuesTool(BaseTool): name: str = "find_column_values" description: str = """ Find distinct values in a specific column of a table. Use this when you need to know what unique values exist in a particular column. """ args_schema: Type[BaseModel] = ColumnSearchInput # def __init__(self, db_path: Optional[str] = None): # super().__init__() # self.sql_runtime = SQLRuntime(db_path) def _run(self, table_name: str, column_name: str, limit: int = 10) -> str: sql_runtime = SQLRuntime('../data/elections.db') try: query = f""" SELECT DISTINCT {column_name} FROM {table_name} LIMIT {limit} """ result = sql_runtime.execute(query) if result["code"] != 0: return f"Error finding values: {result['msg']['reason']}" values = [row[0] for row in result["data"]] return f"Distinct values in {column_name}: {', '.join(map(str, values))}" except Exception as e: return f"Error: {str(e)}" class ListTablesTool(BaseTool): name: str = "list_tables" description: str = """ List all available tables in the database. Use this when you need to know what tables are available to query. """ # def __init__(self, db_path: Optional[str] = None): # super().__init__() # self.sql_runtime = SQLRuntime(db_path) def _run(self, *args, **kwargs) -> str: sql_runtime = SQLRuntime('../data/elections.db') try: tables = sql_runtime.list_tables() return f"Available tables: {', '.join(tables)}" except Exception as e: return f"Error listing tables: {str(e)}" def create_sql_agent_tools(db_path: Optional[str] = '../data/elections.db') -> List[BaseTool]: """ Create a list of all SQL tools for use with a Langchain agent. """ return [ SQLQueryTool(), TableInfoTool(), # ColumnValuesTool(), ListTablesTool() ] if __name__ == "__main__": load_dotenv() tools = create_sql_agent_tools() for tool in tools: print(f"Tool: {tool.name}") print(f"Description: {tool.description}") # print(f"Args Schema: {tool.args_schema.schema()}") # prompt = prompt = ChatPromptTemplate.from_messages( # [ # SystemMessage( # content=""" # You are a sql agent who has access to a database with three tables: elections_2019, elections_2024, maha_2019. # You can use the following tools: # - sql_query: Execute a SQL query and return the results. # - get_table_info: Get information about a specific table including its schema and basic statistics. # - find_column_values: Find distinct values in a specific column of a table. # - list_tables: List all available tables in the database. # Answer the questions using the tools provided. Do not execute harmful queries. # """ # ), # HumanMessagePromptTemplate.from_template("{text}"), # ] # ) output_parser = StrOutputParser() # Create the llm llm = load_llm() # llm.bind_tools(tools) # res = llm.invoke("who won elections in maharashtra in Nandurbar in elections 2019? use the given tools") # chain = prompt | llm | output_parser # Run the chain agent = create_react_agent(llm, tools, react_prompt) # Create an agent executor by passing in the agent and tools agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True) print("Agent created successfully") # Run the agent # agent_executor.invoke({"input": "Who won the elections in 2019 for the state maharashtra in constituency Akkalkuwa?"}) res = agent_executor.invoke({"input": "who won elections in maharashtra in Nandurbar in elections 2019?"}) # run_agent_executor(agent_executor, {"input": "who won elections in maharashtra in Nandurbar in elections 2019?"})