lhoestq HF staff commited on
Commit
e2928bf
·
1 Parent(s): 8fe2070

add "see more" button

Browse files
Files changed (1) hide show
  1. app.py +209 -103
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import time
 
2
  from functools import partial
3
- from typing import Iterator
4
 
5
  import gradio as gr
6
  import requests.exceptions
@@ -10,9 +11,12 @@ from huggingface_hub import InferenceClient
10
  model_id = "microsoft/Phi-3-mini-4k-instruct"
11
  client = InferenceClient(model_id)
12
 
 
 
 
13
  GENERATE_DATASET_NAMES_FOR_SEARCH_QUERY = (
14
  "A Machine Learning Practioner is looking for a dataset that matches '{search_query}'. "
15
- "Generate a list of 10 names of quality dataset that don't exist but sound plausible and would "
16
  "be helpful. Feel free to reuse words from the query '{search_query}' to name the datasets. "
17
  "Every dataset should be about '{search_query}' and have descriptive tags/keywords including the ML task name associated to the dataset (classification, regression, anomaly detection, etc.). Use the following format:\n1. DatasetName1 (tag1, tag2, tag3)\n1. DatasetName2 (tag1, tag2, tag3)"
18
  )
@@ -25,52 +29,9 @@ GENERATE_DATASET_CONTENT_FOR_SEARCH_QUERY_AND_NAME_AND_TAGS = (
25
  "Reply using a short description of the dataset with title **Dataset Description:** followed by the CSV content in a code block and with title **CSV Content Preview:**."
26
  )
27
 
28
- default_query = "various datasets on many different subjects and topics, from classification to language modeling, from science to sport to finance to news"
29
-
30
-
31
- def stream_reponse(msg: str, max_tokens=500) -> Iterator[str]:
32
- for _ in range(3):
33
- try:
34
- for message in client.chat_completion(
35
- messages=[{"role": "user", "content": msg}],
36
- max_tokens=max_tokens,
37
- stream=True,
38
- ):
39
- yield message.choices[0].delta.content
40
- except requests.exceptions.ConnectionError as e:
41
- print(e + "\n\nRetrying in 1sec")
42
- time.sleep(1)
43
- continue
44
- break
45
-
46
-
47
- def gen_datasets(search_query: str) -> Iterator[str]:
48
- search_query = search_query[:1000] if search_query.strip() else default_query
49
- generated_text = ""
50
- for token in stream_reponse(GENERATE_DATASET_NAMES_FOR_SEARCH_QUERY.format(search_query=search_query)):
51
- generated_text += token
52
- if generated_text.endswith("\n"):
53
- yield generated_text.strip()
54
- yield generated_text.strip()
55
- print("-----\n\n" + generated_text)
56
-
57
-
58
- def gen_dataset_content(search_query: str, dataset_name: str, tags: str) -> Iterator[str]:
59
- search_query = search_query[:1000] if search_query.strip() else default_query
60
- generated_text = ""
61
- for token in stream_reponse(GENERATE_DATASET_CONTENT_FOR_SEARCH_QUERY_AND_NAME_AND_TAGS.format(
62
- search_query=search_query,
63
- dataset_name=dataset_name,
64
- tags=tags,
65
- ), max_tokens=1500):
66
- generated_text += token
67
- yield generated_text
68
- print("-----\n\n" + generated_text)
69
-
70
-
71
- NB_ITEMS_PER_PAGE = 10
72
-
73
- default_output = """
74
  1. NewsEventsPredict (classification, media, trend)
75
  2. FinancialForecast (economy, stocks, regression)
76
  3. HealthMonitor (science, real-time, anomaly detection)
@@ -81,10 +42,15 @@ default_output = """
81
  8. NewsEventTracker (classification, public awareness, topical clustering)
82
  9. HealthVitalSigns (anomaly detection, biometrics, prediction)
83
  10. GameStockPredict (classification, finance, sports contingency)
84
- """.strip().split("\n")
85
- assert len(default_output) == NB_ITEMS_PER_PAGE
 
86
 
87
  css = """
 
 
 
 
88
  .datasetButton {
89
  justify-content: start;
90
  justify-content: left;
@@ -93,9 +59,6 @@ css = """
93
  font-size: var(--button-small-text-size);
94
  color: var(--body-text-color-subdued);
95
  }
96
- a {
97
- color: var(--body-text-color);
98
- }
99
  .topButton {
100
  justify-content: start;
101
  justify-content: left;
@@ -134,6 +97,10 @@ a {
134
  .buttonsGroup div {
135
  background: transparent;
136
  }
 
 
 
 
137
  @keyframes placeHolderShimmer{
138
  0%{
139
  background-position: -468px 0
@@ -155,39 +122,9 @@ a {
155
  }
156
  """
157
 
158
- def search_datasets(search_query):
159
- output_values = [
160
- gr.Button("⬜⬜⬜⬜⬜⬜", elem_classes="topButton linear-background"),
161
- gr.Button("░░░░, ░░░░, ░░░░", elem_classes="bottomButton linear-background")
162
- ] * NB_ITEMS_PER_PAGE
163
- for generated_text in gen_datasets(search_query):
164
- if "I'm sorry" in generated_text:
165
- raise gr.Error("Error: inappropriate content")
166
- lines = [line for line in generated_text.split("\n") if line and line.split(".", 1)[0].isnumeric()][:NB_ITEMS_PER_PAGE]
167
- for i, line in enumerate(lines):
168
- dataset_name, tags = line.split(".", 1)[1].strip(" )").split(" (", 1)
169
- output_values[2 * i] = gr.Button(dataset_name, elem_classes="topButton")
170
- output_values[2 * i + 1] = gr.Button(tags, elem_classes="bottomButton")
171
- yield output_values
172
-
173
-
174
- def show_dataset(search_query, *buttons_values, i):
175
- dataset_name, tags = buttons_values[2 * i : 2 * i + 2]
176
- dataset_title = f"# {dataset_name}\n\n tags: {tags}\n\n _Note: This is an AI-generated dataset so its content may be inaccurate or false_"
177
- yield gr.Column(visible=False), gr.Column(visible=True), dataset_title, ""
178
- for generated_text in gen_dataset_content(search_query=search_query, dataset_name=dataset_name, tags=tags):
179
- yield gr.Column(), gr.Column(), dataset_title, generated_text
180
-
181
-
182
- def show_search_page():
183
- return gr.Column(visible=True), gr.Column(visible=False)
184
-
185
-
186
- def generate_full_dataset():
187
- raise gr.Error("Not implemented yet sorry ! Give me some feedbacks in the Community tab in the meantime ;)")
188
-
189
 
190
  with gr.Blocks(css=css) as demo:
 
191
  with gr.Row():
192
  with gr.Column(scale=4, min_width=0):
193
  pass
@@ -208,28 +145,34 @@ with gr.Blocks(css=css) as demo:
208
  search_button = gr.Button("🔍", variant="primary")
209
  with gr.Column(scale=4, min_width=0):
210
  pass
211
- inputs = [search_bar]
212
- show_dataset_outputs = [search_page]
213
  with gr.Row():
214
  with gr.Column(scale=4, min_width=0):
215
  pass
216
  with gr.Column(scale=10):
217
- buttons = []
218
- for i in range(10):
219
- line = default_output[i]
220
- dataset_name, tags = line.split(".", 1)[1].strip(" )").split(" (", 1)
221
- with gr.Group(elem_classes="buttonsGroup"):
222
- top = gr.Button(dataset_name, elem_classes="topButton")
223
- bottom = gr.Button(tags, elem_classes="bottomButton")
224
- buttons += [top, bottom]
225
- top.click(partial(show_dataset, i=i), inputs=inputs, outputs=show_dataset_outputs)
226
- bottom.click(partial(show_dataset, i=i), inputs=inputs, outputs=show_dataset_outputs)
227
- inputs += buttons
 
 
 
 
 
 
 
 
 
228
  gr.Markdown(f"_powered by [{model_id}](https://huggingface.co/{model_id})_")
229
  with gr.Column(scale=4, min_width=0):
230
  pass
231
- search_bar.submit(search_datasets, inputs=search_bar, outputs=buttons)
232
- search_button.click(search_datasets, inputs=search_bar, outputs=buttons)
233
  with gr.Column(visible=False) as dataset_page:
234
  with gr.Row():
235
  with gr.Column(scale=4, min_width=0):
@@ -241,13 +184,176 @@ with gr.Blocks(css=css) as demo:
241
  with gr.Column(scale=4, min_width=0):
242
  pass
243
  with gr.Column():
244
- generate_full_dataset_button = gr.Button("Generate Full Dataset", variant="primary")
245
- generate_full_dataset_button.click(generate_full_dataset)
246
  back_button = gr.Button("< Back", size="sm")
247
- back_button.click(show_search_page, inputs=[], outputs=[search_page, dataset_page])
248
  with gr.Column(scale=4, min_width=0):
249
  pass
250
  with gr.Column(scale=4, min_width=0):
251
  pass
252
- show_dataset_outputs += [dataset_page, dataset_title, dataset_content]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253
  demo.launch()
 
1
  import time
2
+ from itertools import count, islice
3
  from functools import partial
4
+ from typing import Iterable, Iterator, TypeVar
5
 
6
  import gradio as gr
7
  import requests.exceptions
 
11
  model_id = "microsoft/Phi-3-mini-4k-instruct"
12
  client = InferenceClient(model_id)
13
 
14
+ MAX_TOTAL_NB_ITEMS = 100 # almost infinite, don't judge me (actually it's because gradio needs a fixed number of components)
15
+ MAX_NB_ITEMS_PER_GENERATION_CALL = 10
16
+
17
  GENERATE_DATASET_NAMES_FOR_SEARCH_QUERY = (
18
  "A Machine Learning Practioner is looking for a dataset that matches '{search_query}'. "
19
+ f"Generate a list of {MAX_NB_ITEMS_PER_GENERATION_CALL} names of quality dataset that don't exist but sound plausible and would "
20
  "be helpful. Feel free to reuse words from the query '{search_query}' to name the datasets. "
21
  "Every dataset should be about '{search_query}' and have descriptive tags/keywords including the ML task name associated to the dataset (classification, regression, anomaly detection, etc.). Use the following format:\n1. DatasetName1 (tag1, tag2, tag3)\n1. DatasetName2 (tag1, tag2, tag3)"
22
  )
 
29
  "Reply using a short description of the dataset with title **Dataset Description:** followed by the CSV content in a code block and with title **CSV Content Preview:**."
30
  )
31
 
32
+ landing_page_query = "various datasets on many different subjects and topics, from classification to language modeling, from science to sport to finance to news"
33
+
34
+ landing_page_datasets_generated_text = """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  1. NewsEventsPredict (classification, media, trend)
36
  2. FinancialForecast (economy, stocks, regression)
37
  3. HealthMonitor (science, real-time, anomaly detection)
 
42
  8. NewsEventTracker (classification, public awareness, topical clustering)
43
  9. HealthVitalSigns (anomaly detection, biometrics, prediction)
44
  10. GameStockPredict (classification, finance, sports contingency)
45
+ """
46
+ default_output = landing_page_datasets_generated_text.strip().split("\n")
47
+ assert len(default_output) == MAX_NB_ITEMS_PER_GENERATION_CALL
48
 
49
  css = """
50
+ a {
51
+ color: var(--body-text-color);
52
+ }
53
+
54
  .datasetButton {
55
  justify-content: start;
56
  justify-content: left;
 
59
  font-size: var(--button-small-text-size);
60
  color: var(--body-text-color-subdued);
61
  }
 
 
 
62
  .topButton {
63
  justify-content: start;
64
  justify-content: left;
 
97
  .buttonsGroup div {
98
  background: transparent;
99
  }
100
+ .insivibleButtonGroup {
101
+ display: none;
102
+ }
103
+
104
  @keyframes placeHolderShimmer{
105
  0%{
106
  background-position: -468px 0
 
122
  }
123
  """
124
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
 
126
  with gr.Blocks(css=css) as demo:
127
+ generated_texts_state = gr.State((landing_page_datasets_generated_text,))
128
  with gr.Row():
129
  with gr.Column(scale=4, min_width=0):
130
  pass
 
145
  search_button = gr.Button("🔍", variant="primary")
146
  with gr.Column(scale=4, min_width=0):
147
  pass
 
 
148
  with gr.Row():
149
  with gr.Column(scale=4, min_width=0):
150
  pass
151
  with gr.Column(scale=10):
152
+ button_groups: list[gr.Group] = []
153
+ buttons: list[gr.Button] = []
154
+ for i in range(MAX_TOTAL_NB_ITEMS):
155
+ if i < len(default_output):
156
+ line = default_output[i]
157
+ dataset_name, tags = line.split(".", 1)[1].strip(" )").split(" (", 1)
158
+ group_classes = "buttonsGroup"
159
+ dataset_name_classes = "topButton"
160
+ tags_classes = "bottomButton"
161
+ else:
162
+ dataset_name, tags = "⬜⬜⬜⬜⬜⬜", "░░░░, ░░░░, ░░░░"
163
+ group_classes = "buttonsGroup insivibleButtonGroup"
164
+ dataset_name_classes = "topButton linear-background"
165
+ tags_classes = "bottomButton linear-background"
166
+ with gr.Group(elem_classes=group_classes) as button_group:
167
+ button_groups.append(button_group)
168
+ buttons.append(gr.Button(dataset_name, elem_classes=dataset_name_classes))
169
+ buttons.append(gr.Button(tags, elem_classes=tags_classes))
170
+
171
+ see_more = gr.Button("See more") # TODO: dosable when reaching end of page
172
  gr.Markdown(f"_powered by [{model_id}](https://huggingface.co/{model_id})_")
173
  with gr.Column(scale=4, min_width=0):
174
  pass
175
+ # more.click(search_more_datasets, inputs=[generated_texts, search_bar], outputs=[generated_texts] + buttons)
 
176
  with gr.Column(visible=False) as dataset_page:
177
  with gr.Row():
178
  with gr.Column(scale=4, min_width=0):
 
184
  with gr.Column(scale=4, min_width=0):
185
  pass
186
  with gr.Column():
187
+ generate_full_dataset_button = gr.Button("Generate Full Dataset", variant="primary") # TODO: implement
 
188
  back_button = gr.Button("< Back", size="sm")
 
189
  with gr.Column(scale=4, min_width=0):
190
  pass
191
  with gr.Column(scale=4, min_width=0):
192
  pass
193
+
194
+
195
+ T = TypeVar("T")
196
+
197
+ def batched(it: Iterable[T], n: int) -> Iterator[list[T]]:
198
+ it = iter(it)
199
+ while batch := list(islice(it, n)):
200
+ yield batch
201
+
202
+
203
+ def stream_reponse(msg: str, generated_texts: tuple[str] = (), max_tokens=500) -> Iterator[str]:
204
+ messages = [
205
+ {"role": "user", "content": msg}
206
+ ] + [
207
+ item
208
+ for generated_text in generated_texts
209
+ for item in [
210
+ {"role": "assistant", "content": generated_text},
211
+ {"role": "user", "content": "Can you generate more ?"},
212
+ ]
213
+ ]
214
+ for _ in range(3):
215
+ try:
216
+ for message in client.chat_completion(
217
+ messages=messages,
218
+ max_tokens=max_tokens,
219
+ stream=True,
220
+ top_p=0.8,
221
+ ):
222
+ yield message.choices[0].delta.content
223
+ except requests.exceptions.ConnectionError as e:
224
+ print(e + "\n\nRetrying in 1sec")
225
+ time.sleep(1)
226
+ continue
227
+ break
228
+
229
+
230
+ def gen_datasets_line_by_line(search_query: str, generated_texts: tuple[str] = ()) -> Iterator[str]:
231
+ search_query = search_query[:1000] if search_query.strip() else landing_page_query
232
+ generated_text = ""
233
+ current_line = ""
234
+ for token in stream_reponse(
235
+ GENERATE_DATASET_NAMES_FOR_SEARCH_QUERY.format(search_query=search_query),
236
+ generated_texts=generated_texts,
237
+ ):
238
+ current_line += token
239
+ if current_line.endswith("\n"):
240
+ yield current_line
241
+ generated_text += current_line
242
+ current_line = ""
243
+ yield current_line
244
+ generated_text += current_line
245
+ print("-----\n\n" + generated_text)
246
+
247
+
248
+ def gen_dataset_content(search_query: str, dataset_name: str, tags: str) -> Iterator[str]:
249
+ search_query = search_query[:1000] if search_query.strip() else landing_page_query
250
+ generated_text = ""
251
+ for token in stream_reponse(GENERATE_DATASET_CONTENT_FOR_SEARCH_QUERY_AND_NAME_AND_TAGS.format(
252
+ search_query=search_query,
253
+ dataset_name=dataset_name,
254
+ tags=tags,
255
+ ), max_tokens=1500):
256
+ generated_text += token
257
+ yield generated_text
258
+ print("-----\n\n" + generated_text)
259
+
260
+ search_datasets_inputs = search_bar
261
+ search_datasets_outputs = button_groups + buttons + [generated_texts_state]
262
+
263
+ def search_datasets(search_query):
264
+ yield {generated_texts_state: []}
265
+ yield {
266
+ button_group: gr.Group(elem_classes="buttonsGroup insivibleButtonGroup")
267
+ for button_group in button_groups[MAX_NB_ITEMS_PER_GENERATION_CALL:]
268
+ }
269
+ yield {
270
+ k: v
271
+ for dataset_name_button, tags_button in batched(buttons, 2)
272
+ for k, v in {
273
+ dataset_name_button: gr.Button("⬜⬜⬜⬜⬜⬜", elem_classes="topButton linear-background"),
274
+ tags_button: gr.Button("░░░░, ░░░░, ░░░░", elem_classes="bottomButton linear-background")
275
+ }.items()
276
+ }
277
+ current_item_idx = 0
278
+ generated_text = ""
279
+ for line in gen_datasets_line_by_line(search_query):
280
+ if "I'm sorry" in line:
281
+ raise gr.Error("Error: inappropriate content")
282
+ if current_item_idx >= MAX_NB_ITEMS_PER_GENERATION_CALL:
283
+ return
284
+ if line.strip() and line.strip().split(".", 1)[0].isnumeric():
285
+ try:
286
+ dataset_name, tags = line.strip().split(".", 1)[1].strip(" )").split(" (", 1)
287
+ except ValueError:
288
+ dataset_name, tags = line.strip().split(".", 1)[1].strip(" )").split(" ", 1)
289
+ dataset_name, tags = dataset_name.strip("()[]* "), tags.strip("()[]* ")
290
+ generated_text += line
291
+ yield {
292
+ buttons[2 * current_item_idx]: gr.Button(dataset_name, elem_classes="topButton"),
293
+ buttons[2 * current_item_idx + 1]: gr.Button(tags, elem_classes="bottomButton"),
294
+ generated_texts_state: (generated_text,),
295
+ }
296
+ current_item_idx += 1
297
+
298
+ search_more_datasets_inputs = [search_bar, generated_texts_state]
299
+ search_more_datasets_outputs = button_groups + buttons + [generated_texts_state]
300
+
301
+ def search_more_datasets(search_query, generated_texts):
302
+ current_item_idx = initial_item_idx = len(generated_texts) * MAX_NB_ITEMS_PER_GENERATION_CALL
303
+ yield {
304
+ button_group: gr.Group(elem_classes="buttonsGroup")
305
+ for button_group in button_groups[len(generated_texts) * MAX_NB_ITEMS_PER_GENERATION_CALL:(len(generated_texts) + 1) * MAX_NB_ITEMS_PER_GENERATION_CALL]
306
+ }
307
+ generated_text = ""
308
+ for line in gen_datasets_line_by_line(search_query, generated_texts=generated_texts):
309
+ if "I'm sorry" in line:
310
+ raise gr.Error("Error: inappropriate content")
311
+ if current_item_idx - initial_item_idx >= MAX_NB_ITEMS_PER_GENERATION_CALL:
312
+ return
313
+ if line.strip() and line.strip().split(".", 1)[0].isnumeric():
314
+ try:
315
+ dataset_name, tags = line.strip().split(".", 1)[1].strip(" )").split(" (", 1)
316
+ except ValueError:
317
+ dataset_name, tags = line.strip().split(".", 1)[1].strip(" )").split(" ", 1) [0], ""
318
+ dataset_name, tags = dataset_name.strip("()[]* "), tags.strip("()[]* ")
319
+ generated_text += line
320
+ yield {
321
+ buttons[2 * current_item_idx]: gr.Button(dataset_name, elem_classes="topButton"),
322
+ buttons[2 * current_item_idx + 1]: gr.Button(tags, elem_classes="bottomButton"),
323
+ generated_texts_state: (*generated_texts, generated_text),
324
+ }
325
+ current_item_idx += 1
326
+
327
+ show_dataset_inputs = [search_bar, *buttons]
328
+ show_dataset_outputs = [search_page, dataset_page, dataset_title, dataset_content]
329
+
330
+ def show_dataset(search_query, *buttons_values, i):
331
+ dataset_name, tags = buttons_values[2 * i : 2 * i + 2]
332
+ yield {
333
+ search_page: gr.Column(visible=False),
334
+ dataset_page: gr.Column(visible=True),
335
+ dataset_title: f"# {dataset_name}\n\n tags: {tags}\n\n _Note: This is an AI-generated dataset so its content may be inaccurate or false_"
336
+ }
337
+ for generated_text in gen_dataset_content(search_query=search_query, dataset_name=dataset_name, tags=tags):
338
+ yield {dataset_content: generated_text}
339
+
340
+
341
+ def show_search_page():
342
+ return gr.Column(visible=True), gr.Column(visible=False)
343
+
344
+
345
+ def generate_full_dataset():
346
+ raise gr.Error("Not implemented yet sorry ! Give me some feedbacks in the Community tab in the meantime ;)")
347
+
348
+
349
+ search_bar.submit(search_datasets, inputs=search_datasets_inputs, outputs=search_datasets_outputs)
350
+ search_button.click(search_datasets, inputs=search_datasets_inputs, outputs=search_datasets_outputs)
351
+ for i, (dataset_name_button, tags_button) in enumerate(batched(buttons, 2)):
352
+ dataset_name_button.click(partial(show_dataset, i=i), inputs=show_dataset_inputs, outputs=show_dataset_outputs)
353
+ tags_button.click(partial(show_dataset, i=i), inputs=show_dataset_inputs, outputs=show_dataset_outputs)
354
+ see_more.click(search_more_datasets, inputs=search_more_datasets_inputs, outputs=search_more_datasets_outputs)
355
+
356
+ generate_full_dataset_button.click(generate_full_dataset)
357
+ back_button.click(show_search_page, inputs=[], outputs=[search_page, dataset_page])
358
+
359
  demo.launch()