philipp-zettl commited on
Commit
83a1143
1 Parent(s): 43d48ee

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +171 -147
app.py CHANGED
@@ -10,6 +10,10 @@ from collections import Counter
10
  from itertools import chain
11
  from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
12
  import math
 
 
 
 
13
 
14
  model_name = 'philipp-zettl/t5-small-long-qa'
15
  qa_model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
@@ -31,55 +35,6 @@ max_answers = 1
31
  max_elem_value = 100
32
 
33
 
34
-
35
- def ngrams(sequence, n):
36
- return [tuple(sequence[i:i+n]) for i in range(len(sequence)-n+1)]
37
-
38
- def count_ngrams(sequence, max_n):
39
- counts = Counter()
40
- for n in range(1, max_n + 1):
41
- counts.update(ngrams(sequence, n))
42
- return counts
43
-
44
- def self_bleu(outputs):
45
- smoothing_function = SmoothingFunction().method1
46
- scores = []
47
- for i in range(len(outputs)):
48
- references = outputs[:i] + outputs[i+1:]
49
- # Avoid calculating BLEU score for empty references
50
- if references:
51
- scores.append(sentence_bleu(references, outputs[i], smoothing_function=smoothing_function))
52
- # If all references are empty, return a default value
53
- if not scores:
54
- return 0
55
- return sum(scores) / len(scores)
56
-
57
- def dist_n(outputs, n):
58
- all_ngrams = list(chain(*[ngrams(output, n) for output in outputs]))
59
- unique_ngrams = set(all_ngrams)
60
- return len(unique_ngrams) / len(all_ngrams) if all_ngrams else 0
61
-
62
- def perplexity(model, tokenizer, texts):
63
- encodings = tokenizer(texts, return_tensors='pt', padding=True, truncation=True)
64
- max_length = model.config.n_positions
65
- stride = 512
66
- lls = []
67
- for i in range(0, encodings.input_ids.size(1), stride):
68
- begin_loc = max(i + stride - max_length, 0)
69
- end_loc = i + stride
70
- trg_len = end_loc - i
71
- input_ids = encodings.input_ids[:, begin_loc:end_loc].to(model.device)
72
- target_ids = input_ids.clone()
73
- target_ids[:, :-trg_len] = -100
74
-
75
- with torch.no_grad():
76
- outputs = model(input_ids, labels=target_ids)
77
- log_likelihood = outputs.loss * trg_len
78
- lls.append(log_likelihood)
79
-
80
- ppl = torch.exp(torch.stack(lls).sum() / end_loc)
81
- return ppl.item()
82
-
83
  def embedding_similarity(inputs, outputs):
84
  global embedding_model, embedding_tokenizer, device
85
  def embed(texts):
@@ -94,16 +49,6 @@ def embedding_similarity(inputs, outputs):
94
  similarities = pairwise_distances(input_embeddings, output_embeddings, metric='cosine')
95
  return sum(similarities) / len(similarities)
96
 
97
- def js_divergence(p, q):
98
- def kl_divergence(p, q):
99
- return sum(p[i] * math.log(p[i] / q[i]) for i in range(len(p)) if p[i] != 0 and q[i] != 0)
100
-
101
- p_norm = [float(i)/sum(p) for i in p]
102
- q_norm = [float(i)/sum(q) for i in q]
103
-
104
- m = [(p_norm[i] + q_norm[i]) / 2 for i in range(len(p_norm))]
105
-
106
- return (kl_divergence(p_norm, m) + kl_divergence(q_norm, m)) / 2
107
 
108
  def evaluate_model(num_beams, num_beam_groups, model, tokenizer, eval_data, max_length=85):
109
  generated_outputs = []
@@ -150,6 +95,7 @@ def evaluate_model(num_beams, num_beam_groups, model, tokenizer, eval_data, max_
150
  "jsd_score": jsd_score
151
  }
152
 
 
153
  def find_best_parameters(eval_data, model, tokenizer, max_length=85):
154
 
155
  # Parameter ranges
@@ -184,8 +130,6 @@ def find_best_parameters(eval_data, model, tokenizer, max_length=85):
184
  return best_params
185
 
186
 
187
-
188
-
189
  def run_model(inputs, tokenizer, model, num_beams=2, num_beam_groups=2, temperature=0.5, num_return_sequences=1, max_length=85, seed=42069):
190
  all_outputs = []
191
  torch.manual_seed(seed)
@@ -198,29 +142,13 @@ def run_model(inputs, tokenizer, model, num_beams=2, num_beam_groups=2, temperat
198
  sample_output = model.generate(
199
  input_ids[:1],
200
  max_length=max_length,
201
- #temperature=temperature,
202
- #do_sample=True,
203
  num_return_sequences=num_return_sequences,
204
  low_memory=True,
205
- #top_p=temperature,
206
- #num_beams=max(2, num_return_sequences),
207
  use_cache=True,
208
- # Contrastive search
209
- #penalty_alpha=0.6,
210
- #top_k=4,
211
- # Multi-nomial sampling
212
- #do_sample=True,
213
- #num_beams=1,
214
- # Beam search
215
- #num_beams=5,
216
- # Beam search multinomial sampling
217
- #num_beams=5,
218
- #do_sample=True,
219
  # Diverse Beam search decoding
220
  num_beams=max(2, num_return_sequences),
221
  num_beam_groups=max(2, num_return_sequences),
222
  diversity_penalty=temperature,
223
- #do_sample=True,
224
 
225
  )
226
  for i, sample_output in enumerate(sample_output):
@@ -311,38 +239,26 @@ def create_file_download(qnas):
311
  return 'qnas.tsv'
312
 
313
 
314
- with gr.Blocks() as demo:
315
- with gr.Tab(label='Description'):
316
- with gr.Row(equal_height=True):
317
- with gr.Column():
318
- gr.Markdown(
319
- """
320
- # QA-Generator
321
- A combination of fine-tuned flan-T5(-small) models chained into sequence
322
- to generate:
323
-
324
- a) a versatile set of questions
325
- b) an accurate set of matching answers
326
-
327
- according to a given piece of text content.""")
328
- with gr.Column():
329
- gr.Markdown(
330
- """
331
- The idea is simple:
332
-
333
- 1. Add your content
334
- 2. Select the amount of questions you want to generate
335
- 3. (optional) Select the amount of answers you want to generate per goven question
336
- 4. Press generate
337
- 5. ???
338
- 6. Profit
339
- """)
340
- with gr.Row(equal_height=True):
341
- gr.Markdown("""
342
- If you're satisfied with the generated data set, you can export it as TSV
343
- to edit or import it into your favourite tool.
344
- """)
345
- with gr.Row(equal_height=True):
346
  with gr.Accordion(label='Optimization', open=False):
347
  gr.Markdown("""
348
  For optimization of the question generation we apply the following combined score:
@@ -359,48 +275,156 @@ with gr.Blocks() as demo:
359
 
360
  4. **jsd**: This is the Jensen-Shannon Divergence between the n-gram distributions of the generated outputs and the reference data. <u>**Lower values indicate greater similarity between distributions.**</u>
361
  """, latex_delimiters=[{'display': False, 'left': '$$', 'right': '$$'}])
362
- with gr.Tab(label='QA Generator'):
363
- with gr.Row(equal_height=True):
364
- with gr.Group("Content"):
365
- content = gr.Textbox(label='Content', lines=15, placeholder='Enter text here', max_lines=10_000)
366
- with gr.Group("Settings"):
367
- temperature_qg = gr.Slider(label='Diversity Penalty QG', value=0.2, minimum=0, maximum=1, step=0.01)
368
- temperature_qa = gr.Slider(label='Diversity Penalty QA', value=0.5, minimum=0, maximum=1, step=0.01)
369
- max_length = gr.Number(label='Max Length', value=85, minimum=1, step=1, maximum=512)
370
- num_return_sequences_qg = gr.Number(label='Number Questions', value=max_questions, minimum=1, step=1, maximum=max(max_questions, max_elem_value))
371
- num_return_sequences_qa = gr.Number(label="Number Answers", value=max_answers, minimum=1, step=1, maximum=max(max_questions, max_elem_value))
372
- seed = gr.Number(label="seed", value=42069)
373
- optimize_questions = gr.Checkbox(label="Optimize questions?", value=False)
374
-
375
- with gr.Row():
376
- gen_btn = gr.Button("Generate")
377
-
378
- @gr.render(
379
- inputs=[
380
- content, temperature_qg, temperature_qa, num_return_sequences_qg, num_return_sequences_qa,
381
- max_length, seed, optimize_questions
382
- ],
383
- triggers=[gen_btn.click]
384
- )
385
- def render_results(content, temperature_qg, temperature_qa, num_return_sequences_qg, num_return_sequences_qa, max_length, seed, optimize_questions):
386
- if not content.strip():
387
- raise gr.Error('Please enter some content to generate questions and answers.')
388
- qnas = gen(
389
- content, temperature_qg, temperature_qa, num_return_sequences_qg, num_return_sequences_qa,
390
- max_length, seed, optimize_questions
391
  )
392
- df = gr.Dataframe(
393
- value=[u.values() for u in qnas],
394
- headers=['Question', 'Answer'],
395
- col_count=2,
396
- wrap=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
397
  )
398
- pd_df = pd.DataFrame([u.values() for u in qnas], columns=['Question', 'Answer'])
399
 
400
- download = gr.DownloadButton(label='Download (without headers)', value=create_file_download(pd_df))
 
 
 
 
 
 
 
 
 
401
 
402
- content.change(lambda x: x.strip(), content)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
403
 
404
 
405
  demo.queue()
406
- demo.launch()
 
10
  from itertools import chain
11
  from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
12
  import math
13
+ import markdown
14
+ from .text import doctree_from_url, get_selectors_for_class, split_by_heading, DocTree
15
+ from .optimization import ngrams, count_ngrams, self_bleu, dist_n, perplexity, js_divergence
16
+
17
 
18
  model_name = 'philipp-zettl/t5-small-long-qa'
19
  qa_model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
 
35
  max_elem_value = 100
36
 
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  def embedding_similarity(inputs, outputs):
39
  global embedding_model, embedding_tokenizer, device
40
  def embed(texts):
 
49
  similarities = pairwise_distances(input_embeddings, output_embeddings, metric='cosine')
50
  return sum(similarities) / len(similarities)
51
 
 
 
 
 
 
 
 
 
 
 
52
 
53
  def evaluate_model(num_beams, num_beam_groups, model, tokenizer, eval_data, max_length=85):
54
  generated_outputs = []
 
95
  "jsd_score": jsd_score
96
  }
97
 
98
+
99
  def find_best_parameters(eval_data, model, tokenizer, max_length=85):
100
 
101
  # Parameter ranges
 
130
  return best_params
131
 
132
 
 
 
133
  def run_model(inputs, tokenizer, model, num_beams=2, num_beam_groups=2, temperature=0.5, num_return_sequences=1, max_length=85, seed=42069):
134
  all_outputs = []
135
  torch.manual_seed(seed)
 
142
  sample_output = model.generate(
143
  input_ids[:1],
144
  max_length=max_length,
 
 
145
  num_return_sequences=num_return_sequences,
146
  low_memory=True,
 
 
147
  use_cache=True,
 
 
 
 
 
 
 
 
 
 
 
148
  # Diverse Beam search decoding
149
  num_beams=max(2, num_return_sequences),
150
  num_beam_groups=max(2, num_return_sequences),
151
  diversity_penalty=temperature,
 
152
 
153
  )
154
  for i, sample_output in enumerate(sample_output):
 
239
  return 'qnas.tsv'
240
 
241
 
242
+ def main():
243
+ with gr.Tab(label='QA Generator'):
244
+ with gr.Tab(label='Explanation'):
245
+ gr.Markdown(
246
+ '''
247
+ # QA Generator
248
+ This tab allows you to generate questions and answers from a given piece of text content.
249
+
250
+ ## How to use
251
+ 1. Enter the text content you want to generate questions and answers from.
252
+ 2. Adjust the diversity penalty for question generation and answer generation.
253
+ 3. Set the maximum length of the generated questions and answers.
254
+ 4. Choose the number of questions and answers you want to generate.
255
+ 5. Click on the "Generate" button.
256
+
257
+ The next section will give you insights into the generated questions and answers.
258
+
259
+ If you're satisfied with the generated questions and answers, you can download them as a TSV file.
260
+ '''
261
+ )
 
 
 
 
 
 
 
 
 
 
 
 
262
  with gr.Accordion(label='Optimization', open=False):
263
  gr.Markdown("""
264
  For optimization of the question generation we apply the following combined score:
 
275
 
276
  4. **jsd**: This is the Jensen-Shannon Divergence between the n-gram distributions of the generated outputs and the reference data. <u>**Lower values indicate greater similarity between distributions.**</u>
277
  """, latex_delimiters=[{'display': False, 'left': '$$', 'right': '$$'}])
278
+ with gr.Tab(label='Generate QA'):
279
+ with gr.Row(equal_height=True):
280
+ with gr.Group("Content"):
281
+ content = gr.Textbox(label='Content', lines=15, placeholder='Enter text here', max_lines=10_000)
282
+ with gr.Group("Settings"):
283
+ temperature_qg = gr.Slider(label='Diversity Penalty QG', value=0.2, minimum=0, maximum=1, step=0.01)
284
+ temperature_qa = gr.Slider(label='Diversity Penalty QA', value=0.5, minimum=0, maximum=1, step=0.01)
285
+ max_length = gr.Number(label='Max Length', value=85, minimum=1, step=1, maximum=512)
286
+ num_return_sequences_qg = gr.Number(label='Number Questions', value=max_questions, minimum=1, step=1, maximum=max(max_questions, max_elem_value))
287
+ num_return_sequences_qa = gr.Number(label="Number Answers", value=max_answers, minimum=1, step=1, maximum=max(max_questions, max_elem_value))
288
+ seed = gr.Number(label="seed", value=42069)
289
+ optimize_questions = gr.Checkbox(label="Optimize questions?", value=False)
290
+
291
+ with gr.Row():
292
+ gen_btn = gr.Button("Generate")
293
+
294
+ @gr.render(
295
+ inputs=[
296
+ content, temperature_qg, temperature_qa, num_return_sequences_qg, num_return_sequences_qa,
297
+ max_length, seed, optimize_questions
298
+ ],
299
+ triggers=[gen_btn.click]
 
 
 
 
 
 
 
300
  )
301
+ def render_results(content, temperature_qg, temperature_qa, num_return_sequences_qg, num_return_sequences_qa, max_length, seed, optimize_questions):
302
+ if not content.strip():
303
+ raise gr.Error('Please enter some content to generate questions and answers.')
304
+ qnas = gen(
305
+ content, temperature_qg, temperature_qa, num_return_sequences_qg, num_return_sequences_qa,
306
+ max_length, seed, optimize_questions
307
+ )
308
+ df = gr.Dataframe(
309
+ value=[u.values() for u in qnas],
310
+ headers=['Question', 'Answer'],
311
+ col_count=2,
312
+ wrap=True
313
+ )
314
+ pd_df = pd.DataFrame([u.values() for u in qnas], columns=['Question', 'Answer'])
315
+
316
+ download = gr.DownloadButton(label='Download (without headers)', value=create_file_download(pd_df))
317
+
318
+ content.change(lambda x: x.strip(), content)
319
+
320
+
321
+ def new_main():
322
+ with gr.Tab('Content extraction from URL'):
323
+ with gr.Tab(label='Explanation'):
324
+ gr.Markdown(
325
+ '''
326
+ # Content extraction from URL
327
+ This tab allows you to extract content from a URL and chunk it into sections.
328
+
329
+ ## How to use
330
+ 1. Enter the URL of the webpage you want to extract content from.
331
+ 2. Select the element class and class name of the HTML element you want to extract content from.
332
+ 3. Click on the "Extract content" button.
333
+
334
+ The next section will give you insights into the extracted content.
335
+
336
+ This was done to give you the possibility to look at the extracted content, as well as manipulate it further.
337
+
338
+ Once you extract the content, you can choose the depth level to chunk the content into sections.
339
+ 1. Enter the depth level you want to chunk the content into. **Note: <u>This is based on the HTML structure of the webpage, we're utilizing heading tags for this purpose</u>**
340
+ 2. Click on the "Chunk content" button.
341
+ '''
342
+ )
343
+ with gr.Tab(label='Extract content'):
344
+ url = gr.Textbox(label='URL', placeholder='Enter URL here', lines=1, max_lines=1)
345
+ elem_class = gr.Dropdown(label='CSS element class', choices=['div', 'p', 'span'], value='div')
346
+ class_name = gr.Dropdown(label='CSS class name', choices=[], allow_custom_value=True)
347
+
348
+ extract_btn = gr.Button('Extract content')
349
+
350
+ with gr.Group():
351
+ content_state = gr.State(None)
352
+ final_content = gr.Textbox(value='', show_copy_button=True, label='Final content', interactive=True)
353
+ with gr.Accordion('Reveal original input', open=False):
354
+ og_content = gr.Textbox(value='', label='OG HTML content')
355
+
356
+ with gr.Group(visible=False) as step_2_group:
357
+ depth_level = gr.Number(label='Depth level', value=1, minimum=0, step=1, maximum=6)
358
+ continue_btn = gr.Button('Chunk content')
359
+
360
+ def render_results(url, elem_class_, class_name_):
361
+ if not url.strip():
362
+ raise gr.Error('Please enter a URL to extract content.')
363
+ content = doctree_from_url(url, elem_class_, class_name_)
364
+ return [
365
+ content,
366
+ content.content,
367
+ content.as_markdown(content.merge_sections(content.get_sections(0))),
368
+ gr.Group(visible=True)
369
+ ]
370
+
371
+ def get_class_options(url, elem_class):
372
+ if not url.strip():
373
+ raise gr.Error('Please enter a URL to extract content.')
374
+
375
+ return gr.Dropdown(label='CSS class name', choices=list(set(get_selectors_for_class(url, elem_class))))
376
+
377
+ def update_content_state_on_final_change(final_content):
378
+ html_content = markdown.markdown(final_content)
379
+ return DocTree(split_by_heading(html_content, 1))
380
+
381
+ @gr.render(inputs=[content_state, depth_level], triggers=[continue_btn.click])
382
+ def select_content(content, depth_level):
383
+ if not content:
384
+ raise gr.Error('Please extract content first.')
385
+
386
+ sections = content.get_sections_by_depth(depth_level)
387
+ print(f'Found {len(sections)} sections')
388
+ ds = []
389
+ for idx, section in enumerate(sections):
390
+ ds.append([idx, content.as_markdown(content.merge_sections(section))])
391
+ gr.Dataframe(value=ds, headers=['Section #', 'Content'], interactive=True, wrap=True)
392
+
393
+ url.change(
394
+ get_class_options,
395
+ inputs=[url, elem_class],
396
+ outputs=[class_name]
397
  )
 
398
 
399
+ extract_btn.click(
400
+ render_results,
401
+ inputs=[
402
+ url, elem_class, class_name,
403
+ ],
404
+ outputs=[
405
+ content_state, og_content, final_content, step_2_group
406
+ ]
407
+ )
408
+ final_content.change(update_content_state_on_final_change, inputs=[final_content], outputs=[content_state])
409
 
410
+
411
+ with gr.Blocks() as demo:
412
+ gr.Markdown(
413
+ '''
414
+ # QA-Generator
415
+ A tool to build FAQs or QnAs from a given piece of text content.
416
+
417
+ ## How to use
418
+ We provide you two major functionalities:
419
+ 1. **Content extraction from URL**: Extract content from a URL and chunk it into sections.
420
+ 2. **QA Generator**: Generate questions and answers from a given text content.
421
+
422
+ Select the tab you want to use and follow the instructions.
423
+ '''
424
+ )
425
+ new_main()
426
+ main()
427
 
428
 
429
  demo.queue()
430
+ demo.launch()