legacy107 commited on
Commit
94a7057
·
1 Parent(s): 73800d5

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +106 -0
app.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from gradio.components import Textbox
3
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, T5ForConditionalGeneration
4
+ from peft import PeftModel, PeftConfig
5
+ import torch
6
+ import datasets
7
+
8
+ # Load your fine-tuned model and tokenizer
9
+ model_name = "google/flan-t5-large"
10
+ peft_name = "legacy107/flan-t5-large-ia3-cpgQA"
11
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
12
+ model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-large")
13
+ model = PeftModel.from_pretrained(model, peft_name)
14
+
15
+ peft_name = "legacy107/flan-t5-large-ia3-bioasq-paraphrase"
16
+ peft_config = PeftConfig.from_pretrained(peft_name)
17
+ paraphrase_model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
18
+ paraphrase_model = PeftModel.from_pretrained(paraphrase_model, peft_name)
19
+
20
+ max_length = 512
21
+ max_target_length = 200
22
+
23
+ # Load your dataset
24
+ dataset = datasets.load_dataset("minh21/cpgQA-v1.0-unique-context-test-10-percent-validation-10-percent", split="test")
25
+ dataset = dataset.shuffle()
26
+ dataset = dataset.select(range(5))
27
+
28
+
29
+ def paraphrase_answer(question, answer):
30
+ # Combine question and context
31
+ input_text = f"question: {question}. Paraphrase the answer to make it more natural answer: {answer}"
32
+
33
+ # Tokenize the input text
34
+ input_ids = tokenizer(
35
+ input_text,
36
+ return_tensors="pt",
37
+ padding="max_length",
38
+ truncation=True,
39
+ max_length=max_length,
40
+ ).input_ids
41
+
42
+ # Generate the answer
43
+ with torch.no_grad():
44
+ generated_ids = paraphrase_model.generate(input_ids=input_ids, max_new_tokens=max_target_length)
45
+
46
+ # Decode and return the generated answer
47
+ paraphrased_answer = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
48
+
49
+ return paraphrased_answer
50
+
51
+
52
+ # Define your function to generate answers
53
+ def generate_answer(question, context, ground_truth):
54
+ # Combine question and context
55
+ input_text = f"question: {question} context: {context}"
56
+
57
+ # Tokenize the input text
58
+ input_ids = tokenizer(
59
+ input_text,
60
+ return_tensors="pt",
61
+ padding="max_length",
62
+ truncation=True,
63
+ max_length=max_length,
64
+ ).input_ids
65
+
66
+ # Generate the answer
67
+ with torch.no_grad():
68
+ generated_ids = model.generate(input_ids, max_new_tokens=max_target_length)
69
+
70
+ # Decode and return the generated answer
71
+ generated_answer = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
72
+
73
+ # Paraphrase answer
74
+ paraphrased_answer = paraphrase_answer(question, generated_answer)
75
+
76
+ return generated_answer, paraphrased_answer
77
+
78
+
79
+ # Define a function to list examples from the dataset
80
+ def list_examples():
81
+ examples = []
82
+ for example in dataset:
83
+ context = example["context"]
84
+ question = example["question"]
85
+ answer = example["answer_text"]
86
+ examples.append([question, context, answer])
87
+ return examples
88
+
89
+
90
+ # Create a Gradio interface
91
+ iface = gr.Interface(
92
+ fn=generate_answer,
93
+ inputs=[
94
+ Textbox(label="Question"),
95
+ Textbox(label="Context")
96
+ Textbox(label="Ground truth")
97
+ ],
98
+ outputs=[
99
+ Textbox(label="Generated Answer"),
100
+ Textbox(label="Natural Answer")
101
+ ],
102
+ examples=list_examples()
103
+ )
104
+
105
+ # Launch the Gradio interface
106
+ iface.launch()