subtest / dummy_funcs.py
DrishtiSharma's picture
Update dummy_funcs.py
70acfe7 verified
def ask_gpt4o_for_visualization(query, df, llm):
columns = ', '.join(df.columns)
prompt = f"""
Analyze the query and suggest one or more relevant visualizations.
Query: "{query}"
Available Columns: {columns}
Respond in this JSON format (as a list if multiple suggestions):
[
{{
"chart_type": "bar/box/line/scatter",
"x_axis": "column_name",
"y_axis": "column_name",
"group_by": "optional_column_name"
}}
]
"""
response = llm.generate(prompt)
try:
return json.loads(response)
except json.JSONDecodeError:
st.error("⚠️ GPT-4o failed to generate a valid suggestion.")
return None
def add_stats_to_figure(fig, df, y_axis, chart_type):
"""
Add relevant statistical annotations to the visualization
based on the chart type.
"""
# Check if the y-axis column is numeric
if not pd.api.types.is_numeric_dtype(df[y_axis]):
st.warning(f"⚠️ Cannot compute statistics for non-numeric column: {y_axis}")
return fig
# Compute statistics for numeric data
min_val = df[y_axis].min()
max_val = df[y_axis].max()
avg_val = df[y_axis].mean()
median_val = df[y_axis].median()
std_dev_val = df[y_axis].std()
# Format the stats for display
stats_text = (
f"πŸ“Š **Statistics**\n\n"
f"- **Min:** ${min_val:,.2f}\n"
f"- **Max:** ${max_val:,.2f}\n"
f"- **Average:** ${avg_val:,.2f}\n"
f"- **Median:** ${median_val:,.2f}\n"
f"- **Std Dev:** ${std_dev_val:,.2f}"
)
# Apply stats only to relevant chart types
if chart_type in ["bar", "line"]:
# Add annotation box for bar and line charts
fig.add_annotation(
text=stats_text,
xref="paper", yref="paper",
x=1.02, y=1,
showarrow=False,
align="left",
font=dict(size=12, color="black"),
bordercolor="gray",
borderwidth=1,
bgcolor="rgba(255, 255, 255, 0.85)"
)
# Add horizontal reference lines
fig.add_hline(y=min_val, line_dash="dot", line_color="red", annotation_text="Min", annotation_position="bottom right")
fig.add_hline(y=median_val, line_dash="dash", line_color="orange", annotation_text="Median", annotation_position="top right")
fig.add_hline(y=avg_val, line_dash="dashdot", line_color="green", annotation_text="Avg", annotation_position="top right")
fig.add_hline(y=max_val, line_dash="dot", line_color="blue", annotation_text="Max", annotation_position="top right")
elif chart_type == "scatter":
# Add stats annotation only, no lines for scatter plots
fig.add_annotation(
text=stats_text,
xref="paper", yref="paper",
x=1.02, y=1,
showarrow=False,
align="left",
font=dict(size=12, color="black"),
bordercolor="gray",
borderwidth=1,
bgcolor="rgba(255, 255, 255, 0.85)"
)
elif chart_type == "box":
# Box plots inherently show distribution; no extra stats needed
pass
elif chart_type == "pie":
# Pie charts represent proportions, not suitable for stats
st.info("πŸ“Š Pie charts represent proportions. Additional stats are not applicable.")
elif chart_type == "heatmap":
# Heatmaps already reflect data intensity
st.info("πŸ“Š Heatmaps inherently reflect distribution. No additional stats added.")
else:
st.warning(f"⚠️ No statistical overlays applied for unsupported chart type: '{chart_type}'.")
return fig
# Dynamically generate Plotly visualizations based on GPT-4o suggestions
def generate_visualization(suggestion, df):
"""
Generate a Plotly visualization based on GPT-4o's suggestion.
If the Y-axis is missing, infer it intelligently.
"""
chart_type = suggestion.get("chart_type", "bar").lower()
x_axis = suggestion.get("x_axis")
y_axis = suggestion.get("y_axis")
group_by = suggestion.get("group_by")
# Step 1: Infer Y-axis if not provided
if not y_axis:
numeric_columns = df.select_dtypes(include='number').columns.tolist()
# Avoid using the same column for both axes
if x_axis in numeric_columns:
numeric_columns.remove(x_axis)
# Smart guess: prioritize salary or relevant metrics if available
priority_columns = ["salary_in_usd", "income", "earnings", "revenue"]
for col in priority_columns:
if col in numeric_columns:
y_axis = col
break
# Fallback to the first numeric column if no priority columns exist
if not y_axis and numeric_columns:
y_axis = numeric_columns[0]
# Step 2: Validate axes
if not x_axis or not y_axis:
st.warning("⚠️ Unable to determine appropriate columns for visualization.")
return None
# Step 3: Dynamically select the Plotly function
plotly_function = getattr(px, chart_type, None)
if not plotly_function:
st.warning(f"⚠️ Unsupported chart type '{chart_type}' suggested by GPT-4o.")
return None
# Step 4: Prepare dynamic plot arguments
plot_args = {"data_frame": df, "x": x_axis, "y": y_axis}
if group_by and group_by in df.columns:
plot_args["color"] = group_by
try:
# Step 5: Generate the visualization
fig = plotly_function(**plot_args)
fig.update_layout(
title=f"{chart_type.title()} Plot of {y_axis.replace('_', ' ').title()} by {x_axis.replace('_', ' ').title()}",
xaxis_title=x_axis.replace('_', ' ').title(),
yaxis_title=y_axis.replace('_', ' ').title(),
)
# Step 6: Apply statistics intelligently
fig = add_statistics_to_visualization(fig, df, y_axis, chart_type)
return fig
except Exception as e:
st.error(f"⚠️ Failed to generate visualization: {e}")
return None
def generate_multiple_visualizations(suggestions, df):
"""
Generates one or more visualizations based on GPT-4o's suggestions.
Handles both single and multiple suggestions.
"""
visualizations = []
for suggestion in suggestions:
fig = generate_visualization(suggestion, df)
if fig:
# Apply chart-specific statistics
fig = add_stats_to_figure(fig, df, suggestion["y_axis"], suggestion["chart_type"])
visualizations.append(fig)
if not visualizations and suggestions:
st.warning("⚠️ No valid visualization found. Displaying the most relevant one.")
best_suggestion = suggestions[0]
fig = generate_visualization(best_suggestion, df)
fig = add_stats_to_figure(fig, df, best_suggestion["y_axis"], best_suggestion["chart_type"])
visualizations.append(fig)
return visualizations
def handle_visualization_suggestions(suggestions, df):
"""
Determines whether to generate a single or multiple visualizations.
"""
visualizations = []
# If multiple suggestions, generate multiple plots
if isinstance(suggestions, list) and len(suggestions) > 1:
visualizations = generate_multiple_visualizations(suggestions, df)
# If only one suggestion, generate a single plot
elif isinstance(suggestions, dict) or (isinstance(suggestions, list) and len(suggestions) == 1):
suggestion = suggestions[0] if isinstance(suggestions, list) else suggestions
fig = generate_visualization(suggestion, df)
if fig:
visualizations.append(fig)
# Handle cases when no visualization could be generated
if not visualizations:
st.warning("⚠️ Unable to generate any visualization based on the suggestion.")
# Display all generated visualizations
for fig in visualizations:
st.plotly_chart(fig, use_container_width=True)
-----------------
def ask_gpt4o_for_visualization(query, df, llm, retries=2):
import json
# Identify numeric and categorical columns
numeric_columns = df.select_dtypes(include='number').columns.tolist()
categorical_columns = df.select_dtypes(exclude='number').columns.tolist()
# Enhanced Prompt with Dataset-Specific, Query-Based Examples
prompt = f"""
Analyze the following query and suggest the most suitable visualization(s) using the dataset.
**Query:** "{query}"
**Dataset Overview:**
- **Numeric Columns (for Y-axis):** {', '.join(numeric_columns) if numeric_columns else 'None'}
- **Categorical Columns (for X-axis or grouping):** {', '.join(categorical_columns) if categorical_columns else 'None'}
**Expected JSON Response:**
[
{{
"chart_type": "bar/box/line/scatter/pie/heatmap",
"x_axis": "categorical_or_time_column",
"y_axis": "numeric_column",
"group_by": "optional_column_for_grouping",
"title": "Title of the chart",
"description": "Why this chart is suitable"
}}
]
**Query-Based Examples:**
- **Query:** "What is the salary distribution across different job titles?"
**Suggested Visualization:**
{{
"chart_type": "box",
"x_axis": "job_title",
"y_axis": "salary_in_usd",
"group_by": "experience_level",
"title": "Salary Distribution by Job Title and Experience",
"description": "A box plot to show how salaries vary across different job titles and experience levels."
}}
- **Query:** "Show the average salary by company size and employment type."
**Suggested Visualizations:**
[
{{
"chart_type": "bar",
"x_axis": "company_size",
"y_axis": "salary_in_usd",
"group_by": "employment_type",
"title": "Average Salary by Company Size and Employment Type",
"description": "A grouped bar chart comparing average salaries across company sizes and employment types."
}},
{{
"chart_type": "heatmap",
"x_axis": "company_size",
"y_axis": "salary_in_usd",
"group_by": "employment_type",
"title": "Salary Heatmap by Company Size and Employment Type",
"description": "A heatmap showing salary concentration across company sizes and employment types."
}}
]
- **Query:** "How has the average salary changed over the years?"
**Suggested Visualization:**
{{
"chart_type": "line",
"x_axis": "work_year",
"y_axis": "salary_in_usd",
"group_by": "experience_level",
"title": "Average Salary Trend Over Years",
"description": "A line chart showing how the average salary has changed across different experience levels over the years."
}}
- **Query:** "What is the employee distribution by company location?"
**Suggested Visualization:**
{{
"chart_type": "pie",
"x_axis": "company_location",
"y_axis": null,
"group_by": null,
"title": "Employee Distribution by Company Location",
"description": "A pie chart showing the distribution of employees across company locations."
}}
- **Query:** "Is there a relationship between remote work ratio and salary?"
**Suggested Visualization:**
{{
"chart_type": "scatter",
"x_axis": "remote_ratio",
"y_axis": "salary_in_usd",
"group_by": "experience_level",
"title": "Remote Work Ratio vs Salary",
"description": "A scatter plot to analyze the relationship between remote work ratio and salary."
}}
- **Query:** "Which job titles have the highest salaries across regions?"
**Suggested Visualization:**
{{
"chart_type": "heatmap",
"x_axis": "job_title",
"y_axis": "employee_residence",
"group_by": null,
"title": "Salary Heatmap by Job Title and Region",
"description": "A heatmap showing the concentration of high-paying job titles across regions."
}}
Only suggest visualizations that logically match the query and dataset.
"""
# Attempt LLM Response with Retry
for attempt in range(retries + 1):
try:
response = llm.generate(prompt)
suggestions = json.loads(response)
# Validate suggestions using helper
if isinstance(suggestions, list):
valid_suggestions = [s for s in suggestions if is_valid_suggestion(s)]
if valid_suggestions:
return valid_suggestions
else:
st.warning("⚠️ GPT-4o did not suggest valid visualizations.")
return None
elif isinstance(suggestions, dict):
if is_valid_suggestion(suggestions):
return [suggestions]
else:
st.warning("⚠️ GPT-4o's suggestion is incomplete or invalid.")
return None
except json.JSONDecodeError:
st.warning(f"⚠️ Attempt {attempt + 1}: GPT-4o returned invalid JSON.")
except Exception as e:
st.error(f"⚠️ Error during GPT-4o call: {e}")
if attempt < retries:
st.info("πŸ”„ Retrying visualization suggestion...")
st.error("❌ Failed to generate a valid visualization after multiple attempts.")
return None