inoki-giskard commited on
Commit
85095eb
·
1 Parent(s): 01c4e21

Add features, label mapping in text classification

Browse files
Files changed (1) hide show
  1. app.py +118 -42
app.py CHANGED
@@ -7,6 +7,7 @@ import time
7
  from pathlib import Path
8
 
9
  import json
 
10
 
11
  import pandas as pd
12
 
@@ -64,16 +65,20 @@ def text_classificaiton_match_label_case_unsensative(id2label_mapping, label):
64
  for model_label in id2label_mapping.keys():
65
  if model_label.upper() == label.upper():
66
  return model_label, label
 
67
 
68
 
69
  def text_classification_map_model_and_dataset_labels(id2label, dataset_features):
70
  id2label_mapping = {id2label[k]: None for k in id2label.keys()}
 
71
  for feature in dataset_features.values():
72
  if not isinstance(feature, datasets.ClassLabel):
73
  continue
74
  if len(feature.names) != len(id2label_mapping.keys()):
75
  continue
76
 
 
 
77
  # Try to match labels
78
  for label in feature.names:
79
  if label in id2label_mapping.keys():
@@ -81,9 +86,86 @@ def text_classification_map_model_and_dataset_labels(id2label, dataset_features)
81
  else:
82
  # Try to find case unsensative
83
  model_label, label = text_classificaiton_match_label_case_unsensative(id2label_mapping, label)
84
- id2label_mapping[model_label] = label
 
 
 
85
 
86
- return id2label_mapping
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
 
89
  def try_validate(model_id, dataset_id, dataset_config, dataset_split, column_mapping):
@@ -133,7 +215,7 @@ def try_validate(model_id, dataset_id, dataset_config, dataset_split, column_map
133
  )
134
 
135
  # TODO: Validate column mapping by running once
136
- prediction_result = {}
137
  id2label_df = None
138
  if isinstance(ppl, TextClassificationPipeline):
139
  try:
@@ -141,39 +223,32 @@ def try_validate(model_id, dataset_id, dataset_config, dataset_split, column_map
141
  except Exception:
142
  column_mapping = {}
143
 
144
- # Retrieve all labels
145
- id2label_mapping = {}
146
- try:
147
- results = ppl({"text": "Test"}, top_k=None)
148
- prediction_result = {
149
- result["label"]: result["score"] for result in results
150
- }
151
- except Exception as e:
152
- # Pipeline is not executable
153
- pass
154
-
155
- # We assume dataset is ok here
156
- ds = datasets.load_dataset(d_id, config)[split]
157
- try:
158
- id2label = ppl.model.config.id2label
159
- id2label_mapping = text_classification_map_model_and_dataset_labels(ppl.model.config.id2label, ds.features)
160
- id2label_df = pd.DataFrame({
161
- "ID": [i for i in id2label.keys()],
162
- "Model labels": [id2label[label] for label in id2label.keys()],
163
- "Dataset labels": [id2label_mapping[id2label[label]] for label in id2label.keys()],
164
- })
165
- if "label" not in column_mapping.keys():
166
- column_mapping["label"] = {
167
- i: id2label_mapping[id2label[i]] for i in id2label.keys()
168
- }
169
- except AttributeError:
170
- # Dataset does not have features
171
- pass
172
 
173
  column_mapping = json.dumps(column_mapping, indent=2)
174
 
175
  del ppl
176
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
  gr.Info("Model and dataset validations passed. Your can submit the evaluation task.")
178
 
179
  return (
@@ -248,7 +323,6 @@ with gr.Blocks(theme=theme) as iface:
248
  ],
249
  value=0,
250
  )
251
- run_local = gr.Checkbox(value=True, label="Run in this Space")
252
  example_labels = gr.Label(label='Model pipeline test prediction result', visible=False)
253
 
254
  with gr.Column():
@@ -278,16 +352,18 @@ with gr.Blocks(theme=theme) as iface:
278
  id2label_mapping_dataframe = gr.DataFrame(visible=False)
279
 
280
  with gr.Row():
281
- column_mapping_input = gr.Textbox(
282
- value="",
283
- lines=5,
284
- label="Column mapping",
285
- placeholder="Description of mapping of columns in model to dataset, in json format, e.g.:\n"
286
- '{\n'
287
- ' "text": "context",\n'
288
- ' "label": {0: "Positive", 1: "Negative"}\n'
289
- '}',
290
- )
 
 
291
 
292
  with gr.Row():
293
  validate_btn = gr.Button("Validate model and dataset", variant="primary")
 
7
  from pathlib import Path
8
 
9
  import json
10
+ import logging
11
 
12
  import pandas as pd
13
 
 
65
  for model_label in id2label_mapping.keys():
66
  if model_label.upper() == label.upper():
67
  return model_label, label
68
+ return None, label
69
 
70
 
71
  def text_classification_map_model_and_dataset_labels(id2label, dataset_features):
72
  id2label_mapping = {id2label[k]: None for k in id2label.keys()}
73
+ dataset_labels = None
74
  for feature in dataset_features.values():
75
  if not isinstance(feature, datasets.ClassLabel):
76
  continue
77
  if len(feature.names) != len(id2label_mapping.keys()):
78
  continue
79
 
80
+ dataset_labels = feature.names
81
+
82
  # Try to match labels
83
  for label in feature.names:
84
  if label in id2label_mapping.keys():
 
86
  else:
87
  # Try to find case unsensative
88
  model_label, label = text_classificaiton_match_label_case_unsensative(id2label_mapping, label)
89
+ if model_label is not None:
90
+ id2label_mapping[model_label] = label
91
+
92
+ return id2label_mapping, dataset_labels
93
 
94
+
95
+ def text_classification_fix_column_mapping(column_mapping, ppl, d_id, config, split):
96
+ # We assume dataset is ok here
97
+ ds = datasets.load_dataset(d_id, config)[split]
98
+
99
+ try:
100
+ dataset_features = ds.features
101
+ except AttributeError:
102
+ # Dataset does not have features, need to provide everything
103
+ return None, None, None
104
+
105
+ # Check whether we need to infer the text input column
106
+ infer_text_input_column = True
107
+ if "text" in column_mapping.keys():
108
+ dataset_text_column = column_mapping["text"]
109
+ if dataset_text_column in dataset_features.keys():
110
+ infer_text_input_column = False
111
+ else:
112
+ logging.warning(f"Provided {dataset_text_column} is not in Dataset columns")
113
+
114
+ if infer_text_input_column:
115
+ # Try to retrieve one
116
+ candidates = [f for f in dataset_features if dataset_features[f].dtype == "string"]
117
+ if len(candidates) > 0:
118
+ logging.debug(f"Candidates are {candidates}")
119
+ column_mapping["text"] = candidates[0]
120
+ else:
121
+ # Not found a text feature
122
+ return column_mapping, None, None
123
+
124
+ # Load dataset as DataFrame
125
+ df = ds.to_pandas()
126
+
127
+ # Retrieve all labels
128
+ id2label_mapping = {}
129
+ id2label = ppl.model.config.id2label
130
+ label2id = {v: k for k, v in id2label.items()}
131
+ prediction_result = None
132
+ try:
133
+ # Use the first item to test prediction
134
+ results = ppl({"text": df.head(1).at[0, column_mapping["text"]]}, top_k=None)
135
+ prediction_result = {
136
+ f'{result["label"]}({label2id[result["label"]]})': result["score"] for result in results
137
+ }
138
+ except Exception:
139
+ # Pipeline prediction failed, need to provide labels
140
+ return column_mapping, None, None
141
+
142
+ # Infer labels
143
+ id2label_mapping, dataset_labels = text_classification_map_model_and_dataset_labels(id2label, dataset_features)
144
+ if "label" in column_mapping.keys():
145
+ if not isinstance(column_mapping["label"], dict) or set(column_mapping["label"].values()) != set(dataset_labels):
146
+ logging.warning(f'Provided {column_mapping["label"]} does not match labels in Dataset')
147
+ return column_mapping, prediction_result, None
148
+
149
+ if isinstance(column_mapping["label"], dict):
150
+ for model_label in id2label_mapping.keys():
151
+ id2label_mapping[model_label] = column_mapping["label"][str(label2id[model_label])]
152
+ elif None in id2label_mapping.values():
153
+ column_mapping["label"] = {
154
+ i: None for i in id2label.keys()
155
+ }
156
+ return column_mapping, prediction_result, None
157
+
158
+ id2label_df = pd.DataFrame({
159
+ "ID": [i for i in id2label.keys()],
160
+ "Model labels": [id2label[label] for label in id2label.keys()],
161
+ "Dataset labels": [id2label_mapping[id2label[label]] for label in id2label.keys()],
162
+ })
163
+ if "label" not in column_mapping.keys():
164
+ column_mapping["label"] = {
165
+ i: id2label_mapping[id2label[i]] for i in id2label.keys()
166
+ }
167
+
168
+ return column_mapping, prediction_result, id2label_df
169
 
170
 
171
  def try_validate(model_id, dataset_id, dataset_config, dataset_split, column_mapping):
 
215
  )
216
 
217
  # TODO: Validate column mapping by running once
218
+ prediction_result = None
219
  id2label_df = None
220
  if isinstance(ppl, TextClassificationPipeline):
221
  try:
 
223
  except Exception:
224
  column_mapping = {}
225
 
226
+ column_mapping, prediction_result, id2label_df = \
227
+ text_classification_fix_column_mapping(column_mapping, ppl, d_id, config, split)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
 
229
  column_mapping = json.dumps(column_mapping, indent=2)
230
 
231
  del ppl
232
 
233
+ if prediction_result is None:
234
+ gr.Warning('The model failed to predict with the first row in the dataset. Please provide column mappings in "Advance" settings.')
235
+ return (
236
+ config, split,
237
+ gr.update(interactive=False), # Submit button
238
+ gr.update(visible=False), # Model prediction preview
239
+ gr.update(visible=False), # Label mapping preview
240
+ gr.update(value=column_mapping, visible=True, interactive=True), # Column mapping
241
+ )
242
+ elif id2label_df is None:
243
+ gr.Warning('The prediction result does not conform the labels in the dataset. Please provide label mappings in "Advance" settings.')
244
+ return (
245
+ config, split,
246
+ gr.update(interactive=False), # Submit button
247
+ gr.update(value=prediction_result, visible=True), # Model prediction preview
248
+ gr.update(visible=False), # Label mapping preview
249
+ gr.update(value=column_mapping, visible=True, interactive=True), # Column mapping
250
+ )
251
+
252
  gr.Info("Model and dataset validations passed. Your can submit the evaluation task.")
253
 
254
  return (
 
323
  ],
324
  value=0,
325
  )
 
326
  example_labels = gr.Label(label='Model pipeline test prediction result', visible=False)
327
 
328
  with gr.Column():
 
352
  id2label_mapping_dataframe = gr.DataFrame(visible=False)
353
 
354
  with gr.Row():
355
+ with gr.Accordion("Advance", open=False):
356
+ run_local = gr.Checkbox(value=True, label="Run in this Space")
357
+ column_mapping_input = gr.Textbox(
358
+ value="",
359
+ lines=5,
360
+ label="Column mapping",
361
+ placeholder="Description of mapping of columns in model to dataset, in json format, e.g.:\n"
362
+ '{\n'
363
+ ' "text": "context",\n'
364
+ ' "label": {0: "Positive", 1: "Negative"}\n'
365
+ '}',
366
+ )
367
 
368
  with gr.Row():
369
  validate_btn = gr.Button("Validate model and dataset", variant="primary")