Update README.md
Browse files
README.md
CHANGED
@@ -1,3 +1,147 @@
|
|
1 |
-
---
|
2 |
-
license: apache-2.0
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: apache-2.0
|
3 |
+
language:
|
4 |
+
- en
|
5 |
+
pipeline_tag: image-text-to-text
|
6 |
+
inference: false
|
7 |
+
tags:
|
8 |
+
- vision
|
9 |
+
- image-text-to-text
|
10 |
+
arxiv: 2409.18869
|
11 |
+
---
|
12 |
+
|
13 |
+
|
14 |
+
<div align='center'>
|
15 |
+
<h1>Emu3: Next-Token Prediction is All You Need</h1h1>
|
16 |
+
<h3></h3>
|
17 |
+
|
18 |
+
[Emu3 Team, BAAI](https://www.baai.ac.cn/english.html)
|
19 |
+
|
20 |
+
</div>
|
21 |
+
|
22 |
+
|
23 |
+
<div align='left'>
|
24 |
+
<img src="https://github.com/baaivision/Emu3/blob/main/assets/arch.png?raw=True" class="interpolation-image" alt="arch." height="80%" width="70%" />
|
25 |
+
</div>
|
26 |
+
|
27 |
+
Below is the model card of Emu3-Chat model, which is adapted from the original Emu3 model card that you can find [here](https://huggingface.co/BAAI/Emu3-Gen).
|
28 |
+
|
29 |
+
|
30 |
+
## Model details
|
31 |
+
|
32 |
+
**Model type:**
|
33 |
+
Emu3 is an open-source multimodal models trained with next-token prediction task. By tokenizing images and text into a discrete space, Emu3 is trained as a single transformer from scratch on a mixture of multimodal sequences.
|
34 |
+
It is an auto-regressive language model, based on the transformer architecture.
|
35 |
+
|
36 |
+
**Paper or resources for more information:**
|
37 |
+
https://github.com/baaivision/Emu3
|
38 |
+
|
39 |
+
|
40 |
+
## Highlights
|
41 |
+
|
42 |
+
- **Emu3** is capable of generating high-quality images following the text input, by simply predicting the next vision token. The model naturally supports flexible resolutions and styles.
|
43 |
+
- **Emu3** shows strong vision-language understanding capabilities to see the physical world and provides coherent text responses. Notably, this capability is achieved without depending on a CLIP and a pretrained LLM.
|
44 |
+
- **Emu3** simply generates a video causally by predicting the next token in a video sequence, unlike the video diffusion model as in Sora. With a video in context, Emu3 can also naturally extend the video and predict what will happen next.
|
45 |
+
- **Emu3** outperforms several well-established task-specific models in both generation and perception tasks, surpassing flagship open models such as SDXL, LLaVA-1.6 and OpenSora-1.2, while eliminating the need for diffusion or compositional architectures.
|
46 |
+
|
47 |
+
|
48 |
+
## How to use the model
|
49 |
+
|
50 |
+
First, make sure to have `transformers >= 4.48.0`.
|
51 |
+
Below is an example script to run generation in `float16` precision on a GPU device:
|
52 |
+
|
53 |
+
```python
|
54 |
+
import requests
|
55 |
+
from PIL import Image
|
56 |
+
|
57 |
+
import torch
|
58 |
+
from transformers import AutoProcessor, Emu3ForConditionalGeneration
|
59 |
+
|
60 |
+
model_id = "BAAI/Emu3-Gen-hf"
|
61 |
+
model = Emu3ForConditionalGeneration.from_pretrained(
|
62 |
+
model_id,
|
63 |
+
torch_dtype=torch.float16,
|
64 |
+
low_cpu_mem_usage=True,
|
65 |
+
device_map="cuda:0",
|
66 |
+
)
|
67 |
+
|
68 |
+
processor = AutoProcessor.from_pretrained(model_id)
|
69 |
+
inputs = processor(
|
70 |
+
text=["a portrait of young girl. masterpiece, film grained, best quality."],
|
71 |
+
padding=True,
|
72 |
+
return_tensors="pt",
|
73 |
+
return_for_image_generation=True,
|
74 |
+
).to(model.device)
|
75 |
+
|
76 |
+
image_sizes = inputs.pop("image_sizes")
|
77 |
+
HEIGHT, WIDTH = image_sizes[0]
|
78 |
+
VISUAL_TOKENS = model.vocabulary_mapping.image_tokens
|
79 |
+
|
80 |
+
def prefix_allowed_tokens_fn(batch_id, input_ids):
|
81 |
+
height, width = HEIGHT, WIDTH
|
82 |
+
visual_tokens = VISUAL_TOKENS
|
83 |
+
image_wrapper_token_id = torch.tensor([processor.tokenizer.image_wrapper_token_id], device=model.device)
|
84 |
+
eoi_token_id = torch.tensor([processor.tokenizer.eoi_token_id], device=model.device)
|
85 |
+
eos_token_id = torch.tensor([processor.tokenizer.eos_token_id], device=model.device)
|
86 |
+
pad_token_id = torch.tensor([processor.tokenizer.pad_token_id], device=model.device)
|
87 |
+
eof_token_id = torch.tensor([processor.tokenizer.eof_token_id], device=model.device)
|
88 |
+
eol_token_id = processor.tokenizer.encode("<|extra_200|>", return_tensors="pt")[0]
|
89 |
+
|
90 |
+
position = torch.nonzero(input_ids == image_wrapper_token_id, as_tuple=True)[0][0]
|
91 |
+
offset = input_ids.shape[0] - position
|
92 |
+
if offset % (width + 1) == 0:
|
93 |
+
return (eol_token_id,)
|
94 |
+
elif offset == (width + 1) * height + 1:
|
95 |
+
return (eof_token_id,)
|
96 |
+
elif offset == (width + 1) * height + 2:
|
97 |
+
return (eoi_token_id,)
|
98 |
+
elif offset == (width + 1) * height + 3:
|
99 |
+
return (eos_token_id,)
|
100 |
+
elif offset > (width + 1) * height + 3:
|
101 |
+
return (pad_token_id,)
|
102 |
+
else:
|
103 |
+
return visual_tokens
|
104 |
+
|
105 |
+
out = model.generate(
|
106 |
+
**inputs,
|
107 |
+
max_new_tokens=9_000,
|
108 |
+
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
|
109 |
+
do_sample=True,
|
110 |
+
)
|
111 |
+
|
112 |
+
image = model.decode_image_tokens(out.sequences[:, inputs.input_ids.shape[1]: ], height=HEIGHT, width=WIDTH)
|
113 |
+
images = processor.postprocess(list(image.float()), return_tensors="PIL.Image.Image")
|
114 |
+
for i, image in enumerate(images['pixel_values']):
|
115 |
+
image.save(f"result{i}.png")
|
116 |
+
|
117 |
+
```
|
118 |
+
|
119 |
+
### Model optimization
|
120 |
+
|
121 |
+
|
122 |
+
#### Use Flash-Attention 2 to further speed-up generation
|
123 |
+
|
124 |
+
First make sure to install `flash-attn`. Refer to the [original repository of Flash Attention](https://github.com/Dao-AILab/flash-attention) regarding that package installation. Simply change the snippet above with:
|
125 |
+
|
126 |
+
```diff
|
127 |
+
model = Emu3ForConditionalGeneration.from_pretrained(
|
128 |
+
model_id,
|
129 |
+
torch_dtype=torch.float16,
|
130 |
+
low_cpu_mem_usage=True,
|
131 |
+
+ attn_implementation="flash_attention_2",
|
132 |
+
device_map="cuda:0",
|
133 |
+
)
|
134 |
+
```
|
135 |
+
|
136 |
+
# Citation
|
137 |
+
```
|
138 |
+
@misc{wang2024emu3nexttokenpredictionneed,
|
139 |
+
title={Emu3: Next-Token Prediction is All You Need},
|
140 |
+
author={Xinlong Wang and Xiaosong Zhang and Zhengxiong Luo and Quan Sun and Yufeng Cui and Jinsheng Wang and Fan Zhang and Yueze Wang and Zhen Li and Qiying Yu and Yingli Zhao and Yulong Ao and Xuebin Min and Tao Li and Boya Wu and Bo Zhao and Bowen Zhang and Liangdong Wang and Guang Liu and Zheqi He and Xi Yang and Jingjing Liu and Yonghua Lin and Tiejun Huang and Zhongyuan Wang},
|
141 |
+
year={2024},
|
142 |
+
eprint={2409.18869},
|
143 |
+
archivePrefix={arXiv},
|
144 |
+
primaryClass={cs.CV},
|
145 |
+
url={https://arxiv.org/abs/2409.18869},
|
146 |
+
}
|
147 |
+
```
|