kxx-kkk's picture
Update app.py
4ef7de6 verified
import streamlit as st
import re
from transformers import pipeline
from transformers import AutoModelForQuestionAnswering, AutoTokenizer
import tempfile
import pytesseract
import PyPDF2
from pdf2image import convert_from_path
from PIL import Image
st.set_page_config(page_title="Automated Question Answering System") # set page title
# heading
st.markdown("<h2 style='text-align: center;'>Question Answering on Academic Essays</h2>", unsafe_allow_html=True)
# description
st.markdown("<h3 style='text-align: left; color:#F63366; font-size:18px;'><b>What is this project about?<b></h3>", unsafe_allow_html=True)
st.write("This project is to develop a web-based automated question-and-answer system for academic essays using natural language processing (NLP). Users can enter the essay and ask questions about it, and the system will automatically create answers.")
st.write("πŸ‘ Click 'Input Text' or 'Upload File' to start experience the system. ")
# store the model in cache resources to enhance efficiency (ref: https://docs.streamlit.io/library/advanced-features/caching)
@st.cache_resource(show_spinner=False)
def question_model():
# call my model for question answering
with st.spinner(text="Loading question model..."):
model_name = "kxx-kkk/FYP_qa_final"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForQuestionAnswering.from_pretrained(model_name)
question_answerer = pipeline("question-answering", model=model, tokenizer=tokenizer, handle_impossible_answer=True)
print("QA model is dowloaded and ready to use")
return question_answerer
qamodel = question_model()
@st.cache_data(show_spinner=False)
def extract_text(file_path):
text = ""
image_text = ""
with st.spinner(text="Extracting text from file..."):
with open(file_path, "rb") as pdf_file:
pdf_reader = PyPDF2.PdfReader(pdf_file)
num_pages = len(pdf_reader.pages)
for page_number in range(num_pages):
# st.write(f"Page {page_number + 1}")
page = pdf_reader.pages[page_number]
text += page.extract_text()
images = convert_from_path(file_path) # Convert PDF pages to images
for i, image in enumerate(images):
image_text += pytesseract.image_to_string(image)
# text = text + image_text
text = image_text
# remove more than one new line
text = re.sub(r"(?<!\n)\n(?!\n)", " ", text)
return text
# get the answer by passing the context & question to the model
def question_answering(context, question):
with st.spinner(text="Loading question model..."):
question_answerer = qamodel
with st.spinner(text="Getting answer..."):
answer = question_answerer(context=context, question=question)
print(answer)
answer_score = str(answer["score"])
answer = answer["answer"]
if (answer==""):
answer = "CANNOT ANSWER"
# display the result in container
container = st.container(border=True)
container.write("<h5><b>Answer:</b></h5>"+answer+"<p><small>(F1 score: "+answer_score+")</small></p><br>", unsafe_allow_html=True)
# def question_answering(context, question):
# with st.spinner(text="Loading question model..."):
# question_answerer = qamodel
# print("loading QA model...")
# with st.spinner(text="Getting answer..."):
# segment_size = 512
# overlap_size = 32
# text_length = len(context)
# segments = []
# # Split context into segments
# for i in range(0, text_length, segment_size - overlap_size):
# segment_start = i
# segment_end = i + segment_size
# segment = context[segment_start:segment_end]
# segments.append(segment)
# answers = {} # Dictionary to store answers for each segment
# # Get answers for each segment
# for i, segment in enumerate(segments):
# answer = question_answerer(context=segment, question=question)
# answers[i] = answer
# # Find the answer with the highest score
# highest_score = -1
# highest_answer = None
# for segment_index, answer in answers.items():
# print(answer)
# score = answer["score"]
# if score > highest_score:
# highest_score = score
# highest_answer = answer
# if highest_answer is not None:
# answer = highest_answer["answer"]
# if answer == "":
# answer = "CANNOT ANSWER"
# answer_score = str(highest_answer["score"])
# # Display the result in container
# container = st.container(border=True)
# container.write("<h5><b>Answer:</b></h5>" + answer + "<p><small>(F1 score: " + answer_score + ")</small></p><br>",
# unsafe_allow_html=True)
#-------------------- Main Webpage --------------------
# choose the source with different tabs
tab1, tab2 = st.tabs(["Input Text", "Upload File"])
#---------- input text ----------
# if type the text as input
with tab1:
# set the example
sample_question = "What is NLP?"
with open("sample.txt", "r") as text_file:
sample_text = text_file.read()
# Get the initial values of context and question
context = st.session_state.get("contextInput", "")
question = st.session_state.get("questionInput", "")
# Button to try the example
example = st.button("Try with example")
# Update the values if the "Try with example" button is clicked
if example:
context = sample_text
question = sample_question
# Display the text area and text input with the updated or default values
context = st.text_area("Enter the essay below:", value=context, key="contextInput", height=330)
question = st.text_input(label="Enter the question: ", value=question, key="questionInput")
# perform question answering when "get answer" button clicked
button = st.button("Get answer", key="textInput", type="primary")
if button:
if context=="" or question=="":
st.error ("Please enter BOTH the context and the question", icon="🚨")
else:
question_answering(context, question)
# ---------- upload file ----------
# if upload file as input
with tab2:
# provide upload place
uploaded_file = st.file_uploader("Upload essay in PDF format:", type=["pdf"])
# Create a session-level variable to track the uploaded file
if 'file' not in st.session_state:
st.session_state.file = None
# Create a session-level variable to track if text extraction has been done
if 'text_extracted' not in st.session_state:
st.session_state.text_extracted = False
# Get the initial values of context and question
context2 = st.session_state.get("contextInput2", "")
question2 = st.session_state.get("questionInput2", "")
# transfer file to context and allow ask question, then perform question answering
if uploaded_file is not None:
if st.session_state.file != uploaded_file:
# Update the session state with the new file
st.session_state.file = uploaded_file
st.session_state.text_extracted = False
if not st.session_state.text_extracted:
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
temp_file.write(uploaded_file.read()) # Save uploaded file to a temporary path
raw_text = extract_text(temp_file.name)
context2 = raw_text
st.session_state.text_extracted = True
question2 = st.text_input(label="Enter your question",value=question2, key="questionInput2")
context2 = st.text_area("Your essay context: ", value=context2, height=330, key="contextInput2")
# perform question answering when "get answer" button clicked
button2 = st.button("Get answer", key="fileInput", type="primary")
if button2:
if context2=="" or question2=="":
st.error ("Please enter BOTH the context and the question", icon="🚨")
else:
question_answering(context2, question2)
st.markdown("<p style='text-align:center;'>Β© 20069913D HUI Man Ki - Final Year Project</p>", unsafe_allow_html=True)
st.markdown("<br><br><br><br><br>", unsafe_allow_html=True)