Image-captioning-Indonesia
This is an encoder-decoder image captioning model using CLIP as the visual encoder and Marian as the textual decoder on datasets with Indonesian captions.
This model was trained using HuggingFace's Flax framework and is part of the JAX/Flax Community Week organized by HuggingFace. All training was done on a TPUv3-8 VM sponsored by the Google Cloud team.
How to use
At time of writing, you will need to install HuggingFace from its latest master branch in order to load FlaxMarian
.
You will also need to have the flax_clip_vision_marian
folder in your project directory to load the model using the FlaxCLIPVisionMarianForConditionalGeneration
class.
from torchvision.io import ImageReadMode, read_image
from torchvision.transforms import CenterCrop, ConvertImageDtype, Normalize, Resize
from torchvision.transforms.functional import InterpolationMode
import torch
import numpy as np
from transformers import MarianTokenizer
from flax_clip_vision_marian.modeling_clip_vision_marian import FlaxCLIPVisionMarianForConditionalGeneration
clip_marian_model_name = 'flax-community/Image-captioning-Indonesia'
model = FlaxCLIPVisionMarianForConditionalGeneration.from_pretrained(clip_marian_model_name)
marian_model_name = 'Helsinki-NLP/opus-mt-en-id'
tokenizer = MarianTokenizer.from_pretrained(marian_model_name)
config = model.config
image_size = config.clip_vision_config.image_size
# Image transformation
transforms = torch.nn.Sequential(
Resize([image_size], interpolation=InterpolationMode.BICUBIC),
CenterCrop(image_size),
ConvertImageDtype(torch.float),
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
)
# Hyperparameters
max_length = 8
num_beams = 4
gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
def generate_step(batch):
output_ids = model.generate(pixel_values, **gen_kwargs)
token_ids = np.array(output_ids.sequences)[0]
caption = tokenizer.decode(token_ids)
return caption
image_file_path = image_file_path
image = read_image(image_file_path, mode=ImageReadMode.RGB)
image = transforms(image)
pixel_values = torch.stack([image]).permute(0, 2, 3, 1).numpy()
generated_ids = generate_step(pixel_values)
print(generated_ids)
Training data
The Model was trained on translated Coco,Flickr and ViZWiz, each of them were translated using google translate and marian mt. we took only random 2 captions per image for each datasets
Training procedure
The model was trained on a TPUv3-8 VM provided by the Google Cloud team.
Team members
- Cahya Wirawan (@cahya)
- Galuh Sahid (@Galuh)
- Muhammad Agung Hambali (@AyameRushia)
- Samsul Rahmadani (@munggok)
- Downloads last month
- 1