English
Irena Gao commited on
Commit
54d3e79
·
1 Parent(s): 2a94151

init commit

Browse files
Files changed (1) hide show
  1. README.md +166 -0
README.md ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language: en
3
+ datasets:
4
+ - laion2b
5
+ ---
6
+
7
+ # OpenFlamingo-3B (CLIP ViT-L/14, MPT-1B)
8
+
9
+ [Blog post]() | [Code](https://github.com/mlfoundations/open_flamingo) | [Demo]()
10
+
11
+ OpenFlamingo is an open source implementation of DeepMind's [Flamingo](https://www.deepmind.com/blog/tackling-multiple-tasks-with-a-single-visual-language-model) models.
12
+ This 3B-parameter model uses a [CLIP ViT-L/14](https://huggingface.co/openai/clip-vit-large-patch14) vision encoder and [MPT-1B](https://huggingface.co/mosaicml/mpt-1b-redpajama-200b) language model.
13
+
14
+ ## Model Details
15
+ We follow the Flamingo modeling paradigm, outfitting the layers of a pretrained, frozen language model such that they cross-attend to visual features when decoding. Following Flamingo, we freeze the vision encoder and language model but train the connecting modules on web-scraped image-text sequences. Specifically, we trained this model on a mixture of [LAION-2B](https://arxiv.org/abs/2210.08402) and [Multimodal C4](https://arxiv.org/abs/2304.06939).
16
+
17
+ This model has cross-attention modules inserted in *every* decoder block. It was trained using DistributedDataParallel across 64 A100 80GB GPUs at FP32 precision.
18
+
19
+ The [MPT-1B](https://huggingface.co/mosaicml/mpt-1b-redpajama-200b) modeling code does not accept the `labels` kwarg and compute cross-entropy loss within `forward()`. To train with the OpenFlamingo codebase, we suggest a version with the `labels` kwarg [here](https://huggingface.co/anas-awadalla/mpt-1b-redpajama-200b).
20
+
21
+ ## Uses
22
+ OpenFlamingo models process arbitrarily interleaved sequences of images and text to output text. This allows the models to accept in-context examples and undertake tasks like captioning, visual question answering, and image classification.
23
+
24
+ ### Generation example
25
+ Below is an example of generating text conditioned on interleaved images/text. In particular, let's try few-shot image captioning.
26
+
27
+ ``` python
28
+ from PIL import Image
29
+ import requests
30
+
31
+ """
32
+ Step 1: Load images
33
+ """
34
+ demo_image_one = Image.open(
35
+ requests.get(
36
+ "http://images.cocodataset.org/val2017/000000039769.jpg", stream=True
37
+ ).raw
38
+ )
39
+
40
+ demo_image_two = Image.open(
41
+ requests.get(
42
+ "http://images.cocodataset.org/test-stuff2017/000000028137.jpg",
43
+ stream=True
44
+ ).raw
45
+ )
46
+
47
+ query_image = Image.open(
48
+ requests.get(
49
+ "http://images.cocodataset.org/test-stuff2017/000000028352.jpg",
50
+ stream=True
51
+ ).raw
52
+ )
53
+
54
+
55
+ """
56
+ Step 2: Preprocessing images
57
+ Details: For OpenFlamingo, we expect the image to be a torch tensor of shape
58
+ batch_size x num_media x num_frames x channels x height x width.
59
+ In this case batch_size = 1, num_media = 3, num_frames = 1,
60
+ channels = 3, height = 224, width = 224.
61
+ """
62
+ vision_x = [image_processor(demo_image_one).unsqueeze(0), image_processor(demo_image_two).unsqueeze(0), image_processor(query_image).unsqueeze(0)]
63
+ vision_x = torch.cat(vision_x, dim=0)
64
+ vision_x = vision_x.unsqueeze(1).unsqueeze(0)
65
+
66
+ """
67
+ Step 3: Preprocessing text
68
+ Details: In the text we expect an <image> special token to indicate where an image is.
69
+ We also expect an <|endofchunk|> special token to indicate the end of the text
70
+ portion associated with an image.
71
+ """
72
+ tokenizer.padding_side = "left" # For generation padding tokens should be on the left
73
+ lang_x = tokenizer(
74
+ ["<image>An image of two cats.<|endofchunk|><image>An image of a bathroom sink.<|endofchunk|><image>An image of"],
75
+ return_tensors="pt",
76
+ )
77
+
78
+
79
+ """
80
+ Step 4: Generate text
81
+ """
82
+ generated_text = model.generate(
83
+ vision_x=vision_x,
84
+ lang_x=lang_x["input_ids"],
85
+ attention_mask=lang_x["attention_mask"],
86
+ max_new_tokens=20,
87
+ num_beams=3,
88
+ )
89
+
90
+ print("Generated text: ", tokenizer.decode(generated_text[0]))
91
+ ```
92
+
93
+ ### Bias, Risks, and Limitations
94
+ OpenFlamingo models inherit the risks of their parent models, especially the language model. As an open-source research effort, we highly value open, accessible, reproducible multimodal model research; however, it is crucial to be aware that these models are trained on web data, have not been finetuned for safety, and thus may produce unintended, inappropriate, unreliable, and/or inaccurate outputs. Please use caution before deploying OpenFlamingo models in real applications. We also hope that OpenFlamingo enables further safety and reliability research to address these issues.
95
+
96
+ In an effort to mitigate current potential biases and harms, we have deployed a text content filter on model outputs in the OpenFlamingo demo. We continue to red-team the model to understand and improve its safety.
97
+
98
+ ## Evaluation
99
+
100
+ <table>
101
+ <tr>
102
+ <th></th>
103
+ <th>0-shot</th>
104
+ <th>4-shot</th>
105
+ <th>8-shot</th>
106
+ <th>16-shot</th>
107
+ <th>32-shot</th>
108
+ </tr>
109
+ <tr>
110
+ <th>COCO (CIDEr)</th>
111
+ <td>0</td>
112
+ <td>4</td>
113
+ <td>8</td>
114
+ <td>16</td>
115
+ <td>32</td>
116
+ </tr>
117
+ <tr>
118
+ <th>Flickr-30K (CIDEr)</th>
119
+ <td>0</td>
120
+ <td>4</td>
121
+ <td>8</td>
122
+ <td>16</td>
123
+ <td>32</td>
124
+ </tr>
125
+ <tr>
126
+ <th>VQAv2 (Accuracy)</th>
127
+ <td>0</td>
128
+ <td>4</td>
129
+ <td>8</td>
130
+ <td>16</td>
131
+ <td>32</td>
132
+ </tr>
133
+ <tr>
134
+ <th>OK-VQA (Accuracy)</th>
135
+ <td>0</td>
136
+ <td>4</td>
137
+ <td>8</td>
138
+ <td>16</td>
139
+ <td>32</td>
140
+ </tr>
141
+ <tr>
142
+ <th>TextVQA (Accuracy)</th>
143
+ <td>0</td>
144
+ <td>4</td>
145
+ <td>8</td>
146
+ <td>16</td>
147
+ <td>32</td>
148
+
149
+ </tr>
150
+ <tr>
151
+ <th>Vizwiz (Accuracy)</th>
152
+ <td>0</td>
153
+ <td>4</td>
154
+ <td>8</td>
155
+ <td>16</td>
156
+ <td>32</td>
157
+ </tr>
158
+ <tr>
159
+ <th>Hateful Memes (ROC AUC)</th>
160
+ <td>0</td>
161
+ <td>4</td>
162
+ <td>8</td>
163
+ <td>16</td>
164
+ <td>32</td>
165
+ </tr>
166
+ </table>