SnapText / app.py
hruday96's picture
Update app.py
a0d0642 verified
import streamlit as st # Don't forget to include `streamlit` in your `requirements.txt` file to ensure the app runs properly on Hugging Face Spaces.
from transformers import AutoProcessor, AutoModelForImageTextToText # Updated imports to reflect changes
from PIL import Image # Ensure the `pillow` library is included in your `requirements.txt`.
import torch # Since PyTorch is required for this app, specify the appropriate version of `torch` in `requirements.txt` based on compatibility with the model.
import os
def load_model():
"""Load PaliGemma2 model and processor with Hugging Face token."""
token = os.getenv("HUGGINGFACEHUB_API_TOKEN") # Retrieve token from environment variable
if not token:
raise ValueError("Hugging Face API token not found. Please set it in the environment variables.")
# Load the processor and model using the correct identifier
processor = AutoProcessor.from_pretrained("google/paligemma2-3b-pt-224", use_auth_token=token)
model = AutoModelForImageTextToText.from_pretrained("google/paligemma2-3b-pt-224", use_auth_token=token)
return processor, model
def process_image(image, processor, model):
"""Extract text from image using PaliGemma2."""
# Preprocess the image
inputs = processor(images=image, return_tensors="pt")
# Generate predictions
with torch.no_grad():
generated_ids = model.generate(**inputs)
text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
return text
def main():
# Set page configuration
st.set_page_config(page_title="Text Reading with PaliGemma2", layout="centered")
st.title("Text Reading from Images using PaliGemma2")
# Load model and processor
with st.spinner("Loading PaliGemma2 model... This may take a few moments."):
try:
processor, model = load_model()
st.success("Model loaded successfully!")
except ValueError as e:
st.error(str(e))
st.stop()
# User input: upload image
uploaded_image = st.file_uploader("Upload an image containing text", type=["png", "jpg", "jpeg"])
if uploaded_image is not None:
# Display uploaded image
image = Image.open(uploaded_image)
st.image(image, caption="Uploaded Image", use_column_width=True)
# Extract text button
if st.button("Extract Text"):
with st.spinner("Processing image..."):
extracted_text = process_image(image, processor, model)
st.success("Text extraction complete!")
st.subheader("Extracted Text")
st.write(extracted_text)
# Footer
st.markdown("---")
st.markdown("**Built with [PaliGemma2](https://huggingface.co./google/paligemma2-3b-pt-224) and Streamlit**")
if __name__ == "__main__":
main()