polinaeterna HF staff commited on
Commit
d806dcd
Β·
1 Parent(s): e5960a0

add progress bar

Browse files
Files changed (1) hide show
  1. app.py +52 -25
app.py CHANGED
@@ -2,11 +2,13 @@ import gradio as gr
2
  import polars as pl
3
  from gradio_huggingfacehub_search import HuggingfaceHubSearch
4
  import torch
5
- import spaces
 
6
  from torch import nn
7
  from transformers import AutoModel, AutoTokenizer, AutoConfig
8
  from huggingface_hub import PyTorchModelHubMixin
9
  import pandas as pd
 
10
 
11
 
12
  class QualityModel(nn.Module, PyTorchModelHubMixin):
@@ -31,7 +33,7 @@ model = QualityModel.from_pretrained("nvidia/quality-classifier-deberta").to(dev
31
  model.eval()
32
 
33
 
34
- @spaces.GPU
35
  def predict(texts: list[str]):
36
  inputs = tokenizer(
37
  texts, return_tensors="pt", padding="longest", truncation=True
@@ -44,12 +46,23 @@ def predict(texts: list[str]):
44
  return predicted_domains
45
 
46
 
 
 
 
 
 
47
  def plot_and_df(texts, preds):
48
  texts_df = pd.DataFrame({"quality": preds, "text": texts})
49
- counts = pd.DataFrame({"quality": preds}).value_counts().to_frame()
50
- counts.reset_index(inplace=True)
 
 
 
 
 
 
51
  return (
52
- gr.BarPlot(counts, x="quality", y="count"),
53
  texts_df[texts_df["quality"] == "Low"][:20],
54
  texts_df[texts_df["quality"] == "Medium"][:20],
55
  texts_df[texts_df["quality"] == "High"][:20],
@@ -62,42 +75,56 @@ def run_quality_check(dataset, column, batch_size, num_examples):
62
  texts = data[column].to_list()
63
  # batch_size = 100
64
  predictions, texts_processed = [], []
65
- for i in range(0, min(len(texts), num_examples), batch_size):
 
66
  batch_texts = texts[i:i+batch_size]
67
  batch_predictions = predict(batch_texts)
68
  predictions.extend(batch_predictions)
69
  texts_processed.extend(batch_texts)
70
- yield plot_and_df(texts_processed, predictions)
71
-
72
 
73
  with gr.Blocks() as demo:
74
- gr.Markdown("# πŸ’« Dataset Quality Checker πŸ’«")
 
 
 
 
 
75
  dataset_name = HuggingfaceHubSearch(
76
  label="Hub Dataset ID",
77
  placeholder="Search for dataset id on Huggingface",
78
  search_type="dataset",
79
- value="fka/awesome-chatgpt-prompts",
80
  )
81
  # config_name = "default" # TODO: user input
82
- @gr.render(inputs=dataset_name)
83
- def embed(name):
84
- html_code = f"""
85
- <iframe
86
- src="https://huggingface.co/datasets/{name}/embed/viewer/default/train"
87
- frameborder="0"
88
- width="100%"
89
- height="700px"
90
- ></iframe>
91
- """
92
- return gr.HTML(value=html_code)
 
 
93
  text_column = gr.Textbox(placeholder="text", label="Text colum name to check (data must be non-nested, raw texts!)")
94
- batch_size = gr.Number(100, label="Batch size")
95
- num_examples = gr.Number(1000, label="Num examples to check")
96
  gr_check_btn = gr.Button("Check Dataset")
 
97
  plot = gr.BarPlot()
98
 
99
  with gr.Accordion("Explore some individual examples for each class", open=False):
100
- df_low, df_medium, df_high = gr.DataFrame(), gr.DataFrame(), gr.DataFrame()
101
- gr_check_btn.click(run_quality_check, inputs=[dataset_name, text_column, batch_size, num_examples], outputs=[plot, df_low, df_medium, df_high])
 
 
 
 
 
102
 
103
  demo.launch()
 
2
  import polars as pl
3
  from gradio_huggingfacehub_search import HuggingfaceHubSearch
4
  import torch
5
+ from holoviews.ipython.widgets import progress
6
+ # import spaces
7
  from torch import nn
8
  from transformers import AutoModel, AutoTokenizer, AutoConfig
9
  from huggingface_hub import PyTorchModelHubMixin
10
  import pandas as pd
11
+ from collections import Counter
12
 
13
 
14
  class QualityModel(nn.Module, PyTorchModelHubMixin):
 
33
  model.eval()
34
 
35
 
36
+ # @spaces.GPU
37
  def predict(texts: list[str]):
38
  inputs = tokenizer(
39
  texts, return_tensors="pt", padding="longest", truncation=True
 
46
  return predicted_domains
47
 
48
 
49
+ # def progress():
50
+ # title = f"Scan finished" if num_rows == next_row_idx else "Scan in progress..."
51
+
52
+
53
+
54
  def plot_and_df(texts, preds):
55
  texts_df = pd.DataFrame({"quality": preds, "text": texts})
56
+ counts = Counter(preds)
57
+ counts_df = pd.DataFrame(
58
+ {
59
+ "quality": ["Low", "Medium", "High"],
60
+ "count": [counts.get("Low", 0), counts.get("Medium", 0), counts.get("High", 0)]
61
+ }
62
+ )
63
+ # counts.reset_index(inplace=True)
64
  return (
65
+ gr.BarPlot(counts_df, x="quality", y="count"),
66
  texts_df[texts_df["quality"] == "Low"][:20],
67
  texts_df[texts_df["quality"] == "Medium"][:20],
68
  texts_df[texts_df["quality"] == "High"][:20],
 
75
  texts = data[column].to_list()
76
  # batch_size = 100
77
  predictions, texts_processed = [], []
78
+ num_examples = min(len(texts), num_examples)
79
+ for i in range(0, num_examples, batch_size):
80
  batch_texts = texts[i:i+batch_size]
81
  batch_predictions = predict(batch_texts)
82
  predictions.extend(batch_predictions)
83
  texts_processed.extend(batch_texts)
84
+ yield {"scan in progress...": (i+batch_size) / num_examples}, *plot_and_df(texts_processed, predictions)
85
+ yield {"finished": 1.}, *plot_and_df(texts_processed, predictions)
86
 
87
  with gr.Blocks() as demo:
88
+ gr.Markdown(
89
+ """
90
+ # πŸ’« Dataset Quality Checker πŸ’«
91
+ Use [nvidia/quality-classifier-deberta](https://huggingface.co/nvidia/quality-classifier-deberta) on any text dataset on the Hub.
92
+ """
93
+ )
94
  dataset_name = HuggingfaceHubSearch(
95
  label="Hub Dataset ID",
96
  placeholder="Search for dataset id on Huggingface",
97
  search_type="dataset",
98
+ # value="fka/awesome-chatgpt-prompts",
99
  )
100
  # config_name = "default" # TODO: user input
101
+ with gr.Accordion("Dataset preview", open=False):
102
+ @gr.render(inputs=dataset_name)
103
+ def embed(name):
104
+ html_code = f"""
105
+ <iframe
106
+ src="https://huggingface.co/datasets/{name}/embed/viewer/default/train"
107
+ frameborder="0"
108
+ width="100%"
109
+ height="700px"
110
+ ></iframe>
111
+ """
112
+ return gr.HTML(value=html_code)
113
+
114
  text_column = gr.Textbox(placeholder="text", label="Text colum name to check (data must be non-nested, raw texts!)")
115
+ batch_size = gr.Slider(0, 128, 64, step=8, label="Batch size (set this to smaller value if this space crashes.)")
116
+ num_examples = gr.Number(1000, label="Number of first examples to check")
117
  gr_check_btn = gr.Button("Check Dataset")
118
+ progress_bar = gr.Label(show_label=False)
119
  plot = gr.BarPlot()
120
 
121
  with gr.Accordion("Explore some individual examples for each class", open=False):
122
+ gr.Markdown("### Low")
123
+ df_low = gr.DataFrame()
124
+ gr.Markdown("### Medium")
125
+ df_medium = gr.DataFrame()
126
+ gr.Markdown("### High")
127
+ df_high = gr.DataFrame()
128
+ gr_check_btn.click(run_quality_check, inputs=[dataset_name, text_column, batch_size, num_examples], outputs=[progress_bar, plot, df_low, df_medium, df_high])
129
 
130
  demo.launch()