|
import torch |
|
import gradio as gr |
|
from transformers import pipeline, AutoModelForSeq2SeqLM, AutoTokenizer, AutoConfig |
|
|
|
|
|
model_names = [ |
|
"google/bigbird-pegasus-large-arxiv", |
|
"facebook/bart-large-cnn", |
|
"google/t5-v1_1-large", |
|
"sshleifer/distilbart-cnn-12-6", |
|
"allenai/led-base-16384", |
|
"google/pegasus-xsum", |
|
"togethercomputer/LLaMA-2-7B-32K" |
|
] |
|
|
|
|
|
summarizer = None |
|
tokenizer = None |
|
max_tokens = None |
|
|
|
|
|
|
|
def load_model(model_name): |
|
global summarizer, tokenizer, max_tokens |
|
try: |
|
|
|
summarizer = pipeline("summarization", model=model_name, torch_dtype=torch.float32) |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
config = AutoConfig.from_pretrained(model_name) |
|
|
|
|
|
max_tokens = getattr(config, 'max_position_embeddings', 1024) |
|
|
|
return f"Model {model_name} loaded successfully! Max tokens: {max_tokens}" |
|
except Exception as e: |
|
return f"Failed to load model {model_name}. Error: {str(e)}" |
|
|
|
|
|
|
|
def summarize_text(input, min_length, max_length): |
|
if summarizer is None: |
|
return "No model loaded!" |
|
|
|
try: |
|
|
|
input_tokens = tokenizer.encode(input, return_tensors="pt") |
|
num_tokens = input_tokens.shape[1] |
|
if num_tokens > max_tokens: |
|
return f"Error: Input exceeds the max token limit of {max_tokens}." |
|
|
|
|
|
min_summary_length = max(10, int(num_tokens * (min_length / 100))) |
|
max_summary_length = min(max_tokens, int(num_tokens * (max_length / 100))) |
|
|
|
|
|
output = summarizer(input, min_length=min_summary_length, max_length=max_summary_length, truncation=True) |
|
return output[0]['summary_text'] |
|
except Exception as e: |
|
return f"Summarization failed: {str(e)}" |
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
with gr.Row(): |
|
model_dropdown = gr.Dropdown(choices=model_names, label="Choose a model", value="sshleifer/distilbart-cnn-12-6") |
|
load_button = gr.Button("Load Model") |
|
|
|
load_message = gr.Textbox(label="Load Status", interactive=False) |
|
|
|
min_length_slider = gr.Slider(minimum=0, maximum=100, step=1, label="Minimum Summary Length (%)", value=10) |
|
max_length_slider = gr.Slider(minimum=0, maximum=100, step=1, label="Maximum Summary Length (%)", value=20) |
|
|
|
input_text = gr.Textbox(label="Input text to summarize", lines=6) |
|
summarize_button = gr.Button("Summarize Text") |
|
output_text = gr.Textbox(label="Summarized text", lines=4) |
|
|
|
load_button.click(fn=load_model, inputs=model_dropdown, outputs=load_message) |
|
summarize_button.click(fn=summarize_text, inputs=[input_text, min_length_slider, max_length_slider], |
|
outputs=output_text) |
|
|
|
demo.launch() |
|
|