Commit
·
136bd13
1
Parent(s):
d2df8be
update message
Browse files
src/synthetic_dataset_generator/app.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
from synthetic_dataset_generator._tabbedinterface import TabbedInterface
|
|
|
2 |
# from synthetic_dataset_generator.apps.eval import app as eval_app
|
3 |
from synthetic_dataset_generator.apps.readme import app as readme_app
|
4 |
from synthetic_dataset_generator.apps.sft import app as sft_app
|
@@ -15,9 +16,6 @@ button[role="tab"][aria-selected="true"]:hover {border-color: var(--button-prima
|
|
15 |
#system_prompt_examples { color: var(--body-text-color) !important; background-color: var(--block-background-fill) !important;}
|
16 |
.container {padding-inline: 0 !important}
|
17 |
#sign_in_button { flex-grow: 0; width: auto !important; display: flex; align-items: center; justify-content: center; margin: 0 auto; }
|
18 |
-
.table-view .table-wrap {
|
19 |
-
max-height: 450px;
|
20 |
-
}
|
21 |
"""
|
22 |
|
23 |
image = """<br><img src="https://raw.githubusercontent.com/argilla-io/synthetic-data-generator/main/assets/logo.svg" alt="Synthetic Data Generator Logo" style="display: block; margin-left: auto; margin-right: auto; width: clamp(50%, 400px, 100%)"/>"""
|
|
|
1 |
from synthetic_dataset_generator._tabbedinterface import TabbedInterface
|
2 |
+
|
3 |
# from synthetic_dataset_generator.apps.eval import app as eval_app
|
4 |
from synthetic_dataset_generator.apps.readme import app as readme_app
|
5 |
from synthetic_dataset_generator.apps.sft import app as sft_app
|
|
|
16 |
#system_prompt_examples { color: var(--body-text-color) !important; background-color: var(--block-background-fill) !important;}
|
17 |
.container {padding-inline: 0 !important}
|
18 |
#sign_in_button { flex-grow: 0; width: auto !important; display: flex; align-items: center; justify-content: center; margin: 0 auto; }
|
|
|
|
|
|
|
19 |
"""
|
20 |
|
21 |
image = """<br><img src="https://raw.githubusercontent.com/argilla-io/synthetic-data-generator/main/assets/logo.svg" alt="Synthetic Data Generator Logo" style="display: block; margin-left: auto; margin-right: auto; width: clamp(50%, 400px, 100%)"/>"""
|
src/synthetic_dataset_generator/apps/eval.py
CHANGED
@@ -750,7 +750,6 @@ with gr.Blocks() as app:
|
|
750 |
headers=["prompt", "completion", "evaluation"],
|
751 |
wrap=True,
|
752 |
interactive=False,
|
753 |
-
elem_classes="table-view",
|
754 |
)
|
755 |
|
756 |
gr.HTML(value="<hr>")
|
|
|
750 |
headers=["prompt", "completion", "evaluation"],
|
751 |
wrap=True,
|
752 |
interactive=False,
|
|
|
753 |
)
|
754 |
|
755 |
gr.HTML(value="<hr>")
|
src/synthetic_dataset_generator/apps/sft.py
CHANGED
@@ -55,10 +55,10 @@ def convert_dataframe_messages(dataframe: pd.DataFrame) -> pd.DataFrame:
|
|
55 |
|
56 |
|
57 |
def generate_system_prompt(dataset_description, progress=gr.Progress()):
|
58 |
-
progress(0.0, desc="
|
59 |
-
progress(0.3, desc="Initializing
|
60 |
generate_description = get_prompt_generator()
|
61 |
-
progress(0.7, desc="Generating
|
62 |
result = next(
|
63 |
generate_description.process(
|
64 |
[
|
@@ -68,7 +68,7 @@ def generate_system_prompt(dataset_description, progress=gr.Progress()):
|
|
68 |
]
|
69 |
)
|
70 |
)[0]["generation"]
|
71 |
-
progress(1.0, desc="
|
72 |
return result
|
73 |
|
74 |
|
@@ -88,7 +88,6 @@ def _get_dataframe():
|
|
88 |
headers=["prompt", "completion"],
|
89 |
wrap=True,
|
90 |
interactive=False,
|
91 |
-
elem_classes="table-view",
|
92 |
)
|
93 |
|
94 |
|
|
|
55 |
|
56 |
|
57 |
def generate_system_prompt(dataset_description, progress=gr.Progress()):
|
58 |
+
progress(0.0, desc="Starting")
|
59 |
+
progress(0.3, desc="Initializing")
|
60 |
generate_description = get_prompt_generator()
|
61 |
+
progress(0.7, desc="Generating")
|
62 |
result = next(
|
63 |
generate_description.process(
|
64 |
[
|
|
|
68 |
]
|
69 |
)
|
70 |
)[0]["generation"]
|
71 |
+
progress(1.0, desc="Prompt generated")
|
72 |
return result
|
73 |
|
74 |
|
|
|
88 |
headers=["prompt", "completion"],
|
89 |
wrap=True,
|
90 |
interactive=False,
|
|
|
91 |
)
|
92 |
|
93 |
|
src/synthetic_dataset_generator/apps/textcat.py
CHANGED
@@ -42,15 +42,14 @@ def _get_dataframe():
|
|
42 |
headers=["labels", "text"],
|
43 |
wrap=True,
|
44 |
interactive=False,
|
45 |
-
elem_classes="table-view",
|
46 |
)
|
47 |
|
48 |
|
49 |
def generate_system_prompt(dataset_description, progress=gr.Progress()):
|
50 |
-
progress(0.0, desc="
|
51 |
-
progress(0.3, desc="Initializing
|
52 |
generate_description = get_prompt_generator()
|
53 |
-
progress(0.7, desc="Generating
|
54 |
result = next(
|
55 |
generate_description.process(
|
56 |
[
|
@@ -60,7 +59,7 @@ def generate_system_prompt(dataset_description, progress=gr.Progress()):
|
|
60 |
]
|
61 |
)
|
62 |
)[0]["generation"]
|
63 |
-
progress(1.0, desc="
|
64 |
data = json.loads(result)
|
65 |
system_prompt = data["classification_task"]
|
66 |
labels = data["labels"]
|
@@ -94,7 +93,7 @@ def generate_dataset(
|
|
94 |
is_sample: bool = False,
|
95 |
progress=gr.Progress(),
|
96 |
) -> pd.DataFrame:
|
97 |
-
progress(0.0, desc="(1/2) Generating
|
98 |
labels = get_preprocess_labels(labels)
|
99 |
textcat_generator = get_textcat_generator(
|
100 |
difficulty=difficulty,
|
@@ -117,7 +116,7 @@ def generate_dataset(
|
|
117 |
progress(
|
118 |
2 * 0.5 * n_processed / num_rows,
|
119 |
total=total_steps,
|
120 |
-
desc="(1/2) Generating
|
121 |
)
|
122 |
remaining_rows = num_rows - n_processed
|
123 |
batch_size = min(batch_size, remaining_rows)
|
@@ -139,14 +138,14 @@ def generate_dataset(
|
|
139 |
result["text"] = result["input_text"]
|
140 |
|
141 |
# label text classification data
|
142 |
-
progress(2 * 0.5, desc="(
|
143 |
n_processed = 0
|
144 |
labeller_results = []
|
145 |
while n_processed < num_rows:
|
146 |
progress(
|
147 |
0.5 + 0.5 * n_processed / num_rows,
|
148 |
total=total_steps,
|
149 |
-
desc="(
|
150 |
)
|
151 |
batch = textcat_results[n_processed : n_processed + batch_size]
|
152 |
labels_batch = list(labeller_generator.process(inputs=batch))
|
@@ -182,7 +181,7 @@ def generate_dataset(
|
|
182 |
)
|
183 |
)
|
184 |
)
|
185 |
-
progress(1.0, desc="Dataset
|
186 |
return dataframe
|
187 |
|
188 |
|
@@ -316,7 +315,7 @@ def push_dataset(
|
|
316 |
client=client,
|
317 |
)
|
318 |
rg_dataset = rg_dataset.create()
|
319 |
-
progress(0.7, desc="Pushing dataset
|
320 |
hf_dataset = Dataset.from_pandas(dataframe)
|
321 |
records = [
|
322 |
rg.Record(
|
@@ -347,7 +346,7 @@ def push_dataset(
|
|
347 |
for sample in hf_dataset
|
348 |
]
|
349 |
rg_dataset.records.log(records=records)
|
350 |
-
progress(1.0, desc="Dataset pushed
|
351 |
except Exception as e:
|
352 |
raise gr.Error(f"Error pushing dataset to Argilla: {e}")
|
353 |
return ""
|
@@ -406,61 +405,64 @@ with gr.Blocks() as app:
|
|
406 |
|
407 |
gr.HTML("<hr>")
|
408 |
gr.Markdown("## 2. Configure your dataset")
|
409 |
-
with gr.Row(equal_height=
|
410 |
-
with gr.
|
411 |
-
|
412 |
-
|
413 |
-
|
414 |
-
|
415 |
-
|
416 |
-
|
417 |
-
|
418 |
-
|
419 |
-
|
420 |
-
|
421 |
-
|
422 |
-
|
423 |
-
|
424 |
-
|
425 |
-
|
426 |
-
|
427 |
-
|
428 |
-
|
429 |
-
|
430 |
-
|
431 |
-
|
432 |
-
|
433 |
-
|
434 |
-
|
435 |
-
|
436 |
-
|
437 |
-
|
438 |
-
|
439 |
-
|
440 |
-
|
441 |
-
|
442 |
-
|
443 |
-
|
444 |
-
|
445 |
-
|
446 |
-
|
447 |
-
|
448 |
-
|
449 |
-
|
450 |
-
|
451 |
-
|
452 |
-
|
453 |
-
|
454 |
-
|
455 |
-
|
456 |
-
|
457 |
-
|
458 |
-
|
459 |
-
|
460 |
-
|
461 |
-
|
462 |
-
|
463 |
-
|
|
|
|
|
|
|
464 |
|
465 |
gr.HTML("<hr>")
|
466 |
gr.Markdown("## 3. Generate your dataset")
|
|
|
42 |
headers=["labels", "text"],
|
43 |
wrap=True,
|
44 |
interactive=False,
|
|
|
45 |
)
|
46 |
|
47 |
|
48 |
def generate_system_prompt(dataset_description, progress=gr.Progress()):
|
49 |
+
progress(0.0, desc="Starting")
|
50 |
+
progress(0.3, desc="Initializing")
|
51 |
generate_description = get_prompt_generator()
|
52 |
+
progress(0.7, desc="Generating")
|
53 |
result = next(
|
54 |
generate_description.process(
|
55 |
[
|
|
|
59 |
]
|
60 |
)
|
61 |
)[0]["generation"]
|
62 |
+
progress(1.0, desc="Prompt generated")
|
63 |
data = json.loads(result)
|
64 |
system_prompt = data["classification_task"]
|
65 |
labels = data["labels"]
|
|
|
93 |
is_sample: bool = False,
|
94 |
progress=gr.Progress(),
|
95 |
) -> pd.DataFrame:
|
96 |
+
progress(0.0, desc="(1/2) Generating dataset")
|
97 |
labels = get_preprocess_labels(labels)
|
98 |
textcat_generator = get_textcat_generator(
|
99 |
difficulty=difficulty,
|
|
|
116 |
progress(
|
117 |
2 * 0.5 * n_processed / num_rows,
|
118 |
total=total_steps,
|
119 |
+
desc="(1/2) Generating dataset",
|
120 |
)
|
121 |
remaining_rows = num_rows - n_processed
|
122 |
batch_size = min(batch_size, remaining_rows)
|
|
|
138 |
result["text"] = result["input_text"]
|
139 |
|
140 |
# label text classification data
|
141 |
+
progress(2 * 0.5, desc="(2/2) Labeling dataset")
|
142 |
n_processed = 0
|
143 |
labeller_results = []
|
144 |
while n_processed < num_rows:
|
145 |
progress(
|
146 |
0.5 + 0.5 * n_processed / num_rows,
|
147 |
total=total_steps,
|
148 |
+
desc="(2/2) Labeling dataset",
|
149 |
)
|
150 |
batch = textcat_results[n_processed : n_processed + batch_size]
|
151 |
labels_batch = list(labeller_generator.process(inputs=batch))
|
|
|
181 |
)
|
182 |
)
|
183 |
)
|
184 |
+
progress(1.0, desc="Dataset created")
|
185 |
return dataframe
|
186 |
|
187 |
|
|
|
315 |
client=client,
|
316 |
)
|
317 |
rg_dataset = rg_dataset.create()
|
318 |
+
progress(0.7, desc="Pushing dataset")
|
319 |
hf_dataset = Dataset.from_pandas(dataframe)
|
320 |
records = [
|
321 |
rg.Record(
|
|
|
346 |
for sample in hf_dataset
|
347 |
]
|
348 |
rg_dataset.records.log(records=records)
|
349 |
+
progress(1.0, desc="Dataset pushed")
|
350 |
except Exception as e:
|
351 |
raise gr.Error(f"Error pushing dataset to Argilla: {e}")
|
352 |
return ""
|
|
|
405 |
|
406 |
gr.HTML("<hr>")
|
407 |
gr.Markdown("## 2. Configure your dataset")
|
408 |
+
with gr.Row(equal_height=True):
|
409 |
+
with gr.Row(equal_height=False):
|
410 |
+
with gr.Column(scale=2):
|
411 |
+
system_prompt = gr.Textbox(
|
412 |
+
label="System prompt",
|
413 |
+
placeholder="You are a helpful assistant.",
|
414 |
+
visible=True,
|
415 |
+
)
|
416 |
+
labels = gr.Dropdown(
|
417 |
+
choices=[],
|
418 |
+
allow_custom_value=True,
|
419 |
+
interactive=True,
|
420 |
+
label="Labels",
|
421 |
+
multiselect=True,
|
422 |
+
info="Add the labels to classify the text.",
|
423 |
+
)
|
424 |
+
num_labels = gr.Number(
|
425 |
+
label="Number of labels per text",
|
426 |
+
value=1,
|
427 |
+
minimum=1,
|
428 |
+
maximum=10,
|
429 |
+
info="Select 1 for single-label and >1 for multi-label.",
|
430 |
+
interactive=True,
|
431 |
+
)
|
432 |
+
clarity = gr.Dropdown(
|
433 |
+
choices=[
|
434 |
+
("Clear", "clear"),
|
435 |
+
(
|
436 |
+
"Understandable",
|
437 |
+
"understandable with some effort",
|
438 |
+
),
|
439 |
+
("Ambiguous", "ambiguous"),
|
440 |
+
("Mixed", "mixed"),
|
441 |
+
],
|
442 |
+
value="understandable with some effort",
|
443 |
+
label="Clarity",
|
444 |
+
info="Set how easily the correct label or labels can be identified.",
|
445 |
+
interactive=True,
|
446 |
+
)
|
447 |
+
difficulty = gr.Dropdown(
|
448 |
+
choices=[
|
449 |
+
("High School", "high school"),
|
450 |
+
("College", "college"),
|
451 |
+
("PhD", "PhD"),
|
452 |
+
("Mixed", "mixed"),
|
453 |
+
],
|
454 |
+
value="high school",
|
455 |
+
label="Difficulty",
|
456 |
+
info="Select the comprehension level for the text. Ensure it matches the task context.",
|
457 |
+
interactive=True,
|
458 |
+
)
|
459 |
+
with gr.Row():
|
460 |
+
clear_btn_full = gr.Button("Clear", variant="secondary")
|
461 |
+
btn_apply_to_sample_dataset = gr.Button(
|
462 |
+
"Save", variant="primary"
|
463 |
+
)
|
464 |
+
with gr.Column(scale=3):
|
465 |
+
dataframe = _get_dataframe()
|
466 |
|
467 |
gr.HTML("<hr>")
|
468 |
gr.Markdown("## 3. Generate your dataset")
|