English
anas-awadalla's picture
Update README.md
7892789
|
raw
history blame
6.64 kB
metadata
language: en
datasets:
  - laion2b

OpenFlamingo-3B (CLIP ViT-L/14, MPT-1B)

Blog post | Code | Demo

OpenFlamingo is an open source implementation of DeepMind's Flamingo models. This 3B-parameter model uses a CLIP ViT-L/14 vision encoder and MPT-1B language model.

Model Details

We follow the Flamingo modeling paradigm, outfitting the layers of a pretrained, frozen language model such that they cross-attend to visual features when decoding. Following Flamingo, we freeze the vision encoder and language model but train the connecting modules on web-scraped image-text sequences. Specifically, we trained this model on a mixture of LAION-2B and Multimodal C4.

This model has cross-attention modules inserted in every decoder block. It was trained using DistributedDataParallel across 64 A100 80GB GPUs at FP32 precision.

The MPT-1B modeling code does not accept the labels kwarg and compute cross-entropy loss within forward(). To train with the OpenFlamingo codebase, we suggest a version with the labels kwarg here.

Uses

OpenFlamingo models process arbitrarily interleaved sequences of images and text to output text. This allows the models to accept in-context examples and undertake tasks like captioning, visual question answering, and image classification.

Initialization

from open_flamingo import create_model_and_transforms

model, image_processor, tokenizer = create_model_and_transforms(
    clip_vision_encoder_path="ViT-L-14",
    clip_vision_encoder_pretrained="openai",
    lang_encoder_path="anas-awadalla/mpt-1b-redpajama-200b",
    tokenizer_path="anas-awadalla/mpt-1b-redpajama-200b",
    cross_attn_every_n_layers=1
)

# grab model checkpoint from huggingface hub
from huggingface_hub import hf_hub_download
import torch

checkpoint_path = hf_hub_download("openflamingo/OpenFlamingo-3B-vitl-mpt1b", "checkpoint.pt")
model.load_state_dict(torch.load(checkpoint_path), strict=False)

Generation example

Below is an example of generating text conditioned on interleaved images/text. In particular, let's try few-shot image captioning.

from PIL import Image
import requests

"""
Step 1: Load images
"""
demo_image_one = Image.open(
    requests.get(
        "http://images.cocodataset.org/val2017/000000039769.jpg", stream=True
    ).raw
)

demo_image_two = Image.open(
    requests.get(
        "http://images.cocodataset.org/test-stuff2017/000000028137.jpg",
        stream=True
    ).raw
)

query_image = Image.open(
    requests.get(
        "http://images.cocodataset.org/test-stuff2017/000000028352.jpg", 
        stream=True
    ).raw
)


"""
Step 2: Preprocessing images
Details: For OpenFlamingo, we expect the image to be a torch tensor of shape 
 batch_size x num_media x num_frames x channels x height x width. 
 In this case batch_size = 1, num_media = 3, num_frames = 1,
 channels = 3, height = 224, width = 224.
"""
vision_x = [image_processor(demo_image_one).unsqueeze(0), image_processor(demo_image_two).unsqueeze(0), image_processor(query_image).unsqueeze(0)]
vision_x = torch.cat(vision_x, dim=0)
vision_x = vision_x.unsqueeze(1).unsqueeze(0)

"""
Step 3: Preprocessing text
Details: In the text we expect an <image> special token to indicate where an image is.
 We also expect an <|endofchunk|> special token to indicate the end of the text 
 portion associated with an image.
"""
tokenizer.padding_side = "left" # For generation padding tokens should be on the left
lang_x = tokenizer(
    ["<image>An image of two cats.<|endofchunk|><image>An image of a bathroom sink.<|endofchunk|><image>An image of"],
    return_tensors="pt",
)


"""
Step 4: Generate text
"""
generated_text = model.generate(
    vision_x=vision_x,
    lang_x=lang_x["input_ids"],
    attention_mask=lang_x["attention_mask"],
    max_new_tokens=20,
    num_beams=3,
)

print("Generated text: ", tokenizer.decode(generated_text[0]))

Bias, Risks, and Limitations

OpenFlamingo models inherit the risks of their parent models, especially the language model. As an open-source research effort, we highly value open, accessible, reproducible multimodal model research; however, it is crucial to be aware that these models are trained on web data, have not been finetuned for safety, and thus may produce unintended, inappropriate, unreliable, and/or inaccurate outputs. Please use caution before deploying OpenFlamingo models in real applications. We also hope that OpenFlamingo enables further safety and reliability research to address these issues.

In an effort to mitigate current potential biases and harms, we have deployed a text content filter on model outputs in the OpenFlamingo demo. We continue to red-team the model to understand and improve its safety.

Evaluation

0-shot 4-shot 8-shot 16-shot 32-shot
COCO (CIDEr) 74.9 (0.2) 77.3 (0.3) 85.9 (0.6) 89.8 (0.2) 93.0 (0.6)
Flickr-30K (CIDEr) 52.3 (1.0) 57.2 (0.4) 58.6 (1.1) 59.2 (0.5) 61.1 (1.3)
VQAv2 (Accuracy) 44.6 (0.7) 45.9 (0.7) 45.8 (0.5) 45.5 (0.2) 45.8 (0.4)
OK-VQA (Accuracy) 26.8 (0.3) 27.6 (0.2) 27.7 (0.1) 28.4 (0.1) 29.3 (0.2)
TextVQA (Accuracy) 22.8 (0.2) 25.8 (0.2) 24.7 (0.1) 25.2 (0.2) 26.3 (0.2)
Vizwiz (Accuracy) 18.3 (0.6) 23.3 (1.1) 31.8 (0.7) 38.4 (1.1) 42.1 (0.6)
Hateful Memes (ROC AUC) 51.4 (3.3) 51.4 (0.6) 52.1 (0.7) 51.6 (1.1) 51.6 (1.6)