davidberenstein1957 HF staff commited on
Commit
93f233e
·
1 Parent(s): b8a81f2

update beta distribution

Browse files
src/synthetic_dataset_generator/apps/textcat.py CHANGED
@@ -126,7 +126,11 @@ def generate_dataset(
126
  inputs = []
127
  for _ in range(batch_size):
128
  if multi_label:
129
- k = int(random.betavariate(alpha=2, beta=3) * len(labels))
 
 
 
 
130
  else:
131
  k = 1
132
 
 
126
  inputs = []
127
  for _ in range(batch_size):
128
  if multi_label:
129
+ num_labels = len(labels)
130
+ k = int(
131
+ random.betavariate(alpha=(num_labels - 1), beta=num_labels)
132
+ * num_labels
133
+ )
134
  else:
135
  k = 1
136