Spaces:
Runtime error
Runtime error
philipp-zettl
commited on
Commit
•
1ccde3b
1
Parent(s):
ddc0abc
Update app.py
Browse files
app.py
CHANGED
@@ -158,6 +158,7 @@ def find_best_parameters(eval_data, model, tokenizer, max_length=85):
|
|
158 |
4: [2],
|
159 |
6: [2], # 6x3 == 4x2
|
160 |
8: [2], # 8x4 == 6x3 == 4x2
|
|
|
161 |
10: [2], # 10x5 == 8x4 == 6x3 == 4x2
|
162 |
}
|
163 |
|
@@ -249,7 +250,9 @@ def gen(content, temperature_qg=0.5, temperature_qa=0.75, num_return_sequences_q
|
|
249 |
)
|
250 |
|
251 |
if optimize_questions:
|
252 |
-
q_params = find_best_parameters(
|
|
|
|
|
253 |
|
254 |
question = run_model(
|
255 |
inputs,
|
@@ -308,69 +311,89 @@ def create_file_download(qnas):
|
|
308 |
return 'qnas.tsv'
|
309 |
|
310 |
|
311 |
-
with gr.Blocks(
|
312 |
-
with gr.
|
313 |
-
gr.
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
368 |
)
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
373 |
|
374 |
|
375 |
demo.queue()
|
376 |
-
demo.launch()
|
|
|
158 |
4: [2],
|
159 |
6: [2], # 6x3 == 4x2
|
160 |
8: [2], # 8x4 == 6x3 == 4x2
|
161 |
+
9: [3],
|
162 |
10: [2], # 10x5 == 8x4 == 6x3 == 4x2
|
163 |
}
|
164 |
|
|
|
250 |
)
|
251 |
|
252 |
if optimize_questions:
|
253 |
+
q_params = find_best_parameters(
|
254 |
+
list(chain.from_iterable(question)), qg_model, tokenizer, max_length=max_length
|
255 |
+
)
|
256 |
|
257 |
question = run_model(
|
258 |
inputs,
|
|
|
311 |
return 'qnas.tsv'
|
312 |
|
313 |
|
314 |
+
with gr.Blocks() as demo:
|
315 |
+
with gr.Tab(label='Description'):
|
316 |
+
with gr.Row(equal_height=True):
|
317 |
+
gr.Markdown(
|
318 |
+
"""
|
319 |
+
# QA-Generator
|
320 |
+
A combination of fine-tuned flan-T5(-small) models chained into sequence
|
321 |
+
to generate:
|
322 |
+
|
323 |
+
a) a versatile set of questions
|
324 |
+
b) an accurate set of matching answers
|
325 |
+
|
326 |
+
according to a given piece of text content.
|
327 |
+
The idea is simple:
|
328 |
+
|
329 |
+
1. Add your content
|
330 |
+
2. Select the amount of questions you want to generate
|
331 |
+
2.2 (optional) Select the amount of answers you want to generate per goven question
|
332 |
+
3. Press generate
|
333 |
+
4. ???
|
334 |
+
5. Profit
|
335 |
+
If you're satisfied with the generated data set, you can export it as TSV
|
336 |
+
to edit or import it into your favourite tool.
|
337 |
+
""")
|
338 |
+
with gr.Row(equal_height=True):
|
339 |
+
with gr.Accordion(label='Optimization', open=False):
|
340 |
+
gr.Markdown("""
|
341 |
+
For optimization of the question generation we apply the following combined score:
|
342 |
+
|
343 |
+
$$\\text{combined} = \\text{dist1} + \\text{dist2} - \\text{fluency} + \\text{contextual} - \\text{jsd}$$
|
344 |
+
|
345 |
+
Here's a brief explanation of each component:
|
346 |
+
|
347 |
+
1. **dist1 and dist2**: These represent the diversity of the generated outputs. dist1 measures the ratio of unique unigrams to total unigrams, and dist2 measures the ratio of unique bigrams to total bigrams. <u>**Higher values indicate more diverse outputs.**</u>
|
348 |
+
|
349 |
+
2. **fluency**: This is the perplexity of the generated outputs, which measures how well the outputs match the language model's expectations. <u>**Lower values indicate better fluency.**</u>
|
350 |
+
|
351 |
+
3. **contextual**: This measures the similarity between the input and generated outputs using embedding similarity. <u>**Higher values indicate better contextual relevance.**</u>
|
352 |
+
|
353 |
+
4. **jsd**: This is the Jensen-Shannon Divergence between the n-gram distributions of the generated outputs and the reference data. <u>**Lower values indicate greater similarity between distributions.**</u>
|
354 |
+
""", latex_delimiters=[{'display': False, 'left': '$$', 'right': '$$'}])
|
355 |
+
with gr.Tab(label='QA Generator'):
|
356 |
+
with gr.Row(equal_height=True):
|
357 |
+
with gr.Group("Content"):
|
358 |
+
content = gr.Textbox(label='Content', lines=15, placeholder='Enter text here', max_lines=10_000)
|
359 |
+
with gr.Group("Settings"):
|
360 |
+
temperature_qg = gr.Slider(label='Diversity Penalty QG', value=0.2, minimum=0, maximum=1, step=0.01)
|
361 |
+
temperature_qa = gr.Slider(label='Diversity Penalty QA', value=0.5, minimum=0, maximum=1, step=0.01)
|
362 |
+
max_length = gr.Number(label='Max Length', value=85, minimum=1, step=1, maximum=512)
|
363 |
+
num_return_sequences_qg = gr.Number(label='Number Questions', value=max_questions, minimum=1, step=1, maximum=max(max_questions, max_elem_value))
|
364 |
+
num_return_sequences_qa = gr.Number(label="Number Answers", value=max_answers, minimum=1, step=1, maximum=max(max_questions, max_elem_value))
|
365 |
+
seed = gr.Number(label="seed", value=42069)
|
366 |
+
optimize_questions = gr.Checkbox(label="Optimize questions?", value=False)
|
367 |
+
|
368 |
+
with gr.Row():
|
369 |
+
gen_btn = gr.Button("Generate")
|
370 |
+
|
371 |
+
@gr.render(
|
372 |
+
inputs=[
|
373 |
+
content, temperature_qg, temperature_qa, num_return_sequences_qg, num_return_sequences_qa,
|
374 |
+
max_length, seed, optimize_questions
|
375 |
+
],
|
376 |
+
triggers=[gen_btn.click]
|
377 |
)
|
378 |
+
def render_results(content, temperature_qg, temperature_qa, num_return_sequences_qg, num_return_sequences_qa, max_length, seed, optimize_questions):
|
379 |
+
if not content.strip():
|
380 |
+
raise gr.Error('Please enter some content to generate questions and answers.')
|
381 |
+
qnas = gen(
|
382 |
+
content, temperature_qg, temperature_qa, num_return_sequences_qg, num_return_sequences_qa,
|
383 |
+
max_length, seed, optimize_questions
|
384 |
+
)
|
385 |
+
df = gr.Dataframe(
|
386 |
+
value=[u.values() for u in qnas],
|
387 |
+
headers=['Question', 'Answer'],
|
388 |
+
col_count=2,
|
389 |
+
wrap=True
|
390 |
+
)
|
391 |
+
pd_df = pd.DataFrame([u.values() for u in qnas], columns=['Question', 'Answer'])
|
392 |
+
|
393 |
+
download = gr.DownloadButton(label='Download (without headers)', value=create_file_download(pd_df))
|
394 |
+
|
395 |
+
content.change(lambda x: x.strip(), content)
|
396 |
|
397 |
|
398 |
demo.queue()
|
399 |
+
demo.launch()
|