import io from PIL import Image from torchvision.transforms import CenterCrop, ConvertImageDtype, Normalize, Resize, ToTensor, Compose from torchvision.transforms.functional import InterpolationMode import torch import numpy as np from transformers import MarianTokenizer from flax_clip_vision_marian.modeling_clip_vision_marian import FlaxCLIPVisionMarianForConditionalGeneration import logging import streamlit as st from mtranslate import translate class CaptionGenerator: def __init__(self): self.tokenizer = None self.clip_marian_model = None self.marian_model_name = 'Helsinki-NLP/opus-mt-en-id' self.clip_marian_model_name = 'flax-community/Image-captioning-Indonesia' self.config = None self.image_size = None self.custom_transforms = None def load(self): logging.info("Loading tokenizer...") marian_model_name = 'Helsinki-NLP/opus-mt-en-id' self.tokenizer = MarianTokenizer.from_pretrained(self.marian_model_name) logging.info("Tokenizer loaded.") logging.info("Loading model...") self.model = FlaxCLIPVisionMarianForConditionalGeneration.from_pretrained(self.clip_marian_model_name) logging.info("Model loaded.") self.config = self.model.config self.image_size = self.config.clip_vision_config.image_size self.custom_transforms = torch.nn.Sequential( Resize([self.image_size], interpolation=InterpolationMode.BICUBIC), CenterCrop(self.image_size), ConvertImageDtype(torch.float), Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), ) def process_image(self, file): logging.info("Loading image...") image_data = file.read() input_image = Image.open(io.BytesIO(image_data)).convert("RGB") loader = Compose([ToTensor()]) image = loader(input_image) image = self.custom_transforms(image) pixel_values = torch.stack([image]).permute(0, 2, 3, 1).numpy() logging.info("Image loaded.") return pixel_values def generate_step(self, pixel_values, max_len, num_beams): gen_kwargs = {"max_length": max_len , "num_beams": num_beams} logging.info("Generating caption...") output_ids = self.model.generate(pixel_values, **gen_kwargs) token_ids = np.array(output_ids.sequences)[0] caption = self.tokenizer.decode(token_ids) logging.info("Caption generated.") return caption def get_caption(self, file, max_len, num_beams): pixel_values = self.process_image(file) generated_ids = self.generate_step(pixel_values, max_len, num_beams) return generated_ids @st.cache(allow_output_mutation=True) def load_caption_generator(): generator = CaptionGenerator() generator.load() return generator def main(): st.set_page_config(page_title="Indonesian Image Captioning Demo", page_icon="🖼️") generator = load_caption_generator() st.title("Indonesian Image Captioning Demo") st.markdown( """Indonesian image captioning demo, trained on [CLIP](https://huggingface.co./transformers/model_doc/clip.html) and [Marian](https://huggingface.co./transformers/model_doc/marian.html). Part of the [Huggingface JAX/Flax event](https://discuss.huggingface.co/t/open-to-the-community-community-week-using-jax-flax-for-nlp-cv/). """ ) st.sidebar.subheader("Configurable parameters") max_len = st.sidebar.number_input( "Maximum length", value=8, help="The maximum length of the sequence (caption) to be generated." ) num_beams = st.sidebar.number_input( "Number of beams", value=4, help="Number of beams for beam search. 1 means no beam search." ) input_image = st.file_uploader("Insert image") if st.button("Run"): with st.spinner(text="Getting results..."): if input_image: caption = generator.get_caption(file=input_image, max_len=max_len, num_beams=num_beams) st.subheader("Result") st.write(caption.replace("", "")) st.text("English translation") st.write(translate(caption, "en", "id").replace("", "")) else: st.write("Please upload an image.") if __name__ == '__main__': main()