davidberenstein1957 HF staff commited on
Commit
5d3be21
·
1 Parent(s): 105084b

update textcat prompt based on multi_label

Browse files
src/synthetic_dataset_generator/apps/textcat.py CHANGED
@@ -104,8 +104,11 @@ def generate_dataset(
104
  temperature=temperature,
105
  is_sample=is_sample,
106
  )
 
 
 
107
  labeller_generator = get_labeller_generator(
108
- system_prompt=f"{system_prompt}. Optional labels: {', '.join(labels)}. Only apply relevant labels. Applying less labels is better than applying too many labels.",
109
  labels=labels,
110
  multi_label=multi_label,
111
  )
@@ -181,16 +184,20 @@ def generate_dataset(
181
  [
182
  label.lower().strip()
183
  for label in x
184
- if label.lower().strip() in labels
185
  ]
186
  )
187
  )
188
  )
 
189
  else:
190
  dataframe = dataframe.rename(columns={"labels": "label"})
191
  dataframe["label"] = dataframe["label"].apply(
192
- lambda x: x.lower().strip() if x and x.lower().strip() in labels else None
 
 
193
  )
 
194
 
195
  progress(1.0, desc="Dataset created")
196
  return dataframe
 
104
  temperature=temperature,
105
  is_sample=is_sample,
106
  )
107
+ updated_system_prompt = f"{system_prompt}. Optional labels: {', '.join(labels)}."
108
+ if multi_label:
109
+ updated_system_prompt = f"{updated_system_prompt}. Only apply relevant labels. Applying less labels is better than applying too many labels."
110
  labeller_generator = get_labeller_generator(
111
+ system_prompt=updated_system_prompt,
112
  labels=labels,
113
  multi_label=multi_label,
114
  )
 
184
  [
185
  label.lower().strip()
186
  for label in x
187
+ if label is not None and label.lower().strip() in labels
188
  ]
189
  )
190
  )
191
  )
192
+ dataframe = dataframe[dataframe["labels"].notna()]
193
  else:
194
  dataframe = dataframe.rename(columns={"labels": "label"})
195
  dataframe["label"] = dataframe["label"].apply(
196
+ lambda x: x.lower().strip()
197
+ if x and x.lower().strip() in labels
198
+ else random.choice(labels)
199
  )
200
+ dataframe = dataframe[dataframe["text"].notna()]
201
 
202
  progress(1.0, desc="Dataset created")
203
  return dataframe