Spaces:
Running
Running
import json | |
from pathlib import Path | |
import pandas as pd | |
import plotly.express as px | |
from pareto_utils import Agent, compute_pareto_frontier | |
import plotly.graph_objects as go | |
def parse_json_files(folder_path): | |
# Convert folder path to Path object | |
folder = Path(folder_path) | |
# List to store data from each file | |
data_list = [] | |
# Iterate through all JSON files in the folder | |
for json_file in folder.glob('*.json'): | |
try: | |
with open(json_file, 'r') as file: | |
data = json.load(file) | |
# Extract config and results | |
config = data['config'] | |
results = data['results'] | |
# Combine config and results into a single dictionary | |
combined_data = { | |
'agent_name': config['agent_name'], | |
'benchmark_name': config['benchmark_name'], | |
'date': config['date'] | |
} | |
# Add results with 'results_' prefix | |
for key, value in results.items(): | |
combined_data[f'results_{key}'] = value | |
data_list.append(combined_data) | |
except Exception as e: | |
print(f"Error processing {json_file}: {e}. Skipping!") | |
# Create DataFrame from the list of dictionaries | |
df = pd.DataFrame(data_list) | |
return df | |
def create_scatter_plot(df, x: str, y: str, x_label: str = None, y_label: str = None, hover_data: list = None): | |
print(df) | |
agents = [Agent(row.results_total_cost, row.results_accuracy) for row in df.itertuples()] | |
pareto_frontier = compute_pareto_frontier(agents) | |
print(pareto_frontier) | |
fig = px.scatter(df, | |
x=x, | |
y=y, | |
hover_data=hover_data, | |
) | |
# Sort the Pareto frontier points by x-coordinate | |
pareto_points = sorted([(agent.total_cost, agent.accuracy) for agent in pareto_frontier], key=lambda x: x[0]) | |
# Add the Pareto frontier line | |
fig.add_trace(go.Scatter( | |
x=[point[0] for point in pareto_points], | |
y=[point[1] for point in pareto_points], | |
mode='lines', | |
name='Pareto Frontier', | |
line=dict(color='black', width=2, dash='dash') | |
)) | |
# Calculate the maximum x and y values for setting axis ranges | |
max_x = max(df[x].max(), max(point[0] for point in pareto_points)) | |
max_y = max(df[y].max(), max(point[1] for point in pareto_points)) | |
fig.update_yaxes(rangemode="tozero") | |
fig.update_xaxes(rangemode="tozero") | |
fig.update_layout( | |
width = 600, | |
height = 500, | |
xaxis_title = x_label, | |
yaxis_title = y_label, | |
xaxis = dict( | |
showline = True, | |
linecolor = 'black', | |
showgrid = False), | |
yaxis = dict( | |
showline = True, | |
showgrid = False, | |
linecolor = 'black'), | |
plot_bgcolor = 'white', | |
# Legend positioning | |
legend=dict( | |
yanchor="bottom", | |
y=0.01, | |
xanchor="right", | |
x=0.98, | |
bgcolor="rgba(255, 255, 255, 0.5)" # semi-transparent white background | |
) | |
) | |
return fig |