artigen / app.py
CallmeKaito's picture
Update app.py
e67b575 verified
raw
history blame
1.39 kB
import torch
from transformers import AutoProcessor, AutoModel, VisionEncoderDecoderModel, ViTFeatureExtractor, AutoTokenizer
from PIL import Image
import streamlit as st
# Load the saved model state dictionary
model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
model.load_state_dict(torch.load("model.pth", map_location=torch.device('cpu')))
# Load the necessary components
feature_extractor = ViTFeatureExtractor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
# Function to generate a caption for an image
@st.cache_resource
def generate_caption(image):
pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
output_ids = model.generate(pixel_values, max_length=100, num_beams=5, early_stopping=True)
caption = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
return caption
# Streamlit app
def main():
st.title("Image Captioning")
uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
image = Image.open(uploaded_file)
st.image(image, caption="Uploaded Image", use_column_width=True)
caption = generate_caption(image)
st.write(f"Caption: {caption}")
if __name__ == "__main__":
main()