andreybavt ZeroCommand commited on
Commit
5f9a95f
·
verified ·
1 Parent(s): 2070be3

GSK-2737-GSK-2735-support-model-url-remove-misleading-info (#90)

Browse files

- add support to input model url & remove suggestions when no related ds (77da4cdd32b75cee859556a213421840cedc380e)


Co-authored-by: zcy <[email protected]>

text_classification.py CHANGED
@@ -376,4 +376,9 @@ def text_classification_fix_column_mapping(column_mapping, ppl, d_id, config, sp
376
  prediction_result,
377
  id2label_df,
378
  feature_map_df,
379
- )
 
 
 
 
 
 
376
  prediction_result,
377
  id2label_df,
378
  feature_map_df,
379
+ )
380
+
381
+ def strip_model_id_from_url(model_id):
382
+ if model_id.startswith("https://huggingface.co/"):
383
+ return "/".join(model_id.split("/")[-2])
384
+ return model_id
text_classification_ui_helpers.py CHANGED
@@ -11,6 +11,7 @@ import leaderboard
11
  from io_utils import read_column_mapping, write_column_mapping
12
  from run_jobs import save_job_to_pipe
13
  from text_classification import (
 
14
  check_model_task,
15
  preload_hf_inference_api,
16
  get_example_prediction,
@@ -21,6 +22,7 @@ from wordings import (
21
  CHECK_CONFIG_OR_SPLIT_RAW,
22
  CONFIRM_MAPPING_DETAILS_FAIL_RAW,
23
  MAPPING_STYLED_ERROR_WARNING,
 
24
  get_styled_input,
25
  )
26
 
@@ -32,12 +34,12 @@ ds_config = None
32
 
33
  def get_related_datasets_from_leaderboard(model_id):
34
  records = leaderboard.records
 
35
  model_records = records[records["model_id"] == model_id]
36
  datasets_unique = list(model_records["dataset_id"].unique())
37
 
38
  if len(datasets_unique) == 0:
39
- all_unique_datasets = list(records["dataset_id"].unique())
40
- return gr.update(choices=all_unique_datasets, value="")
41
 
42
  return gr.update(choices=datasets_unique, value=datasets_unique[0])
43
 
@@ -161,10 +163,11 @@ def list_labels_and_features_from_dataset(ds_labels, ds_features, model_labels,
161
  def precheck_model_ds_enable_example_btn(
162
  model_id, dataset_id, dataset_config, dataset_split
163
  ):
 
164
  model_task = check_model_task(model_id)
165
  preload_hf_inference_api(model_id)
166
  if model_task is None or model_task != "text-classification":
167
- gr.Warning("Please check your model.")
168
  return (gr.update(), gr.update(),"")
169
 
170
  if dataset_config is None or dataset_split is None or len(dataset_config) == 0:
@@ -195,9 +198,10 @@ def align_columns_and_show_prediction(
195
  run_inference,
196
  inference_token,
197
  ):
 
198
  model_task = check_model_task(model_id)
199
  if model_task is None or model_task != "text-classification":
200
- gr.Warning("Please check your model.")
201
  return (
202
  gr.update(visible=False),
203
  gr.update(visible=False),
@@ -338,7 +342,7 @@ def try_submit(m_id, d_id, config, split, inference, inference_token, uid):
338
  eval_str,
339
  threading.Lock(),
340
  )
341
- gr.Info("Your evaluation is submitted")
342
 
343
  return (
344
  gr.update(interactive=False), # Submit button
 
11
  from io_utils import read_column_mapping, write_column_mapping
12
  from run_jobs import save_job_to_pipe
13
  from text_classification import (
14
+ strip_model_id_from_url,
15
  check_model_task,
16
  preload_hf_inference_api,
17
  get_example_prediction,
 
22
  CHECK_CONFIG_OR_SPLIT_RAW,
23
  CONFIRM_MAPPING_DETAILS_FAIL_RAW,
24
  MAPPING_STYLED_ERROR_WARNING,
25
+ NOT_TEXT_CLASSIFICATION_MODEL_RAW,
26
  get_styled_input,
27
  )
28
 
 
34
 
35
  def get_related_datasets_from_leaderboard(model_id):
36
  records = leaderboard.records
37
+ model_id = strip_model_id_from_url(model_id)
38
  model_records = records[records["model_id"] == model_id]
39
  datasets_unique = list(model_records["dataset_id"].unique())
40
 
41
  if len(datasets_unique) == 0:
42
+ return gr.update(choices=[], value="")
 
43
 
44
  return gr.update(choices=datasets_unique, value=datasets_unique[0])
45
 
 
163
  def precheck_model_ds_enable_example_btn(
164
  model_id, dataset_id, dataset_config, dataset_split
165
  ):
166
+ model_id = strip_model_id_from_url(model_id)
167
  model_task = check_model_task(model_id)
168
  preload_hf_inference_api(model_id)
169
  if model_task is None or model_task != "text-classification":
170
+ gr.Warning(NOT_TEXT_CLASSIFICATION_MODEL_RAW)
171
  return (gr.update(), gr.update(),"")
172
 
173
  if dataset_config is None or dataset_split is None or len(dataset_config) == 0:
 
198
  run_inference,
199
  inference_token,
200
  ):
201
+ model_id = strip_model_id_from_url(model_id)
202
  model_task = check_model_task(model_id)
203
  if model_task is None or model_task != "text-classification":
204
+ gr.Warning(NOT_TEXT_CLASSIFICATION_MODEL_RAW)
205
  return (
206
  gr.update(visible=False),
207
  gr.update(visible=False),
 
342
  eval_str,
343
  threading.Lock(),
344
  )
345
+ gr.Info("Your evaluation has been submitted")
346
 
347
  return (
348
  gr.update(interactive=False), # Submit button
wordings.py CHANGED
@@ -38,6 +38,10 @@ MAPPING_STYLED_ERROR_WARNING = """
38
  </h3>
39
  """
40
 
 
 
 
 
41
  USE_INFERENCE_API_TIP = """
42
  We recommend to use
43
  <a href="https://huggingface.co/docs/api-inference/detailed_parameters#text-classification-task">
 
38
  </h3>
39
  """
40
 
41
+ NOT_TEXT_CLASSIFICATION_MODEL_RAW = """
42
+ Your model does not fall under the category of text classification. This page is specifically designated for the evaluation of text classification models.
43
+ """
44
+
45
  USE_INFERENCE_API_TIP = """
46
  We recommend to use
47
  <a href="https://huggingface.co/docs/api-inference/detailed_parameters#text-classification-task">