legacy107 commited on
Commit
742563f
1 Parent(s): 0b844d0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -6
app.py CHANGED
@@ -27,7 +27,7 @@ dataset = dataset.shuffle()
27
  dataset = dataset.select(range(10))
28
 
29
 
30
- def paraphrase_answer(question, answer):
31
  # Combine question and context
32
  input_text = f"question: {question}. Paraphrase the answer to make it more natural answer: {answer}"
33
 
@@ -42,7 +42,10 @@ def paraphrase_answer(question, answer):
42
 
43
  # Generate the answer
44
  with torch.no_grad():
45
- generated_ids = paraphrase_model.generate(input_ids=input_ids, max_new_tokens=max_target_length)
 
 
 
46
 
47
  # Decode and return the generated answer
48
  paraphrased_answer = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
@@ -51,7 +54,7 @@ def paraphrase_answer(question, answer):
51
 
52
 
53
  # Define your function to generate answers
54
- def generate_answer(question, context, ground_truth, do_pretrained, do_natural):
55
  # Combine question and context
56
  input_text = f"question: {question} context: {context}"
57
 
@@ -83,7 +86,12 @@ def generate_answer(question, context, ground_truth, do_pretrained, do_natural):
83
  pretrained_generated_ids = pretrained_model.generate(input_ids=input_ids, max_new_tokens=max_target_length)
84
  pretrained_answer = tokenizer.decode(pretrained_generated_ids[0], skip_special_tokens=True)
85
 
86
- return generated_answer, paraphrased_answer, pretrained_answer
 
 
 
 
 
87
 
88
 
89
  # Define a function to list examples from the dataset
@@ -104,13 +112,15 @@ iface = gr.Interface(
104
  Textbox(label="Question"),
105
  Textbox(label="Context"),
106
  Textbox(label="Ground truth"),
107
- Checkbox(label="Include pretrained model's result"),
108
- Checkbox(label="Include natural answer")
 
109
  ],
110
  outputs=[
111
  Textbox(label="Generated Answer"),
112
  Textbox(label="Natural Answer"),
113
  Textbox(label="Pretrained Model's Answer"),
 
114
  ],
115
  examples=list_examples()
116
  )
 
27
  dataset = dataset.select(range(10))
28
 
29
 
30
+ def paraphrase_answer(question, answer, use_pretrained=False):
31
  # Combine question and context
32
  input_text = f"question: {question}. Paraphrase the answer to make it more natural answer: {answer}"
33
 
 
42
 
43
  # Generate the answer
44
  with torch.no_grad():
45
+ if use_pretrained:
46
+ generated_ids = pretrained_model.generate(input_ids=input_ids, max_new_tokens=max_target_length)
47
+ else:
48
+ generated_ids = paraphrase_model.generate(input_ids=input_ids, max_new_tokens=max_target_length)
49
 
50
  # Decode and return the generated answer
51
  paraphrased_answer = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
 
54
 
55
 
56
  # Define your function to generate answers
57
+ def generate_answer(question, context, ground_truth, do_pretrained, do_natural, do_pretrained_natural):
58
  # Combine question and context
59
  input_text = f"question: {question} context: {context}"
60
 
 
86
  pretrained_generated_ids = pretrained_model.generate(input_ids=input_ids, max_new_tokens=max_target_length)
87
  pretrained_answer = tokenizer.decode(pretrained_generated_ids[0], skip_special_tokens=True)
88
 
89
+ # Get pretrained model's natural answer
90
+ pretrained_paraphrased_answer = ""
91
+ if do_pretrained_natural:
92
+ pretrained_paraphrased_answer = paraphrase_answer(question, generated_answer, True)
93
+
94
+ return generated_answer, paraphrased_answer, pretrained_answer, pretrained_paraphrased_answer
95
 
96
 
97
  # Define a function to list examples from the dataset
 
112
  Textbox(label="Question"),
113
  Textbox(label="Context"),
114
  Textbox(label="Ground truth"),
115
+ Checkbox(label="Include pretrained model's answer"),
116
+ Checkbox(label="Include natural answer"),
117
+ Checkbox(label="Include pretrained model's natural answer")
118
  ],
119
  outputs=[
120
  Textbox(label="Generated Answer"),
121
  Textbox(label="Natural Answer"),
122
  Textbox(label="Pretrained Model's Answer"),
123
+ Textbox(label="Pretrained Model's Natural Answer")
124
  ],
125
  examples=list_examples()
126
  )