lhoestq HF staff commited on
Commit
4482b40
·
1 Parent(s): e4a82b3

add full generation

Browse files
Files changed (1) hide show
  1. app.py +227 -28
app.py CHANGED
@@ -1,9 +1,14 @@
 
 
1
  import time
2
  from itertools import islice
3
  from functools import partial
4
- from typing import Iterable, Iterator, TypeVar
 
 
5
 
6
  import gradio as gr
 
7
  import requests.exceptions
8
  from huggingface_hub import InferenceClient
9
 
@@ -13,6 +18,8 @@ client = InferenceClient(model_id)
13
 
14
  MAX_TOTAL_NB_ITEMS = 100 # almost infinite, don't judge me (actually it's because gradio needs a fixed number of components)
15
  MAX_NB_ITEMS_PER_GENERATION_CALL = 10
 
 
16
  URL = "https://huggingface.co/spaces/infinite-dataset-hub/infinite-dataset-hub"
17
 
18
  GENERATE_DATASET_NAMES_FOR_SEARCH_QUERY = (
@@ -29,6 +36,23 @@ GENERATE_DATASET_CONTENT_FOR_SEARCH_QUERY_AND_NAME_AND_TAGS = (
29
  "Focus on quality text content and and use a 'label' or 'labels' column if it makes sense (invent labels, avoid reusing the keywords, be accurate while labelling texts). "
30
  "Reply using a short description of the dataset with title **Dataset Description:** followed by the CSV content in a code block and with title **CSV Content Preview:**."
31
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
  landing_page_query = "various datasets on many different subjects and topics, from classification to language modeling, from science to sport to finance to news"
34
 
@@ -174,27 +198,24 @@ with gr.Blocks(css=css) as demo:
174
  with gr.Column(scale=4, min_width=0):
175
  pass
176
  with gr.Column(visible=False) as dataset_page:
177
- with gr.Row():
178
- with gr.Column(scale=4, min_width=0):
179
- pass
180
- with gr.Column(scale=10):
181
- dataset_title = gr.Markdown()
182
- dataset_content = gr.Markdown()
183
- with gr.Row():
184
- with gr.Column(scale=4, min_width=0):
185
- pass
186
- with gr.Column():
187
- generate_full_dataset_button = gr.Button("Generate Full Dataset", variant="primary") # TODO: implement
188
- dataset_share_button = gr.Button("Share Dataset URL")
189
- dataset_share_textbox = gr.Textbox(visible=False, show_copy_button=True, label="Copy this URL:", interactive=False, show_label=True)
190
- back_button = gr.Button("< Back", size="sm")
191
- with gr.Column(scale=4, min_width=0):
192
- pass
193
- with gr.Column(scale=4, min_width=0):
194
- pass
195
 
196
  app_state = gr.State({})
197
 
 
 
 
 
 
 
198
  T = TypeVar("T")
199
 
200
  def batched(it: Iterable[T], n: int) -> Iterator[list[T]]:
@@ -264,6 +285,139 @@ with gr.Blocks(css=css) as demo:
264
  print("-----\n\n" + generated_text)
265
 
266
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
267
  def _search_datasets(search_query):
268
  yield {generated_texts_state: [], app_state: {"search_query": search_query}}
269
  yield {
@@ -303,11 +457,13 @@ with gr.Blocks(css=css) as demo:
303
  @search_button.click(inputs=search_bar, outputs=button_groups + buttons + [generated_texts_state, app_state])
304
  def search_dataset_from_search_button(search_query):
305
  yield from _search_datasets(search_query)
306
-
 
307
  @search_bar.submit(inputs=search_bar, outputs=button_groups + buttons + [generated_texts_state, app_state])
308
  def search_dataset_from_search_bar(search_query):
309
  yield from _search_datasets(search_query)
310
 
 
311
  @load_more_datasets.click(inputs=[search_bar, generated_texts_state], outputs=button_groups + buttons + [generated_texts_state])
312
  def search_more_datasets(search_query, generated_texts):
313
  current_item_idx = initial_item_idx = len(generated_texts) * MAX_NB_ITEMS_PER_GENERATION_CALL
@@ -339,8 +495,11 @@ with gr.Blocks(css=css) as demo:
339
  yield {
340
  search_page: gr.Column(visible=False),
341
  dataset_page: gr.Column(visible=True),
342
- dataset_title: f"# {dataset_name}\n\n tags: {tags}\n\n _Note: This is an AI-generated dataset so its content may be inaccurate or false_",
343
  dataset_share_textbox: gr.Textbox(visible=False),
 
 
 
344
  app_state: {
345
  "search_query": search_query,
346
  "dataset_name": dataset_name,
@@ -352,7 +511,7 @@ with gr.Blocks(css=css) as demo:
352
 
353
 
354
  show_dataset_inputs = [search_bar, *buttons]
355
- show_dataset_outputs = [app_state, search_page, dataset_page, dataset_title, dataset_content, dataset_share_textbox]
356
  scroll_to_top_js = """
357
  function (...args) {
358
  console.log(args);
@@ -363,7 +522,7 @@ with gr.Blocks(css=css) as demo:
363
  }
364
  return args;
365
  }
366
- """.replace("len(show_dataset_inputs)", str(len(show_dataset_inputs)))
367
 
368
  def show_dataset_from_button(search_query, *buttons_values, i):
369
  dataset_name, tags = buttons_values[2 * i : 2 * i + 2]
@@ -374,18 +533,58 @@ with gr.Blocks(css=css) as demo:
374
  tags_button.click(partial(show_dataset_from_button, i=i), inputs=show_dataset_inputs, outputs=show_dataset_outputs, js=scroll_to_top_js)
375
 
376
 
377
- @back_button.click(outputs=[search_page, dataset_page])
378
  def show_search_page():
379
  return gr.Column(visible=True), gr.Column(visible=False)
380
 
381
- @generate_full_dataset_button.click()
382
- def generate_full_dataset():
383
- raise gr.Error("Not implemented yet sorry ! Request your dataset in the Discussion tab (provide the dataset URL)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
384
 
385
  @dataset_share_button.click(inputs=[app_state], outputs=[dataset_share_textbox])
386
  def show_dataset_url(state):
387
  return gr.Textbox(
388
- f"{URL}?q={state['search_query'].replace(' ', '+')}&dataset={state['dataset_name']}&tags={state['tags']}",
389
  visible=True,
390
  )
391
 
 
1
+ import io
2
+ import re
3
  import time
4
  from itertools import islice
5
  from functools import partial
6
+ from multiprocessing.pool import ThreadPool
7
+ from queue import Queue, Empty
8
+ from typing import Callable, Iterable, Iterator, Optional, TypeVar
9
 
10
  import gradio as gr
11
+ import pandas as pd
12
  import requests.exceptions
13
  from huggingface_hub import InferenceClient
14
 
 
18
 
19
  MAX_TOTAL_NB_ITEMS = 100 # almost infinite, don't judge me (actually it's because gradio needs a fixed number of components)
20
  MAX_NB_ITEMS_PER_GENERATION_CALL = 10
21
+ NUM_ROWS = 100
22
+ NUM_VARIANTS = 10
23
  URL = "https://huggingface.co/spaces/infinite-dataset-hub/infinite-dataset-hub"
24
 
25
  GENERATE_DATASET_NAMES_FOR_SEARCH_QUERY = (
 
36
  "Focus on quality text content and and use a 'label' or 'labels' column if it makes sense (invent labels, avoid reusing the keywords, be accurate while labelling texts). "
37
  "Reply using a short description of the dataset with title **Dataset Description:** followed by the CSV content in a code block and with title **CSV Content Preview:**."
38
  )
39
+ GENERATE_MORE_ROWS = "Can you give me 10 additional samples in CSV format as well ? Use the same CSV header '{csv_header}'."
40
+ GENERATE_VARIANTS_WITH_RARITY_AND_LABEL = "Focus on generating samples for the label '{label}' and ideally generate {rarity} samples."
41
+ GENERATE_VARIANTS_WITH_RARITY = "Focus on generating {rarity} samples."
42
+
43
+ RARITIES = ["pretty obvious", "common/regular", "unexpected but useful", "uncommon but still plausible", "rare/niche but still plausible"]
44
+ LONG_RARITIES = [
45
+ "obvious",
46
+ "expected",
47
+ "common",
48
+ "regular",
49
+ "unexpected but useful"
50
+ "original but useful",
51
+ "specific but not far-fetched",
52
+ "uncommon but still plausible",
53
+ "rare but still plausible",
54
+ "very nice but still plausible",
55
+ ]
56
 
57
  landing_page_query = "various datasets on many different subjects and topics, from classification to language modeling, from science to sport to finance to news"
58
 
 
198
  with gr.Column(scale=4, min_width=0):
199
  pass
200
  with gr.Column(visible=False) as dataset_page:
201
+ dataset_title = gr.Markdown()
202
+ gr.Markdown("_Note: This is an AI-generated dataset so its content may be inaccurate or false_")
203
+ dataset_content = gr.Markdown()
204
+ generate_full_dataset_button = gr.Button("Generate Full Dataset", variant="primary")
205
+ dataset_dataframe = gr.DataFrame(visible=False, interactive=False, wrap=True)
206
+ save_dataset_button = gr.Button("💾 Save Dataset", variant="primary", visible=False)
207
+ dataset_share_button = gr.Button("Share Dataset URL")
208
+ dataset_share_textbox = gr.Textbox(visible=False, show_copy_button=True, label="Copy this URL:", interactive=False, show_label=True)
209
+ back_button = gr.Button("< Back", size="sm")
 
 
 
 
 
 
 
 
 
210
 
211
  app_state = gr.State({})
212
 
213
+ ###################################
214
+ #
215
+ # Utils
216
+ #
217
+ ###################################
218
+
219
  T = TypeVar("T")
220
 
221
  def batched(it: Iterable[T], n: int) -> Iterator[list[T]]:
 
285
  print("-----\n\n" + generated_text)
286
 
287
 
288
+ def _write_generator_to_queue(queue: Queue, func: Callable[..., Iterable], kwargs: dict) -> None:
289
+ for i, result in enumerate(func(**kwargs)):
290
+ queue.put(result)
291
+ return None
292
+
293
+
294
+ def iflatmap_unordered(
295
+ func: Callable[..., Iterable[T]],
296
+ *,
297
+ kwargs_iterable: Iterable[dict],
298
+ ) -> Iterable[T]:
299
+ queue = Queue()
300
+ with ThreadPool() as pool:
301
+ async_results = [
302
+ pool.apply_async(_write_generator_to_queue, (queue, func, kwargs)) for kwargs in kwargs_iterable
303
+ ]
304
+ try:
305
+ while True:
306
+ try:
307
+ yield queue.get(timeout=0.05)
308
+ except Empty:
309
+ if all(async_result.ready() for async_result in async_results) and queue.empty():
310
+ break
311
+ finally:
312
+ # we get the result in case there's an error to raise
313
+ [async_result.get(timeout=0.05) for async_result in async_results]
314
+
315
+
316
+ def generate_partial_dataset(title: str, content: str, search_query: str, variant: str, csv_header: str, output: list[dict[str, str]], indices_to_generate: list[int], max_tokens=1500) -> Iterator[int]:
317
+ dataset_name, tags = title.strip("# ").split("\ntags:", 1)
318
+ dataset_name, tags = dataset_name.strip(), tags.strip()
319
+ messages = [
320
+ {
321
+ "role": "user",
322
+ "content": GENERATE_DATASET_CONTENT_FOR_SEARCH_QUERY_AND_NAME_AND_TAGS.format(
323
+ dataset_name=dataset_name,
324
+ tags=tags,
325
+ search_query=search_query,
326
+ )
327
+ },
328
+ {"role": "assistant", "content": title + "\n\n" + content},
329
+ {"role": "user", "content": GENERATE_MORE_ROWS.format(csv_header=csv_header) + " " + variant},
330
+ ]
331
+ for _ in range(3):
332
+ generated_text = ""
333
+ generated_csv = ""
334
+ current_line = ""
335
+ nb_samples = 0
336
+ _in_csv = False
337
+ try:
338
+ for message in client.chat_completion(
339
+ messages=messages,
340
+ max_tokens=max_tokens,
341
+ stream=True,
342
+ top_p=0.8,
343
+ seed=42,
344
+ ):
345
+ if nb_samples >= len(indices_to_generate):
346
+ break
347
+ current_line += message.choices[0].delta.content
348
+ generated_text += message.choices[0].delta.content
349
+ if current_line.endswith("\n"):
350
+ _in_csv = _in_csv ^ current_line.lstrip().startswith("```")
351
+ if current_line.strip() and _in_csv and not current_line.lstrip().startswith("```"):
352
+ generated_csv += current_line
353
+ try:
354
+ generated_df = parse_csv_df(generated_csv.strip(), csv_header=csv_header)
355
+ if len(generated_df) > nb_samples:
356
+ output[indices_to_generate[nb_samples]] = generated_df.iloc[-1].to_dict()
357
+ nb_samples += 1
358
+ yield 1
359
+ except Exception:
360
+ pass
361
+ current_line = ""
362
+ except requests.exceptions.ConnectionError as e:
363
+ print(e + "\n\nRetrying in 1sec")
364
+ time.sleep(1)
365
+ continue
366
+ break
367
+ # for debugging
368
+ # with open(f"output{indices_to_generate[0]}.txt", "w") as f:
369
+ # f.write(generated_text)
370
+
371
+
372
+ def generate_variants(preview_df: pd.DataFrame):
373
+ label_candidate_columns = [column for column in preview_df.columns if "label" in column.lower()]
374
+ if label_candidate_columns:
375
+ labels = preview_df[label_candidate_columns[0]].unique()
376
+ if len(labels) > 1:
377
+ return [
378
+ GENERATE_VARIANTS_WITH_RARITY_AND_LABEL.format(rarity=rarity, label=label)
379
+ for rarity in RARITIES
380
+ for label in labels
381
+ ]
382
+ return [
383
+ GENERATE_VARIANTS_WITH_RARITY.format(rarity=rarity)
384
+ for rarity in LONG_RARITIES
385
+ ]
386
+
387
+
388
+ def parse_preview_df(content: str) -> tuple[str, pd.DataFrame]:
389
+ _in_csv = False
390
+ csv = "\n".join(
391
+ line for line in content.split("\n") if line.strip()
392
+ and (_in_csv := (_in_csv ^ line.lstrip().startswith("```")))
393
+ and not line.lstrip().startswith("```")
394
+ )
395
+ if not csv:
396
+ raise gr.Error("Failed to parse CSV Preview")
397
+ return csv.split("\n")[0], parse_csv_df(csv)
398
+
399
+
400
+ def parse_csv_df(csv: str, csv_header: Optional[str] = None) -> pd.DataFrame:
401
+ # Fix generation mistake when providing a list that is not in quotes
402
+ if ",[" in csv:
403
+ for match in re.finditer(r'\[("[\w ]+"[, ]?)+\]', csv):
404
+ span = match.string[match.start() : match.end()]
405
+ csv = csv.replace(span, '"' + span.replace('"', "'") + '"')
406
+ # Add header if missing
407
+ if csv_header and csv.strip().split("\n")[0] != csv_header:
408
+ csv = csv_header + "\n" + csv
409
+ # Read CSV
410
+ df = pd.read_csv(io.StringIO(csv))
411
+ return df
412
+
413
+
414
+ ###################################
415
+ #
416
+ # Buttons
417
+ #
418
+ ###################################
419
+
420
+
421
  def _search_datasets(search_query):
422
  yield {generated_texts_state: [], app_state: {"search_query": search_query}}
423
  yield {
 
457
  @search_button.click(inputs=search_bar, outputs=button_groups + buttons + [generated_texts_state, app_state])
458
  def search_dataset_from_search_button(search_query):
459
  yield from _search_datasets(search_query)
460
+
461
+
462
  @search_bar.submit(inputs=search_bar, outputs=button_groups + buttons + [generated_texts_state, app_state])
463
  def search_dataset_from_search_bar(search_query):
464
  yield from _search_datasets(search_query)
465
 
466
+
467
  @load_more_datasets.click(inputs=[search_bar, generated_texts_state], outputs=button_groups + buttons + [generated_texts_state])
468
  def search_more_datasets(search_query, generated_texts):
469
  current_item_idx = initial_item_idx = len(generated_texts) * MAX_NB_ITEMS_PER_GENERATION_CALL
 
495
  yield {
496
  search_page: gr.Column(visible=False),
497
  dataset_page: gr.Column(visible=True),
498
+ dataset_title: f"# {dataset_name}\n\n tags: {tags}",
499
  dataset_share_textbox: gr.Textbox(visible=False),
500
+ dataset_dataframe: gr.DataFrame(visible=False),
501
+ generate_full_dataset_button: gr.Button(visible=True),
502
+ save_dataset_button: gr.Button(visible=False),
503
  app_state: {
504
  "search_query": search_query,
505
  "dataset_name": dataset_name,
 
511
 
512
 
513
  show_dataset_inputs = [search_bar, *buttons]
514
+ show_dataset_outputs = [app_state, search_page, dataset_page, dataset_title, dataset_content, generate_full_dataset_button, dataset_dataframe, save_dataset_button, dataset_share_textbox]
515
  scroll_to_top_js = """
516
  function (...args) {
517
  console.log(args);
 
522
  }
523
  return args;
524
  }
525
+ """
526
 
527
  def show_dataset_from_button(search_query, *buttons_values, i):
528
  dataset_name, tags = buttons_values[2 * i : 2 * i + 2]
 
533
  tags_button.click(partial(show_dataset_from_button, i=i), inputs=show_dataset_inputs, outputs=show_dataset_outputs, js=scroll_to_top_js)
534
 
535
 
536
+ @back_button.click(outputs=[search_page, dataset_page], js=scroll_to_top_js)
537
  def show_search_page():
538
  return gr.Column(visible=True), gr.Column(visible=False)
539
 
540
+
541
+ @generate_full_dataset_button.click(inputs=[dataset_title, dataset_content, search_bar], outputs=[dataset_dataframe, generate_full_dataset_button, save_dataset_button])
542
+ def generate_full_dataset(title, content, search_query):
543
+ csv_header, preview_df = parse_preview_df(content)
544
+ # Remove dummy "id" columns
545
+ for column_name, values in preview_df.to_dict(orient="series").items():
546
+ try:
547
+ if [int(v) for v in values] == list(range(len(preview_df))):
548
+ preview_df = preview_df.drop(columns=column_name)
549
+ if [int(v) for v in values] == list(range(1, len(preview_df) + 1)):
550
+ preview_df = preview_df.drop(columns=column_name)
551
+ except Exception:
552
+ pass
553
+ columns = list(preview_df)
554
+ output: list[Optional[dict]] = [None] * NUM_ROWS
555
+ output[:len(preview_df)] = [{"idx": i, **x} for i, x in enumerate(preview_df.to_dict(orient="records"))]
556
+ yield {
557
+ dataset_dataframe: gr.DataFrame(pd.DataFrame([{"idx": i, **x} for i, x in enumerate(output) if x]), visible=True),
558
+ generate_full_dataset_button: gr.Button(visible=False),
559
+ save_dataset_button: gr.Button(visible=True, interactive=False)
560
+ }
561
+ kwargs_iterable = [
562
+ {
563
+ "title": title,
564
+ "content": content,
565
+ "search_query": search_query,
566
+ "variant": variant,
567
+ "csv_header": csv_header,
568
+ "output": output,
569
+ "indices_to_generate": list(range(len(preview_df) + i, NUM_ROWS, NUM_VARIANTS)),
570
+ }
571
+ for i, variant in enumerate(islice(generate_variants(preview_df), NUM_VARIANTS))
572
+ ]
573
+ for _ in iflatmap_unordered(generate_partial_dataset, kwargs_iterable=kwargs_iterable):
574
+ yield {dataset_dataframe: pd.DataFrame([{"idx": i, **{column_name: x.get(column_name) for column_name in columns}} for i, x in enumerate(output) if x])}
575
+ yield {save_dataset_button: gr.Button(visible=True, interactive=True)}
576
+ print(f"Sucessfulyl generated {dataset_name} !")
577
+
578
+
579
+ @save_dataset_button.click(inputs=[dataset_title, dataset_content, search_bar, dataset_dataframe])
580
+ def save_dataset(title, content, search_query, df):
581
+ raise gr.Error("Not implemented yet sorry ! Request your dataset to be saved in the Discussion tab (provide the dataset URL)")
582
+
583
 
584
  @dataset_share_button.click(inputs=[app_state], outputs=[dataset_share_textbox])
585
  def show_dataset_url(state):
586
  return gr.Textbox(
587
+ f"{URL}?q={state['search_query'].replace(' ', '+')}&dataset={state['dataset_name'].replace(' ', '+')}&tags={state['tags'].replace(' ', '+')}",
588
  visible=True,
589
  )
590