paraphrasing / app.py
EmTpro01's picture
Update app.py
8d8d3b6 verified
raw
history blame
2.92 kB
import streamlit as st
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
# Set page configuration
st.set_page_config(page_title="Gemma Paraphraser", page_icon="✍️")
# Load model and tokenizer
@st.cache_resource
def load_model():
model_name = "EmTpro01/gemma-paraphraser-16bit"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="cpu",
torch_dtype=torch.float16
)
return model, tokenizer
# Paraphrase function
def paraphrase_text(text, model, tokenizer):
# Prepare the prompt using Alpaca format
system_prompt = "Below is provided a paragraph, paraphrase it"
prompt = f"{system_prompt}\n\n### Input:\n{text}\n\n### Output:\n"
# Tokenize input
inputs = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
# Generate paraphrased text
outputs = model.generate(
inputs.input_ids,
max_length=512, # Adjust based on your needs
num_return_sequences=1,
temperature=0.7,
do_sample=True
)
# Decode and clean the output
paraphrased = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract the output part (after "### Output:")
output_start = paraphrased.find("### Output:") + len("### Output:")
paraphrased_text = paraphrased[output_start:].strip()
return paraphrased_text
# Streamlit App
def main():
st.title("πŸ“ Gemma Paraphraser")
st.write("Paraphrase your text using the Gemma model")
# Load model
try:
model, tokenizer = load_model()
except Exception as e:
st.error(f"Error loading model: {e}")
return
# Input text area
input_text = st.text_area("Enter text to paraphrase:", height=200)
# Paraphrase button
if st.button("Paraphrase"):
if input_text:
with st.spinner("Generating paraphrase..."):
try:
paraphrased_text = paraphrase_text(input_text, model, tokenizer)
# Display results
st.subheader("Paraphrased Text:")
st.write(paraphrased_text)
# Optional: Copy to clipboard
st.button("Copy to Clipboard",
on_click=lambda: st.write(paraphrased_text))
except Exception as e:
st.error(f"Error during paraphrasing: {e}")
else:
st.warning("Please enter some text to paraphrase.")
# Additional information
st.sidebar.info(
"Model: EmTpro01/gemma-paraphraser-16bit\n\n"
"Tips:\n"
"- Enter a paragraph to paraphrase\n"
"- Click 'Paraphrase' to generate\n"
"- Running on CPU with 16-bit precision"
)
if __name__ == "__main__":
main()