Spaces:
Running
Running
import streamlit as st | |
import re | |
from transformers import pipeline | |
from transformers import AutoModelForQuestionAnswering, AutoTokenizer | |
import tempfile | |
import pytesseract | |
import PyPDF2 | |
from pdf2image import convert_from_path | |
from PIL import Image | |
st.set_page_config(page_title="Automated Question Answering System") # set page title | |
# heading | |
st.markdown("<h2 style='text-align: center;'>Question Answering on Academic Essays</h2>", unsafe_allow_html=True) | |
# description | |
st.markdown("<h3 style='text-align: left; color:#F63366; font-size:18px;'><b>What is this project about?<b></h3>", unsafe_allow_html=True) | |
st.write("This project is to develop a web-based automated question-and-answer system for academic essays using natural language processing (NLP). Users can enter the essay and ask questions about it, and the system will automatically create answers.") | |
st.write("π Click 'Input Text' or 'Upload File' to start experience the system. ") | |
# store the model in cache resources to enhance efficiency (ref: https://docs.streamlit.io/library/advanced-features/caching) | |
def question_model(): | |
# call my model for question answering | |
with st.spinner(text="Loading question model..."): | |
model_name = "kxx-kkk/FYP_qa_final" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForQuestionAnswering.from_pretrained(model_name) | |
question_answerer = pipeline("question-answering", model=model, tokenizer=tokenizer, handle_impossible_answer=True) | |
print("QA model is dowloaded and ready to use") | |
return question_answerer | |
qamodel = question_model() | |
def extract_text(file_path): | |
text = "" | |
image_text = "" | |
with st.spinner(text="Extracting text from file..."): | |
with open(file_path, "rb") as pdf_file: | |
pdf_reader = PyPDF2.PdfReader(pdf_file) | |
num_pages = len(pdf_reader.pages) | |
for page_number in range(num_pages): | |
# st.write(f"Page {page_number + 1}") | |
page = pdf_reader.pages[page_number] | |
text += page.extract_text() | |
images = convert_from_path(file_path) # Convert PDF pages to images | |
for i, image in enumerate(images): | |
image_text += pytesseract.image_to_string(image) | |
# text = text + image_text | |
text = image_text | |
# remove more than one new line | |
text = re.sub(r"(?<!\n)\n(?!\n)", " ", text) | |
return text | |
# get the answer by passing the context & question to the model | |
def question_answering(context, question): | |
with st.spinner(text="Loading question model..."): | |
question_answerer = qamodel | |
with st.spinner(text="Getting answer..."): | |
answer = question_answerer(context=context, question=question) | |
print(answer) | |
answer_score = str(answer["score"]) | |
answer = answer["answer"] | |
if (answer==""): | |
answer = "CANNOT ANSWER" | |
# display the result in container | |
container = st.container(border=True) | |
container.write("<h5><b>Answer:</b></h5>"+answer+"<p><small>(F1 score: "+answer_score+")</small></p><br>", unsafe_allow_html=True) | |
# def question_answering(context, question): | |
# with st.spinner(text="Loading question model..."): | |
# question_answerer = qamodel | |
# print("loading QA model...") | |
# with st.spinner(text="Getting answer..."): | |
# segment_size = 512 | |
# overlap_size = 32 | |
# text_length = len(context) | |
# segments = [] | |
# # Split context into segments | |
# for i in range(0, text_length, segment_size - overlap_size): | |
# segment_start = i | |
# segment_end = i + segment_size | |
# segment = context[segment_start:segment_end] | |
# segments.append(segment) | |
# answers = {} # Dictionary to store answers for each segment | |
# # Get answers for each segment | |
# for i, segment in enumerate(segments): | |
# answer = question_answerer(context=segment, question=question) | |
# answers[i] = answer | |
# # Find the answer with the highest score | |
# highest_score = -1 | |
# highest_answer = None | |
# for segment_index, answer in answers.items(): | |
# print(answer) | |
# score = answer["score"] | |
# if score > highest_score: | |
# highest_score = score | |
# highest_answer = answer | |
# if highest_answer is not None: | |
# answer = highest_answer["answer"] | |
# if answer == "": | |
# answer = "CANNOT ANSWER" | |
# answer_score = str(highest_answer["score"]) | |
# # Display the result in container | |
# container = st.container(border=True) | |
# container.write("<h5><b>Answer:</b></h5>" + answer + "<p><small>(F1 score: " + answer_score + ")</small></p><br>", | |
# unsafe_allow_html=True) | |
#-------------------- Main Webpage -------------------- | |
# choose the source with different tabs | |
tab1, tab2 = st.tabs(["Input Text", "Upload File"]) | |
#---------- input text ---------- | |
# if type the text as input | |
with tab1: | |
# set the example | |
sample_question = "What is NLP?" | |
with open("sample.txt", "r") as text_file: | |
sample_text = text_file.read() | |
# Get the initial values of context and question | |
context = st.session_state.get("contextInput", "") | |
question = st.session_state.get("questionInput", "") | |
# Button to try the example | |
example = st.button("Try with example") | |
# Update the values if the "Try with example" button is clicked | |
if example: | |
context = sample_text | |
question = sample_question | |
# Display the text area and text input with the updated or default values | |
context = st.text_area("Enter the essay below:", value=context, key="contextInput", height=330) | |
question = st.text_input(label="Enter the question: ", value=question, key="questionInput") | |
# perform question answering when "get answer" button clicked | |
button = st.button("Get answer", key="textInput", type="primary") | |
if button: | |
if context=="" or question=="": | |
st.error ("Please enter BOTH the context and the question", icon="π¨") | |
else: | |
question_answering(context, question) | |
# ---------- upload file ---------- | |
# if upload file as input | |
with tab2: | |
# provide upload place | |
uploaded_file = st.file_uploader("Upload essay in PDF format:", type=["pdf"]) | |
# Create a session-level variable to track the uploaded file | |
if 'file' not in st.session_state: | |
st.session_state.file = None | |
# Create a session-level variable to track if text extraction has been done | |
if 'text_extracted' not in st.session_state: | |
st.session_state.text_extracted = False | |
# Get the initial values of context and question | |
context2 = st.session_state.get("contextInput2", "") | |
question2 = st.session_state.get("questionInput2", "") | |
# transfer file to context and allow ask question, then perform question answering | |
if uploaded_file is not None: | |
if st.session_state.file != uploaded_file: | |
# Update the session state with the new file | |
st.session_state.file = uploaded_file | |
st.session_state.text_extracted = False | |
if not st.session_state.text_extracted: | |
with tempfile.NamedTemporaryFile(delete=False) as temp_file: | |
temp_file.write(uploaded_file.read()) # Save uploaded file to a temporary path | |
raw_text = extract_text(temp_file.name) | |
context2 = raw_text | |
st.session_state.text_extracted = True | |
question2 = st.text_input(label="Enter your question",value=question2, key="questionInput2") | |
context2 = st.text_area("Your essay context: ", value=context2, height=330, key="contextInput2") | |
# perform question answering when "get answer" button clicked | |
button2 = st.button("Get answer", key="fileInput", type="primary") | |
if button2: | |
if context2=="" or question2=="": | |
st.error ("Please enter BOTH the context and the question", icon="π¨") | |
else: | |
question_answering(context2, question2) | |
st.markdown("<p style='text-align:center;'>Β© 20069913D HUI Man Ki - Final Year Project</p>", unsafe_allow_html=True) | |
st.markdown("<br><br><br><br><br>", unsafe_allow_html=True) | |