import gradio as gr from gradio.components import Textbox, Checkbox from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, T5ForConditionalGeneration from peft import PeftModel, PeftConfig import torch import datasets # Load your fine-tuned model and tokenizer model_name = "google/flan-t5-large" peft_name = "legacy107/flan-t5-large-ia3-cpgQA" tokenizer = AutoTokenizer.from_pretrained(model_name) pretrained_model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-large") model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-large") model = PeftModel.from_pretrained(model, peft_name) peft_name = "legacy107/flan-t5-large-ia3-bioasq-paraphrase" peft_config = PeftConfig.from_pretrained(peft_name) paraphrase_model = AutoModelForSeq2SeqLM.from_pretrained(model_name) paraphrase_model = PeftModel.from_pretrained(paraphrase_model, peft_name) max_length = 512 max_target_length = 200 # Load your dataset dataset = datasets.load_dataset("minh21/cpgQA-v1.0-unique-context-test-10-percent-validation-10-percent", split="test") # dataset = dataset.shuffle() dataset = dataset.select([32, 7, 92, 8, 108, 51, 64, 84, 93, 94]) def paraphrase_answer(question, answer, use_pretrained=False): # Combine question and context input_text = f"question: {question}. Paraphrase the answer to make it more natural answer: {answer}" # Tokenize the input text input_ids = tokenizer( input_text, return_tensors="pt", padding="max_length", truncation=True, max_length=max_length, ).input_ids # Generate the answer with torch.no_grad(): if use_pretrained: generated_ids = pretrained_model.generate(input_ids=input_ids, max_new_tokens=max_target_length) else: generated_ids = paraphrase_model.generate(input_ids=input_ids, max_new_tokens=max_target_length) # Decode and return the generated answer paraphrased_answer = tokenizer.decode(generated_ids[0], skip_special_tokens=True) return paraphrased_answer # Define your function to generate answers def generate_answer(question, context, ground_truth, do_pretrained, do_natural, do_pretrained_natural): # Combine question and context input_text = f"question: {question} context: {context}" # Tokenize the input text input_ids = tokenizer( input_text, return_tensors="pt", padding="max_length", truncation=True, max_length=max_length, ).input_ids # Generate the answer with torch.no_grad(): generated_ids = model.generate(input_ids=input_ids, max_new_tokens=max_target_length) # Decode and return the generated answer generated_answer = tokenizer.decode(generated_ids[0], skip_special_tokens=True) # Paraphrase answer paraphrased_answer = "" if do_natural: paraphrased_answer = paraphrase_answer(question, generated_answer) # Get pretrained model's answer pretrained_answer = "" if do_pretrained: with torch.no_grad(): pretrained_generated_ids = pretrained_model.generate(input_ids=input_ids, max_new_tokens=max_target_length) pretrained_answer = tokenizer.decode(pretrained_generated_ids[0], skip_special_tokens=True) # Get pretrained model's natural answer pretrained_paraphrased_answer = "" if do_pretrained_natural: pretrained_paraphrased_answer = paraphrase_answer(question, generated_answer, True) return generated_answer, paraphrased_answer, pretrained_answer, pretrained_paraphrased_answer # Define a function to list examples from the dataset def list_examples(): examples = [] for example in dataset: context = example["context"] question = example["question"] answer = example["answer_text"] examples.append([question, context, answer, True, True, True]) return examples # Create a Gradio interface iface = gr.Interface( fn=generate_answer, inputs=[ Textbox(label="Question"), Textbox(label="Context"), Textbox(label="Ground truth"), Checkbox(label="Include pretrained model's answer"), Checkbox(label="Include natural answer"), Checkbox(label="Include pretrained model's natural answer") ], outputs=[ Textbox(label="Generated Answer"), Textbox(label="Natural Answer"), Textbox(label="Pretrained Model's Answer"), Textbox(label="Pretrained Model's Natural Answer") ], examples=list_examples() ) # Launch the Gradio interface iface.launch()