DrishtiSharma commited on
Commit
5cc7611
Β·
verified Β·
1 Parent(s): c95d3e8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -40
app.py CHANGED
@@ -20,7 +20,7 @@ from langchain_community.utilities.sql_database import SQLDatabase
20
  from datasets import load_dataset
21
  import tempfile
22
 
23
- # Environment setup
24
  os.environ["GROQ_API_KEY"] = st.secrets.get("GROQ_API_KEY", "")
25
 
26
  # LLM Callback Logger
@@ -37,17 +37,18 @@ class LLMCallbackHandler(BaseCallbackHandler):
37
  with self.log_path.open("a", encoding="utf-8") as file:
38
  file.write(json.dumps({"event": "llm_end", "text": generation, "timestamp": datetime.now().isoformat()}) + "\n")
39
 
40
- # Initialize the LLM
41
  llm = ChatGroq(
42
  temperature=0,
43
  model_name="mixtral-8x7b-32768",
44
  callbacks=[LLMCallbackHandler(Path("prompts.jsonl"))],
45
  )
46
 
 
47
  st.title("SQL-RAG Using CrewAI πŸš€")
48
  st.write("Analyze datasets using natural language queries powered by SQL and CrewAI.")
49
 
50
- # Input Options
51
  input_option = st.radio("Select Dataset Input:", ["Use Hugging Face Dataset", "Upload CSV File"])
52
  df = None
53
 
@@ -55,13 +56,13 @@ if input_option == "Use Hugging Face Dataset":
55
  dataset_name = st.text_input("Enter Hugging Face Dataset Name:", value="Einstellung/demo-salaries")
56
  if st.button("Load Dataset"):
57
  try:
58
- with st.spinner("Loading Hugging Face dataset..."):
59
  dataset = load_dataset(dataset_name, split="train")
60
  df = pd.DataFrame(dataset)
61
  st.success(f"Dataset '{dataset_name}' loaded successfully!")
62
  st.dataframe(df.head())
63
  except Exception as e:
64
- st.error(f"Error loading dataset: {e}")
65
  else:
66
  uploaded_file = st.file_uploader("Upload CSV File:", type=["csv"])
67
  if uploaded_file:
@@ -77,79 +78,66 @@ if df is not None:
77
  df.to_sql("salaries", connection, if_exists="replace", index=False)
78
  db = SQLDatabase.from_uri(f"sqlite:///{db_path}")
79
 
80
- # Tools with proper docstrings
81
  @tool("list_tables")
82
  def list_tables() -> str:
83
- """List all tables in the SQLite database."""
84
  return ListSQLDatabaseTool(db=db).invoke("")
85
 
86
  @tool("tables_schema")
87
  def tables_schema(tables: str) -> str:
88
- """
89
- Get the schema and sample rows for specific tables in the database.
90
-
91
- Input: Comma-separated table names.
92
- Example: 'salaries'
93
- """
94
  return InfoSQLDatabaseTool(db=db).invoke(tables)
95
 
96
  @tool("execute_sql")
97
  def execute_sql(sql_query: str) -> str:
98
- """
99
- Execute a valid SQL query on the database and return the results.
100
-
101
- Input: A SQL query string.
102
- Example: 'SELECT * FROM salaries LIMIT 5;'
103
- """
104
  return QuerySQLDataBaseTool(db=db).invoke(sql_query)
105
 
106
  @tool("check_sql")
107
  def check_sql(sql_query: str) -> str:
108
- """
109
- Check the validity of a SQL query before execution.
110
-
111
- Input: A SQL query string.
112
- Example: 'SELECT salary FROM salaries WHERE salary > 10000;'
113
- """
114
  return QuerySQLCheckerTool(db=db, llm=llm).invoke({"query": sql_query})
115
 
116
  # Agents
117
  sql_dev = Agent(
118
- role="Database Developer",
119
- goal="Extract relevant data by executing SQL queries.",
 
120
  llm=llm,
121
  tools=[list_tables, tables_schema, execute_sql, check_sql],
122
  )
123
 
124
  data_analyst = Agent(
125
- role="Data Analyst",
126
- goal="Analyze the extracted data and generate detailed insights.",
 
127
  llm=llm,
128
  )
129
 
130
  report_writer = Agent(
131
- role="Report Writer",
132
- goal="Summarize the analysis into an executive report.",
 
133
  llm=llm,
134
  )
135
 
136
  # Tasks
137
  extract_data = Task(
138
- description="Extract data for the query: {query}.",
139
- expected_output="Database query results.",
140
  agent=sql_dev,
141
  )
142
 
143
  analyze_data = Task(
144
- description="Analyze the query results for: {query}.",
145
- expected_output="Analysis report.",
146
  agent=data_analyst,
147
  context=[extract_data],
148
  )
149
 
150
  write_report = Task(
151
- description="Summarize the analysis into an executive summary.",
152
- expected_output="Markdown-formatted report.",
153
  agent=report_writer,
154
  context=[analyze_data],
155
  )
@@ -161,9 +149,9 @@ if df is not None:
161
  verbose=True,
162
  )
163
 
164
- query = st.text_area("Enter Query:", placeholder="e.g., 'What is the average salary by experience level?'")
165
  if st.button("Submit Query"):
166
- with st.spinner("Processing your query with CrewAI..."):
167
  inputs = {"query": query}
168
  result = crew.kickoff(inputs=inputs)
169
  st.markdown("### Analysis Report:")
@@ -171,4 +159,4 @@ if df is not None:
171
 
172
  temp_dir.cleanup()
173
  else:
174
- st.info("Load a dataset to proceed.")
 
20
  from datasets import load_dataset
21
  import tempfile
22
 
23
+ # API Key
24
  os.environ["GROQ_API_KEY"] = st.secrets.get("GROQ_API_KEY", "")
25
 
26
  # LLM Callback Logger
 
37
  with self.log_path.open("a", encoding="utf-8") as file:
38
  file.write(json.dumps({"event": "llm_end", "text": generation, "timestamp": datetime.now().isoformat()}) + "\n")
39
 
40
+ # Initialize LLM
41
  llm = ChatGroq(
42
  temperature=0,
43
  model_name="mixtral-8x7b-32768",
44
  callbacks=[LLMCallbackHandler(Path("prompts.jsonl"))],
45
  )
46
 
47
+ # Streamlit UI
48
  st.title("SQL-RAG Using CrewAI πŸš€")
49
  st.write("Analyze datasets using natural language queries powered by SQL and CrewAI.")
50
 
51
+ # Dataset Input
52
  input_option = st.radio("Select Dataset Input:", ["Use Hugging Face Dataset", "Upload CSV File"])
53
  df = None
54
 
 
56
  dataset_name = st.text_input("Enter Hugging Face Dataset Name:", value="Einstellung/demo-salaries")
57
  if st.button("Load Dataset"):
58
  try:
59
+ with st.spinner("Loading dataset..."):
60
  dataset = load_dataset(dataset_name, split="train")
61
  df = pd.DataFrame(dataset)
62
  st.success(f"Dataset '{dataset_name}' loaded successfully!")
63
  st.dataframe(df.head())
64
  except Exception as e:
65
+ st.error(f"Error: {e}")
66
  else:
67
  uploaded_file = st.file_uploader("Upload CSV File:", type=["csv"])
68
  if uploaded_file:
 
78
  df.to_sql("salaries", connection, if_exists="replace", index=False)
79
  db = SQLDatabase.from_uri(f"sqlite:///{db_path}")
80
 
 
81
  @tool("list_tables")
82
  def list_tables() -> str:
83
+ """List all tables in the database."""
84
  return ListSQLDatabaseTool(db=db).invoke("")
85
 
86
  @tool("tables_schema")
87
  def tables_schema(tables: str) -> str:
88
+ """Get schema and sample rows for given tables."""
 
 
 
 
 
89
  return InfoSQLDatabaseTool(db=db).invoke(tables)
90
 
91
  @tool("execute_sql")
92
  def execute_sql(sql_query: str) -> str:
93
+ """Execute a SQL query against the database."""
 
 
 
 
 
94
  return QuerySQLDataBaseTool(db=db).invoke(sql_query)
95
 
96
  @tool("check_sql")
97
  def check_sql(sql_query: str) -> str:
98
+ """Check the validity of a SQL query."""
 
 
 
 
 
99
  return QuerySQLCheckerTool(db=db, llm=llm).invoke({"query": sql_query})
100
 
101
  # Agents
102
  sql_dev = Agent(
103
+ role="Senior Database Developer",
104
+ goal="Extract data using optimized SQL queries.",
105
+ backstory="An expert in writing optimized SQL queries for complex databases.",
106
  llm=llm,
107
  tools=[list_tables, tables_schema, execute_sql, check_sql],
108
  )
109
 
110
  data_analyst = Agent(
111
+ role="Senior Data Analyst",
112
+ goal="Analyze the data and produce insights.",
113
+ backstory="A seasoned analyst who identifies trends and patterns in datasets.",
114
  llm=llm,
115
  )
116
 
117
  report_writer = Agent(
118
+ role="Technical Report Writer",
119
+ goal="Summarize the insights into a clear report.",
120
+ backstory="An expert in summarizing data insights into readable reports.",
121
  llm=llm,
122
  )
123
 
124
  # Tasks
125
  extract_data = Task(
126
+ description="Extract data based on the query: {query}.",
127
+ expected_output="Database results matching the query.",
128
  agent=sql_dev,
129
  )
130
 
131
  analyze_data = Task(
132
+ description="Analyze the extracted data for query: {query}.",
133
+ expected_output="Analysis text summarizing findings.",
134
  agent=data_analyst,
135
  context=[extract_data],
136
  )
137
 
138
  write_report = Task(
139
+ description="Summarize the analysis into an executive report.",
140
+ expected_output="Markdown report of insights.",
141
  agent=report_writer,
142
  context=[analyze_data],
143
  )
 
149
  verbose=True,
150
  )
151
 
152
+ query = st.text_area("Enter Query:", placeholder="e.g., 'What is the average salary for senior employees?'")
153
  if st.button("Submit Query"):
154
+ with st.spinner("Processing query..."):
155
  inputs = {"query": query}
156
  result = crew.kickoff(inputs=inputs)
157
  st.markdown("### Analysis Report:")
 
159
 
160
  temp_dir.cleanup()
161
  else:
162
+ st.info("Please load a dataset to proceed.")