DrishtiSharma commited on
Commit
b945491
Β·
verified Β·
1 Parent(s): 0449832

Create super_flwed_dynamic_viz_v2.py

Browse files
Files changed (1) hide show
  1. mylab/super_flwed_dynamic_viz_v2.py +614 -0
mylab/super_flwed_dynamic_viz_v2.py ADDED
@@ -0,0 +1,614 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import sqlite3
4
+ import tempfile
5
+ from fpdf import FPDF
6
+ import threading
7
+ import time
8
+ import os
9
+ import re
10
+ import json
11
+ from pathlib import Path
12
+ import plotly.express as px
13
+ from datetime import datetime, timezone
14
+ from crewai import Agent, Crew, Process, Task
15
+ from crewai.tools import tool
16
+ from langchain_groq import ChatGroq
17
+ from langchain_openai import ChatOpenAI
18
+ from langchain.schema.output import LLMResult
19
+ from langchain_community.tools.sql_database.tool import (
20
+ InfoSQLDatabaseTool,
21
+ ListSQLDatabaseTool,
22
+ QuerySQLCheckerTool,
23
+ QuerySQLDataBaseTool,
24
+ )
25
+ from langchain_community.utilities.sql_database import SQLDatabase
26
+ from datasets import load_dataset
27
+ import tempfile
28
+
29
+ st.title("SQL-RAG Using CrewAI πŸš€")
30
+ st.write("Analyze datasets using natural language queries.")
31
+
32
+ # Initialize LLM
33
+ llm = None
34
+
35
+
36
+ # Model Selection
37
+ model_choice = st.radio("Select LLM", ["GPT-4o", "llama-3.3-70b"], index=0, horizontal=True)
38
+
39
+ # API Key Validation and LLM Initialization
40
+ groq_api_key = os.getenv("GROQ_API_KEY")
41
+ openai_api_key = os.getenv("OPENAI_API_KEY")
42
+
43
+ if model_choice == "llama-3.3-70b":
44
+ if not groq_api_key:
45
+ st.error("Groq API key is missing. Please set the GROQ_API_KEY environment variable.")
46
+ llm = None
47
+ else:
48
+ llm = ChatGroq(groq_api_key=groq_api_key, model="groq/llama-3.3-70b-versatile")
49
+ elif model_choice == "GPT-4o":
50
+ if not openai_api_key:
51
+ st.error("OpenAI API key is missing. Please set the OPENAI_API_KEY environment variable.")
52
+ llm = None
53
+ else:
54
+ llm = ChatOpenAI(api_key=openai_api_key, model="gpt-4o")
55
+
56
+ if llm is None:
57
+ st.error("❌ LLM is not initialized. Please check your API keys and model selection.")
58
+
59
+ # Initialize session state for data persistence
60
+ if "df" not in st.session_state:
61
+ st.session_state.df = None
62
+ if "show_preview" not in st.session_state:
63
+ st.session_state.show_preview = False
64
+
65
+ # Dataset Input
66
+ input_option = st.radio("Select Dataset Input:", ["Use Hugging Face Dataset", "Upload CSV File"])
67
+
68
+ if input_option == "Use Hugging Face Dataset":
69
+ dataset_name = st.text_input("Enter Hugging Face Dataset Name:", value="Einstellung/demo-salaries")
70
+ if st.button("Load Dataset"):
71
+ try:
72
+ with st.spinner("Loading dataset..."):
73
+ dataset = load_dataset(dataset_name, split="train")
74
+ st.session_state.df = pd.DataFrame(dataset)
75
+ st.session_state.show_preview = True # Show preview after loading
76
+ st.success(f"Dataset '{dataset_name}' loaded successfully!")
77
+ except Exception as e:
78
+ st.error(f"Error: {e}")
79
+
80
+ elif input_option == "Upload CSV File":
81
+ uploaded_file = st.file_uploader("Upload CSV File:", type=["csv"])
82
+ if uploaded_file:
83
+ try:
84
+ st.session_state.df = pd.read_csv(uploaded_file)
85
+ st.session_state.show_preview = True # Show preview after loading
86
+ st.success("File uploaded successfully!")
87
+ except Exception as e:
88
+ st.error(f"Error loading file: {e}")
89
+
90
+ # Show Dataset Preview Only After Loading
91
+ if st.session_state.df is not None and st.session_state.show_preview:
92
+ st.subheader("πŸ“‚ Dataset Preview")
93
+ st.dataframe(st.session_state.df.head())
94
+
95
+
96
+ # Helper Function for Validation
97
+ def is_valid_suggestion(suggestion):
98
+ chart_type = suggestion.get("chart_type", "").lower()
99
+
100
+ if chart_type in ["bar", "line", "box", "scatter"]:
101
+ return all(k in suggestion for k in ["chart_type", "x_axis", "y_axis"])
102
+
103
+ elif chart_type == "pie":
104
+ return all(k in suggestion for k in ["chart_type", "x_axis"])
105
+
106
+ elif chart_type == "heatmap":
107
+ return all(k in suggestion for k in ["chart_type", "x_axis", "y_axis"])
108
+
109
+ else:
110
+ return False
111
+
112
+ def ask_gpt4o_for_visualization(query, df, llm, retries=2):
113
+ import json
114
+
115
+ # Identify numeric and categorical columns
116
+ numeric_columns = df.select_dtypes(include='number').columns.tolist()
117
+ categorical_columns = df.select_dtypes(exclude='number').columns.tolist()
118
+
119
+ # Prompt with Dataset-Specific, Query-Based Examples
120
+ prompt = f"""
121
+ Analyze the following query and suggest the most suitable visualization(s) using the dataset.
122
+ **Query:** "{query}"
123
+ **Dataset Overview:**
124
+ - **Numeric Columns (for Y-axis):** {', '.join(numeric_columns) if numeric_columns else 'None'}
125
+ - **Categorical Columns (for X-axis or grouping):** {', '.join(categorical_columns) if categorical_columns else 'None'}
126
+ Suggest visualizations in this exact JSON format:
127
+ [
128
+ {{
129
+ "chdart_type": "bar/box/line/scatter/pie/heatmap",
130
+ "x_axis": "categorical_or_time_column",
131
+ "y_axis": "numeric_column",
132
+ "group_by": "optional_column_for_grouping",
133
+ "title": "Title of the chart",
134
+ "description": "Why this chart is suitable"
135
+ }}
136
+ ]
137
+ **Query-Based Examples:**
138
+ - **Query:** "What is the salary distribution across different job titles?"
139
+ **Suggested Visualization:**
140
+ {{
141
+ "chart_type": "box",
142
+ "x_axis": "job_title",
143
+ "y_axis": "salary_in_usd",
144
+ "group_by": "experience_level",
145
+ "title": "Salary Distribution by Job Title and Experience",
146
+ "description": "A box plot to show how salaries vary across different job titles and experience levels."
147
+ }}
148
+ - **Query:** "Show the average salary by company size and employment type."
149
+ **Suggested Visualizations:**
150
+ [
151
+ {{
152
+ "chart_type": "bar",
153
+ "x_axis": "company_size",
154
+ "y_axis": "salary_in_usd",
155
+ "group_by": "employment_type",
156
+ "title": "Average Salary by Company Size and Employment Type",
157
+ "description": "A grouped bar chart comparing average salaries across company sizes and employment types."
158
+ }},
159
+ {{
160
+ "chart_type": "heatmap",
161
+ "x_axis": "company_size",
162
+ "y_axis": "salary_in_usd",
163
+ "group_by": "employment_type",
164
+ "title": "Salary Heatmap by Company Size and Employment Type",
165
+ "description": "A heatmap showing salary concentration across company sizes and employment types."
166
+ }}
167
+ ]
168
+ - **Query:** "How has the average salary changed over the years?"
169
+ **Suggested Visualization:**
170
+ {{
171
+ "chart_type": "line",
172
+ "x_axis": "work_year",
173
+ "y_axis": "salary_in_usd",
174
+ "group_by": "experience_level",
175
+ "title": "Average Salary Trend Over Years",
176
+ "description": "A line chart showing how the average salary has changed across different experience levels over the years."
177
+ }}
178
+ - **Query:** "What is the employee distribution by company location?"
179
+ **Suggested Visualization:**
180
+ {{
181
+ "chart_type": "pie",
182
+ "x_axis": "company_location",
183
+ "y_axis": null,
184
+ "group_by": null,
185
+ "title": "Employee Distribution by Company Location",
186
+ "description": "A pie chart showing the distribution of employees across company locations."
187
+ }}
188
+ - **Query:** "Is there a relationship between remote work ratio and salary?"
189
+ **Suggested Visualization:**
190
+ {{
191
+ "chart_type": "scatter",
192
+ "x_axis": "remote_ratio",
193
+ "y_axis": "salary_in_usd",
194
+ "group_by": "experience_level",
195
+ "title": "Remote Work Ratio vs Salary",
196
+ "description": "A scatter plot to analyze the relationship between remote work ratio and salary."
197
+ }}
198
+ - **Query:** "Which job titles have the highest salaries across regions?"
199
+ **Suggested Visualization:**
200
+ {{
201
+ "chart_type": "heatmap",
202
+ "x_axis": "job_title",
203
+ "y_axis": "employee_residence",
204
+ "group_by": null,
205
+ "title": "Salary Heatmap by Job Title and Region",
206
+ "description": "A heatmap showing the concentration of high-paying job titles across regions."
207
+ }}
208
+ Only suggest visualizations that logically match the query and dataset.
209
+ """
210
+
211
+ for attempt in range(retries + 1):
212
+ try:
213
+ response = llm.generate(prompt)
214
+ suggestions = json.loads(response)
215
+
216
+ if isinstance(suggestions, list):
217
+ valid_suggestions = [s for s in suggestions if is_valid_suggestion(s)]
218
+ if valid_suggestions:
219
+ return valid_suggestions
220
+ else:
221
+ st.warning("⚠️ GPT-4o did not suggest valid visualizations.")
222
+ return None
223
+
224
+ elif isinstance(suggestions, dict):
225
+ if is_valid_suggestion(suggestions):
226
+ return [suggestions]
227
+ else:
228
+ st.warning("⚠️ GPT-4o's suggestion is incomplete or invalid.")
229
+ return None
230
+
231
+ except json.JSONDecodeError:
232
+ st.warning(f"⚠️ Attempt {attempt + 1}: GPT-4o returned invalid JSON.")
233
+ except Exception as e:
234
+ st.error(f"⚠️ Error during GPT-4o call: {e}")
235
+
236
+ if attempt < retries:
237
+ st.info("πŸ”„ Retrying visualization suggestion...")
238
+
239
+ st.error("❌ Failed to generate a valid visualization after multiple attempts.")
240
+ return None
241
+
242
+
243
+ def add_stats_to_figure(fig, df, y_axis, chart_type):
244
+ """
245
+ Add relevant statistical annotations to the visualization
246
+ based on the chart type.
247
+ """
248
+ # Check if the y-axis column is numeric
249
+ if not pd.api.types.is_numeric_dtype(df[y_axis]):
250
+ st.warning(f"⚠️ Cannot compute statistics for non-numeric column: {y_axis}")
251
+ return fig
252
+
253
+ # Compute statistics for numeric data
254
+ min_val = df[y_axis].min()
255
+ max_val = df[y_axis].max()
256
+ avg_val = df[y_axis].mean()
257
+ median_val = df[y_axis].median()
258
+ std_dev_val = df[y_axis].std()
259
+
260
+ # Format the stats for display
261
+ stats_text = (
262
+ f"πŸ“Š **Statistics**\n\n"
263
+ f"- **Min:** ${min_val:,.2f}\n"
264
+ f"- **Max:** ${max_val:,.2f}\n"
265
+ f"- **Average:** ${avg_val:,.2f}\n"
266
+ f"- **Median:** ${median_val:,.2f}\n"
267
+ f"- **Std Dev:** ${std_dev_val:,.2f}"
268
+ )
269
+
270
+ # Apply stats only to relevant chart types
271
+ if chart_type in ["bar", "line"]:
272
+ # Add annotation box for bar and line charts
273
+ fig.add_annotation(
274
+ text=stats_text,
275
+ xref="paper", yref="paper",
276
+ x=1.02, y=1,
277
+ showarrow=False,
278
+ align="left",
279
+ font=dict(size=12, color="black"),
280
+ bordercolor="gray",
281
+ borderwidth=1,
282
+ bgcolor="rgba(255, 255, 255, 0.85)"
283
+ )
284
+
285
+ # Add horizontal reference lines
286
+ fig.add_hline(y=min_val, line_dash="dot", line_color="red", annotation_text="Min", annotation_position="bottom right")
287
+ fig.add_hline(y=median_val, line_dash="dash", line_color="orange", annotation_text="Median", annotation_position="top right")
288
+ fig.add_hline(y=avg_val, line_dash="dashdot", line_color="green", annotation_text="Avg", annotation_position="top right")
289
+ fig.add_hline(y=max_val, line_dash="dot", line_color="blue", annotation_text="Max", annotation_position="top right")
290
+
291
+ elif chart_type == "scatter":
292
+ # Add stats annotation only, no lines for scatter plots
293
+ fig.add_annotation(
294
+ text=stats_text,
295
+ xref="paper", yref="paper",
296
+ x=1.02, y=1,
297
+ showarrow=False,
298
+ align="left",
299
+ font=dict(size=12, color="black"),
300
+ bordercolor="gray",
301
+ borderwidth=1,
302
+ bgcolor="rgba(255, 255, 255, 0.85)"
303
+ )
304
+
305
+ elif chart_type == "box":
306
+ # Box plots inherently show distribution; no extra stats needed
307
+ pass
308
+
309
+ elif chart_type == "pie":
310
+ # Pie charts represent proportions, not suitable for stats
311
+ st.info("πŸ“Š Pie charts represent proportions. Additional stats are not applicable.")
312
+
313
+ elif chart_type == "heatmap":
314
+ # Heatmaps already reflect data intensity
315
+ st.info("πŸ“Š Heatmaps inherently reflect distribution. No additional stats added.")
316
+
317
+ else:
318
+ st.warning(f"⚠️ No statistical overlays applied for unsupported chart type: '{chart_type}'.")
319
+
320
+ return fig
321
+
322
+
323
+ # Dynamically generate Plotly visualizations based on GPT-4o suggestions
324
+ def generate_visualization(suggestion, df):
325
+ """
326
+ Generate a Plotly visualization based on GPT-4o's suggestion.
327
+ If the Y-axis is missing, infer it intelligently.
328
+ """
329
+ chart_type = suggestion.get("chart_type", "bar").lower()
330
+ x_axis = suggestion.get("x_axis")
331
+ y_axis = suggestion.get("y_axis")
332
+ group_by = suggestion.get("group_by")
333
+
334
+ # Step 1: Infer Y-axis if not provided
335
+ if not y_axis:
336
+ numeric_columns = df.select_dtypes(include='number').columns.tolist()
337
+
338
+ # Avoid using the same column for both axes
339
+ if x_axis in numeric_columns:
340
+ numeric_columns.remove(x_axis)
341
+
342
+ # Smart guess: prioritize salary or relevant metrics if available
343
+ priority_columns = ["salary_in_usd", "income", "earnings", "revenue"]
344
+ for col in priority_columns:
345
+ if col in numeric_columns:
346
+ y_axis = col
347
+ break
348
+
349
+ # Fallback to the first numeric column if no priority columns exist
350
+ if not y_axis and numeric_columns:
351
+ y_axis = numeric_columns[0]
352
+
353
+ # Step 2: Validate axes
354
+ if not x_axis or not y_axis:
355
+ st.warning("⚠️ Unable to determine appropriate columns for visualization.")
356
+ return None
357
+
358
+ # Step 3: Dynamically select the Plotly function
359
+ plotly_function = getattr(px, chart_type, None)
360
+ if not plotly_function:
361
+ st.warning(f"⚠️ Unsupported chart type '{chart_type}' suggested by GPT-4o.")
362
+ return None
363
+
364
+ # Step 4: Prepare dynamic plot arguments
365
+ plot_args = {"data_frame": df, "x": x_axis, "y": y_axis}
366
+ if group_by and group_by in df.columns:
367
+ plot_args["color"] = group_by
368
+
369
+ try:
370
+ # Step 5: Generate the visualization
371
+ fig = plotly_function(**plot_args)
372
+ fig.update_layout(
373
+ title=f"{chart_type.title()} Plot of {y_axis.replace('_', ' ').title()} by {x_axis.replace('_', ' ').title()}",
374
+ xaxis_title=x_axis.replace('_', ' ').title(),
375
+ yaxis_title=y_axis.replace('_', ' ').title(),
376
+ )
377
+
378
+ # Step 6: Apply statistics intelligently
379
+ fig = add_statistics_to_visualization(fig, df, y_axis, chart_type)
380
+
381
+ return fig
382
+
383
+ except Exception as e:
384
+ st.error(f"⚠️ Failed to generate visualization: {e}")
385
+ return None
386
+
387
+
388
+ def generate_multiple_visualizations(suggestions, df):
389
+ """
390
+ Generates one or more visualizations based on GPT-4o's suggestions.
391
+ Handles both single and multiple suggestions.
392
+ """
393
+ visualizations = []
394
+
395
+ for suggestion in suggestions:
396
+ fig = generate_visualization(suggestion, df)
397
+ if fig:
398
+ # Apply chart-specific statistics
399
+ fig = add_stats_to_figure(fig, df, suggestion["y_axis"], suggestion["chart_type"])
400
+ visualizations.append(fig)
401
+
402
+ if not visualizations and suggestions:
403
+ st.warning("⚠️ No valid visualization found. Displaying the most relevant one.")
404
+ best_suggestion = suggestions[0]
405
+ fig = generate_visualization(best_suggestion, df)
406
+ fig = add_stats_to_figure(fig, df, best_suggestion["y_axis"], best_suggestion["chart_type"])
407
+ visualizations.append(fig)
408
+
409
+ return visualizations
410
+
411
+
412
+ def handle_visualization_suggestions(suggestions, df):
413
+ """
414
+ Determines whether to generate a single or multiple visualizations.
415
+ """
416
+ visualizations = []
417
+
418
+ # If multiple suggestions, generate multiple plots
419
+ if isinstance(suggestions, list) and len(suggestions) > 1:
420
+ visualizations = generate_multiple_visualizations(suggestions, df)
421
+
422
+ # If only one suggestion, generate a single plot
423
+ elif isinstance(suggestions, dict) or (isinstance(suggestions, list) and len(suggestions) == 1):
424
+ suggestion = suggestions[0] if isinstance(suggestions, list) else suggestions
425
+ fig = generate_visualization(suggestion, df)
426
+ if fig:
427
+ visualizations.append(fig)
428
+
429
+ # Handle cases when no visualization could be generated
430
+ if not visualizations:
431
+ st.warning("⚠️ Unable to generate any visualization based on the suggestion.")
432
+
433
+ # Display all generated visualizations
434
+ for fig in visualizations:
435
+ st.plotly_chart(fig, use_container_width=True)
436
+
437
+
438
+ def escape_markdown(text):
439
+ # Ensure text is a string
440
+ text = str(text)
441
+ # Escape Markdown characters: *, _, `, ~
442
+ escape_chars = r"(\*|_|`|~)"
443
+ return re.sub(escape_chars, r"\\\1", text)
444
+
445
+
446
+ # SQL-RAG Analysis
447
+ if st.session_state.df is not None:
448
+ temp_dir = tempfile.TemporaryDirectory()
449
+ db_path = os.path.join(temp_dir.name, "data.db")
450
+ connection = sqlite3.connect(db_path)
451
+ st.session_state.df.to_sql("salaries", connection, if_exists="replace", index=False)
452
+ db = SQLDatabase.from_uri(f"sqlite:///{db_path}")
453
+
454
+ @tool("list_tables")
455
+ def list_tables() -> str:
456
+ """List all tables in the database."""
457
+ return ListSQLDatabaseTool(db=db).invoke("")
458
+
459
+ @tool("tables_schema")
460
+ def tables_schema(tables: str) -> str:
461
+ """Get the schema and sample rows for the specified tables."""
462
+ return InfoSQLDatabaseTool(db=db).invoke(tables)
463
+
464
+ @tool("execute_sql")
465
+ def execute_sql(sql_query: str) -> str:
466
+ """Execute a SQL query against the database and return the results."""
467
+ return QuerySQLDataBaseTool(db=db).invoke(sql_query)
468
+
469
+ @tool("check_sql")
470
+ def check_sql(sql_query: str) -> str:
471
+ """Validate the SQL query syntax and structure before execution."""
472
+ return QuerySQLCheckerTool(db=db, llm=llm).invoke({"query": sql_query})
473
+
474
+ # Agents for SQL data extraction and analysis
475
+ sql_dev = Agent(
476
+ role="Senior Database Developer",
477
+ goal="Extract data using optimized SQL queries.",
478
+ backstory="An expert in writing optimized SQL queries for complex databases.",
479
+ llm=llm,
480
+ tools=[list_tables, tables_schema, execute_sql, check_sql],
481
+ )
482
+
483
+ data_analyst = Agent(
484
+ role="Senior Data Analyst",
485
+ goal="Analyze the data and produce insights.",
486
+ backstory="A seasoned analyst who identifies trends and patterns in datasets.",
487
+ llm=llm,
488
+ )
489
+
490
+ report_writer = Agent(
491
+ role="Technical Report Writer",
492
+ goal="Write a structured report with Introduction and Key Insights. DO NOT include any Conclusion or Summary.",
493
+ backstory="Specializes in detailed analytical reports without conclusions.",
494
+ llm=llm,
495
+ )
496
+
497
+ conclusion_writer = Agent(
498
+ role="Conclusion Specialist",
499
+ goal="Summarize findings into a clear and concise 3-5 line Conclusion highlighting only the most important insights.",
500
+ backstory="An expert in crafting impactful and clear conclusions.",
501
+ llm=llm,
502
+ )
503
+
504
+ # Define tasks for report and conclusion
505
+ extract_data = Task(
506
+ description="Extract data based on the query: {query}.",
507
+ expected_output="Database results matching the query.",
508
+ agent=sql_dev,
509
+ )
510
+
511
+ analyze_data = Task(
512
+ description="Analyze the extracted data for query: {query}.",
513
+ expected_output="Key Insights and Analysis without any Introduction or Conclusion.",
514
+ agent=data_analyst,
515
+ context=[extract_data],
516
+ )
517
+
518
+ write_report = Task(
519
+ description="Write the analysis report with Introduction and Key Insights. DO NOT include any Conclusion or Summary.",
520
+ expected_output="Markdown-formatted report excluding Conclusion.",
521
+ agent=report_writer,
522
+ context=[analyze_data],
523
+ )
524
+
525
+ write_conclusion = Task(
526
+ description="Summarize the key findings in 3-5 impactful lines, highlighting the maximum, minimum, and average salaries."
527
+ "Emphasize significant insights on salary distribution and influential compensation trends for strategic decision-making.",
528
+ expected_output="Markdown-formatted Conclusion section with key insights and statistics.",
529
+ agent=conclusion_writer,
530
+ context=[analyze_data],
531
+ )
532
+
533
+ # Separate Crews for report and conclusion
534
+ crew_report = Crew(
535
+ agents=[sql_dev, data_analyst, report_writer],
536
+ tasks=[extract_data, analyze_data, write_report],
537
+ process=Process.sequential,
538
+ verbose=True,
539
+ )
540
+
541
+ crew_conclusion = Crew(
542
+ agents=[data_analyst, conclusion_writer],
543
+ tasks=[write_conclusion],
544
+ process=Process.sequential,
545
+ verbose=True,
546
+ )
547
+
548
+ # Tabs for Query Results and Visualizations
549
+ tab1, tab2 = st.tabs(["πŸ” Query Insights + Viz", "πŸ“Š Full Data Viz"])
550
+
551
+ # Query Insights + Visualization
552
+ with tab1:
553
+ query = st.text_area("Enter Query:", value="Provide insights into the salary of a Principal Data Scientist.")
554
+ if st.button("Submit Query"):
555
+ result_container = {"report": None, "conclusion": None, "visuals": None}
556
+ progress_bar = st.progress(0, text="πŸš€ Starting Analysis...")
557
+
558
+ # Define parallel tasks
559
+ def generate_report():
560
+ progress_bar.progress(20, text="πŸ“ Generating Analysis Report...")
561
+ report_inputs = {"query": query + " Provide detailed analysis but DO NOT include Conclusion."}
562
+ result_container['report'] = crew_report.kickoff(inputs=report_inputs)
563
+ progress_bar.progress(40, text="βœ… Analysis Report Ready!")
564
+
565
+ def generate_conclusion():
566
+ progress_bar.progress(40, text="πŸ“ Crafting Conclusion...")
567
+ conclusion_inputs = {"query": query + " Provide ONLY the most important insights in 3-5 concise lines."}
568
+ result_container['conclusion'] = crew_conclusion.kickoff(inputs=conclusion_inputs)
569
+ progress_bar.progress(60, text="βœ… Conclusion Ready!")
570
+
571
+ def generate_visuals():
572
+ progress_bar.progress(60, text="πŸ“Š Creating Visualizations...")
573
+ result_container['visuals'] = ask_gpt4o_for_visualization(query, st.session_state.df, llm)
574
+ progress_bar.progress(80, text="βœ… Visualizations Ready!")
575
+
576
+ # Run tasks in parallel
577
+ thread_report = threading.Thread(target=generate_report)
578
+ thread_conclusion = threading.Thread(target=generate_conclusion)
579
+ thread_visuals = threading.Thread(target=generate_visuals)
580
+
581
+ thread_report.start()
582
+ thread_conclusion.start()
583
+ thread_visuals.start()
584
+
585
+ # Wait for all threads to finish
586
+ thread_report.join()
587
+ thread_conclusion.join()
588
+ thread_visuals.join()
589
+
590
+ progress_bar.progress(100, text="βœ… Full Analysis Complete!")
591
+ time.sleep(0.5)
592
+ progress_bar.empty()
593
+
594
+ # Display Report
595
+ st.markdown("## πŸ“Š Analysis Report")
596
+ st.markdown(result_container['report'] if result_container['report'] else "⚠️ No Report Generated.")
597
+
598
+ # Display Visual Insights
599
+ st.markdown("## πŸ“ˆ Visual Insights")
600
+ if result_container['visuals']:
601
+ handle_visualization_suggestions(result_container['visuals'], st.session_state.df)
602
+ else:
603
+ st.warning("⚠️ No suitable visualizations to display.")
604
+
605
+ # Display Conclusion
606
+ st.markdown("## πŸ“ Conclusion")
607
+ safe_conclusion = escape_markdown(result_container['conclusion'] if result_container['conclusion'] else "⚠️ No Conclusion Generated.")
608
+ st.markdown(safe_conclusion)
609
+
610
+
611
+ # Sidebar Reference
612
+ with st.sidebar:
613
+ st.header("πŸ“š Reference:")
614
+ st.markdown("[SQL Agents w CrewAI & Llama 3 - Plaban Nayak](https://github.com/plaban1981/Agents/blob/main/SQL_Agents_with_CrewAI_and_Llama_3.ipynb)")