mys's picture
Add Transformers-compatible weights converted from fairseq version
22cea17
|
raw
history blame
2.46 kB
metadata
license: apache-2.0

OFA-base-caption

This is the base version of OFA model finetuned for the image captioning task. OFA is a unified multimodal pretrained model that unifies modalities (i.e., cross-modality, vision, language) and tasks (e.g., image generation, visual grounding, image captioning, image classification, text generation, etc.) to a simple sequence-to-sequence learning framework.

The directory includes 4 files, namely config.json which consists of model configuration, vocab.json and merge.txt for our OFA tokenizer, and lastly pytorch_model.bin which consists of model weights. There is no need to worry about the mismatch between Fairseq and transformers, since we have addressed the issue yet.

To use it in transformers, please refer to https://github.com/OFA-Sys/OFA/tree/feature/add_transformers. Install the transformers and download the models as shown below.

git clone --single-branch --branch feature/add_transformers https://github.com/OFA-Sys/OFA.git
pip install OFA/transformers/

After, prepare an image for the testing example below. Also, ensure that you have pillow and torchvision in your environment.

import re
import time
from PIL import Image
from torchvision import transforms
from transformers import OFATokenizer, OFAModel

model_name = "OFA-sys/OFA-base-caption"

mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
resolution = 256

patch_resize_transform = transforms.Compose([
        lambda image: image.convert("RGB"),
        transforms.Resize((resolution, resolution), interpolation=Image.BICUBIC),
        transforms.ToTensor(), 
        transforms.Normalize(mean=mean, std=std)
    ])

start = time.time()
tokenizer = OFATokenizer.from_pretrained(model_name)
model = OFAModel.from_pretrained(model_name, use_cache=False)
alapsed = time.time() - start
print(f"Loaded in {alapsed} secs")


def caption_image(txt, img):
    inputs = tokenizer([txt], return_tensors="pt").input_ids
    patch_img = patch_resize_transform(img).unsqueeze(0)

    gen = model.generate(inputs, patch_images=patch_img, num_beams=5, no_repeat_ngram_size=3)
    results = tokenizer.batch_decode(gen, skip_special_tokens=True)

    result = results[0].strip()
    result = re.sub(r'[^\w\s]', '', result)
    
    return result


if __name__ == "__main__":
    txt = "What does the image describe?"
    img = Image.open('/path/to/input/image.jpg')
    caption = caption_image(txt, img)
    print(caption)