awacke1's picture
Create new file
cce4162
import logging
import time
from pathlib import Path
import gradio as gr
import nltk
from cleantext import clean
from summarize import load_model_and_tokenizer, summarize_via_tokenbatches
from utils import load_example_filenames, truncate_word_count
_here = Path(__file__).parent
nltk.download("stopwords") # TODO=find where this requirement originates from
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
def proc_submission(
input_text: str,
model_size: str,
num_beams,
token_batch_length,
length_penalty,
repetition_penalty,
no_repeat_ngram_size,
max_input_length: int = 768,
):
"""
proc_submission - a helper function for the gradio module to process submissions
Args:
input_text (str): the input text to summarize
model_size (str): the size of the model to use
num_beams (int): the number of beams to use
token_batch_length (int): the length of the token batches to use
length_penalty (float): the length penalty to use
repetition_penalty (float): the repetition penalty to use
no_repeat_ngram_size (int): the no repeat ngram size to use
max_input_length (int, optional): the maximum input length to use. Defaults to 768.
Returns:
str in HTML format, string of the summary, str of score
"""
settings = {
"length_penalty": float(length_penalty),
"repetition_penalty": float(repetition_penalty),
"no_repeat_ngram_size": int(no_repeat_ngram_size),
"encoder_no_repeat_ngram_size": 4,
"num_beams": int(num_beams),
"min_length": 4,
"max_length": int(token_batch_length // 4),
"early_stopping": True,
"do_sample": False,
}
st = time.perf_counter()
history = {}
clean_text = clean(input_text, lower=False)
max_input_length = 2048 if model_size == "base" else max_input_length
processed = truncate_word_count(clean_text, max_input_length)
if processed["was_truncated"]:
tr_in = processed["truncated_text"]
msg = f"Input text was truncated to {max_input_length} words (based on whitespace)"
logging.warning(msg)
history["WARNING"] = msg
else:
tr_in = input_text
msg = None
_summaries = summarize_via_tokenbatches(
tr_in,
model_sm if model_size == "base" else model,
tokenizer_sm if model_size == "base" else tokenizer,
batch_length=token_batch_length,
**settings,
)
sum_text = [f"Section {i}: " + s["summary"][0] for i, s in enumerate(_summaries)]
sum_scores = [
f" - Section {i}: {round(s['summary_score'],4)}"
for i, s in enumerate(_summaries)
]
sum_text_out = "\n".join(sum_text)
history["Summary Scores"] = "<br><br>"
scores_out = "\n".join(sum_scores)
rt = round((time.perf_counter() - st) / 60, 2)
print(f"Runtime: {rt} minutes")
html = ""
html += f"<p>Runtime: {rt} minutes on CPU</p>"
if msg is not None:
html += f"<h2>WARNING:</h2><hr><b>{msg}</b><br><br>"
html += ""
return html, sum_text_out, scores_out
def load_single_example_text(
example_path: str or Path,
):
"""
load_single_example - a helper function for the gradio module to load examples
Returns:
list of str, the examples
"""
global name_to_path
full_ex_path = name_to_path[example_path]
full_ex_path = Path(full_ex_path)
# load the examples into a list
with open(full_ex_path, "r", encoding="utf-8", errors="ignore") as f:
raw_text = f.read()
text = clean(raw_text, lower=False)
return text
def load_uploaded_file(file_obj):
"""
load_uploaded_file - process an uploaded file
Args:
file_obj (POTENTIALLY list): Gradio file object inside a list
Returns:
str, the uploaded file contents
"""
# file_path = Path(file_obj[0].name)
# check if mysterious file object is a list
if isinstance(file_obj, list):
file_obj = file_obj[0]
file_path = Path(file_obj.name)
try:
with open(file_path, "r", encoding="utf-8", errors="ignore") as f:
raw_text = f.read()
text = clean(raw_text, lower=False)
return text
except Exception as e:
logging.info(f"Trying to load file with path {file_path}, error: {e}")
return "Error: Could not read file. Ensure that it is a valid text file with encoding UTF-8."
if __name__ == "__main__":
model, tokenizer = load_model_and_tokenizer("pszemraj/led-large-book-summary")
model_sm, tokenizer_sm = load_model_and_tokenizer("pszemraj/led-base-book-summary")
name_to_path = load_example_filenames(_here / "examples")
logging.info(f"Loaded {len(name_to_path)} examples")
demo = gr.Blocks()
with demo:
gr.Markdown("# Long-Form Summarization: LED & BookSum")
gr.Markdown(
"A simple demo using a fine-tuned LED model to summarize long-form text. See [model card](https://huggingface.co./pszemraj/led-large-book-summary) for a notebook with GPU inference (much faster) on Colab."
)
with gr.Column():
gr.Markdown("## Load Inputs & Select Parameters")
gr.Markdown(
"Enter text below in the text area. The text will be summarized [using the selected parameters](https://huggingface.co./blog/how-to-generate). Optionally load an example below or upload a file."
)
with gr.Row():
model_size = gr.Radio(
choices=["base", "large"], label="Model Variant", value="large"
)
num_beams = gr.Radio(
choices=[2, 3, 4],
label="Beam Search: # of Beams",
value=2,
)
gr.Markdown(
"_The base model is less performant than the large model, but is faster and will accept up to 2048 words per input (Large model accepts up to 768)._"
)
with gr.Row():
length_penalty = gr.inputs.Slider(
minimum=0.5,
maximum=1.0,
label="length penalty",
default=0.7,
step=0.05,
)
token_batch_length = gr.Radio(
choices=[512, 768, 1024],
label="token batch length",
value=512,
)
with gr.Row():
repetition_penalty = gr.inputs.Slider(
minimum=1.0,
maximum=5.0,
label="repetition penalty",
default=3.5,
step=0.1,
)
no_repeat_ngram_size = gr.Radio(
choices=[2, 3, 4],
label="no repeat ngram size",
value=3,
)
with gr.Row():
example_name = gr.Dropdown(
list(name_to_path.keys()),
label="Choose an Example",
)
load_examples_button = gr.Button(
"Load Example",
)
input_text = gr.Textbox(
lines=6,
label="Input Text (for summarization)",
placeholder="Enter text to summarize, the text will be cleaned and truncated on Spaces. Narrative, academic (both papers and lecture transcription), and article text work well. May take a bit to generate depending on the input text :)",
)
gr.Markdown("Upload your own file:")
with gr.Row():
uploaded_file = gr.File(
label="Upload a text file",
file_count="single",
type="file",
)
load_file_button = gr.Button("Load Uploaded File")
gr.Markdown("---")
with gr.Column():
gr.Markdown("## Generate Summary")
gr.Markdown(
"Summary generation should take approximately 1-2 minutes for most settings."
)
summarize_button = gr.Button(
"Summarize!",
variant="primary",
)
output_text = gr.HTML("<p><em>Output will appear below:</em></p>")
gr.Markdown("### Summary Output")
summary_text = gr.Textbox(
label="Summary", placeholder="The generated summary will appear here"
)
gr.Markdown(
"The summary scores can be thought of as representing the quality of the summary. less-negative numbers (closer to 0) are better:"
)
summary_scores = gr.Textbox(
label="Summary Scores", placeholder="Summary scores will appear here"
)
gr.Markdown("---")
with gr.Column():
gr.Markdown("## About the Model")
gr.Markdown(
"- [This model](https://huggingface.co./pszemraj/led-large-book-summary) is a fine-tuned checkpoint of [allenai/led-large-16384](https://huggingface.co./allenai/led-large-16384) on the [BookSum dataset](https://arxiv.org/abs/2105.08209).The goal was to create a model that can generalize well and is useful in summarizing lots of text in academic and daily usage."
)
gr.Markdown(
"- The two most important parameters-empirically-are the `num_beams` and `token_batch_length`. However, increasing these will also increase the amount of time it takes to generate a summary. The `length_penalty` and `repetition_penalty` parameters are also important for the model to generate good summaries."
)
gr.Markdown(
"- The model can be used with tag [pszemraj/led-large-book-summary](https://huggingface.co./pszemraj/led-large-book-summary). See the model card for details on usage & a notebook for a tutorial."
)
gr.Markdown("---")
load_examples_button.click(
fn=load_single_example_text, inputs=[example_name], outputs=[input_text]
)
load_file_button.click(
fn=load_uploaded_file, inputs=uploaded_file, outputs=[input_text]
)
summarize_button.click(
fn=proc_submission,
inputs=[
input_text,
model_size,
num_beams,
token_batch_length,
length_penalty,
repetition_penalty,
no_repeat_ngram_size,
],
outputs=[output_text, summary_text, summary_scores],
)
demo.launch(enable_queue=True, share=True)