Spaces:
Sleeping
Sleeping
File size: 2,919 Bytes
8b467d6 8d8d3b6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
import streamlit as st
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
# Set page configuration
st.set_page_config(page_title="Gemma Paraphraser", page_icon="✍️")
# Load model and tokenizer
@st.cache_resource
def load_model():
model_name = "EmTpro01/gemma-paraphraser-16bit"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="cpu",
torch_dtype=torch.float16
)
return model, tokenizer
# Paraphrase function
def paraphrase_text(text, model, tokenizer):
# Prepare the prompt using Alpaca format
system_prompt = "Below is provided a paragraph, paraphrase it"
prompt = f"{system_prompt}\n\n### Input:\n{text}\n\n### Output:\n"
# Tokenize input
inputs = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
# Generate paraphrased text
outputs = model.generate(
inputs.input_ids,
max_length=512, # Adjust based on your needs
num_return_sequences=1,
temperature=0.7,
do_sample=True
)
# Decode and clean the output
paraphrased = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract the output part (after "### Output:")
output_start = paraphrased.find("### Output:") + len("### Output:")
paraphrased_text = paraphrased[output_start:].strip()
return paraphrased_text
# Streamlit App
def main():
st.title("📝 Gemma Paraphraser")
st.write("Paraphrase your text using the Gemma model")
# Load model
try:
model, tokenizer = load_model()
except Exception as e:
st.error(f"Error loading model: {e}")
return
# Input text area
input_text = st.text_area("Enter text to paraphrase:", height=200)
# Paraphrase button
if st.button("Paraphrase"):
if input_text:
with st.spinner("Generating paraphrase..."):
try:
paraphrased_text = paraphrase_text(input_text, model, tokenizer)
# Display results
st.subheader("Paraphrased Text:")
st.write(paraphrased_text)
# Optional: Copy to clipboard
st.button("Copy to Clipboard",
on_click=lambda: st.write(paraphrased_text))
except Exception as e:
st.error(f"Error during paraphrasing: {e}")
else:
st.warning("Please enter some text to paraphrase.")
# Additional information
st.sidebar.info(
"Model: EmTpro01/gemma-paraphraser-16bit\n\n"
"Tips:\n"
"- Enter a paragraph to paraphrase\n"
"- Click 'Paraphrase' to generate\n"
"- Running on CPU with 16-bit precision"
)
if __name__ == "__main__":
main() |