davidberenstein1957 HF staff commited on
Commit
6f3d06e
·
1 Parent(s): 70abf20

fix returning duplicate labels

Browse files
src/synthetic_dataset_generator/apps/textcat.py CHANGED
@@ -186,13 +186,15 @@ def generate_dataset(
186
  if isinstance(x, str): # single label
187
  return [x.lower().strip()]
188
  elif isinstance(x, list): # multiple labels
189
- return [
190
- label.lower().strip()
191
- for label in x
192
- if label.lower().strip() in labels
193
- ]
 
 
194
  else:
195
- return [random.choice(labels)]
196
 
197
  dataframe["labels"] = dataframe["labels"].apply(_validate_labels)
198
  dataframe = dataframe[dataframe["labels"].notna()]
 
186
  if isinstance(x, str): # single label
187
  return [x.lower().strip()]
188
  elif isinstance(x, list): # multiple labels
189
+ return list(
190
+ set(
191
+ label.lower().strip()
192
+ for label in x
193
+ if label.lower().strip() in labels
194
+ )
195
+ )
196
  else:
197
+ return list(set([random.choice(labels)]))
198
 
199
  dataframe["labels"] = dataframe["labels"].apply(_validate_labels)
200
  dataframe = dataframe[dataframe["labels"].notna()]