DrishtiSharma commited on
Commit
58b0285
Β·
verified Β·
1 Parent(s): 43546ab

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +124 -139
app.py CHANGED
@@ -25,7 +25,7 @@ from datasets import load_dataset
25
  import tempfile
26
 
27
  st.title("SQL-RAG Using CrewAI πŸš€")
28
- st.write("Analyze datasets using natural language queries.")
29
 
30
  # Initialize LLM
31
  llm = None
@@ -87,152 +87,30 @@ if st.session_state.df is not None and st.session_state.show_preview:
87
  st.dataframe(st.session_state.df.head())
88
 
89
 
90
- # Helper Function for Validation
91
- def is_valid_suggestion(suggestion):
92
- chart_type = suggestion.get("chart_type", "").lower()
93
 
94
- if chart_type in ["bar", "line", "box", "scatter"]:
95
- return all(k in suggestion for k in ["chart_type", "x_axis", "y_axis"])
96
 
97
- elif chart_type == "pie":
98
- return all(k in suggestion for k in ["chart_type", "x_axis"])
99
-
100
- elif chart_type == "heatmap":
101
- return all(k in suggestion for k in ["chart_type", "x_axis", "y_axis"])
102
-
103
- else:
104
- return False
105
-
106
- def ask_gpt4o_for_visualization(query, df, llm, retries=2):
107
- import json
108
-
109
- # Identify numeric and categorical columns
110
- numeric_columns = df.select_dtypes(include='number').columns.tolist()
111
- categorical_columns = df.select_dtypes(exclude='number').columns.tolist()
112
-
113
- # Prompt with Dataset-Specific, Query-Based Examples
114
  prompt = f"""
115
- Analyze the following query and suggest the most suitable visualization(s) using the dataset.
116
- **Query:** "{query}"
117
- **Dataset Overview:**
118
- - **Numeric Columns (for Y-axis):** {', '.join(numeric_columns) if numeric_columns else 'None'}
119
- - **Categorical Columns (for X-axis or grouping):** {', '.join(categorical_columns) if categorical_columns else 'None'}
120
- Suggest visualizations in this exact JSON format:
121
  [
122
  {{
123
- "chdart_type": "bar/box/line/scatter/pie/heatmap",
124
- "x_axis": "categorical_or_time_column",
125
- "y_axis": "numeric_column",
126
- "group_by": "optional_column_for_grouping",
127
- "title": "Title of the chart",
128
- "description": "Why this chart is suitable"
129
  }}
130
  ]
131
- **Query-Based Examples:**
132
- - **Query:** "What is the salary distribution across different job titles?"
133
- **Suggested Visualization:**
134
- {{
135
- "chart_type": "box",
136
- "x_axis": "job_title",
137
- "y_axis": "salary_in_usd",
138
- "group_by": "experience_level",
139
- "title": "Salary Distribution by Job Title and Experience",
140
- "description": "A box plot to show how salaries vary across different job titles and experience levels."
141
- }}
142
- - **Query:** "Show the average salary by company size and employment type."
143
- **Suggested Visualizations:**
144
- [
145
- {{
146
- "chart_type": "bar",
147
- "x_axis": "company_size",
148
- "y_axis": "salary_in_usd",
149
- "group_by": "employment_type",
150
- "title": "Average Salary by Company Size and Employment Type",
151
- "description": "A grouped bar chart comparing average salaries across company sizes and employment types."
152
- }},
153
- {{
154
- "chart_type": "heatmap",
155
- "x_axis": "company_size",
156
- "y_axis": "salary_in_usd",
157
- "group_by": "employment_type",
158
- "title": "Salary Heatmap by Company Size and Employment Type",
159
- "description": "A heatmap showing salary concentration across company sizes and employment types."
160
- }}
161
- ]
162
- - **Query:** "How has the average salary changed over the years?"
163
- **Suggested Visualization:**
164
- {{
165
- "chart_type": "line",
166
- "x_axis": "work_year",
167
- "y_axis": "salary_in_usd",
168
- "group_by": "experience_level",
169
- "title": "Average Salary Trend Over Years",
170
- "description": "A line chart showing how the average salary has changed across different experience levels over the years."
171
- }}
172
- - **Query:** "What is the employee distribution by company location?"
173
- **Suggested Visualization:**
174
- {{
175
- "chart_type": "pie",
176
- "x_axis": "company_location",
177
- "y_axis": null,
178
- "group_by": null,
179
- "title": "Employee Distribution by Company Location",
180
- "description": "A pie chart showing the distribution of employees across company locations."
181
- }}
182
- - **Query:** "Is there a relationship between remote work ratio and salary?"
183
- **Suggested Visualization:**
184
- {{
185
- "chart_type": "scatter",
186
- "x_axis": "remote_ratio",
187
- "y_axis": "salary_in_usd",
188
- "group_by": "experience_level",
189
- "title": "Remote Work Ratio vs Salary",
190
- "description": "A scatter plot to analyze the relationship between remote work ratio and salary."
191
- }}
192
- - **Query:** "Which job titles have the highest salaries across regions?"
193
- **Suggested Visualization:**
194
- {{
195
- "chart_type": "heatmap",
196
- "x_axis": "job_title",
197
- "y_axis": "employee_residence",
198
- "group_by": null,
199
- "title": "Salary Heatmap by Job Title and Region",
200
- "description": "A heatmap showing the concentration of high-paying job titles across regions."
201
- }}
202
- Only suggest visualizations that logically match the query and dataset.
203
  """
204
-
205
- for attempt in range(retries + 1):
206
- try:
207
- response = llm.generate(prompt)
208
- suggestions = json.loads(response)
209
-
210
- if isinstance(suggestions, list):
211
- valid_suggestions = [s for s in suggestions if is_valid_suggestion(s)]
212
- if valid_suggestions:
213
- return valid_suggestions
214
- else:
215
- st.warning("⚠️ GPT-4o did not suggest valid visualizations.")
216
- return None
217
-
218
- elif isinstance(suggestions, dict):
219
- if is_valid_suggestion(suggestions):
220
- return [suggestions]
221
- else:
222
- st.warning("⚠️ GPT-4o's suggestion is incomplete or invalid.")
223
- return None
224
-
225
- except json.JSONDecodeError:
226
- st.warning(f"⚠️ Attempt {attempt + 1}: GPT-4o returned invalid JSON.")
227
- except Exception as e:
228
- st.error(f"⚠️ Error during GPT-4o call: {e}")
229
-
230
- if attempt < retries:
231
- st.info("πŸ”„ Retrying visualization suggestion...")
232
-
233
- st.error("❌ Failed to generate a valid visualization after multiple attempts.")
234
- return None
235
-
236
 
237
  def add_stats_to_figure(fig, df, y_axis, chart_type):
238
  """
@@ -429,6 +307,91 @@ def handle_visualization_suggestions(suggestions, df):
429
  st.plotly_chart(fig, use_container_width=True)
430
 
431
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
432
  def escape_markdown(text):
433
  # Ensure text is a string
434
  text = str(text)
@@ -573,6 +536,28 @@ if st.session_state.df is not None:
573
  safe_conclusion = escape_markdown(conclusion_result if conclusion_result else "⚠️ No Conclusion Generated.")
574
  st.markdown(safe_conclusion)
575
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
576
 
577
  # Sidebar Reference
578
  with st.sidebar:
 
25
  import tempfile
26
 
27
  st.title("SQL-RAG Using CrewAI πŸš€")
28
+ st.write("Analyze datasets using natural language queries powered by SQL and CrewAI.")
29
 
30
  # Initialize LLM
31
  llm = None
 
87
  st.dataframe(st.session_state.df.head())
88
 
89
 
 
 
 
90
 
 
 
91
 
92
+ def ask_gpt4o_for_visualization(query, df, llm):
93
+ columns = ', '.join(df.columns)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  prompt = f"""
95
+ Analyze the query and suggest one or more relevant visualizations.
96
+ Query: "{query}"
97
+ Available Columns: {columns}
98
+ Respond in this JSON format (as a list if multiple suggestions):
 
 
99
  [
100
  {{
101
+ "chart_type": "bar/box/line/scatter",
102
+ "x_axis": "column_name",
103
+ "y_axis": "column_name",
104
+ "group_by": "optional_column_name"
 
 
105
  }}
106
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  """
108
+ response = llm.generate(prompt)
109
+ try:
110
+ return json.loads(response)
111
+ except json.JSONDecodeError:
112
+ st.error("⚠️ GPT-4o failed to generate a valid suggestion.")
113
+ return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
  def add_stats_to_figure(fig, df, y_axis, chart_type):
116
  """
 
307
  st.plotly_chart(fig, use_container_width=True)
308
 
309
 
310
+
311
+ # Function to create TXT file
312
+ def create_text_report_with_viz_temp(report, conclusion, visualizations):
313
+ content = f"### Analysis Report\n\n{report}\n\n### Visualizations\n"
314
+
315
+ for i, fig in enumerate(visualizations, start=1):
316
+ fig_title = fig.layout.title.text if fig.layout.title.text else f"Visualization {i}"
317
+ x_axis = fig.layout.xaxis.title.text if fig.layout.xaxis.title.text else "X-axis"
318
+ y_axis = fig.layout.yaxis.title.text if fig.layout.yaxis.title.text else "Y-axis"
319
+
320
+ content += f"\n{i}. {fig_title}\n"
321
+ content += f" - X-axis: {x_axis}\n"
322
+ content += f" - Y-axis: {y_axis}\n"
323
+
324
+ if fig.data:
325
+ trace_types = set(trace.type for trace in fig.data)
326
+ content += f" - Chart Type(s): {', '.join(trace_types)}\n"
327
+ else:
328
+ content += " - No data available in this visualization.\n"
329
+
330
+ content += f"\n\n\n{conclusion}"
331
+
332
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".txt", mode='w', encoding='utf-8') as temp_txt:
333
+ temp_txt.write(content)
334
+ return temp_txt.name
335
+
336
+
337
+
338
+ # Function to create PDF with report text and visualizations
339
+ def create_pdf_report_with_viz(report, conclusion, visualizations):
340
+ pdf = FPDF()
341
+ pdf.set_auto_page_break(auto=True, margin=15)
342
+ pdf.add_page()
343
+ pdf.set_font("Arial", size=12)
344
+
345
+ # Title
346
+ pdf.set_font("Arial", style="B", size=18)
347
+ pdf.cell(0, 10, "πŸ“Š Analysis Report", ln=True, align="C")
348
+ pdf.ln(10)
349
+
350
+ # Report Content
351
+ pdf.set_font("Arial", style="B", size=14)
352
+ pdf.cell(0, 10, "Analysis", ln=True)
353
+ pdf.set_font("Arial", size=12)
354
+ pdf.multi_cell(0, 10, report)
355
+
356
+ pdf.ln(10)
357
+ pdf.set_font("Arial", style="B", size=14)
358
+ pdf.cell(0, 10, "Conclusion", ln=True)
359
+ pdf.set_font("Arial", size=12)
360
+ pdf.multi_cell(0, 10, conclusion)
361
+
362
+ # Add Visualizations
363
+ pdf.add_page()
364
+ pdf.set_font("Arial", style="B", size=16)
365
+ pdf.cell(0, 10, "πŸ“ˆ Visualizations", ln=True)
366
+ pdf.ln(5)
367
+
368
+ with tempfile.TemporaryDirectory() as temp_dir:
369
+ for i, fig in enumerate(visualizations, start=1):
370
+ fig_title = fig.layout.title.text if fig.layout.title.text else f"Visualization {i}"
371
+ x_axis = fig.layout.xaxis.title.text if fig.layout.xaxis.title.text else "X-axis"
372
+ y_axis = fig.layout.yaxis.title.text if fig.layout.yaxis.title.text else "Y-axis"
373
+
374
+ # Save each visualization as a PNG image
375
+ img_path = os.path.join(temp_dir, f"viz_{i}.png")
376
+ fig.write_image(img_path)
377
+
378
+ # Insert Title and Description
379
+ pdf.set_font("Arial", style="B", size=14)
380
+ pdf.multi_cell(0, 10, f"{i}. {fig_title}")
381
+ pdf.set_font("Arial", size=12)
382
+ pdf.multi_cell(0, 10, f"X-axis: {x_axis} | Y-axis: {y_axis}")
383
+ pdf.ln(3)
384
+
385
+ # Embed Visualization
386
+ pdf.image(img_path, w=170)
387
+ pdf.ln(10)
388
+
389
+ # Save PDF
390
+ temp_pdf = tempfile.NamedTemporaryFile(delete=False, suffix=".pdf")
391
+ pdf.output(temp_pdf.name)
392
+
393
+ return temp_pdf
394
+
395
  def escape_markdown(text):
396
  # Ensure text is a string
397
  text = str(text)
 
536
  safe_conclusion = escape_markdown(conclusion_result if conclusion_result else "⚠️ No Conclusion Generated.")
537
  st.markdown(safe_conclusion)
538
 
539
+ # Full Data Visualization Tab
540
+ with tab2:
541
+ st.subheader("πŸ“Š Comprehensive Data Visualizations")
542
+
543
+ fig1 = px.histogram(st.session_state.df, x="job_title", title="Job Title Frequency")
544
+ st.plotly_chart(fig1)
545
+
546
+ fig2 = px.bar(
547
+ st.session_state.df.groupby("experience_level")["salary_in_usd"].mean().reset_index(),
548
+ x="experience_level", y="salary_in_usd",
549
+ title="Average Salary by Experience Level"
550
+ )
551
+ st.plotly_chart(fig2)
552
+
553
+ fig3 = px.box(st.session_state.df, x="employment_type", y="salary_in_usd",
554
+ title="Salary Distribution by Employment Type")
555
+ st.plotly_chart(fig3)
556
+
557
+ temp_dir.cleanup()
558
+ else:
559
+ st.info("Please load a dataset to proceed.")
560
+
561
 
562
  # Sidebar Reference
563
  with st.sidebar: