import logging import time from pathlib import Path import gradio as gr import nltk from cleantext import clean from summarize import load_model_and_tokenizer, summarize_via_tokenbatches from utils import load_example_filenames, truncate_word_count _here = Path(__file__).parent nltk.download("stopwords") # TODO=find where this requirement originates from logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" ) def proc_submission( input_text: str, model_size: str, num_beams, token_batch_length, length_penalty, repetition_penalty, no_repeat_ngram_size, max_input_length: int = 768, ): """ proc_submission - a helper function for the gradio module to process submissions Args: input_text (str): the input text to summarize model_size (str): the size of the model to use num_beams (int): the number of beams to use token_batch_length (int): the length of the token batches to use length_penalty (float): the length penalty to use repetition_penalty (float): the repetition penalty to use no_repeat_ngram_size (int): the no repeat ngram size to use max_input_length (int, optional): the maximum input length to use. Defaults to 768. Returns: str in HTML format, string of the summary, str of score """ settings = { "length_penalty": float(length_penalty), "repetition_penalty": float(repetition_penalty), "no_repeat_ngram_size": int(no_repeat_ngram_size), "encoder_no_repeat_ngram_size": 4, "num_beams": int(num_beams), "min_length": 4, "max_length": int(token_batch_length // 4), "early_stopping": True, "do_sample": False, } st = time.perf_counter() history = {} clean_text = clean(input_text, lower=False) max_input_length = 2048 if model_size == "base" else max_input_length processed = truncate_word_count(clean_text, max_input_length) if processed["was_truncated"]: tr_in = processed["truncated_text"] msg = f"Input text was truncated to {max_input_length} words (based on whitespace)" logging.warning(msg) history["WARNING"] = msg else: tr_in = input_text msg = None _summaries = summarize_via_tokenbatches( tr_in, model_sm if model_size == "base" else model, tokenizer_sm if model_size == "base" else tokenizer, batch_length=token_batch_length, **settings, ) sum_text = [f"Section {i}: " + s["summary"][0] for i, s in enumerate(_summaries)] sum_scores = [ f" - Section {i}: {round(s['summary_score'],4)}" for i, s in enumerate(_summaries) ] sum_text_out = "\n".join(sum_text) history["Summary Scores"] = "

" scores_out = "\n".join(sum_scores) rt = round((time.perf_counter() - st) / 60, 2) print(f"Runtime: {rt} minutes") html = "" html += f"

Runtime: {rt} minutes on CPU

" if msg is not None: html += f"

WARNING:


{msg}

" html += "" return html, sum_text_out, scores_out def load_single_example_text( example_path: str or Path, ): """ load_single_example - a helper function for the gradio module to load examples Returns: list of str, the examples """ global name_to_path full_ex_path = name_to_path[example_path] full_ex_path = Path(full_ex_path) # load the examples into a list with open(full_ex_path, "r", encoding="utf-8", errors="ignore") as f: raw_text = f.read() text = clean(raw_text, lower=False) return text def load_uploaded_file(file_obj): """ load_uploaded_file - process an uploaded file Args: file_obj (POTENTIALLY list): Gradio file object inside a list Returns: str, the uploaded file contents """ # file_path = Path(file_obj[0].name) # check if mysterious file object is a list if isinstance(file_obj, list): file_obj = file_obj[0] file_path = Path(file_obj.name) try: with open(file_path, "r", encoding="utf-8", errors="ignore") as f: raw_text = f.read() text = clean(raw_text, lower=False) return text except Exception as e: logging.info(f"Trying to load file with path {file_path}, error: {e}") return "Error: Could not read file. Ensure that it is a valid text file with encoding UTF-8." if __name__ == "__main__": model, tokenizer = load_model_and_tokenizer("pszemraj/led-large-book-summary") model_sm, tokenizer_sm = load_model_and_tokenizer("pszemraj/led-base-book-summary") name_to_path = load_example_filenames(_here / "examples") logging.info(f"Loaded {len(name_to_path)} examples") demo = gr.Blocks() with demo: gr.Markdown("# Long-Form Summarization: LED & BookSum") gr.Markdown( "A simple demo using a fine-tuned LED model to summarize long-form text. See [model card](https://huggingface.co./pszemraj/led-large-book-summary) for a notebook with GPU inference (much faster) on Colab." ) with gr.Column(): gr.Markdown("## Load Inputs & Select Parameters") gr.Markdown( "Enter text below in the text area. The text will be summarized [using the selected parameters](https://huggingface.co./blog/how-to-generate). Optionally load an example below or upload a file." ) with gr.Row(): model_size = gr.Radio( choices=["base", "large"], label="Model Variant", value="large" ) num_beams = gr.Radio( choices=[2, 3, 4], label="Beam Search: # of Beams", value=2, ) gr.Markdown( "_The base model is less performant than the large model, but is faster and will accept up to 2048 words per input (Large model accepts up to 768)._" ) with gr.Row(): length_penalty = gr.inputs.Slider( minimum=0.5, maximum=1.0, label="length penalty", default=0.7, step=0.05, ) token_batch_length = gr.Radio( choices=[512, 768, 1024], label="token batch length", value=512, ) with gr.Row(): repetition_penalty = gr.inputs.Slider( minimum=1.0, maximum=5.0, label="repetition penalty", default=3.5, step=0.1, ) no_repeat_ngram_size = gr.Radio( choices=[2, 3, 4], label="no repeat ngram size", value=3, ) with gr.Row(): example_name = gr.Dropdown( list(name_to_path.keys()), label="Choose an Example", ) load_examples_button = gr.Button( "Load Example", ) input_text = gr.Textbox( lines=6, label="Input Text (for summarization)", placeholder="Enter text to summarize, the text will be cleaned and truncated on Spaces. Narrative, academic (both papers and lecture transcription), and article text work well. May take a bit to generate depending on the input text :)", ) gr.Markdown("Upload your own file:") with gr.Row(): uploaded_file = gr.File( label="Upload a text file", file_count="single", type="file", ) load_file_button = gr.Button("Load Uploaded File") gr.Markdown("---") with gr.Column(): gr.Markdown("## Generate Summary") gr.Markdown( "Summary generation should take approximately 1-2 minutes for most settings." ) summarize_button = gr.Button( "Summarize!", variant="primary", ) output_text = gr.HTML("

Output will appear below:

") gr.Markdown("### Summary Output") summary_text = gr.Textbox( label="Summary", placeholder="The generated summary will appear here" ) gr.Markdown( "The summary scores can be thought of as representing the quality of the summary. less-negative numbers (closer to 0) are better:" ) summary_scores = gr.Textbox( label="Summary Scores", placeholder="Summary scores will appear here" ) gr.Markdown("---") with gr.Column(): gr.Markdown("## About the Model") gr.Markdown( "- [This model](https://huggingface.co./pszemraj/led-large-book-summary) is a fine-tuned checkpoint of [allenai/led-large-16384](https://huggingface.co./allenai/led-large-16384) on the [BookSum dataset](https://arxiv.org/abs/2105.08209).The goal was to create a model that can generalize well and is useful in summarizing lots of text in academic and daily usage." ) gr.Markdown( "- The two most important parameters-empirically-are the `num_beams` and `token_batch_length`. However, increasing these will also increase the amount of time it takes to generate a summary. The `length_penalty` and `repetition_penalty` parameters are also important for the model to generate good summaries." ) gr.Markdown( "- The model can be used with tag [pszemraj/led-large-book-summary](https://huggingface.co./pszemraj/led-large-book-summary). See the model card for details on usage & a notebook for a tutorial." ) gr.Markdown("---") load_examples_button.click( fn=load_single_example_text, inputs=[example_name], outputs=[input_text] ) load_file_button.click( fn=load_uploaded_file, inputs=uploaded_file, outputs=[input_text] ) summarize_button.click( fn=proc_submission, inputs=[ input_text, model_size, num_beams, token_batch_length, length_penalty, repetition_penalty, no_repeat_ngram_size, ], outputs=[output_text, summary_text, summary_scores], ) demo.launch(enable_queue=True, share=True)