bug-fix-label-mapping-align-with-correct-idx

#80
by ZeroCommand - opened
app.py CHANGED
@@ -12,12 +12,10 @@ try:
12
  with gr.Tab("Text Classification"):
13
  get_demo_text_classification()
14
  with gr.Tab("Leaderboard") as leaderboard_tab:
15
- get_demo_leaderboard()
16
  with gr.Tab("Logs(Debug)"):
17
  get_demo_debug()
18
 
19
- leaderboard_tab.select(fn=get_demo_leaderboard)
20
-
21
  start_process_run_job()
22
 
23
  demo.queue(max_size=1000)
 
12
  with gr.Tab("Text Classification"):
13
  get_demo_text_classification()
14
  with gr.Tab("Leaderboard") as leaderboard_tab:
15
+ get_demo_leaderboard(leaderboard_tab)
16
  with gr.Tab("Logs(Debug)"):
17
  get_demo_debug()
18
 
 
 
19
  start_process_run_job()
20
 
21
  demo.queue(max_size=1000)
app_leaderboard.py CHANGED
@@ -73,8 +73,11 @@ def get_display_df(df):
73
  )
74
  return display_df
75
 
 
 
 
76
 
77
- def get_demo():
78
  logger.info("Loading leaderboard records")
79
  leaderboard.records = get_records_from_dataset_repo(leaderboard.LEADERBOARD)
80
  records = leaderboard.records
@@ -116,6 +119,8 @@ def get_demo():
116
  with gr.Row():
117
  leaderboard_df = gr.DataFrame(display_df, datatype=types, interactive=False)
118
 
 
 
119
  @gr.on(
120
  triggers=[
121
  model_select.change,
 
73
  )
74
  return display_df
75
 
76
+ def update_leaderboard_records():
77
+ logger.info("Updating leaderboard records")
78
+ leaderboard.records = get_records_from_dataset_repo(leaderboard.LEADERBOARD)
79
 
80
+ def get_demo(leaderboard_tab):
81
  logger.info("Loading leaderboard records")
82
  leaderboard.records = get_records_from_dataset_repo(leaderboard.LEADERBOARD)
83
  records = leaderboard.records
 
119
  with gr.Row():
120
  leaderboard_df = gr.DataFrame(display_df, datatype=types, interactive=False)
121
 
122
+ leaderboard_tab.select(fn=update_leaderboard_records)
123
+
124
  @gr.on(
125
  triggers=[
126
  model_select.change,
text_classification_ui_helpers.py CHANGED
@@ -30,7 +30,6 @@ MAX_FEATURES = 20
30
  ds_dict = None
31
  ds_config = None
32
 
33
-
34
  def get_related_datasets_from_leaderboard(model_id):
35
  records = leaderboard.records
36
  model_records = records[records["model_id"] == model_id]
@@ -100,7 +99,7 @@ def export_mappings(all_mappings, key, subkeys, values):
100
  if subkeys is None:
101
  subkeys = list(all_mappings[key].keys())
102
 
103
- if not subkeys:
104
  logging.debug(f"subkeys is empty for {key}")
105
  return all_mappings
106
 
@@ -121,6 +120,8 @@ def list_labels_and_features_from_dataset(ds_labels, ds_features, model_labels,
121
  ds_labels = ds_labels[:MAX_LABELS]
122
  gr.Warning(f"The number of labels is truncated to length {MAX_LABELS}")
123
 
 
 
124
  ds_labels.sort()
125
  model_labels.sort()
126
 
@@ -293,17 +294,20 @@ def check_column_mapping_keys_validity(all_mappings):
293
  return (gr.update(interactive=True), gr.update(visible=False))
294
 
295
 
296
- def construct_label_and_feature_mapping(all_mappings):
297
  label_mapping = {}
298
- for i, label in zip(
299
- range(len(all_mappings["labels"].keys())), all_mappings["labels"].keys()
300
- ):
301
- # FIXME: What's the order during the save
 
 
 
 
302
  label_mapping.update({str(i): all_mappings["labels"][label]})
303
 
304
  if "features" not in all_mappings.keys():
305
  gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW)
306
- return (gr.update(interactive=True), gr.update(visible=False))
307
  feature_mapping = all_mappings["features"]
308
  return label_mapping, feature_mapping
309
 
@@ -311,7 +315,11 @@ def construct_label_and_feature_mapping(all_mappings):
311
  def try_submit(m_id, d_id, config, split, inference, inference_token, uid):
312
  all_mappings = read_column_mapping(uid)
313
  check_column_mapping_keys_validity(all_mappings)
314
- label_mapping, feature_mapping = construct_label_and_feature_mapping(all_mappings)
 
 
 
 
315
 
316
  eval_str = f"[{m_id}]<{d_id}({config}, {split} set)>"
317
  save_job_to_pipe(
 
30
  ds_dict = None
31
  ds_config = None
32
 
 
33
  def get_related_datasets_from_leaderboard(model_id):
34
  records = leaderboard.records
35
  model_records = records[records["model_id"] == model_id]
 
99
  if subkeys is None:
100
  subkeys = list(all_mappings[key].keys())
101
 
102
+ if not subkeys:
103
  logging.debug(f"subkeys is empty for {key}")
104
  return all_mappings
105
 
 
120
  ds_labels = ds_labels[:MAX_LABELS]
121
  gr.Warning(f"The number of labels is truncated to length {MAX_LABELS}")
122
 
123
+ # sort labels to make sure the order is consistent
124
+ # prediction gives the order based on probability
125
  ds_labels.sort()
126
  model_labels.sort()
127
 
 
294
  return (gr.update(interactive=True), gr.update(visible=False))
295
 
296
 
297
+ def construct_label_and_feature_mapping(all_mappings, ds_labels, ds_features):
298
  label_mapping = {}
299
+ if len(all_mappings["labels"].keys()) != len(ds_labels):
300
+ gr.Warning("Label mapping corrupted: " + CONFIRM_MAPPING_DETAILS_FAIL_RAW)
301
+
302
+ if len(all_mappings["features"].keys()) != len(ds_features):
303
+ gr.Warning("Feature mapping corrupted: " + CONFIRM_MAPPING_DETAILS_FAIL_RAW)
304
+
305
+ for i, label in zip(range(len(ds_labels)), ds_labels):
306
+ # align the saved labels with dataset labels order
307
  label_mapping.update({str(i): all_mappings["labels"][label]})
308
 
309
  if "features" not in all_mappings.keys():
310
  gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW)
 
311
  feature_mapping = all_mappings["features"]
312
  return label_mapping, feature_mapping
313
 
 
315
  def try_submit(m_id, d_id, config, split, inference, inference_token, uid):
316
  all_mappings = read_column_mapping(uid)
317
  check_column_mapping_keys_validity(all_mappings)
318
+
319
+ # get ds labels and features again for alignment
320
+ ds = datasets.load_dataset(d_id, config)[split]
321
+ ds_labels, ds_features = get_labels_and_features_from_dataset(ds)
322
+ label_mapping, feature_mapping = construct_label_and_feature_mapping(all_mappings, ds_labels, ds_features)
323
 
324
  eval_str = f"[{m_id}]<{d_id}({config}, {split} set)>"
325
  save_job_to_pipe(