DrishtiSharma commited on
Commit
77389d5
Β·
verified Β·
1 Parent(s): c86cb4d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +101 -105
app.py CHANGED
@@ -18,12 +18,11 @@ from langchain_community.tools.sql_database.tool import (
18
  QuerySQLDataBaseTool,
19
  )
20
  from langchain_community.utilities.sql_database import SQLDatabase
 
21
  import tempfile
22
 
23
- # Setup GROQ API Key
24
  os.environ["GROQ_API_KEY"] = st.secrets.get("GROQ_API_KEY", "")
25
 
26
- # Callback handler for logging LLM responses
27
  class Event:
28
  def __init__(self, event, text):
29
  self.event = event
@@ -43,112 +42,109 @@ class LLMCallbackHandler(BaseCallbackHandler):
43
  with self.log_path.open("a", encoding="utf-8") as file:
44
  file.write(json.dumps({"event": "llm_end", "text": generation, "timestamp": datetime.now().isoformat()}) + "\n")
45
 
46
- # LLM Setup
47
  llm = ChatGroq(
48
  temperature=0,
49
  model_name="mixtral-8x7b-32768",
50
  callbacks=[LLMCallbackHandler(Path("prompts.jsonl"))],
51
  )
52
 
53
- # App Header
54
- st.title("Dynamic Query Analysis with CrewAI πŸš€")
55
- st.write("Provide your query, and the app will extract, analyze, and summarize the data dynamically.")
56
-
57
- # File Upload for Dataset
58
- uploaded_file = st.file_uploader("Upload your dataset (CSV file)", type=["csv"])
59
-
60
- if uploaded_file:
61
- st.success("File uploaded successfully!")
62
-
63
- # Temporary directory for SQLite DB
64
- temp_dir = tempfile.TemporaryDirectory()
65
- db_path = os.path.join(temp_dir.name, "data.db")
66
-
67
- # Create SQLite database
68
- df = pd.read_csv(uploaded_file)
69
- connection = sqlite3.connect(db_path)
70
- df.to_sql("data_table", connection, if_exists="replace", index=False)
71
-
72
- db = SQLDatabase.from_uri(f"sqlite:///{db_path}")
73
-
74
- # Tools
75
- @tool("list_tables")
76
- def list_tables() -> str:
77
- return ListSQLDatabaseTool(db=db).invoke("")
78
-
79
- @tool("tables_schema")
80
- def tables_schema(tables: str) -> str:
81
- return InfoSQLDatabaseTool(db=db).invoke(tables)
82
-
83
- @tool("execute_sql")
84
- def execute_sql(sql_query: str) -> str:
85
- return QuerySQLDataBaseTool(db=db).invoke(sql_query)
86
-
87
- @tool("check_sql")
88
- def check_sql(sql_query: str) -> str:
89
- return QuerySQLCheckerTool(db=db, llm=llm).invoke({"query": sql_query})
90
-
91
- # Agents
92
- sql_dev = Agent(
93
- role="Senior Database Developer",
94
- goal="Extract data from the database based on user query",
95
- llm=llm,
96
- tools=[list_tables, tables_schema, execute_sql, check_sql],
97
- allow_delegation=False,
98
- )
99
-
100
- data_analyst = Agent(
101
- role="Senior Data Analyst",
102
- goal="Analyze the database response and provide insights",
103
- llm=llm,
104
- allow_delegation=False,
105
- )
106
-
107
- report_writer = Agent(
108
- role="Senior Report Editor",
109
- goal="Summarize the analysis into a short report",
110
- llm=llm,
111
- allow_delegation=False,
112
- )
113
-
114
- # Tasks
115
- extract_data = Task(
116
- description="Extract data required for the query: {query}.",
117
- expected_output="Database result for the query",
118
- agent=sql_dev,
119
- )
120
-
121
- analyze_data = Task(
122
- description="Analyze the data and generate insights for: {query}.",
123
- expected_output="Detailed analysis text",
124
- agent=data_analyst,
125
- context=[extract_data],
126
- )
127
-
128
- write_report = Task(
129
- description="Summarize the analysis into a concise executive report.",
130
- expected_output="Markdown report",
131
- agent=report_writer,
132
- context=[analyze_data],
133
- )
134
-
135
- # Crew
136
- crew = Crew(
137
- agents=[sql_dev, data_analyst, report_writer],
138
- tasks=[extract_data, analyze_data, write_report],
139
- process=Process.sequential,
140
- verbose=2,
141
- memory=False,
142
- )
143
-
144
- # User Input Query
145
- query = st.text_input("Enter your query:")
146
- if query:
147
- with st.spinner("Processing your query..."):
148
- inputs = {"query": query}
149
- result = crew.kickoff(inputs=inputs)
150
- st.markdown("### Analysis Report:")
151
- st.markdown(result)
152
-
153
- # Clean up
154
- temp_dir.cleanup()
 
18
  QuerySQLDataBaseTool,
19
  )
20
  from langchain_community.utilities.sql_database import SQLDatabase
21
+ from datasets import load_dataset
22
  import tempfile
23
 
 
24
  os.environ["GROQ_API_KEY"] = st.secrets.get("GROQ_API_KEY", "")
25
 
 
26
  class Event:
27
  def __init__(self, event, text):
28
  self.event = event
 
42
  with self.log_path.open("a", encoding="utf-8") as file:
43
  file.write(json.dumps({"event": "llm_end", "text": generation, "timestamp": datetime.now().isoformat()}) + "\n")
44
 
 
45
  llm = ChatGroq(
46
  temperature=0,
47
  model_name="mixtral-8x7b-32768",
48
  callbacks=[LLMCallbackHandler(Path("prompts.jsonl"))],
49
  )
50
 
51
+ st.title("SQL-RAG using CrewAI πŸš€")
52
+ st.write("Analyze and summarize Hugging Face datasets using natural language queries with SQL-based retrieval.")
53
+
54
+ default_dataset = "datascience/ds-salaries"
55
+ st.text("Example dataset: `datascience/ds-salaries` (You can enter your own dataset name)")
56
+
57
+ dataset_name = st.text_input("Enter Hugging Face dataset name:", value=default_dataset)
58
+
59
+ if dataset_name:
60
+ with st.spinner("Loading dataset..."):
61
+ try:
62
+ dataset = load_dataset(dataset_name, split="train")
63
+ df = pd.DataFrame(dataset)
64
+ st.success(f"Dataset '{dataset_name}' loaded successfully!")
65
+ st.write("Preview of the dataset:")
66
+ st.dataframe(df.head())
67
+
68
+ temp_dir = tempfile.TemporaryDirectory()
69
+ db_path = os.path.join(temp_dir.name, "data.db")
70
+ connection = sqlite3.connect(db_path)
71
+ df.to_sql("data_table", connection, if_exists="replace", index=False)
72
+ db = SQLDatabase.from_uri(f"sqlite:///{db_path}")
73
+
74
+ @tool("list_tables")
75
+ def list_tables() -> str:
76
+ return ListSQLDatabaseTool(db=db).invoke("")
77
+
78
+ @tool("tables_schema")
79
+ def tables_schema(tables: str) -> str:
80
+ return InfoSQLDatabaseTool(db=db).invoke(tables)
81
+
82
+ @tool("execute_sql")
83
+ def execute_sql(sql_query: str) -> str:
84
+ return QuerySQLDataBaseTool(db=db).invoke(sql_query)
85
+
86
+ @tool("check_sql")
87
+ def check_sql(sql_query: str) -> str:
88
+ return QuerySQLCheckerTool(db=db, llm=llm).invoke({"query": sql_query})
89
+
90
+ sql_dev = Agent(
91
+ role="Database Developer",
92
+ goal="Extract data from the database.",
93
+ llm=llm,
94
+ tools=[list_tables, tables_schema, execute_sql, check_sql],
95
+ allow_delegation=False,
96
+ )
97
+
98
+ data_analyst = Agent(
99
+ role="Data Analyst",
100
+ goal="Analyze and provide insights.",
101
+ llm=llm,
102
+ allow_delegation=False,
103
+ )
104
+
105
+ report_writer = Agent(
106
+ role="Report Editor",
107
+ goal="Summarize the analysis.",
108
+ llm=llm,
109
+ allow_delegation=False,
110
+ )
111
+
112
+ extract_data = Task(
113
+ description="Extract data required for the query: {query}.",
114
+ expected_output="Database result for the query",
115
+ agent=sql_dev,
116
+ )
117
+
118
+ analyze_data = Task(
119
+ description="Analyze the data for: {query}.",
120
+ expected_output="Detailed analysis text",
121
+ agent=data_analyst,
122
+ context=[extract_data],
123
+ )
124
+
125
+ write_report = Task(
126
+ description="Summarize the analysis into a short report.",
127
+ expected_output="Markdown report",
128
+ agent=report_writer,
129
+ context=[analyze_data],
130
+ )
131
+
132
+ crew = Crew(
133
+ agents=[sql_dev, data_analyst, report_writer],
134
+ tasks=[extract_data, analyze_data, write_report],
135
+ process=Process.sequential,
136
+ verbose=2,
137
+ memory=False,
138
+ )
139
+
140
+ query = st.text_input("Enter your query:", placeholder="e.g., 'How does salary vary by company size?'")
141
+ if query:
142
+ with st.spinner("Processing your query..."):
143
+ inputs = {"query": query}
144
+ result = crew.kickoff(inputs=inputs)
145
+ st.markdown("### Analysis Report:")
146
+ st.markdown(result)
147
+
148
+ temp_dir.cleanup()
149
+ except Exception as e:
150
+ st.error(f"Error loading dataset: {e}")