ZeroCommand commited on
Commit
44ab78a
·
1 Parent(s): fc7c452

add predict button

Browse files
app_text_classification.py CHANGED
@@ -5,10 +5,11 @@ import gradio as gr
5
  from io_utils import (get_logs_file, read_scanners, write_scanners)
6
  from text_classification_ui_helpers import (check_dataset_and_get_config,
7
  check_dataset_and_get_split,
8
- check_model_and_show_prediction,
9
  deselect_run_inference,
10
  select_run_mode, try_submit,
11
- write_column_mapping_to_config)
 
12
  from wordings import CONFIRM_MAPPING_DETAILS_MD, INTRODUCTION_MD
13
 
14
  MAX_LABELS = 40
@@ -40,6 +41,13 @@ def get_demo():
40
  dataset_config_input = gr.Dropdown(label="Dataset Config", visible=False)
41
  dataset_split_input = gr.Dropdown(label="Dataset Split", visible=False)
42
 
 
 
 
 
 
 
 
43
  with gr.Row():
44
  example_input = gr.HTML(visible=False)
45
  with gr.Row():
@@ -90,7 +98,7 @@ def get_demo():
90
  run_btn = gr.Button(
91
  "Get Evaluation Result",
92
  variant="primary",
93
- interactive=True,
94
  size="lg",
95
  )
96
 
@@ -122,7 +130,7 @@ def get_demo():
122
  inputs=[run_local],
123
  outputs=[inference_token, run_inference],
124
  )
125
-
126
  gr.on(
127
  triggers=[label.change for label in column_mappings],
128
  fn=write_column_mapping_to_config,
@@ -147,9 +155,21 @@ def get_demo():
147
  model_id_input.change,
148
  dataset_id_input.change,
149
  dataset_config_input.change,
150
- dataset_split_input.change,
 
 
 
 
 
 
 
 
 
 
 
 
151
  ],
152
- fn=check_model_and_show_prediction,
153
  inputs=[
154
  model_id_input,
155
  dataset_id_input,
@@ -161,6 +181,7 @@ def get_demo():
161
  example_input,
162
  example_prediction,
163
  column_mapping_accordion,
 
164
  *column_mappings,
165
  ],
166
  )
@@ -188,13 +209,10 @@ def get_demo():
188
 
189
  gr.on(
190
  triggers=[
191
- model_id_input.change,
192
- dataset_config_input.change,
193
- dataset_split_input.change,
194
- run_inference.change,
195
- run_local.change,
196
- inference_token.change,
197
- scanners.change,
198
  ],
199
  fn=enable_run_btn,
200
  inputs=None,
@@ -202,8 +220,8 @@ def get_demo():
202
  )
203
 
204
  gr.on(
205
- triggers=[label.change for label in column_mappings],
206
  fn=enable_run_btn,
207
- inputs=None,
208
  outputs=[run_btn],
209
  )
 
5
  from io_utils import (get_logs_file, read_scanners, write_scanners)
6
  from text_classification_ui_helpers import (check_dataset_and_get_config,
7
  check_dataset_and_get_split,
8
+ align_columns_and_show_prediction,
9
  deselect_run_inference,
10
  select_run_mode, try_submit,
11
+ write_column_mapping_to_config,
12
+ precheck_model_ds_enable_example_btn)
13
  from wordings import CONFIRM_MAPPING_DETAILS_MD, INTRODUCTION_MD
14
 
15
  MAX_LABELS = 40
 
41
  dataset_config_input = gr.Dropdown(label="Dataset Config", visible=False)
42
  dataset_split_input = gr.Dropdown(label="Dataset Split", visible=False)
43
 
44
+ with gr.Row():
45
+ example_btn = gr.Button(
46
+ "Auto-align Columns & Get Sample Prediction",
47
+ visible=True,
48
+ variant="primary",
49
+ interactive=False)
50
+
51
  with gr.Row():
52
  example_input = gr.HTML(visible=False)
53
  with gr.Row():
 
98
  run_btn = gr.Button(
99
  "Get Evaluation Result",
100
  variant="primary",
101
+ interactive=False,
102
  size="lg",
103
  )
104
 
 
130
  inputs=[run_local],
131
  outputs=[inference_token, run_inference],
132
  )
133
+
134
  gr.on(
135
  triggers=[label.change for label in column_mappings],
136
  fn=write_column_mapping_to_config,
 
155
  model_id_input.change,
156
  dataset_id_input.change,
157
  dataset_config_input.change,
158
+ dataset_split_input.change],
159
+ fn=precheck_model_ds_enable_example_btn,
160
+ inputs=[
161
+ model_id_input,
162
+ dataset_id_input,
163
+ dataset_config_input,
164
+ dataset_split_input,
165
+ ],
166
+ outputs=[example_btn])
167
+
168
+ gr.on(
169
+ triggers=[
170
+ example_btn.click,
171
  ],
172
+ fn=align_columns_and_show_prediction,
173
  inputs=[
174
  model_id_input,
175
  dataset_id_input,
 
181
  example_input,
182
  example_prediction,
183
  column_mapping_accordion,
184
+ run_btn,
185
  *column_mappings,
186
  ],
187
  )
 
209
 
210
  gr.on(
211
  triggers=[
212
+ run_inference.input,
213
+ run_local.input,
214
+ inference_token.input,
215
+ scanners.input,
 
 
 
216
  ],
217
  fn=enable_run_btn,
218
  inputs=None,
 
220
  )
221
 
222
  gr.on(
223
+ triggers=[label.input for label in column_mappings],
224
  fn=enable_run_btn,
225
+ inputs=column_mappings,
226
  outputs=[run_btn],
227
  )
text_classification_ui_helpers.py CHANGED
@@ -141,8 +141,21 @@ def list_labels_and_features_from_dataset(ds_labels, ds_features, model_id2label
141
 
142
  return lables + features
143
 
 
 
 
 
 
 
 
 
 
 
 
144
 
145
- def check_model_and_show_prediction(
 
 
146
  model_id, dataset_id, dataset_config, dataset_split, uid
147
  ):
148
  ppl = check_model(model_id)
@@ -151,6 +164,8 @@ def check_model_and_show_prediction(
151
  return (
152
  gr.update(visible=False),
153
  gr.update(visible=False),
 
 
154
  *[gr.update(visible=False) for _ in range(MAX_LABELS + MAX_FEATURES)],
155
  )
156
 
@@ -164,6 +179,7 @@ def check_model_and_show_prediction(
164
  gr.update(visible=False),
165
  gr.update(visible=False),
166
  gr.update(visible=False, open=False),
 
167
  *dropdown_placement,
168
  )
169
  model_id2label = ppl.model.config.id2label
@@ -178,6 +194,7 @@ def check_model_and_show_prediction(
178
  gr.update(visible=False),
179
  gr.update(visible=False),
180
  gr.update(visible=False, open=False),
 
181
  *dropdown_placement,
182
  )
183
 
@@ -198,6 +215,7 @@ def check_model_and_show_prediction(
198
  gr.update(value=MAPPING_STYLED_ERROR_WARNING, visible=True),
199
  gr.update(visible=False),
200
  gr.update(visible=True, open=True),
 
201
  *column_mappings,
202
  )
203
 
@@ -208,6 +226,7 @@ def check_model_and_show_prediction(
208
  gr.update(value=get_styled_input(prediction_input), visible=True),
209
  gr.update(value=prediction_output, visible=True),
210
  gr.update(visible=True, open=False),
 
211
  *column_mappings,
212
  )
213
 
 
141
 
142
  return lables + features
143
 
144
+ def precheck_model_ds_enable_example_btn(model_id, dataset_id, dataset_config, dataset_split):
145
+ ppl = check_model(model_id)
146
+ if ppl is None or not isinstance(ppl, TextClassificationPipeline):
147
+ gr.Warning("Please check your model.")
148
+ return gr.update(interactive=False)
149
+ ds_labels, ds_features = get_labels_and_features_from_dataset(
150
+ dataset_id, dataset_config, dataset_split
151
+ )
152
+ if not isinstance(ds_labels, list) or not isinstance(ds_features, list):
153
+ gr.Warning(CHECK_CONFIG_OR_SPLIT_RAW)
154
+ return gr.update(interactive=False)
155
 
156
+ return gr.update(interactive=True)
157
+
158
+ def align_columns_and_show_prediction(
159
  model_id, dataset_id, dataset_config, dataset_split, uid
160
  ):
161
  ppl = check_model(model_id)
 
164
  return (
165
  gr.update(visible=False),
166
  gr.update(visible=False),
167
+ gr.update(visible=False, open=False),
168
+ gr.update(interactive=False),
169
  *[gr.update(visible=False) for _ in range(MAX_LABELS + MAX_FEATURES)],
170
  )
171
 
 
179
  gr.update(visible=False),
180
  gr.update(visible=False),
181
  gr.update(visible=False, open=False),
182
+ gr.update(interactive=False),
183
  *dropdown_placement,
184
  )
185
  model_id2label = ppl.model.config.id2label
 
194
  gr.update(visible=False),
195
  gr.update(visible=False),
196
  gr.update(visible=False, open=False),
197
+ gr.update(interactive=False),
198
  *dropdown_placement,
199
  )
200
 
 
215
  gr.update(value=MAPPING_STYLED_ERROR_WARNING, visible=True),
216
  gr.update(visible=False),
217
  gr.update(visible=True, open=True),
218
+ gr.update(interactive=True),
219
  *column_mappings,
220
  )
221
 
 
226
  gr.update(value=get_styled_input(prediction_input), visible=True),
227
  gr.update(value=prediction_output, visible=True),
228
  gr.update(visible=True, open=False),
229
+ gr.update(interactive=True),
230
  *column_mappings,
231
  )
232