import torch from transformers import AutoProcessor, AutoModel, VisionEncoderDecoderModel, ViTFeatureExtractor, AutoTokenizer from PIL import Image import streamlit as st # Load the saved model state dictionary model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning") model.load_state_dict(torch.load("model.pth", map_location=torch.device('cpu'))) # Load the necessary components feature_extractor = ViTFeatureExtractor.from_pretrained("nlpconnect/vit-gpt2-image-captioning") tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning") # Function to generate a caption for an image @st.cache_resource def generate_caption(image): pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values output_ids = model.generate(pixel_values, max_length=100, num_beams=5, early_stopping=True) caption = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0] return caption # Streamlit app def main(): st.title("Image Captioning") uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"]) if uploaded_file is not None: image = Image.open(uploaded_file) st.image(image, caption="Uploaded Image", use_column_width=True) caption = generate_caption(image) st.write(f"Caption: {caption}") if __name__ == "__main__": main()