DrishtiSharma commited on
Commit
e8e485a
Β·
verified Β·
1 Parent(s): 7752a10

Create interim_radio.py

Browse files
Files changed (1) hide show
  1. interim_radio.py +171 -0
interim_radio.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import sqlite3
4
+ import os
5
+ import json
6
+ from pathlib import Path
7
+ from datetime import datetime, timezone
8
+ from crewai import Agent, Crew, Process, Task
9
+ from crewai_tools import tool
10
+ from langchain_groq import ChatGroq
11
+ from langchain.schema.output import LLMResult
12
+ from langchain_core.callbacks.base import BaseCallbackHandler
13
+ from langchain_community.tools.sql_database.tool import (
14
+ InfoSQLDatabaseTool,
15
+ ListSQLDatabaseTool,
16
+ QuerySQLCheckerTool,
17
+ QuerySQLDataBaseTool,
18
+ )
19
+ from langchain_community.utilities.sql_database import SQLDatabase
20
+ from datasets import load_dataset
21
+ import tempfile
22
+
23
+ # Environment setup
24
+ os.environ["GROQ_API_KEY"] = st.secrets.get("GROQ_API_KEY", "")
25
+
26
+ # LLM Callback Logger
27
+ class LLMCallbackHandler(BaseCallbackHandler):
28
+ def __init__(self, log_path: Path):
29
+ self.log_path = log_path
30
+
31
+ def on_llm_start(self, serialized, prompts, **kwargs):
32
+ with self.log_path.open("a", encoding="utf-8") as file:
33
+ file.write(json.dumps({"event": "llm_start", "text": prompts[0], "timestamp": datetime.now().isoformat()}) + "\n")
34
+
35
+ def on_llm_end(self, response: LLMResult, **kwargs):
36
+ generation = response.generations[-1][-1].message.content
37
+ with self.log_path.open("a", encoding="utf-8") as file:
38
+ file.write(json.dumps({"event": "llm_end", "text": generation, "timestamp": datetime.now().isoformat()}) + "\n")
39
+
40
+ # Initialize the LLM
41
+ llm = ChatGroq(
42
+ temperature=0,
43
+ model_name="mixtral-8x7b-32768",
44
+ callbacks=[LLMCallbackHandler(Path("prompts.jsonl"))],
45
+ )
46
+
47
+ st.title("SQL-RAG Using CrewAI πŸš€")
48
+ st.write("Analyze datasets using natural language queries powered by SQL and CrewAI.")
49
+
50
+ # Input Options
51
+ input_option = st.radio("Select Dataset Input:", ["Use Hugging Face Dataset", "Upload CSV File"])
52
+ df = None
53
+
54
+ if input_option == "Use Hugging Face Dataset":
55
+ dataset_name = st.text_input("Enter Hugging Face Dataset Name:", value="Einstellung/demo-salaries")
56
+ if st.button("Load Dataset"):
57
+ try:
58
+ with st.spinner("Loading Hugging Face dataset..."):
59
+ dataset = load_dataset(dataset_name, split="train")
60
+ df = pd.DataFrame(dataset)
61
+ st.success(f"Dataset '{dataset_name}' loaded successfully!")
62
+ st.dataframe(df.head())
63
+ except Exception as e:
64
+ st.error(f"Error loading dataset: {e}")
65
+ else:
66
+ uploaded_file = st.file_uploader("Upload CSV File:", type=["csv"])
67
+ if uploaded_file:
68
+ df = pd.read_csv(uploaded_file)
69
+ st.success("File uploaded successfully!")
70
+ st.dataframe(df.head())
71
+
72
+ # SQL-RAG Analysis
73
+ if df is not None:
74
+ temp_dir = tempfile.TemporaryDirectory()
75
+ db_path = os.path.join(temp_dir.name, "data.db")
76
+ connection = sqlite3.connect(db_path)
77
+ df.to_sql("salaries", connection, if_exists="replace", index=False)
78
+ db = SQLDatabase.from_uri(f"sqlite:///{db_path}")
79
+
80
+ # Tools with proper docstrings
81
+ @tool("list_tables")
82
+ def list_tables() -> str:
83
+ """List all tables in the SQLite database."""
84
+ return ListSQLDatabaseTool(db=db).invoke("")
85
+
86
+ @tool("tables_schema")
87
+ def tables_schema(tables: str) -> str:
88
+ """
89
+ Get the schema and sample rows for specific tables in the database.
90
+ Input: Comma-separated table names.
91
+ Example: 'salaries'
92
+ """
93
+ return InfoSQLDatabaseTool(db=db).invoke(tables)
94
+
95
+ @tool("execute_sql")
96
+ def execute_sql(sql_query: str) -> str:
97
+ """
98
+ Execute a valid SQL query on the database and return the results.
99
+ Input: A SQL query string.
100
+ Example: 'SELECT * FROM salaries LIMIT 5;'
101
+ """
102
+ return QuerySQLDataBaseTool(db=db).invoke(sql_query)
103
+
104
+ @tool("check_sql")
105
+ def check_sql(sql_query: str) -> str:
106
+ """
107
+ Check the validity of a SQL query before execution.
108
+ Input: A SQL query string.
109
+ Example: 'SELECT salary FROM salaries WHERE salary > 10000;'
110
+ """
111
+ return QuerySQLCheckerTool(db=db, llm=llm).invoke({"query": sql_query})
112
+
113
+ # Agents
114
+ sql_dev = Agent(
115
+ role="Database Developer",
116
+ goal="Extract relevant data by executing SQL queries.",
117
+ llm=llm,
118
+ tools=[list_tables, tables_schema, execute_sql, check_sql],
119
+ )
120
+
121
+ data_analyst = Agent(
122
+ role="Data Analyst",
123
+ goal="Analyze the extracted data and generate detailed insights.",
124
+ llm=llm,
125
+ )
126
+
127
+ report_writer = Agent(
128
+ role="Report Writer",
129
+ goal="Summarize the analysis into an executive report.",
130
+ llm=llm,
131
+ )
132
+
133
+ # Tasks
134
+ extract_data = Task(
135
+ description="Extract data for the query: {query}.",
136
+ expected_output="Database query results.",
137
+ agent=sql_dev,
138
+ )
139
+
140
+ analyze_data = Task(
141
+ description="Analyze the query results for: {query}.",
142
+ expected_output="Analysis report.",
143
+ agent=data_analyst,
144
+ context=[extract_data],
145
+ )
146
+
147
+ write_report = Task(
148
+ description="Summarize the analysis into an executive summary.",
149
+ expected_output="Markdown-formatted report.",
150
+ agent=report_writer,
151
+ context=[analyze_data],
152
+ )
153
+
154
+ crew = Crew(
155
+ agents=[sql_dev, data_analyst, report_writer],
156
+ tasks=[extract_data, analyze_data, write_report],
157
+ process=Process.sequential,
158
+ verbose=2,
159
+ )
160
+
161
+ query = st.text_area("Enter Query:", placeholder="e.g., 'What is the average salary by experience level?'")
162
+ if st.button("Submit Query"):
163
+ with st.spinner("Processing your query with CrewAI..."):
164
+ inputs = {"query": query}
165
+ result = crew.kickoff(inputs=inputs)
166
+ st.markdown("### Analysis Report:")
167
+ st.markdown(result)
168
+
169
+ temp_dir.cleanup()
170
+ else:
171
+ st.info("Load a dataset to proceed.")