Spaces:
Sleeping
Sleeping
import streamlit as st | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import torch | |
# Set page configuration | |
st.set_page_config(page_title="Gemma Paraphraser", page_icon="βοΈ") | |
# Load model and tokenizer | |
def load_model(): | |
model_name = "EmTpro01/gemma-paraphraser-16bit" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
device_map="cpu", | |
torch_dtype=torch.float16 | |
) | |
return model, tokenizer | |
# Paraphrase function | |
def paraphrase_text(text, model, tokenizer): | |
# Prepare the prompt using Alpaca format | |
system_prompt = "Below is provided a paragraph, paraphrase it" | |
prompt = f"{system_prompt}\n\n### Input:\n{text}\n\n### Output:\n" | |
# Tokenize input | |
inputs = tokenizer(prompt, return_tensors="pt", add_special_tokens=True) | |
# Generate paraphrased text | |
outputs = model.generate( | |
inputs.input_ids, | |
max_length=512, # Adjust based on your needs | |
num_return_sequences=1, | |
temperature=0.7, | |
do_sample=True | |
) | |
# Decode and clean the output | |
paraphrased = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Extract the output part (after "### Output:") | |
output_start = paraphrased.find("### Output:") + len("### Output:") | |
paraphrased_text = paraphrased[output_start:].strip() | |
return paraphrased_text | |
# Streamlit App | |
def main(): | |
st.title("π Gemma Paraphraser") | |
st.write("Paraphrase your text using the Gemma model") | |
# Load model | |
try: | |
model, tokenizer = load_model() | |
except Exception as e: | |
st.error(f"Error loading model: {e}") | |
return | |
# Input text area | |
input_text = st.text_area("Enter text to paraphrase:", height=200) | |
# Paraphrase button | |
if st.button("Paraphrase"): | |
if input_text: | |
with st.spinner("Generating paraphrase..."): | |
try: | |
paraphrased_text = paraphrase_text(input_text, model, tokenizer) | |
# Display results | |
st.subheader("Paraphrased Text:") | |
st.write(paraphrased_text) | |
# Optional: Copy to clipboard | |
st.button("Copy to Clipboard", | |
on_click=lambda: st.write(paraphrased_text)) | |
except Exception as e: | |
st.error(f"Error during paraphrasing: {e}") | |
else: | |
st.warning("Please enter some text to paraphrase.") | |
# Additional information | |
st.sidebar.info( | |
"Model: EmTpro01/gemma-paraphraser-16bit\n\n" | |
"Tips:\n" | |
"- Enter a paragraph to paraphrase\n" | |
"- Click 'Paraphrase' to generate\n" | |
"- Running on CPU with 16-bit precision" | |
) | |
if __name__ == "__main__": | |
main() |