def ask_gpt4o_for_visualization(query, df, llm): columns = ', '.join(df.columns) prompt = f""" Analyze the query and suggest one or more relevant visualizations. Query: "{query}" Available Columns: {columns} Respond in this JSON format (as a list if multiple suggestions): [ {{ "chart_type": "bar/box/line/scatter", "x_axis": "column_name", "y_axis": "column_name", "group_by": "optional_column_name" }} ] """ response = llm.generate(prompt) try: return json.loads(response) except json.JSONDecodeError: st.error("⚠️ GPT-4o failed to generate a valid suggestion.") return None def add_stats_to_figure(fig, df, y_axis, chart_type): """ Add relevant statistical annotations to the visualization based on the chart type. """ # Check if the y-axis column is numeric if not pd.api.types.is_numeric_dtype(df[y_axis]): st.warning(f"⚠️ Cannot compute statistics for non-numeric column: {y_axis}") return fig # Compute statistics for numeric data min_val = df[y_axis].min() max_val = df[y_axis].max() avg_val = df[y_axis].mean() median_val = df[y_axis].median() std_dev_val = df[y_axis].std() # Format the stats for display stats_text = ( f"📊 **Statistics**\n\n" f"- **Min:** ${min_val:,.2f}\n" f"- **Max:** ${max_val:,.2f}\n" f"- **Average:** ${avg_val:,.2f}\n" f"- **Median:** ${median_val:,.2f}\n" f"- **Std Dev:** ${std_dev_val:,.2f}" ) # Apply stats only to relevant chart types if chart_type in ["bar", "line"]: # Add annotation box for bar and line charts fig.add_annotation( text=stats_text, xref="paper", yref="paper", x=1.02, y=1, showarrow=False, align="left", font=dict(size=12, color="black"), bordercolor="gray", borderwidth=1, bgcolor="rgba(255, 255, 255, 0.85)" ) # Add horizontal reference lines fig.add_hline(y=min_val, line_dash="dot", line_color="red", annotation_text="Min", annotation_position="bottom right") fig.add_hline(y=median_val, line_dash="dash", line_color="orange", annotation_text="Median", annotation_position="top right") fig.add_hline(y=avg_val, line_dash="dashdot", line_color="green", annotation_text="Avg", annotation_position="top right") fig.add_hline(y=max_val, line_dash="dot", line_color="blue", annotation_text="Max", annotation_position="top right") elif chart_type == "scatter": # Add stats annotation only, no lines for scatter plots fig.add_annotation( text=stats_text, xref="paper", yref="paper", x=1.02, y=1, showarrow=False, align="left", font=dict(size=12, color="black"), bordercolor="gray", borderwidth=1, bgcolor="rgba(255, 255, 255, 0.85)" ) elif chart_type == "box": # Box plots inherently show distribution; no extra stats needed pass elif chart_type == "pie": # Pie charts represent proportions, not suitable for stats st.info("📊 Pie charts represent proportions. Additional stats are not applicable.") elif chart_type == "heatmap": # Heatmaps already reflect data intensity st.info("📊 Heatmaps inherently reflect distribution. No additional stats added.") else: st.warning(f"⚠️ No statistical overlays applied for unsupported chart type: '{chart_type}'.") return fig # Dynamically generate Plotly visualizations based on GPT-4o suggestions def generate_visualization(suggestion, df): """ Generate a Plotly visualization based on GPT-4o's suggestion. If the Y-axis is missing, infer it intelligently. """ chart_type = suggestion.get("chart_type", "bar").lower() x_axis = suggestion.get("x_axis") y_axis = suggestion.get("y_axis") group_by = suggestion.get("group_by") # Step 1: Infer Y-axis if not provided if not y_axis: numeric_columns = df.select_dtypes(include='number').columns.tolist() # Avoid using the same column for both axes if x_axis in numeric_columns: numeric_columns.remove(x_axis) # Smart guess: prioritize salary or relevant metrics if available priority_columns = ["salary_in_usd", "income", "earnings", "revenue"] for col in priority_columns: if col in numeric_columns: y_axis = col break # Fallback to the first numeric column if no priority columns exist if not y_axis and numeric_columns: y_axis = numeric_columns[0] # Step 2: Validate axes if not x_axis or not y_axis: st.warning("⚠️ Unable to determine appropriate columns for visualization.") return None # Step 3: Dynamically select the Plotly function plotly_function = getattr(px, chart_type, None) if not plotly_function: st.warning(f"⚠️ Unsupported chart type '{chart_type}' suggested by GPT-4o.") return None # Step 4: Prepare dynamic plot arguments plot_args = {"data_frame": df, "x": x_axis, "y": y_axis} if group_by and group_by in df.columns: plot_args["color"] = group_by try: # Step 5: Generate the visualization fig = plotly_function(**plot_args) fig.update_layout( title=f"{chart_type.title()} Plot of {y_axis.replace('_', ' ').title()} by {x_axis.replace('_', ' ').title()}", xaxis_title=x_axis.replace('_', ' ').title(), yaxis_title=y_axis.replace('_', ' ').title(), ) # Step 6: Apply statistics intelligently fig = add_statistics_to_visualization(fig, df, y_axis, chart_type) return fig except Exception as e: st.error(f"⚠️ Failed to generate visualization: {e}") return None def generate_multiple_visualizations(suggestions, df): """ Generates one or more visualizations based on GPT-4o's suggestions. Handles both single and multiple suggestions. """ visualizations = [] for suggestion in suggestions: fig = generate_visualization(suggestion, df) if fig: # Apply chart-specific statistics fig = add_stats_to_figure(fig, df, suggestion["y_axis"], suggestion["chart_type"]) visualizations.append(fig) if not visualizations and suggestions: st.warning("⚠️ No valid visualization found. Displaying the most relevant one.") best_suggestion = suggestions[0] fig = generate_visualization(best_suggestion, df) fig = add_stats_to_figure(fig, df, best_suggestion["y_axis"], best_suggestion["chart_type"]) visualizations.append(fig) return visualizations def handle_visualization_suggestions(suggestions, df): """ Determines whether to generate a single or multiple visualizations. """ visualizations = [] # If multiple suggestions, generate multiple plots if isinstance(suggestions, list) and len(suggestions) > 1: visualizations = generate_multiple_visualizations(suggestions, df) # If only one suggestion, generate a single plot elif isinstance(suggestions, dict) or (isinstance(suggestions, list) and len(suggestions) == 1): suggestion = suggestions[0] if isinstance(suggestions, list) else suggestions fig = generate_visualization(suggestion, df) if fig: visualizations.append(fig) # Handle cases when no visualization could be generated if not visualizations: st.warning("⚠️ Unable to generate any visualization based on the suggestion.") # Display all generated visualizations for fig in visualizations: st.plotly_chart(fig, use_container_width=True) ----------------- def ask_gpt4o_for_visualization(query, df, llm, retries=2): import json # Identify numeric and categorical columns numeric_columns = df.select_dtypes(include='number').columns.tolist() categorical_columns = df.select_dtypes(exclude='number').columns.tolist() # Enhanced Prompt with Dataset-Specific, Query-Based Examples prompt = f""" Analyze the following query and suggest the most suitable visualization(s) using the dataset. **Query:** "{query}" **Dataset Overview:** - **Numeric Columns (for Y-axis):** {', '.join(numeric_columns) if numeric_columns else 'None'} - **Categorical Columns (for X-axis or grouping):** {', '.join(categorical_columns) if categorical_columns else 'None'} **Expected JSON Response:** [ {{ "chart_type": "bar/box/line/scatter/pie/heatmap", "x_axis": "categorical_or_time_column", "y_axis": "numeric_column", "group_by": "optional_column_for_grouping", "title": "Title of the chart", "description": "Why this chart is suitable" }} ] **Query-Based Examples:** - **Query:** "What is the salary distribution across different job titles?" **Suggested Visualization:** {{ "chart_type": "box", "x_axis": "job_title", "y_axis": "salary_in_usd", "group_by": "experience_level", "title": "Salary Distribution by Job Title and Experience", "description": "A box plot to show how salaries vary across different job titles and experience levels." }} - **Query:** "Show the average salary by company size and employment type." **Suggested Visualizations:** [ {{ "chart_type": "bar", "x_axis": "company_size", "y_axis": "salary_in_usd", "group_by": "employment_type", "title": "Average Salary by Company Size and Employment Type", "description": "A grouped bar chart comparing average salaries across company sizes and employment types." }}, {{ "chart_type": "heatmap", "x_axis": "company_size", "y_axis": "salary_in_usd", "group_by": "employment_type", "title": "Salary Heatmap by Company Size and Employment Type", "description": "A heatmap showing salary concentration across company sizes and employment types." }} ] - **Query:** "How has the average salary changed over the years?" **Suggested Visualization:** {{ "chart_type": "line", "x_axis": "work_year", "y_axis": "salary_in_usd", "group_by": "experience_level", "title": "Average Salary Trend Over Years", "description": "A line chart showing how the average salary has changed across different experience levels over the years." }} - **Query:** "What is the employee distribution by company location?" **Suggested Visualization:** {{ "chart_type": "pie", "x_axis": "company_location", "y_axis": null, "group_by": null, "title": "Employee Distribution by Company Location", "description": "A pie chart showing the distribution of employees across company locations." }} - **Query:** "Is there a relationship between remote work ratio and salary?" **Suggested Visualization:** {{ "chart_type": "scatter", "x_axis": "remote_ratio", "y_axis": "salary_in_usd", "group_by": "experience_level", "title": "Remote Work Ratio vs Salary", "description": "A scatter plot to analyze the relationship between remote work ratio and salary." }} - **Query:** "Which job titles have the highest salaries across regions?" **Suggested Visualization:** {{ "chart_type": "heatmap", "x_axis": "job_title", "y_axis": "employee_residence", "group_by": null, "title": "Salary Heatmap by Job Title and Region", "description": "A heatmap showing the concentration of high-paying job titles across regions." }} Only suggest visualizations that logically match the query and dataset. """ # Attempt LLM Response with Retry for attempt in range(retries + 1): try: response = llm.generate(prompt) suggestions = json.loads(response) # Validate suggestions using helper if isinstance(suggestions, list): valid_suggestions = [s for s in suggestions if is_valid_suggestion(s)] if valid_suggestions: return valid_suggestions else: st.warning("⚠️ GPT-4o did not suggest valid visualizations.") return None elif isinstance(suggestions, dict): if is_valid_suggestion(suggestions): return [suggestions] else: st.warning("⚠️ GPT-4o's suggestion is incomplete or invalid.") return None except json.JSONDecodeError: st.warning(f"⚠️ Attempt {attempt + 1}: GPT-4o returned invalid JSON.") except Exception as e: st.error(f"⚠️ Error during GPT-4o call: {e}") if attempt < retries: st.info("🔄 Retrying visualization suggestion...") st.error("❌ Failed to generate a valid visualization after multiple attempts.") return None