sdiazlor HF staff commited on
Commit
cef916e
ยท
1 Parent(s): e8f4283

fix cast index and no-labels errors

Browse files
.gitignore CHANGED
@@ -129,6 +129,7 @@ venv/
129
  ENV/
130
  env.bak/
131
  venv.bak/
 
132
 
133
  # Spyder project settings
134
  .spyderproject
 
129
  ENV/
130
  env.bak/
131
  venv.bak/
132
+ .python-version
133
 
134
  # Spyder project settings
135
  .spyderproject
src/synthetic_dataset_generator/apps/textcat.py CHANGED
@@ -64,7 +64,7 @@ def generate_system_prompt(dataset_description, progress=gr.Progress()):
64
  progress(1.0, desc="Prompt generated")
65
  data = json.loads(result)
66
  system_prompt = data["classification_task"]
67
- labels = data["labels"]
68
  return system_prompt, labels
69
 
70
 
@@ -177,14 +177,20 @@ def generate_dataset(
177
  distiset_results.append(record)
178
 
179
  dataframe = pd.DataFrame(distiset_results)
 
 
 
 
 
 
 
180
  if multi_label:
181
  dataframe["labels"] = dataframe["labels"].apply(
182
  lambda x: list(
183
  set(
184
  [
185
- label.lower().strip()
186
  for label in x
187
- if label is not None and label.lower().strip() in labels
188
  ]
189
  )
190
  )
@@ -214,6 +220,7 @@ def push_dataset_to_hub(
214
  pipeline_code: str = "",
215
  progress=gr.Progress(),
216
  ):
 
217
  progress(0.0, desc="Validating")
218
  repo_id = validate_push_to_hub(org_name, repo_name)
219
  progress(0.3, desc="Preprocessing")
@@ -230,7 +237,10 @@ def push_dataset_to_hub(
230
  features = Features(
231
  {"text": Value("string"), "label": ClassLabel(names=labels)}
232
  )
233
- dataset = Dataset.from_pandas(dataframe, features=features)
 
 
 
234
  dataset = combine_datasets(repo_id, dataset)
235
  distiset = Distiset({"default": dataset})
236
  progress(0.9, desc="Pushing dataset")
@@ -269,6 +279,7 @@ def push_dataset(
269
  num_rows=num_rows,
270
  temperature=temperature,
271
  )
 
272
  push_dataset_to_hub(
273
  dataframe,
274
  org_name,
@@ -365,7 +376,7 @@ def push_dataset(
365
  and all(label in labels for label in sample["labels"])
366
  )
367
  )
368
- else []
369
  ),
370
  )
371
  for sample in hf_dataset
 
64
  progress(1.0, desc="Prompt generated")
65
  data = json.loads(result)
66
  system_prompt = data["classification_task"]
67
+ labels = get_preprocess_labels(data["labels"])
68
  return system_prompt, labels
69
 
70
 
 
177
  distiset_results.append(record)
178
 
179
  dataframe = pd.DataFrame(distiset_results)
180
+ if (
181
+ not labels
182
+ or len(set(label.lower().strip() for label in labels if label.strip())) < 2
183
+ ):
184
+ raise gr.Error(
185
+ "Please provide at least 2 unique, non-empty labels to classify your text."
186
+ )
187
  if multi_label:
188
  dataframe["labels"] = dataframe["labels"].apply(
189
  lambda x: list(
190
  set(
191
  [
192
+ label.lower().strip() if (label is not None and label.lower().strip() in labels) else random.choice(labels)
193
  for label in x
 
194
  ]
195
  )
196
  )
 
220
  pipeline_code: str = "",
221
  progress=gr.Progress(),
222
  ):
223
+ gr.Info(message=f"Dataframe columns in push dataset to hub: {dataframe.columns}", duration=20)
224
  progress(0.0, desc="Validating")
225
  repo_id = validate_push_to_hub(org_name, repo_name)
226
  progress(0.3, desc="Preprocessing")
 
237
  features = Features(
238
  {"text": Value("string"), "label": ClassLabel(names=labels)}
239
  )
240
+ dataset = Dataset.from_pandas(
241
+ dataframe.reset_index(drop=True),
242
+ features=features,
243
+ )
244
  dataset = combine_datasets(repo_id, dataset)
245
  distiset = Distiset({"default": dataset})
246
  progress(0.9, desc="Pushing dataset")
 
279
  num_rows=num_rows,
280
  temperature=temperature,
281
  )
282
+ gr.Info(message=f"Dataframe columns: {dataframe.columns}", duration=20)
283
  push_dataset_to_hub(
284
  dataframe,
285
  org_name,
 
376
  and all(label in labels for label in sample["labels"])
377
  )
378
  )
379
+ else None
380
  ),
381
  )
382
  for sample in hf_dataset