benediktstroebl commited on
Commit
9250161
·
1 Parent(s): f5fc72d

Added sorting to heatmap

Browse files
Files changed (3) hide show
  1. app.py +0 -2
  2. utils/db.py +1 -0
  3. utils/viz.py +14 -1
app.py CHANGED
@@ -224,8 +224,6 @@ with gr.Blocks() as demo:
224
  with gr.Tab("USACO"):
225
  with gr.Row():
226
  with gr.Column(scale=2):
227
- print(parse_json_files(os.path.join(abs_path, "evals_live"), 'usaco').columns)
228
- print(parse_json_files(os.path.join(abs_path, "evals_live"), 'mlagentbench').columns)
229
  Leaderboard(
230
  value=parse_json_files(os.path.join(abs_path, "evals_live"), 'usaco'),
231
  select_columns=SelectColumns(
 
224
  with gr.Tab("USACO"):
225
  with gr.Row():
226
  with gr.Column(scale=2):
 
 
227
  Leaderboard(
228
  value=parse_json_files(os.path.join(abs_path, "evals_live"), 'usaco'),
229
  select_columns=SelectColumns(
utils/db.py CHANGED
@@ -63,6 +63,7 @@ class TracePreprocessor:
63
  ''')
64
 
65
  def preprocess_traces(self, processed_dir="evals_live"):
 
66
  processed_dir = Path(processed_dir)
67
  for file in processed_dir.glob('*.json'):
68
  with open(file, 'r') as f:
 
63
  ''')
64
 
65
  def preprocess_traces(self, processed_dir="evals_live"):
66
+ self.create_tables()
67
  processed_dir = Path(processed_dir)
68
  for file in processed_dir.glob('*.json'):
69
  with open(file, 'r') as f:
utils/viz.py CHANGED
@@ -5,9 +5,22 @@ import plotly.graph_objects as go
5
  import textwrap
6
 
7
  def create_task_success_heatmap(df, benchmark_name):
 
 
 
 
 
 
8
  # Pivot the dataframe to create a matrix of agents vs tasks
9
  pivot_df = df.pivot(index='Agent Name', columns='Task ID', values='Success')
10
 
 
 
 
 
 
 
 
11
  # Create the heatmap
12
  fig = go.Figure(data=go.Heatmap(
13
  z=pivot_df.values,
@@ -23,7 +36,7 @@ def create_task_success_heatmap(df, benchmark_name):
23
  # Update the layout
24
  fig.update_layout(
25
  xaxis_title='Task ID',
26
- height=600,
27
  width=1300,
28
  yaxis=dict(
29
  autorange='reversed',
 
5
  import textwrap
6
 
7
  def create_task_success_heatmap(df, benchmark_name):
8
+ # Calculate agent accuracy
9
+ agent_accuracy = df.groupby('Agent Name')['Success'].mean().sort_values(ascending=False)
10
+
11
+ # Calculate task success rate
12
+ task_success_rate = df.groupby('Task ID')['Success'].mean().sort_values(ascending=False)
13
+
14
  # Pivot the dataframe to create a matrix of agents vs tasks
15
  pivot_df = df.pivot(index='Agent Name', columns='Task ID', values='Success')
16
 
17
+ # Sort the pivot table
18
+ pivot_df = pivot_df.reindex(index=agent_accuracy.index, columns=task_success_rate.index)
19
+
20
+ num_agents = len(pivot_df.index)
21
+ row_height = 30 # Fixed height for each row in pixels
22
+ total_height = num_agents * row_height
23
+
24
  # Create the heatmap
25
  fig = go.Figure(data=go.Heatmap(
26
  z=pivot_df.values,
 
36
  # Update the layout
37
  fig.update_layout(
38
  xaxis_title='Task ID',
39
+ height=total_height,
40
  width=1300,
41
  yaxis=dict(
42
  autorange='reversed',