Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
from langchain.prompts import PromptTemplate | |
from langchain.chains.summarize import load_summarize_chain | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_core.documents import Document | |
from pathlib import Path | |
# Load the Mistral model from Hugging Face | |
model_name = "mistralai/Mistral-7B-Instruct-v0.3" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForSeq2SeqLM.from_pretrained(model_name) | |
# Define the text splitter and summarize chain | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) | |
# Define the summarization function | |
def summarize(file, n_words): | |
# Read the content of the uploaded file | |
file_path = file.name | |
with open(file_path, 'r', encoding='utf-8') as f: | |
file_content = f.read() | |
# Split the content into chunks | |
chunks = text_splitter.create_documents([file_content]) | |
# Summarize each chunk and concatenate the results | |
summaries = [] | |
for chunk in chunks: | |
inputs = tokenizer(chunk.text, return_tensors="pt", max_length=512, truncation=True) | |
summary_ids = model.generate(inputs["input_ids"], max_length=n_words, num_beams=4, early_stopping=True) | |
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True) | |
summaries.append(summary) | |
return " ".join(summaries) | |
# Define the download summary function | |
def download_summary(output_text): | |
if output_text: | |
file_path = Path('summary.txt') | |
with open(file_path, 'w', encoding='utf-8') as f: | |
f.write(output_text) | |
return file_path | |
else: | |
return None | |
def create_download_file(summary_text): | |
file_path = download_summary(summary_text) | |
return str(file_path) if file_path else None | |
# Create the Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("## Document Summarizer") | |
with gr.Row(): | |
with gr.Column(): | |
n_words = gr.Slider(minimum=50, maximum=500, step=50, label="Number of words") | |
file = gr.File(label="Submit a file") | |
with gr.Column(): | |
output_text = gr.Textbox(label="Summary will be printed here", lines=20) | |
submit_button = gr.Button("Summarize") | |
submit_button.click(summarize, inputs=[file, n_words], outputs=output_text) | |
download_button = gr.Button("Download Summary") | |
download_button.click( | |
fn=create_download_file, | |
inputs=[output_text], | |
outputs=gr.File() | |
) | |
# Run the Gradio app | |
demo.launch(share=True) | |