polinaeterna HF staff commited on
Commit
0a44dc6
Β·
1 Parent(s): 46c2a69
Files changed (1) hide show
  1. app.py +49 -31
app.py CHANGED
@@ -23,21 +23,6 @@ retries = Retry(total=5, backoff_factor=1, status_forcelist=[502, 503, 504])
23
  session.mount('http://', HTTPAdapter(max_retries=retries))
24
 
25
 
26
- def proportion_non_ascii(s):
27
- """
28
- Compute the proportion of non-ASCII characters in a string.
29
-
30
- Parameters:
31
- s (str): The input string.
32
-
33
- Returns:
34
- float: The proportion of non-ASCII characters in the string.
35
- """
36
- non_ascii_count = sum(1 for c in s if ord(c) > 127)
37
- total_chars = len(s)
38
- return non_ascii_count / total_chars if total_chars > 0 else 0.0
39
-
40
-
41
  class QualityModel(nn.Module, PyTorchModelHubMixin):
42
  def __init__(self, config):
43
  super(QualityModel, self).__init__()
@@ -95,7 +80,7 @@ def plot_and_df(texts, preds):
95
  def run_quality_check(dataset, column, batch_size, num_examples):
96
  info_resp = session.get(f"https://datasets-server.huggingface.co/info?dataset={dataset}", timeout=3).json()
97
  if "error" in info_resp:
98
- yield "❌ " + info_resp["error"], gr.BarPlot(), pd.DataFrame(), pd.DataFrame(), pd.DataFrame(), plt.Figure(), pd.DataFrame(),
99
  return
100
  config = "default" if "default" in info_resp["dataset_info"] else next(iter(info_resp["dataset_info"]))
101
  split = "train" if "train" in info_resp["dataset_info"][config]["splits"] else next(
@@ -106,10 +91,10 @@ def run_quality_check(dataset, column, batch_size, num_examples):
106
  try:
107
  data = pl.read_parquet(f"hf://datasets/{dataset}@~parquet/{config}/partial-{split}/0000.parquet", columns=[column])
108
  except Exception as error:
109
- yield f"❌ {error}", gr.BarPlot(), pd.DataFrame(), pd.DataFrame(), pd.DataFrame(), plt.Figure(), pd.DataFrame(),
110
  return
111
  texts = data[column].to_list()
112
- texts_sample = data.sample(100, shuffle=True, seed=16).to_pandas()
113
  # batch_size = 100
114
  predictions, texts_processed = [], []
115
  num_examples = min(len(texts), num_examples)
@@ -118,18 +103,18 @@ def run_quality_check(dataset, column, batch_size, num_examples):
118
  batch_predictions = predict(batch_texts)
119
  predictions.extend(batch_predictions)
120
  texts_processed.extend(batch_texts)
121
- yield {"check in progress...": min(i+batch_size, num_examples) / num_examples}, *plot_and_df(texts_processed, predictions), plt.Figure(), pd.DataFrame()
122
 
123
- with multiprocessing.Pool(processes=8) as pool:
124
- props = pool.map(proportion_non_ascii, texts)
 
 
 
 
 
 
125
 
126
- # non_ascii_df = pd.DataFrame.from_dict({"prop_non_ascii": props, "text": texts})
127
- plt.hist(props, bins=20, range=(0., 1.))
128
- plt.title('Histogram of proportion of non-ASCII characters')
129
- plt.xlabel('Proportion of non-ASCII characters')
130
- plt.ylabel('Number of texts')
131
-
132
- yield {"finished": 1.}, *plot_and_df(texts_processed, predictions), plt.gcf(), texts_sample
133
 
134
 
135
  PERSPECTIVE_API_KEY = os.environ.get("PERSPECTIVE_API_KEY")
@@ -199,12 +184,41 @@ def call_perspective_api(texts_df, column_name):#, s):
199
  return req_att_scores
200
  if i % 10 == 0:
201
  plot_toxicity(req_att_scores)
202
- yield {"toxicity check in progress...": i / n_samples}, plt.gcf(), pd.DataFrame()
203
 
204
  plot_toxicity(req_att_scores)
205
  yield {"toxicity check finished.": 1.}, plt.gcf(), pd.DataFrame.from_dict({column_name: texts, **req_att_scores})
206
 
207
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
  with gr.Blocks() as demo:
209
  gr.Markdown(
210
  """
@@ -248,14 +262,18 @@ with gr.Blocks() as demo:
248
  gr.Markdown("### High")
249
  df_high = gr.DataFrame()
250
 
251
- non_ascii_hist = gr.Plot()
252
  texts_sample_df = gr.DataFrame(visible=False)
253
  gr_check_btn.click(
254
  run_quality_check,
255
  inputs=[dataset_name, text_column, batch_size, num_examples],
256
- outputs=[progress_bar, plot, df_low, df_medium, df_high, non_ascii_hist, texts_sample_df]
257
  )
258
 
 
 
 
 
 
259
  gr_toxicity_btn = gr.Button("Run perpspective API to check toxicity of random samples.")
260
  toxicity_progress_bar = gr.Label(show_label=False)
261
  toxicity_hist = gr.Plot()
 
23
  session.mount('http://', HTTPAdapter(max_retries=retries))
24
 
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  class QualityModel(nn.Module, PyTorchModelHubMixin):
27
  def __init__(self, config):
28
  super(QualityModel, self).__init__()
 
80
  def run_quality_check(dataset, column, batch_size, num_examples):
81
  info_resp = session.get(f"https://datasets-server.huggingface.co/info?dataset={dataset}", timeout=3).json()
82
  if "error" in info_resp:
83
+ yield "❌ " + info_resp["error"], gr.BarPlot(), pd.DataFrame(), pd.DataFrame(), pd.DataFrame(), pd.DataFrame(),
84
  return
85
  config = "default" if "default" in info_resp["dataset_info"] else next(iter(info_resp["dataset_info"]))
86
  split = "train" if "train" in info_resp["dataset_info"][config]["splits"] else next(
 
91
  try:
92
  data = pl.read_parquet(f"hf://datasets/{dataset}@~parquet/{config}/partial-{split}/0000.parquet", columns=[column])
93
  except Exception as error:
94
+ yield f"❌ {error}", gr.BarPlot(), pd.DataFrame(), pd.DataFrame(), pd.DataFrame(), pd.DataFrame(),
95
  return
96
  texts = data[column].to_list()
97
+ # texts_sample = data.sample(100, shuffle=True, seed=16).to_pandas()
98
  # batch_size = 100
99
  predictions, texts_processed = [], []
100
  num_examples = min(len(texts), num_examples)
 
103
  batch_predictions = predict(batch_texts)
104
  predictions.extend(batch_predictions)
105
  texts_processed.extend(batch_texts)
106
+ yield {"check in progress...": min(i+batch_size, num_examples) / num_examples}, *plot_and_df(texts_processed, predictions), pd.DataFrame()
107
 
108
+ # with multiprocessing.Pool(processes=8) as pool:
109
+ # props = pool.map(proportion_non_ascii, texts)
110
+ #
111
+ # # non_ascii_df = pd.DataFrame.from_dict({"prop_non_ascii": props, "text": texts})
112
+ # plt.hist(props, bins=20, range=(0., 1.))
113
+ # plt.title('Histogram of proportion of non-ASCII characters')
114
+ # plt.xlabel('Proportion of non-ASCII characters')
115
+ # plt.ylabel('Number of texts')
116
 
117
+ yield {"finished": 1.}, *plot_and_df(texts_processed, predictions), data
 
 
 
 
 
 
118
 
119
 
120
  PERSPECTIVE_API_KEY = os.environ.get("PERSPECTIVE_API_KEY")
 
184
  return req_att_scores
185
  if i % 10 == 0:
186
  plot_toxicity(req_att_scores)
187
+ yield {"toxicity check in progress...": i / n_samples}, plt.gcf(), pd.DataFrame.from_dict({column_name: texts[:i], **req_att_scores})
188
 
189
  plot_toxicity(req_att_scores)
190
  yield {"toxicity check finished.": 1.}, plt.gcf(), pd.DataFrame.from_dict({column_name: texts, **req_att_scores})
191
 
192
 
193
+ def proportion_non_ascii(s):
194
+ """
195
+ Compute the proportion of non-ASCII characters in a string.
196
+
197
+ Parameters:
198
+ s (str): The input string.
199
+
200
+ Returns:
201
+ float: The proportion of non-ASCII characters in the string.
202
+ """
203
+ non_ascii_count = sum(1 for c in s if ord(c) > 127)
204
+ total_chars = len(s)
205
+ return non_ascii_count / total_chars if total_chars > 0 else 0.0
206
+
207
+
208
+ def non_ascii_check(texts_df, column_name):
209
+ texts = texts_df[column_name].to_list()
210
+ with multiprocessing.Pool(processes=8) as pool:
211
+ props = pool.map(proportion_non_ascii, texts)
212
+
213
+ # non_ascii_df = pd.DataFrame.from_dict({"prop_non_ascii": props, "text": texts})
214
+ plt.hist(props, bins=20, range=(0., 1.))
215
+ plt.title('Histogram of proportion of non-ASCII characters')
216
+ plt.xlabel('Proportion of non-ASCII characters')
217
+ plt.ylabel('Number of texts')
218
+
219
+ return plt.gcf()
220
+
221
+
222
  with gr.Blocks() as demo:
223
  gr.Markdown(
224
  """
 
262
  gr.Markdown("### High")
263
  df_high = gr.DataFrame()
264
 
 
265
  texts_sample_df = gr.DataFrame(visible=False)
266
  gr_check_btn.click(
267
  run_quality_check,
268
  inputs=[dataset_name, text_column, batch_size, num_examples],
269
+ outputs=[progress_bar, plot, df_low, df_medium, df_high, texts_sample_df]
270
  )
271
 
272
+ gr_ascii_btn = gr.Button("Non ascii chars.")
273
+ non_ascii_hist = gr.Plot()
274
+
275
+ gr_ascii_btn.click(non_ascii_check, inputs=[texts_sample_df, text_column], outputs=[non_ascii_hist])
276
+
277
  gr_toxicity_btn = gr.Button("Run perpspective API to check toxicity of random samples.")
278
  toxicity_progress_bar = gr.Label(show_label=False)
279
  toxicity_hist = gr.Plot()