Chan-Y's picture
Update app.py
162dd8b verified
raw
history blame
No virus
2.61 kB
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from langchain.prompts import PromptTemplate
from langchain.chains.summarize import load_summarize_chain
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_core.documents import Document
from pathlib import Path
# Load the Mistral model from Hugging Face
model_name = "mistralai/Mistral-7B-Instruct-v0.3"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
# Define the text splitter and summarize chain
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
# Define the summarization function
def summarize(file, n_words):
# Read the content of the uploaded file
file_path = file.name
with open(file_path, 'r', encoding='utf-8') as f:
file_content = f.read()
# Split the content into chunks
chunks = text_splitter.create_documents([file_content])
# Summarize each chunk and concatenate the results
summaries = []
for chunk in chunks:
inputs = tokenizer(chunk.text, return_tensors="pt", max_length=512, truncation=True)
summary_ids = model.generate(inputs["input_ids"], max_length=n_words, num_beams=4, early_stopping=True)
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
summaries.append(summary)
return " ".join(summaries)
# Define the download summary function
def download_summary(output_text):
if output_text:
file_path = Path('summary.txt')
with open(file_path, 'w', encoding='utf-8') as f:
f.write(output_text)
return file_path
else:
return None
def create_download_file(summary_text):
file_path = download_summary(summary_text)
return str(file_path) if file_path else None
# Create the Gradio interface
with gr.Blocks() as demo:
gr.Markdown("## Document Summarizer")
with gr.Row():
with gr.Column():
n_words = gr.Slider(minimum=50, maximum=500, step=50, label="Number of words")
file = gr.File(label="Submit a file")
with gr.Column():
output_text = gr.Textbox(label="Summary will be printed here", lines=20)
submit_button = gr.Button("Summarize")
submit_button.click(summarize, inputs=[file, n_words], outputs=output_text)
download_button = gr.Button("Download Summary")
download_button.click(
fn=create_download_file,
inputs=[output_text],
outputs=gr.File()
)
# Run the Gradio app
demo.launch(share=True)