davidberenstein1957's picture
update textcat prompt based on multi_label
5d3be21
raw
history blame
20.8 kB
import json
import random
import uuid
from typing import List, Union
import argilla as rg
import gradio as gr
import pandas as pd
from datasets import ClassLabel, Dataset, Features, Sequence, Value
from distilabel.distiset import Distiset
from huggingface_hub import HfApi
from src.synthetic_dataset_generator.apps.base import (
combine_datasets,
hide_success_message,
push_pipeline_code_to_hub,
show_success_message,
test_max_num_rows,
validate_argilla_user_workspace_dataset,
validate_push_to_hub,
)
from src.synthetic_dataset_generator.pipelines.embeddings import (
get_embeddings,
get_sentence_embedding_dimensions,
)
from src.synthetic_dataset_generator.pipelines.textcat import (
DEFAULT_DATASET_DESCRIPTIONS,
generate_pipeline_code,
get_labeller_generator,
get_prompt_generator,
get_textcat_generator,
)
from src.synthetic_dataset_generator.utils import (
get_argilla_client,
get_org_dropdown,
get_preprocess_labels,
swap_visibility,
)
from synthetic_dataset_generator.constants import DEFAULT_BATCH_SIZE
def _get_dataframe():
return gr.Dataframe(
headers=["labels", "text"],
wrap=True,
interactive=False,
)
def generate_system_prompt(dataset_description, progress=gr.Progress()):
progress(0.0, desc="Starting")
progress(0.3, desc="Initializing")
generate_description = get_prompt_generator()
progress(0.7, desc="Generating")
result = next(
generate_description.process(
[
{
"instruction": dataset_description,
}
]
)
)[0]["generation"]
progress(1.0, desc="Prompt generated")
data = json.loads(result)
system_prompt = data["classification_task"]
labels = data["labels"]
return system_prompt, labels
def generate_sample_dataset(
system_prompt, difficulty, clarity, labels, multi_label, progress=gr.Progress()
):
dataframe = generate_dataset(
system_prompt=system_prompt,
difficulty=difficulty,
clarity=clarity,
labels=labels,
multi_label=multi_label,
num_rows=10,
progress=progress,
is_sample=True,
)
return dataframe
def generate_dataset(
system_prompt: str,
difficulty: str,
clarity: str,
labels: List[str] = None,
multi_label: bool = False,
num_rows: int = 10,
temperature: float = 0.9,
is_sample: bool = False,
progress=gr.Progress(),
) -> pd.DataFrame:
num_rows = test_max_num_rows(num_rows)
progress(0.0, desc="(1/2) Generating dataset")
labels = get_preprocess_labels(labels)
textcat_generator = get_textcat_generator(
difficulty=difficulty,
clarity=clarity,
temperature=temperature,
is_sample=is_sample,
)
updated_system_prompt = f"{system_prompt}. Optional labels: {', '.join(labels)}."
if multi_label:
updated_system_prompt = f"{updated_system_prompt}. Only apply relevant labels. Applying less labels is better than applying too many labels."
labeller_generator = get_labeller_generator(
system_prompt=updated_system_prompt,
labels=labels,
multi_label=multi_label,
)
total_steps: int = num_rows * 2
batch_size = DEFAULT_BATCH_SIZE
# create text classification data
n_processed = 0
textcat_results = []
while n_processed < num_rows:
progress(
2 * 0.5 * n_processed / num_rows,
total=total_steps,
desc="(1/2) Generating dataset",
)
remaining_rows = num_rows - n_processed
batch_size = min(batch_size, remaining_rows)
inputs = []
for _ in range(batch_size):
if multi_label:
num_labels = len(labels)
k = int(
random.betavariate(alpha=(num_labels - 1), beta=num_labels)
* num_labels
)
else:
k = 1
sampled_labels = random.sample(labels, min(k, len(labels)))
random.shuffle(sampled_labels)
inputs.append(
{
"task": f"{system_prompt}. The text represents the following categories: {', '.join(sampled_labels)}"
}
)
batch = list(textcat_generator.process(inputs=inputs))
textcat_results.extend(batch[0])
n_processed += batch_size
for result in textcat_results:
result["text"] = result["input_text"]
# label text classification data
progress(2 * 0.5, desc="(2/2) Labeling dataset")
n_processed = 0
labeller_results = []
while n_processed < num_rows:
progress(
0.5 + 0.5 * n_processed / num_rows,
total=total_steps,
desc="(2/2) Labeling dataset",
)
batch = textcat_results[n_processed : n_processed + batch_size]
labels_batch = list(labeller_generator.process(inputs=batch))
labeller_results.extend(labels_batch[0])
n_processed += batch_size
progress(
1,
total=total_steps,
desc="(2/2) Creating dataset",
)
# create final dataset
distiset_results = []
for result in labeller_results:
record = {key: result[key] for key in ["labels", "text"] if key in result}
distiset_results.append(record)
dataframe = pd.DataFrame(distiset_results)
if multi_label:
dataframe["labels"] = dataframe["labels"].apply(
lambda x: list(
set(
[
label.lower().strip()
for label in x
if label is not None and label.lower().strip() in labels
]
)
)
)
dataframe = dataframe[dataframe["labels"].notna()]
else:
dataframe = dataframe.rename(columns={"labels": "label"})
dataframe["label"] = dataframe["label"].apply(
lambda x: x.lower().strip()
if x and x.lower().strip() in labels
else random.choice(labels)
)
dataframe = dataframe[dataframe["text"].notna()]
progress(1.0, desc="Dataset created")
return dataframe
def push_dataset_to_hub(
dataframe: pd.DataFrame,
org_name: str,
repo_name: str,
multi_label: bool = False,
labels: List[str] = None,
oauth_token: Union[gr.OAuthToken, None] = None,
private: bool = False,
pipeline_code: str = "",
progress=gr.Progress(),
):
progress(0.0, desc="Validating")
repo_id = validate_push_to_hub(org_name, repo_name)
progress(0.3, desc="Preprocessing")
labels = get_preprocess_labels(labels)
progress(0.7, desc="Creating dataset")
if multi_label:
features = Features(
{
"text": Value("string"),
"labels": Sequence(feature=ClassLabel(names=labels)),
}
)
else:
features = Features(
{"text": Value("string"), "label": ClassLabel(names=labels)}
)
dataset = Dataset.from_pandas(dataframe, features=features)
dataset = combine_datasets(repo_id, dataset)
distiset = Distiset({"default": dataset})
progress(0.9, desc="Pushing dataset")
distiset.push_to_hub(
repo_id=repo_id,
private=private,
include_script=False,
token=oauth_token.token,
create_pr=False,
)
push_pipeline_code_to_hub(pipeline_code, org_name, repo_name, oauth_token)
progress(1.0, desc="Dataset pushed")
def push_dataset(
org_name: str,
repo_name: str,
system_prompt: str,
difficulty: str,
clarity: str,
multi_label: int = 1,
num_rows: int = 10,
labels: List[str] = None,
private: bool = False,
temperature: float = 0.8,
pipeline_code: str = "",
oauth_token: Union[gr.OAuthToken, None] = None,
progress=gr.Progress(),
) -> pd.DataFrame:
dataframe = generate_dataset(
system_prompt=system_prompt,
difficulty=difficulty,
clarity=clarity,
multi_label=multi_label,
labels=labels,
num_rows=num_rows,
temperature=temperature,
)
push_dataset_to_hub(
dataframe,
org_name,
repo_name,
multi_label,
labels,
oauth_token,
private,
pipeline_code,
)
dataframe = dataframe[
(dataframe["text"].str.strip() != "") & (dataframe["text"].notna())
]
try:
progress(0.1, desc="Setting up user and workspace")
hf_user = HfApi().whoami(token=oauth_token.token)["name"]
client = get_argilla_client()
if client is None:
return ""
labels = get_preprocess_labels(labels)
settings = rg.Settings(
fields=[
rg.TextField(
name="text",
description="The text classification data",
title="Text",
),
],
questions=[
(
rg.MultiLabelQuestion(
name="labels",
title="Labels",
description="The labels of the conversation",
labels=labels,
)
if multi_label
else rg.LabelQuestion(
name="label",
title="Label",
description="The label of the text",
labels=labels,
)
),
],
metadata=[
rg.IntegerMetadataProperty(name="text_length", title="Text Length"),
],
vectors=[
rg.VectorField(
name="text_embeddings",
dimensions=get_sentence_embedding_dimensions(),
)
],
guidelines="Please review the text and provide or correct the label where needed.",
)
dataframe["text_length"] = dataframe["text"].apply(len)
dataframe["text_embeddings"] = get_embeddings(dataframe["text"].to_list())
progress(0.5, desc="Creating dataset")
rg_dataset = client.datasets(name=repo_name, workspace=hf_user)
if rg_dataset is None:
rg_dataset = rg.Dataset(
name=repo_name,
workspace=hf_user,
settings=settings,
client=client,
)
rg_dataset = rg_dataset.create()
progress(0.7, desc="Pushing dataset")
hf_dataset = Dataset.from_pandas(dataframe)
records = [
rg.Record(
fields={
"text": sample["text"],
},
metadata={"text_length": sample["text_length"]},
vectors={"text_embeddings": sample["text_embeddings"]},
suggestions=(
[
rg.Suggestion(
question_name="labels" if multi_label else "label",
value=(
sample["labels"] if multi_label else sample["label"]
),
)
]
if (
(not multi_label and sample["label"] in labels)
or (
multi_label
and all(label in labels for label in sample["labels"])
)
)
else []
),
)
for sample in hf_dataset
]
rg_dataset.records.log(records=records)
progress(1.0, desc="Dataset pushed")
except Exception as e:
raise gr.Error(f"Error pushing dataset to Argilla: {e}")
return ""
def validate_input_labels(labels):
if not labels or len(labels) < 2:
raise gr.Error(
f"Please select at least 2 labels to classify your text. You selected {len(labels) if labels else 0}."
)
return labels
def show_pipeline_code_visibility():
return {pipeline_code_ui: gr.Accordion(visible=True)}
def hide_pipeline_code_visibility():
return {pipeline_code_ui: gr.Accordion(visible=False)}
######################
# Gradio UI
######################
with gr.Blocks() as app:
with gr.Column() as main_ui:
gr.Markdown("## 1. Describe the dataset you want")
with gr.Row():
with gr.Column(scale=2):
dataset_description = gr.Textbox(
label="Dataset description",
placeholder="Give a precise description of your desired dataset.",
)
with gr.Row():
clear_btn_part = gr.Button(
"Clear",
variant="secondary",
)
load_btn = gr.Button(
"Create",
variant="primary",
)
with gr.Column(scale=3):
examples = gr.Examples(
examples=DEFAULT_DATASET_DESCRIPTIONS,
inputs=[dataset_description],
cache_examples=False,
label="Examples",
)
gr.HTML("<hr>")
gr.Markdown("## 2. Configure your dataset")
with gr.Row(equal_height=True):
with gr.Row(equal_height=False):
with gr.Column(scale=2):
system_prompt = gr.Textbox(
label="System prompt",
placeholder="You are a helpful assistant.",
visible=True,
)
labels = gr.Dropdown(
choices=[],
allow_custom_value=True,
interactive=True,
label="Labels",
multiselect=True,
info="Add the labels to classify the text.",
)
multi_label = gr.Checkbox(
label="Multi-label",
value=False,
interactive=True,
info="If checked, the text will be classified into multiple labels.",
)
clarity = gr.Dropdown(
choices=[
("Clear", "clear"),
(
"Understandable",
"understandable with some effort",
),
("Ambiguous", "ambiguous"),
("Mixed", "mixed"),
],
value="mixed",
label="Clarity",
info="Set how easily the correct label or labels can be identified.",
interactive=True,
)
difficulty = gr.Dropdown(
choices=[
("High School", "high school"),
("College", "college"),
("PhD", "PhD"),
("Mixed", "mixed"),
],
value="high school",
label="Difficulty",
info="Select the comprehension level for the text. Ensure it matches the task context.",
interactive=True,
)
with gr.Row():
clear_btn_full = gr.Button("Clear", variant="secondary")
btn_apply_to_sample_dataset = gr.Button(
"Save", variant="primary"
)
with gr.Column(scale=3):
dataframe = _get_dataframe()
gr.HTML("<hr>")
gr.Markdown("## 3. Generate your dataset")
with gr.Row(equal_height=False):
with gr.Column(scale=2):
org_name = get_org_dropdown()
repo_name = gr.Textbox(
label="Repo name",
placeholder="dataset_name",
value=f"my-distiset-{str(uuid.uuid4())[:8]}",
interactive=True,
)
num_rows = gr.Number(
label="Number of rows",
value=10,
interactive=True,
scale=1,
)
temperature = gr.Slider(
label="Temperature",
minimum=0.1,
maximum=1,
value=0.8,
step=0.1,
interactive=True,
)
private = gr.Checkbox(
label="Private dataset",
value=False,
interactive=True,
scale=1,
)
btn_push_to_hub = gr.Button("Push to Hub", variant="primary", scale=2)
with gr.Column(scale=3):
success_message = gr.Markdown(
visible=True,
min_height=100, # don't remove this otherwise progress is not visible
)
with gr.Accordion(
"Customize your pipeline with distilabel",
open=False,
visible=False,
) as pipeline_code_ui:
code = generate_pipeline_code(
system_prompt.value,
difficulty=difficulty.value,
clarity=clarity.value,
labels=labels.value,
num_labels=len(labels.value) if multi_label.value else 1,
num_rows=num_rows.value,
temperature=temperature.value,
)
pipeline_code = gr.Code(
value=code,
language="python",
label="Distilabel Pipeline Code",
)
load_btn.click(
fn=generate_system_prompt,
inputs=[dataset_description],
outputs=[system_prompt, labels],
show_progress=True,
).then(
fn=generate_sample_dataset,
inputs=[system_prompt, difficulty, clarity, labels, multi_label],
outputs=[dataframe],
show_progress=True,
)
btn_apply_to_sample_dataset.click(
fn=generate_sample_dataset,
inputs=[system_prompt, difficulty, clarity, labels, multi_label],
outputs=[dataframe],
show_progress=True,
)
btn_push_to_hub.click(
fn=validate_argilla_user_workspace_dataset,
inputs=[repo_name],
outputs=[success_message],
show_progress=True,
).then(
fn=validate_push_to_hub,
inputs=[org_name, repo_name],
outputs=[success_message],
show_progress=True,
).success(
fn=hide_success_message,
outputs=[success_message],
show_progress=True,
).success(
fn=hide_pipeline_code_visibility,
inputs=[],
outputs=[pipeline_code_ui],
).success(
fn=push_dataset,
inputs=[
org_name,
repo_name,
system_prompt,
difficulty,
clarity,
multi_label,
num_rows,
labels,
private,
temperature,
pipeline_code,
],
outputs=[success_message],
show_progress=True,
).success(
fn=show_success_message,
inputs=[org_name, repo_name],
outputs=[success_message],
).success(
fn=generate_pipeline_code,
inputs=[
system_prompt,
difficulty,
clarity,
labels,
multi_label,
num_rows,
temperature,
],
outputs=[pipeline_code],
).success(
fn=show_pipeline_code_visibility,
inputs=[],
outputs=[pipeline_code_ui],
)
gr.on(
triggers=[clear_btn_part.click, clear_btn_full.click],
fn=lambda _: (
"",
"",
[],
_get_dataframe(),
),
inputs=[dataframe],
outputs=[dataset_description, system_prompt, labels, dataframe],
)
app.load(fn=swap_visibility, outputs=main_ui)
app.load(fn=get_org_dropdown, outputs=[org_name])