Spaces:
Sleeping
Sleeping
add "see more" button
Browse files
app.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import time
|
|
|
2 |
from functools import partial
|
3 |
-
from typing import Iterator
|
4 |
|
5 |
import gradio as gr
|
6 |
import requests.exceptions
|
@@ -10,9 +11,12 @@ from huggingface_hub import InferenceClient
|
|
10 |
model_id = "microsoft/Phi-3-mini-4k-instruct"
|
11 |
client = InferenceClient(model_id)
|
12 |
|
|
|
|
|
|
|
13 |
GENERATE_DATASET_NAMES_FOR_SEARCH_QUERY = (
|
14 |
"A Machine Learning Practioner is looking for a dataset that matches '{search_query}'. "
|
15 |
-
"Generate a list of
|
16 |
"be helpful. Feel free to reuse words from the query '{search_query}' to name the datasets. "
|
17 |
"Every dataset should be about '{search_query}' and have descriptive tags/keywords including the ML task name associated to the dataset (classification, regression, anomaly detection, etc.). Use the following format:\n1. DatasetName1 (tag1, tag2, tag3)\n1. DatasetName2 (tag1, tag2, tag3)"
|
18 |
)
|
@@ -25,52 +29,9 @@ GENERATE_DATASET_CONTENT_FOR_SEARCH_QUERY_AND_NAME_AND_TAGS = (
|
|
25 |
"Reply using a short description of the dataset with title **Dataset Description:** followed by the CSV content in a code block and with title **CSV Content Preview:**."
|
26 |
)
|
27 |
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
def stream_reponse(msg: str, max_tokens=500) -> Iterator[str]:
|
32 |
-
for _ in range(3):
|
33 |
-
try:
|
34 |
-
for message in client.chat_completion(
|
35 |
-
messages=[{"role": "user", "content": msg}],
|
36 |
-
max_tokens=max_tokens,
|
37 |
-
stream=True,
|
38 |
-
):
|
39 |
-
yield message.choices[0].delta.content
|
40 |
-
except requests.exceptions.ConnectionError as e:
|
41 |
-
print(e + "\n\nRetrying in 1sec")
|
42 |
-
time.sleep(1)
|
43 |
-
continue
|
44 |
-
break
|
45 |
-
|
46 |
-
|
47 |
-
def gen_datasets(search_query: str) -> Iterator[str]:
|
48 |
-
search_query = search_query[:1000] if search_query.strip() else default_query
|
49 |
-
generated_text = ""
|
50 |
-
for token in stream_reponse(GENERATE_DATASET_NAMES_FOR_SEARCH_QUERY.format(search_query=search_query)):
|
51 |
-
generated_text += token
|
52 |
-
if generated_text.endswith("\n"):
|
53 |
-
yield generated_text.strip()
|
54 |
-
yield generated_text.strip()
|
55 |
-
print("-----\n\n" + generated_text)
|
56 |
-
|
57 |
-
|
58 |
-
def gen_dataset_content(search_query: str, dataset_name: str, tags: str) -> Iterator[str]:
|
59 |
-
search_query = search_query[:1000] if search_query.strip() else default_query
|
60 |
-
generated_text = ""
|
61 |
-
for token in stream_reponse(GENERATE_DATASET_CONTENT_FOR_SEARCH_QUERY_AND_NAME_AND_TAGS.format(
|
62 |
-
search_query=search_query,
|
63 |
-
dataset_name=dataset_name,
|
64 |
-
tags=tags,
|
65 |
-
), max_tokens=1500):
|
66 |
-
generated_text += token
|
67 |
-
yield generated_text
|
68 |
-
print("-----\n\n" + generated_text)
|
69 |
-
|
70 |
-
|
71 |
-
NB_ITEMS_PER_PAGE = 10
|
72 |
-
|
73 |
-
default_output = """
|
74 |
1. NewsEventsPredict (classification, media, trend)
|
75 |
2. FinancialForecast (economy, stocks, regression)
|
76 |
3. HealthMonitor (science, real-time, anomaly detection)
|
@@ -81,10 +42,15 @@ default_output = """
|
|
81 |
8. NewsEventTracker (classification, public awareness, topical clustering)
|
82 |
9. HealthVitalSigns (anomaly detection, biometrics, prediction)
|
83 |
10. GameStockPredict (classification, finance, sports contingency)
|
84 |
-
"""
|
85 |
-
|
|
|
86 |
|
87 |
css = """
|
|
|
|
|
|
|
|
|
88 |
.datasetButton {
|
89 |
justify-content: start;
|
90 |
justify-content: left;
|
@@ -93,9 +59,6 @@ css = """
|
|
93 |
font-size: var(--button-small-text-size);
|
94 |
color: var(--body-text-color-subdued);
|
95 |
}
|
96 |
-
a {
|
97 |
-
color: var(--body-text-color);
|
98 |
-
}
|
99 |
.topButton {
|
100 |
justify-content: start;
|
101 |
justify-content: left;
|
@@ -134,6 +97,10 @@ a {
|
|
134 |
.buttonsGroup div {
|
135 |
background: transparent;
|
136 |
}
|
|
|
|
|
|
|
|
|
137 |
@keyframes placeHolderShimmer{
|
138 |
0%{
|
139 |
background-position: -468px 0
|
@@ -155,39 +122,9 @@ a {
|
|
155 |
}
|
156 |
"""
|
157 |
|
158 |
-
def search_datasets(search_query):
|
159 |
-
output_values = [
|
160 |
-
gr.Button("⬜⬜⬜⬜⬜⬜", elem_classes="topButton linear-background"),
|
161 |
-
gr.Button("░░░░, ░░░░, ░░░░", elem_classes="bottomButton linear-background")
|
162 |
-
] * NB_ITEMS_PER_PAGE
|
163 |
-
for generated_text in gen_datasets(search_query):
|
164 |
-
if "I'm sorry" in generated_text:
|
165 |
-
raise gr.Error("Error: inappropriate content")
|
166 |
-
lines = [line for line in generated_text.split("\n") if line and line.split(".", 1)[0].isnumeric()][:NB_ITEMS_PER_PAGE]
|
167 |
-
for i, line in enumerate(lines):
|
168 |
-
dataset_name, tags = line.split(".", 1)[1].strip(" )").split(" (", 1)
|
169 |
-
output_values[2 * i] = gr.Button(dataset_name, elem_classes="topButton")
|
170 |
-
output_values[2 * i + 1] = gr.Button(tags, elem_classes="bottomButton")
|
171 |
-
yield output_values
|
172 |
-
|
173 |
-
|
174 |
-
def show_dataset(search_query, *buttons_values, i):
|
175 |
-
dataset_name, tags = buttons_values[2 * i : 2 * i + 2]
|
176 |
-
dataset_title = f"# {dataset_name}\n\n tags: {tags}\n\n _Note: This is an AI-generated dataset so its content may be inaccurate or false_"
|
177 |
-
yield gr.Column(visible=False), gr.Column(visible=True), dataset_title, ""
|
178 |
-
for generated_text in gen_dataset_content(search_query=search_query, dataset_name=dataset_name, tags=tags):
|
179 |
-
yield gr.Column(), gr.Column(), dataset_title, generated_text
|
180 |
-
|
181 |
-
|
182 |
-
def show_search_page():
|
183 |
-
return gr.Column(visible=True), gr.Column(visible=False)
|
184 |
-
|
185 |
-
|
186 |
-
def generate_full_dataset():
|
187 |
-
raise gr.Error("Not implemented yet sorry ! Give me some feedbacks in the Community tab in the meantime ;)")
|
188 |
-
|
189 |
|
190 |
with gr.Blocks(css=css) as demo:
|
|
|
191 |
with gr.Row():
|
192 |
with gr.Column(scale=4, min_width=0):
|
193 |
pass
|
@@ -208,28 +145,34 @@ with gr.Blocks(css=css) as demo:
|
|
208 |
search_button = gr.Button("🔍", variant="primary")
|
209 |
with gr.Column(scale=4, min_width=0):
|
210 |
pass
|
211 |
-
inputs = [search_bar]
|
212 |
-
show_dataset_outputs = [search_page]
|
213 |
with gr.Row():
|
214 |
with gr.Column(scale=4, min_width=0):
|
215 |
pass
|
216 |
with gr.Column(scale=10):
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
228 |
gr.Markdown(f"_powered by [{model_id}](https://huggingface.co/{model_id})_")
|
229 |
with gr.Column(scale=4, min_width=0):
|
230 |
pass
|
231 |
-
|
232 |
-
search_button.click(search_datasets, inputs=search_bar, outputs=buttons)
|
233 |
with gr.Column(visible=False) as dataset_page:
|
234 |
with gr.Row():
|
235 |
with gr.Column(scale=4, min_width=0):
|
@@ -241,13 +184,176 @@ with gr.Blocks(css=css) as demo:
|
|
241 |
with gr.Column(scale=4, min_width=0):
|
242 |
pass
|
243 |
with gr.Column():
|
244 |
-
generate_full_dataset_button = gr.Button("Generate Full Dataset", variant="primary")
|
245 |
-
generate_full_dataset_button.click(generate_full_dataset)
|
246 |
back_button = gr.Button("< Back", size="sm")
|
247 |
-
back_button.click(show_search_page, inputs=[], outputs=[search_page, dataset_page])
|
248 |
with gr.Column(scale=4, min_width=0):
|
249 |
pass
|
250 |
with gr.Column(scale=4, min_width=0):
|
251 |
pass
|
252 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
253 |
demo.launch()
|
|
|
1 |
import time
|
2 |
+
from itertools import count, islice
|
3 |
from functools import partial
|
4 |
+
from typing import Iterable, Iterator, TypeVar
|
5 |
|
6 |
import gradio as gr
|
7 |
import requests.exceptions
|
|
|
11 |
model_id = "microsoft/Phi-3-mini-4k-instruct"
|
12 |
client = InferenceClient(model_id)
|
13 |
|
14 |
+
MAX_TOTAL_NB_ITEMS = 100 # almost infinite, don't judge me (actually it's because gradio needs a fixed number of components)
|
15 |
+
MAX_NB_ITEMS_PER_GENERATION_CALL = 10
|
16 |
+
|
17 |
GENERATE_DATASET_NAMES_FOR_SEARCH_QUERY = (
|
18 |
"A Machine Learning Practioner is looking for a dataset that matches '{search_query}'. "
|
19 |
+
f"Generate a list of {MAX_NB_ITEMS_PER_GENERATION_CALL} names of quality dataset that don't exist but sound plausible and would "
|
20 |
"be helpful. Feel free to reuse words from the query '{search_query}' to name the datasets. "
|
21 |
"Every dataset should be about '{search_query}' and have descriptive tags/keywords including the ML task name associated to the dataset (classification, regression, anomaly detection, etc.). Use the following format:\n1. DatasetName1 (tag1, tag2, tag3)\n1. DatasetName2 (tag1, tag2, tag3)"
|
22 |
)
|
|
|
29 |
"Reply using a short description of the dataset with title **Dataset Description:** followed by the CSV content in a code block and with title **CSV Content Preview:**."
|
30 |
)
|
31 |
|
32 |
+
landing_page_query = "various datasets on many different subjects and topics, from classification to language modeling, from science to sport to finance to news"
|
33 |
+
|
34 |
+
landing_page_datasets_generated_text = """
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
1. NewsEventsPredict (classification, media, trend)
|
36 |
2. FinancialForecast (economy, stocks, regression)
|
37 |
3. HealthMonitor (science, real-time, anomaly detection)
|
|
|
42 |
8. NewsEventTracker (classification, public awareness, topical clustering)
|
43 |
9. HealthVitalSigns (anomaly detection, biometrics, prediction)
|
44 |
10. GameStockPredict (classification, finance, sports contingency)
|
45 |
+
"""
|
46 |
+
default_output = landing_page_datasets_generated_text.strip().split("\n")
|
47 |
+
assert len(default_output) == MAX_NB_ITEMS_PER_GENERATION_CALL
|
48 |
|
49 |
css = """
|
50 |
+
a {
|
51 |
+
color: var(--body-text-color);
|
52 |
+
}
|
53 |
+
|
54 |
.datasetButton {
|
55 |
justify-content: start;
|
56 |
justify-content: left;
|
|
|
59 |
font-size: var(--button-small-text-size);
|
60 |
color: var(--body-text-color-subdued);
|
61 |
}
|
|
|
|
|
|
|
62 |
.topButton {
|
63 |
justify-content: start;
|
64 |
justify-content: left;
|
|
|
97 |
.buttonsGroup div {
|
98 |
background: transparent;
|
99 |
}
|
100 |
+
.insivibleButtonGroup {
|
101 |
+
display: none;
|
102 |
+
}
|
103 |
+
|
104 |
@keyframes placeHolderShimmer{
|
105 |
0%{
|
106 |
background-position: -468px 0
|
|
|
122 |
}
|
123 |
"""
|
124 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
|
126 |
with gr.Blocks(css=css) as demo:
|
127 |
+
generated_texts_state = gr.State((landing_page_datasets_generated_text,))
|
128 |
with gr.Row():
|
129 |
with gr.Column(scale=4, min_width=0):
|
130 |
pass
|
|
|
145 |
search_button = gr.Button("🔍", variant="primary")
|
146 |
with gr.Column(scale=4, min_width=0):
|
147 |
pass
|
|
|
|
|
148 |
with gr.Row():
|
149 |
with gr.Column(scale=4, min_width=0):
|
150 |
pass
|
151 |
with gr.Column(scale=10):
|
152 |
+
button_groups: list[gr.Group] = []
|
153 |
+
buttons: list[gr.Button] = []
|
154 |
+
for i in range(MAX_TOTAL_NB_ITEMS):
|
155 |
+
if i < len(default_output):
|
156 |
+
line = default_output[i]
|
157 |
+
dataset_name, tags = line.split(".", 1)[1].strip(" )").split(" (", 1)
|
158 |
+
group_classes = "buttonsGroup"
|
159 |
+
dataset_name_classes = "topButton"
|
160 |
+
tags_classes = "bottomButton"
|
161 |
+
else:
|
162 |
+
dataset_name, tags = "⬜⬜⬜⬜⬜⬜", "░░░░, ░░░░, ░░░░"
|
163 |
+
group_classes = "buttonsGroup insivibleButtonGroup"
|
164 |
+
dataset_name_classes = "topButton linear-background"
|
165 |
+
tags_classes = "bottomButton linear-background"
|
166 |
+
with gr.Group(elem_classes=group_classes) as button_group:
|
167 |
+
button_groups.append(button_group)
|
168 |
+
buttons.append(gr.Button(dataset_name, elem_classes=dataset_name_classes))
|
169 |
+
buttons.append(gr.Button(tags, elem_classes=tags_classes))
|
170 |
+
|
171 |
+
see_more = gr.Button("See more") # TODO: dosable when reaching end of page
|
172 |
gr.Markdown(f"_powered by [{model_id}](https://huggingface.co/{model_id})_")
|
173 |
with gr.Column(scale=4, min_width=0):
|
174 |
pass
|
175 |
+
# more.click(search_more_datasets, inputs=[generated_texts, search_bar], outputs=[generated_texts] + buttons)
|
|
|
176 |
with gr.Column(visible=False) as dataset_page:
|
177 |
with gr.Row():
|
178 |
with gr.Column(scale=4, min_width=0):
|
|
|
184 |
with gr.Column(scale=4, min_width=0):
|
185 |
pass
|
186 |
with gr.Column():
|
187 |
+
generate_full_dataset_button = gr.Button("Generate Full Dataset", variant="primary") # TODO: implement
|
|
|
188 |
back_button = gr.Button("< Back", size="sm")
|
|
|
189 |
with gr.Column(scale=4, min_width=0):
|
190 |
pass
|
191 |
with gr.Column(scale=4, min_width=0):
|
192 |
pass
|
193 |
+
|
194 |
+
|
195 |
+
T = TypeVar("T")
|
196 |
+
|
197 |
+
def batched(it: Iterable[T], n: int) -> Iterator[list[T]]:
|
198 |
+
it = iter(it)
|
199 |
+
while batch := list(islice(it, n)):
|
200 |
+
yield batch
|
201 |
+
|
202 |
+
|
203 |
+
def stream_reponse(msg: str, generated_texts: tuple[str] = (), max_tokens=500) -> Iterator[str]:
|
204 |
+
messages = [
|
205 |
+
{"role": "user", "content": msg}
|
206 |
+
] + [
|
207 |
+
item
|
208 |
+
for generated_text in generated_texts
|
209 |
+
for item in [
|
210 |
+
{"role": "assistant", "content": generated_text},
|
211 |
+
{"role": "user", "content": "Can you generate more ?"},
|
212 |
+
]
|
213 |
+
]
|
214 |
+
for _ in range(3):
|
215 |
+
try:
|
216 |
+
for message in client.chat_completion(
|
217 |
+
messages=messages,
|
218 |
+
max_tokens=max_tokens,
|
219 |
+
stream=True,
|
220 |
+
top_p=0.8,
|
221 |
+
):
|
222 |
+
yield message.choices[0].delta.content
|
223 |
+
except requests.exceptions.ConnectionError as e:
|
224 |
+
print(e + "\n\nRetrying in 1sec")
|
225 |
+
time.sleep(1)
|
226 |
+
continue
|
227 |
+
break
|
228 |
+
|
229 |
+
|
230 |
+
def gen_datasets_line_by_line(search_query: str, generated_texts: tuple[str] = ()) -> Iterator[str]:
|
231 |
+
search_query = search_query[:1000] if search_query.strip() else landing_page_query
|
232 |
+
generated_text = ""
|
233 |
+
current_line = ""
|
234 |
+
for token in stream_reponse(
|
235 |
+
GENERATE_DATASET_NAMES_FOR_SEARCH_QUERY.format(search_query=search_query),
|
236 |
+
generated_texts=generated_texts,
|
237 |
+
):
|
238 |
+
current_line += token
|
239 |
+
if current_line.endswith("\n"):
|
240 |
+
yield current_line
|
241 |
+
generated_text += current_line
|
242 |
+
current_line = ""
|
243 |
+
yield current_line
|
244 |
+
generated_text += current_line
|
245 |
+
print("-----\n\n" + generated_text)
|
246 |
+
|
247 |
+
|
248 |
+
def gen_dataset_content(search_query: str, dataset_name: str, tags: str) -> Iterator[str]:
|
249 |
+
search_query = search_query[:1000] if search_query.strip() else landing_page_query
|
250 |
+
generated_text = ""
|
251 |
+
for token in stream_reponse(GENERATE_DATASET_CONTENT_FOR_SEARCH_QUERY_AND_NAME_AND_TAGS.format(
|
252 |
+
search_query=search_query,
|
253 |
+
dataset_name=dataset_name,
|
254 |
+
tags=tags,
|
255 |
+
), max_tokens=1500):
|
256 |
+
generated_text += token
|
257 |
+
yield generated_text
|
258 |
+
print("-----\n\n" + generated_text)
|
259 |
+
|
260 |
+
search_datasets_inputs = search_bar
|
261 |
+
search_datasets_outputs = button_groups + buttons + [generated_texts_state]
|
262 |
+
|
263 |
+
def search_datasets(search_query):
|
264 |
+
yield {generated_texts_state: []}
|
265 |
+
yield {
|
266 |
+
button_group: gr.Group(elem_classes="buttonsGroup insivibleButtonGroup")
|
267 |
+
for button_group in button_groups[MAX_NB_ITEMS_PER_GENERATION_CALL:]
|
268 |
+
}
|
269 |
+
yield {
|
270 |
+
k: v
|
271 |
+
for dataset_name_button, tags_button in batched(buttons, 2)
|
272 |
+
for k, v in {
|
273 |
+
dataset_name_button: gr.Button("⬜⬜⬜⬜⬜⬜", elem_classes="topButton linear-background"),
|
274 |
+
tags_button: gr.Button("░░░░, ░░░░, ░░░░", elem_classes="bottomButton linear-background")
|
275 |
+
}.items()
|
276 |
+
}
|
277 |
+
current_item_idx = 0
|
278 |
+
generated_text = ""
|
279 |
+
for line in gen_datasets_line_by_line(search_query):
|
280 |
+
if "I'm sorry" in line:
|
281 |
+
raise gr.Error("Error: inappropriate content")
|
282 |
+
if current_item_idx >= MAX_NB_ITEMS_PER_GENERATION_CALL:
|
283 |
+
return
|
284 |
+
if line.strip() and line.strip().split(".", 1)[0].isnumeric():
|
285 |
+
try:
|
286 |
+
dataset_name, tags = line.strip().split(".", 1)[1].strip(" )").split(" (", 1)
|
287 |
+
except ValueError:
|
288 |
+
dataset_name, tags = line.strip().split(".", 1)[1].strip(" )").split(" ", 1)
|
289 |
+
dataset_name, tags = dataset_name.strip("()[]* "), tags.strip("()[]* ")
|
290 |
+
generated_text += line
|
291 |
+
yield {
|
292 |
+
buttons[2 * current_item_idx]: gr.Button(dataset_name, elem_classes="topButton"),
|
293 |
+
buttons[2 * current_item_idx + 1]: gr.Button(tags, elem_classes="bottomButton"),
|
294 |
+
generated_texts_state: (generated_text,),
|
295 |
+
}
|
296 |
+
current_item_idx += 1
|
297 |
+
|
298 |
+
search_more_datasets_inputs = [search_bar, generated_texts_state]
|
299 |
+
search_more_datasets_outputs = button_groups + buttons + [generated_texts_state]
|
300 |
+
|
301 |
+
def search_more_datasets(search_query, generated_texts):
|
302 |
+
current_item_idx = initial_item_idx = len(generated_texts) * MAX_NB_ITEMS_PER_GENERATION_CALL
|
303 |
+
yield {
|
304 |
+
button_group: gr.Group(elem_classes="buttonsGroup")
|
305 |
+
for button_group in button_groups[len(generated_texts) * MAX_NB_ITEMS_PER_GENERATION_CALL:(len(generated_texts) + 1) * MAX_NB_ITEMS_PER_GENERATION_CALL]
|
306 |
+
}
|
307 |
+
generated_text = ""
|
308 |
+
for line in gen_datasets_line_by_line(search_query, generated_texts=generated_texts):
|
309 |
+
if "I'm sorry" in line:
|
310 |
+
raise gr.Error("Error: inappropriate content")
|
311 |
+
if current_item_idx - initial_item_idx >= MAX_NB_ITEMS_PER_GENERATION_CALL:
|
312 |
+
return
|
313 |
+
if line.strip() and line.strip().split(".", 1)[0].isnumeric():
|
314 |
+
try:
|
315 |
+
dataset_name, tags = line.strip().split(".", 1)[1].strip(" )").split(" (", 1)
|
316 |
+
except ValueError:
|
317 |
+
dataset_name, tags = line.strip().split(".", 1)[1].strip(" )").split(" ", 1) [0], ""
|
318 |
+
dataset_name, tags = dataset_name.strip("()[]* "), tags.strip("()[]* ")
|
319 |
+
generated_text += line
|
320 |
+
yield {
|
321 |
+
buttons[2 * current_item_idx]: gr.Button(dataset_name, elem_classes="topButton"),
|
322 |
+
buttons[2 * current_item_idx + 1]: gr.Button(tags, elem_classes="bottomButton"),
|
323 |
+
generated_texts_state: (*generated_texts, generated_text),
|
324 |
+
}
|
325 |
+
current_item_idx += 1
|
326 |
+
|
327 |
+
show_dataset_inputs = [search_bar, *buttons]
|
328 |
+
show_dataset_outputs = [search_page, dataset_page, dataset_title, dataset_content]
|
329 |
+
|
330 |
+
def show_dataset(search_query, *buttons_values, i):
|
331 |
+
dataset_name, tags = buttons_values[2 * i : 2 * i + 2]
|
332 |
+
yield {
|
333 |
+
search_page: gr.Column(visible=False),
|
334 |
+
dataset_page: gr.Column(visible=True),
|
335 |
+
dataset_title: f"# {dataset_name}\n\n tags: {tags}\n\n _Note: This is an AI-generated dataset so its content may be inaccurate or false_"
|
336 |
+
}
|
337 |
+
for generated_text in gen_dataset_content(search_query=search_query, dataset_name=dataset_name, tags=tags):
|
338 |
+
yield {dataset_content: generated_text}
|
339 |
+
|
340 |
+
|
341 |
+
def show_search_page():
|
342 |
+
return gr.Column(visible=True), gr.Column(visible=False)
|
343 |
+
|
344 |
+
|
345 |
+
def generate_full_dataset():
|
346 |
+
raise gr.Error("Not implemented yet sorry ! Give me some feedbacks in the Community tab in the meantime ;)")
|
347 |
+
|
348 |
+
|
349 |
+
search_bar.submit(search_datasets, inputs=search_datasets_inputs, outputs=search_datasets_outputs)
|
350 |
+
search_button.click(search_datasets, inputs=search_datasets_inputs, outputs=search_datasets_outputs)
|
351 |
+
for i, (dataset_name_button, tags_button) in enumerate(batched(buttons, 2)):
|
352 |
+
dataset_name_button.click(partial(show_dataset, i=i), inputs=show_dataset_inputs, outputs=show_dataset_outputs)
|
353 |
+
tags_button.click(partial(show_dataset, i=i), inputs=show_dataset_inputs, outputs=show_dataset_outputs)
|
354 |
+
see_more.click(search_more_datasets, inputs=search_more_datasets_inputs, outputs=search_more_datasets_outputs)
|
355 |
+
|
356 |
+
generate_full_dataset_button.click(generate_full_dataset)
|
357 |
+
back_button.click(show_search_page, inputs=[], outputs=[search_page, dataset_page])
|
358 |
+
|
359 |
demo.launch()
|