EmreYY20
add hybrid summarization
5f89cc0
raw
history blame
7.93 kB
import streamlit as st
import re
import PyPDF2
import matplotlib.pyplot as plt
import io
from wordcloud import WordCloud
from PIL import Image
from rouge import Rouge
from datasets import load_dataset
from extractive_summarization import summarize_with_textrank, summarize_with_lsa
from abstractive_summarization import summarize_with_bart_cnn, summarize_with_bart_ft, summarize_with_led, summarize_with_t5
from keyword_extraction import extract_keywords
from keyphrase_extraction import extract_sentences_with_obligations
from hybrid_summarization import summarize_hybrid
#-------------------------------------------------------------------#
# Load in ToS-Summaries dataset
dataset = load_dataset("EE21/ToS-Summaries")
# Extract titles or identifiers for the ToS
tos_titles = [f"Document {i}" for i in range(len(dataset['train']))]
# Set page to wide mode
st.set_page_config(layout="wide")
# Function to handle file upload and return its content
def load_pdf(file):
pdf_reader = PyPDF2.PdfReader(file)
pdf_text = ""
for page_num in range(len(pdf_reader.pages)):
pdf_text += pdf_reader.pages[page_num].extract_text() or ""
return pdf_text
# Main app
def main():
st.title("Terms of Service Summarizer")
# Layout: 3 columns
col1, col2, col3 = st.columns([1, 3, 2], gap="large")
# Left column: Radio buttons for summarizer choice
with col1:
radio_options = ["Hybrid (RAKE + BART Fine-tuned)", "Abstractive (LongT5)", "Abstractive (LED)", 'Abstractive (BART Fine-tuned)', "Abstractive (BART-large-CNN)", 'Extractive (TextRank)',
"Extractive (Latent Semantic Analysis)", 'Keyphrase Extraction (RAKE)', 'Keyword Extraction (RAKE)']
help_text = "Abstractive: Abstractive summarization generates a summary that may contain words not present in the original text. " \
"It uses a fine-tuned model on BART-large-CNN.<br>" \
"Extractive: Extractive summarization selects and extracts sentences or phrases directly from the original text to create a summary using the TextRank algorithm.<br>" \
"Keyword Extraction: Keyword extraction identifies and extracts important keywords or terms from the text using the Rake algorithm. " \
"These keywords can be used for various purposes such as content analysis and SEO.<br>" \
"Keyphrase Extraction: Keyphrase extraction is similar to keyword extraction but focuses on identifying multi-word phrases or expressions that are significant in the text using the Rake algorithm."
radio_selection = st.radio("Choose type of summarizer:", radio_options, help=help_text)
# Middle column: Text input and File uploader
with col2:
user_input = st.text_area("Enter a text")
uploaded_file = st.file_uploader("Upload a PDF", type="pdf")
# Dropdown for selecting the document
tos_selection_index = st.selectbox("Select a Terms of Service Document", range(len(tos_titles)), format_func=lambda x: tos_titles[x])
if st.button("Summarize"):
if uploaded_file and user_input and tos_selection_index:
st.warning("Please provide either text input or a PDF file, not both.")
return
elif uploaded_file:
# Extract text from PDF
file_content = load_pdf(uploaded_file)
st.write("PDF uploaded successfully.")
elif user_input:
file_content = user_input
elif tos_selection_index is not None:
file_content = dataset['train'][tos_selection_index]['plain_text']
else:
st.warning("Please upload a PDF, enter some text, or select a document to summarize.")
return
# Perform hybrid summarization
if radio_selection == "Hybrid (RAKE + BART Fine-tuned)":
summary = summarize_hybrid(file_content)
st.session_state.summary = summary
# Perform extractive summarization
if radio_selection == "Extractive (TextRank)":
summary = summarize_with_textrank(file_content)
st.session_state.summary = summary
# Perform extractive summarization
if radio_selection == "Extractive (Latent Semantic Analysis)":
summary = summarize_with_lsa(file_content)
st.session_state.summary = summary
# Perform abstractive summarization
if radio_selection == "Abstractive (BART Fine-tuned)":
summary = summarize_with_bart_ft(file_content)
st.session_state.summary = summary
# Perform abstractive summarization
if radio_selection == "Abstractive (BART-large-CNN)":
summary = summarize_with_bart_cnn(file_content)
st.session_state.summary = summary
# Perform abstractive summarization
if radio_selection == "Abstractive (LongT5)":
summary = summarize_with_t5(file_content)
st.session_state.summary = summary
# Perform abstractive summarization
if radio_selection == "Abstractive (LED)":
summary = summarize_with_led(file_content)
st.session_state.summary = summary
# Perform Keyword Extraction
if radio_selection == "Keyword Extraction (RAKE)":
summary = extract_keywords(file_content)
st.session_state.summary = summary
# Perform Keyphrase Extraction
if radio_selection == "Keyphrase Extraction (RAKE)":
summary = extract_sentences_with_obligations(file_content)
st.session_state.summary = summary
# Right column: Displaying text after pressing 'Summarize'
with col3:
st.write("Summary:")
if 'summary' in st.session_state:
st.write(st.session_state.summary)
# Generate and display word cloud
wordcloud = WordCloud(width=800, height=400, background_color='white', max_words=20).generate(st.session_state.summary)
# Convert to PIL Image
image = wordcloud.to_image()
# Convert PIL Image to bytes
buf = io.BytesIO()
image.save(buf, format='PNG')
byte_im = buf.getvalue()
st.image(byte_im, caption='Word Cloud of Summary', use_column_width=True)
# Check if no PDF or text input is provided and a ToS document is selected
if not uploaded_file and not user_input and tos_selection_index is not None and 'summary' in dataset['train'][tos_selection_index]:
# Fetch the reference summary
reference_summary = dataset['train'][tos_selection_index]['summary']
# Calculate ROUGE scores
rouge = Rouge()
scores = rouge.get_scores(st.session_state.summary, reference_summary)
# Display ROUGE scores as styled text
col1, col2, col3 = st.columns(3)
with col1:
st.markdown(f"<p style='text-align: center; color: black; border: 1px solid #cccccc; padding: 5px; border-radius: 4px;'>ROUGE-1: {scores[0]['rouge-1']['f']:.4f}</p>", unsafe_allow_html=True)
with col2:
st.markdown(f"<p style='text-align: center; color: black; border: 1px solid #cccccc; padding: 5px; border-radius: 4px;'>ROUGE-2: {scores[0]['rouge-2']['f']:.4f}</p>", unsafe_allow_html=True)
with col3:
st.markdown(f"<p style='text-align: center; color: black; border: 1px solid #cccccc; padding: 5px; border-radius: 4px;'>ROUGE-L: {scores[0]['rouge-l']['f']:.4f}</p>", unsafe_allow_html=True)
if __name__ == "__main__":
main()