benediktstroebl commited on
Commit
5cbaf0e
·
1 Parent(s): 221fb8a

refactoring

Browse files
Files changed (2) hide show
  1. app.py +20 -2
  2. utils/data.py +0 -20
app.py CHANGED
@@ -6,7 +6,7 @@ from pathlib import Path
6
  import pandas as pd
7
  import os
8
  import json
9
- from utils.data import parse_json_files, preprocess_traces
10
  from utils.viz import create_scatter_plot, create_flow_chart
11
  from utils.processing import check_and_process_uploads
12
  from huggingface_hub import snapshot_download
@@ -36,9 +36,27 @@ def download_latest_results():
36
 
37
  abs_path = Path(__file__).parent
38
 
39
-
40
  # Global variable to store preprocessed data
41
  preprocessed_traces = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
  def get_analyzed_traces(agent_name, benchmark_name):
44
  return preprocessed_traces.get(benchmark_name, {}).get(agent_name)
 
6
  import pandas as pd
7
  import os
8
  import json
9
+ from utils.data import parse_json_files
10
  from utils.viz import create_scatter_plot, create_flow_chart
11
  from utils.processing import check_and_process_uploads
12
  from huggingface_hub import snapshot_download
 
36
 
37
  abs_path = Path(__file__).parent
38
 
 
39
  # Global variable to store preprocessed data
40
  preprocessed_traces = {}
41
+ def preprocess_traces():
42
+ global preprocessed_traces
43
+ processed_dir = Path("evals_live")
44
+ for file in processed_dir.glob('*.json'):
45
+ try:
46
+ with open(file, 'r') as f:
47
+ data = json.load(f)
48
+ agent_name = data['config']['agent_name']
49
+ benchmark_name = data['config']['benchmark_name']
50
+ if benchmark_name not in preprocessed_traces:
51
+ preprocessed_traces[benchmark_name] = {}
52
+
53
+ assert type(data['raw_logging_results']) == dict, f"Invalid format for raw_logging_results: {type(data['raw_logging_results'])}"
54
+ preprocessed_traces[benchmark_name][agent_name] = data['raw_logging_results']
55
+ except AssertionError as e:
56
+ preprocessed_traces[benchmark_name][agent_name] = None
57
+ except Exception as e:
58
+ print(f"Error preprocessing {file}: {e}")
59
+ preprocessed_traces[benchmark_name][agent_name] = None
60
 
61
  def get_analyzed_traces(agent_name, benchmark_name):
62
  return preprocessed_traces.get(benchmark_name, {}).get(agent_name)
utils/data.py CHANGED
@@ -6,26 +6,6 @@ from utils.pareto import Agent, compute_pareto_frontier
6
  import plotly.graph_objects as go
7
  import textwrap
8
 
9
- def preprocess_traces():
10
- global preprocessed_traces
11
- processed_dir = "evals_live"
12
- for file in processed_dir.glob('*.json'):
13
- try:
14
- with open(file, 'r') as f:
15
- data = json.load(f)
16
- agent_name = data['config']['agent_name']
17
- benchmark_name = data['config']['benchmark_name']
18
- if benchmark_name not in preprocessed_traces:
19
- preprocessed_traces[benchmark_name] = {}
20
-
21
- assert type(data['raw_logging_results']) == dict, f"Invalid format for raw_logging_results: {type(data['raw_logging_results'])}"
22
- preprocessed_traces[benchmark_name][agent_name] = data['raw_logging_results']
23
- except AssertionError as e:
24
- preprocessed_traces[benchmark_name][agent_name] = None
25
- except Exception as e:
26
- print(f"Error preprocessing {file}: {e}")
27
- preprocessed_traces[benchmark_name][agent_name] = None
28
-
29
  def parse_json_files(folder_path, benchmark_name):
30
  # Convert folder path to Path object
31
  folder = Path(folder_path)
 
6
  import plotly.graph_objects as go
7
  import textwrap
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  def parse_json_files(folder_path, benchmark_name):
10
  # Convert folder path to Path object
11
  folder = Path(folder_path)