Spaces:
Running
on
Zero
Running
on
Zero
Commit
Β·
d806dcd
1
Parent(s):
e5960a0
add progress bar
Browse files
app.py
CHANGED
@@ -2,11 +2,13 @@ import gradio as gr
|
|
2 |
import polars as pl
|
3 |
from gradio_huggingfacehub_search import HuggingfaceHubSearch
|
4 |
import torch
|
5 |
-
import
|
|
|
6 |
from torch import nn
|
7 |
from transformers import AutoModel, AutoTokenizer, AutoConfig
|
8 |
from huggingface_hub import PyTorchModelHubMixin
|
9 |
import pandas as pd
|
|
|
10 |
|
11 |
|
12 |
class QualityModel(nn.Module, PyTorchModelHubMixin):
|
@@ -31,7 +33,7 @@ model = QualityModel.from_pretrained("nvidia/quality-classifier-deberta").to(dev
|
|
31 |
model.eval()
|
32 |
|
33 |
|
34 |
-
@spaces.GPU
|
35 |
def predict(texts: list[str]):
|
36 |
inputs = tokenizer(
|
37 |
texts, return_tensors="pt", padding="longest", truncation=True
|
@@ -44,12 +46,23 @@ def predict(texts: list[str]):
|
|
44 |
return predicted_domains
|
45 |
|
46 |
|
|
|
|
|
|
|
|
|
|
|
47 |
def plot_and_df(texts, preds):
|
48 |
texts_df = pd.DataFrame({"quality": preds, "text": texts})
|
49 |
-
counts =
|
50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
return (
|
52 |
-
gr.BarPlot(
|
53 |
texts_df[texts_df["quality"] == "Low"][:20],
|
54 |
texts_df[texts_df["quality"] == "Medium"][:20],
|
55 |
texts_df[texts_df["quality"] == "High"][:20],
|
@@ -62,42 +75,56 @@ def run_quality_check(dataset, column, batch_size, num_examples):
|
|
62 |
texts = data[column].to_list()
|
63 |
# batch_size = 100
|
64 |
predictions, texts_processed = [], []
|
65 |
-
|
|
|
66 |
batch_texts = texts[i:i+batch_size]
|
67 |
batch_predictions = predict(batch_texts)
|
68 |
predictions.extend(batch_predictions)
|
69 |
texts_processed.extend(batch_texts)
|
70 |
-
yield plot_and_df(texts_processed, predictions)
|
71 |
-
|
72 |
|
73 |
with gr.Blocks() as demo:
|
74 |
-
gr.Markdown(
|
|
|
|
|
|
|
|
|
|
|
75 |
dataset_name = HuggingfaceHubSearch(
|
76 |
label="Hub Dataset ID",
|
77 |
placeholder="Search for dataset id on Huggingface",
|
78 |
search_type="dataset",
|
79 |
-
value="fka/awesome-chatgpt-prompts",
|
80 |
)
|
81 |
# config_name = "default" # TODO: user input
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
|
|
|
|
93 |
text_column = gr.Textbox(placeholder="text", label="Text colum name to check (data must be non-nested, raw texts!)")
|
94 |
-
batch_size = gr.
|
95 |
-
num_examples = gr.Number(1000, label="
|
96 |
gr_check_btn = gr.Button("Check Dataset")
|
|
|
97 |
plot = gr.BarPlot()
|
98 |
|
99 |
with gr.Accordion("Explore some individual examples for each class", open=False):
|
100 |
-
|
101 |
-
|
|
|
|
|
|
|
|
|
|
|
102 |
|
103 |
demo.launch()
|
|
|
2 |
import polars as pl
|
3 |
from gradio_huggingfacehub_search import HuggingfaceHubSearch
|
4 |
import torch
|
5 |
+
from holoviews.ipython.widgets import progress
|
6 |
+
# import spaces
|
7 |
from torch import nn
|
8 |
from transformers import AutoModel, AutoTokenizer, AutoConfig
|
9 |
from huggingface_hub import PyTorchModelHubMixin
|
10 |
import pandas as pd
|
11 |
+
from collections import Counter
|
12 |
|
13 |
|
14 |
class QualityModel(nn.Module, PyTorchModelHubMixin):
|
|
|
33 |
model.eval()
|
34 |
|
35 |
|
36 |
+
# @spaces.GPU
|
37 |
def predict(texts: list[str]):
|
38 |
inputs = tokenizer(
|
39 |
texts, return_tensors="pt", padding="longest", truncation=True
|
|
|
46 |
return predicted_domains
|
47 |
|
48 |
|
49 |
+
# def progress():
|
50 |
+
# title = f"Scan finished" if num_rows == next_row_idx else "Scan in progress..."
|
51 |
+
|
52 |
+
|
53 |
+
|
54 |
def plot_and_df(texts, preds):
|
55 |
texts_df = pd.DataFrame({"quality": preds, "text": texts})
|
56 |
+
counts = Counter(preds)
|
57 |
+
counts_df = pd.DataFrame(
|
58 |
+
{
|
59 |
+
"quality": ["Low", "Medium", "High"],
|
60 |
+
"count": [counts.get("Low", 0), counts.get("Medium", 0), counts.get("High", 0)]
|
61 |
+
}
|
62 |
+
)
|
63 |
+
# counts.reset_index(inplace=True)
|
64 |
return (
|
65 |
+
gr.BarPlot(counts_df, x="quality", y="count"),
|
66 |
texts_df[texts_df["quality"] == "Low"][:20],
|
67 |
texts_df[texts_df["quality"] == "Medium"][:20],
|
68 |
texts_df[texts_df["quality"] == "High"][:20],
|
|
|
75 |
texts = data[column].to_list()
|
76 |
# batch_size = 100
|
77 |
predictions, texts_processed = [], []
|
78 |
+
num_examples = min(len(texts), num_examples)
|
79 |
+
for i in range(0, num_examples, batch_size):
|
80 |
batch_texts = texts[i:i+batch_size]
|
81 |
batch_predictions = predict(batch_texts)
|
82 |
predictions.extend(batch_predictions)
|
83 |
texts_processed.extend(batch_texts)
|
84 |
+
yield {"scan in progress...": (i+batch_size) / num_examples}, *plot_and_df(texts_processed, predictions)
|
85 |
+
yield {"finished": 1.}, *plot_and_df(texts_processed, predictions)
|
86 |
|
87 |
with gr.Blocks() as demo:
|
88 |
+
gr.Markdown(
|
89 |
+
"""
|
90 |
+
# π« Dataset Quality Checker π«
|
91 |
+
Use [nvidia/quality-classifier-deberta](https://huggingface.co/nvidia/quality-classifier-deberta) on any text dataset on the Hub.
|
92 |
+
"""
|
93 |
+
)
|
94 |
dataset_name = HuggingfaceHubSearch(
|
95 |
label="Hub Dataset ID",
|
96 |
placeholder="Search for dataset id on Huggingface",
|
97 |
search_type="dataset",
|
98 |
+
# value="fka/awesome-chatgpt-prompts",
|
99 |
)
|
100 |
# config_name = "default" # TODO: user input
|
101 |
+
with gr.Accordion("Dataset preview", open=False):
|
102 |
+
@gr.render(inputs=dataset_name)
|
103 |
+
def embed(name):
|
104 |
+
html_code = f"""
|
105 |
+
<iframe
|
106 |
+
src="https://huggingface.co/datasets/{name}/embed/viewer/default/train"
|
107 |
+
frameborder="0"
|
108 |
+
width="100%"
|
109 |
+
height="700px"
|
110 |
+
></iframe>
|
111 |
+
"""
|
112 |
+
return gr.HTML(value=html_code)
|
113 |
+
|
114 |
text_column = gr.Textbox(placeholder="text", label="Text colum name to check (data must be non-nested, raw texts!)")
|
115 |
+
batch_size = gr.Slider(0, 128, 64, step=8, label="Batch size (set this to smaller value if this space crashes.)")
|
116 |
+
num_examples = gr.Number(1000, label="Number of first examples to check")
|
117 |
gr_check_btn = gr.Button("Check Dataset")
|
118 |
+
progress_bar = gr.Label(show_label=False)
|
119 |
plot = gr.BarPlot()
|
120 |
|
121 |
with gr.Accordion("Explore some individual examples for each class", open=False):
|
122 |
+
gr.Markdown("### Low")
|
123 |
+
df_low = gr.DataFrame()
|
124 |
+
gr.Markdown("### Medium")
|
125 |
+
df_medium = gr.DataFrame()
|
126 |
+
gr.Markdown("### High")
|
127 |
+
df_high = gr.DataFrame()
|
128 |
+
gr_check_btn.click(run_quality_check, inputs=[dataset_name, text_column, batch_size, num_examples], outputs=[progress_bar, plot, df_low, df_medium, df_high])
|
129 |
|
130 |
demo.launch()
|