JosephPai
commited on
Commit
•
8741abe
1
Parent(s):
36c1777
init
Browse files- LICENSE +21 -0
- README.md +2 -2
- app.py +526 -0
- configs/showo_demo.yaml +49 -0
- configs/showo_demo_w_clip_vit.yaml +49 -0
- gradio/app.py +488 -0
- gradio/app_gradio.py +281 -0
- gradio/app_w_clip.py +559 -0
- gradio/share_btn.py +113 -0
- inference_mmu.py +174 -0
- inference_t2i.py +331 -0
- inpainting_validation/.DS_Store +0 -0
- inpainting_validation/alpine_lake.jpg +0 -0
- inpainting_validation/bedroom.jpg +0 -0
- inpainting_validation/bedroom_mask.webp +0 -0
- inpainting_validation/bench.jpg +0 -0
- inpainting_validation/bench_mask.webp +0 -0
- inpainting_validation/bus.jpg +0 -0
- inpainting_validation/bus_mask.webp +0 -0
- inpainting_validation/lake_mountain.jpg +0 -0
- inpainting_validation/maya.png +0 -0
- inpainting_validation/river.png +0 -0
- inpainting_validation/train.jpg +0 -0
- inpainting_validation/train_mask.webp +0 -0
- inpainting_validation/truebsee.jpg +0 -0
- inpainting_validation/truebsee_mask.webp +0 -0
- inpainting_validation/wukong1.jpg +0 -0
- inpainting_validation/wukong2.jpg +0 -0
- mmu_validation/sofa_under_water.jpg +0 -0
- models/__init__.py +4 -0
- models/clip_encoder.py +140 -0
- models/common_modules.py +407 -0
- models/logging.py +338 -0
- models/lr_schedulers.py +292 -0
- models/misc.py +53 -0
- models/modeling_magvitv2.py +440 -0
- models/modeling_showo.py +206 -0
- models/modeling_utils.py +1207 -0
- models/phi.py +1489 -0
- models/sampling.py +118 -0
- models/training_utils.py +455 -0
- prompting_utils.py +528 -0
- requirements.txt +228 -0
- training/__init__.py +1 -0
- training/conversation.py +432 -0
- training/utils.py +185 -0
- training_utils.py +185 -0
- validation_prompts/showoprompts.txt +24 -0
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2024 Jinheng Xie
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:df8da49db1cd3db14e34d015e398bd3d5ab51c3988fd98976405701ce1838ef5
|
3 |
+
size 224
|
app.py
ADDED
@@ -0,0 +1,526 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "False"
|
3 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
4 |
+
import numpy as np
|
5 |
+
import gradio as gr
|
6 |
+
import spaces
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from PIL import Image
|
10 |
+
from omegaconf import OmegaConf
|
11 |
+
from transformers import AutoTokenizer
|
12 |
+
|
13 |
+
from prompting_utils import UniversalPrompting, create_attention_mask_predict_next, create_attention_mask_for_mmu
|
14 |
+
from training_utils import image_transform
|
15 |
+
from models import Showo, MAGVITv2, get_mask_chedule
|
16 |
+
|
17 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
18 |
+
|
19 |
+
config = OmegaConf.load("configs/showo_demo.yaml")
|
20 |
+
tokenizer = AutoTokenizer.from_pretrained(config.model.showo.llm_model_path, padding_side="left")
|
21 |
+
|
22 |
+
uni_prompting = UniversalPrompting(tokenizer, max_text_len=config.dataset.preprocessing.max_seq_length,
|
23 |
+
special_tokens=("<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", "<|mmu|>",
|
24 |
+
"<|t2v|>", "<|v2v|>", "<|lvg|>"),
|
25 |
+
ignore_id=-100, cond_dropout_prob=config.training.cond_dropout_prob)
|
26 |
+
|
27 |
+
vq_model = MAGVITv2()
|
28 |
+
vq_model = vq_model.from_pretrained(config.model.vq_model.vq_model_name).to(device)
|
29 |
+
vq_model.requires_grad_(False)
|
30 |
+
vq_model.eval()
|
31 |
+
|
32 |
+
model = Showo.from_pretrained(config.model.showo.pretrained_model_path).to(device)
|
33 |
+
model.eval()
|
34 |
+
mask_token_id = model.config.mask_token_id
|
35 |
+
|
36 |
+
|
37 |
+
@spaces.GPU
|
38 |
+
def text_to_image_generation(input_text, guidance_scale=1.75, generation_timesteps=18):
|
39 |
+
prompts = [input_text]
|
40 |
+
config.training.batch_size = config.batch_size = 1
|
41 |
+
config.training.guidance_scale = config.guidance_scale = guidance_scale
|
42 |
+
config.training.generation_timesteps = config.generation_timesteps = generation_timesteps
|
43 |
+
|
44 |
+
image_tokens = torch.ones((len(prompts), config.model.showo.num_vq_tokens),
|
45 |
+
dtype=torch.long, device=device) * mask_token_id
|
46 |
+
|
47 |
+
input_ids, _ = uni_prompting((prompts, image_tokens), 't2i_gen')
|
48 |
+
|
49 |
+
if config.training.guidance_scale > 0:
|
50 |
+
uncond_input_ids, _ = uni_prompting(([''] * len(prompts), image_tokens), 't2i_gen')
|
51 |
+
attention_mask = create_attention_mask_predict_next(torch.cat([input_ids, uncond_input_ids], dim=0),
|
52 |
+
pad_id=int(uni_prompting.sptids_dict['<|pad|>']),
|
53 |
+
soi_id=int(uni_prompting.sptids_dict['<|soi|>']),
|
54 |
+
eoi_id=int(uni_prompting.sptids_dict['<|eoi|>']),
|
55 |
+
rm_pad_in_image=True)
|
56 |
+
else:
|
57 |
+
attention_mask = create_attention_mask_predict_next(input_ids,
|
58 |
+
pad_id=int(uni_prompting.sptids_dict['<|pad|>']),
|
59 |
+
soi_id=int(uni_prompting.sptids_dict['<|soi|>']),
|
60 |
+
eoi_id=int(uni_prompting.sptids_dict['<|eoi|>']),
|
61 |
+
rm_pad_in_image=True)
|
62 |
+
uncond_input_ids = None
|
63 |
+
|
64 |
+
if config.get("mask_schedule", None) is not None:
|
65 |
+
schedule = config.mask_schedule.schedule
|
66 |
+
args = config.mask_schedule.get("params", {})
|
67 |
+
mask_schedule = get_mask_chedule(schedule, **args)
|
68 |
+
else:
|
69 |
+
mask_schedule = get_mask_chedule(config.training.get("mask_schedule", "cosine"))
|
70 |
+
|
71 |
+
with torch.no_grad():
|
72 |
+
gen_token_ids = model.t2i_generate(
|
73 |
+
input_ids=input_ids,
|
74 |
+
uncond_input_ids=uncond_input_ids,
|
75 |
+
attention_mask=attention_mask,
|
76 |
+
guidance_scale=config.training.guidance_scale,
|
77 |
+
temperature=config.training.get("generation_temperature", 1.0),
|
78 |
+
timesteps=config.training.generation_timesteps,
|
79 |
+
noise_schedule=mask_schedule,
|
80 |
+
noise_type=config.training.get("noise_type", "mask"),
|
81 |
+
seq_len=config.model.showo.num_vq_tokens,
|
82 |
+
uni_prompting=uni_prompting,
|
83 |
+
config=config,
|
84 |
+
)
|
85 |
+
|
86 |
+
gen_token_ids = torch.clamp(gen_token_ids, max=config.model.showo.codebook_size - 1, min=0)
|
87 |
+
images = vq_model.decode_code(gen_token_ids)
|
88 |
+
|
89 |
+
images = torch.clamp((images + 1.0) / 2.0, min=0.0, max=1.0)
|
90 |
+
images *= 255.0
|
91 |
+
images = images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8)
|
92 |
+
|
93 |
+
return images[0]
|
94 |
+
|
95 |
+
|
96 |
+
@spaces.GPU
|
97 |
+
def text_guided_inpainting(input_text, inpainting_image, inpainting_mask, guidance_scale=1.75, generation_timesteps=16):
|
98 |
+
prompt = [input_text]
|
99 |
+
|
100 |
+
config.training.batch_size = config.batch_size = 1
|
101 |
+
config.training.guidance_scale = config.guidance_scale = guidance_scale
|
102 |
+
config.training.generation_timesteps = config.generation_timesteps = generation_timesteps
|
103 |
+
|
104 |
+
inpainting_image = image_transform(inpainting_image, resolution=config.dataset.params.resolution).to(device)
|
105 |
+
inpainting_mask = image_transform(inpainting_mask, resolution=config.dataset.params.resolution, normalize=False)
|
106 |
+
|
107 |
+
inpainting_image = inpainting_image.unsqueeze(0).repeat(config.training.batch_size, 1, 1, 1)
|
108 |
+
|
109 |
+
inpainting_mask = inpainting_mask.unsqueeze(0).to(device)
|
110 |
+
inpainting_mask = F.interpolate(inpainting_mask, size=config.dataset.params.resolution // 16, mode='bicubic')
|
111 |
+
inpainting_mask = inpainting_mask.repeat(config.training.batch_size, 1, 1, 1)
|
112 |
+
|
113 |
+
inpainting_mask[inpainting_mask < 0.5] = 0
|
114 |
+
inpainting_mask[inpainting_mask >= 0.5] = 1
|
115 |
+
|
116 |
+
inpainting_mask = inpainting_mask.reshape(config.training.batch_size, -1)
|
117 |
+
inpainting_mask = inpainting_mask.to(torch.bool)
|
118 |
+
|
119 |
+
inpainting_image_tokens = vq_model.get_code(inpainting_image) + len(uni_prompting.text_tokenizer)
|
120 |
+
inpainting_image_tokens[inpainting_mask] = mask_token_id
|
121 |
+
|
122 |
+
input_ids, _ = uni_prompting((prompt, inpainting_image_tokens), 't2i_gen')
|
123 |
+
|
124 |
+
if config.training.guidance_scale > 0:
|
125 |
+
uncond_input_ids, _ = uni_prompting(([''] * len(prompt), inpainting_image_tokens), 't2i_gen')
|
126 |
+
attention_mask = create_attention_mask_predict_next(torch.cat([input_ids, uncond_input_ids], dim=0),
|
127 |
+
pad_id=int(uni_prompting.sptids_dict['<|pad|>']),
|
128 |
+
soi_id=int(uni_prompting.sptids_dict['<|soi|>']),
|
129 |
+
eoi_id=int(uni_prompting.sptids_dict['<|eoi|>']),
|
130 |
+
rm_pad_in_image=True)
|
131 |
+
else:
|
132 |
+
attention_mask = create_attention_mask_predict_next(input_ids,
|
133 |
+
pad_id=int(uni_prompting.sptids_dict['<|pad|>']),
|
134 |
+
soi_id=int(uni_prompting.sptids_dict['<|soi|>']),
|
135 |
+
eoi_id=int(uni_prompting.sptids_dict['<|eoi|>']),
|
136 |
+
rm_pad_in_image=True)
|
137 |
+
uncond_input_ids = None
|
138 |
+
|
139 |
+
if config.get("mask_schedule", None) is not None:
|
140 |
+
schedule = config.mask_schedule.schedule
|
141 |
+
args = config.mask_schedule.get("params", {})
|
142 |
+
mask_schedule = get_mask_chedule(schedule, **args)
|
143 |
+
else:
|
144 |
+
mask_schedule = get_mask_chedule(config.training.get("mask_schedule", "cosine"))
|
145 |
+
|
146 |
+
with torch.no_grad():
|
147 |
+
gen_token_ids = model.t2i_generate(
|
148 |
+
input_ids=input_ids,
|
149 |
+
uncond_input_ids=uncond_input_ids,
|
150 |
+
attention_mask=attention_mask,
|
151 |
+
guidance_scale=config.training.guidance_scale,
|
152 |
+
temperature=config.training.get("generation_temperature", 1.0),
|
153 |
+
timesteps=config.training.generation_timesteps,
|
154 |
+
noise_schedule=mask_schedule,
|
155 |
+
noise_type=config.training.get("noise_type", "mask"),
|
156 |
+
seq_len=config.model.showo.num_vq_tokens,
|
157 |
+
uni_prompting=uni_prompting,
|
158 |
+
config=config,
|
159 |
+
)
|
160 |
+
|
161 |
+
gen_token_ids = torch.clamp(gen_token_ids, max=config.model.showo.codebook_size - 1, min=0)
|
162 |
+
images = vq_model.decode_code(gen_token_ids)
|
163 |
+
|
164 |
+
images = torch.clamp((images + 1.0) / 2.0, min=0.0, max=1.0)
|
165 |
+
images *= 255.0
|
166 |
+
images = images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8)
|
167 |
+
|
168 |
+
return images[0]
|
169 |
+
|
170 |
+
|
171 |
+
@spaces.GPU
|
172 |
+
def text_guided_extrapolation(input_img, input_text, left_ext, right_ext, guidance_scale=1.75, generation_timesteps=16):
|
173 |
+
config.offset = 0
|
174 |
+
config.training.batch_size = config.batch_size = 1
|
175 |
+
config.training.guidance_scale = config.guidance_scale = guidance_scale
|
176 |
+
config.training.generation_timesteps = config.generation_timesteps = generation_timesteps
|
177 |
+
|
178 |
+
extra_direction = ['right'] * int(right_ext) + ['left'] * int(left_ext)
|
179 |
+
prompt = [input_text] * len(extra_direction)
|
180 |
+
W = config.dataset.params.resolution // 16
|
181 |
+
for id, (prt, direction) in enumerate(zip(prompt, extra_direction)):
|
182 |
+
prt = [prt] * config.training.batch_size
|
183 |
+
if id == 0:
|
184 |
+
# extrapolation_image = Image.open(config.image_path).convert("RGB")
|
185 |
+
extrapolation_image = input_img
|
186 |
+
extrapolation_image = image_transform(extrapolation_image,
|
187 |
+
resolution=config.dataset.params.resolution).to(device)
|
188 |
+
|
189 |
+
B, _, _ = extrapolation_image.shape
|
190 |
+
extrapolation_image = extrapolation_image.unsqueeze(0)
|
191 |
+
extrapolation_image_tokens = vq_model.get_code(extrapolation_image) + len(uni_prompting.text_tokenizer)
|
192 |
+
extrapolation_image_tokens = extrapolation_image_tokens.reshape(1,
|
193 |
+
config.dataset.params.resolution // 16,
|
194 |
+
config.dataset.params.resolution // 16)
|
195 |
+
extrapolation_image_tokens = extrapolation_image_tokens.repeat(config.training.batch_size, 1, 1)
|
196 |
+
else:
|
197 |
+
|
198 |
+
extrapolation_image_tokens = gen_token_ids + len(uni_prompting.text_tokenizer)
|
199 |
+
|
200 |
+
image_left_part = extrapolation_image_tokens[:, :, :-(W // 2 - config.offset)] - len(
|
201 |
+
uni_prompting.text_tokenizer)
|
202 |
+
image_right_part = extrapolation_image_tokens[:, :, W // 2 - config.offset:] - len(uni_prompting.text_tokenizer)
|
203 |
+
image_up_part = extrapolation_image_tokens[:, :-(W // 2 - config.offset), :] - len(uni_prompting.text_tokenizer)
|
204 |
+
image_down_part = extrapolation_image_tokens[:, W // 2 - config.offset:, :] - len(uni_prompting.text_tokenizer)
|
205 |
+
|
206 |
+
if direction in ['left', 'right']:
|
207 |
+
extrapolation_mask = torch.zeros((config.training.batch_size,
|
208 |
+
config.dataset.params.resolution // 16,
|
209 |
+
config.dataset.params.resolution // 16 // 2 + config.offset),
|
210 |
+
dtype=torch.int64, device=device) + mask_token_id
|
211 |
+
else:
|
212 |
+
extrapolation_mask = torch.zeros((config.training.batch_size,
|
213 |
+
config.dataset.params.resolution // 16 // 2 + config.offset,
|
214 |
+
config.dataset.params.resolution // 16),
|
215 |
+
dtype=torch.int64, device=device) + mask_token_id
|
216 |
+
|
217 |
+
if direction == 'left':
|
218 |
+
extrapolation_image_tokens = torch.cat(
|
219 |
+
[extrapolation_mask, extrapolation_image_tokens[:, :, :W // 2 - config.offset]], dim=-1)
|
220 |
+
elif direction == 'right':
|
221 |
+
extrapolation_image_tokens = torch.cat(
|
222 |
+
[extrapolation_image_tokens[:, :, -(W // 2 - config.offset):], extrapolation_mask], dim=-1)
|
223 |
+
elif direction == 'up':
|
224 |
+
extrapolation_image_tokens = torch.cat(
|
225 |
+
[extrapolation_mask, extrapolation_image_tokens[:, :W // 2 - config.offset, :]], dim=-2)
|
226 |
+
else:
|
227 |
+
extrapolation_image_tokens = torch.cat(
|
228 |
+
[extrapolation_image_tokens[:, -(W // 2 - config.offset):, :], extrapolation_mask], dim=-2)
|
229 |
+
|
230 |
+
extrapolation_image_tokens = extrapolation_image_tokens.reshape(config.training.batch_size, -1)
|
231 |
+
|
232 |
+
input_ids, _ = uni_prompting((prt, extrapolation_image_tokens), 't2i_gen')
|
233 |
+
|
234 |
+
if config.training.guidance_scale > 0:
|
235 |
+
uncond_input_ids, _ = uni_prompting(([''] * len(prt), extrapolation_image_tokens), 't2i_gen')
|
236 |
+
attention_mask = create_attention_mask_predict_next(torch.cat([input_ids, uncond_input_ids], dim=0),
|
237 |
+
pad_id=int(uni_prompting.sptids_dict['<|pad|>']),
|
238 |
+
soi_id=int(uni_prompting.sptids_dict['<|soi|>']),
|
239 |
+
eoi_id=int(uni_prompting.sptids_dict['<|eoi|>']),
|
240 |
+
rm_pad_in_image=True)
|
241 |
+
else:
|
242 |
+
attention_mask = create_attention_mask_predict_next(input_ids,
|
243 |
+
pad_id=int(uni_prompting.sptids_dict['<|pad|>']),
|
244 |
+
soi_id=int(uni_prompting.sptids_dict['<|soi|>']),
|
245 |
+
eoi_id=int(uni_prompting.sptids_dict['<|eoi|>']),
|
246 |
+
rm_pad_in_image=True)
|
247 |
+
uncond_input_ids = None
|
248 |
+
|
249 |
+
if config.get("mask_schedule", None) is not None:
|
250 |
+
schedule = config.mask_schedule.schedule
|
251 |
+
args = config.mask_schedule.get("params", {})
|
252 |
+
mask_schedule = get_mask_chedule(schedule, **args)
|
253 |
+
else:
|
254 |
+
mask_schedule = get_mask_chedule(config.training.get("mask_schedule", "cosine"))
|
255 |
+
|
256 |
+
with torch.no_grad():
|
257 |
+
gen_token_ids = model.t2i_generate(
|
258 |
+
input_ids=input_ids,
|
259 |
+
uncond_input_ids=uncond_input_ids,
|
260 |
+
attention_mask=attention_mask,
|
261 |
+
guidance_scale=config.training.guidance_scale,
|
262 |
+
temperature=config.training.get("generation_temperature", 1.0),
|
263 |
+
timesteps=config.training.generation_timesteps,
|
264 |
+
noise_schedule=mask_schedule,
|
265 |
+
noise_type=config.training.get("noise_type", "mask"),
|
266 |
+
seq_len=config.model.showo.num_vq_tokens,
|
267 |
+
uni_prompting=uni_prompting,
|
268 |
+
config=config,
|
269 |
+
)
|
270 |
+
|
271 |
+
gen_token_ids = torch.clamp(gen_token_ids, max=config.model.showo.codebook_size - 1, min=0)
|
272 |
+
gen_token_ids = gen_token_ids.reshape(config.training.batch_size,
|
273 |
+
config.dataset.params.resolution // 16,
|
274 |
+
config.dataset.params.resolution // 16)
|
275 |
+
if direction == 'left':
|
276 |
+
gen_token_ids = torch.cat([gen_token_ids, image_right_part], dim=-1)
|
277 |
+
elif direction == 'right':
|
278 |
+
gen_token_ids = torch.cat([image_left_part, gen_token_ids], dim=-1)
|
279 |
+
elif direction == 'up':
|
280 |
+
gen_token_ids = torch.cat([gen_token_ids, image_down_part], dim=-2)
|
281 |
+
else:
|
282 |
+
gen_token_ids = torch.cat([image_left_part, gen_token_ids], dim=-2)
|
283 |
+
|
284 |
+
_, h, w = gen_token_ids.shape
|
285 |
+
gen_token_ids = gen_token_ids.reshape(config.training.batch_size, -1)
|
286 |
+
images = vq_model.decode_code(gen_token_ids, shape=(h, w))
|
287 |
+
|
288 |
+
images = torch.clamp((images + 1.0) / 2.0, min=0.0, max=1.0)
|
289 |
+
images *= 255.0
|
290 |
+
images = images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8)
|
291 |
+
|
292 |
+
return images[0]
|
293 |
+
|
294 |
+
|
295 |
+
@spaces.GPU
|
296 |
+
def multimodal_understanding(input_img, input_text, chat_history):
|
297 |
+
top_k = 1 # retain only the top_k most likely tokens, clamp others to have 0 probability
|
298 |
+
|
299 |
+
image_ori = input_img
|
300 |
+
image = image_transform(image_ori, resolution=config.dataset.params.resolution).to(device)
|
301 |
+
image = image.unsqueeze(0)
|
302 |
+
image_tokens = vq_model.get_code(image) + len(uni_prompting.text_tokenizer)
|
303 |
+
|
304 |
+
question = input_text
|
305 |
+
input_ids = uni_prompting.text_tokenizer(['USER: \n' + question + ' ASSISTANT:'])[
|
306 |
+
'input_ids']
|
307 |
+
input_ids = torch.tensor(input_ids).to(device)
|
308 |
+
|
309 |
+
input_ids = torch.cat([
|
310 |
+
(torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|mmu|>']).to(device),
|
311 |
+
(torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|soi|>']).to(device),
|
312 |
+
image_tokens,
|
313 |
+
(torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|eoi|>']).to(device),
|
314 |
+
(torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|sot|>']).to(device),
|
315 |
+
input_ids
|
316 |
+
], dim=1).long()
|
317 |
+
|
318 |
+
attention_mask = create_attention_mask_for_mmu(input_ids.to(device),
|
319 |
+
eoi_id=int(uni_prompting.sptids_dict['<|eoi|>']))
|
320 |
+
|
321 |
+
cont_toks_list = model.mmu_generate(input_ids, attention_mask=attention_mask,
|
322 |
+
max_new_tokens=100, top_k=top_k,
|
323 |
+
eot_token=uni_prompting.sptids_dict['<|eot|>'])
|
324 |
+
|
325 |
+
cont_toks_list = torch.stack(cont_toks_list).squeeze()[None]
|
326 |
+
|
327 |
+
output_text = uni_prompting.text_tokenizer.batch_decode(cont_toks_list, skip_special_tokens=True)
|
328 |
+
|
329 |
+
output_text = output_text[0].strip()
|
330 |
+
|
331 |
+
chat_history.append((input_text, output_text))
|
332 |
+
|
333 |
+
return "", chat_history
|
334 |
+
|
335 |
+
|
336 |
+
with gr.Blocks() as demo:
|
337 |
+
gr.HTML("""
|
338 |
+
<h1 class="display-2 fw-bold title">
|
339 |
+
<a style="color: #70a8dc;">S</a><a style="color: #6fb051;">h</a><a style="color: #e06766;">o</a><a style="color: #f7b26b;">w</a>-o
|
340 |
+
</h1>
|
341 |
+
<p>This is the official Gradio demo for the Show-o model, a unified model that can do multimodal understanding and generation.</p>
|
342 |
+
|
343 |
+
<strong>Paper:</strong> <a href="https://arxiv.org/abs/2408.12528" target="_blank">Show-o: One Single Transformer To Unify Multimodal Understanding and Generation </a>
|
344 |
+
<br/>
|
345 |
+
<strong>Project Website:</strong> <a href="https://showlab.github.io/Show-o/" target="_blank">Show-o Website</a>
|
346 |
+
<br/>
|
347 |
+
<strong>Code and Models:</strong> <a href="https://github.com/showlab/Show-o" target="_blank">GitHub</a>
|
348 |
+
<br/>
|
349 |
+
<br/>
|
350 |
+
""")
|
351 |
+
|
352 |
+
with gr.Row():
|
353 |
+
with gr.Column():
|
354 |
+
text_prompt_t2i = gr.Textbox(
|
355 |
+
label="Text prompt",
|
356 |
+
lines=2,
|
357 |
+
placeholder="Input the text prompt here for image generation."
|
358 |
+
)
|
359 |
+
guidance_scale_t2i = gr.Slider(
|
360 |
+
label="guidance scale",
|
361 |
+
minimum=0,
|
362 |
+
maximum=5,
|
363 |
+
step=0.05,
|
364 |
+
value=1.75
|
365 |
+
)
|
366 |
+
generation_timesteps_t2i = gr.Slider(
|
367 |
+
label="timesteps",
|
368 |
+
minimum=1,
|
369 |
+
maximum=30,
|
370 |
+
step=1,
|
371 |
+
value=18
|
372 |
+
)
|
373 |
+
generated_img_t2i = gr.Image(
|
374 |
+
label="Output image"
|
375 |
+
)
|
376 |
+
examples_t2i = gr.Examples(
|
377 |
+
label="Text to image generation examples",
|
378 |
+
examples=[
|
379 |
+
"A dynamic scene of a rally car race.",
|
380 |
+
"Paper artwork, layered paper, colorful Chinese dragon surrounded by clouds.",
|
381 |
+
"Pixel art character riding a dragon through the clouds.",
|
382 |
+
],
|
383 |
+
inputs=text_prompt_t2i,
|
384 |
+
)
|
385 |
+
submit_btn_t2i = gr.Button("Generate: Text-to-image")
|
386 |
+
submit_btn_t2i.click(text_to_image_generation,
|
387 |
+
[text_prompt_t2i, guidance_scale_t2i, generation_timesteps_t2i],
|
388 |
+
[generated_img_t2i])
|
389 |
+
|
390 |
+
with gr.Row():
|
391 |
+
inpainting_input_img = gr.Image(
|
392 |
+
label="Input image",
|
393 |
+
type="pil",
|
394 |
+
)
|
395 |
+
inpainting_input_mask = gr.Image(
|
396 |
+
label="Inpainting mask",
|
397 |
+
image_mode="L",
|
398 |
+
type="pil",
|
399 |
+
)
|
400 |
+
|
401 |
+
with gr.Column():
|
402 |
+
text_prompt_inpainting = gr.Textbox(
|
403 |
+
label="Text prompt",
|
404 |
+
lines=2,
|
405 |
+
placeholder="Input the text prompt here for image inpainting."
|
406 |
+
)
|
407 |
+
guidance_scale_inpainting = gr.Slider(
|
408 |
+
label="guidance scale",
|
409 |
+
minimum=0,
|
410 |
+
maximum=5,
|
411 |
+
step=0.05,
|
412 |
+
value=1.75
|
413 |
+
)
|
414 |
+
generation_timesteps_inpainting = gr.Slider(
|
415 |
+
label="timesteps",
|
416 |
+
minimum=1,
|
417 |
+
maximum=30,
|
418 |
+
step=1,
|
419 |
+
value=16
|
420 |
+
)
|
421 |
+
generated_img_inpainting = gr.Image(
|
422 |
+
label="Output image"
|
423 |
+
)
|
424 |
+
examples_inpainting = gr.Examples(
|
425 |
+
label="Text-guided inpainting examples",
|
426 |
+
examples=[
|
427 |
+
[
|
428 |
+
"a blue sports car with sleek curves and tinted windows, parked on a bustling city street.",
|
429 |
+
Image.open("./inpainting_validation/bus.jpg").convert("RGB"),
|
430 |
+
Image.open("./inpainting_validation/bus_mask.webp").convert("L"),
|
431 |
+
],
|
432 |
+
[
|
433 |
+
"a clear, shallow river with some vibrant flowers in it.",
|
434 |
+
Image.open("./inpainting_validation/train.jpg").convert("RGB"),
|
435 |
+
Image.open("./inpainting_validation/train_mask.webp").convert("L"),
|
436 |
+
],
|
437 |
+
],
|
438 |
+
inputs=[text_prompt_inpainting, inpainting_input_img, inpainting_input_mask],
|
439 |
+
)
|
440 |
+
submit_btn_inpainting = gr.Button("Generate: Text-guided Inpainting")
|
441 |
+
submit_btn_inpainting.click(text_guided_inpainting,
|
442 |
+
[text_prompt_inpainting, inpainting_input_img, inpainting_input_mask,
|
443 |
+
guidance_scale_inpainting, generation_timesteps_inpainting],
|
444 |
+
[generated_img_inpainting])
|
445 |
+
|
446 |
+
with gr.Row():
|
447 |
+
extra_input_img = gr.Image(
|
448 |
+
label="Input image",
|
449 |
+
type="pil",
|
450 |
+
image_mode="RGB",
|
451 |
+
)
|
452 |
+
|
453 |
+
with gr.Column():
|
454 |
+
text_prompt_extrapolation = gr.Textbox(
|
455 |
+
label="Text prompt",
|
456 |
+
lines=1,
|
457 |
+
placeholder="Input the text prompt here for image extrapolation."
|
458 |
+
)
|
459 |
+
guidance_scale_extrapolation = gr.Slider(
|
460 |
+
label="guidance scale",
|
461 |
+
minimum=0,
|
462 |
+
maximum=5,
|
463 |
+
step=0.05,
|
464 |
+
value=1.75
|
465 |
+
)
|
466 |
+
generation_timesteps_extrapolation = gr.Slider(
|
467 |
+
label="timesteps",
|
468 |
+
minimum=1,
|
469 |
+
maximum=30,
|
470 |
+
step=1,
|
471 |
+
value=16
|
472 |
+
)
|
473 |
+
left_extrapolation = gr.Slider(
|
474 |
+
label="left extrapolation",
|
475 |
+
minimum=0,
|
476 |
+
maximum=5,
|
477 |
+
step=1,
|
478 |
+
value=1
|
479 |
+
)
|
480 |
+
right_extrapolation = gr.Slider(
|
481 |
+
label="right extrapolation",
|
482 |
+
minimum=0,
|
483 |
+
maximum=5,
|
484 |
+
step=1,
|
485 |
+
value=1
|
486 |
+
)
|
487 |
+
generated_img_extrapolation = gr.Image(
|
488 |
+
label="Output image"
|
489 |
+
)
|
490 |
+
examples_extra = gr.Examples(
|
491 |
+
label="Text-guided extrapolation examples",
|
492 |
+
examples=[
|
493 |
+
[
|
494 |
+
Image.open("./inpainting_validation/wukong2.jpg").convert("RGB"),
|
495 |
+
"the continuous mountain ranges and jungles, with meandering rivers occasionally appearing.",
|
496 |
+
2,
|
497 |
+
2,
|
498 |
+
],
|
499 |
+
[
|
500 |
+
Image.open("./inpainting_validation/alpine_lake.jpg").convert("RGB"),
|
501 |
+
"a serene natural landscape featuring a clear, blue lake surrounded by lush green trees.",
|
502 |
+
2,
|
503 |
+
2,
|
504 |
+
],
|
505 |
+
],
|
506 |
+
inputs=[extra_input_img, text_prompt_extrapolation, left_extrapolation, right_extrapolation],
|
507 |
+
)
|
508 |
+
submit_btn_inpainting = gr.Button("Generate: Text-guided Extrapolation")
|
509 |
+
submit_btn_inpainting.click(text_guided_extrapolation,
|
510 |
+
[extra_input_img, text_prompt_extrapolation, left_extrapolation, right_extrapolation,
|
511 |
+
guidance_scale_extrapolation, generation_timesteps_extrapolation],
|
512 |
+
[generated_img_extrapolation])
|
513 |
+
with gr.Row():
|
514 |
+
with gr.Row():
|
515 |
+
chat_input_img = gr.Image(
|
516 |
+
label="Input image",
|
517 |
+
type="pil",
|
518 |
+
image_mode="RGB",
|
519 |
+
)
|
520 |
+
with gr.Column():
|
521 |
+
chatbot = gr.Chatbot()
|
522 |
+
msg = gr.Textbox(label="Press Enter to send a message for chat")
|
523 |
+
clear = gr.ClearButton([msg, chatbot])
|
524 |
+
msg.submit(multimodal_understanding, [chat_input_img, msg, chatbot], [msg, chatbot])
|
525 |
+
|
526 |
+
demo.launch()
|
configs/showo_demo.yaml
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
wandb:
|
2 |
+
entity: null
|
3 |
+
# run_id: askkz9i2
|
4 |
+
resume: 'auto'
|
5 |
+
|
6 |
+
experiment:
|
7 |
+
project: "demo"
|
8 |
+
name: "show-o-demo"
|
9 |
+
output_dir: "show-o-demo"
|
10 |
+
|
11 |
+
model:
|
12 |
+
vq_model:
|
13 |
+
type: "magvitv2"
|
14 |
+
vq_model_name: "showlab/magvitv2"
|
15 |
+
|
16 |
+
showo:
|
17 |
+
pretrained_model_path: "showlab/show-o"
|
18 |
+
w_clip_vit: False
|
19 |
+
vocab_size: 58498
|
20 |
+
llm_vocab_size: 50295
|
21 |
+
llm_model_path: 'microsoft/phi-1_5'
|
22 |
+
codebook_size: 8192
|
23 |
+
num_vq_tokens: 256
|
24 |
+
|
25 |
+
gradient_checkpointing: True
|
26 |
+
enable_xformers_memory_efficient_attention: True
|
27 |
+
|
28 |
+
|
29 |
+
dataset:
|
30 |
+
gen_type: "t2i"
|
31 |
+
und_type: "large_cap"
|
32 |
+
params:
|
33 |
+
batch_size: ${training.batch_size}
|
34 |
+
shuffle_buffer_size: 1000
|
35 |
+
num_workers: 32
|
36 |
+
resolution: 256
|
37 |
+
pin_memory: True
|
38 |
+
persistent_workers: True
|
39 |
+
|
40 |
+
preprocessing:
|
41 |
+
max_seq_length: 128
|
42 |
+
resolution: 256
|
43 |
+
center_crop: False
|
44 |
+
random_flip: False
|
45 |
+
|
46 |
+
training:
|
47 |
+
gradient_accumulation_steps: 1
|
48 |
+
cond_dropout_prob: 0.1
|
49 |
+
batch_size: 20
|
configs/showo_demo_w_clip_vit.yaml
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
wandb:
|
2 |
+
entity: null
|
3 |
+
# run_id: askkz9i2
|
4 |
+
resume: 'auto'
|
5 |
+
|
6 |
+
experiment:
|
7 |
+
project: "demo"
|
8 |
+
name: "show-o-demo"
|
9 |
+
output_dir: "show-o-demo"
|
10 |
+
|
11 |
+
model:
|
12 |
+
vq_model:
|
13 |
+
type: "magvitv2"
|
14 |
+
vq_model_name: "showlab/magvitv2"
|
15 |
+
|
16 |
+
showo:
|
17 |
+
pretrained_model_path: "showlab/show-o-w-clip-vit"
|
18 |
+
w_clip_vit: True
|
19 |
+
vocab_size: 58498
|
20 |
+
llm_vocab_size: 50295
|
21 |
+
llm_model_path: 'microsoft/phi-1_5'
|
22 |
+
codebook_size: 8192
|
23 |
+
num_vq_tokens: 256
|
24 |
+
|
25 |
+
gradient_checkpointing: True
|
26 |
+
enable_xformers_memory_efficient_attention: True
|
27 |
+
|
28 |
+
|
29 |
+
dataset:
|
30 |
+
gen_type: "t2i"
|
31 |
+
und_type: "large_cap"
|
32 |
+
params:
|
33 |
+
batch_size: ${training.batch_size}
|
34 |
+
shuffle_buffer_size: 1000
|
35 |
+
num_workers: 32
|
36 |
+
resolution: 256
|
37 |
+
pin_memory: True
|
38 |
+
persistent_workers: True
|
39 |
+
|
40 |
+
preprocessing:
|
41 |
+
max_seq_length: 128
|
42 |
+
resolution: 256
|
43 |
+
center_crop: False
|
44 |
+
random_flip: False
|
45 |
+
|
46 |
+
training:
|
47 |
+
gradient_accumulation_steps: 1
|
48 |
+
cond_dropout_prob: 0.1
|
49 |
+
batch_size: 20
|
gradio/app.py
ADDED
@@ -0,0 +1,488 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "False"
|
3 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
4 |
+
import numpy as np
|
5 |
+
import gradio as gr
|
6 |
+
import torch
|
7 |
+
from PIL import Image
|
8 |
+
from omegaconf import OmegaConf
|
9 |
+
from transformers import AutoTokenizer
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from transformers import CLIPImageProcessor
|
12 |
+
|
13 |
+
import sys
|
14 |
+
sys.path.insert(0, ".")
|
15 |
+
from training import conversation as conversation_lib
|
16 |
+
from prompting_utils import UniversalPrompting, create_attention_mask_predict_next, create_attention_mask_for_mmu
|
17 |
+
from training_utils import image_transform
|
18 |
+
from models import Showo, MAGVITv2, get_mask_chedule, CLIPVisionTower
|
19 |
+
|
20 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
21 |
+
conversation_lib.default_conversation = conversation_lib.conv_templates["phi1.5"]
|
22 |
+
SYSTEM_PROMPT = "A chat between a curious user and an artificial intelligence assistant. " \
|
23 |
+
"The assistant gives helpful, detailed, and polite answers to the user's questions."
|
24 |
+
SYSTEM_PROMPT_LEN = 28
|
25 |
+
|
26 |
+
|
27 |
+
config = OmegaConf.load("configs/showo_demo.yaml")
|
28 |
+
tokenizer = AutoTokenizer.from_pretrained(config.model.showo.llm_model_path, padding_side="left")
|
29 |
+
|
30 |
+
uni_prompting = UniversalPrompting(tokenizer, max_text_len=config.dataset.preprocessing.max_seq_length,
|
31 |
+
special_tokens=("<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", "<|mmu|>",
|
32 |
+
"<|t2v|>", "<|v2v|>", "<|lvg|>"),
|
33 |
+
ignore_id=-100, cond_dropout_prob=config.training.cond_dropout_prob)
|
34 |
+
|
35 |
+
vq_model = MAGVITv2()
|
36 |
+
vq_model = vq_model.from_pretrained(config.model.vq_model.vq_model_name).to(device)
|
37 |
+
vq_model.requires_grad_(False)
|
38 |
+
vq_model.eval()
|
39 |
+
|
40 |
+
model = Showo.from_pretrained(config.model.showo.pretrained_model_path).to(device)
|
41 |
+
model.eval()
|
42 |
+
mask_token_id = model.config.mask_token_id
|
43 |
+
|
44 |
+
|
45 |
+
def text_to_image_generation(input_text, guidance_scale, generation_timesteps):
|
46 |
+
prompts = [input_text]
|
47 |
+
config.training.batch_size = config.batch_size = 1
|
48 |
+
config.training.guidance_scale = config.guidance_scale = guidance_scale
|
49 |
+
config.training.generation_timesteps = config.generation_timesteps = generation_timesteps
|
50 |
+
|
51 |
+
image_tokens = torch.ones((len(prompts), config.model.showo.num_vq_tokens),
|
52 |
+
dtype=torch.long, device=device) * mask_token_id
|
53 |
+
|
54 |
+
input_ids, _ = uni_prompting((prompts, image_tokens), 't2i_gen')
|
55 |
+
|
56 |
+
if config.training.guidance_scale > 0:
|
57 |
+
uncond_input_ids, _ = uni_prompting(([''] * len(prompts), image_tokens), 't2i_gen')
|
58 |
+
attention_mask = create_attention_mask_predict_next(torch.cat([input_ids, uncond_input_ids], dim=0),
|
59 |
+
pad_id=int(uni_prompting.sptids_dict['<|pad|>']),
|
60 |
+
soi_id=int(uni_prompting.sptids_dict['<|soi|>']),
|
61 |
+
eoi_id=int(uni_prompting.sptids_dict['<|eoi|>']),
|
62 |
+
rm_pad_in_image=True)
|
63 |
+
else:
|
64 |
+
attention_mask = create_attention_mask_predict_next(input_ids,
|
65 |
+
pad_id=int(uni_prompting.sptids_dict['<|pad|>']),
|
66 |
+
soi_id=int(uni_prompting.sptids_dict['<|soi|>']),
|
67 |
+
eoi_id=int(uni_prompting.sptids_dict['<|eoi|>']),
|
68 |
+
rm_pad_in_image=True)
|
69 |
+
uncond_input_ids = None
|
70 |
+
|
71 |
+
if config.get("mask_schedule", None) is not None:
|
72 |
+
schedule = config.mask_schedule.schedule
|
73 |
+
args = config.mask_schedule.get("params", {})
|
74 |
+
mask_schedule = get_mask_chedule(schedule, **args)
|
75 |
+
else:
|
76 |
+
mask_schedule = get_mask_chedule(config.training.get("mask_schedule", "cosine"))
|
77 |
+
|
78 |
+
with torch.no_grad():
|
79 |
+
gen_token_ids = model.t2i_generate(
|
80 |
+
input_ids=input_ids,
|
81 |
+
uncond_input_ids=uncond_input_ids,
|
82 |
+
attention_mask=attention_mask,
|
83 |
+
guidance_scale=config.training.guidance_scale,
|
84 |
+
temperature=config.training.get("generation_temperature", 1.0),
|
85 |
+
timesteps=config.training.generation_timesteps,
|
86 |
+
noise_schedule=mask_schedule,
|
87 |
+
noise_type=config.training.get("noise_type", "mask"),
|
88 |
+
seq_len=config.model.showo.num_vq_tokens,
|
89 |
+
uni_prompting=uni_prompting,
|
90 |
+
config=config,
|
91 |
+
)
|
92 |
+
|
93 |
+
gen_token_ids = torch.clamp(gen_token_ids, max=config.model.showo.codebook_size - 1, min=0)
|
94 |
+
images = vq_model.decode_code(gen_token_ids)
|
95 |
+
|
96 |
+
images = torch.clamp((images + 1.0) / 2.0, min=0.0, max=1.0)
|
97 |
+
images *= 255.0
|
98 |
+
images = images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8)
|
99 |
+
|
100 |
+
return images[0]
|
101 |
+
|
102 |
+
|
103 |
+
def text_guided_inpainting(input_text, inpainting_image, inpainting_mask, guidance_scale, generation_timesteps):
|
104 |
+
prompt = [input_text]
|
105 |
+
|
106 |
+
config.training.batch_size = config.batch_size = 1
|
107 |
+
config.training.guidance_scale = config.guidance_scale = guidance_scale
|
108 |
+
config.training.generation_timesteps = config.generation_timesteps = generation_timesteps
|
109 |
+
|
110 |
+
inpainting_image = image_transform(inpainting_image, resolution=config.dataset.params.resolution).to(device)
|
111 |
+
inpainting_mask = image_transform(inpainting_mask, resolution=config.dataset.params.resolution, normalize=False)
|
112 |
+
|
113 |
+
inpainting_image = inpainting_image.unsqueeze(0).repeat(config.training.batch_size, 1, 1, 1)
|
114 |
+
|
115 |
+
inpainting_mask = inpainting_mask.unsqueeze(0).to(device)
|
116 |
+
inpainting_mask = F.interpolate(inpainting_mask, size=config.dataset.params.resolution // 16, mode='bicubic')
|
117 |
+
inpainting_mask = inpainting_mask.repeat(config.training.batch_size, 1, 1, 1)
|
118 |
+
|
119 |
+
inpainting_mask[inpainting_mask < 0.5] = 0
|
120 |
+
inpainting_mask[inpainting_mask >= 0.5] = 1
|
121 |
+
|
122 |
+
inpainting_mask = inpainting_mask.reshape(config.training.batch_size, -1)
|
123 |
+
inpainting_mask = inpainting_mask.to(torch.bool)
|
124 |
+
|
125 |
+
inpainting_image_tokens = vq_model.get_code(inpainting_image) + len(uni_prompting.text_tokenizer)
|
126 |
+
inpainting_image_tokens[inpainting_mask] = mask_token_id
|
127 |
+
|
128 |
+
input_ids, _ = uni_prompting((prompt, inpainting_image_tokens), 't2i_gen')
|
129 |
+
|
130 |
+
if config.training.guidance_scale > 0:
|
131 |
+
uncond_input_ids, _ = uni_prompting(([''] * len(prompt), inpainting_image_tokens), 't2i_gen')
|
132 |
+
attention_mask = create_attention_mask_predict_next(torch.cat([input_ids, uncond_input_ids], dim=0),
|
133 |
+
pad_id=int(uni_prompting.sptids_dict['<|pad|>']),
|
134 |
+
soi_id=int(uni_prompting.sptids_dict['<|soi|>']),
|
135 |
+
eoi_id=int(uni_prompting.sptids_dict['<|eoi|>']),
|
136 |
+
rm_pad_in_image=True)
|
137 |
+
else:
|
138 |
+
attention_mask = create_attention_mask_predict_next(input_ids,
|
139 |
+
pad_id=int(uni_prompting.sptids_dict['<|pad|>']),
|
140 |
+
soi_id=int(uni_prompting.sptids_dict['<|soi|>']),
|
141 |
+
eoi_id=int(uni_prompting.sptids_dict['<|eoi|>']),
|
142 |
+
rm_pad_in_image=True)
|
143 |
+
uncond_input_ids = None
|
144 |
+
|
145 |
+
if config.get("mask_schedule", None) is not None:
|
146 |
+
schedule = config.mask_schedule.schedule
|
147 |
+
args = config.mask_schedule.get("params", {})
|
148 |
+
mask_schedule = get_mask_chedule(schedule, **args)
|
149 |
+
else:
|
150 |
+
mask_schedule = get_mask_chedule(config.training.get("mask_schedule", "cosine"))
|
151 |
+
|
152 |
+
with torch.no_grad():
|
153 |
+
gen_token_ids = model.t2i_generate(
|
154 |
+
input_ids=input_ids,
|
155 |
+
uncond_input_ids=uncond_input_ids,
|
156 |
+
attention_mask=attention_mask,
|
157 |
+
guidance_scale=config.training.guidance_scale,
|
158 |
+
temperature=config.training.get("generation_temperature", 1.0),
|
159 |
+
timesteps=config.training.generation_timesteps,
|
160 |
+
noise_schedule=mask_schedule,
|
161 |
+
noise_type=config.training.get("noise_type", "mask"),
|
162 |
+
seq_len=config.model.showo.num_vq_tokens,
|
163 |
+
uni_prompting=uni_prompting,
|
164 |
+
config=config,
|
165 |
+
)
|
166 |
+
|
167 |
+
gen_token_ids = torch.clamp(gen_token_ids, max=config.model.showo.codebook_size - 1, min=0)
|
168 |
+
images = vq_model.decode_code(gen_token_ids)
|
169 |
+
|
170 |
+
images = torch.clamp((images + 1.0) / 2.0, min=0.0, max=1.0)
|
171 |
+
images *= 255.0
|
172 |
+
images = images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8)
|
173 |
+
|
174 |
+
return images[0]
|
175 |
+
|
176 |
+
|
177 |
+
def text_guided_extrapolation(input_img, input_text, left_ext, right_ext, guidance_scale, generation_timesteps):
|
178 |
+
config.offset = 0
|
179 |
+
config.training.batch_size = config.batch_size = 1
|
180 |
+
config.training.guidance_scale = config.guidance_scale = guidance_scale
|
181 |
+
config.training.generation_timesteps = config.generation_timesteps = generation_timesteps
|
182 |
+
|
183 |
+
extra_direction = ['right'] * int(right_ext) + ['left'] * int(left_ext)
|
184 |
+
prompt = [input_text] * len(extra_direction)
|
185 |
+
W = config.dataset.params.resolution // 16
|
186 |
+
for id, (prt, direction) in enumerate(zip(prompt, extra_direction)):
|
187 |
+
prt = [prt] * config.training.batch_size
|
188 |
+
if id == 0:
|
189 |
+
# extrapolation_image = Image.open(config.image_path).convert("RGB")
|
190 |
+
extrapolation_image = input_img
|
191 |
+
extrapolation_image = image_transform(extrapolation_image,
|
192 |
+
resolution=config.dataset.params.resolution).to(device)
|
193 |
+
|
194 |
+
B, _, _ = extrapolation_image.shape
|
195 |
+
extrapolation_image = extrapolation_image.unsqueeze(0)
|
196 |
+
extrapolation_image_tokens = vq_model.get_code(extrapolation_image) + len(uni_prompting.text_tokenizer)
|
197 |
+
extrapolation_image_tokens = extrapolation_image_tokens.reshape(1,
|
198 |
+
config.dataset.params.resolution // 16,
|
199 |
+
config.dataset.params.resolution // 16)
|
200 |
+
extrapolation_image_tokens = extrapolation_image_tokens.repeat(config.training.batch_size, 1, 1)
|
201 |
+
else:
|
202 |
+
|
203 |
+
extrapolation_image_tokens = gen_token_ids + len(uni_prompting.text_tokenizer)
|
204 |
+
|
205 |
+
image_left_part = extrapolation_image_tokens[:, :, :-(W // 2 - config.offset)] - len(
|
206 |
+
uni_prompting.text_tokenizer)
|
207 |
+
image_right_part = extrapolation_image_tokens[:, :, W // 2 - config.offset:] - len(uni_prompting.text_tokenizer)
|
208 |
+
image_up_part = extrapolation_image_tokens[:, :-(W // 2 - config.offset), :] - len(uni_prompting.text_tokenizer)
|
209 |
+
image_down_part = extrapolation_image_tokens[:, W // 2 - config.offset:, :] - len(uni_prompting.text_tokenizer)
|
210 |
+
|
211 |
+
if direction in ['left', 'right']:
|
212 |
+
extrapolation_mask = torch.zeros((config.training.batch_size,
|
213 |
+
config.dataset.params.resolution // 16,
|
214 |
+
config.dataset.params.resolution // 16 // 2 + config.offset),
|
215 |
+
dtype=torch.int64, device=device) + mask_token_id
|
216 |
+
else:
|
217 |
+
extrapolation_mask = torch.zeros((config.training.batch_size,
|
218 |
+
config.dataset.params.resolution // 16 // 2 + config.offset,
|
219 |
+
config.dataset.params.resolution // 16),
|
220 |
+
dtype=torch.int64, device=device) + mask_token_id
|
221 |
+
|
222 |
+
if direction == 'left':
|
223 |
+
extrapolation_image_tokens = torch.cat(
|
224 |
+
[extrapolation_mask, extrapolation_image_tokens[:, :, :W // 2 - config.offset]], dim=-1)
|
225 |
+
elif direction == 'right':
|
226 |
+
extrapolation_image_tokens = torch.cat(
|
227 |
+
[extrapolation_image_tokens[:, :, -(W // 2 - config.offset):], extrapolation_mask], dim=-1)
|
228 |
+
elif direction == 'up':
|
229 |
+
extrapolation_image_tokens = torch.cat(
|
230 |
+
[extrapolation_mask, extrapolation_image_tokens[:, :W // 2 - config.offset, :]], dim=-2)
|
231 |
+
else:
|
232 |
+
extrapolation_image_tokens = torch.cat(
|
233 |
+
[extrapolation_image_tokens[:, -(W // 2 - config.offset):, :], extrapolation_mask], dim=-2)
|
234 |
+
|
235 |
+
extrapolation_image_tokens = extrapolation_image_tokens.reshape(config.training.batch_size, -1)
|
236 |
+
|
237 |
+
input_ids, _ = uni_prompting((prt, extrapolation_image_tokens), 't2i_gen')
|
238 |
+
|
239 |
+
if config.training.guidance_scale > 0:
|
240 |
+
uncond_input_ids, _ = uni_prompting(([''] * len(prt), extrapolation_image_tokens), 't2i_gen')
|
241 |
+
attention_mask = create_attention_mask_predict_next(torch.cat([input_ids, uncond_input_ids], dim=0),
|
242 |
+
pad_id=int(uni_prompting.sptids_dict['<|pad|>']),
|
243 |
+
soi_id=int(uni_prompting.sptids_dict['<|soi|>']),
|
244 |
+
eoi_id=int(uni_prompting.sptids_dict['<|eoi|>']),
|
245 |
+
rm_pad_in_image=True)
|
246 |
+
else:
|
247 |
+
attention_mask = create_attention_mask_predict_next(input_ids,
|
248 |
+
pad_id=int(uni_prompting.sptids_dict['<|pad|>']),
|
249 |
+
soi_id=int(uni_prompting.sptids_dict['<|soi|>']),
|
250 |
+
eoi_id=int(uni_prompting.sptids_dict['<|eoi|>']),
|
251 |
+
rm_pad_in_image=True)
|
252 |
+
uncond_input_ids = None
|
253 |
+
|
254 |
+
if config.get("mask_schedule", None) is not None:
|
255 |
+
schedule = config.mask_schedule.schedule
|
256 |
+
args = config.mask_schedule.get("params", {})
|
257 |
+
mask_schedule = get_mask_chedule(schedule, **args)
|
258 |
+
else:
|
259 |
+
mask_schedule = get_mask_chedule(config.training.get("mask_schedule", "cosine"))
|
260 |
+
|
261 |
+
with torch.no_grad():
|
262 |
+
gen_token_ids = model.t2i_generate(
|
263 |
+
input_ids=input_ids,
|
264 |
+
uncond_input_ids=uncond_input_ids,
|
265 |
+
attention_mask=attention_mask,
|
266 |
+
guidance_scale=config.training.guidance_scale,
|
267 |
+
temperature=config.training.get("generation_temperature", 1.0),
|
268 |
+
timesteps=config.training.generation_timesteps,
|
269 |
+
noise_schedule=mask_schedule,
|
270 |
+
noise_type=config.training.get("noise_type", "mask"),
|
271 |
+
seq_len=config.model.showo.num_vq_tokens,
|
272 |
+
uni_prompting=uni_prompting,
|
273 |
+
config=config,
|
274 |
+
)
|
275 |
+
|
276 |
+
gen_token_ids = torch.clamp(gen_token_ids, max=config.model.showo.codebook_size - 1, min=0)
|
277 |
+
gen_token_ids = gen_token_ids.reshape(config.training.batch_size,
|
278 |
+
config.dataset.params.resolution // 16,
|
279 |
+
config.dataset.params.resolution // 16)
|
280 |
+
if direction == 'left':
|
281 |
+
gen_token_ids = torch.cat([gen_token_ids, image_right_part], dim=-1)
|
282 |
+
elif direction == 'right':
|
283 |
+
gen_token_ids = torch.cat([image_left_part, gen_token_ids], dim=-1)
|
284 |
+
elif direction == 'up':
|
285 |
+
gen_token_ids = torch.cat([gen_token_ids, image_down_part], dim=-2)
|
286 |
+
else:
|
287 |
+
gen_token_ids = torch.cat([image_left_part, gen_token_ids], dim=-2)
|
288 |
+
|
289 |
+
_, h, w = gen_token_ids.shape
|
290 |
+
gen_token_ids = gen_token_ids.reshape(config.training.batch_size, -1)
|
291 |
+
images = vq_model.decode_code(gen_token_ids, shape=(h, w))
|
292 |
+
|
293 |
+
images = torch.clamp((images + 1.0) / 2.0, min=0.0, max=1.0)
|
294 |
+
images *= 255.0
|
295 |
+
images = images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8)
|
296 |
+
|
297 |
+
return images[0]
|
298 |
+
|
299 |
+
|
300 |
+
def multimodal_understanding(input_img, input_text, chat_history):
|
301 |
+
top_k = 1 # retain only the top_k most likely tokens, clamp others to have 0 probability
|
302 |
+
|
303 |
+
image_ori = input_img
|
304 |
+
image = image_transform(image_ori, resolution=config.dataset.params.resolution).to(device)
|
305 |
+
image = image.unsqueeze(0)
|
306 |
+
image_tokens = vq_model.get_code(image) + len(uni_prompting.text_tokenizer)
|
307 |
+
|
308 |
+
question = input_text
|
309 |
+
input_ids = uni_prompting.text_tokenizer(['USER: \n' + question + ' ASSISTANT:'])[
|
310 |
+
'input_ids']
|
311 |
+
input_ids = torch.tensor(input_ids).to(device)
|
312 |
+
|
313 |
+
input_ids = torch.cat([
|
314 |
+
(torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|mmu|>']).to(device),
|
315 |
+
(torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|soi|>']).to(device),
|
316 |
+
image_tokens,
|
317 |
+
(torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|eoi|>']).to(device),
|
318 |
+
(torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|sot|>']).to(device),
|
319 |
+
input_ids
|
320 |
+
], dim=1).long()
|
321 |
+
|
322 |
+
attention_mask = create_attention_mask_for_mmu(input_ids.to(device),
|
323 |
+
eoi_id=int(uni_prompting.sptids_dict['<|eoi|>']))
|
324 |
+
|
325 |
+
cont_toks_list = model.mmu_generate(input_ids, attention_mask=attention_mask,
|
326 |
+
max_new_tokens=100, top_k=top_k,
|
327 |
+
eot_token=uni_prompting.sptids_dict['<|eot|>'])
|
328 |
+
|
329 |
+
cont_toks_list = torch.stack(cont_toks_list).squeeze()[None]
|
330 |
+
|
331 |
+
output_text = uni_prompting.text_tokenizer.batch_decode(cont_toks_list, skip_special_tokens=True)
|
332 |
+
|
333 |
+
output_text = output_text[0].strip()
|
334 |
+
|
335 |
+
chat_history.append((input_text, output_text))
|
336 |
+
|
337 |
+
return "", chat_history
|
338 |
+
|
339 |
+
|
340 |
+
with gr.Blocks() as demo:
|
341 |
+
gr.HTML("""
|
342 |
+
<h1 class="display-2 fw-bold title">
|
343 |
+
<a style="color: #70a8dc;">S</a><a style="color: #6fb051;">h</a><a style="color: #e06766;">o</a><a style="color: #f7b26b;">w</a>-o
|
344 |
+
</h1>
|
345 |
+
<p>This is the official Gradio demo for the Show-o model, a unified model that can do multimodal understanding and generation.</p>
|
346 |
+
|
347 |
+
<strong>Paper:</strong> <a href="https://arxiv.org/abs/2408.12528" target="_blank">Show-o: One Single Transformer To Unify Multimodal Understanding and Generation </a>
|
348 |
+
<br/>
|
349 |
+
<strong>Project Website:</strong> <a href="https://showlab.github.io/Show-o/" target="_blank">Show-o Website</a>
|
350 |
+
<br/>
|
351 |
+
<strong>Code and Models:</strong> <a href="https://github.com/showlab/Show-o" target="_blank">GitHub</a>
|
352 |
+
<br/>
|
353 |
+
<br/>
|
354 |
+
""")
|
355 |
+
|
356 |
+
with gr.Row():
|
357 |
+
with gr.Column():
|
358 |
+
text_prompt_t2i = gr.Textbox(
|
359 |
+
label="Text prompt",
|
360 |
+
lines=2,
|
361 |
+
placeholder="Input the text prompt here for image generation."
|
362 |
+
)
|
363 |
+
guidance_scale_t2i = gr.Slider(
|
364 |
+
label="guidance scale",
|
365 |
+
minimum=0,
|
366 |
+
maximum=5,
|
367 |
+
step=0.05,
|
368 |
+
value=1.75
|
369 |
+
)
|
370 |
+
generation_timesteps_t2i = gr.Slider(
|
371 |
+
label="timesteps",
|
372 |
+
minimum=1,
|
373 |
+
maximum=30,
|
374 |
+
step=1,
|
375 |
+
value=18
|
376 |
+
)
|
377 |
+
generated_img_t2i = gr.Image(
|
378 |
+
label="Output image"
|
379 |
+
)
|
380 |
+
submit_btn_t2i = gr.Button("Generate: Text-to-image")
|
381 |
+
submit_btn_t2i.click(text_to_image_generation,
|
382 |
+
[text_prompt_t2i, guidance_scale_t2i, generation_timesteps_t2i],
|
383 |
+
[generated_img_t2i])
|
384 |
+
|
385 |
+
with gr.Row():
|
386 |
+
inpainting_input_img = gr.Image(
|
387 |
+
label="Input image",
|
388 |
+
type="pil",
|
389 |
+
)
|
390 |
+
inpainting_input_mask = gr.Image(
|
391 |
+
label="Inpainting mask",
|
392 |
+
image_mode="L",
|
393 |
+
type="pil",
|
394 |
+
)
|
395 |
+
|
396 |
+
with gr.Column():
|
397 |
+
text_prompt_inpainting = gr.Textbox(
|
398 |
+
label="Text prompt",
|
399 |
+
lines=2,
|
400 |
+
placeholder="Input the text prompt here for image inpainting."
|
401 |
+
)
|
402 |
+
guidance_scale_inpainting = gr.Slider(
|
403 |
+
label="guidance scale",
|
404 |
+
minimum=0,
|
405 |
+
maximum=5,
|
406 |
+
step=0.05,
|
407 |
+
value=1.75
|
408 |
+
)
|
409 |
+
generation_timesteps_inpainting = gr.Slider(
|
410 |
+
label="timesteps",
|
411 |
+
minimum=1,
|
412 |
+
maximum=30,
|
413 |
+
step=1,
|
414 |
+
value=16
|
415 |
+
)
|
416 |
+
generated_img_inpainting = gr.Image(
|
417 |
+
label="Output image"
|
418 |
+
)
|
419 |
+
submit_btn_inpainting = gr.Button("Generate: Text-guided Inpainting")
|
420 |
+
submit_btn_inpainting.click(text_guided_inpainting,
|
421 |
+
[text_prompt_inpainting, inpainting_input_img, inpainting_input_mask,
|
422 |
+
guidance_scale_inpainting, generation_timesteps_inpainting],
|
423 |
+
[generated_img_inpainting])
|
424 |
+
|
425 |
+
with gr.Row():
|
426 |
+
extra_input_img = gr.Image(
|
427 |
+
label="Input image",
|
428 |
+
type="pil",
|
429 |
+
image_mode="RGB",
|
430 |
+
)
|
431 |
+
|
432 |
+
with gr.Column():
|
433 |
+
text_prompt_extrapolation = gr.Textbox(
|
434 |
+
label="Text prompt",
|
435 |
+
lines=1,
|
436 |
+
placeholder="Input the text prompt here for image extrapolation."
|
437 |
+
)
|
438 |
+
guidance_scale_extrapolation = gr.Slider(
|
439 |
+
label="guidance scale",
|
440 |
+
minimum=0,
|
441 |
+
maximum=5,
|
442 |
+
step=0.05,
|
443 |
+
value=1.75
|
444 |
+
)
|
445 |
+
generation_timesteps_extrapolation = gr.Slider(
|
446 |
+
label="timesteps",
|
447 |
+
minimum=1,
|
448 |
+
maximum=30,
|
449 |
+
step=1,
|
450 |
+
value=16
|
451 |
+
)
|
452 |
+
left_extrapolation = gr.Slider(
|
453 |
+
label="left extrapolation",
|
454 |
+
minimum=0,
|
455 |
+
maximum=5,
|
456 |
+
step=1,
|
457 |
+
value=1
|
458 |
+
)
|
459 |
+
right_extrapolation = gr.Slider(
|
460 |
+
label="right extrapolation",
|
461 |
+
minimum=0,
|
462 |
+
maximum=5,
|
463 |
+
step=1,
|
464 |
+
value=1
|
465 |
+
)
|
466 |
+
generated_img_extrapolation = gr.Image(
|
467 |
+
label="Output image"
|
468 |
+
)
|
469 |
+
submit_btn_inpainting = gr.Button("Generate: Text-guided Extrapolation")
|
470 |
+
submit_btn_inpainting.click(text_guided_extrapolation,
|
471 |
+
[extra_input_img, text_prompt_extrapolation, left_extrapolation, right_extrapolation,
|
472 |
+
guidance_scale_extrapolation, generation_timesteps_extrapolation],
|
473 |
+
[generated_img_extrapolation])
|
474 |
+
|
475 |
+
with gr.Row():
|
476 |
+
with gr.Row():
|
477 |
+
chat_input_img = gr.Image(
|
478 |
+
label="Input image",
|
479 |
+
type="pil",
|
480 |
+
image_mode="RGB",
|
481 |
+
)
|
482 |
+
with gr.Column():
|
483 |
+
chatbot = gr.Chatbot()
|
484 |
+
msg = gr.Textbox(label="Press Enter to send a message for chat")
|
485 |
+
clear = gr.ClearButton([msg, chatbot])
|
486 |
+
msg.submit(multimodal_understanding, [chat_input_img, msg, chatbot], [msg, chatbot])
|
487 |
+
|
488 |
+
demo.launch()
|
gradio/app_gradio.py
ADDED
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "False"
|
3 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
4 |
+
import tempfile
|
5 |
+
from share_btn import share_js, save_js
|
6 |
+
import gradio as gr
|
7 |
+
from PIL import Image
|
8 |
+
import torch
|
9 |
+
from omegaconf import OmegaConf
|
10 |
+
from transformers import AutoTokenizer
|
11 |
+
|
12 |
+
from models import Showo, MAGVITv2, get_mask_chedule
|
13 |
+
from prompting_utils import UniversalPrompting, create_attention_mask_predict_next
|
14 |
+
|
15 |
+
|
16 |
+
# Prepare model
|
17 |
+
config = OmegaConf.load("configs/showo_demo.yaml")
|
18 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
19 |
+
tokenizer = AutoTokenizer.from_pretrained(config.model.showo.llm_model_path, padding_side="left")
|
20 |
+
|
21 |
+
uni_prompting = UniversalPrompting(tokenizer, max_text_len=config.dataset.preprocessing.max_seq_length,
|
22 |
+
special_tokens=("<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", "<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>"),
|
23 |
+
ignore_id=-100, cond_dropout_prob=config.training.cond_dropout_prob)
|
24 |
+
|
25 |
+
vq_model = MAGVITv2(config.model.vq_model.type)
|
26 |
+
vq_model = vq_model.from_pretrained(config.model.vq_model.vq_model_name).to(device)
|
27 |
+
vq_model.requires_grad_(False)
|
28 |
+
vq_model.eval()
|
29 |
+
|
30 |
+
model = Showo.from_pretrained(config.model.showo.pretrained_model_path).to(device)
|
31 |
+
model.eval()
|
32 |
+
|
33 |
+
mask_token_id = model.config.mask_token_id
|
34 |
+
|
35 |
+
|
36 |
+
css = """
|
37 |
+
#chatbot { min-height: 300px; }
|
38 |
+
#save-btn {
|
39 |
+
background-image: linear-gradient(to right bottom, rgba(130,217,244, 0.9), rgba(158,231,214, 1.0));
|
40 |
+
}
|
41 |
+
#save-btn:hover {
|
42 |
+
background-image: linear-gradient(to right bottom, rgba(110,197,224, 0.9), rgba(138,211,194, 1.0));
|
43 |
+
}
|
44 |
+
#share-btn {
|
45 |
+
background-image: linear-gradient(to right bottom, rgba(130,217,244, 0.9), rgba(158,231,214, 1.0));
|
46 |
+
}
|
47 |
+
#share-btn:hover {
|
48 |
+
background-image: linear-gradient(to right bottom, rgba(110,197,224, 0.9), rgba(138,211,194, 1.0));
|
49 |
+
}
|
50 |
+
#gallery { z-index: 999999; }
|
51 |
+
#gallery img:hover {transform: scale(2.3); z-index: 999999; position: relative; padding-right: 30%; padding-bottom: 30%;}
|
52 |
+
#gallery button img:hover {transform: none; z-index: 999999; position: relative; padding-right: 0; padding-bottom: 0;}
|
53 |
+
@media (hover: none) {
|
54 |
+
#gallery img:hover {transform: none; z-index: 999999; position: relative; padding-right: 0; 0;}
|
55 |
+
}
|
56 |
+
.html2canvas-container { width: 3000px !important; height: 3000px !important; }
|
57 |
+
"""
|
58 |
+
|
59 |
+
|
60 |
+
def upload_image(state, image_input):
|
61 |
+
conversation = state[0]
|
62 |
+
chat_history = state[1]
|
63 |
+
input_image = Image.open(image_input.name).resize(
|
64 |
+
(224, 224)).convert('RGB')
|
65 |
+
input_image.save(image_input.name) # Overwrite with smaller image.
|
66 |
+
conversation += [(f'<img src="./file={image_input.name}" style="display: inline-block;">', "")]
|
67 |
+
return [conversation, chat_history + [input_image, ""]], conversation
|
68 |
+
|
69 |
+
|
70 |
+
def reset():
|
71 |
+
return [[], []], []
|
72 |
+
|
73 |
+
|
74 |
+
def reset_last(state):
|
75 |
+
conversation = state[0][:-1]
|
76 |
+
chat_history = state[1][:-2]
|
77 |
+
return [conversation, chat_history], conversation
|
78 |
+
|
79 |
+
|
80 |
+
def save_image_to_local(image: Image.Image):
|
81 |
+
filename = next(tempfile._get_candidate_names()) + '.png'
|
82 |
+
image.save(filename)
|
83 |
+
return filename
|
84 |
+
|
85 |
+
|
86 |
+
def text_to_image_generation(input_text, state, guidance_scale, generation_timesteps):
|
87 |
+
prompts = [input_text]
|
88 |
+
config.training.batch_size = config.batch_size = 1
|
89 |
+
config.training.guidance_scale = config.guidance_scale = guidance_scale
|
90 |
+
config.training.generation_timesteps = config.generation_timesteps = generation_timesteps
|
91 |
+
|
92 |
+
image_tokens = torch.ones((len(prompts), config.model.showo.num_vq_tokens),
|
93 |
+
dtype=torch.long, device=device) * mask_token_id
|
94 |
+
|
95 |
+
input_ids, _ = uni_prompting((prompts, image_tokens), 't2i_gen')
|
96 |
+
|
97 |
+
if config.training.guidance_scale > 0:
|
98 |
+
uncond_input_ids, _ = uni_prompting(([''] * len(prompts), image_tokens), 't2i_gen')
|
99 |
+
attention_mask = create_attention_mask_predict_next(torch.cat([input_ids, uncond_input_ids], dim=0),
|
100 |
+
pad_id=int(uni_prompting.sptids_dict['<|pad|>']),
|
101 |
+
soi_id=int(uni_prompting.sptids_dict['<|soi|>']),
|
102 |
+
eoi_id=int(uni_prompting.sptids_dict['<|eoi|>']),
|
103 |
+
rm_pad_in_image=True)
|
104 |
+
else:
|
105 |
+
attention_mask = create_attention_mask_predict_next(input_ids,
|
106 |
+
pad_id=int(uni_prompting.sptids_dict['<|pad|>']),
|
107 |
+
soi_id=int(uni_prompting.sptids_dict['<|soi|>']),
|
108 |
+
eoi_id=int(uni_prompting.sptids_dict['<|eoi|>']),
|
109 |
+
rm_pad_in_image=True)
|
110 |
+
uncond_input_ids = None
|
111 |
+
|
112 |
+
if config.get("mask_schedule", None) is not None:
|
113 |
+
schedule = config.mask_schedule.schedule
|
114 |
+
args = config.mask_schedule.get("params", {})
|
115 |
+
mask_schedule = get_mask_chedule(schedule, **args)
|
116 |
+
else:
|
117 |
+
mask_schedule = get_mask_chedule(config.training.get("mask_schedule", "cosine"))
|
118 |
+
|
119 |
+
with torch.no_grad():
|
120 |
+
gen_token_ids = model.t2i_generate(
|
121 |
+
input_ids=input_ids,
|
122 |
+
uncond_input_ids=uncond_input_ids,
|
123 |
+
attention_mask=attention_mask,
|
124 |
+
guidance_scale=config.training.guidance_scale,
|
125 |
+
temperature=config.training.get("generation_temperature", 1.0),
|
126 |
+
timesteps=config.training.generation_timesteps,
|
127 |
+
noise_schedule=mask_schedule,
|
128 |
+
noise_type=config.training.get("noise_type", "mask"),
|
129 |
+
seq_len=config.model.showo.num_vq_tokens,
|
130 |
+
uni_prompting=uni_prompting,
|
131 |
+
config=config,
|
132 |
+
)
|
133 |
+
|
134 |
+
gen_token_ids = torch.clamp(gen_token_ids, max=config.model.showo.codebook_size - 1, min=0)
|
135 |
+
images = vq_model.decode_code(gen_token_ids)
|
136 |
+
|
137 |
+
images = torch.clamp((images + 1.0) / 2.0, min=0.0, max=1.0)
|
138 |
+
images *= 255.0
|
139 |
+
images = images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8)
|
140 |
+
pil_images = [Image.fromarray(image) for image in images]
|
141 |
+
|
142 |
+
wandb_images = [wandb.Image(image, caption=prompts[i]) for i, image in enumerate(pil_images)]
|
143 |
+
wandb.log({"generated_images": wandb_images}, step=step)
|
144 |
+
|
145 |
+
|
146 |
+
def generate_for_prompt(input_text, state, ret_scale_factor, num_words, temperature):
|
147 |
+
g_cuda = torch.Generator(device='cuda').manual_seed(1337)
|
148 |
+
|
149 |
+
# Ignore empty inputs.
|
150 |
+
if len(input_text) == 0:
|
151 |
+
return state, state[0], gr.update(visible=True)
|
152 |
+
|
153 |
+
input_prompt = 'Q: ' + input_text + '\nA:'
|
154 |
+
conversation = state[0]
|
155 |
+
chat_history = state[1]
|
156 |
+
print('Generating for', chat_history, flush=True)
|
157 |
+
|
158 |
+
# If an image was uploaded, prepend it to the model.
|
159 |
+
model_inputs = chat_history
|
160 |
+
model_inputs.append(input_prompt)
|
161 |
+
# Remove empty text.
|
162 |
+
model_inputs = [s for s in model_inputs if s != '']
|
163 |
+
|
164 |
+
top_p = 1.0
|
165 |
+
if temperature != 0.0:
|
166 |
+
top_p = 0.95
|
167 |
+
|
168 |
+
print('Running model.generate_for_images_and_texts with', model_inputs, flush=True)
|
169 |
+
model_outputs = model.generate_for_images_and_texts(model_inputs,
|
170 |
+
num_words=max(num_words, 1), ret_scale_factor=ret_scale_factor, top_p=top_p,
|
171 |
+
temperature=temperature, max_num_rets=1,
|
172 |
+
num_inference_steps=50, generator=g_cuda)
|
173 |
+
print('model_outputs', model_outputs, ret_scale_factor, flush=True)
|
174 |
+
|
175 |
+
response = ''
|
176 |
+
text_outputs = []
|
177 |
+
for output_i, p in enumerate(model_outputs):
|
178 |
+
if type(p) == str:
|
179 |
+
if output_i > 0:
|
180 |
+
response += '<br/>'
|
181 |
+
# Remove the image tokens for output.
|
182 |
+
text_outputs.append(p.replace('[IMG0] [IMG1] [IMG2] [IMG3] [IMG4] [IMG5] [IMG6] [IMG7]', ''))
|
183 |
+
response += p
|
184 |
+
if len(model_outputs) > 1:
|
185 |
+
response += '<br/>'
|
186 |
+
elif type(p) == dict:
|
187 |
+
# Decide whether to generate or retrieve.
|
188 |
+
if p['decision'] is not None and p['decision'][0] == 'gen':
|
189 |
+
image = p['gen'][0][0]#.resize((224, 224))
|
190 |
+
filename = save_image_to_local(image)
|
191 |
+
response += f'<img src="./file={filename}" style="display: inline-block;"><p style="font-size: 12px; color: #555; margin-top: 0;">(Generated)</p>'
|
192 |
+
else:
|
193 |
+
image = p['ret'][0][0]#.resize((224, 224))
|
194 |
+
filename = save_image_to_local(image)
|
195 |
+
response += f'<img src="./file={filename}" style="display: inline-block;"><p style="font-size: 12px; color: #555; margin-top: 0;">(Retrieved)</p>'
|
196 |
+
|
197 |
+
chat_history = model_inputs + \
|
198 |
+
[' '.join([s for s in model_outputs if type(s) == str]) + '\n']
|
199 |
+
# Remove [RET] from outputs.
|
200 |
+
conversation.append((input_text, response.replace('[IMG0] [IMG1] [IMG2] [IMG3] [IMG4] [IMG5] [IMG6] [IMG7]', '')))
|
201 |
+
|
202 |
+
# Set input image to None.
|
203 |
+
print('state', state, flush=True)
|
204 |
+
print('updated state', [conversation, chat_history], flush=True)
|
205 |
+
return [conversation, chat_history], conversation, gr.update(visible=True), gr.update(visible=True)
|
206 |
+
|
207 |
+
|
208 |
+
with gr.Blocks(css=css) as demo:
|
209 |
+
gr.HTML("""
|
210 |
+
<h1>🐟 GILL</h1>
|
211 |
+
<p>This is the official Gradio demo for the GILL model, a model that can process arbitrarily interleaved image and text inputs, and produce image and text outputs.</p>
|
212 |
+
|
213 |
+
<strong>Paper:</strong> <a href="https://arxiv.org/abs/2305.17216" target="_blank">Generating Images with Multimodal Language Models</a>
|
214 |
+
<br/>
|
215 |
+
<strong>Project Website:</strong> <a href="https://jykoh.com/gill" target="_blank">GILL Website</a>
|
216 |
+
<br/>
|
217 |
+
<strong>Code and Models:</strong> <a href="https://github.com/kohjingyu/gill" target="_blank">GitHub</a>
|
218 |
+
<br/>
|
219 |
+
<br/>
|
220 |
+
|
221 |
+
<strong>Tips:</strong>
|
222 |
+
<ul>
|
223 |
+
<li>Start by inputting either image or text prompts (or both) and chat with GILL to get image-and-text replies.</li>
|
224 |
+
<li>Tweak the level of sensitivity to images and text using the parameters on the right.</li>
|
225 |
+
<li>Check out cool conversations in the examples or community tab for inspiration and share your own!</li>
|
226 |
+
<li>If the model outputs a blank image, it is because Stable Diffusion's safety filter detected inappropriate content. Please try again with a different prompt.</li>
|
227 |
+
<li>Outputs may differ slightly from the paper due to slight implementation differences. For reproducing paper results, please use our <a href="https://github.com/kohjingyu/gill" target="_blank">official code</a>.</li>
|
228 |
+
<li>For faster inference without waiting in queue, you may duplicate the space and use your own GPU: <a href="https://huggingface.co/spaces/jykoh/gill?duplicate=true"><img style="display: inline-block; margin-top: 0em; margin-bottom: 0em" src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAAAAXNSR0IArs4c6QAAAP5JREFUOE+lk7FqAkEURY+ltunEgFXS2sZGIbXfEPdLlnxJyDdYB62sbbUKpLbVNhyYFzbrrA74YJlh9r079973psed0cvUD4A+4HoCjsA85X0Dfn/RBLBgBDxnQPfAEJgBY+A9gALA4tcbamSzS4xq4FOQAJgCDwV2CPKV8tZAJcAjMMkUe1vX+U+SMhfAJEHasQIWmXNN3abzDwHUrgcRGmYcgKe0bxrblHEB4E/pndMazNpSZGcsZdBlYJcEL9Afo75molJyM2FxmPgmgPqlWNLGfwZGG6UiyEvLzHYDmoPkDDiNm9JR9uboiONcBXrpY1qmgs21x1QwyZcpvxt9NS09PlsPAAAAAElFTkSuQmCC&logoWidth=14" alt="Duplicate Space"></a></li>
|
229 |
+
</ul>
|
230 |
+
""")
|
231 |
+
|
232 |
+
gr_state = gr.State([[], []]) # conversation, chat_history
|
233 |
+
|
234 |
+
with gr.Row():
|
235 |
+
with gr.Column(scale=0.7, min_width=500):
|
236 |
+
with gr.Row():
|
237 |
+
chatbot = gr.Chatbot(elem_id="chatbot", label="🐟 GILL Chatbot")
|
238 |
+
with gr.Row():
|
239 |
+
image_btn = gr.UploadButton("🖼️ Upload Image", file_types=["image"])
|
240 |
+
|
241 |
+
text_input = gr.Textbox(label="Message", placeholder="Type a message")
|
242 |
+
|
243 |
+
with gr.Column():
|
244 |
+
submit_btn = gr.Button("Submit", interactive=True, variant="primary")
|
245 |
+
clear_last_btn = gr.Button("Undo")
|
246 |
+
clear_btn = gr.Button("Reset All")
|
247 |
+
with gr.Row(visible=False) as save_group:
|
248 |
+
save_button = gr.Button("💾 Save Conversation as .png", elem_id="save-btn")
|
249 |
+
|
250 |
+
with gr.Row(visible=False) as share_group:
|
251 |
+
share_button = gr.Button("🤗 Share to Community (opens new window)", elem_id="share-btn")
|
252 |
+
|
253 |
+
with gr.Column(scale=0.3, min_width=400):
|
254 |
+
ret_scale_factor = gr.Slider(minimum=0.0, maximum=3.0, value=1.3, step=0.1, interactive=True,
|
255 |
+
label="Frequency multiplier for returning images (higher means more frequent)")
|
256 |
+
gr_max_len = gr.Slider(minimum=1, maximum=64, value=32,
|
257 |
+
step=1, interactive=True, label="Max # of words")
|
258 |
+
gr_temperature = gr.Slider(
|
259 |
+
minimum=0.0, maximum=1.0, value=0.0, step=0.1, interactive=True, label="Temperature (0 for deterministic, higher for more randomness)")
|
260 |
+
|
261 |
+
gallery = gr.Gallery(
|
262 |
+
value=[Image.open(e) for e in examples], label="Example Conversations", show_label=True, elem_id="gallery",
|
263 |
+
).style(grid=[2], height="auto")
|
264 |
+
|
265 |
+
text_input.submit(generate_for_prompt, [text_input, gr_state, ret_scale_factor,
|
266 |
+
gr_max_len, gr_temperature], [gr_state, chatbot, share_group, save_group])
|
267 |
+
text_input.submit(lambda: "", None, text_input) # Reset chatbox.
|
268 |
+
|
269 |
+
submit_btn.click(generate_for_prompt, [text_input, gr_state, ret_scale_factor,
|
270 |
+
gr_max_len, gr_temperature], [gr_state, chatbot, share_group, save_group])
|
271 |
+
submit_btn.click(lambda: "", None, text_input) # Reset chatbox.
|
272 |
+
|
273 |
+
image_btn.upload(upload_image, [gr_state, image_btn], [gr_state, chatbot])
|
274 |
+
clear_last_btn.click(reset_last, [gr_state], [gr_state, chatbot])
|
275 |
+
clear_btn.click(reset, [], [gr_state, chatbot])
|
276 |
+
share_button.click(None, [], [], _js=share_js)
|
277 |
+
save_button.click(None, [], [], _js=save_js)
|
278 |
+
|
279 |
+
|
280 |
+
demo.queue(concurrency_count=1, api_open=False, max_size=16)
|
281 |
+
demo.launch(debug=True, server_name="0.0.0.0")
|
gradio/app_w_clip.py
ADDED
@@ -0,0 +1,559 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "False"
|
3 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
4 |
+
import numpy as np
|
5 |
+
import gradio as gr
|
6 |
+
import torch
|
7 |
+
from PIL import Image
|
8 |
+
from omegaconf import OmegaConf
|
9 |
+
from transformers import AutoTokenizer
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from transformers import CLIPImageProcessor
|
12 |
+
|
13 |
+
import sys
|
14 |
+
sys.path.insert(0, ".")
|
15 |
+
from training import conversation as conversation_lib
|
16 |
+
from prompting_utils import UniversalPrompting, create_attention_mask_predict_next, create_attention_mask_for_mmu_vit
|
17 |
+
from training_utils import image_transform
|
18 |
+
from models import Showo, MAGVITv2, get_mask_chedule, CLIPVisionTower
|
19 |
+
|
20 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
21 |
+
conversation_lib.default_conversation = conversation_lib.conv_templates["phi1.5"]
|
22 |
+
SYSTEM_PROMPT = "A chat between a curious user and an artificial intelligence assistant. " \
|
23 |
+
"The assistant gives helpful, detailed, and polite answers to the user's questions."
|
24 |
+
SYSTEM_PROMPT_LEN = 28
|
25 |
+
|
26 |
+
|
27 |
+
def load_discrete_checkpoint():
|
28 |
+
config = OmegaConf.load("configs/showo_demo.yaml")
|
29 |
+
tokenizer = AutoTokenizer.from_pretrained(config.model.showo.llm_model_path, padding_side="left")
|
30 |
+
|
31 |
+
uni_prompting = UniversalPrompting(tokenizer, max_text_len=config.dataset.preprocessing.max_seq_length,
|
32 |
+
special_tokens=("<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", "<|mmu|>",
|
33 |
+
"<|t2v|>", "<|v2v|>", "<|lvg|>"),
|
34 |
+
ignore_id=-100, cond_dropout_prob=config.training.cond_dropout_prob)
|
35 |
+
|
36 |
+
vq_model = MAGVITv2()
|
37 |
+
vq_model = vq_model.from_pretrained(config.model.vq_model.vq_model_name).to(device)
|
38 |
+
vq_model.requires_grad_(False)
|
39 |
+
vq_model.eval()
|
40 |
+
|
41 |
+
model = Showo.from_pretrained(config.model.showo.pretrained_model_path).to(device)
|
42 |
+
model.eval()
|
43 |
+
mask_token_id = model.config.mask_token_id
|
44 |
+
|
45 |
+
return config, uni_prompting, tokenizer, vq_model, model, mask_token_id
|
46 |
+
|
47 |
+
|
48 |
+
config_gen, uni_prompting_gen, tokenizer_gen, vq_model_gen, model_gen, mask_token_id = load_discrete_checkpoint()
|
49 |
+
|
50 |
+
|
51 |
+
def load_continuous_checkpoint():
|
52 |
+
config = OmegaConf.load("configs/showo_demo_w_clip_vit.yaml")
|
53 |
+
|
54 |
+
tokenizer = AutoTokenizer.from_pretrained(config.model.showo.llm_model_path, padding_side="left")
|
55 |
+
|
56 |
+
uni_prompting = UniversalPrompting(tokenizer, max_text_len=config.dataset.preprocessing.max_seq_length,
|
57 |
+
special_tokens=(
|
58 |
+
"<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", "<|mmu|>", "<|t2v|>",
|
59 |
+
"<|v2v|>", "<|lvg|>"),
|
60 |
+
ignore_id=-100, cond_dropout_prob=config.training.cond_dropout_prob)
|
61 |
+
|
62 |
+
vision_tower_name = "openai/clip-vit-large-patch14-336"
|
63 |
+
vision_tower = CLIPVisionTower(vision_tower_name).to(device)
|
64 |
+
clip_image_processor = CLIPImageProcessor.from_pretrained(vision_tower_name)
|
65 |
+
|
66 |
+
model = Showo.from_pretrained(config.model.showo.pretrained_model_path).to(device)
|
67 |
+
model.eval()
|
68 |
+
|
69 |
+
return config, uni_prompting, tokenizer, model, vision_tower, clip_image_processor
|
70 |
+
|
71 |
+
|
72 |
+
config_mmu = uni_prompting_mmu = tokenizer_mmu = model_mmu = vision_tower = clip_image_processor = None
|
73 |
+
|
74 |
+
|
75 |
+
def text_to_image_generation(input_text, guidance_scale, generation_timesteps):
|
76 |
+
config, uni_prompting, tokenizer, vq_model, model = config_gen, uni_prompting_gen, tokenizer_gen, vq_model_gen, model_gen
|
77 |
+
|
78 |
+
prompts = [input_text]
|
79 |
+
config.training.batch_size = config.batch_size = 1
|
80 |
+
config.training.guidance_scale = config.guidance_scale = guidance_scale
|
81 |
+
config.training.generation_timesteps = config.generation_timesteps = generation_timesteps
|
82 |
+
|
83 |
+
image_tokens = torch.ones((len(prompts), config.model.showo.num_vq_tokens),
|
84 |
+
dtype=torch.long, device=device) * mask_token_id
|
85 |
+
|
86 |
+
input_ids, _ = uni_prompting((prompts, image_tokens), 't2i_gen')
|
87 |
+
|
88 |
+
if config.training.guidance_scale > 0:
|
89 |
+
uncond_input_ids, _ = uni_prompting(([''] * len(prompts), image_tokens), 't2i_gen')
|
90 |
+
attention_mask = create_attention_mask_predict_next(torch.cat([input_ids, uncond_input_ids], dim=0),
|
91 |
+
pad_id=int(uni_prompting.sptids_dict['<|pad|>']),
|
92 |
+
soi_id=int(uni_prompting.sptids_dict['<|soi|>']),
|
93 |
+
eoi_id=int(uni_prompting.sptids_dict['<|eoi|>']),
|
94 |
+
rm_pad_in_image=True)
|
95 |
+
else:
|
96 |
+
attention_mask = create_attention_mask_predict_next(input_ids,
|
97 |
+
pad_id=int(uni_prompting.sptids_dict['<|pad|>']),
|
98 |
+
soi_id=int(uni_prompting.sptids_dict['<|soi|>']),
|
99 |
+
eoi_id=int(uni_prompting.sptids_dict['<|eoi|>']),
|
100 |
+
rm_pad_in_image=True)
|
101 |
+
uncond_input_ids = None
|
102 |
+
|
103 |
+
if config.get("mask_schedule", None) is not None:
|
104 |
+
schedule = config.mask_schedule.schedule
|
105 |
+
args = config.mask_schedule.get("params", {})
|
106 |
+
mask_schedule = get_mask_chedule(schedule, **args)
|
107 |
+
else:
|
108 |
+
mask_schedule = get_mask_chedule(config.training.get("mask_schedule", "cosine"))
|
109 |
+
|
110 |
+
with torch.no_grad():
|
111 |
+
gen_token_ids = model.t2i_generate(
|
112 |
+
input_ids=input_ids,
|
113 |
+
uncond_input_ids=uncond_input_ids,
|
114 |
+
attention_mask=attention_mask,
|
115 |
+
guidance_scale=config.training.guidance_scale,
|
116 |
+
temperature=config.training.get("generation_temperature", 1.0),
|
117 |
+
timesteps=config.training.generation_timesteps,
|
118 |
+
noise_schedule=mask_schedule,
|
119 |
+
noise_type=config.training.get("noise_type", "mask"),
|
120 |
+
seq_len=config.model.showo.num_vq_tokens,
|
121 |
+
uni_prompting=uni_prompting,
|
122 |
+
config=config,
|
123 |
+
)
|
124 |
+
|
125 |
+
gen_token_ids = torch.clamp(gen_token_ids, max=config.model.showo.codebook_size - 1, min=0)
|
126 |
+
images = vq_model.decode_code(gen_token_ids)
|
127 |
+
|
128 |
+
images = torch.clamp((images + 1.0) / 2.0, min=0.0, max=1.0)
|
129 |
+
images *= 255.0
|
130 |
+
images = images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8)
|
131 |
+
|
132 |
+
return images[0]
|
133 |
+
|
134 |
+
|
135 |
+
def text_guided_inpainting(input_text, inpainting_image, inpainting_mask, guidance_scale, generation_timesteps):
|
136 |
+
config, uni_prompting, tokenizer, vq_model, model = config_gen, uni_prompting_gen, tokenizer_gen, vq_model_gen, model_gen
|
137 |
+
|
138 |
+
prompt = [input_text]
|
139 |
+
|
140 |
+
config.training.batch_size = config.batch_size = 1
|
141 |
+
config.training.guidance_scale = config.guidance_scale = guidance_scale
|
142 |
+
config.training.generation_timesteps = config.generation_timesteps = generation_timesteps
|
143 |
+
|
144 |
+
inpainting_image = image_transform(inpainting_image, resolution=config.dataset.params.resolution).to(device)
|
145 |
+
inpainting_mask = image_transform(inpainting_mask, resolution=config.dataset.params.resolution, normalize=False)
|
146 |
+
|
147 |
+
inpainting_image = inpainting_image.unsqueeze(0).repeat(config.training.batch_size, 1, 1, 1)
|
148 |
+
|
149 |
+
inpainting_mask = inpainting_mask.unsqueeze(0).to(device)
|
150 |
+
inpainting_mask = F.interpolate(inpainting_mask, size=config.dataset.params.resolution // 16, mode='bicubic')
|
151 |
+
inpainting_mask = inpainting_mask.repeat(config.training.batch_size, 1, 1, 1)
|
152 |
+
|
153 |
+
inpainting_mask[inpainting_mask < 0.5] = 0
|
154 |
+
inpainting_mask[inpainting_mask >= 0.5] = 1
|
155 |
+
|
156 |
+
inpainting_mask = inpainting_mask.reshape(config.training.batch_size, -1)
|
157 |
+
inpainting_mask = inpainting_mask.to(torch.bool)
|
158 |
+
|
159 |
+
inpainting_image_tokens = vq_model.get_code(inpainting_image) + len(uni_prompting.text_tokenizer)
|
160 |
+
inpainting_image_tokens[inpainting_mask] = mask_token_id
|
161 |
+
|
162 |
+
input_ids, _ = uni_prompting((prompt, inpainting_image_tokens), 't2i_gen')
|
163 |
+
|
164 |
+
if config.training.guidance_scale > 0:
|
165 |
+
uncond_input_ids, _ = uni_prompting(([''] * len(prompt), inpainting_image_tokens), 't2i_gen')
|
166 |
+
attention_mask = create_attention_mask_predict_next(torch.cat([input_ids, uncond_input_ids], dim=0),
|
167 |
+
pad_id=int(uni_prompting.sptids_dict['<|pad|>']),
|
168 |
+
soi_id=int(uni_prompting.sptids_dict['<|soi|>']),
|
169 |
+
eoi_id=int(uni_prompting.sptids_dict['<|eoi|>']),
|
170 |
+
rm_pad_in_image=True)
|
171 |
+
else:
|
172 |
+
attention_mask = create_attention_mask_predict_next(input_ids,
|
173 |
+
pad_id=int(uni_prompting.sptids_dict['<|pad|>']),
|
174 |
+
soi_id=int(uni_prompting.sptids_dict['<|soi|>']),
|
175 |
+
eoi_id=int(uni_prompting.sptids_dict['<|eoi|>']),
|
176 |
+
rm_pad_in_image=True)
|
177 |
+
uncond_input_ids = None
|
178 |
+
|
179 |
+
if config.get("mask_schedule", None) is not None:
|
180 |
+
schedule = config.mask_schedule.schedule
|
181 |
+
args = config.mask_schedule.get("params", {})
|
182 |
+
mask_schedule = get_mask_chedule(schedule, **args)
|
183 |
+
else:
|
184 |
+
mask_schedule = get_mask_chedule(config.training.get("mask_schedule", "cosine"))
|
185 |
+
|
186 |
+
with torch.no_grad():
|
187 |
+
gen_token_ids = model.t2i_generate(
|
188 |
+
input_ids=input_ids,
|
189 |
+
uncond_input_ids=uncond_input_ids,
|
190 |
+
attention_mask=attention_mask,
|
191 |
+
guidance_scale=config.training.guidance_scale,
|
192 |
+
temperature=config.training.get("generation_temperature", 1.0),
|
193 |
+
timesteps=config.training.generation_timesteps,
|
194 |
+
noise_schedule=mask_schedule,
|
195 |
+
noise_type=config.training.get("noise_type", "mask"),
|
196 |
+
seq_len=config.model.showo.num_vq_tokens,
|
197 |
+
uni_prompting=uni_prompting,
|
198 |
+
config=config,
|
199 |
+
)
|
200 |
+
|
201 |
+
gen_token_ids = torch.clamp(gen_token_ids, max=config.model.showo.codebook_size - 1, min=0)
|
202 |
+
images = vq_model.decode_code(gen_token_ids)
|
203 |
+
|
204 |
+
images = torch.clamp((images + 1.0) / 2.0, min=0.0, max=1.0)
|
205 |
+
images *= 255.0
|
206 |
+
images = images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8)
|
207 |
+
|
208 |
+
return images[0]
|
209 |
+
|
210 |
+
|
211 |
+
def text_guided_extrapolation(input_img, input_text, left_ext, right_ext, guidance_scale, generation_timesteps):
|
212 |
+
config, uni_prompting, tokenizer, vq_model, model = config_gen, uni_prompting_gen, tokenizer_gen, vq_model_gen, model_gen
|
213 |
+
|
214 |
+
config.offset = 0
|
215 |
+
config.training.batch_size = config.batch_size = 1
|
216 |
+
config.training.guidance_scale = config.guidance_scale = guidance_scale
|
217 |
+
config.training.generation_timesteps = config.generation_timesteps = generation_timesteps
|
218 |
+
|
219 |
+
extra_direction = ['right'] * int(right_ext) + ['left'] * int(left_ext)
|
220 |
+
prompt = [input_text] * len(extra_direction)
|
221 |
+
W = config.dataset.params.resolution // 16
|
222 |
+
for id, (prt, direction) in enumerate(zip(prompt, extra_direction)):
|
223 |
+
prt = [prt] * config.training.batch_size
|
224 |
+
if id == 0:
|
225 |
+
# extrapolation_image = Image.open(config.image_path).convert("RGB")
|
226 |
+
extrapolation_image = input_img
|
227 |
+
extrapolation_image = image_transform(extrapolation_image,
|
228 |
+
resolution=config.dataset.params.resolution).to(device)
|
229 |
+
|
230 |
+
B, _, _ = extrapolation_image.shape
|
231 |
+
extrapolation_image = extrapolation_image.unsqueeze(0)
|
232 |
+
extrapolation_image_tokens = vq_model.get_code(extrapolation_image) + len(uni_prompting.text_tokenizer)
|
233 |
+
extrapolation_image_tokens = extrapolation_image_tokens.reshape(1,
|
234 |
+
config.dataset.params.resolution // 16,
|
235 |
+
config.dataset.params.resolution // 16)
|
236 |
+
extrapolation_image_tokens = extrapolation_image_tokens.repeat(config.training.batch_size, 1, 1)
|
237 |
+
else:
|
238 |
+
|
239 |
+
extrapolation_image_tokens = gen_token_ids + len(uni_prompting.text_tokenizer)
|
240 |
+
|
241 |
+
image_left_part = extrapolation_image_tokens[:, :, :-(W // 2 - config.offset)] - len(
|
242 |
+
uni_prompting.text_tokenizer)
|
243 |
+
image_right_part = extrapolation_image_tokens[:, :, W // 2 - config.offset:] - len(uni_prompting.text_tokenizer)
|
244 |
+
image_up_part = extrapolation_image_tokens[:, :-(W // 2 - config.offset), :] - len(uni_prompting.text_tokenizer)
|
245 |
+
image_down_part = extrapolation_image_tokens[:, W // 2 - config.offset:, :] - len(uni_prompting.text_tokenizer)
|
246 |
+
|
247 |
+
if direction in ['left', 'right']:
|
248 |
+
extrapolation_mask = torch.zeros((config.training.batch_size,
|
249 |
+
config.dataset.params.resolution // 16,
|
250 |
+
config.dataset.params.resolution // 16 // 2 + config.offset),
|
251 |
+
dtype=torch.int64, device=device) + mask_token_id
|
252 |
+
else:
|
253 |
+
extrapolation_mask = torch.zeros((config.training.batch_size,
|
254 |
+
config.dataset.params.resolution // 16 // 2 + config.offset,
|
255 |
+
config.dataset.params.resolution // 16),
|
256 |
+
dtype=torch.int64, device=device) + mask_token_id
|
257 |
+
|
258 |
+
if direction == 'left':
|
259 |
+
extrapolation_image_tokens = torch.cat(
|
260 |
+
[extrapolation_mask, extrapolation_image_tokens[:, :, :W // 2 - config.offset]], dim=-1)
|
261 |
+
elif direction == 'right':
|
262 |
+
extrapolation_image_tokens = torch.cat(
|
263 |
+
[extrapolation_image_tokens[:, :, -(W // 2 - config.offset):], extrapolation_mask], dim=-1)
|
264 |
+
elif direction == 'up':
|
265 |
+
extrapolation_image_tokens = torch.cat(
|
266 |
+
[extrapolation_mask, extrapolation_image_tokens[:, :W // 2 - config.offset, :]], dim=-2)
|
267 |
+
else:
|
268 |
+
extrapolation_image_tokens = torch.cat(
|
269 |
+
[extrapolation_image_tokens[:, -(W // 2 - config.offset):, :], extrapolation_mask], dim=-2)
|
270 |
+
|
271 |
+
extrapolation_image_tokens = extrapolation_image_tokens.reshape(config.training.batch_size, -1)
|
272 |
+
|
273 |
+
input_ids, _ = uni_prompting((prt, extrapolation_image_tokens), 't2i_gen')
|
274 |
+
|
275 |
+
if config.training.guidance_scale > 0:
|
276 |
+
uncond_input_ids, _ = uni_prompting(([''] * len(prt), extrapolation_image_tokens), 't2i_gen')
|
277 |
+
attention_mask = create_attention_mask_predict_next(torch.cat([input_ids, uncond_input_ids], dim=0),
|
278 |
+
pad_id=int(uni_prompting.sptids_dict['<|pad|>']),
|
279 |
+
soi_id=int(uni_prompting.sptids_dict['<|soi|>']),
|
280 |
+
eoi_id=int(uni_prompting.sptids_dict['<|eoi|>']),
|
281 |
+
rm_pad_in_image=True)
|
282 |
+
else:
|
283 |
+
attention_mask = create_attention_mask_predict_next(input_ids,
|
284 |
+
pad_id=int(uni_prompting.sptids_dict['<|pad|>']),
|
285 |
+
soi_id=int(uni_prompting.sptids_dict['<|soi|>']),
|
286 |
+
eoi_id=int(uni_prompting.sptids_dict['<|eoi|>']),
|
287 |
+
rm_pad_in_image=True)
|
288 |
+
uncond_input_ids = None
|
289 |
+
|
290 |
+
if config.get("mask_schedule", None) is not None:
|
291 |
+
schedule = config.mask_schedule.schedule
|
292 |
+
args = config.mask_schedule.get("params", {})
|
293 |
+
mask_schedule = get_mask_chedule(schedule, **args)
|
294 |
+
else:
|
295 |
+
mask_schedule = get_mask_chedule(config.training.get("mask_schedule", "cosine"))
|
296 |
+
|
297 |
+
with torch.no_grad():
|
298 |
+
gen_token_ids = model.t2i_generate(
|
299 |
+
input_ids=input_ids,
|
300 |
+
uncond_input_ids=uncond_input_ids,
|
301 |
+
attention_mask=attention_mask,
|
302 |
+
guidance_scale=config.training.guidance_scale,
|
303 |
+
temperature=config.training.get("generation_temperature", 1.0),
|
304 |
+
timesteps=config.training.generation_timesteps,
|
305 |
+
noise_schedule=mask_schedule,
|
306 |
+
noise_type=config.training.get("noise_type", "mask"),
|
307 |
+
seq_len=config.model.showo.num_vq_tokens,
|
308 |
+
uni_prompting=uni_prompting,
|
309 |
+
config=config,
|
310 |
+
)
|
311 |
+
|
312 |
+
gen_token_ids = torch.clamp(gen_token_ids, max=config.model.showo.codebook_size - 1, min=0)
|
313 |
+
gen_token_ids = gen_token_ids.reshape(config.training.batch_size,
|
314 |
+
config.dataset.params.resolution // 16,
|
315 |
+
config.dataset.params.resolution // 16)
|
316 |
+
if direction == 'left':
|
317 |
+
gen_token_ids = torch.cat([gen_token_ids, image_right_part], dim=-1)
|
318 |
+
elif direction == 'right':
|
319 |
+
gen_token_ids = torch.cat([image_left_part, gen_token_ids], dim=-1)
|
320 |
+
elif direction == 'up':
|
321 |
+
gen_token_ids = torch.cat([gen_token_ids, image_down_part], dim=-2)
|
322 |
+
else:
|
323 |
+
gen_token_ids = torch.cat([image_left_part, gen_token_ids], dim=-2)
|
324 |
+
|
325 |
+
_, h, w = gen_token_ids.shape
|
326 |
+
gen_token_ids = gen_token_ids.reshape(config.training.batch_size, -1)
|
327 |
+
images = vq_model.decode_code(gen_token_ids, shape=(h, w))
|
328 |
+
|
329 |
+
images = torch.clamp((images + 1.0) / 2.0, min=0.0, max=1.0)
|
330 |
+
images *= 255.0
|
331 |
+
images = images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8)
|
332 |
+
|
333 |
+
return images[0]
|
334 |
+
|
335 |
+
|
336 |
+
def multimodal_understanding(input_img, input_text, chat_history):
|
337 |
+
global config_mmu, uni_prompting_mmu, tokenizer_mmu, model_mmu, vision_tower, clip_image_processor
|
338 |
+
if model_mmu is None:
|
339 |
+
config_mmu, uni_prompting_mmu, tokenizer_mmu, model_mmu, vision_tower, clip_image_processor = load_continuous_checkpoint()
|
340 |
+
config, uni_prompting, tokenizer, model = config_mmu, uni_prompting_mmu, tokenizer_mmu, model_mmu
|
341 |
+
|
342 |
+
image_ori = input_img
|
343 |
+
pixel_values = clip_image_processor.preprocess(image_ori, return_tensors="pt")["pixel_values"][0]
|
344 |
+
batch_size = 1
|
345 |
+
question = input_text
|
346 |
+
top_k = 1 # retain only the top_k most likely tokens, clamp others to have 0 probability
|
347 |
+
|
348 |
+
conv = conversation_lib.default_conversation.copy()
|
349 |
+
conv.append_message(conv.roles[0], question)
|
350 |
+
conv.append_message(conv.roles[1], None)
|
351 |
+
prompt_question = conv.get_prompt()
|
352 |
+
question_input = []
|
353 |
+
question_input.append(prompt_question.strip())
|
354 |
+
|
355 |
+
input_ids_system = [uni_prompting.text_tokenizer(SYSTEM_PROMPT, return_tensors="pt", padding="longest").input_ids
|
356 |
+
for _ in range(batch_size)]
|
357 |
+
input_ids_system = torch.stack(input_ids_system, dim=0)
|
358 |
+
assert input_ids_system.shape[-1] == 28
|
359 |
+
input_ids_system = input_ids_system.to(device)
|
360 |
+
input_ids_system = input_ids_system[0]
|
361 |
+
|
362 |
+
input_ids = [uni_prompting.text_tokenizer(prompt, return_tensors="pt", padding="longest").input_ids
|
363 |
+
for prompt in question_input]
|
364 |
+
|
365 |
+
input_ids = torch.stack(input_ids)
|
366 |
+
input_ids = torch.nn.utils.rnn.pad_sequence(
|
367 |
+
input_ids, batch_first=True, padding_value=uni_prompting.text_tokenizer.pad_token_id
|
368 |
+
)
|
369 |
+
input_ids = torch.tensor(input_ids).to(device).squeeze(0)
|
370 |
+
input_ids_llava = torch.cat([
|
371 |
+
(torch.ones(input_ids.shape[0], 1) *uni_prompting.sptids_dict['<|mmu|>']).to(device),
|
372 |
+
input_ids_system,
|
373 |
+
(torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|soi|>']).to(device),
|
374 |
+
# place your img embedding here
|
375 |
+
(torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|eoi|>']).to(device),
|
376 |
+
input_ids,
|
377 |
+
], dim=1).long()
|
378 |
+
|
379 |
+
images_embeddings = vision_tower(pixel_values[None])
|
380 |
+
images_embeddings = model.mm_projector(images_embeddings)
|
381 |
+
|
382 |
+
text_embeddings = model.showo.model.embed_tokens(input_ids_llava)
|
383 |
+
|
384 |
+
# Full input seq
|
385 |
+
part1 = text_embeddings[:, :2 + SYSTEM_PROMPT_LEN, :]
|
386 |
+
part2 = text_embeddings[:, 2 + SYSTEM_PROMPT_LEN:, :]
|
387 |
+
input_embeddings = torch.cat((part1, images_embeddings, part2), dim=1)
|
388 |
+
|
389 |
+
attention_mask_llava = create_attention_mask_for_mmu_vit(input_embeddings,
|
390 |
+
system_prompt_len=SYSTEM_PROMPT_LEN)
|
391 |
+
|
392 |
+
cont_toks_list = model.mmu_generate(input_embeddings=input_embeddings,
|
393 |
+
attention_mask=attention_mask_llava[0].unsqueeze(0),
|
394 |
+
max_new_tokens=100,
|
395 |
+
top_k=top_k,
|
396 |
+
# eot_token=uni_prompting.sptids_dict['<|eot|>']
|
397 |
+
eot_token=tokenizer.eos_token_id
|
398 |
+
)
|
399 |
+
|
400 |
+
cont_toks_list = torch.stack(cont_toks_list).squeeze()[None]
|
401 |
+
|
402 |
+
output_text = uni_prompting.text_tokenizer.batch_decode(cont_toks_list, skip_special_tokens=True)
|
403 |
+
|
404 |
+
output_text = output_text[0].strip()
|
405 |
+
|
406 |
+
chat_history.append((input_text, output_text))
|
407 |
+
|
408 |
+
return "", chat_history
|
409 |
+
|
410 |
+
|
411 |
+
with gr.Blocks() as demo:
|
412 |
+
gr.HTML("""
|
413 |
+
<h1 class="display-2 fw-bold title">
|
414 |
+
<a style="color: #70a8dc;">S</a><a style="color: #6fb051;">h</a><a style="color: #e06766;">o</a><a style="color: #f7b26b;">w</a>-o
|
415 |
+
</h1>
|
416 |
+
<p>This is the official Gradio demo for the Show-o model, a unified model that can do multimodal understanding and generation.</p>
|
417 |
+
|
418 |
+
<strong>Paper:</strong> <a href="https://arxiv.org/abs/2408.12528" target="_blank">Show-o: One Single Transformer To Unify Multimodal Understanding and Generation </a>
|
419 |
+
<br/>
|
420 |
+
<strong>Project Website:</strong> <a href="https://showlab.github.io/Show-o/" target="_blank">Show-o Website</a>
|
421 |
+
<br/>
|
422 |
+
<strong>Code and Models:</strong> <a href="https://github.com/showlab/Show-o" target="_blank">GitHub</a>
|
423 |
+
<br/>
|
424 |
+
<br/>
|
425 |
+
""")
|
426 |
+
|
427 |
+
with gr.Row():
|
428 |
+
with gr.Column():
|
429 |
+
text_prompt_t2i = gr.Textbox(
|
430 |
+
label="Text prompt",
|
431 |
+
lines=2,
|
432 |
+
placeholder="Input the text prompt here for image generation."
|
433 |
+
)
|
434 |
+
guidance_scale_t2i = gr.Slider(
|
435 |
+
label="guidance scale",
|
436 |
+
minimum=0,
|
437 |
+
maximum=5,
|
438 |
+
step=0.05,
|
439 |
+
value=1.75
|
440 |
+
)
|
441 |
+
generation_timesteps_t2i = gr.Slider(
|
442 |
+
label="timesteps",
|
443 |
+
minimum=1,
|
444 |
+
maximum=30,
|
445 |
+
step=1,
|
446 |
+
value=18
|
447 |
+
)
|
448 |
+
generated_img_t2i = gr.Image(
|
449 |
+
label="Output image"
|
450 |
+
)
|
451 |
+
submit_btn_t2i = gr.Button("Generate: Text-to-image")
|
452 |
+
submit_btn_t2i.click(text_to_image_generation,
|
453 |
+
[text_prompt_t2i, guidance_scale_t2i, generation_timesteps_t2i],
|
454 |
+
[generated_img_t2i])
|
455 |
+
|
456 |
+
with gr.Row():
|
457 |
+
inpainting_input_img = gr.Image(
|
458 |
+
label="Input image",
|
459 |
+
type="pil",
|
460 |
+
)
|
461 |
+
inpainting_input_mask = gr.Image(
|
462 |
+
label="Inpainting mask",
|
463 |
+
image_mode="L",
|
464 |
+
type="pil",
|
465 |
+
)
|
466 |
+
|
467 |
+
with gr.Column():
|
468 |
+
text_prompt_inpainting = gr.Textbox(
|
469 |
+
label="Text prompt",
|
470 |
+
lines=2,
|
471 |
+
placeholder="Input the text prompt here for image inpainting."
|
472 |
+
)
|
473 |
+
guidance_scale_inpainting = gr.Slider(
|
474 |
+
label="guidance scale",
|
475 |
+
minimum=0,
|
476 |
+
maximum=5,
|
477 |
+
step=0.05,
|
478 |
+
value=1.75
|
479 |
+
)
|
480 |
+
generation_timesteps_inpainting = gr.Slider(
|
481 |
+
label="timesteps",
|
482 |
+
minimum=1,
|
483 |
+
maximum=30,
|
484 |
+
step=1,
|
485 |
+
value=16
|
486 |
+
)
|
487 |
+
generated_img_inpainting = gr.Image(
|
488 |
+
label="Output image"
|
489 |
+
)
|
490 |
+
submit_btn_inpainting = gr.Button("Generate: Text-guided Inpainting")
|
491 |
+
submit_btn_inpainting.click(text_guided_inpainting,
|
492 |
+
[text_prompt_inpainting, inpainting_input_img, inpainting_input_mask,
|
493 |
+
guidance_scale_inpainting, generation_timesteps_inpainting],
|
494 |
+
[generated_img_inpainting])
|
495 |
+
|
496 |
+
with gr.Row():
|
497 |
+
extra_input_img = gr.Image(
|
498 |
+
label="Input image",
|
499 |
+
type="pil",
|
500 |
+
image_mode="RGB",
|
501 |
+
)
|
502 |
+
|
503 |
+
with gr.Column():
|
504 |
+
text_prompt_extrapolation = gr.Textbox(
|
505 |
+
label="Text prompt",
|
506 |
+
lines=1,
|
507 |
+
placeholder="Input the text prompt here for image extrapolation."
|
508 |
+
)
|
509 |
+
guidance_scale_extrapolation = gr.Slider(
|
510 |
+
label="guidance scale",
|
511 |
+
minimum=0,
|
512 |
+
maximum=5,
|
513 |
+
step=0.05,
|
514 |
+
value=1.75
|
515 |
+
)
|
516 |
+
generation_timesteps_extrapolation = gr.Slider(
|
517 |
+
label="timesteps",
|
518 |
+
minimum=1,
|
519 |
+
maximum=30,
|
520 |
+
step=1,
|
521 |
+
value=16
|
522 |
+
)
|
523 |
+
left_extrapolation = gr.Slider(
|
524 |
+
label="left extrapolation",
|
525 |
+
minimum=0,
|
526 |
+
maximum=5,
|
527 |
+
step=1,
|
528 |
+
value=1
|
529 |
+
)
|
530 |
+
right_extrapolation = gr.Slider(
|
531 |
+
label="right extrapolation",
|
532 |
+
minimum=0,
|
533 |
+
maximum=5,
|
534 |
+
step=1,
|
535 |
+
value=1
|
536 |
+
)
|
537 |
+
generated_img_extrapolation = gr.Image(
|
538 |
+
label="Output image"
|
539 |
+
)
|
540 |
+
submit_btn_inpainting = gr.Button("Generate: Text-guided Extrapolation")
|
541 |
+
submit_btn_inpainting.click(text_guided_extrapolation,
|
542 |
+
[extra_input_img, text_prompt_extrapolation, left_extrapolation, right_extrapolation,
|
543 |
+
guidance_scale_extrapolation, generation_timesteps_extrapolation],
|
544 |
+
[generated_img_extrapolation])
|
545 |
+
|
546 |
+
with gr.Row():
|
547 |
+
with gr.Row():
|
548 |
+
chat_input_img = gr.Image(
|
549 |
+
label="Input image",
|
550 |
+
type="pil",
|
551 |
+
image_mode="RGB",
|
552 |
+
)
|
553 |
+
with gr.Column():
|
554 |
+
chatbot = gr.Chatbot()
|
555 |
+
msg = gr.Textbox(label="Press Enter to send a message for chat")
|
556 |
+
clear = gr.ClearButton([msg, chatbot])
|
557 |
+
msg.submit(multimodal_understanding, [chat_input_img, msg, chatbot], [msg, chatbot])
|
558 |
+
|
559 |
+
demo.launch()
|
gradio/share_btn.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from https://huggingface.co/spaces/haoheliu/audioldm-text-to-audio-generation/blob/79681cd8cb235160a27cdd100673346eb1784e53/share_btn.py
|
2 |
+
|
3 |
+
community_icon_html = """<svg id="share-btn-share-icon" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32">
|
4 |
+
<path d="M20.6081 3C21.7684 3 22.8053 3.49196 23.5284 4.38415C23.9756 4.93678 24.4428 5.82749 24.4808 7.16133C24.9674 7.01707 25.4353 6.93643 25.8725 6.93643C26.9833 6.93643 27.9865 7.37587 28.696 8.17411C29.6075 9.19872 30.0124 10.4579 29.8361 11.7177C29.7523 12.3177 29.5581 12.8555 29.2678 13.3534C29.8798 13.8646 30.3306 14.5763 30.5485 15.4322C30.719 16.1032 30.8939 17.5006 29.9808 18.9403C30.0389 19.0342 30.0934 19.1319 30.1442 19.2318C30.6932 20.3074 30.7283 21.5229 30.2439 22.6548C29.5093 24.3704 27.6841 25.7219 24.1397 27.1727C21.9347 28.0753 19.9174 28.6523 19.8994 28.6575C16.9842 29.4379 14.3477 29.8345 12.0653 29.8345C7.87017 29.8345 4.8668 28.508 3.13831 25.8921C0.356375 21.6797 0.754104 17.8269 4.35369 14.1131C6.34591 12.058 7.67023 9.02782 7.94613 8.36275C8.50224 6.39343 9.97271 4.20438 12.4172 4.20438H12.4179C12.6236 4.20438 12.8314 4.2214 13.0364 4.25468C14.107 4.42854 15.0428 5.06476 15.7115 6.02205C16.4331 5.09583 17.134 4.359 17.7682 3.94323C18.7242 3.31737 19.6794 3 20.6081 3ZM20.6081 5.95917C20.2427 5.95917 19.7963 6.1197 19.3039 6.44225C17.7754 7.44319 14.8258 12.6772 13.7458 14.7131C13.3839 15.3952 12.7655 15.6837 12.2086 15.6837C11.1036 15.6837 10.2408 14.5497 12.1076 13.1085C14.9146 10.9402 13.9299 7.39584 12.5898 7.1776C12.5311 7.16799 12.4731 7.16355 12.4172 7.16355C11.1989 7.16355 10.6615 9.33114 10.6615 9.33114C10.6615 9.33114 9.0863 13.4148 6.38031 16.206C3.67434 18.998 3.5346 21.2388 5.50675 24.2246C6.85185 26.2606 9.42666 26.8753 12.0653 26.8753C14.8021 26.8753 17.6077 26.2139 19.1799 25.793C19.2574 25.7723 28.8193 22.984 27.6081 20.6107C27.4046 20.212 27.0693 20.0522 26.6471 20.0522C24.9416 20.0522 21.8393 22.6726 20.5057 22.6726C20.2076 22.6726 19.9976 22.5416 19.9116 22.222C19.3433 20.1173 28.552 19.2325 27.7758 16.1839C27.639 15.6445 27.2677 15.4256 26.746 15.4263C24.4923 15.4263 19.4358 19.5181 18.3759 19.5181C18.2949 19.5181 18.2368 19.4937 18.2053 19.4419C17.6743 18.557 17.9653 17.9394 21.7082 15.6009C25.4511 13.2617 28.0783 11.8545 26.5841 10.1752C26.4121 9.98141 26.1684 9.8956 25.8725 9.8956C23.6001 9.89634 18.2311 14.9403 18.2311 14.9403C18.2311 14.9403 16.7821 16.496 15.9057 16.496C15.7043 16.496 15.533 16.4139 15.4169 16.2112C14.7956 15.1296 21.1879 10.1286 21.5484 8.06535C21.7928 6.66715 21.3771 5.95917 20.6081 5.95917Z" fill="#FF9D00"></path>
|
5 |
+
<path d="M5.50686 24.2246C3.53472 21.2387 3.67446 18.9979 6.38043 16.206C9.08641 13.4147 10.6615 9.33111 10.6615 9.33111C10.6615 9.33111 11.2499 6.95933 12.59 7.17757C13.93 7.39581 14.9139 10.9401 12.1069 13.1084C9.29997 15.276 12.6659 16.7489 13.7459 14.713C14.8258 12.6772 17.7747 7.44316 19.304 6.44221C20.8326 5.44128 21.9089 6.00204 21.5484 8.06532C21.188 10.1286 14.795 15.1295 15.4171 16.2118C16.0391 17.2934 18.2312 14.9402 18.2312 14.9402C18.2312 14.9402 25.0907 8.49588 26.5842 10.1752C28.0776 11.8545 25.4512 13.2616 21.7082 15.6008C17.9646 17.9393 17.6744 18.557 18.2054 19.4418C18.7372 20.3266 26.9998 13.1351 27.7759 16.1838C28.5513 19.2324 19.3434 20.1173 19.9117 22.2219C20.48 24.3274 26.3979 18.2382 27.6082 20.6107C28.8193 22.9839 19.2574 25.7722 19.18 25.7929C16.0914 26.62 8.24723 28.3726 5.50686 24.2246Z" fill="#FFD21E"></path>
|
6 |
+
</svg>"""
|
7 |
+
|
8 |
+
loading_icon_html = """<svg id="share-btn-loading-icon" style="display:none;" class="animate-spin"
|
9 |
+
style="color: #ffffff;
|
10 |
+
"
|
11 |
+
xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" fill="none" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 24 24"><circle style="opacity: 0.25;" cx="12" cy="12" r="10" stroke="white" stroke-width="4"></circle><path style="opacity: 0.75;" fill="white" d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"></path></svg>"""
|
12 |
+
|
13 |
+
share_js = """
|
14 |
+
async () => {
|
15 |
+
const html2canvas = (await import('https://cdnjs.cloudflare.com/ajax/libs/html2canvas/1.4.1/html2canvas.esm.js')).default;
|
16 |
+
async function uploadFile(file) {
|
17 |
+
console.log(file.type)
|
18 |
+
const UPLOAD_URL = 'https://huggingface.co/uploads';
|
19 |
+
const response = await fetch(UPLOAD_URL, {
|
20 |
+
method: 'POST',
|
21 |
+
headers: {
|
22 |
+
'Content-Type': file.type,
|
23 |
+
'X-Requested-With': 'XMLHttpRequest',
|
24 |
+
},
|
25 |
+
body: file, /// <- File inherits from Blob
|
26 |
+
});
|
27 |
+
const url = await response.text();
|
28 |
+
return url;
|
29 |
+
}
|
30 |
+
async function getImageFile(div) {
|
31 |
+
let chatbot = document.getElementById("chatbot");
|
32 |
+
chatbot.style.height = "";
|
33 |
+
return new Promise((resolve, reject) =>
|
34 |
+
html2canvas(div)
|
35 |
+
.then((canvas) => {
|
36 |
+
chatbot.style.height = "400px";
|
37 |
+
const imageBlob = canvas.toBlob((blob) => {
|
38 |
+
const imageId = Date.now();
|
39 |
+
const fileName = "GILL-" + imageId + ".jpg";
|
40 |
+
resolve(new File([blob], fileName, { type: 'image/jpeg' }));
|
41 |
+
}, 'image/jpeg', 0.95);
|
42 |
+
})
|
43 |
+
|
44 |
+
)
|
45 |
+
}
|
46 |
+
|
47 |
+
const gradioEl = document.querySelector("gradio-app").shadowRoot || document.querySelector('body > gradio-app');
|
48 |
+
const chatbotEl = gradioEl.querySelector('#chatbot')
|
49 |
+
const imageFile = await getImageFile(chatbotEl);
|
50 |
+
console.log(imageFile);
|
51 |
+
const urlChatbotImage = await uploadFile(imageFile);
|
52 |
+
console.log(urlChatbotImage);
|
53 |
+
let titleTxt = `GILL Example`;
|
54 |
+
|
55 |
+
//const shareBtnEl = gradioEl.querySelector('#share-btn');
|
56 |
+
//shareBtnEl.style.pointerEvents = 'none';
|
57 |
+
const descriptionMd = `
|
58 |
+
|
59 |
+
<img src='${urlChatbotImage}'>
|
60 |
+
`;
|
61 |
+
const params = new URLSearchParams({
|
62 |
+
title: titleTxt,
|
63 |
+
description: descriptionMd,
|
64 |
+
});
|
65 |
+
const paramsStr = params.toString();
|
66 |
+
window.open(`https://huggingface.co/spaces/jykoh/gill/discussions/new?${paramsStr}`, '_blank');
|
67 |
+
//shareBtnEl.style.removeProperty('pointer-events');
|
68 |
+
}
|
69 |
+
"""
|
70 |
+
|
71 |
+
save_js = """
|
72 |
+
async () => {
|
73 |
+
const html2canvas = (await import('https://cdnjs.cloudflare.com/ajax/libs/html2canvas/1.4.1/html2canvas.esm.js')).default;
|
74 |
+
|
75 |
+
function saveAs(uri, filename) {
|
76 |
+
var link = document.createElement('a');
|
77 |
+
if (typeof link.download === 'string') {
|
78 |
+
link.href = uri;
|
79 |
+
link.download = filename;
|
80 |
+
|
81 |
+
//Firefox requires the link to be in the body
|
82 |
+
document.body.appendChild(link);
|
83 |
+
|
84 |
+
//simulate click
|
85 |
+
link.click();
|
86 |
+
|
87 |
+
//remove the link when done
|
88 |
+
document.body.removeChild(link);
|
89 |
+
} else {
|
90 |
+
window.open(uri);
|
91 |
+
}
|
92 |
+
}
|
93 |
+
|
94 |
+
async function getImageFile(div) {
|
95 |
+
let chatbot = document.getElementById("chatbot");
|
96 |
+
chatbot.style.height = "";
|
97 |
+
return new Promise((resolve, reject) =>
|
98 |
+
html2canvas(div)
|
99 |
+
.then((canvas) => {
|
100 |
+
chatbot.style.height = "400px";
|
101 |
+
const imageId = Date.now();
|
102 |
+
const fileName = "GILL-" + imageId + ".png";
|
103 |
+
saveAs(canvas.toDataURL(), fileName);
|
104 |
+
})
|
105 |
+
|
106 |
+
)
|
107 |
+
}
|
108 |
+
const gradioEl = document.querySelector("gradio-app").shadowRoot || document.querySelector('body > gradio-app');
|
109 |
+
const chatbotEl = gradioEl.querySelector('#chatbot')
|
110 |
+
const imageFile = await getImageFile(chatbotEl);
|
111 |
+
console.log(imageFile);
|
112 |
+
}
|
113 |
+
"""
|
inference_mmu.py
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
3 |
+
from PIL import Image
|
4 |
+
from tqdm import tqdm
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import wandb
|
8 |
+
from models import Showo, MAGVITv2
|
9 |
+
from prompting_utils import UniversalPrompting, create_attention_mask_for_mmu, create_attention_mask_for_mmu_vit
|
10 |
+
from training.utils import get_config, flatten_omega_conf, image_transform
|
11 |
+
from transformers import AutoTokenizer
|
12 |
+
from models.clip_encoder import CLIPVisionTower
|
13 |
+
from transformers import CLIPImageProcessor
|
14 |
+
|
15 |
+
# import.training.conversation as conversation_lib
|
16 |
+
from training import conversation as conversation_lib
|
17 |
+
|
18 |
+
conversation_lib.default_conversation = conversation_lib.conv_templates["phi1.5"]
|
19 |
+
SYSTEM_PROMPT = "A chat between a curious user and an artificial intelligence assistant. " \
|
20 |
+
"The assistant gives helpful, detailed, and polite answers to the user's questions."
|
21 |
+
SYSTEM_PROMPT_LEN = 28
|
22 |
+
|
23 |
+
def get_vq_model_class(model_type):
|
24 |
+
if model_type == "magvitv2":
|
25 |
+
return MAGVITv2
|
26 |
+
else:
|
27 |
+
raise ValueError(f"model_type {model_type} not supported.")
|
28 |
+
|
29 |
+
if __name__ == '__main__':
|
30 |
+
|
31 |
+
config = get_config()
|
32 |
+
|
33 |
+
resume_wandb_run = config.wandb.resume
|
34 |
+
run_id = config.wandb.get("run_id", None)
|
35 |
+
if run_id is None:
|
36 |
+
resume_wandb_run = False
|
37 |
+
run_id = wandb.util.generate_id()
|
38 |
+
config.wandb.run_id = run_id
|
39 |
+
|
40 |
+
wandb_config = {k: v for k, v in flatten_omega_conf(config, resolve=True)}
|
41 |
+
|
42 |
+
wandb.init(
|
43 |
+
project="demo",
|
44 |
+
name=config.experiment.name + '_mmu',
|
45 |
+
config=wandb_config,
|
46 |
+
)
|
47 |
+
|
48 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
49 |
+
tokenizer = AutoTokenizer.from_pretrained(config.model.showo.llm_model_path, padding_side="left")
|
50 |
+
|
51 |
+
uni_prompting = UniversalPrompting(tokenizer, max_text_len=config.dataset.preprocessing.max_seq_length,
|
52 |
+
special_tokens=("<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", "<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>"),
|
53 |
+
ignore_id=-100, cond_dropout_prob=config.training.cond_dropout_prob)
|
54 |
+
|
55 |
+
vq_model = get_vq_model_class(config.model.vq_model.type)
|
56 |
+
vq_model = vq_model.from_pretrained(config.model.vq_model.vq_model_name).to(device)
|
57 |
+
vq_model.requires_grad_(False)
|
58 |
+
vq_model.eval()
|
59 |
+
|
60 |
+
vision_tower_name = "openai/clip-vit-large-patch14-336"
|
61 |
+
vision_tower = CLIPVisionTower(vision_tower_name).to(device)
|
62 |
+
clip_image_processor = CLIPImageProcessor.from_pretrained(vision_tower_name)
|
63 |
+
|
64 |
+
model = Showo.from_pretrained(config.model.showo.pretrained_model_path).to(device)
|
65 |
+
model.eval()
|
66 |
+
|
67 |
+
temperature = 0.8 # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions
|
68 |
+
top_k = 1 # retain only the top_k most likely tokens, clamp others to have 0 probability
|
69 |
+
|
70 |
+
file_list = os.listdir(config.mmu_image_root)
|
71 |
+
responses = ['' for i in range(len(file_list))]
|
72 |
+
images = []
|
73 |
+
config.question = config.question.split(' *** ')
|
74 |
+
for i, file_name in enumerate(tqdm(file_list)):
|
75 |
+
image_path = os.path.join(config.mmu_image_root, file_name)
|
76 |
+
image_ori = Image.open(image_path).convert("RGB")
|
77 |
+
image = image_transform(image_ori, resolution=config.dataset.params.resolution).to(device)
|
78 |
+
image = image.unsqueeze(0)
|
79 |
+
images.append(image)
|
80 |
+
|
81 |
+
pixel_values = clip_image_processor.preprocess(image_ori, return_tensors="pt")["pixel_values"][0]
|
82 |
+
|
83 |
+
image_tokens = vq_model.get_code(image) + len(uni_prompting.text_tokenizer)
|
84 |
+
batch_size = 1
|
85 |
+
|
86 |
+
for question in config.question:
|
87 |
+
if config.model.showo.w_clip_vit:
|
88 |
+
conv = conversation_lib.default_conversation.copy()
|
89 |
+
conv.append_message(conv.roles[0], question)
|
90 |
+
conv.append_message(conv.roles[1], None)
|
91 |
+
prompt_question = conv.get_prompt()
|
92 |
+
question_input = []
|
93 |
+
question_input.append(prompt_question.strip())
|
94 |
+
|
95 |
+
input_ids_system = [uni_prompting.text_tokenizer(SYSTEM_PROMPT, return_tensors="pt", padding="longest").input_ids
|
96 |
+
for _ in range(batch_size)]
|
97 |
+
input_ids_system = torch.stack(input_ids_system, dim=0)
|
98 |
+
assert input_ids_system.shape[-1] == 28
|
99 |
+
input_ids_system = input_ids_system.to(device)
|
100 |
+
input_ids_system = input_ids_system[0]
|
101 |
+
|
102 |
+
input_ids = [uni_prompting.text_tokenizer(prompt, return_tensors="pt", padding="longest").input_ids
|
103 |
+
for prompt in question_input]
|
104 |
+
|
105 |
+
input_ids = torch.stack(input_ids)
|
106 |
+
input_ids = torch.nn.utils.rnn.pad_sequence(
|
107 |
+
input_ids, batch_first=True, padding_value=uni_prompting.text_tokenizer.pad_token_id
|
108 |
+
)
|
109 |
+
input_ids = torch.tensor(input_ids).to(device).squeeze(0)
|
110 |
+
# import pdb; pdb.set_trace()
|
111 |
+
input_ids_llava = torch.cat([
|
112 |
+
(torch.ones(input_ids.shape[0], 1) *uni_prompting.sptids_dict['<|mmu|>']).to(device),
|
113 |
+
input_ids_system,
|
114 |
+
(torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|soi|>']).to(device),
|
115 |
+
# place your img embedding here
|
116 |
+
(torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|eoi|>']).to(device),
|
117 |
+
input_ids,
|
118 |
+
], dim=1).long()
|
119 |
+
|
120 |
+
images_embeddings = vision_tower(pixel_values[None])
|
121 |
+
images_embeddings = model.mm_projector(images_embeddings)
|
122 |
+
|
123 |
+
text_embeddings = model.showo.model.embed_tokens(input_ids_llava)
|
124 |
+
|
125 |
+
# Full input seq
|
126 |
+
part1 = text_embeddings[:, :2 + SYSTEM_PROMPT_LEN, :]
|
127 |
+
part2 = text_embeddings[:, 2 + SYSTEM_PROMPT_LEN:, :]
|
128 |
+
input_embeddings = torch.cat((part1, images_embeddings, part2), dim=1)
|
129 |
+
|
130 |
+
attention_mask_llava = create_attention_mask_for_mmu_vit(input_embeddings,
|
131 |
+
system_prompt_len=SYSTEM_PROMPT_LEN)
|
132 |
+
|
133 |
+
cont_toks_list = model.mmu_generate(input_embeddings=input_embeddings,
|
134 |
+
attention_mask=attention_mask_llava[0].unsqueeze(0),
|
135 |
+
max_new_tokens=100,
|
136 |
+
top_k=top_k,
|
137 |
+
eot_token=uni_prompting.sptids_dict['<|eot|>']
|
138 |
+
)
|
139 |
+
else:
|
140 |
+
input_ids = uni_prompting.text_tokenizer(['USER: \n' + question + ' ASSISTANT:'])[
|
141 |
+
'input_ids']
|
142 |
+
input_ids = torch.tensor(input_ids).to(device)
|
143 |
+
|
144 |
+
input_ids = torch.cat([
|
145 |
+
(torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|mmu|>']).to(device),
|
146 |
+
(torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|soi|>']).to(device),
|
147 |
+
image_tokens,
|
148 |
+
(torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|eoi|>']).to(device),
|
149 |
+
(torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|sot|>']).to(device),
|
150 |
+
input_ids
|
151 |
+
], dim=1).long()
|
152 |
+
|
153 |
+
attention_mask = create_attention_mask_for_mmu(input_ids.to(device),
|
154 |
+
eoi_id=int(uni_prompting.sptids_dict['<|eoi|>']))
|
155 |
+
|
156 |
+
cont_toks_list = model.mmu_generate(input_ids, attention_mask=attention_mask,
|
157 |
+
max_new_tokens=100, top_k=top_k,
|
158 |
+
eot_token=uni_prompting.sptids_dict['<|eot|>'])
|
159 |
+
|
160 |
+
cont_toks_list = torch.stack(cont_toks_list).squeeze()[None]
|
161 |
+
|
162 |
+
text = uni_prompting.text_tokenizer.batch_decode(cont_toks_list, skip_special_tokens=True)
|
163 |
+
print(text)
|
164 |
+
responses[i] += f'User: ' + question + f'\n Answer : ' + text[0] + '\n'
|
165 |
+
|
166 |
+
images = torch.cat(images, dim=0)
|
167 |
+
images = torch.clamp((images + 1.0) / 2.0, min=0.0, max=1.0)
|
168 |
+
images *= 255.0
|
169 |
+
images = images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8)
|
170 |
+
pil_images = [Image.fromarray(image) for image in images]
|
171 |
+
|
172 |
+
wandb_images = [wandb.Image(image, caption=responses[i]) for i, image in enumerate(pil_images)]
|
173 |
+
wandb.log({"multimodal understanding": wandb_images}, step=0)
|
174 |
+
|
inference_t2i.py
ADDED
@@ -0,0 +1,331 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
3 |
+
from PIL import Image
|
4 |
+
from tqdm import tqdm
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import wandb
|
8 |
+
from models import Showo, MAGVITv2, get_mask_chedule
|
9 |
+
from prompting_utils import UniversalPrompting, create_attention_mask_predict_next
|
10 |
+
from training.utils import get_config, flatten_omega_conf, image_transform
|
11 |
+
from transformers import AutoTokenizer
|
12 |
+
import torch.nn.functional as F
|
13 |
+
|
14 |
+
def get_vq_model_class(model_type):
|
15 |
+
if model_type == "magvitv2":
|
16 |
+
return MAGVITv2
|
17 |
+
else:
|
18 |
+
raise ValueError(f"model_type {model_type} not supported.")
|
19 |
+
|
20 |
+
if __name__ == '__main__':
|
21 |
+
|
22 |
+
config = get_config()
|
23 |
+
|
24 |
+
resume_wandb_run = config.wandb.resume
|
25 |
+
run_id = config.wandb.get("run_id", None)
|
26 |
+
if run_id is None:
|
27 |
+
resume_wandb_run = False
|
28 |
+
run_id = wandb.util.generate_id()
|
29 |
+
config.wandb.run_id = run_id
|
30 |
+
|
31 |
+
wandb_config = {k: v for k, v in flatten_omega_conf(config, resolve=True)}
|
32 |
+
|
33 |
+
wandb.init(
|
34 |
+
project="demo",
|
35 |
+
name=config.experiment.name + '_t2i' + f'_{config.mode}',
|
36 |
+
config=wandb_config,
|
37 |
+
)
|
38 |
+
|
39 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
40 |
+
tokenizer = AutoTokenizer.from_pretrained(config.model.showo.llm_model_path, padding_side="left")
|
41 |
+
|
42 |
+
uni_prompting = UniversalPrompting(tokenizer, max_text_len=config.dataset.preprocessing.max_seq_length,
|
43 |
+
special_tokens=("<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", "<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>"),
|
44 |
+
ignore_id=-100, cond_dropout_prob=config.training.cond_dropout_prob)
|
45 |
+
|
46 |
+
vq_model = get_vq_model_class(config.model.vq_model.type)
|
47 |
+
vq_model = vq_model.from_pretrained(config.model.vq_model.vq_model_name).to(device)
|
48 |
+
vq_model.requires_grad_(False)
|
49 |
+
vq_model.eval()
|
50 |
+
|
51 |
+
model = Showo.from_pretrained(config.model.showo.pretrained_model_path).to(device)
|
52 |
+
model.eval()
|
53 |
+
|
54 |
+
mask_token_id = model.config.mask_token_id
|
55 |
+
|
56 |
+
# load from users passed arguments
|
57 |
+
if config.get("validation_prompts_file", None) is not None:
|
58 |
+
config.dataset.params.validation_prompts_file = config.validation_prompts_file
|
59 |
+
config.training.batch_size = config.batch_size
|
60 |
+
config.training.guidance_scale = config.guidance_scale
|
61 |
+
config.training.generation_timesteps = config.generation_timesteps
|
62 |
+
# load from users passed arguments
|
63 |
+
|
64 |
+
if config.mode == 'inpainting':
|
65 |
+
|
66 |
+
prompt = [config.prompt] * config.batch_size
|
67 |
+
inpainting_image = Image.open(config.image_path).convert("RGB")
|
68 |
+
inpainting_mask = Image.open(config.inpainting_mask_path).convert("L")
|
69 |
+
|
70 |
+
import pdb
|
71 |
+
pdb.set_trace()
|
72 |
+
|
73 |
+
inpainting_image = image_transform(inpainting_image, resolution=config.dataset.params.resolution).to(device)
|
74 |
+
inpainting_mask = image_transform(inpainting_mask, resolution=config.dataset.params.resolution, normalize=False)
|
75 |
+
|
76 |
+
# record original image and inpainting mask
|
77 |
+
images = torch.clamp(
|
78 |
+
(torch.stack([inpainting_image, inpainting_mask.repeat(3, 1, 1).to(device)], dim=0) + 1.0) / 2.0,
|
79 |
+
min=0.0, max=1.0)
|
80 |
+
images *= 255.0
|
81 |
+
images = images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8)
|
82 |
+
pil_images = [Image.fromarray(image) for image in images]
|
83 |
+
|
84 |
+
labels = ['original image', 'inpainting mask']
|
85 |
+
wandb_images = [wandb.Image(image, caption=labels[i]) for i, image in enumerate(pil_images)]
|
86 |
+
|
87 |
+
inpainting_image = inpainting_image.unsqueeze(0).repeat(config.training.batch_size, 1, 1, 1)
|
88 |
+
|
89 |
+
inpainting_mask = inpainting_mask.unsqueeze(0).to(device)
|
90 |
+
inpainting_mask = F.interpolate(inpainting_mask, size=config.dataset.params.resolution // 16, mode='bicubic')
|
91 |
+
inpainting_mask = inpainting_mask.repeat(config.training.batch_size, 1, 1, 1)
|
92 |
+
|
93 |
+
inpainting_mask[inpainting_mask < 0.5] = 0
|
94 |
+
inpainting_mask[inpainting_mask >= 0.5] = 1
|
95 |
+
|
96 |
+
inpainting_mask = inpainting_mask.reshape(config.training.batch_size, -1)
|
97 |
+
inpainting_mask = inpainting_mask.to(torch.bool)
|
98 |
+
|
99 |
+
inpainting_image_tokens = vq_model.get_code(inpainting_image) + len(uni_prompting.text_tokenizer)
|
100 |
+
inpainting_image_tokens[inpainting_mask] = mask_token_id
|
101 |
+
|
102 |
+
input_ids, _ = uni_prompting((prompt, inpainting_image_tokens), 't2i_gen')
|
103 |
+
|
104 |
+
if config.training.guidance_scale > 0:
|
105 |
+
uncond_input_ids, _ = uni_prompting(([''] * len(prompt), inpainting_image_tokens), 't2i_gen')
|
106 |
+
attention_mask = create_attention_mask_predict_next(torch.cat([input_ids, uncond_input_ids], dim=0),
|
107 |
+
pad_id=int(uni_prompting.sptids_dict['<|pad|>']),
|
108 |
+
soi_id=int(uni_prompting.sptids_dict['<|soi|>']),
|
109 |
+
eoi_id=int(uni_prompting.sptids_dict['<|eoi|>']),
|
110 |
+
rm_pad_in_image=True)
|
111 |
+
else:
|
112 |
+
attention_mask = create_attention_mask_predict_next(input_ids,
|
113 |
+
pad_id=int(uni_prompting.sptids_dict['<|pad|>']),
|
114 |
+
soi_id=int(uni_prompting.sptids_dict['<|soi|>']),
|
115 |
+
eoi_id=int(uni_prompting.sptids_dict['<|eoi|>']),
|
116 |
+
rm_pad_in_image=True)
|
117 |
+
uncond_input_ids = None
|
118 |
+
|
119 |
+
if config.get("mask_schedule", None) is not None:
|
120 |
+
schedule = config.mask_schedule.schedule
|
121 |
+
args = config.mask_schedule.get("params", {})
|
122 |
+
mask_schedule = get_mask_chedule(schedule, **args)
|
123 |
+
else:
|
124 |
+
mask_schedule = get_mask_chedule(config.training.get("mask_schedule", "cosine"))
|
125 |
+
|
126 |
+
with torch.no_grad():
|
127 |
+
gen_token_ids = model.t2i_generate(
|
128 |
+
input_ids=input_ids,
|
129 |
+
uncond_input_ids=uncond_input_ids,
|
130 |
+
attention_mask=attention_mask,
|
131 |
+
guidance_scale=config.training.guidance_scale,
|
132 |
+
temperature=config.training.get("generation_temperature", 1.0),
|
133 |
+
timesteps=config.training.generation_timesteps,
|
134 |
+
noise_schedule=mask_schedule,
|
135 |
+
noise_type=config.training.get("noise_type", "mask"),
|
136 |
+
seq_len=config.model.showo.num_vq_tokens,
|
137 |
+
uni_prompting=uni_prompting,
|
138 |
+
config=config,
|
139 |
+
)
|
140 |
+
|
141 |
+
gen_token_ids = torch.clamp(gen_token_ids, max=config.model.showo.codebook_size - 1, min=0)
|
142 |
+
images = vq_model.decode_code(gen_token_ids)
|
143 |
+
|
144 |
+
images = torch.clamp((images + 1.0) / 2.0, min=0.0, max=1.0)
|
145 |
+
images *= 255.0
|
146 |
+
images = images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8)
|
147 |
+
pil_images = [Image.fromarray(image) for image in images]
|
148 |
+
# import ipdb
|
149 |
+
# ipdb.set_trace()
|
150 |
+
wandb_images.extend([wandb.Image(image, caption=prompt[i]) for i, image in enumerate(pil_images)])
|
151 |
+
wandb.log({"generated_images": wandb_images}, step=0)
|
152 |
+
|
153 |
+
elif config.mode == 'extrapolation':
|
154 |
+
|
155 |
+
prompt = [p for p in config.prompt.split(" *** ") if len(p) != 0]
|
156 |
+
extra_direction = [d for d in config.extra_direction.split(" *** ") if len(d) != 0]
|
157 |
+
print(prompt, extra_direction)
|
158 |
+
W = config.dataset.params.resolution // 16
|
159 |
+
for id, (prt, direction) in enumerate(zip(prompt, extra_direction)):
|
160 |
+
prt = [prt] * config.training.batch_size
|
161 |
+
if id == 0:
|
162 |
+
extrapolation_image = Image.open(config.image_path).convert("RGB")
|
163 |
+
extrapolation_image = image_transform(extrapolation_image,
|
164 |
+
resolution=config.dataset.params.resolution).to(device)
|
165 |
+
|
166 |
+
B, _, _ = extrapolation_image.shape
|
167 |
+
extrapolation_image = extrapolation_image.unsqueeze(0)
|
168 |
+
extrapolation_image_tokens = vq_model.get_code(extrapolation_image) + len(uni_prompting.text_tokenizer)
|
169 |
+
extrapolation_image_tokens = extrapolation_image_tokens.reshape(1,
|
170 |
+
config.dataset.params.resolution // 16,
|
171 |
+
config.dataset.params.resolution // 16)
|
172 |
+
extrapolation_image_tokens = extrapolation_image_tokens.repeat(config.training.batch_size, 1, 1)
|
173 |
+
else:
|
174 |
+
|
175 |
+
|
176 |
+
extrapolation_image_tokens = gen_token_ids + len(uni_prompting.text_tokenizer)
|
177 |
+
|
178 |
+
image_left_part = extrapolation_image_tokens[:, :, :-(W//2-config.offset)] - len(uni_prompting.text_tokenizer)
|
179 |
+
image_right_part = extrapolation_image_tokens[:, :, W//2-config.offset:] - len(uni_prompting.text_tokenizer)
|
180 |
+
image_up_part = extrapolation_image_tokens[:, :-(W//2-config.offset), :] - len(uni_prompting.text_tokenizer)
|
181 |
+
image_down_part = extrapolation_image_tokens[:, W//2-config.offset:, :] - len(uni_prompting.text_tokenizer)
|
182 |
+
|
183 |
+
if direction in ['left', 'right']:
|
184 |
+
extrapolation_mask = torch.zeros((config.training.batch_size,
|
185 |
+
config.dataset.params.resolution // 16,
|
186 |
+
config.dataset.params.resolution // 16 // 2 + config.offset),
|
187 |
+
dtype=torch.int64, device=device) + mask_token_id
|
188 |
+
else:
|
189 |
+
extrapolation_mask = torch.zeros((config.training.batch_size,
|
190 |
+
config.dataset.params.resolution // 16 // 2 + config.offset,
|
191 |
+
config.dataset.params.resolution // 16),
|
192 |
+
dtype=torch.int64, device=device) + mask_token_id
|
193 |
+
|
194 |
+
if direction == 'left':
|
195 |
+
extrapolation_image_tokens = torch.cat(
|
196 |
+
[extrapolation_mask, extrapolation_image_tokens[:, :, :W//2-config.offset]], dim=-1)
|
197 |
+
elif direction == 'right':
|
198 |
+
extrapolation_image_tokens = torch.cat(
|
199 |
+
[extrapolation_image_tokens[:, :, -(W//2-config.offset):], extrapolation_mask], dim=-1)
|
200 |
+
elif direction == 'up':
|
201 |
+
extrapolation_image_tokens = torch.cat(
|
202 |
+
[extrapolation_mask, extrapolation_image_tokens[:, :W // 2 - config.offset, :]], dim=-2)
|
203 |
+
else:
|
204 |
+
extrapolation_image_tokens = torch.cat(
|
205 |
+
[extrapolation_image_tokens[:, -(W // 2 - config.offset):, :], extrapolation_mask], dim=-2)
|
206 |
+
|
207 |
+
extrapolation_image_tokens = extrapolation_image_tokens.reshape(config.training.batch_size, -1)
|
208 |
+
|
209 |
+
input_ids, _ = uni_prompting((prt, extrapolation_image_tokens), 't2i_gen')
|
210 |
+
|
211 |
+
if config.training.guidance_scale > 0:
|
212 |
+
uncond_input_ids, _ = uni_prompting(([''] * len(prt), extrapolation_image_tokens), 't2i_gen')
|
213 |
+
attention_mask = create_attention_mask_predict_next(torch.cat([input_ids, uncond_input_ids], dim=0),
|
214 |
+
pad_id=int(uni_prompting.sptids_dict['<|pad|>']),
|
215 |
+
soi_id=int(uni_prompting.sptids_dict['<|soi|>']),
|
216 |
+
eoi_id=int(uni_prompting.sptids_dict['<|eoi|>']),
|
217 |
+
rm_pad_in_image=True)
|
218 |
+
else:
|
219 |
+
attention_mask = create_attention_mask_predict_next(input_ids,
|
220 |
+
pad_id=int(uni_prompting.sptids_dict['<|pad|>']),
|
221 |
+
soi_id=int(uni_prompting.sptids_dict['<|soi|>']),
|
222 |
+
eoi_id=int(uni_prompting.sptids_dict['<|eoi|>']),
|
223 |
+
rm_pad_in_image=True)
|
224 |
+
uncond_input_ids = None
|
225 |
+
|
226 |
+
if config.get("mask_schedule", None) is not None:
|
227 |
+
schedule = config.mask_schedule.schedule
|
228 |
+
args = config.mask_schedule.get("params", {})
|
229 |
+
mask_schedule = get_mask_chedule(schedule, **args)
|
230 |
+
else:
|
231 |
+
mask_schedule = get_mask_chedule(config.training.get("mask_schedule", "cosine"))
|
232 |
+
|
233 |
+
with torch.no_grad():
|
234 |
+
gen_token_ids = model.t2i_generate(
|
235 |
+
input_ids=input_ids,
|
236 |
+
uncond_input_ids=uncond_input_ids,
|
237 |
+
attention_mask=attention_mask,
|
238 |
+
guidance_scale=config.training.guidance_scale,
|
239 |
+
temperature=config.training.get("generation_temperature", 1.0),
|
240 |
+
timesteps=config.training.generation_timesteps,
|
241 |
+
noise_schedule=mask_schedule,
|
242 |
+
noise_type=config.training.get("noise_type", "mask"),
|
243 |
+
seq_len=config.model.showo.num_vq_tokens,
|
244 |
+
uni_prompting=uni_prompting,
|
245 |
+
config=config,
|
246 |
+
)
|
247 |
+
|
248 |
+
gen_token_ids = torch.clamp(gen_token_ids, max=config.model.showo.codebook_size - 1, min=0)
|
249 |
+
gen_token_ids = gen_token_ids.reshape(config.training.batch_size,
|
250 |
+
config.dataset.params.resolution // 16,
|
251 |
+
config.dataset.params.resolution // 16)
|
252 |
+
if direction == 'left':
|
253 |
+
gen_token_ids = torch.cat([gen_token_ids, image_right_part], dim=-1)
|
254 |
+
elif direction == 'right':
|
255 |
+
gen_token_ids = torch.cat([image_left_part, gen_token_ids], dim=-1)
|
256 |
+
elif direction == 'up':
|
257 |
+
gen_token_ids = torch.cat([gen_token_ids, image_down_part], dim=-2)
|
258 |
+
else:
|
259 |
+
gen_token_ids = torch.cat([image_left_part, gen_token_ids], dim=-2)
|
260 |
+
|
261 |
+
_, h, w = gen_token_ids.shape
|
262 |
+
gen_token_ids = gen_token_ids.reshape(config.training.batch_size, -1)
|
263 |
+
images = vq_model.decode_code(gen_token_ids, shape=(h, w))
|
264 |
+
|
265 |
+
images = torch.clamp((images + 1.0) / 2.0, min=0.0, max=1.0)
|
266 |
+
images *= 255.0
|
267 |
+
images = images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8)
|
268 |
+
pil_images = [Image.fromarray(image) for image in images]
|
269 |
+
|
270 |
+
wandb_images = [wandb.Image(image, caption=' '.join(prompt)) for i, image in enumerate(pil_images)]
|
271 |
+
wandb.log({"generated_images": wandb_images}, step=0)
|
272 |
+
|
273 |
+
elif config.mode == 't2i':
|
274 |
+
with open(config.dataset.params.validation_prompts_file, "r") as f:
|
275 |
+
validation_prompts = f.read().splitlines()
|
276 |
+
|
277 |
+
for step in tqdm(range(0, len(validation_prompts), config.training.batch_size)):
|
278 |
+
prompts = validation_prompts[step:step + config.training.batch_size]
|
279 |
+
|
280 |
+
image_tokens = torch.ones((len(prompts), config.model.showo.num_vq_tokens),
|
281 |
+
dtype=torch.long, device=device) * mask_token_id
|
282 |
+
|
283 |
+
input_ids, _ = uni_prompting((prompts, image_tokens), 't2i_gen')
|
284 |
+
|
285 |
+
if config.training.guidance_scale > 0:
|
286 |
+
uncond_input_ids, _ = uni_prompting(([''] * len(prompts), image_tokens), 't2i_gen')
|
287 |
+
attention_mask = create_attention_mask_predict_next(torch.cat([input_ids, uncond_input_ids], dim=0),
|
288 |
+
pad_id=int(uni_prompting.sptids_dict['<|pad|>']),
|
289 |
+
soi_id=int(uni_prompting.sptids_dict['<|soi|>']),
|
290 |
+
eoi_id=int(uni_prompting.sptids_dict['<|eoi|>']),
|
291 |
+
rm_pad_in_image=True)
|
292 |
+
else:
|
293 |
+
attention_mask = create_attention_mask_predict_next(input_ids,
|
294 |
+
pad_id=int(uni_prompting.sptids_dict['<|pad|>']),
|
295 |
+
soi_id=int(uni_prompting.sptids_dict['<|soi|>']),
|
296 |
+
eoi_id=int(uni_prompting.sptids_dict['<|eoi|>']),
|
297 |
+
rm_pad_in_image=True)
|
298 |
+
uncond_input_ids = None
|
299 |
+
|
300 |
+
if config.get("mask_schedule", None) is not None:
|
301 |
+
schedule = config.mask_schedule.schedule
|
302 |
+
args = config.mask_schedule.get("params", {})
|
303 |
+
mask_schedule = get_mask_chedule(schedule, **args)
|
304 |
+
else:
|
305 |
+
mask_schedule = get_mask_chedule(config.training.get("mask_schedule", "cosine"))
|
306 |
+
|
307 |
+
with torch.no_grad():
|
308 |
+
gen_token_ids = model.t2i_generate(
|
309 |
+
input_ids=input_ids,
|
310 |
+
uncond_input_ids=uncond_input_ids,
|
311 |
+
attention_mask=attention_mask,
|
312 |
+
guidance_scale=config.training.guidance_scale,
|
313 |
+
temperature=config.training.get("generation_temperature", 1.0),
|
314 |
+
timesteps=config.training.generation_timesteps,
|
315 |
+
noise_schedule=mask_schedule,
|
316 |
+
noise_type=config.training.get("noise_type", "mask"),
|
317 |
+
seq_len=config.model.showo.num_vq_tokens,
|
318 |
+
uni_prompting=uni_prompting,
|
319 |
+
config=config,
|
320 |
+
)
|
321 |
+
|
322 |
+
gen_token_ids = torch.clamp(gen_token_ids, max=config.model.showo.codebook_size - 1, min=0)
|
323 |
+
images = vq_model.decode_code(gen_token_ids)
|
324 |
+
|
325 |
+
images = torch.clamp((images + 1.0) / 2.0, min=0.0, max=1.0)
|
326 |
+
images *= 255.0
|
327 |
+
images = images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8)
|
328 |
+
pil_images = [Image.fromarray(image) for image in images]
|
329 |
+
|
330 |
+
wandb_images = [wandb.Image(image, caption=prompts[i]) for i, image in enumerate(pil_images)]
|
331 |
+
wandb.log({"generated_images": wandb_images}, step=step)
|
inpainting_validation/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
inpainting_validation/alpine_lake.jpg
ADDED
inpainting_validation/bedroom.jpg
ADDED
inpainting_validation/bedroom_mask.webp
ADDED
inpainting_validation/bench.jpg
ADDED
inpainting_validation/bench_mask.webp
ADDED
inpainting_validation/bus.jpg
ADDED
inpainting_validation/bus_mask.webp
ADDED
inpainting_validation/lake_mountain.jpg
ADDED
inpainting_validation/maya.png
ADDED
inpainting_validation/river.png
ADDED
inpainting_validation/train.jpg
ADDED
inpainting_validation/train_mask.webp
ADDED
inpainting_validation/truebsee.jpg
ADDED
inpainting_validation/truebsee_mask.webp
ADDED
inpainting_validation/wukong1.jpg
ADDED
inpainting_validation/wukong2.jpg
ADDED
mmu_validation/sofa_under_water.jpg
ADDED
models/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .modeling_showo import Showo
|
2 |
+
from .modeling_magvitv2 import VQGANEncoder, VQGANDecoder, LFQuantizer, MAGVITv2
|
3 |
+
from .sampling import *
|
4 |
+
from .clip_encoder import CLIPVisionTower
|
models/clip_encoder.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
|
5 |
+
|
6 |
+
class CLIPVisionTower(nn.Module):
|
7 |
+
def __init__(self, vision_tower):
|
8 |
+
super().__init__()
|
9 |
+
|
10 |
+
self.is_loaded = False
|
11 |
+
|
12 |
+
self.vision_tower_name = vision_tower
|
13 |
+
self.select_layer = -2
|
14 |
+
self.select_feature = "patch"
|
15 |
+
self.load_model()
|
16 |
+
self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
|
17 |
+
|
18 |
+
def load_model(self, device_map=None):
|
19 |
+
if self.is_loaded:
|
20 |
+
print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name))
|
21 |
+
return
|
22 |
+
|
23 |
+
self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
|
24 |
+
self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map)
|
25 |
+
self.vision_tower.requires_grad_(False)
|
26 |
+
|
27 |
+
self.is_loaded = True
|
28 |
+
|
29 |
+
def feature_select(self, image_forward_outs):
|
30 |
+
image_features = image_forward_outs.hidden_states[self.select_layer]
|
31 |
+
if self.select_feature == 'patch':
|
32 |
+
image_features = image_features[:, 1:]
|
33 |
+
elif self.select_feature == 'cls_patch':
|
34 |
+
image_features = image_features
|
35 |
+
else:
|
36 |
+
raise ValueError(f'Unexpected select feature: {self.select_feature}')
|
37 |
+
return image_features
|
38 |
+
|
39 |
+
@torch.no_grad()
|
40 |
+
def forward(self, images):
|
41 |
+
if type(images) is list:
|
42 |
+
image_features = []
|
43 |
+
for image in images:
|
44 |
+
image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
|
45 |
+
image_feature = self.feature_select(image_forward_out).to(image.dtype)
|
46 |
+
image_features.append(image_feature)
|
47 |
+
else:
|
48 |
+
image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
|
49 |
+
image_features = self.feature_select(image_forward_outs).to(images.dtype)
|
50 |
+
|
51 |
+
return image_features
|
52 |
+
|
53 |
+
@property
|
54 |
+
def dummy_feature(self):
|
55 |
+
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
|
56 |
+
|
57 |
+
@property
|
58 |
+
def dtype(self):
|
59 |
+
return self.vision_tower.dtype
|
60 |
+
|
61 |
+
@property
|
62 |
+
def device(self):
|
63 |
+
return self.vision_tower.device
|
64 |
+
|
65 |
+
@property
|
66 |
+
def config(self):
|
67 |
+
if self.is_loaded:
|
68 |
+
return self.vision_tower.config
|
69 |
+
else:
|
70 |
+
return self.cfg_only
|
71 |
+
|
72 |
+
@property
|
73 |
+
def hidden_size(self):
|
74 |
+
return self.config.hidden_size
|
75 |
+
|
76 |
+
@property
|
77 |
+
def num_patches_per_side(self):
|
78 |
+
return self.config.image_size // self.config.patch_size
|
79 |
+
|
80 |
+
@property
|
81 |
+
def num_patches(self):
|
82 |
+
return (self.config.image_size // self.config.patch_size) ** 2
|
83 |
+
|
84 |
+
|
85 |
+
class CLIPVisionTowerS2(CLIPVisionTower):
|
86 |
+
def __init__(self, vision_tower, args, delay_load=False):
|
87 |
+
super().__init__(vision_tower, args, delay_load)
|
88 |
+
|
89 |
+
self.s2_scales = getattr(args, 's2_scales', '336,672,1008')
|
90 |
+
self.s2_scales = list(map(int, self.s2_scales.split(',')))
|
91 |
+
self.s2_scales.sort()
|
92 |
+
self.s2_split_size = self.s2_scales[0]
|
93 |
+
self.s2_image_size = self.s2_scales[-1]
|
94 |
+
|
95 |
+
try:
|
96 |
+
from s2wrapper import forward as multiscale_forward
|
97 |
+
except ImportError:
|
98 |
+
raise ImportError('Package s2wrapper not found! Please install by running: \npip install git+https://github.com/bfshi/scaling_on_scales.git')
|
99 |
+
self.multiscale_forward = multiscale_forward
|
100 |
+
|
101 |
+
# change resize/crop size in preprocessing to the largest image size in s2_scale
|
102 |
+
if not delay_load or getattr(args, 'unfreeze_mm_vision_tower', False):
|
103 |
+
self.image_processor.size['shortest_edge'] = self.s2_image_size
|
104 |
+
self.image_processor.crop_size['height'] = self.image_processor.crop_size['width'] = self.s2_image_size
|
105 |
+
|
106 |
+
def load_model(self, device_map=None):
|
107 |
+
if self.is_loaded:
|
108 |
+
print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name))
|
109 |
+
return
|
110 |
+
|
111 |
+
self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
|
112 |
+
self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map)
|
113 |
+
self.vision_tower.requires_grad_(False)
|
114 |
+
|
115 |
+
self.image_processor.size['shortest_edge'] = self.s2_image_size
|
116 |
+
self.image_processor.crop_size['height'] = self.image_processor.crop_size['width'] = self.s2_image_size
|
117 |
+
|
118 |
+
self.is_loaded = True
|
119 |
+
|
120 |
+
@torch.no_grad()
|
121 |
+
def forward_feature(self, images):
|
122 |
+
image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
|
123 |
+
image_features = self.feature_select(image_forward_outs).to(images.dtype)
|
124 |
+
return image_features
|
125 |
+
|
126 |
+
@torch.no_grad()
|
127 |
+
def forward(self, images):
|
128 |
+
if type(images) is list:
|
129 |
+
image_features = []
|
130 |
+
for image in images:
|
131 |
+
image_feature = self.multiscale_forward(self.forward_feature, image.unsqueeze(0), img_sizes=self.s2_scales, max_split_size=self.s2_split_size)
|
132 |
+
image_features.append(image_feature)
|
133 |
+
else:
|
134 |
+
image_features = self.multiscale_forward(self.forward_feature, images, img_sizes=self.s2_scales, max_split_size=self.s2_split_size)
|
135 |
+
|
136 |
+
return image_features
|
137 |
+
|
138 |
+
@property
|
139 |
+
def hidden_size(self):
|
140 |
+
return self.config.hidden_size * len(self.s2_scales)
|
models/common_modules.py
ADDED
@@ -0,0 +1,407 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Modified from https://github.com/CompVis/taming-transformers/blob/master/taming/modules/diffusionmodules/model.py#L34
|
3 |
+
"""
|
4 |
+
|
5 |
+
import math
|
6 |
+
from typing import Tuple, Union
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.nn.functional as F
|
12 |
+
from einops import rearrange, repeat
|
13 |
+
from einops.layers.torch import Rearrange
|
14 |
+
|
15 |
+
|
16 |
+
def nonlinearity(x):
|
17 |
+
# swish
|
18 |
+
return x * torch.sigmoid(x)
|
19 |
+
|
20 |
+
|
21 |
+
def Normalize(in_channels):
|
22 |
+
return torch.nn.GroupNorm(
|
23 |
+
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
|
24 |
+
)
|
25 |
+
|
26 |
+
|
27 |
+
class Upsample(nn.Module):
|
28 |
+
def __init__(self, in_channels, with_conv):
|
29 |
+
super().__init__()
|
30 |
+
self.with_conv = with_conv
|
31 |
+
if self.with_conv:
|
32 |
+
self.conv = torch.nn.Conv2d(
|
33 |
+
in_channels, in_channels, kernel_size=3, stride=1, padding=1
|
34 |
+
)
|
35 |
+
|
36 |
+
def forward(self, x):
|
37 |
+
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
38 |
+
if self.with_conv:
|
39 |
+
x = self.conv(x)
|
40 |
+
return x
|
41 |
+
|
42 |
+
|
43 |
+
class DepthToSpaceUpsample(nn.Module):
|
44 |
+
def __init__(
|
45 |
+
self,
|
46 |
+
in_channels,
|
47 |
+
):
|
48 |
+
super().__init__()
|
49 |
+
conv = nn.Conv2d(in_channels, in_channels * 4, 1)
|
50 |
+
|
51 |
+
self.net = nn.Sequential(
|
52 |
+
conv,
|
53 |
+
nn.SiLU(),
|
54 |
+
Rearrange("b (c p1 p2) h w -> b c (h p1) (w p2)", p1=2, p2=2),
|
55 |
+
)
|
56 |
+
|
57 |
+
self.init_conv_(conv)
|
58 |
+
|
59 |
+
def init_conv_(self, conv):
|
60 |
+
o, i, h, w = conv.weight.shape
|
61 |
+
conv_weight = torch.empty(o // 4, i, h, w)
|
62 |
+
nn.init.kaiming_uniform_(conv_weight)
|
63 |
+
conv_weight = repeat(conv_weight, "o ... -> (o 4) ...")
|
64 |
+
|
65 |
+
conv.weight.data.copy_(conv_weight)
|
66 |
+
nn.init.zeros_(conv.bias.data)
|
67 |
+
|
68 |
+
def forward(self, x):
|
69 |
+
out = self.net(x)
|
70 |
+
return out
|
71 |
+
|
72 |
+
|
73 |
+
class Downsample(nn.Module):
|
74 |
+
def __init__(self, in_channels, with_conv):
|
75 |
+
super().__init__()
|
76 |
+
self.with_conv = with_conv
|
77 |
+
if self.with_conv:
|
78 |
+
# no asymmetric padding in torch conv, must do it ourselves
|
79 |
+
self.conv = torch.nn.Conv2d(
|
80 |
+
in_channels, in_channels, kernel_size=3, stride=2, padding=0
|
81 |
+
)
|
82 |
+
|
83 |
+
def forward(self, x):
|
84 |
+
if self.with_conv:
|
85 |
+
pad = (0, 1, 0, 1)
|
86 |
+
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
87 |
+
x = self.conv(x)
|
88 |
+
else:
|
89 |
+
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
90 |
+
return x
|
91 |
+
|
92 |
+
|
93 |
+
def unpack_time(t, batch):
|
94 |
+
_, c, w, h = t.size()
|
95 |
+
out = torch.reshape(t, [batch, -1, c, w, h])
|
96 |
+
out = rearrange(out, "b t c h w -> b c t h w")
|
97 |
+
return out
|
98 |
+
|
99 |
+
|
100 |
+
def pack_time(t):
|
101 |
+
out = rearrange(t, "b c t h w -> b t c h w")
|
102 |
+
_, _, c, w, h = out.size()
|
103 |
+
return torch.reshape(out, [-1, c, w, h])
|
104 |
+
|
105 |
+
|
106 |
+
class TimeDownsample2x(nn.Module):
|
107 |
+
def __init__(
|
108 |
+
self,
|
109 |
+
dim,
|
110 |
+
dim_out=None,
|
111 |
+
kernel_size=3,
|
112 |
+
):
|
113 |
+
super().__init__()
|
114 |
+
if dim_out is None:
|
115 |
+
dim_out = dim
|
116 |
+
self.time_causal_padding = (kernel_size - 1, 0)
|
117 |
+
self.conv = nn.Conv1d(dim, dim_out, kernel_size, stride=2)
|
118 |
+
|
119 |
+
def forward(self, x):
|
120 |
+
x = rearrange(x, "b c t h w -> b h w c t")
|
121 |
+
b, h, w, c, t = x.size()
|
122 |
+
x = torch.reshape(x, [-1, c, t])
|
123 |
+
|
124 |
+
x = F.pad(x, self.time_causal_padding)
|
125 |
+
out = self.conv(x)
|
126 |
+
|
127 |
+
out = torch.reshape(out, [b, h, w, c, t])
|
128 |
+
out = rearrange(out, "b h w c t -> b c t h w")
|
129 |
+
out = rearrange(out, "b h w c t -> b c t h w")
|
130 |
+
return out
|
131 |
+
|
132 |
+
|
133 |
+
class TimeUpsample2x(nn.Module):
|
134 |
+
def __init__(self, dim, dim_out=None):
|
135 |
+
super().__init__()
|
136 |
+
if dim_out is None:
|
137 |
+
dim_out = dim
|
138 |
+
conv = nn.Conv1d(dim, dim_out * 2, 1)
|
139 |
+
|
140 |
+
self.net = nn.Sequential(
|
141 |
+
nn.SiLU(), conv, Rearrange("b (c p) t -> b c (t p)", p=2)
|
142 |
+
)
|
143 |
+
|
144 |
+
self.init_conv_(conv)
|
145 |
+
|
146 |
+
def init_conv_(self, conv):
|
147 |
+
o, i, t = conv.weight.shape
|
148 |
+
conv_weight = torch.empty(o // 2, i, t)
|
149 |
+
nn.init.kaiming_uniform_(conv_weight)
|
150 |
+
conv_weight = repeat(conv_weight, "o ... -> (o 2) ...")
|
151 |
+
|
152 |
+
conv.weight.data.copy_(conv_weight)
|
153 |
+
nn.init.zeros_(conv.bias.data)
|
154 |
+
|
155 |
+
def forward(self, x):
|
156 |
+
x = rearrange(x, "b c t h w -> b h w c t")
|
157 |
+
b, h, w, c, t = x.size()
|
158 |
+
x = torch.reshape(x, [-1, c, t])
|
159 |
+
|
160 |
+
out = self.net(x)
|
161 |
+
out = out[:, :, 1:].contiguous()
|
162 |
+
|
163 |
+
out = torch.reshape(out, [b, h, w, c, t])
|
164 |
+
out = rearrange(out, "b h w c t -> b c t h w")
|
165 |
+
return out
|
166 |
+
|
167 |
+
|
168 |
+
class AttnBlock(nn.Module):
|
169 |
+
def __init__(self, in_channels):
|
170 |
+
super().__init__()
|
171 |
+
self.in_channels = in_channels
|
172 |
+
|
173 |
+
self.norm = Normalize(in_channels)
|
174 |
+
self.q = torch.nn.Conv2d(
|
175 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
176 |
+
)
|
177 |
+
self.k = torch.nn.Conv2d(
|
178 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
179 |
+
)
|
180 |
+
self.v = torch.nn.Conv2d(
|
181 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
182 |
+
)
|
183 |
+
self.proj_out = torch.nn.Conv2d(
|
184 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
185 |
+
)
|
186 |
+
|
187 |
+
def forward(self, x):
|
188 |
+
h_ = x
|
189 |
+
h_ = self.norm(h_)
|
190 |
+
q = self.q(h_)
|
191 |
+
k = self.k(h_)
|
192 |
+
v = self.v(h_)
|
193 |
+
|
194 |
+
# compute attention
|
195 |
+
b, c, h, w = q.shape
|
196 |
+
q = q.reshape(b, c, h * w)
|
197 |
+
q = q.permute(0, 2, 1) # b,hw,c
|
198 |
+
k = k.reshape(b, c, h * w) # b,c,hw
|
199 |
+
w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
200 |
+
w_ = w_ * (int(c) ** (-0.5))
|
201 |
+
w_ = torch.nn.functional.softmax(w_, dim=2)
|
202 |
+
|
203 |
+
# attend to values
|
204 |
+
v = v.reshape(b, c, h * w)
|
205 |
+
w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
|
206 |
+
h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
207 |
+
h_ = h_.reshape(b, c, h, w)
|
208 |
+
|
209 |
+
h_ = self.proj_out(h_)
|
210 |
+
|
211 |
+
return x + h_
|
212 |
+
|
213 |
+
|
214 |
+
class TimeAttention(AttnBlock):
|
215 |
+
def forward(self, x, *args, **kwargs):
|
216 |
+
x = rearrange(x, "b c t h w -> b h w t c")
|
217 |
+
b, h, w, t, c = x.size()
|
218 |
+
x = torch.reshape(x, (-1, t, c))
|
219 |
+
|
220 |
+
x = super().forward(x, *args, **kwargs)
|
221 |
+
|
222 |
+
x = torch.reshape(x, [b, h, w, t, c])
|
223 |
+
return rearrange(x, "b h w t c -> b c t h w")
|
224 |
+
|
225 |
+
|
226 |
+
class Residual(nn.Module):
|
227 |
+
def __init__(self, fn: nn.Module):
|
228 |
+
super().__init__()
|
229 |
+
self.fn = fn
|
230 |
+
|
231 |
+
def forward(self, x, **kwargs):
|
232 |
+
return self.fn(x, **kwargs) + x
|
233 |
+
|
234 |
+
|
235 |
+
def cast_tuple(t, length=1):
|
236 |
+
return t if isinstance(t, tuple) else ((t,) * length)
|
237 |
+
|
238 |
+
|
239 |
+
class CausalConv3d(nn.Module):
|
240 |
+
def __init__(
|
241 |
+
self,
|
242 |
+
chan_in,
|
243 |
+
chan_out,
|
244 |
+
kernel_size: Union[int, Tuple[int, int, int]],
|
245 |
+
pad_mode="constant",
|
246 |
+
**kwargs
|
247 |
+
):
|
248 |
+
super().__init__()
|
249 |
+
kernel_size = cast_tuple(kernel_size, 3)
|
250 |
+
|
251 |
+
time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
|
252 |
+
|
253 |
+
dilation = kwargs.pop("dilation", 1)
|
254 |
+
stride = kwargs.pop("stride", 1)
|
255 |
+
|
256 |
+
self.pad_mode = pad_mode
|
257 |
+
time_pad = dilation * (time_kernel_size - 1) + (1 - stride)
|
258 |
+
height_pad = height_kernel_size // 2
|
259 |
+
width_pad = width_kernel_size // 2
|
260 |
+
|
261 |
+
self.time_pad = time_pad
|
262 |
+
self.time_causal_padding = (
|
263 |
+
width_pad,
|
264 |
+
width_pad,
|
265 |
+
height_pad,
|
266 |
+
height_pad,
|
267 |
+
time_pad,
|
268 |
+
0,
|
269 |
+
)
|
270 |
+
|
271 |
+
stride = (stride, 1, 1)
|
272 |
+
dilation = (dilation, 1, 1)
|
273 |
+
self.conv = nn.Conv3d(
|
274 |
+
chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs
|
275 |
+
)
|
276 |
+
|
277 |
+
def forward(self, x):
|
278 |
+
pad_mode = self.pad_mode if self.time_pad < x.shape[2] else "constant"
|
279 |
+
|
280 |
+
x = F.pad(x, self.time_causal_padding, mode=pad_mode)
|
281 |
+
return self.conv(x)
|
282 |
+
|
283 |
+
|
284 |
+
def ResnetBlockCausal3D(
|
285 |
+
dim, kernel_size: Union[int, Tuple[int, int, int]], pad_mode: str = "constant"
|
286 |
+
):
|
287 |
+
net = nn.Sequential(
|
288 |
+
Normalize(dim),
|
289 |
+
nn.SiLU(),
|
290 |
+
CausalConv3d(dim, dim, kernel_size, pad_mode),
|
291 |
+
Normalize(dim),
|
292 |
+
nn.SiLU(),
|
293 |
+
CausalConv3d(dim, dim, kernel_size, pad_mode),
|
294 |
+
)
|
295 |
+
return Residual(net)
|
296 |
+
|
297 |
+
|
298 |
+
class ResnetBlock(nn.Module):
|
299 |
+
def __init__(
|
300 |
+
self,
|
301 |
+
*,
|
302 |
+
in_channels,
|
303 |
+
out_channels=None,
|
304 |
+
conv_shortcut=False,
|
305 |
+
dropout,
|
306 |
+
temb_channels=512
|
307 |
+
):
|
308 |
+
super().__init__()
|
309 |
+
self.in_channels = in_channels
|
310 |
+
out_channels = in_channels if out_channels is None else out_channels
|
311 |
+
self.out_channels = out_channels
|
312 |
+
self.use_conv_shortcut = conv_shortcut
|
313 |
+
|
314 |
+
self.norm1 = Normalize(in_channels)
|
315 |
+
self.conv1 = torch.nn.Conv2d(
|
316 |
+
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
317 |
+
)
|
318 |
+
if temb_channels > 0:
|
319 |
+
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
|
320 |
+
else:
|
321 |
+
self.temb_proj = None
|
322 |
+
self.norm2 = Normalize(out_channels)
|
323 |
+
self.dropout = torch.nn.Dropout(dropout)
|
324 |
+
self.conv2 = torch.nn.Conv2d(
|
325 |
+
out_channels, out_channels, kernel_size=3, stride=1, padding=1
|
326 |
+
)
|
327 |
+
if self.in_channels != self.out_channels:
|
328 |
+
if self.use_conv_shortcut:
|
329 |
+
self.conv_shortcut = torch.nn.Conv2d(
|
330 |
+
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
331 |
+
)
|
332 |
+
else:
|
333 |
+
self.nin_shortcut = torch.nn.Conv2d(
|
334 |
+
in_channels, out_channels, kernel_size=1, stride=1, padding=0
|
335 |
+
)
|
336 |
+
|
337 |
+
def forward(self, x, temb):
|
338 |
+
h = x
|
339 |
+
h = self.norm1(h)
|
340 |
+
h = nonlinearity(h)
|
341 |
+
h = self.conv1(h)
|
342 |
+
|
343 |
+
if temb is not None:
|
344 |
+
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
|
345 |
+
|
346 |
+
h = self.norm2(h)
|
347 |
+
h = nonlinearity(h)
|
348 |
+
h = self.dropout(h)
|
349 |
+
h = self.conv2(h)
|
350 |
+
|
351 |
+
if self.in_channels != self.out_channels:
|
352 |
+
if self.use_conv_shortcut:
|
353 |
+
x = self.conv_shortcut(x)
|
354 |
+
else:
|
355 |
+
x = self.nin_shortcut(x)
|
356 |
+
|
357 |
+
return x + h
|
358 |
+
|
359 |
+
|
360 |
+
class DinoV2Model(nn.Module):
|
361 |
+
def __init__(
|
362 |
+
self,
|
363 |
+
model_name,
|
364 |
+
local_checkpoint_path="",
|
365 |
+
renorm_input=False,
|
366 |
+
old_input_mean=0.5,
|
367 |
+
old_input_std=0.5,
|
368 |
+
freeze_model=False,
|
369 |
+
):
|
370 |
+
super().__init__()
|
371 |
+
if local_checkpoint_path != "":
|
372 |
+
self._model = torch.hub.load(
|
373 |
+
local_checkpoint_path, model_name, source="local"
|
374 |
+
)
|
375 |
+
else:
|
376 |
+
self._model = torch.hub.load("facebookresearch/dinov2", model_name)
|
377 |
+
self.register_buffer(
|
378 |
+
"_dino_input_mean",
|
379 |
+
torch.tensor([0.485, 0.456, 0.406]).float()[None, :, None, None],
|
380 |
+
)
|
381 |
+
self.register_buffer(
|
382 |
+
"_dino_input_std",
|
383 |
+
torch.tensor([0.229, 0.224, 0.225]).float()[None, :, None, None],
|
384 |
+
)
|
385 |
+
self._old_input_mean = old_input_mean
|
386 |
+
self._old_input_std = old_input_std
|
387 |
+
self._renorm_input = renorm_input
|
388 |
+
if freeze_model:
|
389 |
+
for param in self._model.parameters():
|
390 |
+
param.requires_grad = False
|
391 |
+
|
392 |
+
def forward(self, inputs):
|
393 |
+
batch, _, height, width = inputs.size()
|
394 |
+
if self._renorm_input:
|
395 |
+
inputs = inputs * self._old_input_mean + self._old_input_std
|
396 |
+
inputs = (inputs - self._dino_input_mean) / self._dino_input_std
|
397 |
+
# TODO(yanwan): If we want to remove this resizing, have to modify the decoder to support upscaling by a factor of 14.
|
398 |
+
# Reduce both height and width to 7/8 of their original values while maintaining aspect ratio to fit dinov2 requirement.
|
399 |
+
new_height = height // 8 * 7
|
400 |
+
new_width = width // 8 * 7
|
401 |
+
inputs = F.interpolate(inputs, (new_height, new_width), mode="bilinear")
|
402 |
+
features = self._model.forward_features(inputs)["x_norm_patchtokens"]
|
403 |
+
features = torch.transpose(features, 1, 2).contiguous()
|
404 |
+
features = torch.reshape(
|
405 |
+
features, (batch, -1, new_height // 14, new_width // 14)
|
406 |
+
)
|
407 |
+
return features
|
models/logging.py
ADDED
@@ -0,0 +1,338 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 Optuna, Hugging Face
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
""" Logging utilities."""
|
16 |
+
|
17 |
+
import logging
|
18 |
+
import os
|
19 |
+
import sys
|
20 |
+
import threading
|
21 |
+
from logging import CRITICAL # NOQA
|
22 |
+
from logging import DEBUG # NOQA
|
23 |
+
from logging import ERROR # NOQA
|
24 |
+
from logging import FATAL # NOQA
|
25 |
+
from logging import INFO # NOQA
|
26 |
+
from logging import NOTSET # NOQA
|
27 |
+
from logging import WARN # NOQA
|
28 |
+
from logging import WARNING # NOQA
|
29 |
+
from typing import Optional
|
30 |
+
|
31 |
+
from tqdm import auto as tqdm_lib
|
32 |
+
|
33 |
+
_lock = threading.Lock()
|
34 |
+
_default_handler: Optional[logging.Handler] = None
|
35 |
+
|
36 |
+
log_levels = {
|
37 |
+
"debug": logging.DEBUG,
|
38 |
+
"info": logging.INFO,
|
39 |
+
"warning": logging.WARNING,
|
40 |
+
"error": logging.ERROR,
|
41 |
+
"critical": logging.CRITICAL,
|
42 |
+
}
|
43 |
+
|
44 |
+
_default_log_level = logging.WARNING
|
45 |
+
|
46 |
+
_tqdm_active = True
|
47 |
+
|
48 |
+
|
49 |
+
def _get_default_logging_level():
|
50 |
+
"""
|
51 |
+
If muse_VERBOSITY env var is set to one of the valid choices return that as the new default level. If it is
|
52 |
+
not - fall back to `_default_log_level`
|
53 |
+
"""
|
54 |
+
env_level_str = os.getenv("muse_VERBOSITY", None)
|
55 |
+
if env_level_str:
|
56 |
+
if env_level_str in log_levels:
|
57 |
+
return log_levels[env_level_str]
|
58 |
+
else:
|
59 |
+
logging.getLogger().warning(
|
60 |
+
f"Unknown option muse_VERBOSITY={env_level_str}, has to be one of: { ', '.join(log_levels.keys()) }"
|
61 |
+
)
|
62 |
+
return _default_log_level
|
63 |
+
|
64 |
+
|
65 |
+
def _get_library_name() -> str:
|
66 |
+
return __name__.split(".")[0]
|
67 |
+
|
68 |
+
|
69 |
+
def _get_library_root_logger() -> logging.Logger:
|
70 |
+
return logging.getLogger(_get_library_name())
|
71 |
+
|
72 |
+
|
73 |
+
def _configure_library_root_logger() -> None:
|
74 |
+
global _default_handler
|
75 |
+
|
76 |
+
with _lock:
|
77 |
+
if _default_handler:
|
78 |
+
# This library has already configured the library root logger.
|
79 |
+
return
|
80 |
+
_default_handler = logging.StreamHandler() # Set sys.stderr as stream.
|
81 |
+
_default_handler.flush = sys.stderr.flush
|
82 |
+
|
83 |
+
# Apply our default configuration to the library root logger.
|
84 |
+
library_root_logger = _get_library_root_logger()
|
85 |
+
library_root_logger.addHandler(_default_handler)
|
86 |
+
library_root_logger.setLevel(_get_default_logging_level())
|
87 |
+
library_root_logger.propagate = False
|
88 |
+
|
89 |
+
|
90 |
+
def _reset_library_root_logger() -> None:
|
91 |
+
global _default_handler
|
92 |
+
|
93 |
+
with _lock:
|
94 |
+
if not _default_handler:
|
95 |
+
return
|
96 |
+
|
97 |
+
library_root_logger = _get_library_root_logger()
|
98 |
+
library_root_logger.removeHandler(_default_handler)
|
99 |
+
library_root_logger.setLevel(logging.NOTSET)
|
100 |
+
_default_handler = None
|
101 |
+
|
102 |
+
|
103 |
+
def get_log_levels_dict():
|
104 |
+
return log_levels
|
105 |
+
|
106 |
+
|
107 |
+
def get_logger(name: Optional[str] = None) -> logging.Logger:
|
108 |
+
"""
|
109 |
+
Return a logger with the specified name.
|
110 |
+
|
111 |
+
This function is not supposed to be directly accessed unless you are writing a custom muse module.
|
112 |
+
"""
|
113 |
+
|
114 |
+
if name is None:
|
115 |
+
name = _get_library_name()
|
116 |
+
|
117 |
+
_configure_library_root_logger()
|
118 |
+
return logging.getLogger(name)
|
119 |
+
|
120 |
+
|
121 |
+
def get_verbosity() -> int:
|
122 |
+
"""
|
123 |
+
Return the current level for the 🤗 muse' root logger as an int.
|
124 |
+
|
125 |
+
Returns:
|
126 |
+
`int`: The logging level.
|
127 |
+
|
128 |
+
<Tip>
|
129 |
+
|
130 |
+
🤗 muse has following logging levels:
|
131 |
+
|
132 |
+
- 50: `muse.logging.CRITICAL` or `muse.logging.FATAL`
|
133 |
+
- 40: `muse.logging.ERROR`
|
134 |
+
- 30: `muse.logging.WARNING` or `muse.logging.WARN`
|
135 |
+
- 20: `muse.logging.INFO`
|
136 |
+
- 10: `muse.logging.DEBUG`
|
137 |
+
|
138 |
+
</Tip>"""
|
139 |
+
|
140 |
+
_configure_library_root_logger()
|
141 |
+
return _get_library_root_logger().getEffectiveLevel()
|
142 |
+
|
143 |
+
|
144 |
+
def set_verbosity(verbosity: int) -> None:
|
145 |
+
"""
|
146 |
+
Set the verbosity level for the 🤗 muse' root logger.
|
147 |
+
|
148 |
+
Args:
|
149 |
+
verbosity (`int`):
|
150 |
+
Logging level, e.g., one of:
|
151 |
+
|
152 |
+
- `muse.logging.CRITICAL` or `muse.logging.FATAL`
|
153 |
+
- `muse.logging.ERROR`
|
154 |
+
- `muse.logging.WARNING` or `muse.logging.WARN`
|
155 |
+
- `muse.logging.INFO`
|
156 |
+
- `muse.logging.DEBUG`
|
157 |
+
"""
|
158 |
+
|
159 |
+
_configure_library_root_logger()
|
160 |
+
_get_library_root_logger().setLevel(verbosity)
|
161 |
+
|
162 |
+
|
163 |
+
def set_verbosity_info():
|
164 |
+
"""Set the verbosity to the `INFO` level."""
|
165 |
+
return set_verbosity(INFO)
|
166 |
+
|
167 |
+
|
168 |
+
def set_verbosity_warning():
|
169 |
+
"""Set the verbosity to the `WARNING` level."""
|
170 |
+
return set_verbosity(WARNING)
|
171 |
+
|
172 |
+
|
173 |
+
def set_verbosity_debug():
|
174 |
+
"""Set the verbosity to the `DEBUG` level."""
|
175 |
+
return set_verbosity(DEBUG)
|
176 |
+
|
177 |
+
|
178 |
+
def set_verbosity_error():
|
179 |
+
"""Set the verbosity to the `ERROR` level."""
|
180 |
+
return set_verbosity(ERROR)
|
181 |
+
|
182 |
+
|
183 |
+
def disable_default_handler() -> None:
|
184 |
+
"""Disable the default handler of the HuggingFace muse' root logger."""
|
185 |
+
|
186 |
+
_configure_library_root_logger()
|
187 |
+
|
188 |
+
assert _default_handler is not None
|
189 |
+
_get_library_root_logger().removeHandler(_default_handler)
|
190 |
+
|
191 |
+
|
192 |
+
def enable_default_handler() -> None:
|
193 |
+
"""Enable the default handler of the HuggingFace muse' root logger."""
|
194 |
+
|
195 |
+
_configure_library_root_logger()
|
196 |
+
|
197 |
+
assert _default_handler is not None
|
198 |
+
_get_library_root_logger().addHandler(_default_handler)
|
199 |
+
|
200 |
+
|
201 |
+
def add_handler(handler: logging.Handler) -> None:
|
202 |
+
"""adds a handler to the HuggingFace muse' root logger."""
|
203 |
+
|
204 |
+
_configure_library_root_logger()
|
205 |
+
|
206 |
+
assert handler is not None
|
207 |
+
_get_library_root_logger().addHandler(handler)
|
208 |
+
|
209 |
+
|
210 |
+
def remove_handler(handler: logging.Handler) -> None:
|
211 |
+
"""removes given handler from the HuggingFace muse' root logger."""
|
212 |
+
|
213 |
+
_configure_library_root_logger()
|
214 |
+
|
215 |
+
assert handler is not None and handler not in _get_library_root_logger().handlers
|
216 |
+
_get_library_root_logger().removeHandler(handler)
|
217 |
+
|
218 |
+
|
219 |
+
def disable_propagation() -> None:
|
220 |
+
"""
|
221 |
+
Disable propagation of the library log outputs. Note that log propagation is disabled by default.
|
222 |
+
"""
|
223 |
+
|
224 |
+
_configure_library_root_logger()
|
225 |
+
_get_library_root_logger().propagate = False
|
226 |
+
|
227 |
+
|
228 |
+
def enable_propagation() -> None:
|
229 |
+
"""
|
230 |
+
Enable propagation of the library log outputs. Please disable the HuggingFace muse' default handler to prevent
|
231 |
+
double logging if the root logger has been configured.
|
232 |
+
"""
|
233 |
+
|
234 |
+
_configure_library_root_logger()
|
235 |
+
_get_library_root_logger().propagate = True
|
236 |
+
|
237 |
+
|
238 |
+
def enable_explicit_format() -> None:
|
239 |
+
"""
|
240 |
+
Enable explicit formatting for every HuggingFace muse' logger. The explicit formatter is as follows:
|
241 |
+
```
|
242 |
+
[LEVELNAME|FILENAME|LINE NUMBER] TIME >> MESSAGE
|
243 |
+
```
|
244 |
+
All handlers currently bound to the root logger are affected by this method.
|
245 |
+
"""
|
246 |
+
handlers = _get_library_root_logger().handlers
|
247 |
+
|
248 |
+
for handler in handlers:
|
249 |
+
formatter = logging.Formatter("[%(levelname)s|%(filename)s:%(lineno)s] %(asctime)s >> %(message)s")
|
250 |
+
handler.setFormatter(formatter)
|
251 |
+
|
252 |
+
|
253 |
+
def reset_format() -> None:
|
254 |
+
"""
|
255 |
+
Resets the formatting for HuggingFace muse' loggers.
|
256 |
+
|
257 |
+
All handlers currently bound to the root logger are affected by this method.
|
258 |
+
"""
|
259 |
+
handlers = _get_library_root_logger().handlers
|
260 |
+
|
261 |
+
for handler in handlers:
|
262 |
+
handler.setFormatter(None)
|
263 |
+
|
264 |
+
|
265 |
+
def warning_advice(self, *args, **kwargs):
|
266 |
+
"""
|
267 |
+
This method is identical to `logger.warning()`, but if env var muse_NO_ADVISORY_WARNINGS=1 is set, this
|
268 |
+
warning will not be printed
|
269 |
+
"""
|
270 |
+
no_advisory_warnings = os.getenv("muse_NO_ADVISORY_WARNINGS", False)
|
271 |
+
if no_advisory_warnings:
|
272 |
+
return
|
273 |
+
self.warning(*args, **kwargs)
|
274 |
+
|
275 |
+
|
276 |
+
logging.Logger.warning_advice = warning_advice
|
277 |
+
|
278 |
+
|
279 |
+
class EmptyTqdm:
|
280 |
+
"""Dummy tqdm which doesn't do anything."""
|
281 |
+
|
282 |
+
def __init__(self, *args, **kwargs): # pylint: disable=unused-argument
|
283 |
+
self._iterator = args[0] if args else None
|
284 |
+
|
285 |
+
def __iter__(self):
|
286 |
+
return iter(self._iterator)
|
287 |
+
|
288 |
+
def __getattr__(self, _):
|
289 |
+
"""Return empty function."""
|
290 |
+
|
291 |
+
def empty_fn(*args, **kwargs): # pylint: disable=unused-argument
|
292 |
+
return
|
293 |
+
|
294 |
+
return empty_fn
|
295 |
+
|
296 |
+
def __enter__(self):
|
297 |
+
return self
|
298 |
+
|
299 |
+
def __exit__(self, type_, value, traceback):
|
300 |
+
return
|
301 |
+
|
302 |
+
|
303 |
+
class _tqdm_cls:
|
304 |
+
def __call__(self, *args, **kwargs):
|
305 |
+
if _tqdm_active:
|
306 |
+
return tqdm_lib.tqdm(*args, **kwargs)
|
307 |
+
else:
|
308 |
+
return EmptyTqdm(*args, **kwargs)
|
309 |
+
|
310 |
+
def set_lock(self, *args, **kwargs):
|
311 |
+
self._lock = None
|
312 |
+
if _tqdm_active:
|
313 |
+
return tqdm_lib.tqdm.set_lock(*args, **kwargs)
|
314 |
+
|
315 |
+
def get_lock(self):
|
316 |
+
if _tqdm_active:
|
317 |
+
return tqdm_lib.tqdm.get_lock()
|
318 |
+
|
319 |
+
|
320 |
+
tqdm = _tqdm_cls()
|
321 |
+
|
322 |
+
|
323 |
+
def is_progress_bar_enabled() -> bool:
|
324 |
+
"""Return a boolean indicating whether tqdm progress bars are enabled."""
|
325 |
+
global _tqdm_active
|
326 |
+
return bool(_tqdm_active)
|
327 |
+
|
328 |
+
|
329 |
+
def enable_progress_bar():
|
330 |
+
"""Enable tqdm progress bar."""
|
331 |
+
global _tqdm_active
|
332 |
+
_tqdm_active = True
|
333 |
+
|
334 |
+
|
335 |
+
def disable_progress_bar():
|
336 |
+
"""Disable tqdm progress bar."""
|
337 |
+
global _tqdm_active
|
338 |
+
_tqdm_active = False
|
models/lr_schedulers.py
ADDED
@@ -0,0 +1,292 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The HuggingFace Inc. team.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
"""PyTorch optimization for diffusion models."""
|
16 |
+
|
17 |
+
import math
|
18 |
+
from enum import Enum
|
19 |
+
from typing import Optional, Union
|
20 |
+
|
21 |
+
from torch.optim import Optimizer
|
22 |
+
from torch.optim.lr_scheduler import LambdaLR
|
23 |
+
|
24 |
+
from .logging import get_logger
|
25 |
+
|
26 |
+
logger = get_logger(__name__)
|
27 |
+
|
28 |
+
|
29 |
+
class SchedulerType(Enum):
|
30 |
+
LINEAR = "linear"
|
31 |
+
COSINE = "cosine"
|
32 |
+
COSINE_WITH_RESTARTS = "cosine_with_restarts"
|
33 |
+
POLYNOMIAL = "polynomial"
|
34 |
+
CONSTANT = "constant"
|
35 |
+
CONSTANT_WITH_WARMUP = "constant_with_warmup"
|
36 |
+
|
37 |
+
|
38 |
+
def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1):
|
39 |
+
"""
|
40 |
+
Create a schedule with a constant learning rate, using the learning rate set in optimizer.
|
41 |
+
|
42 |
+
Args:
|
43 |
+
optimizer ([`~torch.optim.Optimizer`]):
|
44 |
+
The optimizer for which to schedule the learning rate.
|
45 |
+
last_epoch (`int`, *optional*, defaults to -1):
|
46 |
+
The index of the last epoch when resuming training.
|
47 |
+
|
48 |
+
Return:
|
49 |
+
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
50 |
+
"""
|
51 |
+
return LambdaLR(optimizer, lambda _: 1, last_epoch=last_epoch)
|
52 |
+
|
53 |
+
|
54 |
+
def get_constant_schedule_with_warmup(optimizer: Optimizer, num_warmup_steps: int, last_epoch: int = -1):
|
55 |
+
"""
|
56 |
+
Create a schedule with a constant learning rate preceded by a warmup period during which the learning rate
|
57 |
+
increases linearly between 0 and the initial lr set in the optimizer.
|
58 |
+
|
59 |
+
Args:
|
60 |
+
optimizer ([`~torch.optim.Optimizer`]):
|
61 |
+
The optimizer for which to schedule the learning rate.
|
62 |
+
num_warmup_steps (`int`):
|
63 |
+
The number of steps for the warmup phase.
|
64 |
+
last_epoch (`int`, *optional*, defaults to -1):
|
65 |
+
The index of the last epoch when resuming training.
|
66 |
+
|
67 |
+
Return:
|
68 |
+
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
69 |
+
"""
|
70 |
+
|
71 |
+
def lr_lambda(current_step: int):
|
72 |
+
if current_step < num_warmup_steps:
|
73 |
+
return float(current_step) / float(max(1.0, num_warmup_steps))
|
74 |
+
return 1.0
|
75 |
+
|
76 |
+
return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)
|
77 |
+
|
78 |
+
|
79 |
+
def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
|
80 |
+
"""
|
81 |
+
Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after
|
82 |
+
a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.
|
83 |
+
|
84 |
+
Args:
|
85 |
+
optimizer ([`~torch.optim.Optimizer`]):
|
86 |
+
The optimizer for which to schedule the learning rate.
|
87 |
+
num_warmup_steps (`int`):
|
88 |
+
The number of steps for the warmup phase.
|
89 |
+
num_training_steps (`int`):
|
90 |
+
The total number of training steps.
|
91 |
+
last_epoch (`int`, *optional*, defaults to -1):
|
92 |
+
The index of the last epoch when resuming training.
|
93 |
+
|
94 |
+
Return:
|
95 |
+
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
96 |
+
"""
|
97 |
+
|
98 |
+
def lr_lambda(current_step: int):
|
99 |
+
if current_step < num_warmup_steps:
|
100 |
+
return float(current_step) / float(max(1, num_warmup_steps))
|
101 |
+
return max(
|
102 |
+
0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
|
103 |
+
)
|
104 |
+
|
105 |
+
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
106 |
+
|
107 |
+
|
108 |
+
def get_cosine_schedule_with_warmup(
|
109 |
+
optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: float = 0.5, last_epoch: int = -1
|
110 |
+
):
|
111 |
+
"""
|
112 |
+
Create a schedule with a learning rate that decreases following the values of the cosine function between the
|
113 |
+
initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
|
114 |
+
initial lr set in the optimizer.
|
115 |
+
|
116 |
+
Args:
|
117 |
+
optimizer ([`~torch.optim.Optimizer`]):
|
118 |
+
The optimizer for which to schedule the learning rate.
|
119 |
+
num_warmup_steps (`int`):
|
120 |
+
The number of steps for the warmup phase.
|
121 |
+
num_training_steps (`int`):
|
122 |
+
The total number of training steps.
|
123 |
+
num_periods (`float`, *optional*, defaults to 0.5):
|
124 |
+
The number of periods of the cosine function in a schedule (the default is to just decrease from the max
|
125 |
+
value to 0 following a half-cosine).
|
126 |
+
last_epoch (`int`, *optional*, defaults to -1):
|
127 |
+
The index of the last epoch when resuming training.
|
128 |
+
|
129 |
+
Return:
|
130 |
+
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
131 |
+
"""
|
132 |
+
|
133 |
+
def lr_lambda(current_step):
|
134 |
+
if current_step < num_warmup_steps:
|
135 |
+
return float(current_step) / float(max(1, num_warmup_steps))
|
136 |
+
progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
|
137 |
+
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
|
138 |
+
|
139 |
+
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
140 |
+
|
141 |
+
|
142 |
+
def get_cosine_with_hard_restarts_schedule_with_warmup(
|
143 |
+
optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: int = 1, last_epoch: int = -1
|
144 |
+
):
|
145 |
+
"""
|
146 |
+
Create a schedule with a learning rate that decreases following the values of the cosine function between the
|
147 |
+
initial lr set in the optimizer to 0, with several hard restarts, after a warmup period during which it increases
|
148 |
+
linearly between 0 and the initial lr set in the optimizer.
|
149 |
+
|
150 |
+
Args:
|
151 |
+
optimizer ([`~torch.optim.Optimizer`]):
|
152 |
+
The optimizer for which to schedule the learning rate.
|
153 |
+
num_warmup_steps (`int`):
|
154 |
+
The number of steps for the warmup phase.
|
155 |
+
num_training_steps (`int`):
|
156 |
+
The total number of training steps.
|
157 |
+
num_cycles (`int`, *optional*, defaults to 1):
|
158 |
+
The number of hard restarts to use.
|
159 |
+
last_epoch (`int`, *optional*, defaults to -1):
|
160 |
+
The index of the last epoch when resuming training.
|
161 |
+
|
162 |
+
Return:
|
163 |
+
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
164 |
+
"""
|
165 |
+
|
166 |
+
def lr_lambda(current_step):
|
167 |
+
if current_step < num_warmup_steps:
|
168 |
+
return float(current_step) / float(max(1, num_warmup_steps))
|
169 |
+
progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
|
170 |
+
if progress >= 1.0:
|
171 |
+
return 0.0
|
172 |
+
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0))))
|
173 |
+
|
174 |
+
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
175 |
+
|
176 |
+
|
177 |
+
def get_polynomial_decay_schedule_with_warmup(
|
178 |
+
optimizer, num_warmup_steps, num_training_steps, lr_end=1e-7, power=1.0, last_epoch=-1
|
179 |
+
):
|
180 |
+
"""
|
181 |
+
Create a schedule with a learning rate that decreases as a polynomial decay from the initial lr set in the
|
182 |
+
optimizer to end lr defined by *lr_end*, after a warmup period during which it increases linearly from 0 to the
|
183 |
+
initial lr set in the optimizer.
|
184 |
+
|
185 |
+
Args:
|
186 |
+
optimizer ([`~torch.optim.Optimizer`]):
|
187 |
+
The optimizer for which to schedule the learning rate.
|
188 |
+
num_warmup_steps (`int`):
|
189 |
+
The number of steps for the warmup phase.
|
190 |
+
num_training_steps (`int`):
|
191 |
+
The total number of training steps.
|
192 |
+
lr_end (`float`, *optional*, defaults to 1e-7):
|
193 |
+
The end LR.
|
194 |
+
power (`float`, *optional*, defaults to 1.0):
|
195 |
+
Power factor.
|
196 |
+
last_epoch (`int`, *optional*, defaults to -1):
|
197 |
+
The index of the last epoch when resuming training.
|
198 |
+
|
199 |
+
Note: *power* defaults to 1.0 as in the fairseq implementation, which in turn is based on the original BERT
|
200 |
+
implementation at
|
201 |
+
https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/optimization.py#L37
|
202 |
+
|
203 |
+
Return:
|
204 |
+
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
205 |
+
|
206 |
+
"""
|
207 |
+
|
208 |
+
lr_init = optimizer.defaults["lr"]
|
209 |
+
if not (lr_init > lr_end):
|
210 |
+
raise ValueError(f"lr_end ({lr_end}) must be be smaller than initial lr ({lr_init})")
|
211 |
+
|
212 |
+
def lr_lambda(current_step: int):
|
213 |
+
if current_step < num_warmup_steps:
|
214 |
+
return float(current_step) / float(max(1, num_warmup_steps))
|
215 |
+
elif current_step > num_training_steps:
|
216 |
+
return lr_end / lr_init # as LambdaLR multiplies by lr_init
|
217 |
+
else:
|
218 |
+
lr_range = lr_init - lr_end
|
219 |
+
decay_steps = num_training_steps - num_warmup_steps
|
220 |
+
pct_remaining = 1 - (current_step - num_warmup_steps) / decay_steps
|
221 |
+
decay = lr_range * pct_remaining**power + lr_end
|
222 |
+
return decay / lr_init # as LambdaLR multiplies by lr_init
|
223 |
+
|
224 |
+
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
225 |
+
|
226 |
+
|
227 |
+
TYPE_TO_SCHEDULER_FUNCTION = {
|
228 |
+
SchedulerType.LINEAR: get_linear_schedule_with_warmup,
|
229 |
+
SchedulerType.COSINE: get_cosine_schedule_with_warmup,
|
230 |
+
SchedulerType.COSINE_WITH_RESTARTS: get_cosine_with_hard_restarts_schedule_with_warmup,
|
231 |
+
SchedulerType.POLYNOMIAL: get_polynomial_decay_schedule_with_warmup,
|
232 |
+
SchedulerType.CONSTANT: get_constant_schedule,
|
233 |
+
SchedulerType.CONSTANT_WITH_WARMUP: get_constant_schedule_with_warmup,
|
234 |
+
}
|
235 |
+
|
236 |
+
|
237 |
+
def get_scheduler(
|
238 |
+
name: Union[str, SchedulerType],
|
239 |
+
optimizer: Optimizer,
|
240 |
+
num_warmup_steps: Optional[int] = None,
|
241 |
+
num_training_steps: Optional[int] = None,
|
242 |
+
num_cycles: int = 1,
|
243 |
+
power: float = 1.0,
|
244 |
+
):
|
245 |
+
"""
|
246 |
+
Unified API to get any scheduler from its name.
|
247 |
+
|
248 |
+
Args:
|
249 |
+
name (`str` or `SchedulerType`):
|
250 |
+
The name of the scheduler to use.
|
251 |
+
optimizer (`torch.optim.Optimizer`):
|
252 |
+
The optimizer that will be used during training.
|
253 |
+
num_warmup_steps (`int`, *optional*):
|
254 |
+
The number of warmup steps to do. This is not required by all schedulers (hence the argument being
|
255 |
+
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
256 |
+
num_training_steps (`int``, *optional*):
|
257 |
+
The number of training steps to do. This is not required by all schedulers (hence the argument being
|
258 |
+
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
259 |
+
num_cycles (`int`, *optional*):
|
260 |
+
The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler.
|
261 |
+
power (`float`, *optional*, defaults to 1.0):
|
262 |
+
Power factor. See `POLYNOMIAL` scheduler
|
263 |
+
last_epoch (`int`, *optional*, defaults to -1):
|
264 |
+
The index of the last epoch when resuming training.
|
265 |
+
"""
|
266 |
+
name = SchedulerType(name)
|
267 |
+
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
|
268 |
+
if name == SchedulerType.CONSTANT:
|
269 |
+
return schedule_func(optimizer)
|
270 |
+
|
271 |
+
# All other schedulers require `num_warmup_steps`
|
272 |
+
if num_warmup_steps is None:
|
273 |
+
raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
|
274 |
+
|
275 |
+
if name == SchedulerType.CONSTANT_WITH_WARMUP:
|
276 |
+
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)
|
277 |
+
|
278 |
+
# All other schedulers require `num_training_steps`
|
279 |
+
if num_training_steps is None:
|
280 |
+
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
|
281 |
+
|
282 |
+
if name == SchedulerType.COSINE_WITH_RESTARTS:
|
283 |
+
return schedule_func(
|
284 |
+
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles
|
285 |
+
)
|
286 |
+
|
287 |
+
if name == SchedulerType.POLYNOMIAL:
|
288 |
+
return schedule_func(
|
289 |
+
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power
|
290 |
+
)
|
291 |
+
|
292 |
+
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
|
models/misc.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from omegaconf import OmegaConf
|
2 |
+
import torch
|
3 |
+
from typing import (
|
4 |
+
Any,
|
5 |
+
Callable,
|
6 |
+
Dict,
|
7 |
+
Iterable,
|
8 |
+
List,
|
9 |
+
NamedTuple,
|
10 |
+
NewType,
|
11 |
+
Optional,
|
12 |
+
Sized,
|
13 |
+
Tuple,
|
14 |
+
Type,
|
15 |
+
TypeVar,
|
16 |
+
Union,
|
17 |
+
)
|
18 |
+
try:
|
19 |
+
from typing import Literal
|
20 |
+
except ImportError:
|
21 |
+
from typing_extensions import Literal
|
22 |
+
|
23 |
+
# Tensor dtype
|
24 |
+
# for jaxtyping usage, see https://github.com/google/jaxtyping/blob/main/API.md
|
25 |
+
from jaxtyping import Bool, Complex, Float, Inexact, Int, Integer, Num, Shaped, UInt
|
26 |
+
|
27 |
+
# Config type
|
28 |
+
from omegaconf import DictConfig
|
29 |
+
|
30 |
+
# PyTorch Tensor type
|
31 |
+
from torch import Tensor
|
32 |
+
|
33 |
+
# Runtime type checking decorator
|
34 |
+
from typeguard import typechecked as typechecker
|
35 |
+
|
36 |
+
|
37 |
+
def broadcast(tensor, src=0):
|
38 |
+
if not _distributed_available():
|
39 |
+
return tensor
|
40 |
+
else:
|
41 |
+
torch.distributed.broadcast(tensor, src=src)
|
42 |
+
return tensor
|
43 |
+
|
44 |
+
def _distributed_available():
|
45 |
+
return torch.distributed.is_available() and torch.distributed.is_initialized()
|
46 |
+
|
47 |
+
def parse_structured(fields: Any, cfg: Optional[Union[dict, DictConfig]] = None) -> Any:
|
48 |
+
# added by Xavier -- delete '--local-rank' in multi-nodes training, don't know why there is such a keyword
|
49 |
+
if '--local-rank' in cfg:
|
50 |
+
del cfg['--local-rank']
|
51 |
+
# added by Xavier -- delete '--local-rank' in multi-nodes training, don't know why there is such a keyword
|
52 |
+
scfg = OmegaConf.structured(fields(**cfg))
|
53 |
+
return scfg
|
models/modeling_magvitv2.py
ADDED
@@ -0,0 +1,440 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass, field
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from .common_modules import *
|
6 |
+
from .modeling_utils import ConfigMixin, ModelMixin, register_to_config
|
7 |
+
from .misc import *
|
8 |
+
import math
|
9 |
+
|
10 |
+
class Updateable:
|
11 |
+
def do_update_step(
|
12 |
+
self, epoch: int, global_step: int, on_load_weights: bool = False
|
13 |
+
):
|
14 |
+
for attr in self.__dir__():
|
15 |
+
if attr.startswith("_"):
|
16 |
+
continue
|
17 |
+
try:
|
18 |
+
module = getattr(self, attr)
|
19 |
+
except:
|
20 |
+
continue # ignore attributes like property, which can't be retrived using getattr?
|
21 |
+
if isinstance(module, Updateable):
|
22 |
+
module.do_update_step(
|
23 |
+
epoch, global_step, on_load_weights=on_load_weights
|
24 |
+
)
|
25 |
+
self.update_step(epoch, global_step, on_load_weights=on_load_weights)
|
26 |
+
|
27 |
+
def do_update_step_end(self, epoch: int, global_step: int):
|
28 |
+
for attr in self.__dir__():
|
29 |
+
if attr.startswith("_"):
|
30 |
+
continue
|
31 |
+
try:
|
32 |
+
module = getattr(self, attr)
|
33 |
+
except:
|
34 |
+
continue # ignore attributes like property, which can't be retrived using getattr?
|
35 |
+
if isinstance(module, Updateable):
|
36 |
+
module.do_update_step_end(epoch, global_step)
|
37 |
+
self.update_step_end(epoch, global_step)
|
38 |
+
|
39 |
+
def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False):
|
40 |
+
# override this method to implement custom update logic
|
41 |
+
# if on_load_weights is True, you should be careful doing things related to model evaluations,
|
42 |
+
# as the models and tensors are not guarenteed to be on the same device
|
43 |
+
pass
|
44 |
+
|
45 |
+
def update_step_end(self, epoch: int, global_step: int):
|
46 |
+
pass
|
47 |
+
|
48 |
+
class VQGANEncoder(ModelMixin, ConfigMixin):
|
49 |
+
@dataclass
|
50 |
+
class Config:
|
51 |
+
ch: int = 128
|
52 |
+
ch_mult: List[int] = field(default_factory=lambda: [1, 2, 2, 4, 4])
|
53 |
+
num_res_blocks: List[int] = field(default_factory=lambda: [4, 3, 4, 3, 4])
|
54 |
+
attn_resolutions: List[int] = field(default_factory=lambda: [5])
|
55 |
+
dropout: float = 0.0
|
56 |
+
in_ch: int = 3
|
57 |
+
out_ch: int = 3
|
58 |
+
resolution: int = 256
|
59 |
+
z_channels: int = 13
|
60 |
+
double_z: bool = False
|
61 |
+
|
62 |
+
def __init__(self,
|
63 |
+
ch: int = 128,
|
64 |
+
ch_mult: List[int] = [1, 2, 2, 4, 4],
|
65 |
+
num_res_blocks: List[int] = [4, 3, 4, 3, 4],
|
66 |
+
attn_resolutions: List[int] = [5],
|
67 |
+
dropout: float = 0.0,
|
68 |
+
in_ch: int = 3,
|
69 |
+
out_ch: int = 3,
|
70 |
+
resolution: int = 256,
|
71 |
+
z_channels: int = 13,
|
72 |
+
double_z: bool = False):
|
73 |
+
super().__init__()
|
74 |
+
self.ch = ch
|
75 |
+
self.temb_ch = 0
|
76 |
+
self.num_resolutions = len(ch_mult)
|
77 |
+
self.num_res_blocks = num_res_blocks
|
78 |
+
self.resolution = resolution
|
79 |
+
self.in_ch = in_ch
|
80 |
+
# downsampling
|
81 |
+
self.conv_in = torch.nn.Conv2d(
|
82 |
+
self.in_ch, self.ch, kernel_size=3, stride=1, padding=1
|
83 |
+
)
|
84 |
+
|
85 |
+
curr_res = self.resolution
|
86 |
+
in_ch_mult = (1,) + tuple(ch_mult)
|
87 |
+
self.down = nn.ModuleList()
|
88 |
+
for i_level in range(self.num_resolutions):
|
89 |
+
block = nn.ModuleList()
|
90 |
+
attn = nn.ModuleList()
|
91 |
+
block_in = self.ch * in_ch_mult[i_level]
|
92 |
+
block_out = self.ch * ch_mult[i_level]
|
93 |
+
for i_block in range(self.num_res_blocks[i_level]):
|
94 |
+
block.append(
|
95 |
+
ResnetBlock(
|
96 |
+
in_channels=block_in,
|
97 |
+
out_channels=block_out,
|
98 |
+
temb_channels=self.temb_ch,
|
99 |
+
dropout=dropout,
|
100 |
+
)
|
101 |
+
)
|
102 |
+
block_in = block_out
|
103 |
+
if curr_res in attn_resolutions:
|
104 |
+
attn.append(AttnBlock(block_in))
|
105 |
+
down = nn.Module()
|
106 |
+
down.block = block
|
107 |
+
down.attn = attn
|
108 |
+
if i_level != self.num_resolutions - 1:
|
109 |
+
down.downsample = Downsample(block_in, True)
|
110 |
+
curr_res = curr_res // 2
|
111 |
+
self.down.append(down)
|
112 |
+
|
113 |
+
# middle
|
114 |
+
self.mid = nn.Module()
|
115 |
+
self.mid.block_1 = ResnetBlock(
|
116 |
+
in_channels=block_in,
|
117 |
+
out_channels=block_in,
|
118 |
+
temb_channels=self.temb_ch,
|
119 |
+
dropout=dropout,
|
120 |
+
)
|
121 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
122 |
+
self.mid.block_2 = ResnetBlock(
|
123 |
+
in_channels=block_in,
|
124 |
+
out_channels=block_in,
|
125 |
+
temb_channels=self.temb_ch,
|
126 |
+
dropout=dropout,
|
127 |
+
)
|
128 |
+
|
129 |
+
|
130 |
+
self.norm_out = Normalize(block_in)
|
131 |
+
self.conv_out = torch.nn.Conv2d(
|
132 |
+
block_in,
|
133 |
+
2 * z_channels if double_z else z_channels,
|
134 |
+
kernel_size=3,
|
135 |
+
stride=1,
|
136 |
+
padding=1,
|
137 |
+
)
|
138 |
+
|
139 |
+
self.quant_conv = torch.nn.Conv2d(z_channels, z_channels, 1)
|
140 |
+
# for param in self.parameters():
|
141 |
+
# broadcast(param, src=0)
|
142 |
+
|
143 |
+
def forward(self, x):
|
144 |
+
# timestep embedding
|
145 |
+
temb = None
|
146 |
+
|
147 |
+
# downsampling
|
148 |
+
hs = [self.conv_in(x)]
|
149 |
+
for i_level in range(self.num_resolutions):
|
150 |
+
for i_block in range(self.num_res_blocks[i_level]):
|
151 |
+
h = self.down[i_level].block[i_block](hs[-1], temb)
|
152 |
+
if len(self.down[i_level].attn) > 0:
|
153 |
+
h = self.down[i_level].attn[i_block](h)
|
154 |
+
hs.append(h)
|
155 |
+
if i_level != self.num_resolutions - 1:
|
156 |
+
hs.append(self.down[i_level].downsample(hs[-1]))
|
157 |
+
|
158 |
+
# middle
|
159 |
+
h = hs[-1]
|
160 |
+
h = self.mid.block_1(h, temb)
|
161 |
+
h = self.mid.attn_1(h)
|
162 |
+
h = self.mid.block_2(h, temb)
|
163 |
+
|
164 |
+
# end
|
165 |
+
h = self.norm_out(h)
|
166 |
+
h = nonlinearity(h)
|
167 |
+
h = self.conv_out(h)
|
168 |
+
h = self.quant_conv(h)
|
169 |
+
return h
|
170 |
+
|
171 |
+
|
172 |
+
class LFQuantizer(nn.Module):
|
173 |
+
def __init__(self, num_codebook_entry: int = -1,
|
174 |
+
codebook_dim: int = 13,
|
175 |
+
beta: float = 0.25,
|
176 |
+
entropy_multiplier: float = 0.1,
|
177 |
+
commit_loss_multiplier: float = 0.1, ):
|
178 |
+
super().__init__()
|
179 |
+
self.codebook_size = 2 ** codebook_dim
|
180 |
+
print(
|
181 |
+
f"Look-up free quantizer with codebook size: {self.codebook_size}"
|
182 |
+
)
|
183 |
+
self.e_dim = codebook_dim
|
184 |
+
self.beta = beta
|
185 |
+
|
186 |
+
indices = torch.arange(self.codebook_size)
|
187 |
+
|
188 |
+
binary = (
|
189 |
+
indices.unsqueeze(1)
|
190 |
+
>> torch.arange(codebook_dim - 1, -1, -1, dtype=torch.long)
|
191 |
+
) & 1
|
192 |
+
|
193 |
+
embedding = binary.float() * 2 - 1
|
194 |
+
self.register_buffer("embedding", embedding)
|
195 |
+
self.register_buffer(
|
196 |
+
"power_vals", 2 ** torch.arange(codebook_dim - 1, -1, -1)
|
197 |
+
)
|
198 |
+
self.commit_loss_multiplier = commit_loss_multiplier
|
199 |
+
self.entropy_multiplier = entropy_multiplier
|
200 |
+
|
201 |
+
def get_indices(self, z_q):
|
202 |
+
return (
|
203 |
+
(self.power_vals.reshape(1, -1, 1, 1) * (z_q > 0).float())
|
204 |
+
.sum(1, keepdim=True)
|
205 |
+
.long()
|
206 |
+
)
|
207 |
+
|
208 |
+
def get_codebook_entry(self, indices, shape=None):
|
209 |
+
if shape is None:
|
210 |
+
h, w = int(math.sqrt(indices.shape[-1])), int(math.sqrt(indices.shape[-1]))
|
211 |
+
else:
|
212 |
+
h, w = shape
|
213 |
+
b, _ = indices.shape
|
214 |
+
indices = indices.reshape(-1)
|
215 |
+
z_q = self.embedding[indices]
|
216 |
+
z_q = z_q.view(b, h, w, -1)
|
217 |
+
|
218 |
+
# reshape back to match original input shape
|
219 |
+
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
220 |
+
|
221 |
+
return z_q
|
222 |
+
|
223 |
+
def forward(self, z, get_code=False):
|
224 |
+
"""
|
225 |
+
Inputs the output of the encoder network z and maps it to a discrete
|
226 |
+
one-hot vector that is the index of the closest embedding vector e_j
|
227 |
+
z (continuous) -> z_q (discrete)
|
228 |
+
z.shape = (batch, channel, height, width)
|
229 |
+
quantization pipeline:
|
230 |
+
1. get encoder input (B,C,H,W)
|
231 |
+
2. flatten input to (B*H*W,C)
|
232 |
+
"""
|
233 |
+
if get_code:
|
234 |
+
return self.get_codebook_entry(z)
|
235 |
+
|
236 |
+
# reshape z -> (batch, height, width, channel) and flatten
|
237 |
+
z = z.permute(0, 2, 3, 1).contiguous()
|
238 |
+
z_flattened = z.view(-1, self.e_dim)
|
239 |
+
ge_zero = (z_flattened > 0).float()
|
240 |
+
ones = torch.ones_like(z_flattened)
|
241 |
+
z_q = ones * ge_zero + -ones * (1 - ge_zero)
|
242 |
+
|
243 |
+
# preserve gradients
|
244 |
+
z_q = z_flattened + (z_q - z_flattened).detach()
|
245 |
+
|
246 |
+
# compute entropy loss
|
247 |
+
CatDist = torch.distributions.categorical.Categorical
|
248 |
+
logit = torch.stack(
|
249 |
+
[
|
250 |
+
-(z_flattened - torch.ones_like(z_q)).pow(2),
|
251 |
+
-(z_flattened - torch.ones_like(z_q) * -1).pow(2),
|
252 |
+
],
|
253 |
+
dim=-1,
|
254 |
+
)
|
255 |
+
cat_dist = CatDist(logits=logit)
|
256 |
+
entropy = cat_dist.entropy().mean()
|
257 |
+
mean_prob = cat_dist.probs.mean(0)
|
258 |
+
mean_entropy = CatDist(probs=mean_prob).entropy().mean()
|
259 |
+
|
260 |
+
# compute loss for embedding
|
261 |
+
commit_loss = torch.mean(
|
262 |
+
(z_q.detach() - z_flattened) ** 2
|
263 |
+
) + self.beta * torch.mean((z_q - z_flattened.detach()) ** 2)
|
264 |
+
|
265 |
+
# reshape back to match original input shape
|
266 |
+
z_q = z_q.view(z.shape)
|
267 |
+
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
268 |
+
|
269 |
+
return {
|
270 |
+
"z": z_q,
|
271 |
+
"quantizer_loss": commit_loss * self.commit_loss_multiplier,
|
272 |
+
"entropy_loss": (entropy - mean_entropy) * self.entropy_multiplier,
|
273 |
+
"indices": self.get_indices(z_q),
|
274 |
+
}
|
275 |
+
|
276 |
+
|
277 |
+
class VQGANDecoder(ModelMixin, ConfigMixin):
|
278 |
+
def __init__(self, ch: int = 128,
|
279 |
+
ch_mult: List[int] = [1, 1, 2, 2, 4],
|
280 |
+
num_res_blocks: List[int] = [4, 4, 3, 4, 3],
|
281 |
+
attn_resolutions: List[int] = [5],
|
282 |
+
dropout: float = 0.0,
|
283 |
+
in_ch: int = 3,
|
284 |
+
out_ch: int = 3,
|
285 |
+
resolution: int = 256,
|
286 |
+
z_channels: int = 13,
|
287 |
+
double_z: bool = False):
|
288 |
+
super().__init__()
|
289 |
+
self.ch = ch
|
290 |
+
self.temb_ch = 0
|
291 |
+
self.num_resolutions = len(ch_mult)
|
292 |
+
self.num_res_blocks = num_res_blocks
|
293 |
+
self.resolution = resolution
|
294 |
+
self.in_ch = in_ch
|
295 |
+
self.give_pre_end = False
|
296 |
+
|
297 |
+
self.z_channels = z_channels
|
298 |
+
# compute in_ch_mult, block_in and curr_res at lowest res
|
299 |
+
in_ch_mult = (1,) + tuple(ch_mult)
|
300 |
+
block_in = ch * ch_mult[self.num_resolutions - 1]
|
301 |
+
curr_res = self.resolution // 2 ** (self.num_resolutions - 1)
|
302 |
+
self.z_shape = (1, z_channels, curr_res, curr_res)
|
303 |
+
print(
|
304 |
+
"Working with z of shape {} = {} dimensions.".format(
|
305 |
+
self.z_shape, np.prod(self.z_shape)
|
306 |
+
)
|
307 |
+
)
|
308 |
+
|
309 |
+
# z to block_in
|
310 |
+
self.conv_in = torch.nn.Conv2d(
|
311 |
+
z_channels, block_in, kernel_size=3, stride=1, padding=1
|
312 |
+
)
|
313 |
+
|
314 |
+
# middle
|
315 |
+
self.mid = nn.Module()
|
316 |
+
self.mid.block_1 = ResnetBlock(
|
317 |
+
in_channels=block_in,
|
318 |
+
out_channels=block_in,
|
319 |
+
temb_channels=self.temb_ch,
|
320 |
+
dropout=dropout,
|
321 |
+
)
|
322 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
323 |
+
self.mid.block_2 = ResnetBlock(
|
324 |
+
in_channels=block_in,
|
325 |
+
out_channels=block_in,
|
326 |
+
temb_channels=self.temb_ch,
|
327 |
+
dropout=dropout,
|
328 |
+
)
|
329 |
+
|
330 |
+
# upsampling
|
331 |
+
self.up = nn.ModuleList()
|
332 |
+
for i_level in reversed(range(self.num_resolutions)):
|
333 |
+
block = nn.ModuleList()
|
334 |
+
attn = nn.ModuleList()
|
335 |
+
block_out = ch * ch_mult[i_level]
|
336 |
+
for i_block in range(self.num_res_blocks[i_level]):
|
337 |
+
block.append(
|
338 |
+
ResnetBlock(
|
339 |
+
in_channels=block_in,
|
340 |
+
out_channels=block_out,
|
341 |
+
temb_channels=self.temb_ch,
|
342 |
+
dropout=dropout,
|
343 |
+
)
|
344 |
+
)
|
345 |
+
block_in = block_out
|
346 |
+
if curr_res in attn_resolutions:
|
347 |
+
attn.append(AttnBlock(block_in))
|
348 |
+
up = nn.Module()
|
349 |
+
up.block = block
|
350 |
+
up.attn = attn
|
351 |
+
if i_level != 0:
|
352 |
+
up.upsample = Upsample(block_in, True)
|
353 |
+
curr_res = curr_res * 2
|
354 |
+
self.up.insert(0, up) # prepend to get consistent order
|
355 |
+
|
356 |
+
self.norm_out = Normalize(block_in)
|
357 |
+
self.conv_out = torch.nn.Conv2d(
|
358 |
+
block_in, out_ch, kernel_size=3, stride=1, padding=1
|
359 |
+
)
|
360 |
+
self.post_quant_conv = torch.nn.Conv2d(
|
361 |
+
z_channels, z_channels, 1
|
362 |
+
)
|
363 |
+
|
364 |
+
|
365 |
+
def forward(self, z):
|
366 |
+
# assert z.shape[1:] == self.z_shape[1:]
|
367 |
+
self.last_z_shape = z.shape
|
368 |
+
# timestep embedding
|
369 |
+
temb = None
|
370 |
+
output = dict()
|
371 |
+
z = self.post_quant_conv(z)
|
372 |
+
|
373 |
+
# z to block_in
|
374 |
+
h = self.conv_in(z)
|
375 |
+
|
376 |
+
# middle
|
377 |
+
h = self.mid.block_1(h, temb)
|
378 |
+
h = self.mid.attn_1(h)
|
379 |
+
h = self.mid.block_2(h, temb)
|
380 |
+
|
381 |
+
# upsampling
|
382 |
+
for i_level in reversed(range(self.num_resolutions)):
|
383 |
+
for i_block in range(self.num_res_blocks[i_level]):
|
384 |
+
h = self.up[i_level].block[i_block](h, temb)
|
385 |
+
if len(self.up[i_level].attn) > 0:
|
386 |
+
h = self.up[i_level].attn[i_block](h)
|
387 |
+
if i_level != 0:
|
388 |
+
h = self.up[i_level].upsample(h)
|
389 |
+
|
390 |
+
# end
|
391 |
+
output["output"] = h
|
392 |
+
if self.give_pre_end:
|
393 |
+
return output
|
394 |
+
|
395 |
+
h = self.norm_out(h)
|
396 |
+
h = nonlinearity(h)
|
397 |
+
h = self.conv_out(h)
|
398 |
+
output["output"] = h
|
399 |
+
return output
|
400 |
+
|
401 |
+
|
402 |
+
class MAGVITv2(ModelMixin, ConfigMixin):
|
403 |
+
@register_to_config
|
404 |
+
def __init__(
|
405 |
+
self,
|
406 |
+
):
|
407 |
+
super().__init__()
|
408 |
+
|
409 |
+
self.encoder = VQGANEncoder()
|
410 |
+
self.decoder = VQGANDecoder()
|
411 |
+
self.quantize = LFQuantizer()
|
412 |
+
|
413 |
+
def forward(self, pixel_values, return_loss=False):
|
414 |
+
pass
|
415 |
+
|
416 |
+
def encode(self, pixel_values, return_loss=False):
|
417 |
+
hidden_states = self.encoder(pixel_values)
|
418 |
+
quantized_states = self.quantize(hidden_states)['z']
|
419 |
+
codebook_indices = self.quantize.get_indices(quantized_states).reshape(pixel_values.shape[0], -1)
|
420 |
+
output = (quantized_states, codebook_indices)
|
421 |
+
return output
|
422 |
+
|
423 |
+
def get_code(self, pixel_values):
|
424 |
+
hidden_states = self.encoder(pixel_values)
|
425 |
+
codebook_indices = self.quantize.get_indices(self.quantize(hidden_states)['z']).reshape(pixel_values.shape[0], -1)
|
426 |
+
|
427 |
+
return codebook_indices
|
428 |
+
|
429 |
+
def decode_code(self, codebook_indices, shape=None):
|
430 |
+
z_q = self.quantize.get_codebook_entry(codebook_indices, shape=shape)
|
431 |
+
|
432 |
+
reconstructed_pixel_values = self.decoder(z_q)["output"]
|
433 |
+
return reconstructed_pixel_values
|
434 |
+
|
435 |
+
|
436 |
+
if __name__ == '__main__':
|
437 |
+
encoder = VQGANEncoder()
|
438 |
+
import ipdb
|
439 |
+
ipdb.set_trace()
|
440 |
+
print()
|
models/modeling_showo.py
ADDED
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from transformers import AutoConfig
|
4 |
+
from .modeling_utils import ConfigMixin, ModelMixin, register_to_config
|
5 |
+
from .sampling import cosine_schedule, mask_by_random_topk
|
6 |
+
from .phi import PhiForCausalLM
|
7 |
+
|
8 |
+
try:
|
9 |
+
import xformers.ops as xops
|
10 |
+
|
11 |
+
is_xformers_available = True
|
12 |
+
except ImportError:
|
13 |
+
is_xformers_available = False
|
14 |
+
|
15 |
+
|
16 |
+
class Showo(ModelMixin, ConfigMixin):
|
17 |
+
_supports_gradient_checkpointing = True
|
18 |
+
|
19 |
+
@register_to_config
|
20 |
+
def __init__(
|
21 |
+
self,
|
22 |
+
w_clip_vit,
|
23 |
+
vocab_size,
|
24 |
+
llm_vocab_size,
|
25 |
+
llm_model_path='',
|
26 |
+
codebook_size=8192,
|
27 |
+
num_vq_tokens=256,
|
28 |
+
**kwargs,
|
29 |
+
):
|
30 |
+
super().__init__()
|
31 |
+
self.vocab_size = vocab_size
|
32 |
+
self.register_to_config(mask_token_id=vocab_size - 1)
|
33 |
+
config = AutoConfig.from_pretrained(llm_model_path)
|
34 |
+
self.showo = PhiForCausalLM(config)
|
35 |
+
self.showo.resize_token_embeddings(self.vocab_size)
|
36 |
+
self.output_size = self.vocab_size
|
37 |
+
|
38 |
+
if self.w_clip_vit:
|
39 |
+
self.mm_projector = torch.nn.Sequential(
|
40 |
+
torch.nn.Linear(1024, 2048),
|
41 |
+
torch.nn.GELU(),
|
42 |
+
torch.nn.Linear(2048, 2048)
|
43 |
+
)
|
44 |
+
|
45 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
46 |
+
self.gradient_checkpointing = True
|
47 |
+
|
48 |
+
def forward(
|
49 |
+
self,
|
50 |
+
input_ids,
|
51 |
+
input_embeddings=None,
|
52 |
+
attention_mask=None,
|
53 |
+
labels=None,
|
54 |
+
label_smoothing=0.0,
|
55 |
+
config=None,
|
56 |
+
labels_mask_text=None,
|
57 |
+
labels_mask_image=None,
|
58 |
+
**kwargs,
|
59 |
+
):
|
60 |
+
|
61 |
+
if input_embeddings is None:
|
62 |
+
logits = self.showo(input_ids=input_ids, attention_mask=attention_mask)['logits']
|
63 |
+
else:
|
64 |
+
logits = self.showo(inputs_embeds=input_embeddings, attention_mask=attention_mask)['logits']
|
65 |
+
|
66 |
+
if labels is not None:
|
67 |
+
raise NotImplementedError
|
68 |
+
|
69 |
+
return logits
|
70 |
+
|
71 |
+
def t2i_generate(
|
72 |
+
self,
|
73 |
+
input_ids: torch.LongTensor = None,
|
74 |
+
uncond_input_ids: torch.LongTensor = None,
|
75 |
+
attention_mask=None,
|
76 |
+
temperature=1.0,
|
77 |
+
timesteps=18, # ideal number of steps is 18 in maskgit paper
|
78 |
+
guidance_scale=0,
|
79 |
+
noise_schedule=cosine_schedule,
|
80 |
+
generator: torch.Generator = None,
|
81 |
+
uni_prompting=None,
|
82 |
+
config=None,
|
83 |
+
**kwargs,
|
84 |
+
):
|
85 |
+
"""
|
86 |
+
Generate 1:1 similar to the original MaskGit repo
|
87 |
+
https://github.com/google-research/maskgit/blob/main/maskgit/libml/parallel_decode.py#L79
|
88 |
+
"""
|
89 |
+
# begin with all image token ids masked
|
90 |
+
mask_token_id = self.config.mask_token_id
|
91 |
+
seq_len = config.model.showo.num_vq_tokens
|
92 |
+
|
93 |
+
input_ids_minus_lm_vocab_size = input_ids[:, -(seq_len + 1):-1].clone()
|
94 |
+
input_ids_minus_lm_vocab_size = torch.where(input_ids_minus_lm_vocab_size == mask_token_id,
|
95 |
+
mask_token_id,
|
96 |
+
input_ids_minus_lm_vocab_size - config.model.showo.llm_vocab_size - 10)
|
97 |
+
# import ipdb
|
98 |
+
# ipdb.set_trace()
|
99 |
+
if uncond_input_ids is not None:
|
100 |
+
uncond_prefix = uncond_input_ids[:, :config.dataset.preprocessing.max_seq_length + 1]
|
101 |
+
|
102 |
+
for step in range(timesteps):
|
103 |
+
if uncond_input_ids is not None and guidance_scale > 0:
|
104 |
+
uncond_input_ids = torch.cat(
|
105 |
+
[uncond_prefix, input_ids[:, config.dataset.preprocessing.max_seq_length + 1:]], dim=1)
|
106 |
+
model_input = torch.cat([input_ids, uncond_input_ids])
|
107 |
+
cond_logits, uncond_logits = self(model_input, attention_mask=attention_mask).chunk(2)
|
108 |
+
# logits = uncond_logits + guidance_scale * (cond_logits - uncond_logits)
|
109 |
+
# it seems that muse has different cfg setting
|
110 |
+
logits = (1 + guidance_scale) * cond_logits - guidance_scale * uncond_logits
|
111 |
+
logits = logits[:, -(seq_len + 1):-1, config.model.showo.llm_vocab_size + 10:-1]
|
112 |
+
else:
|
113 |
+
logits = self(input_ids, attention_mask=attention_mask)
|
114 |
+
logits = logits[:, -(seq_len + 1):-1, config.model.showo.llm_vocab_size + 10:-1]
|
115 |
+
|
116 |
+
probs = logits.softmax(dim=-1)
|
117 |
+
sampled = probs.reshape(-1, logits.size(-1))
|
118 |
+
sampled_ids = torch.multinomial(sampled, 1, generator=generator)[:, 0].view(*logits.shape[:-1])
|
119 |
+
|
120 |
+
unknown_map = input_ids_minus_lm_vocab_size == mask_token_id
|
121 |
+
sampled_ids = torch.where(unknown_map, sampled_ids, input_ids_minus_lm_vocab_size)
|
122 |
+
# Defines the mask ratio for the next round. The number to mask out is
|
123 |
+
# determined by mask_ratio * unknown_number_in_the_beginning.
|
124 |
+
ratio = 1.0 * (step + 1) / timesteps
|
125 |
+
mask_ratio = noise_schedule(torch.tensor(ratio))
|
126 |
+
# Computes the probabilities of each selected tokens.
|
127 |
+
selected_probs = torch.gather(probs, -1, sampled_ids.long()[..., None])
|
128 |
+
selected_probs = selected_probs.squeeze(-1)
|
129 |
+
|
130 |
+
# Ignores the tokens given in the input by overwriting their confidence.
|
131 |
+
selected_probs = torch.where(unknown_map, selected_probs, torch.finfo(selected_probs.dtype).max)
|
132 |
+
# Gets mask lens for each sample in the batch according to the mask ratio.
|
133 |
+
mask_len = (seq_len * mask_ratio).floor().unsqueeze(0).to(logits.device)
|
134 |
+
# Keeps at least one of prediction in this round and also masks out at least
|
135 |
+
# one and for the next iteration
|
136 |
+
mask_len = torch.max(
|
137 |
+
torch.tensor([1], device=logits.device), torch.min(unknown_map.sum(dim=-1, keepdim=True) - 1, mask_len)
|
138 |
+
)
|
139 |
+
# Adds noise for randomness
|
140 |
+
temperature = temperature * (1.0 - ratio)
|
141 |
+
masking = mask_by_random_topk(mask_len, selected_probs, temperature, generator=generator)
|
142 |
+
# Masks tokens with lower confidence.
|
143 |
+
input_ids[:, -(seq_len + 1):-1] = torch.where(masking, mask_token_id,
|
144 |
+
sampled_ids + config.model.showo.llm_vocab_size + 10)
|
145 |
+
input_ids_minus_lm_vocab_size = torch.where(masking, mask_token_id, sampled_ids)
|
146 |
+
|
147 |
+
return sampled_ids
|
148 |
+
|
149 |
+
@torch.no_grad()
|
150 |
+
def mmu_generate(self, idx=None, input_embeddings=None, attention_mask=None, max_new_tokens=100, temperature=1.0, top_k=None, eot_token=None):
|
151 |
+
"""
|
152 |
+
Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
|
153 |
+
the sequence max_new_tokens times, feeding the predictions back into the model each time.
|
154 |
+
Most likely you'll want to make sure to be in model.eval() mode of operation for this.
|
155 |
+
"""
|
156 |
+
try:
|
157 |
+
device = idx.device
|
158 |
+
except:
|
159 |
+
device = input_embeddings.device
|
160 |
+
|
161 |
+
result = []
|
162 |
+
for _ in range(max_new_tokens):
|
163 |
+
# if the sequence context is growing too long we must crop it at block_size
|
164 |
+
# idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
|
165 |
+
# forward the model to get the logits for the index in the sequence
|
166 |
+
# logits, _ = self(idx_cond)
|
167 |
+
logits = self(idx, input_embeddings=input_embeddings, attention_mask=attention_mask)
|
168 |
+
|
169 |
+
L = attention_mask.shape[-1]
|
170 |
+
attention_mask = attention_mask.squeeze()
|
171 |
+
attention_mask_a = torch.hstack(
|
172 |
+
[
|
173 |
+
attention_mask, # L, L
|
174 |
+
torch.zeros((L, 1)).to(device) + torch.finfo(logits.dtype).min,
|
175 |
+
]
|
176 |
+
)
|
177 |
+
attention_mask_b = torch.vstack(
|
178 |
+
[
|
179 |
+
attention_mask_a, # L, L+1
|
180 |
+
torch.hstack([attention_mask[-1, :], torch.tensor([0]).to(device)]).unsqueeze(0),
|
181 |
+
]
|
182 |
+
)
|
183 |
+
attention_mask = attention_mask_b
|
184 |
+
|
185 |
+
# pluck the logits at the final step and scale by desired temperature
|
186 |
+
logits = logits[:, -1, :] / temperature
|
187 |
+
# optionally crop the logits to only the top k options
|
188 |
+
if top_k is not None:
|
189 |
+
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
190 |
+
logits[logits < v[:, [-1]]] = -float('Inf')
|
191 |
+
# apply softmax to convert logits to (normalized) probabilities
|
192 |
+
probs = F.softmax(logits, dim=-1)
|
193 |
+
# sample from the distribution
|
194 |
+
idx_next = torch.multinomial(probs, num_samples=1)
|
195 |
+
result.append(idx_next[0][0])
|
196 |
+
# append sampled index to the running sequence and continue
|
197 |
+
if self.config.w_clip_vit:
|
198 |
+
idx_next_embeddings = self.showo.model.embed_tokens(idx_next)
|
199 |
+
input_embeddings = torch.cat([input_embeddings, idx_next_embeddings], dim=1)
|
200 |
+
else:
|
201 |
+
idx = torch.cat((idx, idx_next), dim=1)
|
202 |
+
|
203 |
+
if eot_token is not None and idx_next.cpu() == eot_token:
|
204 |
+
break
|
205 |
+
|
206 |
+
return result
|
models/modeling_utils.py
ADDED
@@ -0,0 +1,1207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 The HuggingFace Inc. team.
|
3 |
+
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
|
17 |
+
import inspect
|
18 |
+
import itertools
|
19 |
+
import json
|
20 |
+
import os
|
21 |
+
import re
|
22 |
+
from collections import OrderedDict
|
23 |
+
from functools import partial
|
24 |
+
from pathlib import Path
|
25 |
+
from typing import Any, Callable, List, Optional, Tuple, Union
|
26 |
+
|
27 |
+
import safetensors
|
28 |
+
import torch
|
29 |
+
from huggingface_hub import create_repo, split_torch_state_dict_into_shards
|
30 |
+
from huggingface_hub.utils import validate_hf_hub_args
|
31 |
+
from torch import Tensor, nn
|
32 |
+
|
33 |
+
from diffusers import __version__
|
34 |
+
from diffusers.utils import (
|
35 |
+
FLAX_WEIGHTS_NAME,
|
36 |
+
SAFE_WEIGHTS_INDEX_NAME,
|
37 |
+
WEIGHTS_INDEX_NAME,
|
38 |
+
_add_variant,
|
39 |
+
_get_checkpoint_shard_files,
|
40 |
+
_get_model_file,
|
41 |
+
deprecate,
|
42 |
+
is_accelerate_available,
|
43 |
+
is_torch_version,
|
44 |
+
logging,
|
45 |
+
)
|
46 |
+
|
47 |
+
CONFIG_NAME = "config.json"
|
48 |
+
WEIGHTS_NAME = "pytorch_model.bin"
|
49 |
+
SAFETENSORS_WEIGHTS_NAME = "pytorch_model.safetensors"
|
50 |
+
HUGGINGFACE_CO_RESOLVE_ENDPOINT = "https://huggingface.co"
|
51 |
+
|
52 |
+
from diffusers.utils.hub_utils import (
|
53 |
+
PushToHubMixin,
|
54 |
+
load_or_create_model_card,
|
55 |
+
populate_model_card,
|
56 |
+
)
|
57 |
+
from diffusers.models.model_loading_utils import (
|
58 |
+
_determine_device_map,
|
59 |
+
_fetch_index_file,
|
60 |
+
_load_state_dict_into_model,
|
61 |
+
load_model_dict_into_meta,
|
62 |
+
load_state_dict,
|
63 |
+
)
|
64 |
+
|
65 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
66 |
+
|
67 |
+
logger = logging.get_logger(__name__)
|
68 |
+
|
69 |
+
_REGEX_SHARD = re.compile(r"(.*?)-\d{5}-of-\d{5}")
|
70 |
+
|
71 |
+
|
72 |
+
if is_torch_version(">=", "1.9.0"):
|
73 |
+
_LOW_CPU_MEM_USAGE_DEFAULT = True
|
74 |
+
else:
|
75 |
+
_LOW_CPU_MEM_USAGE_DEFAULT = False
|
76 |
+
|
77 |
+
|
78 |
+
if is_accelerate_available():
|
79 |
+
import accelerate
|
80 |
+
|
81 |
+
|
82 |
+
def get_parameter_device(parameter: torch.nn.Module) -> torch.device:
|
83 |
+
try:
|
84 |
+
parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers())
|
85 |
+
return next(parameters_and_buffers).device
|
86 |
+
except StopIteration:
|
87 |
+
# For torch.nn.DataParallel compatibility in PyTorch 1.5
|
88 |
+
|
89 |
+
def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
|
90 |
+
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
|
91 |
+
return tuples
|
92 |
+
|
93 |
+
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
|
94 |
+
first_tuple = next(gen)
|
95 |
+
return first_tuple[1].device
|
96 |
+
|
97 |
+
|
98 |
+
def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype:
|
99 |
+
try:
|
100 |
+
params = tuple(parameter.parameters())
|
101 |
+
if len(params) > 0:
|
102 |
+
return params[0].dtype
|
103 |
+
|
104 |
+
buffers = tuple(parameter.buffers())
|
105 |
+
if len(buffers) > 0:
|
106 |
+
return buffers[0].dtype
|
107 |
+
|
108 |
+
except StopIteration:
|
109 |
+
# For torch.nn.DataParallel compatibility in PyTorch 1.5
|
110 |
+
|
111 |
+
def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
|
112 |
+
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
|
113 |
+
return tuples
|
114 |
+
|
115 |
+
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
|
116 |
+
first_tuple = next(gen)
|
117 |
+
return first_tuple[1].dtype
|
118 |
+
|
119 |
+
|
120 |
+
class ModelMixin(torch.nn.Module, PushToHubMixin):
|
121 |
+
r"""
|
122 |
+
Base class for all models.
|
123 |
+
|
124 |
+
[`ModelMixin`] takes care of storing the model configuration and provides methods for loading, downloading and
|
125 |
+
saving models.
|
126 |
+
|
127 |
+
- **config_name** ([`str`]) -- Filename to save a model to when calling [`~models.ModelMixin.save_pretrained`].
|
128 |
+
"""
|
129 |
+
|
130 |
+
config_name = CONFIG_NAME
|
131 |
+
_automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
|
132 |
+
_supports_gradient_checkpointing = False
|
133 |
+
_keys_to_ignore_on_load_unexpected = None
|
134 |
+
_no_split_modules = None
|
135 |
+
|
136 |
+
def __init__(self):
|
137 |
+
super().__init__()
|
138 |
+
|
139 |
+
def __getattr__(self, name: str) -> Any:
|
140 |
+
"""The only reason we overwrite `getattr` here is to gracefully deprecate accessing
|
141 |
+
config attributes directly. See https://github.com/huggingface/diffusers/pull/3129 We need to overwrite
|
142 |
+
__getattr__ here in addition so that we don't trigger `torch.nn.Module`'s __getattr__':
|
143 |
+
https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
|
144 |
+
"""
|
145 |
+
|
146 |
+
is_in_config = "_internal_dict" in self.__dict__ and hasattr(self.__dict__["_internal_dict"], name)
|
147 |
+
is_attribute = name in self.__dict__
|
148 |
+
|
149 |
+
if is_in_config and not is_attribute:
|
150 |
+
deprecation_message = f"Accessing config attribute `{name}` directly via '{type(self).__name__}' object attribute is deprecated. Please access '{name}' over '{type(self).__name__}'s config object instead, e.g. 'unet.config.{name}'."
|
151 |
+
deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False, stacklevel=3)
|
152 |
+
return self._internal_dict[name]
|
153 |
+
|
154 |
+
# call PyTorch's https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
|
155 |
+
return super().__getattr__(name)
|
156 |
+
|
157 |
+
@property
|
158 |
+
def is_gradient_checkpointing(self) -> bool:
|
159 |
+
"""
|
160 |
+
Whether gradient checkpointing is activated for this model or not.
|
161 |
+
"""
|
162 |
+
return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
|
163 |
+
|
164 |
+
def enable_gradient_checkpointing(self) -> None:
|
165 |
+
"""
|
166 |
+
Activates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or
|
167 |
+
*checkpoint activations* in other frameworks).
|
168 |
+
"""
|
169 |
+
if not self._supports_gradient_checkpointing:
|
170 |
+
raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
|
171 |
+
self.apply(partial(self._set_gradient_checkpointing, value=True))
|
172 |
+
|
173 |
+
def disable_gradient_checkpointing(self) -> None:
|
174 |
+
"""
|
175 |
+
Deactivates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or
|
176 |
+
*checkpoint activations* in other frameworks).
|
177 |
+
"""
|
178 |
+
if self._supports_gradient_checkpointing:
|
179 |
+
self.apply(partial(self._set_gradient_checkpointing, value=False))
|
180 |
+
|
181 |
+
def set_use_npu_flash_attention(self, valid: bool) -> None:
|
182 |
+
r"""
|
183 |
+
Set the switch for the npu flash attention.
|
184 |
+
"""
|
185 |
+
|
186 |
+
def fn_recursive_set_npu_flash_attention(module: torch.nn.Module):
|
187 |
+
if hasattr(module, "set_use_npu_flash_attention"):
|
188 |
+
module.set_use_npu_flash_attention(valid)
|
189 |
+
|
190 |
+
for child in module.children():
|
191 |
+
fn_recursive_set_npu_flash_attention(child)
|
192 |
+
|
193 |
+
for module in self.children():
|
194 |
+
if isinstance(module, torch.nn.Module):
|
195 |
+
fn_recursive_set_npu_flash_attention(module)
|
196 |
+
|
197 |
+
def enable_npu_flash_attention(self) -> None:
|
198 |
+
r"""
|
199 |
+
Enable npu flash attention from torch_npu
|
200 |
+
|
201 |
+
"""
|
202 |
+
self.set_use_npu_flash_attention(True)
|
203 |
+
|
204 |
+
def disable_npu_flash_attention(self) -> None:
|
205 |
+
r"""
|
206 |
+
disable npu flash attention from torch_npu
|
207 |
+
|
208 |
+
"""
|
209 |
+
self.set_use_npu_flash_attention(False)
|
210 |
+
|
211 |
+
def set_use_memory_efficient_attention_xformers(
|
212 |
+
self, valid: bool, attention_op: Optional[Callable] = None
|
213 |
+
) -> None:
|
214 |
+
# Recursively walk through all the children.
|
215 |
+
# Any children which exposes the set_use_memory_efficient_attention_xformers method
|
216 |
+
# gets the message
|
217 |
+
def fn_recursive_set_mem_eff(module: torch.nn.Module):
|
218 |
+
if hasattr(module, "set_use_memory_efficient_attention_xformers"):
|
219 |
+
module.set_use_memory_efficient_attention_xformers(valid, attention_op)
|
220 |
+
|
221 |
+
for child in module.children():
|
222 |
+
fn_recursive_set_mem_eff(child)
|
223 |
+
|
224 |
+
for module in self.children():
|
225 |
+
if isinstance(module, torch.nn.Module):
|
226 |
+
fn_recursive_set_mem_eff(module)
|
227 |
+
|
228 |
+
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None) -> None:
|
229 |
+
r"""
|
230 |
+
Enable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/).
|
231 |
+
|
232 |
+
When this option is enabled, you should observe lower GPU memory usage and a potential speed up during
|
233 |
+
inference. Speed up during training is not guaranteed.
|
234 |
+
|
235 |
+
<Tip warning={true}>
|
236 |
+
|
237 |
+
⚠️ When memory efficient attention and sliced attention are both enabled, memory efficient attention takes
|
238 |
+
precedent.
|
239 |
+
|
240 |
+
</Tip>
|
241 |
+
|
242 |
+
Parameters:
|
243 |
+
attention_op (`Callable`, *optional*):
|
244 |
+
Override the default `None` operator for use as `op` argument to the
|
245 |
+
[`memory_efficient_attention()`](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.memory_efficient_attention)
|
246 |
+
function of xFormers.
|
247 |
+
|
248 |
+
Examples:
|
249 |
+
|
250 |
+
```py
|
251 |
+
>>> import torch
|
252 |
+
>>> from diffusers import UNet2DConditionModel
|
253 |
+
>>> from xformers.ops import MemoryEfficientAttentionFlashAttentionOp
|
254 |
+
|
255 |
+
>>> model = UNet2DConditionModel.from_pretrained(
|
256 |
+
... "stabilityai/stable-diffusion-2-1", subfolder="unet", torch_dtype=torch.float16
|
257 |
+
... )
|
258 |
+
>>> model = model.to("cuda")
|
259 |
+
>>> model.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp)
|
260 |
+
```
|
261 |
+
"""
|
262 |
+
self.set_use_memory_efficient_attention_xformers(True, attention_op)
|
263 |
+
|
264 |
+
def disable_xformers_memory_efficient_attention(self) -> None:
|
265 |
+
r"""
|
266 |
+
Disable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/).
|
267 |
+
"""
|
268 |
+
self.set_use_memory_efficient_attention_xformers(False)
|
269 |
+
|
270 |
+
def save_pretrained(
|
271 |
+
self,
|
272 |
+
save_directory: Union[str, os.PathLike],
|
273 |
+
is_main_process: bool = True,
|
274 |
+
save_function: Optional[Callable] = None,
|
275 |
+
safe_serialization: bool = True,
|
276 |
+
variant: Optional[str] = None,
|
277 |
+
max_shard_size: Union[int, str] = "10GB",
|
278 |
+
push_to_hub: bool = False,
|
279 |
+
**kwargs,
|
280 |
+
):
|
281 |
+
"""
|
282 |
+
Save a model and its configuration file to a directory so that it can be reloaded using the
|
283 |
+
[`~models.ModelMixin.from_pretrained`] class method.
|
284 |
+
|
285 |
+
Arguments:
|
286 |
+
save_directory (`str` or `os.PathLike`):
|
287 |
+
Directory to save a model and its configuration file to. Will be created if it doesn't exist.
|
288 |
+
is_main_process (`bool`, *optional*, defaults to `True`):
|
289 |
+
Whether the process calling this is the main process or not. Useful during distributed training and you
|
290 |
+
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
|
291 |
+
process to avoid race conditions.
|
292 |
+
save_function (`Callable`):
|
293 |
+
The function to use to save the state dictionary. Useful during distributed training when you need to
|
294 |
+
replace `torch.save` with another method. Can be configured with the environment variable
|
295 |
+
`DIFFUSERS_SAVE_MODE`.
|
296 |
+
safe_serialization (`bool`, *optional*, defaults to `True`):
|
297 |
+
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
298 |
+
variant (`str`, *optional*):
|
299 |
+
If specified, weights are saved in the format `pytorch_model.<variant>.bin`.
|
300 |
+
max_shard_size (`int` or `str`, defaults to `"10GB"`):
|
301 |
+
The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size
|
302 |
+
lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5GB"`).
|
303 |
+
If expressed as an integer, the unit is bytes. Note that this limit will be decreased after a certain
|
304 |
+
period of time (starting from Oct 2024) to allow users to upgrade to the latest version of `diffusers`.
|
305 |
+
This is to establish a common default size for this argument across different libraries in the Hugging
|
306 |
+
Face ecosystem (`transformers`, and `accelerate`, for example).
|
307 |
+
push_to_hub (`bool`, *optional*, defaults to `False`):
|
308 |
+
Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
|
309 |
+
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
|
310 |
+
namespace).
|
311 |
+
kwargs (`Dict[str, Any]`, *optional*):
|
312 |
+
Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
313 |
+
"""
|
314 |
+
if os.path.isfile(save_directory):
|
315 |
+
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
316 |
+
return
|
317 |
+
|
318 |
+
weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
|
319 |
+
weights_name = _add_variant(weights_name, variant)
|
320 |
+
weight_name_split = weights_name.split(".")
|
321 |
+
if len(weight_name_split) in [2, 3]:
|
322 |
+
weights_name_pattern = weight_name_split[0] + "{suffix}." + ".".join(weight_name_split[1:])
|
323 |
+
else:
|
324 |
+
raise ValueError(f"Invalid {weights_name} provided.")
|
325 |
+
|
326 |
+
os.makedirs(save_directory, exist_ok=True)
|
327 |
+
|
328 |
+
if push_to_hub:
|
329 |
+
commit_message = kwargs.pop("commit_message", None)
|
330 |
+
private = kwargs.pop("private", False)
|
331 |
+
create_pr = kwargs.pop("create_pr", False)
|
332 |
+
token = kwargs.pop("token", None)
|
333 |
+
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
|
334 |
+
repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id
|
335 |
+
|
336 |
+
# Only save the model itself if we are using distributed training
|
337 |
+
model_to_save = self
|
338 |
+
|
339 |
+
# Attach architecture to the config
|
340 |
+
# Save the config
|
341 |
+
if is_main_process:
|
342 |
+
model_to_save.save_config(save_directory)
|
343 |
+
|
344 |
+
# Save the model
|
345 |
+
state_dict = model_to_save.state_dict()
|
346 |
+
|
347 |
+
# Save the model
|
348 |
+
state_dict_split = split_torch_state_dict_into_shards(
|
349 |
+
state_dict, max_shard_size=max_shard_size, filename_pattern=weights_name_pattern
|
350 |
+
)
|
351 |
+
|
352 |
+
# Clean the folder from a previous save
|
353 |
+
if is_main_process:
|
354 |
+
for filename in os.listdir(save_directory):
|
355 |
+
if filename in state_dict_split.filename_to_tensors.keys():
|
356 |
+
continue
|
357 |
+
full_filename = os.path.join(save_directory, filename)
|
358 |
+
if not os.path.isfile(full_filename):
|
359 |
+
continue
|
360 |
+
weights_without_ext = weights_name_pattern.replace(".bin", "").replace(".safetensors", "")
|
361 |
+
weights_without_ext = weights_without_ext.replace("{suffix}", "")
|
362 |
+
filename_without_ext = filename.replace(".bin", "").replace(".safetensors", "")
|
363 |
+
# make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005
|
364 |
+
if (
|
365 |
+
filename.startswith(weights_without_ext)
|
366 |
+
and _REGEX_SHARD.fullmatch(filename_without_ext) is not None
|
367 |
+
):
|
368 |
+
os.remove(full_filename)
|
369 |
+
|
370 |
+
for filename, tensors in state_dict_split.filename_to_tensors.items():
|
371 |
+
shard = {tensor: state_dict[tensor] for tensor in tensors}
|
372 |
+
filepath = os.path.join(save_directory, filename)
|
373 |
+
if safe_serialization:
|
374 |
+
# At some point we will need to deal better with save_function (used for TPU and other distributed
|
375 |
+
# joyfulness), but for now this enough.
|
376 |
+
safetensors.torch.save_file(shard, filepath, metadata={"format": "pt"})
|
377 |
+
else:
|
378 |
+
torch.save(shard, filepath)
|
379 |
+
|
380 |
+
if state_dict_split.is_sharded:
|
381 |
+
index = {
|
382 |
+
"metadata": state_dict_split.metadata,
|
383 |
+
"weight_map": state_dict_split.tensor_to_filename,
|
384 |
+
}
|
385 |
+
save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
|
386 |
+
save_index_file = os.path.join(save_directory, _add_variant(save_index_file, variant))
|
387 |
+
# Save the index as well
|
388 |
+
with open(save_index_file, "w", encoding="utf-8") as f:
|
389 |
+
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
|
390 |
+
f.write(content)
|
391 |
+
logger.info(
|
392 |
+
f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
|
393 |
+
f"split in {len(state_dict_split.filename_to_tensors)} checkpoint shards. You can find where each parameters has been saved in the "
|
394 |
+
f"index located at {save_index_file}."
|
395 |
+
)
|
396 |
+
else:
|
397 |
+
path_to_weights = os.path.join(save_directory, weights_name)
|
398 |
+
logger.info(f"Model weights saved in {path_to_weights}")
|
399 |
+
|
400 |
+
if push_to_hub:
|
401 |
+
# Create a new empty model card and eventually tag it
|
402 |
+
model_card = load_or_create_model_card(repo_id, token=token)
|
403 |
+
model_card = populate_model_card(model_card)
|
404 |
+
model_card.save(Path(save_directory, "README.md").as_posix())
|
405 |
+
|
406 |
+
self._upload_folder(
|
407 |
+
save_directory,
|
408 |
+
repo_id,
|
409 |
+
token=token,
|
410 |
+
commit_message=commit_message,
|
411 |
+
create_pr=create_pr,
|
412 |
+
)
|
413 |
+
|
414 |
+
@classmethod
|
415 |
+
@validate_hf_hub_args
|
416 |
+
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
|
417 |
+
r"""
|
418 |
+
Instantiate a pretrained PyTorch model from a pretrained model configuration.
|
419 |
+
|
420 |
+
The model is set in evaluation mode - `model.eval()` - by default, and dropout modules are deactivated. To
|
421 |
+
train the model, set it back in training mode with `model.train()`.
|
422 |
+
|
423 |
+
Parameters:
|
424 |
+
pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
|
425 |
+
Can be either:
|
426 |
+
|
427 |
+
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
428 |
+
the Hub.
|
429 |
+
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
430 |
+
with [`~ModelMixin.save_pretrained`].
|
431 |
+
|
432 |
+
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
433 |
+
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
434 |
+
is not used.
|
435 |
+
torch_dtype (`str` or `torch.dtype`, *optional*):
|
436 |
+
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
|
437 |
+
dtype is automatically derived from the model's weights.
|
438 |
+
force_download (`bool`, *optional*, defaults to `False`):
|
439 |
+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
440 |
+
cached versions if they exist.
|
441 |
+
proxies (`Dict[str, str]`, *optional*):
|
442 |
+
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
443 |
+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
444 |
+
output_loading_info (`bool`, *optional*, defaults to `False`):
|
445 |
+
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
446 |
+
local_files_only(`bool`, *optional*, defaults to `False`):
|
447 |
+
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
448 |
+
won't be downloaded from the Hub.
|
449 |
+
token (`str` or *bool*, *optional*):
|
450 |
+
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
451 |
+
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
452 |
+
revision (`str`, *optional*, defaults to `"main"`):
|
453 |
+
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
454 |
+
allowed by Git.
|
455 |
+
from_flax (`bool`, *optional*, defaults to `False`):
|
456 |
+
Load the model weights from a Flax checkpoint save file.
|
457 |
+
subfolder (`str`, *optional*, defaults to `""`):
|
458 |
+
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
459 |
+
mirror (`str`, *optional*):
|
460 |
+
Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
|
461 |
+
guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
|
462 |
+
information.
|
463 |
+
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
|
464 |
+
A map that specifies where each submodule should go. It doesn't need to be defined for each
|
465 |
+
parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the
|
466 |
+
same device. Defaults to `None`, meaning that the model will be loaded on CPU.
|
467 |
+
|
468 |
+
Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For
|
469 |
+
more information about each option see [designing a device
|
470 |
+
map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
|
471 |
+
max_memory (`Dict`, *optional*):
|
472 |
+
A dictionary device identifier for the maximum memory. Will default to the maximum memory available for
|
473 |
+
each GPU and the available CPU RAM if unset.
|
474 |
+
offload_folder (`str` or `os.PathLike`, *optional*):
|
475 |
+
The path to offload weights if `device_map` contains the value `"disk"`.
|
476 |
+
offload_state_dict (`bool`, *optional*):
|
477 |
+
If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if
|
478 |
+
the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True`
|
479 |
+
when there is some disk offload.
|
480 |
+
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
481 |
+
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
|
482 |
+
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
|
483 |
+
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
|
484 |
+
argument to `True` will raise an error.
|
485 |
+
variant (`str`, *optional*):
|
486 |
+
Load weights from a specified `variant` filename such as `"fp16"` or `"ema"`. This is ignored when
|
487 |
+
loading `from_flax`.
|
488 |
+
use_safetensors (`bool`, *optional*, defaults to `None`):
|
489 |
+
If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the
|
490 |
+
`safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors`
|
491 |
+
weights. If set to `False`, `safetensors` weights are not loaded.
|
492 |
+
|
493 |
+
<Tip>
|
494 |
+
|
495 |
+
To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with
|
496 |
+
`huggingface-cli login`. You can also activate the special
|
497 |
+
["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a
|
498 |
+
firewalled environment.
|
499 |
+
|
500 |
+
</Tip>
|
501 |
+
|
502 |
+
Example:
|
503 |
+
|
504 |
+
```py
|
505 |
+
from diffusers import UNet2DConditionModel
|
506 |
+
|
507 |
+
unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet")
|
508 |
+
```
|
509 |
+
|
510 |
+
If you get the error message below, you need to finetune the weights for your downstream task:
|
511 |
+
|
512 |
+
```bash
|
513 |
+
Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
|
514 |
+
- conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated
|
515 |
+
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
|
516 |
+
```
|
517 |
+
"""
|
518 |
+
cache_dir = kwargs.pop("cache_dir", None)
|
519 |
+
ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
|
520 |
+
force_download = kwargs.pop("force_download", False)
|
521 |
+
from_flax = kwargs.pop("from_flax", False)
|
522 |
+
proxies = kwargs.pop("proxies", None)
|
523 |
+
output_loading_info = kwargs.pop("output_loading_info", False)
|
524 |
+
local_files_only = kwargs.pop("local_files_only", None)
|
525 |
+
token = kwargs.pop("token", None)
|
526 |
+
revision = kwargs.pop("revision", None)
|
527 |
+
torch_dtype = kwargs.pop("torch_dtype", None)
|
528 |
+
subfolder = kwargs.pop("subfolder", None)
|
529 |
+
device_map = kwargs.pop("device_map", None)
|
530 |
+
max_memory = kwargs.pop("max_memory", None)
|
531 |
+
offload_folder = kwargs.pop("offload_folder", None)
|
532 |
+
offload_state_dict = kwargs.pop("offload_state_dict", False)
|
533 |
+
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
534 |
+
variant = kwargs.pop("variant", None)
|
535 |
+
use_safetensors = kwargs.pop("use_safetensors", None)
|
536 |
+
|
537 |
+
allow_pickle = False
|
538 |
+
if use_safetensors is None:
|
539 |
+
use_safetensors = True
|
540 |
+
allow_pickle = True
|
541 |
+
|
542 |
+
if low_cpu_mem_usage and not is_accelerate_available():
|
543 |
+
low_cpu_mem_usage = False
|
544 |
+
logger.warning(
|
545 |
+
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
546 |
+
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
547 |
+
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
548 |
+
" install accelerate\n```\n."
|
549 |
+
)
|
550 |
+
|
551 |
+
if device_map is not None and not is_accelerate_available():
|
552 |
+
raise NotImplementedError(
|
553 |
+
"Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set"
|
554 |
+
" `device_map=None`. You can install accelerate with `pip install accelerate`."
|
555 |
+
)
|
556 |
+
|
557 |
+
# Check if we can handle device_map and dispatching the weights
|
558 |
+
if device_map is not None and not is_torch_version(">=", "1.9.0"):
|
559 |
+
raise NotImplementedError(
|
560 |
+
"Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
561 |
+
" `device_map=None`."
|
562 |
+
)
|
563 |
+
|
564 |
+
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
565 |
+
raise NotImplementedError(
|
566 |
+
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
567 |
+
" `low_cpu_mem_usage=False`."
|
568 |
+
)
|
569 |
+
|
570 |
+
if low_cpu_mem_usage is False and device_map is not None:
|
571 |
+
raise ValueError(
|
572 |
+
f"You cannot set `low_cpu_mem_usage` to `False` while using device_map={device_map} for loading and"
|
573 |
+
" dispatching. Please make sure to set `low_cpu_mem_usage=True`."
|
574 |
+
)
|
575 |
+
|
576 |
+
# change device_map into a map if we passed an int, a str or a torch.device
|
577 |
+
if isinstance(device_map, torch.device):
|
578 |
+
device_map = {"": device_map}
|
579 |
+
elif isinstance(device_map, str) and device_map not in ["auto", "balanced", "balanced_low_0", "sequential"]:
|
580 |
+
try:
|
581 |
+
device_map = {"": torch.device(device_map)}
|
582 |
+
except RuntimeError:
|
583 |
+
raise ValueError(
|
584 |
+
"When passing device_map as a string, the value needs to be a device name (e.g. cpu, cuda:0) or "
|
585 |
+
f"'auto', 'balanced', 'balanced_low_0', 'sequential' but found {device_map}."
|
586 |
+
)
|
587 |
+
elif isinstance(device_map, int):
|
588 |
+
if device_map < 0:
|
589 |
+
raise ValueError(
|
590 |
+
"You can't pass device_map as a negative int. If you want to put the model on the cpu, pass device_map = 'cpu' "
|
591 |
+
)
|
592 |
+
else:
|
593 |
+
device_map = {"": device_map}
|
594 |
+
|
595 |
+
if device_map is not None:
|
596 |
+
if low_cpu_mem_usage is None:
|
597 |
+
low_cpu_mem_usage = True
|
598 |
+
elif not low_cpu_mem_usage:
|
599 |
+
raise ValueError("Passing along a `device_map` requires `low_cpu_mem_usage=True`")
|
600 |
+
|
601 |
+
if low_cpu_mem_usage:
|
602 |
+
if device_map is not None and not is_torch_version(">=", "1.10"):
|
603 |
+
# The max memory utils require PyTorch >= 1.10 to have torch.cuda.mem_get_info.
|
604 |
+
raise ValueError("`low_cpu_mem_usage` and `device_map` require PyTorch >= 1.10.")
|
605 |
+
|
606 |
+
# Load config if we don't provide a configuration
|
607 |
+
config_path = pretrained_model_name_or_path
|
608 |
+
|
609 |
+
user_agent = {
|
610 |
+
"diffusers": __version__,
|
611 |
+
"file_type": "model",
|
612 |
+
"framework": "pytorch",
|
613 |
+
}
|
614 |
+
|
615 |
+
# load config
|
616 |
+
config, unused_kwargs, commit_hash = cls.load_config(
|
617 |
+
config_path,
|
618 |
+
cache_dir=cache_dir,
|
619 |
+
return_unused_kwargs=True,
|
620 |
+
return_commit_hash=True,
|
621 |
+
force_download=force_download,
|
622 |
+
proxies=proxies,
|
623 |
+
local_files_only=local_files_only,
|
624 |
+
token=token,
|
625 |
+
revision=revision,
|
626 |
+
subfolder=subfolder,
|
627 |
+
user_agent=user_agent,
|
628 |
+
**kwargs,
|
629 |
+
)
|
630 |
+
|
631 |
+
# Determine if we're loading from a directory of sharded checkpoints.
|
632 |
+
is_sharded = False
|
633 |
+
index_file = None
|
634 |
+
is_local = os.path.isdir(pretrained_model_name_or_path)
|
635 |
+
index_file = _fetch_index_file(
|
636 |
+
is_local=is_local,
|
637 |
+
pretrained_model_name_or_path=pretrained_model_name_or_path,
|
638 |
+
subfolder=subfolder or "",
|
639 |
+
use_safetensors=use_safetensors,
|
640 |
+
cache_dir=cache_dir,
|
641 |
+
variant=variant,
|
642 |
+
force_download=force_download,
|
643 |
+
proxies=proxies,
|
644 |
+
local_files_only=local_files_only,
|
645 |
+
token=token,
|
646 |
+
revision=revision,
|
647 |
+
user_agent=user_agent,
|
648 |
+
commit_hash=commit_hash,
|
649 |
+
)
|
650 |
+
if index_file is not None and index_file.is_file():
|
651 |
+
is_sharded = True
|
652 |
+
|
653 |
+
if is_sharded and from_flax:
|
654 |
+
raise ValueError("Loading of sharded checkpoints is not supported when `from_flax=True`.")
|
655 |
+
|
656 |
+
# load model
|
657 |
+
model_file = None
|
658 |
+
if from_flax:
|
659 |
+
model_file = _get_model_file(
|
660 |
+
pretrained_model_name_or_path,
|
661 |
+
weights_name=FLAX_WEIGHTS_NAME,
|
662 |
+
cache_dir=cache_dir,
|
663 |
+
force_download=force_download,
|
664 |
+
proxies=proxies,
|
665 |
+
local_files_only=local_files_only,
|
666 |
+
token=token,
|
667 |
+
revision=revision,
|
668 |
+
subfolder=subfolder,
|
669 |
+
user_agent=user_agent,
|
670 |
+
commit_hash=commit_hash,
|
671 |
+
)
|
672 |
+
model = cls.from_config(config, **unused_kwargs)
|
673 |
+
|
674 |
+
# Convert the weights
|
675 |
+
from .modeling_pytorch_flax_utils import load_flax_checkpoint_in_pytorch_model
|
676 |
+
|
677 |
+
model = load_flax_checkpoint_in_pytorch_model(model, model_file)
|
678 |
+
else:
|
679 |
+
if is_sharded:
|
680 |
+
sharded_ckpt_cached_folder, sharded_metadata = _get_checkpoint_shard_files(
|
681 |
+
pretrained_model_name_or_path,
|
682 |
+
index_file,
|
683 |
+
cache_dir=cache_dir,
|
684 |
+
proxies=proxies,
|
685 |
+
local_files_only=local_files_only,
|
686 |
+
token=token,
|
687 |
+
user_agent=user_agent,
|
688 |
+
revision=revision,
|
689 |
+
subfolder=subfolder or "",
|
690 |
+
)
|
691 |
+
|
692 |
+
elif use_safetensors and not is_sharded:
|
693 |
+
try:
|
694 |
+
model_file = _get_model_file(
|
695 |
+
pretrained_model_name_or_path,
|
696 |
+
weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
|
697 |
+
cache_dir=cache_dir,
|
698 |
+
force_download=force_download,
|
699 |
+
proxies=proxies,
|
700 |
+
local_files_only=local_files_only,
|
701 |
+
token=token,
|
702 |
+
revision=revision,
|
703 |
+
subfolder=subfolder,
|
704 |
+
user_agent=user_agent,
|
705 |
+
commit_hash=commit_hash,
|
706 |
+
)
|
707 |
+
|
708 |
+
except IOError as e:
|
709 |
+
logger.error(f"An error occurred while trying to fetch {pretrained_model_name_or_path}: {e}")
|
710 |
+
if not allow_pickle:
|
711 |
+
raise
|
712 |
+
logger.warning(
|
713 |
+
"Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead."
|
714 |
+
)
|
715 |
+
|
716 |
+
if model_file is None and not is_sharded:
|
717 |
+
model_file = _get_model_file(
|
718 |
+
pretrained_model_name_or_path,
|
719 |
+
weights_name=_add_variant(WEIGHTS_NAME, variant),
|
720 |
+
cache_dir=cache_dir,
|
721 |
+
force_download=force_download,
|
722 |
+
proxies=proxies,
|
723 |
+
local_files_only=local_files_only,
|
724 |
+
token=token,
|
725 |
+
revision=revision,
|
726 |
+
subfolder=subfolder,
|
727 |
+
user_agent=user_agent,
|
728 |
+
commit_hash=commit_hash,
|
729 |
+
)
|
730 |
+
|
731 |
+
if low_cpu_mem_usage:
|
732 |
+
# Instantiate model with empty weights
|
733 |
+
with accelerate.init_empty_weights():
|
734 |
+
model = cls.from_config(config, **unused_kwargs)
|
735 |
+
|
736 |
+
# if device_map is None, load the state dict and move the params from meta device to the cpu
|
737 |
+
if device_map is None and not is_sharded:
|
738 |
+
param_device = "cpu"
|
739 |
+
state_dict = load_state_dict(model_file, variant=variant)
|
740 |
+
model._convert_deprecated_attention_blocks(state_dict)
|
741 |
+
# move the params from meta device to cpu
|
742 |
+
missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
|
743 |
+
if len(missing_keys) > 0:
|
744 |
+
raise ValueError(
|
745 |
+
f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are"
|
746 |
+
f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
|
747 |
+
" `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
|
748 |
+
" those weights or else make sure your checkpoint file is correct."
|
749 |
+
)
|
750 |
+
|
751 |
+
unexpected_keys = load_model_dict_into_meta(
|
752 |
+
model,
|
753 |
+
state_dict,
|
754 |
+
device=param_device,
|
755 |
+
dtype=torch_dtype,
|
756 |
+
model_name_or_path=pretrained_model_name_or_path,
|
757 |
+
)
|
758 |
+
|
759 |
+
if cls._keys_to_ignore_on_load_unexpected is not None:
|
760 |
+
for pat in cls._keys_to_ignore_on_load_unexpected:
|
761 |
+
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
762 |
+
|
763 |
+
if len(unexpected_keys) > 0:
|
764 |
+
logger.warning(
|
765 |
+
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
|
766 |
+
)
|
767 |
+
|
768 |
+
else: # else let accelerate handle loading and dispatching.
|
769 |
+
# Load weights and dispatch according to the device_map
|
770 |
+
# by default the device_map is None and the weights are loaded on the CPU
|
771 |
+
force_hook = True
|
772 |
+
device_map = _determine_device_map(model, device_map, max_memory, torch_dtype)
|
773 |
+
if device_map is None and is_sharded:
|
774 |
+
# we load the parameters on the cpu
|
775 |
+
device_map = {"": "cpu"}
|
776 |
+
force_hook = False
|
777 |
+
try:
|
778 |
+
accelerate.load_checkpoint_and_dispatch(
|
779 |
+
model,
|
780 |
+
model_file if not is_sharded else index_file,
|
781 |
+
device_map,
|
782 |
+
max_memory=max_memory,
|
783 |
+
offload_folder=offload_folder,
|
784 |
+
offload_state_dict=offload_state_dict,
|
785 |
+
dtype=torch_dtype,
|
786 |
+
force_hooks=force_hook,
|
787 |
+
strict=True,
|
788 |
+
)
|
789 |
+
except AttributeError as e:
|
790 |
+
# When using accelerate loading, we do not have the ability to load the state
|
791 |
+
# dict and rename the weight names manually. Additionally, accelerate skips
|
792 |
+
# torch loading conventions and directly writes into `module.{_buffers, _parameters}`
|
793 |
+
# (which look like they should be private variables?), so we can't use the standard hooks
|
794 |
+
# to rename parameters on load. We need to mimic the original weight names so the correct
|
795 |
+
# attributes are available. After we have loaded the weights, we convert the deprecated
|
796 |
+
# names to the new non-deprecated names. Then we _greatly encourage_ the user to convert
|
797 |
+
# the weights so we don't have to do this again.
|
798 |
+
|
799 |
+
if "'Attention' object has no attribute" in str(e):
|
800 |
+
logger.warning(
|
801 |
+
f"Taking `{str(e)}` while using `accelerate.load_checkpoint_and_dispatch` to mean {pretrained_model_name_or_path}"
|
802 |
+
" was saved with deprecated attention block weight names. We will load it with the deprecated attention block"
|
803 |
+
" names and convert them on the fly to the new attention block format. Please re-save the model after this conversion,"
|
804 |
+
" so we don't have to do the on the fly renaming in the future. If the model is from a hub checkpoint,"
|
805 |
+
" please also re-upload it or open a PR on the original repository."
|
806 |
+
)
|
807 |
+
model._temp_convert_self_to_deprecated_attention_blocks()
|
808 |
+
accelerate.load_checkpoint_and_dispatch(
|
809 |
+
model,
|
810 |
+
model_file if not is_sharded else index_file,
|
811 |
+
device_map,
|
812 |
+
max_memory=max_memory,
|
813 |
+
offload_folder=offload_folder,
|
814 |
+
offload_state_dict=offload_state_dict,
|
815 |
+
dtype=torch_dtype,
|
816 |
+
force_hooks=force_hook,
|
817 |
+
strict=True,
|
818 |
+
)
|
819 |
+
model._undo_temp_convert_self_to_deprecated_attention_blocks()
|
820 |
+
else:
|
821 |
+
raise e
|
822 |
+
|
823 |
+
loading_info = {
|
824 |
+
"missing_keys": [],
|
825 |
+
"unexpected_keys": [],
|
826 |
+
"mismatched_keys": [],
|
827 |
+
"error_msgs": [],
|
828 |
+
}
|
829 |
+
else:
|
830 |
+
model = cls.from_config(config, **unused_kwargs)
|
831 |
+
|
832 |
+
state_dict = load_state_dict(model_file, variant=variant)
|
833 |
+
model._convert_deprecated_attention_blocks(state_dict)
|
834 |
+
|
835 |
+
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
|
836 |
+
model,
|
837 |
+
state_dict,
|
838 |
+
model_file,
|
839 |
+
pretrained_model_name_or_path,
|
840 |
+
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
841 |
+
)
|
842 |
+
|
843 |
+
loading_info = {
|
844 |
+
"missing_keys": missing_keys,
|
845 |
+
"unexpected_keys": unexpected_keys,
|
846 |
+
"mismatched_keys": mismatched_keys,
|
847 |
+
"error_msgs": error_msgs,
|
848 |
+
}
|
849 |
+
|
850 |
+
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
|
851 |
+
raise ValueError(
|
852 |
+
f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
|
853 |
+
)
|
854 |
+
elif torch_dtype is not None:
|
855 |
+
model = model.to(torch_dtype)
|
856 |
+
|
857 |
+
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
|
858 |
+
|
859 |
+
# Set model in evaluation mode to deactivate DropOut modules by default
|
860 |
+
model.eval()
|
861 |
+
if output_loading_info:
|
862 |
+
return model, loading_info
|
863 |
+
|
864 |
+
return model
|
865 |
+
|
866 |
+
@classmethod
|
867 |
+
def _load_pretrained_model(
|
868 |
+
cls,
|
869 |
+
model,
|
870 |
+
state_dict: OrderedDict,
|
871 |
+
resolved_archive_file,
|
872 |
+
pretrained_model_name_or_path: Union[str, os.PathLike],
|
873 |
+
ignore_mismatched_sizes: bool = False,
|
874 |
+
):
|
875 |
+
# Retrieve missing & unexpected_keys
|
876 |
+
model_state_dict = model.state_dict()
|
877 |
+
loaded_keys = list(state_dict.keys())
|
878 |
+
|
879 |
+
expected_keys = list(model_state_dict.keys())
|
880 |
+
|
881 |
+
original_loaded_keys = loaded_keys
|
882 |
+
|
883 |
+
missing_keys = list(set(expected_keys) - set(loaded_keys))
|
884 |
+
unexpected_keys = list(set(loaded_keys) - set(expected_keys))
|
885 |
+
|
886 |
+
# Make sure we are able to load base models as well as derived models (with heads)
|
887 |
+
model_to_load = model
|
888 |
+
|
889 |
+
def _find_mismatched_keys(
|
890 |
+
state_dict,
|
891 |
+
model_state_dict,
|
892 |
+
loaded_keys,
|
893 |
+
ignore_mismatched_sizes,
|
894 |
+
):
|
895 |
+
mismatched_keys = []
|
896 |
+
if ignore_mismatched_sizes:
|
897 |
+
for checkpoint_key in loaded_keys:
|
898 |
+
model_key = checkpoint_key
|
899 |
+
|
900 |
+
if (
|
901 |
+
model_key in model_state_dict
|
902 |
+
and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
|
903 |
+
):
|
904 |
+
mismatched_keys.append(
|
905 |
+
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
|
906 |
+
)
|
907 |
+
del state_dict[checkpoint_key]
|
908 |
+
return mismatched_keys
|
909 |
+
|
910 |
+
if state_dict is not None:
|
911 |
+
# Whole checkpoint
|
912 |
+
mismatched_keys = _find_mismatched_keys(
|
913 |
+
state_dict,
|
914 |
+
model_state_dict,
|
915 |
+
original_loaded_keys,
|
916 |
+
ignore_mismatched_sizes,
|
917 |
+
)
|
918 |
+
error_msgs = _load_state_dict_into_model(model_to_load, state_dict)
|
919 |
+
|
920 |
+
if len(error_msgs) > 0:
|
921 |
+
error_msg = "\n\t".join(error_msgs)
|
922 |
+
if "size mismatch" in error_msg:
|
923 |
+
error_msg += (
|
924 |
+
"\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
|
925 |
+
)
|
926 |
+
raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
|
927 |
+
|
928 |
+
if len(unexpected_keys) > 0:
|
929 |
+
logger.warning(
|
930 |
+
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
|
931 |
+
f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
|
932 |
+
f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task"
|
933 |
+
" or with another architecture (e.g. initializing a BertForSequenceClassification model from a"
|
934 |
+
" BertForPreTraining model).\n- This IS NOT expected if you are initializing"
|
935 |
+
f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly"
|
936 |
+
" identical (initializing a BertForSequenceClassification model from a"
|
937 |
+
" BertForSequenceClassification model)."
|
938 |
+
)
|
939 |
+
else:
|
940 |
+
logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
|
941 |
+
if len(missing_keys) > 0:
|
942 |
+
logger.warning(
|
943 |
+
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
|
944 |
+
f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
|
945 |
+
" TRAIN this model on a down-stream task to be able to use it for predictions and inference."
|
946 |
+
)
|
947 |
+
elif len(mismatched_keys) == 0:
|
948 |
+
logger.info(
|
949 |
+
f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
|
950 |
+
f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the"
|
951 |
+
f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions"
|
952 |
+
" without further training."
|
953 |
+
)
|
954 |
+
if len(mismatched_keys) > 0:
|
955 |
+
mismatched_warning = "\n".join(
|
956 |
+
[
|
957 |
+
f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
|
958 |
+
for key, shape1, shape2 in mismatched_keys
|
959 |
+
]
|
960 |
+
)
|
961 |
+
logger.warning(
|
962 |
+
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
|
963 |
+
f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
|
964 |
+
f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be"
|
965 |
+
" able to use it for predictions and inference."
|
966 |
+
)
|
967 |
+
|
968 |
+
return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
|
969 |
+
|
970 |
+
@classmethod
|
971 |
+
def _get_signature_keys(cls, obj):
|
972 |
+
parameters = inspect.signature(obj.__init__).parameters
|
973 |
+
required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty}
|
974 |
+
optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty})
|
975 |
+
expected_modules = set(required_parameters.keys()) - {"self"}
|
976 |
+
|
977 |
+
return expected_modules, optional_parameters
|
978 |
+
|
979 |
+
# Adapted from `transformers` modeling_utils.py
|
980 |
+
def _get_no_split_modules(self, device_map: str):
|
981 |
+
"""
|
982 |
+
Get the modules of the model that should not be spit when using device_map. We iterate through the modules to
|
983 |
+
get the underlying `_no_split_modules`.
|
984 |
+
|
985 |
+
Args:
|
986 |
+
device_map (`str`):
|
987 |
+
The device map value. Options are ["auto", "balanced", "balanced_low_0", "sequential"]
|
988 |
+
|
989 |
+
Returns:
|
990 |
+
`List[str]`: List of modules that should not be split
|
991 |
+
"""
|
992 |
+
_no_split_modules = set()
|
993 |
+
modules_to_check = [self]
|
994 |
+
while len(modules_to_check) > 0:
|
995 |
+
module = modules_to_check.pop(-1)
|
996 |
+
# if the module does not appear in _no_split_modules, we also check the children
|
997 |
+
if module.__class__.__name__ not in _no_split_modules:
|
998 |
+
if isinstance(module, ModelMixin):
|
999 |
+
if module._no_split_modules is None:
|
1000 |
+
raise ValueError(
|
1001 |
+
f"{module.__class__.__name__} does not support `device_map='{device_map}'`. To implement support, the model "
|
1002 |
+
"class needs to implement the `_no_split_modules` attribute."
|
1003 |
+
)
|
1004 |
+
else:
|
1005 |
+
_no_split_modules = _no_split_modules | set(module._no_split_modules)
|
1006 |
+
modules_to_check += list(module.children())
|
1007 |
+
return list(_no_split_modules)
|
1008 |
+
|
1009 |
+
@property
|
1010 |
+
def device(self) -> torch.device:
|
1011 |
+
"""
|
1012 |
+
`torch.device`: The device on which the module is (assuming that all the module parameters are on the same
|
1013 |
+
device).
|
1014 |
+
"""
|
1015 |
+
return get_parameter_device(self)
|
1016 |
+
|
1017 |
+
@property
|
1018 |
+
def dtype(self) -> torch.dtype:
|
1019 |
+
"""
|
1020 |
+
`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
|
1021 |
+
"""
|
1022 |
+
return get_parameter_dtype(self)
|
1023 |
+
|
1024 |
+
def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int:
|
1025 |
+
"""
|
1026 |
+
Get number of (trainable or non-embedding) parameters in the module.
|
1027 |
+
|
1028 |
+
Args:
|
1029 |
+
only_trainable (`bool`, *optional*, defaults to `False`):
|
1030 |
+
Whether or not to return only the number of trainable parameters.
|
1031 |
+
exclude_embeddings (`bool`, *optional*, defaults to `False`):
|
1032 |
+
Whether or not to return only the number of non-embedding parameters.
|
1033 |
+
|
1034 |
+
Returns:
|
1035 |
+
`int`: The number of parameters.
|
1036 |
+
|
1037 |
+
Example:
|
1038 |
+
|
1039 |
+
```py
|
1040 |
+
from diffusers import UNet2DConditionModel
|
1041 |
+
|
1042 |
+
model_id = "runwayml/stable-diffusion-v1-5"
|
1043 |
+
unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet")
|
1044 |
+
unet.num_parameters(only_trainable=True)
|
1045 |
+
859520964
|
1046 |
+
```
|
1047 |
+
"""
|
1048 |
+
|
1049 |
+
if exclude_embeddings:
|
1050 |
+
embedding_param_names = [
|
1051 |
+
f"{name}.weight"
|
1052 |
+
for name, module_type in self.named_modules()
|
1053 |
+
if isinstance(module_type, torch.nn.Embedding)
|
1054 |
+
]
|
1055 |
+
non_embedding_parameters = [
|
1056 |
+
parameter for name, parameter in self.named_parameters() if name not in embedding_param_names
|
1057 |
+
]
|
1058 |
+
return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable)
|
1059 |
+
else:
|
1060 |
+
return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
|
1061 |
+
|
1062 |
+
def _convert_deprecated_attention_blocks(self, state_dict: OrderedDict) -> None:
|
1063 |
+
deprecated_attention_block_paths = []
|
1064 |
+
|
1065 |
+
def recursive_find_attn_block(name, module):
|
1066 |
+
if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
|
1067 |
+
deprecated_attention_block_paths.append(name)
|
1068 |
+
|
1069 |
+
for sub_name, sub_module in module.named_children():
|
1070 |
+
sub_name = sub_name if name == "" else f"{name}.{sub_name}"
|
1071 |
+
recursive_find_attn_block(sub_name, sub_module)
|
1072 |
+
|
1073 |
+
recursive_find_attn_block("", self)
|
1074 |
+
|
1075 |
+
# NOTE: we have to check if the deprecated parameters are in the state dict
|
1076 |
+
# because it is possible we are loading from a state dict that was already
|
1077 |
+
# converted
|
1078 |
+
|
1079 |
+
for path in deprecated_attention_block_paths:
|
1080 |
+
# group_norm path stays the same
|
1081 |
+
|
1082 |
+
# query -> to_q
|
1083 |
+
if f"{path}.query.weight" in state_dict:
|
1084 |
+
state_dict[f"{path}.to_q.weight"] = state_dict.pop(f"{path}.query.weight")
|
1085 |
+
if f"{path}.query.bias" in state_dict:
|
1086 |
+
state_dict[f"{path}.to_q.bias"] = state_dict.pop(f"{path}.query.bias")
|
1087 |
+
|
1088 |
+
# key -> to_k
|
1089 |
+
if f"{path}.key.weight" in state_dict:
|
1090 |
+
state_dict[f"{path}.to_k.weight"] = state_dict.pop(f"{path}.key.weight")
|
1091 |
+
if f"{path}.key.bias" in state_dict:
|
1092 |
+
state_dict[f"{path}.to_k.bias"] = state_dict.pop(f"{path}.key.bias")
|
1093 |
+
|
1094 |
+
# value -> to_v
|
1095 |
+
if f"{path}.value.weight" in state_dict:
|
1096 |
+
state_dict[f"{path}.to_v.weight"] = state_dict.pop(f"{path}.value.weight")
|
1097 |
+
if f"{path}.value.bias" in state_dict:
|
1098 |
+
state_dict[f"{path}.to_v.bias"] = state_dict.pop(f"{path}.value.bias")
|
1099 |
+
|
1100 |
+
# proj_attn -> to_out.0
|
1101 |
+
if f"{path}.proj_attn.weight" in state_dict:
|
1102 |
+
state_dict[f"{path}.to_out.0.weight"] = state_dict.pop(f"{path}.proj_attn.weight")
|
1103 |
+
if f"{path}.proj_attn.bias" in state_dict:
|
1104 |
+
state_dict[f"{path}.to_out.0.bias"] = state_dict.pop(f"{path}.proj_attn.bias")
|
1105 |
+
|
1106 |
+
def _temp_convert_self_to_deprecated_attention_blocks(self) -> None:
|
1107 |
+
deprecated_attention_block_modules = []
|
1108 |
+
|
1109 |
+
def recursive_find_attn_block(module):
|
1110 |
+
if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
|
1111 |
+
deprecated_attention_block_modules.append(module)
|
1112 |
+
|
1113 |
+
for sub_module in module.children():
|
1114 |
+
recursive_find_attn_block(sub_module)
|
1115 |
+
|
1116 |
+
recursive_find_attn_block(self)
|
1117 |
+
|
1118 |
+
for module in deprecated_attention_block_modules:
|
1119 |
+
module.query = module.to_q
|
1120 |
+
module.key = module.to_k
|
1121 |
+
module.value = module.to_v
|
1122 |
+
module.proj_attn = module.to_out[0]
|
1123 |
+
|
1124 |
+
# We don't _have_ to delete the old attributes, but it's helpful to ensure
|
1125 |
+
# that _all_ the weights are loaded into the new attributes and we're not
|
1126 |
+
# making an incorrect assumption that this model should be converted when
|
1127 |
+
# it really shouldn't be.
|
1128 |
+
del module.to_q
|
1129 |
+
del module.to_k
|
1130 |
+
del module.to_v
|
1131 |
+
del module.to_out
|
1132 |
+
|
1133 |
+
def _undo_temp_convert_self_to_deprecated_attention_blocks(self) -> None:
|
1134 |
+
deprecated_attention_block_modules = []
|
1135 |
+
|
1136 |
+
def recursive_find_attn_block(module) -> None:
|
1137 |
+
if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
|
1138 |
+
deprecated_attention_block_modules.append(module)
|
1139 |
+
|
1140 |
+
for sub_module in module.children():
|
1141 |
+
recursive_find_attn_block(sub_module)
|
1142 |
+
|
1143 |
+
recursive_find_attn_block(self)
|
1144 |
+
|
1145 |
+
for module in deprecated_attention_block_modules:
|
1146 |
+
module.to_q = module.query
|
1147 |
+
module.to_k = module.key
|
1148 |
+
module.to_v = module.value
|
1149 |
+
module.to_out = nn.ModuleList([module.proj_attn, nn.Dropout(module.dropout)])
|
1150 |
+
|
1151 |
+
del module.query
|
1152 |
+
del module.key
|
1153 |
+
del module.value
|
1154 |
+
del module.proj_attn
|
1155 |
+
|
1156 |
+
|
1157 |
+
class LegacyModelMixin(ModelMixin):
|
1158 |
+
r"""
|
1159 |
+
A subclass of `ModelMixin` to resolve class mapping from legacy classes (like `Transformer2DModel`) to more
|
1160 |
+
pipeline-specific classes (like `DiTTransformer2DModel`).
|
1161 |
+
"""
|
1162 |
+
|
1163 |
+
@classmethod
|
1164 |
+
@validate_hf_hub_args
|
1165 |
+
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
|
1166 |
+
# To prevent dependency import problem.
|
1167 |
+
from diffusers.models.model_loading_utils import _fetch_remapped_cls_from_config
|
1168 |
+
|
1169 |
+
# Create a copy of the kwargs so that we don't mess with the keyword arguments in the downstream calls.
|
1170 |
+
kwargs_copy = kwargs.copy()
|
1171 |
+
|
1172 |
+
cache_dir = kwargs.pop("cache_dir", None)
|
1173 |
+
force_download = kwargs.pop("force_download", False)
|
1174 |
+
proxies = kwargs.pop("proxies", None)
|
1175 |
+
local_files_only = kwargs.pop("local_files_only", None)
|
1176 |
+
token = kwargs.pop("token", None)
|
1177 |
+
revision = kwargs.pop("revision", None)
|
1178 |
+
subfolder = kwargs.pop("subfolder", None)
|
1179 |
+
|
1180 |
+
# Load config if we don't provide a configuration
|
1181 |
+
config_path = pretrained_model_name_or_path
|
1182 |
+
|
1183 |
+
user_agent = {
|
1184 |
+
"diffusers": __version__,
|
1185 |
+
"file_type": "model",
|
1186 |
+
"framework": "pytorch",
|
1187 |
+
}
|
1188 |
+
|
1189 |
+
# load config
|
1190 |
+
config, _, _ = cls.load_config(
|
1191 |
+
config_path,
|
1192 |
+
cache_dir=cache_dir,
|
1193 |
+
return_unused_kwargs=True,
|
1194 |
+
return_commit_hash=True,
|
1195 |
+
force_download=force_download,
|
1196 |
+
proxies=proxies,
|
1197 |
+
local_files_only=local_files_only,
|
1198 |
+
token=token,
|
1199 |
+
revision=revision,
|
1200 |
+
subfolder=subfolder,
|
1201 |
+
user_agent=user_agent,
|
1202 |
+
**kwargs,
|
1203 |
+
)
|
1204 |
+
# resolve remapping
|
1205 |
+
remapped_class = _fetch_remapped_cls_from_config(config, cls)
|
1206 |
+
|
1207 |
+
return remapped_class.from_pretrained(pretrained_model_name_or_path, **kwargs_copy)
|
models/phi.py
ADDED
@@ -0,0 +1,1489 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 Microsoft and the HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""PyTorch Phi model."""
|
17 |
+
|
18 |
+
import math
|
19 |
+
from typing import List, Optional, Tuple, Union
|
20 |
+
|
21 |
+
import torch
|
22 |
+
import torch.nn.functional as F
|
23 |
+
import torch.utils.checkpoint
|
24 |
+
from packaging import version
|
25 |
+
from torch import nn
|
26 |
+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
27 |
+
|
28 |
+
from transformers.activations import ACT2FN
|
29 |
+
from transformers.cache_utils import Cache, DynamicCache
|
30 |
+
from transformers.modeling_attn_mask_utils import (
|
31 |
+
_prepare_4d_causal_attention_mask,
|
32 |
+
_prepare_4d_causal_attention_mask_for_sdpa,
|
33 |
+
)
|
34 |
+
from transformers.modeling_outputs import (
|
35 |
+
BaseModelOutputWithPast,
|
36 |
+
CausalLMOutputWithPast,
|
37 |
+
SequenceClassifierOutputWithPast,
|
38 |
+
TokenClassifierOutput,
|
39 |
+
)
|
40 |
+
from transformers.modeling_utils import PreTrainedModel
|
41 |
+
from transformers.utils import (
|
42 |
+
add_code_sample_docstrings,
|
43 |
+
add_start_docstrings,
|
44 |
+
add_start_docstrings_to_model_forward,
|
45 |
+
get_torch_version,
|
46 |
+
is_flash_attn_2_available,
|
47 |
+
is_flash_attn_greater_or_equal_2_10,
|
48 |
+
logging,
|
49 |
+
replace_return_docstrings,
|
50 |
+
)
|
51 |
+
from transformers.models.phi.configuration_phi import PhiConfig
|
52 |
+
|
53 |
+
|
54 |
+
if is_flash_attn_2_available():
|
55 |
+
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
56 |
+
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
57 |
+
|
58 |
+
|
59 |
+
logger = logging.get_logger(__name__)
|
60 |
+
|
61 |
+
_CHECKPOINT_FOR_DOC = "microsoft/phi-1"
|
62 |
+
_CONFIG_FOR_DOC = "PhiConfig"
|
63 |
+
|
64 |
+
|
65 |
+
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
|
66 |
+
def _get_unpad_data(attention_mask):
|
67 |
+
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
68 |
+
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
69 |
+
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
70 |
+
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
|
71 |
+
return (
|
72 |
+
indices,
|
73 |
+
cu_seqlens,
|
74 |
+
max_seqlen_in_batch,
|
75 |
+
)
|
76 |
+
|
77 |
+
|
78 |
+
# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Phi
|
79 |
+
class PhiRotaryEmbedding(nn.Module):
|
80 |
+
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
81 |
+
super().__init__()
|
82 |
+
|
83 |
+
self.dim = dim
|
84 |
+
self.max_position_embeddings = max_position_embeddings
|
85 |
+
self.base = base
|
86 |
+
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
|
87 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
88 |
+
|
89 |
+
# Build here to make `torch.jit.trace` work.
|
90 |
+
self._set_cos_sin_cache(
|
91 |
+
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
|
92 |
+
)
|
93 |
+
|
94 |
+
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
95 |
+
self.max_seq_len_cached = seq_len
|
96 |
+
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
|
97 |
+
|
98 |
+
freqs = torch.outer(t, self.inv_freq)
|
99 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
100 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
101 |
+
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
102 |
+
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
103 |
+
|
104 |
+
def forward(self, x, seq_len=None):
|
105 |
+
# x: [bs, num_attention_heads, seq_len, head_size]
|
106 |
+
if seq_len > self.max_seq_len_cached:
|
107 |
+
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
108 |
+
|
109 |
+
return (
|
110 |
+
self.cos_cached[:seq_len].to(dtype=x.dtype),
|
111 |
+
self.sin_cached[:seq_len].to(dtype=x.dtype),
|
112 |
+
)
|
113 |
+
|
114 |
+
|
115 |
+
# Copied from transformers.models.falcon.modeling_falcon.FalconLinearScalingRotaryEmbedding with Falcon->Phi
|
116 |
+
class PhiLinearScalingRotaryEmbedding(PhiRotaryEmbedding):
|
117 |
+
"""PhiRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
|
118 |
+
|
119 |
+
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
|
120 |
+
self.scaling_factor = scaling_factor
|
121 |
+
super().__init__(dim, max_position_embeddings, base, device)
|
122 |
+
|
123 |
+
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
124 |
+
self.max_seq_len_cached = seq_len
|
125 |
+
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
|
126 |
+
t = t / self.scaling_factor
|
127 |
+
|
128 |
+
freqs = torch.outer(t, self.inv_freq)
|
129 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
130 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
131 |
+
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
132 |
+
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
133 |
+
|
134 |
+
|
135 |
+
# Copied from transformers.models.falcon.modeling_falcon.FalconDynamicNTKScalingRotaryEmbedding with Falcon->Phi
|
136 |
+
class PhiDynamicNTKScalingRotaryEmbedding(PhiRotaryEmbedding):
|
137 |
+
"""PhiRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
138 |
+
|
139 |
+
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
|
140 |
+
self.scaling_factor = scaling_factor
|
141 |
+
super().__init__(dim, max_position_embeddings, base, device)
|
142 |
+
|
143 |
+
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
144 |
+
self.max_seq_len_cached = seq_len
|
145 |
+
|
146 |
+
if seq_len > self.max_position_embeddings:
|
147 |
+
base = self.base * (
|
148 |
+
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
|
149 |
+
) ** (self.dim / (self.dim - 2))
|
150 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
|
151 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
152 |
+
|
153 |
+
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
|
154 |
+
|
155 |
+
freqs = torch.outer(t, self.inv_freq)
|
156 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
157 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
158 |
+
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
159 |
+
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
160 |
+
|
161 |
+
|
162 |
+
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
163 |
+
def rotate_half(x):
|
164 |
+
"""Rotates half the hidden dims of the input."""
|
165 |
+
x1 = x[..., : x.shape[-1] // 2]
|
166 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
167 |
+
return torch.cat((-x2, x1), dim=-1)
|
168 |
+
|
169 |
+
|
170 |
+
# Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb
|
171 |
+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
172 |
+
"""Applies Rotary Position Embedding to the query and key tensors.
|
173 |
+
|
174 |
+
Args:
|
175 |
+
q (`torch.Tensor`): The query tensor.
|
176 |
+
k (`torch.Tensor`): The key tensor.
|
177 |
+
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
178 |
+
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
179 |
+
position_ids (`torch.Tensor`):
|
180 |
+
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
|
181 |
+
used to pass offsetted position ids when working with a KV-cache.
|
182 |
+
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
183 |
+
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
184 |
+
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
185 |
+
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
186 |
+
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
187 |
+
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
188 |
+
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
189 |
+
Returns:
|
190 |
+
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
191 |
+
"""
|
192 |
+
cos = cos[position_ids].unsqueeze(unsqueeze_dim)
|
193 |
+
sin = sin[position_ids].unsqueeze(unsqueeze_dim)
|
194 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
195 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
196 |
+
return q_embed, k_embed
|
197 |
+
|
198 |
+
|
199 |
+
# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Phi
|
200 |
+
class PhiMLP(nn.Module):
|
201 |
+
def __init__(self, config):
|
202 |
+
super().__init__()
|
203 |
+
self.config = config
|
204 |
+
self.activation_fn = ACT2FN[config.hidden_act]
|
205 |
+
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
|
206 |
+
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
|
207 |
+
|
208 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
209 |
+
hidden_states = self.fc1(hidden_states)
|
210 |
+
hidden_states = self.activation_fn(hidden_states)
|
211 |
+
hidden_states = self.fc2(hidden_states)
|
212 |
+
return hidden_states
|
213 |
+
|
214 |
+
|
215 |
+
# Copied from transformers.models.llama.modeling_llama.repeat_kv with llama->phi
|
216 |
+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
217 |
+
"""
|
218 |
+
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
219 |
+
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
220 |
+
"""
|
221 |
+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
222 |
+
if n_rep == 1:
|
223 |
+
return hidden_states
|
224 |
+
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
225 |
+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
226 |
+
|
227 |
+
|
228 |
+
class PhiAttention(nn.Module):
|
229 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
230 |
+
|
231 |
+
def __init__(self, config: PhiConfig, layer_idx: Optional[int] = None):
|
232 |
+
super().__init__()
|
233 |
+
self.config = config
|
234 |
+
self.layer_idx = layer_idx
|
235 |
+
if layer_idx is None:
|
236 |
+
logger.warning_once(
|
237 |
+
f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
|
238 |
+
"lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
|
239 |
+
"when creating this class."
|
240 |
+
)
|
241 |
+
|
242 |
+
self.attention_dropout = config.attention_dropout
|
243 |
+
self.hidden_size = config.hidden_size
|
244 |
+
self.num_heads = config.num_attention_heads
|
245 |
+
self.head_dim = self.hidden_size // self.num_heads
|
246 |
+
self.num_key_value_heads = config.num_key_value_heads
|
247 |
+
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
248 |
+
self.max_position_embeddings = config.max_position_embeddings
|
249 |
+
self.rope_theta = config.rope_theta
|
250 |
+
self.partial_rotary_factor = config.partial_rotary_factor
|
251 |
+
self.is_causal = True
|
252 |
+
|
253 |
+
if (self.head_dim * self.num_heads) != self.hidden_size:
|
254 |
+
raise ValueError(
|
255 |
+
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
256 |
+
f" and `num_heads`: {self.num_heads})."
|
257 |
+
)
|
258 |
+
|
259 |
+
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
|
260 |
+
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
|
261 |
+
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
|
262 |
+
self.dense = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=True)
|
263 |
+
|
264 |
+
self.qk_layernorm = config.qk_layernorm
|
265 |
+
if self.qk_layernorm:
|
266 |
+
self.q_layernorm = nn.LayerNorm(
|
267 |
+
config.hidden_size // self.num_heads, eps=config.layer_norm_eps, elementwise_affine=True
|
268 |
+
)
|
269 |
+
self.k_layernorm = nn.LayerNorm(
|
270 |
+
config.hidden_size // self.num_heads, eps=config.layer_norm_eps, elementwise_affine=True
|
271 |
+
)
|
272 |
+
|
273 |
+
self._init_rope()
|
274 |
+
|
275 |
+
def _init_rope(self):
|
276 |
+
if self.config.rope_scaling is None:
|
277 |
+
self.rotary_emb = PhiRotaryEmbedding(
|
278 |
+
int(self.partial_rotary_factor * self.head_dim),
|
279 |
+
max_position_embeddings=self.max_position_embeddings,
|
280 |
+
base=self.rope_theta,
|
281 |
+
)
|
282 |
+
else:
|
283 |
+
scaling_type = self.config.rope_scaling["type"]
|
284 |
+
scaling_factor = self.config.rope_scaling["factor"]
|
285 |
+
if scaling_type == "linear":
|
286 |
+
self.rotary_emb = PhiLinearScalingRotaryEmbedding(
|
287 |
+
int(self.partial_rotary_factor * self.head_dim),
|
288 |
+
max_position_embeddings=self.max_position_embeddings,
|
289 |
+
scaling_factor=scaling_factor,
|
290 |
+
base=self.rope_theta,
|
291 |
+
)
|
292 |
+
elif scaling_type == "dynamic":
|
293 |
+
self.rotary_emb = PhiDynamicNTKScalingRotaryEmbedding(
|
294 |
+
int(self.partial_rotary_factor * self.head_dim),
|
295 |
+
max_position_embeddings=self.max_position_embeddings,
|
296 |
+
scaling_factor=scaling_factor,
|
297 |
+
base=self.rope_theta,
|
298 |
+
)
|
299 |
+
else:
|
300 |
+
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
301 |
+
|
302 |
+
def forward(
|
303 |
+
self,
|
304 |
+
hidden_states: torch.Tensor,
|
305 |
+
attention_mask: Optional[torch.Tensor] = None,
|
306 |
+
position_ids: Optional[torch.LongTensor] = None,
|
307 |
+
past_key_value: Optional[Cache] = None,
|
308 |
+
output_attentions: bool = False,
|
309 |
+
use_cache: bool = False,
|
310 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
311 |
+
bsz, q_len, _ = hidden_states.size()
|
312 |
+
|
313 |
+
query_states = self.q_proj(hidden_states)
|
314 |
+
key_states = self.k_proj(hidden_states)
|
315 |
+
value_states = self.v_proj(hidden_states)
|
316 |
+
|
317 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
318 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
319 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
320 |
+
|
321 |
+
if self.qk_layernorm:
|
322 |
+
query_states = self.q_layernorm(query_states)
|
323 |
+
key_states = self.k_layernorm(key_states)
|
324 |
+
|
325 |
+
kv_seq_len = key_states.shape[-2]
|
326 |
+
if past_key_value is not None:
|
327 |
+
if self.layer_idx is None:
|
328 |
+
raise ValueError(
|
329 |
+
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
330 |
+
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
331 |
+
"with a layer index."
|
332 |
+
)
|
333 |
+
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
334 |
+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
335 |
+
|
336 |
+
# Partial rotary embedding
|
337 |
+
query_rot, query_pass = (
|
338 |
+
query_states[..., : self.rotary_emb.dim],
|
339 |
+
query_states[..., self.rotary_emb.dim :],
|
340 |
+
)
|
341 |
+
key_rot, key_pass = (
|
342 |
+
key_states[..., : self.rotary_emb.dim],
|
343 |
+
key_states[..., self.rotary_emb.dim :],
|
344 |
+
)
|
345 |
+
# [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
|
346 |
+
query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids)
|
347 |
+
|
348 |
+
# [batch_size, seq_length, num_heads, head_dim]
|
349 |
+
query_states = torch.cat((query_rot, query_pass), dim=-1)
|
350 |
+
key_states = torch.cat((key_rot, key_pass), dim=-1)
|
351 |
+
|
352 |
+
if past_key_value is not None:
|
353 |
+
cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim}
|
354 |
+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
355 |
+
|
356 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
357 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
358 |
+
|
359 |
+
# Queries and keys upcast to fp32 is required by Phi-2 to avoid overflow
|
360 |
+
attn_weights = torch.matmul(
|
361 |
+
query_states.to(torch.float32), key_states.to(torch.float32).transpose(2, 3)
|
362 |
+
) / math.sqrt(self.head_dim)
|
363 |
+
|
364 |
+
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
365 |
+
raise ValueError(
|
366 |
+
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
367 |
+
f" {attn_weights.size()}"
|
368 |
+
)
|
369 |
+
|
370 |
+
if attention_mask is not None:
|
371 |
+
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
372 |
+
raise ValueError(
|
373 |
+
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
374 |
+
)
|
375 |
+
attn_weights = attn_weights + attention_mask
|
376 |
+
|
377 |
+
# upcast attention to fp32
|
378 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(value_states.dtype)
|
379 |
+
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
380 |
+
|
381 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
382 |
+
|
383 |
+
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
384 |
+
raise ValueError(
|
385 |
+
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
386 |
+
f" {attn_output.size()}"
|
387 |
+
)
|
388 |
+
|
389 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
390 |
+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
391 |
+
|
392 |
+
attn_output = self.dense(attn_output)
|
393 |
+
|
394 |
+
if not output_attentions:
|
395 |
+
attn_weights = None
|
396 |
+
|
397 |
+
return attn_output, attn_weights, past_key_value
|
398 |
+
|
399 |
+
|
400 |
+
class PhiFlashAttention2(PhiAttention):
|
401 |
+
"""
|
402 |
+
Phi flash attention module. This module inherits from `PhiAttention` as the weights of the module stays
|
403 |
+
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
|
404 |
+
flash attention and deal with padding tokens in case the input contains any of them.
|
405 |
+
"""
|
406 |
+
|
407 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
|
408 |
+
def __init__(self, *args, **kwargs):
|
409 |
+
super().__init__(*args, **kwargs)
|
410 |
+
|
411 |
+
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
412 |
+
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
413 |
+
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
414 |
+
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
415 |
+
|
416 |
+
def forward(
|
417 |
+
self,
|
418 |
+
hidden_states: torch.Tensor,
|
419 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
420 |
+
position_ids: Optional[torch.LongTensor] = None,
|
421 |
+
past_key_value: Optional[Cache] = None,
|
422 |
+
output_attentions: bool = False,
|
423 |
+
use_cache: bool = False,
|
424 |
+
**kwargs,
|
425 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
426 |
+
# PhiFlashAttention2 attention does not support output_attentions
|
427 |
+
|
428 |
+
output_attentions = False
|
429 |
+
|
430 |
+
bsz, q_len, _ = hidden_states.size()
|
431 |
+
|
432 |
+
query_states = self.q_proj(hidden_states)
|
433 |
+
key_states = self.k_proj(hidden_states)
|
434 |
+
value_states = self.v_proj(hidden_states)
|
435 |
+
|
436 |
+
# Flash attention requires the input to have the shape
|
437 |
+
# batch_size x seq_length x head_dim x hidden_dim
|
438 |
+
# therefore we just need to keep the original shape
|
439 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
440 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
441 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
442 |
+
|
443 |
+
if self.qk_layernorm:
|
444 |
+
query_states = self.q_layernorm(query_states)
|
445 |
+
key_states = self.k_layernorm(key_states)
|
446 |
+
|
447 |
+
kv_seq_len = key_states.shape[-2]
|
448 |
+
if past_key_value is not None:
|
449 |
+
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
450 |
+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
451 |
+
|
452 |
+
# Partial rotary embedding
|
453 |
+
query_rot, query_pass = (
|
454 |
+
query_states[..., : self.rotary_emb.dim],
|
455 |
+
query_states[..., self.rotary_emb.dim :],
|
456 |
+
)
|
457 |
+
key_rot, key_pass = (
|
458 |
+
key_states[..., : self.rotary_emb.dim],
|
459 |
+
key_states[..., self.rotary_emb.dim :],
|
460 |
+
)
|
461 |
+
# [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
|
462 |
+
query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids)
|
463 |
+
|
464 |
+
# [batch_size, seq_length, num_heads, head_dim]
|
465 |
+
query_states = torch.cat((query_rot, query_pass), dim=-1)
|
466 |
+
key_states = torch.cat((key_rot, key_pass), dim=-1)
|
467 |
+
|
468 |
+
if past_key_value is not None:
|
469 |
+
cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim}
|
470 |
+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
471 |
+
|
472 |
+
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
|
473 |
+
# to be able to avoid many of these transpose/reshape/view.
|
474 |
+
query_states = query_states.transpose(1, 2)
|
475 |
+
key_states = key_states.transpose(1, 2)
|
476 |
+
value_states = value_states.transpose(1, 2)
|
477 |
+
|
478 |
+
attn_dropout = self.attention_dropout if self.training else 0.0
|
479 |
+
|
480 |
+
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
481 |
+
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
482 |
+
# cast them back in the correct dtype just to be sure everything works as expected.
|
483 |
+
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
484 |
+
# in fp32.
|
485 |
+
|
486 |
+
if query_states.dtype == torch.float32:
|
487 |
+
if torch.is_autocast_enabled():
|
488 |
+
target_dtype = torch.get_autocast_gpu_dtype()
|
489 |
+
# Handle the case where the model is quantized
|
490 |
+
elif hasattr(self.config, "_pre_quantization_dtype"):
|
491 |
+
target_dtype = self.config._pre_quantization_dtype
|
492 |
+
else:
|
493 |
+
target_dtype = self.q_proj.weight.dtype
|
494 |
+
|
495 |
+
logger.warning_once(
|
496 |
+
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
497 |
+
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
498 |
+
f" {target_dtype}."
|
499 |
+
)
|
500 |
+
|
501 |
+
query_states = query_states.to(target_dtype)
|
502 |
+
key_states = key_states.to(target_dtype)
|
503 |
+
value_states = value_states.to(target_dtype)
|
504 |
+
|
505 |
+
attn_output = self._flash_attention_forward(
|
506 |
+
query_states, key_states, value_states, attention_mask, q_len, dropout=attn_dropout, softmax_scale=None
|
507 |
+
)
|
508 |
+
|
509 |
+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
|
510 |
+
attn_output = self.dense(attn_output)
|
511 |
+
|
512 |
+
if not output_attentions:
|
513 |
+
attn_weights = None
|
514 |
+
|
515 |
+
return attn_output, attn_weights, past_key_value
|
516 |
+
|
517 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
|
518 |
+
def _flash_attention_forward(
|
519 |
+
self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
|
520 |
+
):
|
521 |
+
"""
|
522 |
+
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
|
523 |
+
first unpad the input, then computes the attention scores and pad the final attention scores.
|
524 |
+
|
525 |
+
Args:
|
526 |
+
query_states (`torch.Tensor`):
|
527 |
+
Input query states to be passed to Flash Attention API
|
528 |
+
key_states (`torch.Tensor`):
|
529 |
+
Input key states to be passed to Flash Attention API
|
530 |
+
value_states (`torch.Tensor`):
|
531 |
+
Input value states to be passed to Flash Attention API
|
532 |
+
attention_mask (`torch.Tensor`):
|
533 |
+
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
|
534 |
+
position of padding tokens and 1 for the position of non-padding tokens.
|
535 |
+
dropout (`float`):
|
536 |
+
Attention dropout
|
537 |
+
softmax_scale (`float`, *optional*):
|
538 |
+
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
|
539 |
+
"""
|
540 |
+
if not self._flash_attn_uses_top_left_mask:
|
541 |
+
causal = self.is_causal
|
542 |
+
else:
|
543 |
+
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
|
544 |
+
causal = self.is_causal and query_length != 1
|
545 |
+
|
546 |
+
# Contains at least one padding token in the sequence
|
547 |
+
if attention_mask is not None:
|
548 |
+
batch_size = query_states.shape[0]
|
549 |
+
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
|
550 |
+
query_states, key_states, value_states, attention_mask, query_length
|
551 |
+
)
|
552 |
+
|
553 |
+
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
554 |
+
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
555 |
+
|
556 |
+
attn_output_unpad = flash_attn_varlen_func(
|
557 |
+
query_states,
|
558 |
+
key_states,
|
559 |
+
value_states,
|
560 |
+
cu_seqlens_q=cu_seqlens_q,
|
561 |
+
cu_seqlens_k=cu_seqlens_k,
|
562 |
+
max_seqlen_q=max_seqlen_in_batch_q,
|
563 |
+
max_seqlen_k=max_seqlen_in_batch_k,
|
564 |
+
dropout_p=dropout,
|
565 |
+
softmax_scale=softmax_scale,
|
566 |
+
causal=causal,
|
567 |
+
)
|
568 |
+
|
569 |
+
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
|
570 |
+
else:
|
571 |
+
attn_output = flash_attn_func(
|
572 |
+
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
|
573 |
+
)
|
574 |
+
|
575 |
+
return attn_output
|
576 |
+
|
577 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
|
578 |
+
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
|
579 |
+
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
|
580 |
+
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
|
581 |
+
|
582 |
+
key_layer = index_first_axis(
|
583 |
+
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
|
584 |
+
)
|
585 |
+
value_layer = index_first_axis(
|
586 |
+
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
|
587 |
+
)
|
588 |
+
if query_length == kv_seq_len:
|
589 |
+
query_layer = index_first_axis(
|
590 |
+
query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
|
591 |
+
)
|
592 |
+
cu_seqlens_q = cu_seqlens_k
|
593 |
+
max_seqlen_in_batch_q = max_seqlen_in_batch_k
|
594 |
+
indices_q = indices_k
|
595 |
+
elif query_length == 1:
|
596 |
+
max_seqlen_in_batch_q = 1
|
597 |
+
cu_seqlens_q = torch.arange(
|
598 |
+
batch_size + 1, dtype=torch.int32, device=query_layer.device
|
599 |
+
) # There is a memcpy here, that is very bad.
|
600 |
+
indices_q = cu_seqlens_q[:-1]
|
601 |
+
query_layer = query_layer.squeeze(1)
|
602 |
+
else:
|
603 |
+
# The -q_len: slice assumes left padding.
|
604 |
+
attention_mask = attention_mask[:, -query_length:]
|
605 |
+
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
|
606 |
+
|
607 |
+
return (
|
608 |
+
query_layer,
|
609 |
+
key_layer,
|
610 |
+
value_layer,
|
611 |
+
indices_q,
|
612 |
+
(cu_seqlens_q, cu_seqlens_k),
|
613 |
+
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
|
614 |
+
)
|
615 |
+
|
616 |
+
|
617 |
+
class PhiSdpaAttention(PhiAttention):
|
618 |
+
def __init__(self, *args, **kwargs):
|
619 |
+
super().__init__(*args, **kwargs)
|
620 |
+
self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0")
|
621 |
+
|
622 |
+
"""
|
623 |
+
SDPA attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
|
624 |
+
`PhiAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
|
625 |
+
SDPA API.
|
626 |
+
"""
|
627 |
+
|
628 |
+
# Adapted from PhiAttention.forward
|
629 |
+
def forward(
|
630 |
+
self,
|
631 |
+
hidden_states: torch.Tensor,
|
632 |
+
attention_mask: Optional[torch.Tensor] = None,
|
633 |
+
position_ids: Optional[torch.LongTensor] = None,
|
634 |
+
past_key_value: Optional[Cache] = None,
|
635 |
+
output_attentions: bool = False,
|
636 |
+
use_cache: bool = False,
|
637 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
638 |
+
if output_attentions:
|
639 |
+
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
640 |
+
logger.warning_once(
|
641 |
+
"PhiModel is using PhiSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not "
|
642 |
+
"support `output_attentions=True`. Falling back to the manual attention implementation, but specifying "
|
643 |
+
"the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can "
|
644 |
+
'be removed using the argument `attn_implementation="eager"` when loading the model.'
|
645 |
+
)
|
646 |
+
return super().forward(
|
647 |
+
hidden_states=hidden_states,
|
648 |
+
attention_mask=attention_mask,
|
649 |
+
position_ids=position_ids,
|
650 |
+
past_key_value=past_key_value,
|
651 |
+
output_attentions=output_attentions,
|
652 |
+
use_cache=use_cache,
|
653 |
+
)
|
654 |
+
|
655 |
+
bsz, q_len, _ = hidden_states.size()
|
656 |
+
|
657 |
+
query_states = self.q_proj(hidden_states)
|
658 |
+
key_states = self.k_proj(hidden_states)
|
659 |
+
value_states = self.v_proj(hidden_states)
|
660 |
+
|
661 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
662 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
663 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
664 |
+
|
665 |
+
if self.qk_layernorm:
|
666 |
+
query_states = self.q_layernorm(query_states)
|
667 |
+
key_states = self.k_layernorm(key_states)
|
668 |
+
|
669 |
+
kv_seq_len = key_states.shape[-2]
|
670 |
+
if past_key_value is not None:
|
671 |
+
if self.layer_idx is None:
|
672 |
+
raise ValueError(
|
673 |
+
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
674 |
+
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
675 |
+
"with a layer index."
|
676 |
+
)
|
677 |
+
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
678 |
+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
679 |
+
|
680 |
+
# Partial rotary embedding
|
681 |
+
query_rot, query_pass = (
|
682 |
+
query_states[..., : self.rotary_emb.dim],
|
683 |
+
query_states[..., self.rotary_emb.dim :],
|
684 |
+
)
|
685 |
+
key_rot, key_pass = (
|
686 |
+
key_states[..., : self.rotary_emb.dim],
|
687 |
+
key_states[..., self.rotary_emb.dim :],
|
688 |
+
)
|
689 |
+
# [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
|
690 |
+
query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids)
|
691 |
+
|
692 |
+
# [batch_size, seq_length, num_heads, head_dim]
|
693 |
+
query_states = torch.cat((query_rot, query_pass), dim=-1)
|
694 |
+
key_states = torch.cat((key_rot, key_pass), dim=-1)
|
695 |
+
|
696 |
+
if past_key_value is not None:
|
697 |
+
cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim}
|
698 |
+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
699 |
+
|
700 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
701 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
702 |
+
|
703 |
+
# SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom
|
704 |
+
# attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0.
|
705 |
+
# Reference: https://github.com/pytorch/pytorch/issues/112577
|
706 |
+
if self.require_contiguous_qkv and query_states.device.type == "cuda" and attention_mask is not None:
|
707 |
+
query_states = query_states.contiguous()
|
708 |
+
key_states = key_states.contiguous()
|
709 |
+
value_states = value_states.contiguous()
|
710 |
+
|
711 |
+
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
712 |
+
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
713 |
+
is_causal = True if self.is_causal and attention_mask is None and q_len > 1 else False
|
714 |
+
|
715 |
+
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
716 |
+
query_states,
|
717 |
+
key_states,
|
718 |
+
value_states,
|
719 |
+
attn_mask=attention_mask,
|
720 |
+
dropout_p=self.attention_dropout if self.training else 0.0,
|
721 |
+
is_causal=is_causal,
|
722 |
+
)
|
723 |
+
|
724 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
725 |
+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
726 |
+
|
727 |
+
attn_output = self.dense(attn_output)
|
728 |
+
|
729 |
+
return attn_output, None, past_key_value
|
730 |
+
|
731 |
+
|
732 |
+
PHI_ATTENTION_CLASSES = {
|
733 |
+
"eager": PhiAttention,
|
734 |
+
"flash_attention_2": PhiFlashAttention2,
|
735 |
+
"sdpa": PhiSdpaAttention,
|
736 |
+
}
|
737 |
+
|
738 |
+
|
739 |
+
class PhiDecoderLayer(nn.Module):
|
740 |
+
def __init__(self, config: PhiConfig, layer_idx: int):
|
741 |
+
super().__init__()
|
742 |
+
self.self_attn = PHI_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx)
|
743 |
+
self.mlp = PhiMLP(config)
|
744 |
+
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
745 |
+
self.resid_dropout = nn.Dropout(config.resid_pdrop)
|
746 |
+
|
747 |
+
def forward(
|
748 |
+
self,
|
749 |
+
hidden_states: torch.Tensor,
|
750 |
+
attention_mask: Optional[torch.Tensor] = None,
|
751 |
+
position_ids: Optional[torch.LongTensor] = None,
|
752 |
+
output_attentions: Optional[bool] = False,
|
753 |
+
use_cache: Optional[bool] = False,
|
754 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
755 |
+
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
756 |
+
"""
|
757 |
+
Args:
|
758 |
+
hidden_states (`torch.FloatTensor`):
|
759 |
+
input to the layer of shape `(batch, seq_len, embed_dim)`
|
760 |
+
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
761 |
+
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
762 |
+
position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
|
763 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range
|
764 |
+
`[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
|
765 |
+
output_attentions (`bool`, *optional*):
|
766 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
767 |
+
returned tensors for more detail.
|
768 |
+
use_cache (`bool`, *optional*):
|
769 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
770 |
+
(see `past_key_values`).
|
771 |
+
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
772 |
+
"""
|
773 |
+
|
774 |
+
residual = hidden_states
|
775 |
+
|
776 |
+
hidden_states = self.input_layernorm(hidden_states)
|
777 |
+
|
778 |
+
# Self Attention
|
779 |
+
attn_outputs, self_attn_weights, present_key_value = self.self_attn(
|
780 |
+
hidden_states=hidden_states,
|
781 |
+
attention_mask=attention_mask,
|
782 |
+
position_ids=position_ids,
|
783 |
+
past_key_value=past_key_value,
|
784 |
+
output_attentions=output_attentions,
|
785 |
+
use_cache=use_cache,
|
786 |
+
)
|
787 |
+
attn_outputs = self.resid_dropout(attn_outputs)
|
788 |
+
|
789 |
+
feed_forward_hidden_states = self.resid_dropout(self.mlp(hidden_states))
|
790 |
+
hidden_states = attn_outputs + feed_forward_hidden_states + residual
|
791 |
+
outputs = (hidden_states,)
|
792 |
+
|
793 |
+
if output_attentions:
|
794 |
+
outputs += (self_attn_weights,)
|
795 |
+
|
796 |
+
if use_cache:
|
797 |
+
outputs += (present_key_value,)
|
798 |
+
|
799 |
+
return outputs
|
800 |
+
|
801 |
+
|
802 |
+
PHI_START_DOCSTRING = r"""
|
803 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
804 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
805 |
+
etc.)
|
806 |
+
|
807 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
808 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
809 |
+
and behavior.
|
810 |
+
|
811 |
+
Parameters:
|
812 |
+
config ([`PhiConfig`]):
|
813 |
+
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
814 |
+
load the weights associated with the model, only the configuration. Check out the
|
815 |
+
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
816 |
+
"""
|
817 |
+
|
818 |
+
|
819 |
+
@add_start_docstrings(
|
820 |
+
"The bare Phi Model outputting raw hidden-states without any specific head on top.",
|
821 |
+
PHI_START_DOCSTRING,
|
822 |
+
)
|
823 |
+
class PhiPreTrainedModel(PreTrainedModel):
|
824 |
+
config_class = PhiConfig
|
825 |
+
base_model_prefix = "model"
|
826 |
+
supports_gradient_checkpointing = True
|
827 |
+
_no_split_modules = ["PhiDecoderLayer"]
|
828 |
+
_skip_keys_device_placement = "past_key_values"
|
829 |
+
_supports_flash_attn_2 = True
|
830 |
+
_supports_sdpa = True
|
831 |
+
_supports_cache_class = True
|
832 |
+
|
833 |
+
def _init_weights(self, module):
|
834 |
+
std = self.config.initializer_range
|
835 |
+
if isinstance(module, nn.Linear):
|
836 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
837 |
+
if module.bias is not None:
|
838 |
+
module.bias.data.zero_()
|
839 |
+
elif isinstance(module, nn.Embedding):
|
840 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
841 |
+
if module.padding_idx is not None:
|
842 |
+
module.weight.data[module.padding_idx].zero_()
|
843 |
+
|
844 |
+
|
845 |
+
PHI_INPUTS_DOCSTRING = r"""
|
846 |
+
Args:
|
847 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
848 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
849 |
+
it.
|
850 |
+
|
851 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
852 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
853 |
+
|
854 |
+
[What are input IDs?](../glossary#input-ids)
|
855 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
856 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
857 |
+
|
858 |
+
- 1 for tokens that are **not masked**,
|
859 |
+
- 0 for tokens that are **masked**.
|
860 |
+
|
861 |
+
[What are attention masks?](../glossary#attention-mask)
|
862 |
+
|
863 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
864 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
865 |
+
|
866 |
+
If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
|
867 |
+
`past_key_values`).
|
868 |
+
|
869 |
+
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
|
870 |
+
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
|
871 |
+
information on the default strategy.
|
872 |
+
|
873 |
+
- 1 indicates the head is **not masked**,
|
874 |
+
- 0 indicates the head is **masked**.
|
875 |
+
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
876 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
877 |
+
config.n_positions - 1]`.
|
878 |
+
|
879 |
+
[What are position IDs?](../glossary#position-ids)
|
880 |
+
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
|
881 |
+
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
882 |
+
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
|
883 |
+
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
|
884 |
+
|
885 |
+
Two formats are allowed:
|
886 |
+
- a [`~cache_utils.Cache`] instance;
|
887 |
+
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
|
888 |
+
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
|
889 |
+
cache format.
|
890 |
+
|
891 |
+
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
|
892 |
+
legacy cache format will be returned.
|
893 |
+
|
894 |
+
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
|
895 |
+
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
|
896 |
+
of shape `(batch_size, sequence_length)`.
|
897 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
898 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
899 |
+
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
900 |
+
model's internal embedding lookup matrix.
|
901 |
+
use_cache (`bool`, *optional*):
|
902 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
903 |
+
`past_key_values`).
|
904 |
+
output_attentions (`bool`, *optional*):
|
905 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
906 |
+
tensors for more detail.
|
907 |
+
output_hidden_states (`bool`, *optional*):
|
908 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
909 |
+
more detail.
|
910 |
+
return_dict (`bool`, *optional*):
|
911 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
912 |
+
"""
|
913 |
+
|
914 |
+
|
915 |
+
@add_start_docstrings(
|
916 |
+
"The bare Phi Model outputting raw hidden-states without any specific head on top.",
|
917 |
+
PHI_START_DOCSTRING,
|
918 |
+
)
|
919 |
+
class PhiModel(PhiPreTrainedModel):
|
920 |
+
"""
|
921 |
+
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`PhiDecoderLayer`]
|
922 |
+
|
923 |
+
Args:
|
924 |
+
config: PhiConfig
|
925 |
+
"""
|
926 |
+
|
927 |
+
def __init__(self, config: PhiConfig):
|
928 |
+
super().__init__(config)
|
929 |
+
self.padding_idx = config.pad_token_id
|
930 |
+
self.vocab_size = config.vocab_size
|
931 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
932 |
+
self.embed_dropout = nn.Dropout(config.embd_pdrop)
|
933 |
+
print("attention implementation: ", config._attn_implementation)
|
934 |
+
self.layers = nn.ModuleList(
|
935 |
+
[PhiDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
936 |
+
)
|
937 |
+
self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
938 |
+
|
939 |
+
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
|
940 |
+
self._use_sdpa = config._attn_implementation == "sdpa"
|
941 |
+
|
942 |
+
self.gradient_checkpointing = False
|
943 |
+
# Initialize weights and apply final processing
|
944 |
+
self.post_init()
|
945 |
+
|
946 |
+
def get_input_embeddings(self):
|
947 |
+
return self.embed_tokens
|
948 |
+
|
949 |
+
def set_input_embeddings(self, value):
|
950 |
+
self.embed_tokens = value
|
951 |
+
|
952 |
+
@add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING)
|
953 |
+
def forward(
|
954 |
+
self,
|
955 |
+
input_ids: torch.LongTensor = None,
|
956 |
+
attention_mask: Optional[torch.Tensor] = None,
|
957 |
+
position_ids: Optional[torch.LongTensor] = None,
|
958 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
959 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
960 |
+
use_cache: Optional[bool] = None,
|
961 |
+
output_attentions: Optional[bool] = None,
|
962 |
+
output_hidden_states: Optional[bool] = None,
|
963 |
+
return_dict: Optional[bool] = None,
|
964 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
965 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
966 |
+
output_hidden_states = (
|
967 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
968 |
+
)
|
969 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
970 |
+
|
971 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
972 |
+
|
973 |
+
# retrieve input_ids and inputs_embeds
|
974 |
+
if input_ids is not None and inputs_embeds is not None:
|
975 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
976 |
+
elif input_ids is not None:
|
977 |
+
batch_size, seq_length = input_ids.shape[:2]
|
978 |
+
elif inputs_embeds is not None:
|
979 |
+
batch_size, seq_length = inputs_embeds.shape[:2]
|
980 |
+
else:
|
981 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
982 |
+
|
983 |
+
past_key_values_length = 0
|
984 |
+
|
985 |
+
if self.gradient_checkpointing and self.training:
|
986 |
+
if use_cache:
|
987 |
+
logger.warning_once(
|
988 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
989 |
+
)
|
990 |
+
use_cache = False
|
991 |
+
|
992 |
+
if use_cache:
|
993 |
+
use_legacy_cache = not isinstance(past_key_values, Cache)
|
994 |
+
if use_legacy_cache:
|
995 |
+
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
996 |
+
past_key_values_length = past_key_values.get_usable_length(seq_length)
|
997 |
+
|
998 |
+
if position_ids is None:
|
999 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
1000 |
+
position_ids = torch.arange(
|
1001 |
+
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
1002 |
+
)
|
1003 |
+
position_ids = position_ids.unsqueeze(0)
|
1004 |
+
|
1005 |
+
if inputs_embeds is None:
|
1006 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
1007 |
+
|
1008 |
+
inputs_embeds = self.embed_dropout(inputs_embeds)
|
1009 |
+
# commented by Xavier
|
1010 |
+
# Attention mask.
|
1011 |
+
# if self._use_flash_attention_2:
|
1012 |
+
# # 2d mask is passed through the layers
|
1013 |
+
# attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
1014 |
+
# elif self._use_sdpa and not output_attentions:
|
1015 |
+
# attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
1016 |
+
# attention_mask,
|
1017 |
+
# (batch_size, seq_length),
|
1018 |
+
# inputs_embeds,
|
1019 |
+
# past_key_values_length,
|
1020 |
+
# )
|
1021 |
+
# else:
|
1022 |
+
# # 4d mask is passed through the layers
|
1023 |
+
# attention_mask = _prepare_4d_causal_attention_mask(
|
1024 |
+
# attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
1025 |
+
# )
|
1026 |
+
# commented by Xavier
|
1027 |
+
|
1028 |
+
hidden_states = inputs_embeds
|
1029 |
+
|
1030 |
+
# decoder layers
|
1031 |
+
all_hidden_states = () if output_hidden_states else None
|
1032 |
+
all_self_attns = () if output_attentions else None
|
1033 |
+
next_decoder_cache = None
|
1034 |
+
for decoder_layer in self.layers:
|
1035 |
+
if output_hidden_states:
|
1036 |
+
all_hidden_states += (hidden_states,)
|
1037 |
+
|
1038 |
+
if self.gradient_checkpointing and self.training:
|
1039 |
+
layer_outputs = self._gradient_checkpointing_func(
|
1040 |
+
decoder_layer.__call__,
|
1041 |
+
hidden_states,
|
1042 |
+
attention_mask,
|
1043 |
+
position_ids,
|
1044 |
+
past_key_values,
|
1045 |
+
output_attentions,
|
1046 |
+
)
|
1047 |
+
else:
|
1048 |
+
layer_outputs = decoder_layer(
|
1049 |
+
hidden_states,
|
1050 |
+
attention_mask=attention_mask,
|
1051 |
+
position_ids=position_ids,
|
1052 |
+
past_key_value=past_key_values,
|
1053 |
+
output_attentions=output_attentions,
|
1054 |
+
use_cache=use_cache,
|
1055 |
+
)
|
1056 |
+
|
1057 |
+
hidden_states = layer_outputs[0]
|
1058 |
+
|
1059 |
+
if use_cache:
|
1060 |
+
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
1061 |
+
|
1062 |
+
if output_attentions:
|
1063 |
+
all_self_attns += (layer_outputs[1],)
|
1064 |
+
|
1065 |
+
hidden_states = self.final_layernorm(hidden_states)
|
1066 |
+
|
1067 |
+
# add hidden states from the last decoder layer
|
1068 |
+
if output_hidden_states:
|
1069 |
+
all_hidden_states += (hidden_states,)
|
1070 |
+
|
1071 |
+
next_cache = None
|
1072 |
+
if use_cache:
|
1073 |
+
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
|
1074 |
+
if not return_dict:
|
1075 |
+
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
1076 |
+
return BaseModelOutputWithPast(
|
1077 |
+
last_hidden_state=hidden_states,
|
1078 |
+
past_key_values=next_cache,
|
1079 |
+
hidden_states=all_hidden_states,
|
1080 |
+
attentions=all_self_attns,
|
1081 |
+
)
|
1082 |
+
|
1083 |
+
|
1084 |
+
class PhiForCausalLM(PhiPreTrainedModel):
|
1085 |
+
_tied_weights_keys = ["lm_head.weight"]
|
1086 |
+
def __init__(self, config):
|
1087 |
+
super().__init__(config)
|
1088 |
+
config.qk_layernorm = True
|
1089 |
+
config.use_cache = False
|
1090 |
+
self.model = PhiModel(config)
|
1091 |
+
self.vocab_size = config.vocab_size
|
1092 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=True)
|
1093 |
+
|
1094 |
+
# Initialize weights and apply final processing
|
1095 |
+
self.post_init()
|
1096 |
+
|
1097 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_input_embeddings
|
1098 |
+
def get_input_embeddings(self):
|
1099 |
+
return self.model.embed_tokens
|
1100 |
+
|
1101 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_input_embeddings
|
1102 |
+
def set_input_embeddings(self, value):
|
1103 |
+
self.model.embed_tokens = value
|
1104 |
+
|
1105 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_output_embeddings
|
1106 |
+
def get_output_embeddings(self):
|
1107 |
+
return self.lm_head
|
1108 |
+
|
1109 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_output_embeddings
|
1110 |
+
def set_output_embeddings(self, new_embeddings):
|
1111 |
+
self.lm_head = new_embeddings
|
1112 |
+
|
1113 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_decoder
|
1114 |
+
def set_decoder(self, decoder):
|
1115 |
+
self.model = decoder
|
1116 |
+
|
1117 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_decoder
|
1118 |
+
def get_decoder(self):
|
1119 |
+
return self.model
|
1120 |
+
|
1121 |
+
@add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING)
|
1122 |
+
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
1123 |
+
def forward(
|
1124 |
+
self,
|
1125 |
+
input_ids: torch.LongTensor = None,
|
1126 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1127 |
+
position_ids: Optional[torch.LongTensor] = None,
|
1128 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
1129 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
1130 |
+
labels: Optional[torch.LongTensor] = None,
|
1131 |
+
use_cache: Optional[bool] = None,
|
1132 |
+
output_attentions: Optional[bool] = None,
|
1133 |
+
output_hidden_states: Optional[bool] = None,
|
1134 |
+
return_dict: Optional[bool] = None,
|
1135 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
1136 |
+
r"""
|
1137 |
+
Args:
|
1138 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
1139 |
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
1140 |
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
1141 |
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
1142 |
+
|
1143 |
+
Returns:
|
1144 |
+
|
1145 |
+
Example:
|
1146 |
+
|
1147 |
+
```python
|
1148 |
+
>>> from transformers import AutoTokenizer, PhiForCausalLM
|
1149 |
+
|
1150 |
+
>>> model = PhiForCausalLM.from_pretrained("microsoft/phi-1")
|
1151 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-1")
|
1152 |
+
|
1153 |
+
>>> prompt = "This is an example script ."
|
1154 |
+
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
1155 |
+
|
1156 |
+
>>> # Generate
|
1157 |
+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
1158 |
+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
1159 |
+
'This is an example script .\n\n\n\nfrom typing import List\n\ndef find_most_common_letter(words: List[str'
|
1160 |
+
```"""
|
1161 |
+
|
1162 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
1163 |
+
output_hidden_states = (
|
1164 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
1165 |
+
)
|
1166 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1167 |
+
|
1168 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
1169 |
+
outputs = self.model(
|
1170 |
+
input_ids=input_ids,
|
1171 |
+
attention_mask=attention_mask,
|
1172 |
+
position_ids=position_ids,
|
1173 |
+
past_key_values=past_key_values,
|
1174 |
+
inputs_embeds=inputs_embeds,
|
1175 |
+
use_cache=use_cache,
|
1176 |
+
output_attentions=output_attentions,
|
1177 |
+
output_hidden_states=output_hidden_states,
|
1178 |
+
return_dict=return_dict,
|
1179 |
+
)
|
1180 |
+
|
1181 |
+
hidden_states = outputs[0]
|
1182 |
+
logits = self.lm_head(hidden_states)
|
1183 |
+
logits = logits.float()
|
1184 |
+
|
1185 |
+
loss = None
|
1186 |
+
if labels is not None:
|
1187 |
+
# Shift so that tokens < n predict n
|
1188 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
1189 |
+
shift_labels = labels[..., 1:].contiguous()
|
1190 |
+
# Flatten the tokens
|
1191 |
+
loss_fct = CrossEntropyLoss()
|
1192 |
+
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
1193 |
+
shift_labels = shift_labels.view(-1)
|
1194 |
+
# Enable model parallelism
|
1195 |
+
shift_labels = shift_labels.to(shift_logits.device)
|
1196 |
+
loss = loss_fct(shift_logits, shift_labels)
|
1197 |
+
|
1198 |
+
if not return_dict:
|
1199 |
+
output = (logits,) + outputs[1:]
|
1200 |
+
return (loss,) + output if loss is not None else output
|
1201 |
+
|
1202 |
+
return CausalLMOutputWithPast(
|
1203 |
+
loss=loss,
|
1204 |
+
logits=logits,
|
1205 |
+
past_key_values=outputs.past_key_values,
|
1206 |
+
hidden_states=outputs.hidden_states,
|
1207 |
+
attentions=outputs.attentions,
|
1208 |
+
)
|
1209 |
+
|
1210 |
+
# Copied from transformers.models.persimmon.modeling_persimmon.PersimmonForCausalLM.prepare_inputs_for_generation
|
1211 |
+
def prepare_inputs_for_generation(
|
1212 |
+
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
1213 |
+
):
|
1214 |
+
if past_key_values is not None:
|
1215 |
+
if isinstance(past_key_values, Cache):
|
1216 |
+
cache_length = past_key_values.get_seq_length()
|
1217 |
+
past_length = past_key_values.seen_tokens
|
1218 |
+
max_cache_length = past_key_values.get_max_length()
|
1219 |
+
else:
|
1220 |
+
cache_length = past_length = past_key_values[0][0].shape[2]
|
1221 |
+
max_cache_length = None
|
1222 |
+
|
1223 |
+
# Keep only the unprocessed tokens:
|
1224 |
+
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
1225 |
+
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
|
1226 |
+
# input)
|
1227 |
+
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
|
1228 |
+
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
1229 |
+
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
|
1230 |
+
# input_ids based on the past_length.
|
1231 |
+
elif past_length < input_ids.shape[1]:
|
1232 |
+
input_ids = input_ids[:, past_length:]
|
1233 |
+
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
|
1234 |
+
|
1235 |
+
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
|
1236 |
+
if (
|
1237 |
+
max_cache_length is not None
|
1238 |
+
and attention_mask is not None
|
1239 |
+
and cache_length + input_ids.shape[1] > max_cache_length
|
1240 |
+
):
|
1241 |
+
attention_mask = attention_mask[:, -max_cache_length:]
|
1242 |
+
|
1243 |
+
position_ids = kwargs.get("position_ids", None)
|
1244 |
+
if attention_mask is not None and position_ids is None:
|
1245 |
+
# create position_ids on the fly for batch generation
|
1246 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
1247 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
1248 |
+
if past_key_values:
|
1249 |
+
position_ids = position_ids[:, -input_ids.shape[1] :]
|
1250 |
+
|
1251 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
1252 |
+
if inputs_embeds is not None and past_key_values is None:
|
1253 |
+
model_inputs = {"inputs_embeds": inputs_embeds}
|
1254 |
+
else:
|
1255 |
+
model_inputs = {"input_ids": input_ids}
|
1256 |
+
|
1257 |
+
model_inputs.update(
|
1258 |
+
{
|
1259 |
+
"position_ids": position_ids,
|
1260 |
+
"past_key_values": past_key_values,
|
1261 |
+
"use_cache": kwargs.get("use_cache"),
|
1262 |
+
"attention_mask": attention_mask,
|
1263 |
+
}
|
1264 |
+
)
|
1265 |
+
return model_inputs
|
1266 |
+
|
1267 |
+
@staticmethod
|
1268 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM._reorder_cache
|
1269 |
+
def _reorder_cache(past_key_values, beam_idx):
|
1270 |
+
reordered_past = ()
|
1271 |
+
for layer_past in past_key_values:
|
1272 |
+
reordered_past += (
|
1273 |
+
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
1274 |
+
)
|
1275 |
+
return reordered_past
|
1276 |
+
|
1277 |
+
|
1278 |
+
@add_start_docstrings(
|
1279 |
+
"""
|
1280 |
+
The PhiModel with a sequence classification head on top (linear layer).
|
1281 |
+
|
1282 |
+
[`PhiForSequenceClassification`] uses the last token in order to do the classification, as other causal models
|
1283 |
+
(e.g. GPT-2) do.
|
1284 |
+
|
1285 |
+
Since it does classification on the last token, it requires to know the position of the last token. If a
|
1286 |
+
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
|
1287 |
+
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
|
1288 |
+
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
|
1289 |
+
each row of the batch).
|
1290 |
+
""",
|
1291 |
+
PHI_START_DOCSTRING,
|
1292 |
+
)
|
1293 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with LLAMA->PHI,Llama->Phi with self.transformer->self.model, transformer_outputs->model_outputs
|
1294 |
+
class PhiForSequenceClassification(PhiPreTrainedModel):
|
1295 |
+
def __init__(self, config):
|
1296 |
+
super().__init__(config)
|
1297 |
+
self.num_labels = config.num_labels
|
1298 |
+
self.model = PhiModel(config)
|
1299 |
+
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
|
1300 |
+
|
1301 |
+
# Initialize weights and apply final processing
|
1302 |
+
self.post_init()
|
1303 |
+
|
1304 |
+
def get_input_embeddings(self):
|
1305 |
+
return self.model.embed_tokens
|
1306 |
+
|
1307 |
+
def set_input_embeddings(self, value):
|
1308 |
+
self.model.embed_tokens = value
|
1309 |
+
|
1310 |
+
@add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING)
|
1311 |
+
def forward(
|
1312 |
+
self,
|
1313 |
+
input_ids: torch.LongTensor = None,
|
1314 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1315 |
+
position_ids: Optional[torch.LongTensor] = None,
|
1316 |
+
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
1317 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
1318 |
+
labels: Optional[torch.LongTensor] = None,
|
1319 |
+
use_cache: Optional[bool] = None,
|
1320 |
+
output_attentions: Optional[bool] = None,
|
1321 |
+
output_hidden_states: Optional[bool] = None,
|
1322 |
+
return_dict: Optional[bool] = None,
|
1323 |
+
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
|
1324 |
+
r"""
|
1325 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
1326 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
1327 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
1328 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
1329 |
+
"""
|
1330 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1331 |
+
|
1332 |
+
model_outputs = self.model(
|
1333 |
+
input_ids,
|
1334 |
+
attention_mask=attention_mask,
|
1335 |
+
position_ids=position_ids,
|
1336 |
+
past_key_values=past_key_values,
|
1337 |
+
inputs_embeds=inputs_embeds,
|
1338 |
+
use_cache=use_cache,
|
1339 |
+
output_attentions=output_attentions,
|
1340 |
+
output_hidden_states=output_hidden_states,
|
1341 |
+
return_dict=return_dict,
|
1342 |
+
)
|
1343 |
+
hidden_states = model_outputs[0]
|
1344 |
+
logits = self.score(hidden_states)
|
1345 |
+
|
1346 |
+
if input_ids is not None:
|
1347 |
+
batch_size = input_ids.shape[0]
|
1348 |
+
else:
|
1349 |
+
batch_size = inputs_embeds.shape[0]
|
1350 |
+
|
1351 |
+
if self.config.pad_token_id is None and batch_size != 1:
|
1352 |
+
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
1353 |
+
if self.config.pad_token_id is None:
|
1354 |
+
sequence_lengths = -1
|
1355 |
+
else:
|
1356 |
+
if input_ids is not None:
|
1357 |
+
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
|
1358 |
+
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
|
1359 |
+
sequence_lengths = sequence_lengths % input_ids.shape[-1]
|
1360 |
+
sequence_lengths = sequence_lengths.to(logits.device)
|
1361 |
+
else:
|
1362 |
+
sequence_lengths = -1
|
1363 |
+
|
1364 |
+
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
|
1365 |
+
|
1366 |
+
loss = None
|
1367 |
+
if labels is not None:
|
1368 |
+
labels = labels.to(logits.device)
|
1369 |
+
if self.config.problem_type is None:
|
1370 |
+
if self.num_labels == 1:
|
1371 |
+
self.config.problem_type = "regression"
|
1372 |
+
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
1373 |
+
self.config.problem_type = "single_label_classification"
|
1374 |
+
else:
|
1375 |
+
self.config.problem_type = "multi_label_classification"
|
1376 |
+
|
1377 |
+
if self.config.problem_type == "regression":
|
1378 |
+
loss_fct = MSELoss()
|
1379 |
+
if self.num_labels == 1:
|
1380 |
+
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
|
1381 |
+
else:
|
1382 |
+
loss = loss_fct(pooled_logits, labels)
|
1383 |
+
elif self.config.problem_type == "single_label_classification":
|
1384 |
+
loss_fct = CrossEntropyLoss()
|
1385 |
+
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
|
1386 |
+
elif self.config.problem_type == "multi_label_classification":
|
1387 |
+
loss_fct = BCEWithLogitsLoss()
|
1388 |
+
loss = loss_fct(pooled_logits, labels)
|
1389 |
+
if not return_dict:
|
1390 |
+
output = (pooled_logits,) + model_outputs[1:]
|
1391 |
+
return ((loss,) + output) if loss is not None else output
|
1392 |
+
|
1393 |
+
return SequenceClassifierOutputWithPast(
|
1394 |
+
loss=loss,
|
1395 |
+
logits=pooled_logits,
|
1396 |
+
past_key_values=model_outputs.past_key_values,
|
1397 |
+
hidden_states=model_outputs.hidden_states,
|
1398 |
+
attentions=model_outputs.attentions,
|
1399 |
+
)
|
1400 |
+
|
1401 |
+
|
1402 |
+
@add_start_docstrings(
|
1403 |
+
"""
|
1404 |
+
PhiModel with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
|
1405 |
+
Named-Entity-Recognition (NER) tasks.
|
1406 |
+
""",
|
1407 |
+
PHI_START_DOCSTRING,
|
1408 |
+
)
|
1409 |
+
# Copied from transformers.models.mpt.modeling_mpt.MptForTokenClassification with MPT->PHI,Mpt->Phi,self.transformer->self.model,transformer_outputs->model_outputs
|
1410 |
+
class PhiForTokenClassification(PhiPreTrainedModel):
|
1411 |
+
def __init__(self, config: PhiConfig):
|
1412 |
+
super().__init__(config)
|
1413 |
+
self.num_labels = config.num_labels
|
1414 |
+
|
1415 |
+
self.model = PhiModel(config)
|
1416 |
+
if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None:
|
1417 |
+
classifier_dropout = config.classifier_dropout
|
1418 |
+
elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None:
|
1419 |
+
classifier_dropout = config.hidden_dropout
|
1420 |
+
else:
|
1421 |
+
classifier_dropout = 0.1
|
1422 |
+
self.dropout = nn.Dropout(classifier_dropout)
|
1423 |
+
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
1424 |
+
|
1425 |
+
# Initialize weights and apply final processing
|
1426 |
+
self.post_init()
|
1427 |
+
|
1428 |
+
@add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING)
|
1429 |
+
@add_code_sample_docstrings(
|
1430 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
1431 |
+
output_type=TokenClassifierOutput,
|
1432 |
+
config_class=_CONFIG_FOR_DOC,
|
1433 |
+
)
|
1434 |
+
def forward(
|
1435 |
+
self,
|
1436 |
+
input_ids: Optional[torch.LongTensor] = None,
|
1437 |
+
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
1438 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1439 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
1440 |
+
labels: Optional[torch.Tensor] = None,
|
1441 |
+
use_cache: Optional[bool] = None,
|
1442 |
+
output_attentions: Optional[bool] = None,
|
1443 |
+
output_hidden_states: Optional[bool] = None,
|
1444 |
+
return_dict: Optional[bool] = None,
|
1445 |
+
**deprecated_arguments,
|
1446 |
+
) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
|
1447 |
+
r"""
|
1448 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
1449 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
1450 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
1451 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
1452 |
+
"""
|
1453 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1454 |
+
|
1455 |
+
model_outputs = self.model(
|
1456 |
+
input_ids,
|
1457 |
+
past_key_values=past_key_values,
|
1458 |
+
attention_mask=attention_mask,
|
1459 |
+
inputs_embeds=inputs_embeds,
|
1460 |
+
use_cache=use_cache,
|
1461 |
+
output_attentions=output_attentions,
|
1462 |
+
output_hidden_states=output_hidden_states,
|
1463 |
+
return_dict=return_dict,
|
1464 |
+
)
|
1465 |
+
|
1466 |
+
hidden_states = model_outputs[0]
|
1467 |
+
hidden_states = self.dropout(hidden_states)
|
1468 |
+
logits = self.classifier(hidden_states)
|
1469 |
+
|
1470 |
+
loss = None
|
1471 |
+
if labels is not None:
|
1472 |
+
# move labels to correct device to enable model parallelism
|
1473 |
+
labels = labels.to(logits.device)
|
1474 |
+
batch_size, seq_length = labels.shape
|
1475 |
+
loss_fct = CrossEntropyLoss()
|
1476 |
+
loss = loss_fct(
|
1477 |
+
logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length)
|
1478 |
+
)
|
1479 |
+
|
1480 |
+
if not return_dict:
|
1481 |
+
output = (logits,) + model_outputs[2:]
|
1482 |
+
return ((loss,) + output) if loss is not None else output
|
1483 |
+
|
1484 |
+
return TokenClassifierOutput(
|
1485 |
+
loss=loss,
|
1486 |
+
logits=logits,
|
1487 |
+
hidden_states=model_outputs.hidden_states,
|
1488 |
+
attentions=model_outputs.attentions,
|
1489 |
+
)
|
models/sampling.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/lucidrains/muse-maskgit-pytorch
|
2 |
+
|
3 |
+
import math
|
4 |
+
from functools import partial
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
|
10 |
+
def log(t, eps=1e-20):
|
11 |
+
return torch.log(t.clamp(min=eps))
|
12 |
+
|
13 |
+
|
14 |
+
def gumbel_noise(t, generator=None):
|
15 |
+
noise = torch.zeros_like(t).uniform_(0, 1, generator=generator)
|
16 |
+
return -log(-log(noise))
|
17 |
+
|
18 |
+
|
19 |
+
def gumbel_sample(t, temperature=1.0, dim=-1, generator=None):
|
20 |
+
return ((t / max(temperature, 1e-10)) + gumbel_noise(t, generator=generator)).argmax(dim=dim)
|
21 |
+
|
22 |
+
|
23 |
+
def top_k(logits, thres=0.9):
|
24 |
+
k = math.ceil((1 - thres) * logits.shape[-1])
|
25 |
+
val, ind = logits.topk(k, dim=-1)
|
26 |
+
probs = torch.full_like(logits, float("-inf"))
|
27 |
+
probs.scatter_(2, ind, val)
|
28 |
+
return probs
|
29 |
+
|
30 |
+
|
31 |
+
def mask_by_random_topk(mask_len, probs, temperature=1.0, generator=None):
|
32 |
+
confidence = log(probs) + temperature * gumbel_noise(probs, generator=generator)
|
33 |
+
sorted_confidence = torch.sort(confidence, dim=-1).values
|
34 |
+
cut_off = torch.gather(sorted_confidence, 1, mask_len.long())
|
35 |
+
masking = confidence < cut_off
|
36 |
+
return masking
|
37 |
+
|
38 |
+
|
39 |
+
def cosine_schedule(t):
|
40 |
+
return torch.cos(t * math.pi * 0.5)
|
41 |
+
|
42 |
+
|
43 |
+
def linear_schedule(t):
|
44 |
+
mask_ratio = 1 - t
|
45 |
+
mask_ratio = mask_ratio.clamp(min=1e-6, max=1.0)
|
46 |
+
return mask_ratio
|
47 |
+
|
48 |
+
|
49 |
+
def pow(t, method):
|
50 |
+
exponent = float(method.replace("pow", ""))
|
51 |
+
mask_ratio = 1.0 - t**exponent
|
52 |
+
mask_ratio = mask_ratio.clamp(min=1e-6, max=1.0)
|
53 |
+
return mask_ratio
|
54 |
+
|
55 |
+
|
56 |
+
def sigmoid_schedule(t, start=-3, end=3, tau=1.0, clip_min=1e-6):
|
57 |
+
for item in [t, start, end, tau]:
|
58 |
+
item = torch.tensor(item) if not torch.is_tensor(item) else item
|
59 |
+
|
60 |
+
# A gamma function based on sigmoid function.
|
61 |
+
v_start = torch.sigmoid(torch.tensor(start / tau))
|
62 |
+
v_end = torch.sigmoid(torch.tensor(end / tau))
|
63 |
+
output = torch.sigmoid((t * (end - start) + start) / tau)
|
64 |
+
output = (v_end - output) / (v_end - v_start)
|
65 |
+
return torch.clip(output, clip_min, 1.0)
|
66 |
+
|
67 |
+
|
68 |
+
def get_mask_chedule(method, **schedule_kwargs):
|
69 |
+
if method == "cosine":
|
70 |
+
return cosine_schedule
|
71 |
+
elif method == "linear":
|
72 |
+
return linear_schedule
|
73 |
+
elif "pow" in method:
|
74 |
+
return partial(pow, method=method)
|
75 |
+
elif method == "sigmoid":
|
76 |
+
return partial(sigmoid_schedule, **schedule_kwargs)
|
77 |
+
else:
|
78 |
+
raise ValueError("Unknown schedule method: {}".format(method))
|
79 |
+
|
80 |
+
def top_k_top_p_filtering(
|
81 |
+
logits: torch.Tensor,
|
82 |
+
top_k: int = 0,
|
83 |
+
top_p: float = 1.0,
|
84 |
+
filter_value: float = -float("Inf"),
|
85 |
+
min_tokens_to_keep: int = 1,
|
86 |
+
) -> torch.Tensor:
|
87 |
+
"""Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
|
88 |
+
Args:
|
89 |
+
logits: logits distribution shape (batch size, vocabulary size)
|
90 |
+
if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
|
91 |
+
if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
|
92 |
+
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
|
93 |
+
Make sure we keep at least min_tokens_to_keep per batch example in the output
|
94 |
+
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
|
95 |
+
"""
|
96 |
+
if top_k > 0:
|
97 |
+
top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check
|
98 |
+
# Remove all tokens with a probability less than the last token of the top-k
|
99 |
+
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
100 |
+
logits[indices_to_remove] = filter_value
|
101 |
+
|
102 |
+
if top_p < 1.0:
|
103 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
104 |
+
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
105 |
+
|
106 |
+
# Remove tokens with cumulative probability above the threshold (token with 0 are kept)
|
107 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
108 |
+
if min_tokens_to_keep > 1:
|
109 |
+
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
|
110 |
+
sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
|
111 |
+
# Shift the indices to the right to keep also the first token above the threshold
|
112 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
113 |
+
sorted_indices_to_remove[..., 0] = 0
|
114 |
+
|
115 |
+
# scatter sorted tensors to original indexing
|
116 |
+
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
117 |
+
logits[indices_to_remove] = filter_value
|
118 |
+
return logits
|
models/training_utils.py
ADDED
@@ -0,0 +1,455 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The HuggingFace Inc. team.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import copy
|
17 |
+
import os
|
18 |
+
import random
|
19 |
+
from typing import Any, Dict, Iterable, Optional, Union
|
20 |
+
|
21 |
+
import numpy as np
|
22 |
+
import pandas as pd
|
23 |
+
import torch
|
24 |
+
import torch.nn.functional as F
|
25 |
+
|
26 |
+
|
27 |
+
def enable_full_determinism(seed: int):
|
28 |
+
"""
|
29 |
+
Helper function for reproducible behavior during distributed training. See
|
30 |
+
- https://pytorch.org/docs/stable/notes/randomness.html for pytorch
|
31 |
+
"""
|
32 |
+
# set seed first
|
33 |
+
set_seed(seed)
|
34 |
+
|
35 |
+
# Enable PyTorch deterministic mode. This potentially requires either the environment
|
36 |
+
# variable 'CUDA_LAUNCH_BLOCKING' or 'CUBLAS_WORKSPACE_CONFIG' to be set,
|
37 |
+
# depending on the CUDA version, so we set them both here
|
38 |
+
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
|
39 |
+
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
|
40 |
+
torch.use_deterministic_algorithms(True)
|
41 |
+
|
42 |
+
# Enable CUDNN deterministic mode
|
43 |
+
torch.backends.cudnn.deterministic = True
|
44 |
+
torch.backends.cudnn.benchmark = False
|
45 |
+
|
46 |
+
|
47 |
+
def set_seed(seed: int):
|
48 |
+
"""
|
49 |
+
Args:
|
50 |
+
Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`.
|
51 |
+
seed (`int`): The seed to set.
|
52 |
+
"""
|
53 |
+
random.seed(seed)
|
54 |
+
np.random.seed(seed)
|
55 |
+
torch.manual_seed(seed)
|
56 |
+
torch.cuda.manual_seed_all(seed)
|
57 |
+
# ^^ safe to call this function even if cuda is not available
|
58 |
+
|
59 |
+
|
60 |
+
# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
|
61 |
+
class EMA:
|
62 |
+
"""
|
63 |
+
Exponential Moving Average of models weights
|
64 |
+
"""
|
65 |
+
|
66 |
+
def __init__(
|
67 |
+
self,
|
68 |
+
parameters: Iterable[torch.nn.Parameter],
|
69 |
+
decay: float = 0.9999,
|
70 |
+
min_decay: float = 0.0,
|
71 |
+
update_after_step: int = 0,
|
72 |
+
use_ema_warmup: bool = False,
|
73 |
+
inv_gamma: Union[float, int] = 1.0,
|
74 |
+
power: Union[float, int] = 2 / 3,
|
75 |
+
model_cls: Optional[Any] = None,
|
76 |
+
model_config: Dict[str, Any] = None,
|
77 |
+
**kwargs,
|
78 |
+
):
|
79 |
+
"""
|
80 |
+
Args:
|
81 |
+
parameters (Iterable[torch.nn.Parameter]): The parameters to track.
|
82 |
+
decay (float): The decay factor for the exponential moving average.
|
83 |
+
min_decay (float): The minimum decay factor for the exponential moving average.
|
84 |
+
update_after_step (int): The number of steps to wait before starting to update the EMA weights.
|
85 |
+
use_ema_warmup (bool): Whether to use EMA warmup.
|
86 |
+
inv_gamma (float):
|
87 |
+
Inverse multiplicative factor of EMA warmup. Default: 1. Only used if `use_ema_warmup` is True.
|
88 |
+
power (float): Exponential factor of EMA warmup. Default: 2/3. Only used if `use_ema_warmup` is True.
|
89 |
+
device (Optional[Union[str, torch.device]]): The device to store the EMA weights on. If None, the EMA
|
90 |
+
weights will be stored on CPU.
|
91 |
+
|
92 |
+
@crowsonkb's notes on EMA Warmup:
|
93 |
+
If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan
|
94 |
+
to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps),
|
95 |
+
gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999
|
96 |
+
at 215.4k steps).
|
97 |
+
"""
|
98 |
+
|
99 |
+
parameters = list(parameters)
|
100 |
+
self.shadow_params = [p.clone().detach() for p in parameters]
|
101 |
+
|
102 |
+
self.temp_stored_params = None
|
103 |
+
|
104 |
+
self.decay = decay
|
105 |
+
self.min_decay = min_decay
|
106 |
+
self.update_after_step = update_after_step
|
107 |
+
self.use_ema_warmup = use_ema_warmup
|
108 |
+
self.inv_gamma = inv_gamma
|
109 |
+
self.power = power
|
110 |
+
self.optimization_step = 0
|
111 |
+
self.cur_decay_value = None # set in `step()`
|
112 |
+
|
113 |
+
self.model_cls = model_cls
|
114 |
+
self.model_config = model_config
|
115 |
+
|
116 |
+
@classmethod
|
117 |
+
def from_pretrained(cls, path, model_cls) -> "EMA":
|
118 |
+
_, ema_kwargs = model_cls.load_config(path, return_unused_kwargs=True)
|
119 |
+
model = model_cls.from_pretrained(path)
|
120 |
+
|
121 |
+
ema_model = cls(model.parameters(), model_cls=model_cls, model_config=model.config)
|
122 |
+
|
123 |
+
ema_model.load_state_dict(ema_kwargs)
|
124 |
+
return ema_model
|
125 |
+
|
126 |
+
def save_pretrained(self, path):
|
127 |
+
if self.model_cls is None:
|
128 |
+
raise ValueError("`save_pretrained` can only be used if `model_cls` was defined at __init__.")
|
129 |
+
|
130 |
+
if self.model_config is None:
|
131 |
+
raise ValueError("`save_pretrained` can only be used if `model_config` was defined at __init__.")
|
132 |
+
|
133 |
+
model = self.model_cls.from_config(self.model_config)
|
134 |
+
state_dict = self.state_dict()
|
135 |
+
state_dict.pop("shadow_params", None)
|
136 |
+
|
137 |
+
model.register_to_config(**state_dict)
|
138 |
+
self.copy_to(model.parameters())
|
139 |
+
model.save_pretrained(path)
|
140 |
+
|
141 |
+
def get_decay(self, optimization_step: int) -> float:
|
142 |
+
"""
|
143 |
+
Compute the decay factor for the exponential moving average.
|
144 |
+
"""
|
145 |
+
step = max(0, optimization_step - self.update_after_step - 1)
|
146 |
+
|
147 |
+
if step <= 0:
|
148 |
+
return 0.0
|
149 |
+
|
150 |
+
if self.use_ema_warmup:
|
151 |
+
cur_decay_value = 1 - (1 + step / self.inv_gamma) ** -self.power
|
152 |
+
else:
|
153 |
+
cur_decay_value = (1 + step) / (10 + step)
|
154 |
+
|
155 |
+
cur_decay_value = min(cur_decay_value, self.decay)
|
156 |
+
# make sure decay is not smaller than min_decay
|
157 |
+
cur_decay_value = max(cur_decay_value, self.min_decay)
|
158 |
+
return cur_decay_value
|
159 |
+
|
160 |
+
@torch.no_grad()
|
161 |
+
def step(self, parameters: Iterable[torch.nn.Parameter]):
|
162 |
+
parameters = list(parameters)
|
163 |
+
|
164 |
+
self.optimization_step += 1
|
165 |
+
|
166 |
+
# Compute the decay factor for the exponential moving average.
|
167 |
+
decay = self.get_decay(self.optimization_step)
|
168 |
+
self.cur_decay_value = decay
|
169 |
+
one_minus_decay = 1 - decay
|
170 |
+
|
171 |
+
for s_param, param in zip(self.shadow_params, parameters):
|
172 |
+
if param.requires_grad:
|
173 |
+
s_param.sub_(one_minus_decay * (s_param - param))
|
174 |
+
else:
|
175 |
+
s_param.copy_(param)
|
176 |
+
|
177 |
+
torch.cuda.empty_cache()
|
178 |
+
|
179 |
+
def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None:
|
180 |
+
"""
|
181 |
+
Copy current averaged parameters into given collection of parameters.
|
182 |
+
|
183 |
+
Args:
|
184 |
+
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
185 |
+
updated with the stored moving averages. If `None`, the parameters with which this
|
186 |
+
`ExponentialMovingAverage` was initialized will be used.
|
187 |
+
"""
|
188 |
+
parameters = list(parameters)
|
189 |
+
for s_param, param in zip(self.shadow_params, parameters):
|
190 |
+
param.data.copy_(s_param.to(param.device).data)
|
191 |
+
|
192 |
+
def to(self, device=None, dtype=None) -> None:
|
193 |
+
r"""Move internal buffers of the ExponentialMovingAverage to `device`.
|
194 |
+
|
195 |
+
Args:
|
196 |
+
device: like `device` argument to `torch.Tensor.to`
|
197 |
+
"""
|
198 |
+
# .to() on the tensors handles None correctly
|
199 |
+
self.shadow_params = [
|
200 |
+
p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device)
|
201 |
+
for p in self.shadow_params
|
202 |
+
]
|
203 |
+
|
204 |
+
def state_dict(self) -> dict:
|
205 |
+
r"""
|
206 |
+
Returns the state of the ExponentialMovingAverage as a dict. This method is used by accelerate during
|
207 |
+
checkpointing to save the ema state dict.
|
208 |
+
"""
|
209 |
+
# Following PyTorch conventions, references to tensors are returned:
|
210 |
+
# "returns a reference to the state and not its copy!" -
|
211 |
+
# https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict
|
212 |
+
return {
|
213 |
+
"decay": self.decay,
|
214 |
+
"min_decay": self.min_decay,
|
215 |
+
"optimization_step": self.optimization_step,
|
216 |
+
"update_after_step": self.update_after_step,
|
217 |
+
"use_ema_warmup": self.use_ema_warmup,
|
218 |
+
"inv_gamma": self.inv_gamma,
|
219 |
+
"power": self.power,
|
220 |
+
"shadow_params": self.shadow_params,
|
221 |
+
}
|
222 |
+
|
223 |
+
def store(self, parameters: Iterable[torch.nn.Parameter]) -> None:
|
224 |
+
r"""
|
225 |
+
Args:
|
226 |
+
Save the current parameters for restoring later.
|
227 |
+
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
228 |
+
temporarily stored.
|
229 |
+
"""
|
230 |
+
self.temp_stored_params = [param.detach().cpu().clone() for param in parameters]
|
231 |
+
|
232 |
+
def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None:
|
233 |
+
r"""
|
234 |
+
Args:
|
235 |
+
Restore the parameters stored with the `store` method. Useful to validate the model with EMA parameters without:
|
236 |
+
affecting the original optimization process. Store the parameters before the `copy_to()` method. After
|
237 |
+
validation (or model saving), use this to restore the former parameters.
|
238 |
+
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
239 |
+
updated with the stored parameters. If `None`, the parameters with which this
|
240 |
+
`ExponentialMovingAverage` was initialized will be used.
|
241 |
+
"""
|
242 |
+
if self.temp_stored_params is None:
|
243 |
+
raise RuntimeError("This ExponentialMovingAverage has no `store()`ed weights to `restore()`")
|
244 |
+
for c_param, param in zip(self.temp_stored_params, parameters):
|
245 |
+
param.data.copy_(c_param.data)
|
246 |
+
|
247 |
+
# Better memory-wise.
|
248 |
+
self.temp_stored_params = None
|
249 |
+
|
250 |
+
def load_state_dict(self, state_dict: dict) -> None:
|
251 |
+
r"""
|
252 |
+
Args:
|
253 |
+
Loads the ExponentialMovingAverage state. This method is used by accelerate during checkpointing to save the
|
254 |
+
ema state dict.
|
255 |
+
state_dict (dict): EMA state. Should be an object returned
|
256 |
+
from a call to :meth:`state_dict`.
|
257 |
+
"""
|
258 |
+
# deepcopy, to be consistent with module API
|
259 |
+
state_dict = copy.deepcopy(state_dict)
|
260 |
+
|
261 |
+
self.decay = state_dict.get("decay", self.decay)
|
262 |
+
if self.decay < 0.0 or self.decay > 1.0:
|
263 |
+
raise ValueError("Decay must be between 0 and 1")
|
264 |
+
|
265 |
+
self.min_decay = state_dict.get("min_decay", self.min_decay)
|
266 |
+
if not isinstance(self.min_decay, float):
|
267 |
+
raise ValueError("Invalid min_decay")
|
268 |
+
|
269 |
+
self.optimization_step = state_dict.get("optimization_step", self.optimization_step)
|
270 |
+
if not isinstance(self.optimization_step, int):
|
271 |
+
raise ValueError("Invalid optimization_step")
|
272 |
+
|
273 |
+
self.update_after_step = state_dict.get("update_after_step", self.update_after_step)
|
274 |
+
if not isinstance(self.update_after_step, int):
|
275 |
+
raise ValueError("Invalid update_after_step")
|
276 |
+
|
277 |
+
self.use_ema_warmup = state_dict.get("use_ema_warmup", self.use_ema_warmup)
|
278 |
+
if not isinstance(self.use_ema_warmup, bool):
|
279 |
+
raise ValueError("Invalid use_ema_warmup")
|
280 |
+
|
281 |
+
self.inv_gamma = state_dict.get("inv_gamma", self.inv_gamma)
|
282 |
+
if not isinstance(self.inv_gamma, (float, int)):
|
283 |
+
raise ValueError("Invalid inv_gamma")
|
284 |
+
|
285 |
+
self.power = state_dict.get("power", self.power)
|
286 |
+
if not isinstance(self.power, (float, int)):
|
287 |
+
raise ValueError("Invalid power")
|
288 |
+
|
289 |
+
shadow_params = state_dict.get("shadow_params", None)
|
290 |
+
if shadow_params is not None:
|
291 |
+
self.shadow_params = shadow_params
|
292 |
+
if not isinstance(self.shadow_params, list):
|
293 |
+
raise ValueError("shadow_params must be a list")
|
294 |
+
if not all(isinstance(p, torch.Tensor) for p in self.shadow_params):
|
295 |
+
raise ValueError("shadow_params must all be Tensors")
|
296 |
+
|
297 |
+
|
298 |
+
# calculates entropy over each pixel distribution
|
299 |
+
def pixel_entropy_per_percent_masked_bucket(logits, input_ids, mask_id):
|
300 |
+
# only calculated entropy over image tokens that were masked in the original image
|
301 |
+
masked_tokens = input_ids == mask_id
|
302 |
+
num_masked_pixels = masked_tokens.sum(-1)
|
303 |
+
|
304 |
+
probs = F.softmax(logits, dim=-1)
|
305 |
+
log_probs = F.log_softmax(logits, dim=-1)
|
306 |
+
|
307 |
+
entropy_per_pixel = -((probs * log_probs).sum(-1))
|
308 |
+
|
309 |
+
# the predictions for non-masked aren't used, so set their entropies to zero
|
310 |
+
entropy_per_pixel[~masked_tokens] = 0
|
311 |
+
|
312 |
+
entropy_per_image_numerator = entropy_per_pixel.sum(-1)
|
313 |
+
entropy_per_image = entropy_per_image_numerator / num_masked_pixels
|
314 |
+
|
315 |
+
total_buckets = 10
|
316 |
+
masked_buckets = input_ids_to_masked_buckets(input_ids, mask_id, total_buckets)
|
317 |
+
|
318 |
+
entropy_by_masked_bucket = average_by_buckets(entropy_per_image, masked_buckets, total_buckets)
|
319 |
+
|
320 |
+
return entropy_by_masked_bucket
|
321 |
+
|
322 |
+
|
323 |
+
# calculates entropy over the averaged distribution of pixels for the whole image
|
324 |
+
def image_entropy_per_percent_masked_bucket(logits, input_ids, mask_id):
|
325 |
+
# only calculated entropy over image tokens that were masked in the original image
|
326 |
+
masked_tokens = input_ids == mask_id
|
327 |
+
num_masked_pixels = masked_tokens.sum(-1, keepdim=True)
|
328 |
+
|
329 |
+
pixel_probs = F.softmax(logits, dim=-1)
|
330 |
+
pixel_probs[~masked_tokens] = 0
|
331 |
+
image_probs_numerator = pixel_probs.sum(-2)
|
332 |
+
image_probs = image_probs_numerator / num_masked_pixels
|
333 |
+
|
334 |
+
image_log_probs = image_probs.log()
|
335 |
+
|
336 |
+
entropy_per_image = -((image_probs * image_log_probs).sum(-1))
|
337 |
+
|
338 |
+
total_buckets = 10
|
339 |
+
masked_buckets = input_ids_to_masked_buckets(input_ids, mask_id, total_buckets)
|
340 |
+
|
341 |
+
entropy_by_masked_bucket = average_by_buckets(entropy_per_image, masked_buckets, total_buckets)
|
342 |
+
|
343 |
+
return entropy_by_masked_bucket
|
344 |
+
|
345 |
+
|
346 |
+
def cross_entropy_per_percent_masked_bucket(logits, labels, input_ids, mask_id, output_size, label_smoothing):
|
347 |
+
cross_entropy_per_image = F.cross_entropy(
|
348 |
+
logits.view(-1, output_size),
|
349 |
+
labels.view(-1),
|
350 |
+
ignore_index=-100,
|
351 |
+
label_smoothing=label_smoothing,
|
352 |
+
reduction="none",
|
353 |
+
)
|
354 |
+
|
355 |
+
total_buckets = 10
|
356 |
+
masked_buckets = input_ids_to_masked_buckets(input_ids, mask_id, total_buckets)
|
357 |
+
|
358 |
+
cross_entropy_by_percent_masked_bucket = average_by_buckets(cross_entropy_per_image, masked_buckets, total_buckets)
|
359 |
+
|
360 |
+
return cross_entropy_by_percent_masked_bucket
|
361 |
+
|
362 |
+
|
363 |
+
def token_probability_distributions_per_percent_masked_bucket(logits, input_ids, mask_id):
|
364 |
+
probs = F.softmax(logits, dim=-1)
|
365 |
+
|
366 |
+
total_buckets = 10
|
367 |
+
masked_buckets = input_ids_to_masked_buckets(input_ids, mask_id, total_buckets)
|
368 |
+
|
369 |
+
data = []
|
370 |
+
|
371 |
+
for bucket_idx in range(total_buckets):
|
372 |
+
indices_for_bucket = masked_buckets[masked_buckets == bucket_idx]
|
373 |
+
|
374 |
+
# It's ok if none were noised in the range of this bucket. This
|
375 |
+
# function will be called for a later training step where it's likely
|
376 |
+
# there will be an element noised in the range.
|
377 |
+
if indices_for_bucket.shape[0] == 0:
|
378 |
+
continue
|
379 |
+
|
380 |
+
index_for_bucket = indices_for_bucket[0]
|
381 |
+
|
382 |
+
image_probs = probs[index_for_bucket]
|
383 |
+
|
384 |
+
# find the index of a masked pixel for the image
|
385 |
+
input_ids_for_image = input_ids[index_for_bucket]
|
386 |
+
masked_pixels_probs = image_probs[input_ids_for_image == mask_id]
|
387 |
+
|
388 |
+
masked_pixel_probs = masked_pixels_probs[0]
|
389 |
+
|
390 |
+
masked_pixel_probs = masked_pixel_probs.cpu().numpy()
|
391 |
+
|
392 |
+
for masked_pixel_prob in masked_pixel_probs:
|
393 |
+
data.append({"bucket": bucket_idx, "masked_pixel_prob": masked_pixel_prob})
|
394 |
+
|
395 |
+
df = pd.DataFrame(data)
|
396 |
+
|
397 |
+
return df
|
398 |
+
|
399 |
+
|
400 |
+
def average_by_buckets(values, masked_buckets, total_buckets):
|
401 |
+
unique_buckets, bucket_counts = masked_buckets.unique(dim=0, return_counts=True)
|
402 |
+
|
403 |
+
numerator = torch.zeros(total_buckets, device=values.device)
|
404 |
+
|
405 |
+
numerator.scatter_add_(0, masked_buckets, values)
|
406 |
+
|
407 |
+
# default value is one because the buckets for which there aren't
|
408 |
+
# any values will have a numerator of zero. So we just need to not divide
|
409 |
+
# by zero.
|
410 |
+
denominator = torch.ones(total_buckets, device=values.device, dtype=torch.long)
|
411 |
+
denominator[unique_buckets] = bucket_counts
|
412 |
+
|
413 |
+
averaged_by_buckets = numerator / denominator
|
414 |
+
|
415 |
+
return averaged_by_buckets
|
416 |
+
|
417 |
+
|
418 |
+
def input_ids_to_masked_buckets(input_ids, mask_id, total_buckets=10):
|
419 |
+
assert total_buckets == 10
|
420 |
+
|
421 |
+
masked_percent = (input_ids == mask_id).sum(-1) / input_ids.shape[-1]
|
422 |
+
|
423 |
+
# we do not formally use timesteps to noise images. Instead, we mask a percent
|
424 |
+
# of the pixels. We don't want to log entropy for every mask percent between 0 and 1,
|
425 |
+
# and we also want to track how the entropy evolves over time w/in a range of mask
|
426 |
+
# percents that should have similar entropy. So we bucket the masked percents into a
|
427 |
+
# fixed number of buckets
|
428 |
+
|
429 |
+
# we could generalize this later if needed but for now, let's just assume a fixed
|
430 |
+
# number of 10 buckets.
|
431 |
+
|
432 |
+
# How this maps to a bucket index:
|
433 |
+
# (mask) * bucket_index +
|
434 |
+
# (mask_1) * bucket_index_1
|
435 |
+
#
|
436 |
+
# -> Where the mask is true will be set to the expected bucket index,
|
437 |
+
# where the mask is false will be set to 0.
|
438 |
+
#
|
439 |
+
# Given the probabilities are between 0 and 1, each masked_percent will get mapped
|
440 |
+
# to a timestep by one and only one of the masks.
|
441 |
+
|
442 |
+
masked_buckets = (
|
443 |
+
((0 < masked_percent) & (masked_percent <= 0.1)) * 0
|
444 |
+
+ ((0.1 < masked_percent) & (masked_percent <= 0.2)) * 1
|
445 |
+
+ ((0.2 < masked_percent) & (masked_percent <= 0.3)) * 2
|
446 |
+
+ ((0.3 < masked_percent) & (masked_percent <= 0.4)) * 3
|
447 |
+
+ ((0.4 < masked_percent) & (masked_percent <= 0.5)) * 4
|
448 |
+
+ ((0.5 < masked_percent) & (masked_percent <= 0.6)) * 5
|
449 |
+
+ ((0.6 < masked_percent) & (masked_percent <= 0.7)) * 6
|
450 |
+
+ ((0.7 < masked_percent) & (masked_percent <= 0.8)) * 7
|
451 |
+
+ ((0.8 < masked_percent) & (masked_percent <= 0.9)) * 8
|
452 |
+
+ ((0.9 < masked_percent) & (masked_percent <= 1.0)) * 9
|
453 |
+
)
|
454 |
+
|
455 |
+
return masked_buckets
|
prompting_utils.py
ADDED
@@ -0,0 +1,528 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
class UniversalPrompting():
|
4 |
+
def __init__(self, text_tokenizer,
|
5 |
+
special_tokens=("<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", "<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>"),
|
6 |
+
max_text_len=8000, max_seq_len=377, ignore_id=-100, cond_dropout_prob=0.1):
|
7 |
+
"""
|
8 |
+
:param text_tokenizer: original text tokenizer
|
9 |
+
"""
|
10 |
+
self.text_tokenizer = text_tokenizer
|
11 |
+
self.text_tokenizer.add_special_tokens({'pad_token': '[PAD]'})
|
12 |
+
self.text_tokenizer.add_tokens(list(special_tokens))
|
13 |
+
self.sptids_dict = {token: torch.tensor(self.text_tokenizer.convert_tokens_to_ids([token])) for token in
|
14 |
+
special_tokens}
|
15 |
+
self.sptids_dict['<|sot|>'] = torch.tensor([self.text_tokenizer.bos_token_id])
|
16 |
+
self.sptids_dict['<|eot|>'] = torch.tensor([self.text_tokenizer.eos_token_id])
|
17 |
+
self.sptids_dict['<|pad|>'] = torch.tensor([self.text_tokenizer.pad_token_id])
|
18 |
+
# plus 1 because at this time we add a task token before
|
19 |
+
self.max_text_len = max_text_len + 1
|
20 |
+
self.pad_id = self.text_tokenizer.convert_tokens_to_ids('[PAD]')
|
21 |
+
self.ignore_id = ignore_id
|
22 |
+
self.cond_dropout_prob = cond_dropout_prob
|
23 |
+
|
24 |
+
def t2i_prompt_predict_next(self, text_ids, image_ids, labels):
|
25 |
+
|
26 |
+
device = image_ids.device
|
27 |
+
sequence_ids = []
|
28 |
+
attention_masks = []
|
29 |
+
label_ids = []
|
30 |
+
probs = torch.rand(len(text_ids))
|
31 |
+
for i in range(len(text_ids)):
|
32 |
+
|
33 |
+
if len(text_ids[i]) == 0:
|
34 |
+
text_ids[i] = [self.text_tokenizer.bos_token_id]
|
35 |
+
elif text_ids[i][0] != self.text_tokenizer.bos_token_id:
|
36 |
+
text_ids[i] = [self.text_tokenizer.bos_token_id] + text_ids[i]
|
37 |
+
|
38 |
+
temp_ids = [int(self.sptids_dict['<|t2i|>'])] + text_ids[i] + [self.text_tokenizer.eos_token_id]
|
39 |
+
|
40 |
+
# randomly dropout text condition
|
41 |
+
if probs[i] < self.cond_dropout_prob:
|
42 |
+
temp_ids = [int(self.sptids_dict['<|t2i|>']), self.text_tokenizer.bos_token_id, self.text_tokenizer.eos_token_id]
|
43 |
+
|
44 |
+
if self.max_text_len >= len(temp_ids):
|
45 |
+
temp_ids = [self.pad_id] * (self.max_text_len - len(temp_ids)) + temp_ids
|
46 |
+
temp_masks = [0] * (self.max_text_len - len(temp_ids)) + [1] * (len(temp_ids) + image_ids.shape[-1] + 3)
|
47 |
+
else:
|
48 |
+
# should add the eos token
|
49 |
+
temp_ids = temp_ids[:self.max_text_len - 1] + [self.text_tokenizer.eos_token_id]
|
50 |
+
temp_masks = [1] * (len(temp_ids) + image_ids.shape[-1] + 3) # +2 for two special tokens
|
51 |
+
|
52 |
+
# prompting -- [task token] [sot] [text tokens] [eot] [soi] [image tokens] [eoi]
|
53 |
+
temp_label_ids = torch.cat([
|
54 |
+
# should we predict text tokens when doing image reconstruction?
|
55 |
+
torch.tensor(temp_ids).to(device),
|
56 |
+
self.sptids_dict['<|soi|>'].to(device),
|
57 |
+
labels[i],
|
58 |
+
self.sptids_dict['<|eoi|>'].to(device)
|
59 |
+
], dim=0)
|
60 |
+
|
61 |
+
temp_label_ids = torch.where(temp_label_ids == self.pad_id, self.ignore_id, temp_label_ids)
|
62 |
+
|
63 |
+
temp_ids = torch.cat([
|
64 |
+
torch.tensor(temp_ids).to(device),
|
65 |
+
self.sptids_dict['<|soi|>'].to(device),
|
66 |
+
image_ids[i],
|
67 |
+
self.sptids_dict['<|eoi|>'].to(device)
|
68 |
+
], dim=0)
|
69 |
+
|
70 |
+
temp_masks = torch.tensor(temp_masks).to(device)
|
71 |
+
sequence_ids.append(temp_ids.unsqueeze(0))
|
72 |
+
attention_masks.append(temp_masks.unsqueeze(0))
|
73 |
+
label_ids.append(temp_label_ids.unsqueeze(0))
|
74 |
+
|
75 |
+
return torch.cat(sequence_ids, dim=0), torch.cat(attention_masks, dim=0), torch.cat(label_ids, dim=0)
|
76 |
+
|
77 |
+
def t2i_gen_prompt(self, text_ids, image_ids):
|
78 |
+
|
79 |
+
device = image_ids.device
|
80 |
+
sequence_ids = []
|
81 |
+
attention_masks = []
|
82 |
+
for i in range(len(text_ids)):
|
83 |
+
if len(text_ids[i]) == 0:
|
84 |
+
text_ids[i] = [self.text_tokenizer.bos_token_id]
|
85 |
+
elif text_ids[i][0] != self.text_tokenizer.bos_token_id:
|
86 |
+
text_ids[i] = [self.text_tokenizer.bos_token_id] + text_ids[i]
|
87 |
+
# note that, llama3 tokenizer automatically add the bot token at first but without eot
|
88 |
+
temp_ids = [int(self.sptids_dict['<|t2i|>'])] + text_ids[i] + [self.text_tokenizer.eos_token_id]
|
89 |
+
if self.max_text_len >= len(temp_ids):
|
90 |
+
temp_ids = [self.pad_id] * (self.max_text_len - len(temp_ids)) + temp_ids
|
91 |
+
temp_masks = [0] * (self.max_text_len - len(temp_ids)) + [1] * len(temp_ids)
|
92 |
+
else:
|
93 |
+
temp_ids = temp_ids[:self.max_text_len - 1] + [self.text_tokenizer.eos_token_id]
|
94 |
+
temp_masks = [1] * len(temp_ids) # +2 for two special tokens
|
95 |
+
|
96 |
+
# prompting -- [task token] [sot] [text tokens] [eot] [soi] [image tokens] [eoi]
|
97 |
+
temp_ids = torch.cat([
|
98 |
+
torch.tensor(temp_ids).to(device),
|
99 |
+
self.sptids_dict['<|soi|>'].to(device),
|
100 |
+
image_ids[i],
|
101 |
+
self.sptids_dict['<|eoi|>'].to(device)
|
102 |
+
], dim=0)
|
103 |
+
|
104 |
+
temp_masks = torch.tensor(temp_masks).to(device)
|
105 |
+
sequence_ids.append(temp_ids.unsqueeze(0))
|
106 |
+
attention_masks.append(temp_masks.unsqueeze(0))
|
107 |
+
|
108 |
+
return torch.cat(sequence_ids, dim=0), torch.cat(attention_masks, dim=0)
|
109 |
+
|
110 |
+
# language modeling
|
111 |
+
def lm_prompt(self, text_ids, max_seq_len):
|
112 |
+
|
113 |
+
sequence_ids = []
|
114 |
+
attention_masks = []
|
115 |
+
label_ids = []
|
116 |
+
for i in range(len(text_ids)):
|
117 |
+
if len(text_ids[i]) == 0:
|
118 |
+
text_ids[i] = [self.text_tokenizer.bos_token_id]
|
119 |
+
elif text_ids[i][0] != self.text_tokenizer.bos_token_id:
|
120 |
+
text_ids[i] = [self.text_tokenizer.eos_token_id] + text_ids[i]
|
121 |
+
|
122 |
+
temp_ids = text_ids[i] + [self.text_tokenizer.eos_token_id]
|
123 |
+
|
124 |
+
if max_seq_len >= len(temp_ids):
|
125 |
+
temp_labels_ids = temp_ids + [self.ignore_id] * (max_seq_len - len(temp_ids))
|
126 |
+
temp_ids = temp_ids + [self.pad_id] * (max_seq_len - len(temp_ids))
|
127 |
+
temp_masks = [1] * len(temp_ids) + [0] * (max_seq_len - len(temp_ids))
|
128 |
+
else:
|
129 |
+
# In language modeling, we only process text tokens. We do not add the eos token if the text length
|
130 |
+
# exceeds the max sequence length
|
131 |
+
temp_labels_ids = temp_ids[:max_seq_len]
|
132 |
+
temp_ids = temp_ids[:max_seq_len]
|
133 |
+
temp_masks = [1] * len(temp_ids) # +2 for two special tokens
|
134 |
+
|
135 |
+
# prompting -- [task token] [sot] [text tokens] [eot] [soi] [image tokens] [eoi]
|
136 |
+
temp_ids = torch.tensor(temp_ids)
|
137 |
+
temp_masks = torch.tensor(temp_masks)
|
138 |
+
temp_labels_ids = torch.tensor(temp_labels_ids)
|
139 |
+
|
140 |
+
sequence_ids.append(temp_ids.unsqueeze(0))
|
141 |
+
attention_masks.append(temp_masks.unsqueeze(0))
|
142 |
+
label_ids.append(temp_labels_ids.unsqueeze(0))
|
143 |
+
|
144 |
+
# input_ids, masks, labels
|
145 |
+
return torch.cat(sequence_ids, dim=0), torch.cat(attention_masks, dim=0), torch.cat(label_ids, dim=0)
|
146 |
+
|
147 |
+
def mmu_prompt(self, image_ids, text_ids):
|
148 |
+
device = image_ids.device
|
149 |
+
sequence_ids = []
|
150 |
+
attention_masks = []
|
151 |
+
label_ids = []
|
152 |
+
max_text_len = self.max_text_len - 1
|
153 |
+
for i in range(len(text_ids)):
|
154 |
+
# note that, llama3 tokenizer automatically add the bot token at first but without eot
|
155 |
+
# for empty list []
|
156 |
+
|
157 |
+
if len(text_ids[i]) == 0:
|
158 |
+
text_ids[i] = [self.text_tokenizer.bos_token_id]
|
159 |
+
elif text_ids[i][0] != self.text_tokenizer.bos_token_id:
|
160 |
+
text_ids[i] = [self.text_tokenizer.eos_token_id] + text_ids[i]
|
161 |
+
|
162 |
+
temp_ids = text_ids[i] + [self.text_tokenizer.eos_token_id]
|
163 |
+
|
164 |
+
if max_text_len >= len(temp_ids):
|
165 |
+
# minus 1 because task token was prepended to the former image tokens
|
166 |
+
temp_ids = temp_ids + [self.pad_id] * (max_text_len - len(temp_ids))
|
167 |
+
temp_masks = [1] * (len(temp_ids) + image_ids.shape[-1] + 3) + [0] * (max_text_len - len(temp_ids))
|
168 |
+
else:
|
169 |
+
# should add the eos token
|
170 |
+
temp_ids = temp_ids[:max_text_len - 1] + [self.text_tokenizer.eos_token_id]
|
171 |
+
temp_masks = [1] * (len(temp_ids) + image_ids.shape[-1] + 3) # +2 for two special tokens
|
172 |
+
|
173 |
+
# prompting -- [task token] [sot] [text tokens] [eot] [soi] [image tokens] [eoi]
|
174 |
+
temp_label_ids = torch.cat([
|
175 |
+
torch.tensor([self.ignore_id]).to(device),
|
176 |
+
torch.tensor([self.ignore_id]).to(device),
|
177 |
+
torch.ones_like(image_ids[i]) * self.ignore_id,
|
178 |
+
torch.tensor([self.ignore_id]).to(device),
|
179 |
+
torch.tensor(temp_ids).to(device),
|
180 |
+
], dim=0)
|
181 |
+
|
182 |
+
temp_label_ids = torch.where(temp_label_ids == self.pad_id, self.ignore_id, temp_label_ids)
|
183 |
+
|
184 |
+
temp_ids = torch.cat([
|
185 |
+
self.sptids_dict['<|mmu|>'].to(device), # task token
|
186 |
+
self.sptids_dict['<|soi|>'].to(device),
|
187 |
+
image_ids[i],
|
188 |
+
self.sptids_dict['<|eoi|>'].to(device),
|
189 |
+
torch.tensor(temp_ids).to(device),
|
190 |
+
], dim=0)
|
191 |
+
|
192 |
+
temp_masks = torch.tensor(temp_masks).to(device)
|
193 |
+
sequence_ids.append(temp_ids.unsqueeze(0))
|
194 |
+
attention_masks.append(temp_masks.unsqueeze(0))
|
195 |
+
label_ids.append(temp_label_ids.unsqueeze(0))
|
196 |
+
|
197 |
+
return torch.cat(sequence_ids, dim=0), torch.cat(attention_masks, dim=0), torch.cat(label_ids, dim=0)
|
198 |
+
|
199 |
+
def t2v_prompt(self, text_ids, video_ids):
|
200 |
+
"""
|
201 |
+
:param text_ids:
|
202 |
+
:param video_ids:
|
203 |
+
:return:
|
204 |
+
"""
|
205 |
+
pass
|
206 |
+
|
207 |
+
def i2v_prompt(self, image_ids, video_ids):
|
208 |
+
"""
|
209 |
+
:param image_ids:
|
210 |
+
:param video_ids:
|
211 |
+
:return:
|
212 |
+
"""
|
213 |
+
pass
|
214 |
+
|
215 |
+
def lvg_prompt(self, text_ids, image_ids, labels):
|
216 |
+
|
217 |
+
device = image_ids.device
|
218 |
+
sequence_ids = []
|
219 |
+
attention_masks = []
|
220 |
+
label_ids = []
|
221 |
+
probs = torch.rand(len(text_ids))
|
222 |
+
probs2 = torch.rand(len(text_ids))
|
223 |
+
for i in range(len(text_ids)):
|
224 |
+
|
225 |
+
if len(text_ids[i]) == 0:
|
226 |
+
text_ids[i] = [self.text_tokenizer.bos_token_id]
|
227 |
+
elif text_ids[i][0] != self.text_tokenizer.bos_token_id:
|
228 |
+
text_ids[i] = [self.text_tokenizer.bos_token_id] + text_ids[i]
|
229 |
+
|
230 |
+
temp_ids = [int(self.sptids_dict['<|t2i|>'])] + text_ids[i] + [self.text_tokenizer.eos_token_id]
|
231 |
+
|
232 |
+
# randomly dropout text condition
|
233 |
+
if probs[i] < self.cond_dropout_prob:
|
234 |
+
temp_ids = [int(self.sptids_dict['<|t2i|>']), self.text_tokenizer.bos_token_id,
|
235 |
+
self.text_tokenizer.eos_token_id]
|
236 |
+
|
237 |
+
if self.max_text_len >= len(temp_ids):
|
238 |
+
temp_ids = [self.pad_id] * (self.max_text_len - len(temp_ids)) + temp_ids
|
239 |
+
temp_masks = [0] * (self.max_text_len - len(temp_ids)) + [1] * (len(temp_ids) + image_ids.shape[-1] + 3)
|
240 |
+
else:
|
241 |
+
# should add the eos token
|
242 |
+
temp_ids = temp_ids[:self.max_text_len - 1] + [self.text_tokenizer.eos_token_id]
|
243 |
+
temp_masks = [1] * (len(temp_ids) + image_ids.shape[-1] + 3) # +2 for two special tokens
|
244 |
+
|
245 |
+
# prompting -- [task token] [sot] [text tokens] [eot] [soi] [image tokens] [eoi]
|
246 |
+
temp_label_ids = torch.cat([
|
247 |
+
# should we predict text tokens when doing image reconstruction?
|
248 |
+
torch.tensor(temp_ids).to(device),
|
249 |
+
self.sptids_dict['<|soi|>'].to(device),
|
250 |
+
labels[i],
|
251 |
+
self.sptids_dict['<|eoi|>'].to(device)
|
252 |
+
], dim=0)
|
253 |
+
|
254 |
+
temp_label_ids = torch.where(temp_label_ids == self.pad_id, self.ignore_id, temp_label_ids)
|
255 |
+
|
256 |
+
temp_ids = torch.cat([
|
257 |
+
torch.tensor(temp_ids).to(device),
|
258 |
+
self.sptids_dict['<|soi|>'].to(device),
|
259 |
+
image_ids[i],
|
260 |
+
self.sptids_dict['<|eoi|>'].to(device)
|
261 |
+
], dim=0)
|
262 |
+
|
263 |
+
temp_masks = torch.tensor(temp_masks).to(device)
|
264 |
+
sequence_ids.append(temp_ids.unsqueeze(0))
|
265 |
+
attention_masks.append(temp_masks.unsqueeze(0))
|
266 |
+
label_ids.append(temp_label_ids.unsqueeze(0))
|
267 |
+
|
268 |
+
return torch.cat(sequence_ids, dim=0), torch.cat(attention_masks, dim=0), torch.cat(label_ids, dim=0)
|
269 |
+
|
270 |
+
def lvg_gen_prompt(self, text_ids, image_ids):
|
271 |
+
|
272 |
+
device = image_ids.device
|
273 |
+
sequence_ids = []
|
274 |
+
attention_masks = []
|
275 |
+
for i in range(len(text_ids)):
|
276 |
+
if len(text_ids[i]) == 0:
|
277 |
+
text_ids[i] = [self.text_tokenizer.bos_token_id]
|
278 |
+
elif text_ids[i][0] != self.text_tokenizer.bos_token_id:
|
279 |
+
text_ids[i] = [self.text_tokenizer.bos_token_id] + text_ids[i]
|
280 |
+
# note that, llama3 tokenizer automatically add the bot token at first but without eot
|
281 |
+
temp_ids = [int(self.sptids_dict['<|t2i|>'])] + text_ids[i] + [self.text_tokenizer.eos_token_id]
|
282 |
+
if self.max_text_len >= len(temp_ids):
|
283 |
+
temp_ids = [self.pad_id] * (self.max_text_len - len(temp_ids)) + temp_ids
|
284 |
+
temp_masks = [0] * (self.max_text_len - len(temp_ids)) + [1] * len(temp_ids)
|
285 |
+
else:
|
286 |
+
temp_ids = temp_ids[:self.max_text_len - 1] + [self.text_tokenizer.eos_token_id]
|
287 |
+
temp_masks = [1] * len(temp_ids) # +2 for two special tokens
|
288 |
+
|
289 |
+
# prompting -- [task token] [sot] [text tokens] [eot] [soi] [image tokens] [eoi]
|
290 |
+
temp_ids = torch.cat([
|
291 |
+
torch.tensor(temp_ids).to(device),
|
292 |
+
self.sptids_dict['<|soi|>'].to(device),
|
293 |
+
image_ids[i],
|
294 |
+
self.sptids_dict['<|eoi|>'].to(device)
|
295 |
+
], dim=0)
|
296 |
+
|
297 |
+
temp_masks = torch.tensor(temp_masks).to(device)
|
298 |
+
sequence_ids.append(temp_ids.unsqueeze(0))
|
299 |
+
attention_masks.append(temp_masks.unsqueeze(0))
|
300 |
+
|
301 |
+
return torch.cat(sequence_ids, dim=0), torch.cat(attention_masks, dim=0)
|
302 |
+
|
303 |
+
def mask_prompt(self):
|
304 |
+
pass
|
305 |
+
|
306 |
+
def __call__(self, input, task, padding=True, config=None):
|
307 |
+
"""
|
308 |
+
input (tuple) : data pairs contain text(str), image(tensor), or videos(tensor).
|
309 |
+
task (str) : a flag indicates the current task.
|
310 |
+
"""
|
311 |
+
if task == "t2i":
|
312 |
+
text_ids = self.text_tokenizer(input[0])['input_ids'] # (B, max_len)
|
313 |
+
image_ids = input[1] # (B, #tokens)
|
314 |
+
sequence_ids_with_masks = self.t2i_prompt(text_ids, image_ids, input[2])
|
315 |
+
|
316 |
+
elif task == "t2i_predict_next":
|
317 |
+
text_ids = self.text_tokenizer(input[0])['input_ids'] # (B, max_len)
|
318 |
+
image_ids = input[1] # (B, #tokens)
|
319 |
+
sequence_ids_with_masks = self.t2i_prompt_predict_next(text_ids, image_ids, input[2])
|
320 |
+
|
321 |
+
elif task == "t2i_predict_next_plus_lm":
|
322 |
+
text_ids = self.text_tokenizer(input[0])['input_ids'] # (B, max_len)
|
323 |
+
image_ids = input[1] # (B, #tokens)
|
324 |
+
sequence_ids_with_masks = self.t2i_prompt_predict_next(text_ids[:config.training.batch_size], image_ids,
|
325 |
+
input[2])
|
326 |
+
sequence_ids_with_masks_lm = self.lm_prompt(text_ids[config.training.batch_size:], input[3])
|
327 |
+
return sequence_ids_with_masks, sequence_ids_with_masks_lm
|
328 |
+
|
329 |
+
elif task == "t2i_gen":
|
330 |
+
text_ids = self.text_tokenizer(input[0])['input_ids'] # (B, max_len)
|
331 |
+
image_ids = input[1] # (B, #tokens)
|
332 |
+
sequence_ids_with_masks = self.t2i_gen_prompt(text_ids, image_ids)
|
333 |
+
|
334 |
+
elif task == "lm":
|
335 |
+
text_ids = self.text_tokenizer(input[0], truncation=True)['input_ids'] # (B, max_len)
|
336 |
+
sequence_ids_with_masks = self.lm_prompt(text_ids, input[1])
|
337 |
+
|
338 |
+
elif task == "mmu":
|
339 |
+
image_ids = input[0]
|
340 |
+
text_ids = self.text_tokenizer(input[1])['input_ids']
|
341 |
+
sequence_ids_with_masks = self.mmu_prompt(image_ids, text_ids)
|
342 |
+
|
343 |
+
elif task == "t2v":
|
344 |
+
text_ids = self.text_tokenizer(input[0]['input_ids'])
|
345 |
+
video_ids = self.vision_tokenizer(input[1])
|
346 |
+
sequence_ids_with_masks = self.t2v_prompt(text_ids, video_ids)
|
347 |
+
|
348 |
+
elif task == "i2v":
|
349 |
+
image_ids = self.text_tokenizer(input[0])
|
350 |
+
video_ids = self.vision_tokenizer(input[1])
|
351 |
+
sequence_ids_with_masks = self.i2v_prompt(image_ids, video_ids)
|
352 |
+
|
353 |
+
elif task == "lvg":
|
354 |
+
text_ids = self.text_tokenizer(input[0])['input_ids'] # (B, max_len)
|
355 |
+
image_ids = input[1] # (B, #tokens)
|
356 |
+
sequence_ids_with_masks = self.lvg_prompt(text_ids, image_ids, input[2])
|
357 |
+
|
358 |
+
elif task == "lvg_gen":
|
359 |
+
text_ids = self.text_tokenizer(input[0])['input_ids'] # (B, max_len)
|
360 |
+
image_ids = input[1] # (B, #tokens)
|
361 |
+
sequence_ids_with_masks = self.lvg_gen_prompt(text_ids, image_ids)
|
362 |
+
else:
|
363 |
+
raise NotImplementedError
|
364 |
+
|
365 |
+
return sequence_ids_with_masks
|
366 |
+
|
367 |
+
def create_attention_mask_predict_next(sequence, pad_id=128256, soi_id=128257, eoi_id=128258, rm_pad_in_image=False,
|
368 |
+
return_inverse_mask=True):
|
369 |
+
# sequence is expected to be of shape [N, L]
|
370 |
+
N, L = sequence.shape
|
371 |
+
|
372 |
+
# Masks to identify different types of tokens
|
373 |
+
is_padding = sequence == pad_id
|
374 |
+
|
375 |
+
is_start_image = sequence == soi_id
|
376 |
+
|
377 |
+
is_end_image = sequence == eoi_id
|
378 |
+
|
379 |
+
# Create cumulative sum masks to identify regions of image tokens
|
380 |
+
cumulative_start = torch.cumsum(is_start_image, dim=1)
|
381 |
+
cumulative_end = torch.cumsum(is_end_image, dim=1)
|
382 |
+
in_image_segment = (cumulative_start > cumulative_end) | is_start_image | is_end_image
|
383 |
+
|
384 |
+
is_text = ~(in_image_segment)
|
385 |
+
|
386 |
+
causal_mask = torch.tril(torch.ones((L, L), dtype=torch.bool)).to(sequence.device)
|
387 |
+
|
388 |
+
mask_text = is_text[:, :, None] * causal_mask[None, :, :]
|
389 |
+
|
390 |
+
is_text_image = is_text | in_image_segment
|
391 |
+
|
392 |
+
mask_text_image_bi = is_text_image[:, :, None] * is_text_image[:, None, :]
|
393 |
+
if rm_pad_in_image:
|
394 |
+
sid_img = torch.where(sequence == soi_id)[1]
|
395 |
+
for i in range(mask_text_image_bi.shape[0]):
|
396 |
+
pad_end_idx = torch.where(sequence[i] == pad_id)
|
397 |
+
if len(pad_end_idx[0]) != 0:
|
398 |
+
pad_end_idx = pad_end_idx[0][-1]
|
399 |
+
mask_text[i][pad_end_idx + 1:, :pad_end_idx + 1] = 0
|
400 |
+
id_padding = torch.where(is_padding[i] == True)
|
401 |
+
mask_text_image_bi[i][sid_img[i]:, id_padding[0]] = 0
|
402 |
+
|
403 |
+
mask_text[in_image_segment] = mask_text_image_bi[in_image_segment]
|
404 |
+
# No token attends to padding tokens and padding tokens do not attend to any token
|
405 |
+
if return_inverse_mask:
|
406 |
+
inverted_mask = 1.0 - mask_text.type(sequence.dtype)
|
407 |
+
inverted_mask = inverted_mask.masked_fill(
|
408 |
+
inverted_mask.to(torch.bool), torch.iinfo(sequence.dtype).min
|
409 |
+
)
|
410 |
+
return inverted_mask.unsqueeze(1)
|
411 |
+
else:
|
412 |
+
return mask_text.unsqueeze(1)
|
413 |
+
|
414 |
+
def create_attention_mask_lvg(sequence, pad_id=128256, soi_id=128257, eoi_id=128258, return_inverse_mask=True):
|
415 |
+
# sequence is expected to be of shape [N, L]
|
416 |
+
N, L = sequence.shape
|
417 |
+
# Masks to identify different types of tokens
|
418 |
+
is_padding = sequence == pad_id
|
419 |
+
mask_text_image_bi = torch.tril(torch.ones(N, L, L), diagonal=0).to(sequence.device)
|
420 |
+
|
421 |
+
sid_img = torch.where(sequence == soi_id)[1].reshape(mask_text_image_bi.shape[0], -1)[:, 0]
|
422 |
+
sid_img_for_bi = torch.where(sequence == soi_id)[1].reshape(mask_text_image_bi.shape[0], -1)
|
423 |
+
eid_img_for_bi = torch.where(sequence == eoi_id)[1].reshape(mask_text_image_bi.shape[0], -1)
|
424 |
+
for i in range(N):
|
425 |
+
id_padding = torch.where(is_padding[i] == True)
|
426 |
+
mask_text_image_bi[i][sid_img[i]:, id_padding[0]] = 0
|
427 |
+
for j in range(sid_img_for_bi.shape[-1]):
|
428 |
+
mask_text_image_bi[i][sid_img_for_bi[i, j]:eid_img_for_bi[i, j] + 1,
|
429 |
+
sid_img_for_bi[i, j]:eid_img_for_bi[i, j] + 1] = 1
|
430 |
+
|
431 |
+
# No token attends to padding tokens and padding tokens do not attend to any token
|
432 |
+
if return_inverse_mask:
|
433 |
+
inverted_mask = 1.0 - mask_text_image_bi.type(sequence.dtype)
|
434 |
+
inverted_mask = inverted_mask.masked_fill(
|
435 |
+
inverted_mask.to(torch.bool), torch.iinfo(sequence.dtype).min
|
436 |
+
)
|
437 |
+
return inverted_mask.unsqueeze(1)
|
438 |
+
else:
|
439 |
+
return mask_text_image_bi.unsqueeze(1)
|
440 |
+
|
441 |
+
# texts without attending image regions
|
442 |
+
def create_attention_mask_lvg_v2(sequence, pad_id=128256, soi_id=128257, eoi_id=128258, sot_id=1000, eot_id=1001, return_inverse_mask=True):
|
443 |
+
# sequence is expected to be of shape [N, L]
|
444 |
+
N, L = sequence.shape
|
445 |
+
# Masks to identify different types of tokens
|
446 |
+
is_padding = sequence == pad_id
|
447 |
+
# is_text = torch.where(sequence < 2000, True, False)
|
448 |
+
is_text = torch.where(sequence < pad_id, True, False)
|
449 |
+
mask_text_image_bi = torch.tril(torch.ones(N, L, L), diagonal=0).to(sequence.device).int()
|
450 |
+
sid_text_for_bi = torch.where(sequence == sot_id)[1].reshape(mask_text_image_bi.shape[0], -1)
|
451 |
+
eid_text_for_bi = torch.where(sequence == eot_id)[1].reshape(mask_text_image_bi.shape[0], -1)
|
452 |
+
# import ipdb
|
453 |
+
# ipdb.set_trace()
|
454 |
+
if sot_id == eot_id:
|
455 |
+
if sid_text_for_bi.shape[-1] % 2 != 0:
|
456 |
+
sid_text_for_bi = sid_text_for_bi[:, :-1]
|
457 |
+
eid_text_for_bi = eid_text_for_bi[:, :-1]
|
458 |
+
select_idx = [i for i in range(0, sid_text_for_bi.shape[1], 2)]
|
459 |
+
sid_text_for_bi = sid_text_for_bi[:, select_idx]
|
460 |
+
select_idx = [i+1 for i in range(0, eid_text_for_bi.shape[1], 2)]
|
461 |
+
eid_text_for_bi = eid_text_for_bi[:, select_idx]
|
462 |
+
sid_img_for_bi = torch.where(sequence == soi_id)[1].reshape(mask_text_image_bi.shape[0], -1)
|
463 |
+
eid_img_for_bi = torch.where(sequence == eoi_id)[1].reshape(mask_text_image_bi.shape[0], -1)
|
464 |
+
all_zeros = torch.zeros_like(mask_text_image_bi).int()
|
465 |
+
for i in range(N):
|
466 |
+
all_zeros[i, :, is_text[i]] = 1
|
467 |
+
for j in range(sid_text_for_bi.shape[-1]):
|
468 |
+
all_zeros[i][is_text[i], sid_text_for_bi[i, j]:eid_text_for_bi[i, j]+1] = 1
|
469 |
+
all_zeros[i][~is_text[i], sid_text_for_bi[i, j]:eid_text_for_bi[i, j]+1] = 1
|
470 |
+
for j in range(sid_img_for_bi.shape[-1]):
|
471 |
+
all_zeros[i][~is_text[i], sid_img_for_bi[i, j]:eid_img_for_bi[i, j]+1] = 1
|
472 |
+
mask_text_image_bi = mask_text_image_bi * all_zeros
|
473 |
+
sid_img = torch.where(sequence == soi_id)[1].reshape(mask_text_image_bi.shape[0], -1)[:, 0]
|
474 |
+
|
475 |
+
for i in range(N):
|
476 |
+
id_padding = torch.where(is_padding[i] == True)
|
477 |
+
mask_text_image_bi[i][sid_img[i]:, id_padding[0]] = 0
|
478 |
+
for j in range(sid_img_for_bi.shape[-1]):
|
479 |
+
mask_text_image_bi[i][sid_img_for_bi[i, j]:eid_img_for_bi[i, j]+1, sid_img_for_bi[i, j]:eid_img_for_bi[i, j]+1] = 1
|
480 |
+
|
481 |
+
mask_text_image_bi[:, :, 0] = 1
|
482 |
+
# No token attends to padding tokens and padding tokens do not attend to any token
|
483 |
+
if return_inverse_mask:
|
484 |
+
inverted_mask = 1.0 - mask_text_image_bi.type(sequence.dtype)
|
485 |
+
inverted_mask = inverted_mask.masked_fill(
|
486 |
+
inverted_mask.to(torch.bool), torch.iinfo(sequence.dtype).min
|
487 |
+
)
|
488 |
+
return inverted_mask.unsqueeze(1)
|
489 |
+
else:
|
490 |
+
return mask_text_image_bi.unsqueeze(1)
|
491 |
+
|
492 |
+
def create_attention_mask_for_mmu(sequence, eoi_id=128258, return_inverse_mask=True):
|
493 |
+
N, L = sequence.shape
|
494 |
+
causal_mask = torch.tril(torch.ones((N, 1, L, L), dtype=torch.bool)).to(sequence.device)
|
495 |
+
eoi_image = torch.where(sequence == eoi_id)[1]
|
496 |
+
causal_mask[:, :, :, :eoi_image[0] + 1] = 1
|
497 |
+
|
498 |
+
if return_inverse_mask:
|
499 |
+
inverted_mask = 1.0 - causal_mask.type(sequence.dtype)
|
500 |
+
inverted_mask = inverted_mask.masked_fill(
|
501 |
+
inverted_mask.to(torch.bool), torch.iinfo(sequence.dtype).min
|
502 |
+
)
|
503 |
+
return inverted_mask
|
504 |
+
else:
|
505 |
+
return causal_mask
|
506 |
+
|
507 |
+
def create_attention_mask_for_mmu_vit(
|
508 |
+
sequence,
|
509 |
+
return_inverse_mask=True,
|
510 |
+
system_prompt_len=0
|
511 |
+
):
|
512 |
+
N, L, H = sequence.shape
|
513 |
+
causal_mask = torch.tril(torch.ones((N, 1, L, L), dtype=torch.bool)).to(sequence.device)
|
514 |
+
index = 1 + system_prompt_len + 1 + 576
|
515 |
+
|
516 |
+
causal_mask[:, :, :, :index] = 1
|
517 |
+
if return_inverse_mask:
|
518 |
+
inverted_mask = 1.0 - causal_mask.type(torch.int64)
|
519 |
+
inverted_mask = inverted_mask.masked_fill(
|
520 |
+
inverted_mask.to(torch.bool), torch.iinfo(torch.int64).min
|
521 |
+
)
|
522 |
+
return inverted_mask
|
523 |
+
else:
|
524 |
+
return causal_mask
|
525 |
+
|
526 |
+
|
527 |
+
if __name__ == '__main__':
|
528 |
+
pass
|
requirements.txt
ADDED
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
accelerate==0.21.0
|
2 |
+
aiohttp==3.9.5
|
3 |
+
aiosignal==1.3.1
|
4 |
+
albumentations==0.3.2
|
5 |
+
annotated-types==0.7.0
|
6 |
+
antlr4-python3-runtime==4.9.3
|
7 |
+
anykeystore==0.2
|
8 |
+
asn1crypto==1.5.1
|
9 |
+
asttokens==2.4.1
|
10 |
+
async-timeout==4.0.3
|
11 |
+
attrs==21.2.0
|
12 |
+
bidict==0.23.1
|
13 |
+
blessed==1.20.0
|
14 |
+
boto3==1.34.113
|
15 |
+
botocore==1.34.113
|
16 |
+
braceexpand==0.1.7
|
17 |
+
cachetools==5.3.3
|
18 |
+
certifi==2024.2.2
|
19 |
+
cffi==1.16.0
|
20 |
+
chardet==5.2.0
|
21 |
+
charset-normalizer==3.3.2
|
22 |
+
click==8.1.7
|
23 |
+
clip==0.2.0
|
24 |
+
clip-openai==1.0.post20230121
|
25 |
+
cmake==3.29.3
|
26 |
+
cramjam==2.8.3
|
27 |
+
crcmod==1.7
|
28 |
+
cryptacular==1.6.2
|
29 |
+
cryptography==39.0.2
|
30 |
+
cycler==0.12.1
|
31 |
+
datasets
|
32 |
+
diffusers==0.30.1
|
33 |
+
decorator==5.1.1
|
34 |
+
decord==0.6.0
|
35 |
+
deepspeed==0.14.2
|
36 |
+
defusedxml==0.7.1
|
37 |
+
Deprecated==1.2.14
|
38 |
+
descartes==1.1.0
|
39 |
+
dill==0.3.8
|
40 |
+
distlib==0.3.8
|
41 |
+
distro-info==1.0
|
42 |
+
dnspython==2.6.1
|
43 |
+
docker-pycreds==0.4.0
|
44 |
+
docstring_parser==0.16
|
45 |
+
ecdsa==0.19.0
|
46 |
+
einops==0.6.0
|
47 |
+
exceptiongroup==1.2.1
|
48 |
+
executing==2.0.1
|
49 |
+
fairscale==0.4.13
|
50 |
+
fastparquet==2024.5.0
|
51 |
+
ffmpegcv==0.3.13
|
52 |
+
filelock==3.14.0
|
53 |
+
fire==0.6.0
|
54 |
+
fonttools==4.51.0
|
55 |
+
frozenlist==1.4.1
|
56 |
+
fsspec==2023.6.0
|
57 |
+
ftfy==6.2.0
|
58 |
+
gitdb==4.0.11
|
59 |
+
GitPython==3.1.43
|
60 |
+
gpustat==1.1.1
|
61 |
+
greenlet==3.0.3
|
62 |
+
grpcio==1.64.0
|
63 |
+
h11==0.14.0
|
64 |
+
hjson==3.1.0
|
65 |
+
huggingface-hub==0.23.2
|
66 |
+
hupper==1.12.1
|
67 |
+
idna==3.7
|
68 |
+
imageio==2.34.1
|
69 |
+
imgaug==0.2.6
|
70 |
+
iniconfig==2.0.0
|
71 |
+
ipaddress==1.0.23
|
72 |
+
ipdb==0.13.13
|
73 |
+
ipython==8.18.1
|
74 |
+
jaxtyping==0.2.28
|
75 |
+
jedi==0.19.1
|
76 |
+
Jinja2==3.1.4
|
77 |
+
jmespath==1.0.1
|
78 |
+
joblib==1.4.2
|
79 |
+
jsonargparse==4.14.1
|
80 |
+
jsonlines==4.0.0
|
81 |
+
kiwisolver==1.4.5
|
82 |
+
kornia==0.7.2
|
83 |
+
kornia_rs==0.1.3
|
84 |
+
lazy_loader==0.4
|
85 |
+
lightning==2.2.3
|
86 |
+
lightning-utilities==0.11.2
|
87 |
+
lit==18.1.6
|
88 |
+
MarkupSafe==2.1.5
|
89 |
+
matplotlib==3.5.3
|
90 |
+
matplotlib-inline==0.1.7
|
91 |
+
miscreant==0.3.0
|
92 |
+
mpmath==1.3.0
|
93 |
+
msgpack==1.0.8
|
94 |
+
multidict==6.0.5
|
95 |
+
multiprocess==0.70.16
|
96 |
+
natsort==8.4.0
|
97 |
+
networkx==3.2.1
|
98 |
+
ninja==1.11.1.1
|
99 |
+
numpy==1.24.4
|
100 |
+
nuscenes-devkit==1.1.11
|
101 |
+
oauthlib==3.2.2
|
102 |
+
omegaconf==2.3.0
|
103 |
+
open-clip-torch==2.24.0
|
104 |
+
openai-clip
|
105 |
+
opencv-python==4.9.0.80
|
106 |
+
opencv-python-headless==3.4.18.65
|
107 |
+
packaging==22.0
|
108 |
+
pandas==1.5.3
|
109 |
+
parquet==1.3.1
|
110 |
+
parso==0.8.4
|
111 |
+
PasteDeploy==3.1.0
|
112 |
+
pathlib2==2.3.7.post1
|
113 |
+
pathtools==0.1.2
|
114 |
+
pbkdf2==1.3
|
115 |
+
pexpect==4.9.0
|
116 |
+
pillow==10.3.0
|
117 |
+
plaster==1.1.2
|
118 |
+
plaster-pastedeploy==1.0.1
|
119 |
+
platformdirs==4.2.2
|
120 |
+
plotly==5.22.0
|
121 |
+
pluggy==1.5.0
|
122 |
+
ply==3.11
|
123 |
+
promise==2.3
|
124 |
+
prompt-toolkit==3.0.43
|
125 |
+
protobuf==3.20.3
|
126 |
+
psutil==5.9.8
|
127 |
+
ptyprocess==0.7.0
|
128 |
+
pure-eval==0.2.2
|
129 |
+
py==1.11.0
|
130 |
+
py-cpuinfo==9.0.0
|
131 |
+
py-spy==0.3.14
|
132 |
+
pyarrow==11.0.0
|
133 |
+
pyarrow-hotfix==0.6
|
134 |
+
pyasn1==0.6.0
|
135 |
+
pycocotools==2.0.7
|
136 |
+
pycparser==2.22
|
137 |
+
pycryptodomex==3.20.0
|
138 |
+
pycurl==7.43.0.6
|
139 |
+
pydantic==1.10.15
|
140 |
+
pydantic_core==2.18.3
|
141 |
+
Pygments==2.18.0
|
142 |
+
PyJWT==2.8.0
|
143 |
+
pynvml==11.5.0
|
144 |
+
pyope==0.2.2
|
145 |
+
pyOpenSSL==23.2.0
|
146 |
+
pyparsing==3.1.2
|
147 |
+
pyquaternion==0.9.9
|
148 |
+
pyramid==2.0.2
|
149 |
+
pyramid-mailer==0.15.1
|
150 |
+
pytest==6.2.5
|
151 |
+
python-consul==1.1.0
|
152 |
+
python-dateutil==2.9.0.post0
|
153 |
+
python-engineio==4.9.1
|
154 |
+
python-etcd==0.4.5
|
155 |
+
python-jose==3.3.0
|
156 |
+
python-socketio==5.11.2
|
157 |
+
python3-openid==3.2.0
|
158 |
+
pytorch-extension==0.2
|
159 |
+
pytorch-lightning==2.2.3
|
160 |
+
pytz==2024.1
|
161 |
+
PyYAML==6.0.1
|
162 |
+
regex==2024.5.15
|
163 |
+
repoze.sendmail==4.4.1
|
164 |
+
requests==2.31.0
|
165 |
+
requests-oauthlib==2.0.0
|
166 |
+
rsa==4.9
|
167 |
+
s3transfer==0.10.1
|
168 |
+
safetensors==0.4.3
|
169 |
+
schedule==1.2.2
|
170 |
+
scikit-image==0.22.0
|
171 |
+
scikit-learn==1.5.0
|
172 |
+
scipy==1.13.1
|
173 |
+
sentencepiece==0.2.0
|
174 |
+
sentry-sdk==2.3.1
|
175 |
+
setproctitle==1.3.3
|
176 |
+
Shapely==1.8.5.post1
|
177 |
+
shortuuid==1.0.13
|
178 |
+
simple-websocket==1.0.0
|
179 |
+
six==1.16.0
|
180 |
+
smmap==5.0.1
|
181 |
+
SQLAlchemy==2.0.30
|
182 |
+
stack-data==0.6.3
|
183 |
+
sympy==1.12
|
184 |
+
taming-transformers-rom1504==0.0.6
|
185 |
+
tenacity==8.3.0
|
186 |
+
tensorboardX==2.6.2.2
|
187 |
+
termcolor==2.4.0
|
188 |
+
threadpoolctl==3.5.0
|
189 |
+
thriftpy2==0.5.0
|
190 |
+
tifffile==2024.5.22
|
191 |
+
timm==1.0.3
|
192 |
+
tokenizers==0.19.1
|
193 |
+
toml==0.10.2
|
194 |
+
tomli==2.0.1
|
195 |
+
torch==2.2.1
|
196 |
+
torch-fidelity==0.3.0
|
197 |
+
torchmetrics==1.4.0.post0
|
198 |
+
torchvision==0.17.1
|
199 |
+
tox==3.28.0
|
200 |
+
tqdm==4.66.4
|
201 |
+
traitlets==5.14.3
|
202 |
+
transaction==4.0
|
203 |
+
transformers==4.41.1
|
204 |
+
translationstring==1.4
|
205 |
+
triton==2.2.0
|
206 |
+
typeguard==2.13.3
|
207 |
+
typing_extensions==4.12.0
|
208 |
+
tzdata==2024.1
|
209 |
+
urllib3==1.26.18
|
210 |
+
velruse==1.1.1
|
211 |
+
venusian==3.1.0
|
212 |
+
virtualenv==20.26.2
|
213 |
+
wandb==0.17.0
|
214 |
+
watchdog==4.0.1
|
215 |
+
wcwidth==0.2.13
|
216 |
+
webdataset==0.2.86
|
217 |
+
WebOb==1.8.7
|
218 |
+
websocket-client==1.8.0
|
219 |
+
wrapt==1.16.0
|
220 |
+
wsproto==1.2.0
|
221 |
+
WTForms==3.1.2
|
222 |
+
wtforms-recaptcha==0.3.2
|
223 |
+
xformers==0.0.25
|
224 |
+
xxhash==3.4.1
|
225 |
+
yarl==1.9.4
|
226 |
+
zope.deprecation==5.0
|
227 |
+
zope.interface==6.4.post2
|
228 |
+
zope.sqlalchemy==3.1
|
training/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
training/conversation.py
ADDED
@@ -0,0 +1,432 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from LLaVA: https://github.com/haotian-liu/LLaVA.git
|
2 |
+
import dataclasses
|
3 |
+
from enum import auto, Enum
|
4 |
+
from typing import List, Tuple
|
5 |
+
import base64
|
6 |
+
from io import BytesIO
|
7 |
+
from PIL import Image
|
8 |
+
|
9 |
+
|
10 |
+
class SeparatorStyle(Enum):
|
11 |
+
"""Different separator style."""
|
12 |
+
SINGLE = auto()
|
13 |
+
TWO = auto()
|
14 |
+
MPT = auto()
|
15 |
+
PLAIN = auto()
|
16 |
+
LLAMA_2 = auto()
|
17 |
+
|
18 |
+
|
19 |
+
@dataclasses.dataclass
|
20 |
+
class Conversation:
|
21 |
+
"""A class that keeps all conversation history."""
|
22 |
+
system: str
|
23 |
+
roles: List[str]
|
24 |
+
messages: List[List[str]]
|
25 |
+
offset: int
|
26 |
+
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
|
27 |
+
sep: str = "###"
|
28 |
+
sep2: str = None
|
29 |
+
version: str = "Unknown"
|
30 |
+
|
31 |
+
skip_next: bool = False
|
32 |
+
|
33 |
+
def get_prompt(self):
|
34 |
+
messages = self.messages
|
35 |
+
if len(messages) > 0 and type(messages[0][1]) is tuple:
|
36 |
+
messages = self.messages.copy()
|
37 |
+
init_role, init_msg = messages[0].copy()
|
38 |
+
init_msg = init_msg[0].replace("<image>", "").strip()
|
39 |
+
if 'mmtag' in self.version:
|
40 |
+
messages[0] = (init_role, init_msg)
|
41 |
+
messages.insert(0, (self.roles[0], "<Image><image></Image>"))
|
42 |
+
messages.insert(1, (self.roles[1], "Received."))
|
43 |
+
else:
|
44 |
+
messages[0] = (init_role, "<image>\n" + init_msg)
|
45 |
+
|
46 |
+
if self.sep_style == SeparatorStyle.SINGLE:
|
47 |
+
ret = self.system + self.sep
|
48 |
+
for role, message in messages:
|
49 |
+
if message:
|
50 |
+
if type(message) is tuple:
|
51 |
+
message, _, _ = message
|
52 |
+
ret += role + ": " + message + self.sep
|
53 |
+
else:
|
54 |
+
ret += role + ":"
|
55 |
+
elif self.sep_style == SeparatorStyle.TWO:
|
56 |
+
seps = [self.sep, self.sep2]
|
57 |
+
ret = self.system + seps[0]
|
58 |
+
for i, (role, message) in enumerate(messages):
|
59 |
+
if message:
|
60 |
+
if type(message) is tuple:
|
61 |
+
message, _, _ = message
|
62 |
+
ret += role + ": " + message + seps[i % 2]
|
63 |
+
else:
|
64 |
+
ret += role + ":"
|
65 |
+
elif self.sep_style == SeparatorStyle.MPT:
|
66 |
+
ret = self.system + self.sep
|
67 |
+
for role, message in messages:
|
68 |
+
if message:
|
69 |
+
if type(message) is tuple:
|
70 |
+
message, _, _ = message
|
71 |
+
ret += role + message + self.sep
|
72 |
+
else:
|
73 |
+
ret += role
|
74 |
+
elif self.sep_style == SeparatorStyle.LLAMA_2:
|
75 |
+
wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n" if len(msg) > 0 else msg
|
76 |
+
wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
|
77 |
+
ret = ""
|
78 |
+
|
79 |
+
for i, (role, message) in enumerate(messages):
|
80 |
+
if i == 0:
|
81 |
+
assert message, "first message should not be none"
|
82 |
+
assert role == self.roles[0], "first message should come from user"
|
83 |
+
if message:
|
84 |
+
if type(message) is tuple:
|
85 |
+
message, _, _ = message
|
86 |
+
if i == 0: message = wrap_sys(self.system) + message
|
87 |
+
if i % 2 == 0:
|
88 |
+
message = wrap_inst(message)
|
89 |
+
ret += self.sep + message
|
90 |
+
else:
|
91 |
+
ret += " " + message + " " + self.sep2
|
92 |
+
else:
|
93 |
+
ret += ""
|
94 |
+
ret = ret.lstrip(self.sep)
|
95 |
+
elif self.sep_style == SeparatorStyle.PLAIN:
|
96 |
+
seps = [self.sep, self.sep2]
|
97 |
+
ret = self.system
|
98 |
+
for i, (role, message) in enumerate(messages):
|
99 |
+
if message:
|
100 |
+
if type(message) is tuple:
|
101 |
+
message, _, _ = message
|
102 |
+
ret += message + seps[i % 2]
|
103 |
+
else:
|
104 |
+
ret += ""
|
105 |
+
else:
|
106 |
+
raise ValueError(f"Invalid style: {self.sep_style}")
|
107 |
+
|
108 |
+
return ret
|
109 |
+
|
110 |
+
def append_message(self, role, message):
|
111 |
+
self.messages.append([role, message])
|
112 |
+
|
113 |
+
def process_image(self, image, image_process_mode, return_pil=False, image_format='PNG', max_len=1344, min_len=672):
|
114 |
+
if image_process_mode == "Pad":
|
115 |
+
def expand2square(pil_img, background_color=(122, 116, 104)):
|
116 |
+
width, height = pil_img.size
|
117 |
+
if width == height:
|
118 |
+
return pil_img
|
119 |
+
elif width > height:
|
120 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
121 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
122 |
+
return result
|
123 |
+
else:
|
124 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
125 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
126 |
+
return result
|
127 |
+
image = expand2square(image)
|
128 |
+
elif image_process_mode in ["Default", "Crop"]:
|
129 |
+
pass
|
130 |
+
elif image_process_mode == "Resize":
|
131 |
+
image = image.resize((336, 336))
|
132 |
+
else:
|
133 |
+
raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
|
134 |
+
if max(image.size) > max_len:
|
135 |
+
max_hw, min_hw = max(image.size), min(image.size)
|
136 |
+
aspect_ratio = max_hw / min_hw
|
137 |
+
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
|
138 |
+
longest_edge = int(shortest_edge * aspect_ratio)
|
139 |
+
W, H = image.size
|
140 |
+
if H > W:
|
141 |
+
H, W = longest_edge, shortest_edge
|
142 |
+
else:
|
143 |
+
H, W = shortest_edge, longest_edge
|
144 |
+
image = image.resize((W, H))
|
145 |
+
if return_pil:
|
146 |
+
return image
|
147 |
+
else:
|
148 |
+
buffered = BytesIO()
|
149 |
+
image.save(buffered, format=image_format)
|
150 |
+
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
|
151 |
+
return img_b64_str
|
152 |
+
|
153 |
+
def get_images(self, return_pil=False):
|
154 |
+
images = []
|
155 |
+
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
156 |
+
if i % 2 == 0:
|
157 |
+
if type(msg) is tuple:
|
158 |
+
msg, image, image_process_mode = msg
|
159 |
+
image = self.process_image(image, image_process_mode, return_pil=return_pil)
|
160 |
+
images.append(image)
|
161 |
+
return images
|
162 |
+
|
163 |
+
def to_gradio_chatbot(self):
|
164 |
+
ret = []
|
165 |
+
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
166 |
+
if i % 2 == 0:
|
167 |
+
if type(msg) is tuple:
|
168 |
+
msg, image, image_process_mode = msg
|
169 |
+
img_b64_str = self.process_image(
|
170 |
+
image, "Default", return_pil=False,
|
171 |
+
image_format='JPEG')
|
172 |
+
img_str = f'<img src="data:image/jpeg;base64,{img_b64_str}" alt="user upload image" />'
|
173 |
+
msg = img_str + msg.replace('<image>', '').strip()
|
174 |
+
ret.append([msg, None])
|
175 |
+
else:
|
176 |
+
ret.append([msg, None])
|
177 |
+
else:
|
178 |
+
ret[-1][-1] = msg
|
179 |
+
return ret
|
180 |
+
|
181 |
+
def copy(self):
|
182 |
+
return Conversation(
|
183 |
+
system=self.system,
|
184 |
+
roles=self.roles,
|
185 |
+
messages=[[x, y] for x, y in self.messages],
|
186 |
+
offset=self.offset,
|
187 |
+
sep_style=self.sep_style,
|
188 |
+
sep=self.sep,
|
189 |
+
sep2=self.sep2,
|
190 |
+
version=self.version)
|
191 |
+
|
192 |
+
def dict(self):
|
193 |
+
if len(self.get_images()) > 0:
|
194 |
+
return {
|
195 |
+
"system": self.system,
|
196 |
+
"roles": self.roles,
|
197 |
+
"messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
|
198 |
+
"offset": self.offset,
|
199 |
+
"sep": self.sep,
|
200 |
+
"sep2": self.sep2,
|
201 |
+
}
|
202 |
+
return {
|
203 |
+
"system": self.system,
|
204 |
+
"roles": self.roles,
|
205 |
+
"messages": self.messages,
|
206 |
+
"offset": self.offset,
|
207 |
+
"sep": self.sep,
|
208 |
+
"sep2": self.sep2,
|
209 |
+
}
|
210 |
+
|
211 |
+
|
212 |
+
conv_vicuna_v0 = Conversation(
|
213 |
+
system="A chat between a curious human and an artificial intelligence assistant. "
|
214 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
215 |
+
roles=("Human", "Assistant"),
|
216 |
+
messages=(
|
217 |
+
("Human", "What are the key differences between renewable and non-renewable energy sources?"),
|
218 |
+
("Assistant",
|
219 |
+
"Renewable energy sources are those that can be replenished naturally in a relatively "
|
220 |
+
"short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
|
221 |
+
"Non-renewable energy sources, on the other hand, are finite and will eventually be "
|
222 |
+
"depleted, such as coal, oil, and natural gas. Here are some key differences between "
|
223 |
+
"renewable and non-renewable energy sources:\n"
|
224 |
+
"1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
|
225 |
+
"energy sources are finite and will eventually run out.\n"
|
226 |
+
"2. Environmental impact: Renewable energy sources have a much lower environmental impact "
|
227 |
+
"than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
|
228 |
+
"and other negative effects.\n"
|
229 |
+
"3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
|
230 |
+
"have lower operational costs than non-renewable sources.\n"
|
231 |
+
"4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
|
232 |
+
"locations than non-renewable sources.\n"
|
233 |
+
"5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
|
234 |
+
"situations and needs, while non-renewable sources are more rigid and inflexible.\n"
|
235 |
+
"6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
|
236 |
+
"non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
|
237 |
+
),
|
238 |
+
offset=2,
|
239 |
+
sep_style=SeparatorStyle.SINGLE,
|
240 |
+
sep="###",
|
241 |
+
)
|
242 |
+
|
243 |
+
conv_vicuna_v1 = Conversation(
|
244 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
245 |
+
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
246 |
+
roles=("USER", "ASSISTANT"),
|
247 |
+
version="v1",
|
248 |
+
messages=(),
|
249 |
+
offset=0,
|
250 |
+
sep_style=SeparatorStyle.TWO,
|
251 |
+
sep=" ",
|
252 |
+
sep2="</s>",
|
253 |
+
)
|
254 |
+
|
255 |
+
conv_llama_2 = Conversation(
|
256 |
+
system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
|
257 |
+
|
258 |
+
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
|
259 |
+
roles=("USER", "ASSISTANT"),
|
260 |
+
version="llama_v2",
|
261 |
+
messages=(),
|
262 |
+
offset=0,
|
263 |
+
sep_style=SeparatorStyle.LLAMA_2,
|
264 |
+
sep="<s>",
|
265 |
+
sep2="</s>",
|
266 |
+
)
|
267 |
+
|
268 |
+
conv_llava_llama_2 = Conversation(
|
269 |
+
system="You are a helpful language and vision assistant. "
|
270 |
+
"You are able to understand the visual content that the user provides, "
|
271 |
+
"and assist the user with a variety of tasks using natural language.",
|
272 |
+
roles=("USER", "ASSISTANT"),
|
273 |
+
version="llama_v2",
|
274 |
+
messages=(),
|
275 |
+
offset=0,
|
276 |
+
sep_style=SeparatorStyle.LLAMA_2,
|
277 |
+
sep="<s>",
|
278 |
+
sep2="</s>",
|
279 |
+
)
|
280 |
+
|
281 |
+
conv_mpt = Conversation(
|
282 |
+
system="""<|im_start|>system
|
283 |
+
A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
|
284 |
+
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
|
285 |
+
version="mpt",
|
286 |
+
messages=(),
|
287 |
+
offset=0,
|
288 |
+
sep_style=SeparatorStyle.MPT,
|
289 |
+
sep="<|im_end|>",
|
290 |
+
)
|
291 |
+
|
292 |
+
conv_llava_plain = Conversation(
|
293 |
+
system="",
|
294 |
+
roles=("", ""),
|
295 |
+
messages=(
|
296 |
+
),
|
297 |
+
offset=0,
|
298 |
+
sep_style=SeparatorStyle.PLAIN,
|
299 |
+
sep="\n",
|
300 |
+
)
|
301 |
+
|
302 |
+
conv_llava_v0 = Conversation(
|
303 |
+
system="A chat between a curious human and an artificial intelligence assistant. "
|
304 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
305 |
+
roles=("Human", "Assistant"),
|
306 |
+
messages=(
|
307 |
+
),
|
308 |
+
offset=0,
|
309 |
+
sep_style=SeparatorStyle.SINGLE,
|
310 |
+
sep="###",
|
311 |
+
)
|
312 |
+
|
313 |
+
conv_llava_v0_mmtag = Conversation(
|
314 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
315 |
+
"The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
|
316 |
+
"The visual content will be provided with the following format: <Image>visual content</Image>.",
|
317 |
+
roles=("Human", "Assistant"),
|
318 |
+
messages=(
|
319 |
+
),
|
320 |
+
offset=0,
|
321 |
+
sep_style=SeparatorStyle.SINGLE,
|
322 |
+
sep="###",
|
323 |
+
version="v0_mmtag",
|
324 |
+
)
|
325 |
+
|
326 |
+
conv_llava_v1 = Conversation(
|
327 |
+
system="A chat between a curious human and an artificial intelligence assistant. "
|
328 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
329 |
+
roles=("USER", "ASSISTANT"),
|
330 |
+
version="v1",
|
331 |
+
messages=(),
|
332 |
+
offset=0,
|
333 |
+
sep_style=SeparatorStyle.TWO,
|
334 |
+
sep=" ",
|
335 |
+
sep2="</s>",
|
336 |
+
)
|
337 |
+
|
338 |
+
conv_llava_v1_mmtag = Conversation(
|
339 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
340 |
+
"The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
|
341 |
+
"The visual content will be provided with the following format: <Image>visual content</Image>.",
|
342 |
+
roles=("USER", "ASSISTANT"),
|
343 |
+
messages=(),
|
344 |
+
offset=0,
|
345 |
+
sep_style=SeparatorStyle.TWO,
|
346 |
+
sep=" ",
|
347 |
+
sep2="</s>",
|
348 |
+
version="v1_mmtag",
|
349 |
+
)
|
350 |
+
|
351 |
+
conv_mistral_instruct = Conversation(
|
352 |
+
system="",
|
353 |
+
roles=("USER", "ASSISTANT"),
|
354 |
+
version="llama_v2",
|
355 |
+
messages=(),
|
356 |
+
offset=0,
|
357 |
+
sep_style=SeparatorStyle.LLAMA_2,
|
358 |
+
sep="",
|
359 |
+
sep2="</s>",
|
360 |
+
)
|
361 |
+
|
362 |
+
conv_chatml_direct = Conversation(
|
363 |
+
system="""<|im_start|>system
|
364 |
+
Answer the questions.""",
|
365 |
+
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
|
366 |
+
version="mpt",
|
367 |
+
messages=(),
|
368 |
+
offset=0,
|
369 |
+
sep_style=SeparatorStyle.MPT,
|
370 |
+
sep="<|im_end|>",
|
371 |
+
)
|
372 |
+
|
373 |
+
conv_phi3_instruct = Conversation(
|
374 |
+
system="""<|system|>\nYou are a helpful AI assistant.""",
|
375 |
+
roles=("\n<|user|>\n", "\n<|assistant|>\n"),
|
376 |
+
version="phi3",
|
377 |
+
messages=(),
|
378 |
+
offset=0,
|
379 |
+
sep_style=SeparatorStyle.MPT,
|
380 |
+
sep="<|end|>",
|
381 |
+
)
|
382 |
+
|
383 |
+
# conv_phi_v0 = Conversation(
|
384 |
+
# system="A chat between a curious user and an artificial intelligence assistant. "
|
385 |
+
# "The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
386 |
+
# roles=("USER", "ASSISTANT"),
|
387 |
+
# version="v0",
|
388 |
+
# messages=(),
|
389 |
+
# offset=0,
|
390 |
+
# sep_style=SeparatorStyle.TWO,
|
391 |
+
# sep=" ",
|
392 |
+
# sep2="<|endoftext|>",
|
393 |
+
# )
|
394 |
+
|
395 |
+
conv_phi_v0 = Conversation(
|
396 |
+
system="",
|
397 |
+
roles=("USER", "ASSISTANT"),
|
398 |
+
version="v0",
|
399 |
+
messages=(),
|
400 |
+
offset=0,
|
401 |
+
sep_style=SeparatorStyle.TWO,
|
402 |
+
sep=" ",
|
403 |
+
sep2="<|endoftext|>",
|
404 |
+
)
|
405 |
+
|
406 |
+
default_conversation = conv_vicuna_v1
|
407 |
+
conv_templates = {
|
408 |
+
"default": conv_vicuna_v0,
|
409 |
+
"v0": conv_vicuna_v0,
|
410 |
+
"v1": conv_vicuna_v1,
|
411 |
+
"vicuna_v1": conv_vicuna_v1,
|
412 |
+
"llama_2": conv_llama_2,
|
413 |
+
"mistral_instruct": conv_mistral_instruct,
|
414 |
+
"chatml_direct": conv_chatml_direct,
|
415 |
+
"mistral_direct": conv_chatml_direct,
|
416 |
+
|
417 |
+
"plain": conv_llava_plain,
|
418 |
+
"v0_plain": conv_llava_plain,
|
419 |
+
"llava_v0": conv_llava_v0,
|
420 |
+
"v0_mmtag": conv_llava_v0_mmtag,
|
421 |
+
"llava_v1": conv_llava_v1,
|
422 |
+
"v1_mmtag": conv_llava_v1_mmtag,
|
423 |
+
"llava_llama_2": conv_llava_llama_2,
|
424 |
+
"phi3_instruct": conv_phi3_instruct,
|
425 |
+
"phi1.5": conv_phi_v0,
|
426 |
+
|
427 |
+
"mpt": conv_mpt,
|
428 |
+
}
|
429 |
+
|
430 |
+
|
431 |
+
if __name__ == "__main__":
|
432 |
+
print(default_conversation.get_prompt())
|
training/utils.py
ADDED
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import random
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from omegaconf import DictConfig, ListConfig, OmegaConf
|
6 |
+
from typing import Any, List, Tuple, Union
|
7 |
+
|
8 |
+
|
9 |
+
##################################################
|
10 |
+
# config utils
|
11 |
+
##################################################
|
12 |
+
def get_config():
|
13 |
+
cli_conf = OmegaConf.from_cli()
|
14 |
+
yaml_conf = OmegaConf.load(cli_conf.config)
|
15 |
+
conf = OmegaConf.merge(yaml_conf, cli_conf)
|
16 |
+
|
17 |
+
return conf
|
18 |
+
|
19 |
+
|
20 |
+
def flatten_omega_conf(cfg: Any, resolve: bool = False) -> List[Tuple[str, Any]]:
|
21 |
+
ret = []
|
22 |
+
|
23 |
+
def handle_dict(key: Any, value: Any, resolve: bool) -> List[Tuple[str, Any]]:
|
24 |
+
return [(f"{key}.{k1}", v1) for k1, v1 in flatten_omega_conf(value, resolve=resolve)]
|
25 |
+
|
26 |
+
def handle_list(key: Any, value: Any, resolve: bool) -> List[Tuple[str, Any]]:
|
27 |
+
return [(f"{key}.{idx}", v1) for idx, v1 in flatten_omega_conf(value, resolve=resolve)]
|
28 |
+
|
29 |
+
if isinstance(cfg, DictConfig):
|
30 |
+
for k, v in cfg.items_ex(resolve=resolve):
|
31 |
+
if isinstance(v, DictConfig):
|
32 |
+
ret.extend(handle_dict(k, v, resolve=resolve))
|
33 |
+
elif isinstance(v, ListConfig):
|
34 |
+
ret.extend(handle_list(k, v, resolve=resolve))
|
35 |
+
else:
|
36 |
+
ret.append((str(k), v))
|
37 |
+
elif isinstance(cfg, ListConfig):
|
38 |
+
for idx, v in enumerate(cfg._iter_ex(resolve=resolve)):
|
39 |
+
if isinstance(v, DictConfig):
|
40 |
+
ret.extend(handle_dict(idx, v, resolve=resolve))
|
41 |
+
elif isinstance(v, ListConfig):
|
42 |
+
ret.extend(handle_list(idx, v, resolve=resolve))
|
43 |
+
else:
|
44 |
+
ret.append((str(idx), v))
|
45 |
+
else:
|
46 |
+
assert False
|
47 |
+
|
48 |
+
return ret
|
49 |
+
|
50 |
+
|
51 |
+
##################################################
|
52 |
+
# training utils
|
53 |
+
##################################################
|
54 |
+
def soft_target_cross_entropy(logits, targets, soft_targets):
|
55 |
+
# ignore the first token from logits and targets (class id token)
|
56 |
+
logits = logits[:, 1:]
|
57 |
+
targets = targets[:, 1:]
|
58 |
+
|
59 |
+
logits = logits[..., : soft_targets.shape[-1]]
|
60 |
+
|
61 |
+
log_probs = F.log_softmax(logits, dim=-1)
|
62 |
+
padding_mask = targets.eq(-100)
|
63 |
+
|
64 |
+
loss = torch.sum(-soft_targets * log_probs, dim=-1)
|
65 |
+
loss.masked_fill_(padding_mask, 0.0)
|
66 |
+
|
67 |
+
# Take the mean over the label dimensions, then divide by the number of active elements (i.e. not-padded):
|
68 |
+
num_active_elements = padding_mask.numel() - padding_mask.long().sum()
|
69 |
+
loss = loss.sum() / num_active_elements
|
70 |
+
return loss
|
71 |
+
|
72 |
+
|
73 |
+
def get_loss_weight(t, mask, min_val=0.3):
|
74 |
+
return 1 - (1 - mask) * ((1 - t) * (1 - min_val))[:, None]
|
75 |
+
|
76 |
+
|
77 |
+
def mask_or_random_replace_tokens(image_tokens, mask_id, config, mask_schedule, is_train=True):
|
78 |
+
batch_size, seq_len = image_tokens.shape
|
79 |
+
|
80 |
+
if not is_train and config.training.get("eval_mask_ratios", None):
|
81 |
+
mask_prob = random.choices(config.training.eval_mask_ratios, k=batch_size)
|
82 |
+
mask_prob = torch.tensor(mask_prob, device=image_tokens.device)
|
83 |
+
else:
|
84 |
+
# Sample a random timestep for each image
|
85 |
+
timesteps = torch.rand(batch_size, device=image_tokens.device)
|
86 |
+
# Sample a random mask probability for each image using timestep and cosine schedule
|
87 |
+
mask_prob = mask_schedule(timesteps)
|
88 |
+
mask_prob = mask_prob.clip(config.training.min_masking_rate)
|
89 |
+
|
90 |
+
# creat a random mask for each image
|
91 |
+
num_token_masked = (seq_len * mask_prob).round().clamp(min=1)
|
92 |
+
|
93 |
+
mask_contiguous_region_prob = config.training.get("mask_contiguous_region_prob", None)
|
94 |
+
|
95 |
+
if mask_contiguous_region_prob is None:
|
96 |
+
mask_contiguous_region = False
|
97 |
+
else:
|
98 |
+
mask_contiguous_region = random.random() < mask_contiguous_region_prob
|
99 |
+
|
100 |
+
if not mask_contiguous_region:
|
101 |
+
batch_randperm = torch.rand(batch_size, seq_len, device=image_tokens.device).argsort(dim=-1)
|
102 |
+
mask = batch_randperm < num_token_masked.unsqueeze(-1)
|
103 |
+
else:
|
104 |
+
resolution = int(seq_len ** 0.5)
|
105 |
+
mask = torch.zeros((batch_size, resolution, resolution), device=image_tokens.device)
|
106 |
+
|
107 |
+
# TODO - would be nice to vectorize
|
108 |
+
for batch_idx, num_token_masked_ in enumerate(num_token_masked):
|
109 |
+
num_token_masked_ = int(num_token_masked_.item())
|
110 |
+
|
111 |
+
# NOTE: a bit handwavy with the bounds but gets a rectangle of ~num_token_masked_
|
112 |
+
num_token_masked_height = random.randint(
|
113 |
+
math.ceil(num_token_masked_ / resolution), min(resolution, num_token_masked_)
|
114 |
+
)
|
115 |
+
num_token_masked_height = min(num_token_masked_height, resolution)
|
116 |
+
|
117 |
+
num_token_masked_width = math.ceil(num_token_masked_ / num_token_masked_height)
|
118 |
+
num_token_masked_width = min(num_token_masked_width, resolution)
|
119 |
+
|
120 |
+
start_idx_height = random.randint(0, resolution - num_token_masked_height)
|
121 |
+
start_idx_width = random.randint(0, resolution - num_token_masked_width)
|
122 |
+
|
123 |
+
mask[
|
124 |
+
batch_idx,
|
125 |
+
start_idx_height: start_idx_height + num_token_masked_height,
|
126 |
+
start_idx_width: start_idx_width + num_token_masked_width,
|
127 |
+
] = 1
|
128 |
+
|
129 |
+
mask = mask.reshape(batch_size, seq_len)
|
130 |
+
mask = mask.to(torch.bool)
|
131 |
+
|
132 |
+
# mask images and create input and labels
|
133 |
+
if config.training.get("noise_type", "mask"):
|
134 |
+
input_ids = torch.where(mask, mask_id, image_tokens)
|
135 |
+
elif config.training.get("noise_type", "random_replace"):
|
136 |
+
# sample random tokens from the vocabulary
|
137 |
+
random_tokens = torch.randint_like(
|
138 |
+
image_tokens, low=0, high=config.model.codebook_size, device=image_tokens.device
|
139 |
+
)
|
140 |
+
input_ids = torch.where(mask, random_tokens, image_tokens)
|
141 |
+
else:
|
142 |
+
raise ValueError(f"noise_type {config.training.noise_type} not supported")
|
143 |
+
|
144 |
+
if (
|
145 |
+
config.training.get("predict_all_tokens", False)
|
146 |
+
or config.training.get("noise_type", "mask") == "random_replace"
|
147 |
+
):
|
148 |
+
labels = image_tokens
|
149 |
+
loss_weight = get_loss_weight(mask_prob, mask.long())
|
150 |
+
else:
|
151 |
+
labels = torch.where(mask, image_tokens, -100)
|
152 |
+
loss_weight = None
|
153 |
+
|
154 |
+
return input_ids, labels, loss_weight, mask_prob
|
155 |
+
|
156 |
+
|
157 |
+
##################################################
|
158 |
+
# misc
|
159 |
+
##################################################
|
160 |
+
class AverageMeter(object):
|
161 |
+
"""Computes and stores the average and current value"""
|
162 |
+
|
163 |
+
def __init__(self):
|
164 |
+
self.reset()
|
165 |
+
|
166 |
+
def reset(self):
|
167 |
+
self.val = 0
|
168 |
+
self.avg = 0
|
169 |
+
self.sum = 0
|
170 |
+
self.count = 0
|
171 |
+
|
172 |
+
def update(self, val, n=1):
|
173 |
+
self.val = val
|
174 |
+
self.sum += val * n
|
175 |
+
self.count += n
|
176 |
+
self.avg = self.sum / self.count
|
177 |
+
|
178 |
+
from torchvision import transforms
|
179 |
+
def image_transform(image, resolution=256, normalize=True):
|
180 |
+
image = transforms.Resize(resolution, interpolation=transforms.InterpolationMode.BILINEAR)(image)
|
181 |
+
image = transforms.CenterCrop((resolution, resolution))(image)
|
182 |
+
image = transforms.ToTensor()(image)
|
183 |
+
if normalize:
|
184 |
+
image = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)(image)
|
185 |
+
return image
|
training_utils.py
ADDED
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import random
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from omegaconf import DictConfig, ListConfig, OmegaConf
|
6 |
+
from typing import Any, List, Tuple, Union
|
7 |
+
|
8 |
+
|
9 |
+
##################################################
|
10 |
+
# config utils
|
11 |
+
##################################################
|
12 |
+
def get_config():
|
13 |
+
cli_conf = OmegaConf.from_cli()
|
14 |
+
yaml_conf = OmegaConf.load(cli_conf.config)
|
15 |
+
conf = OmegaConf.merge(yaml_conf, cli_conf)
|
16 |
+
|
17 |
+
return conf
|
18 |
+
|
19 |
+
|
20 |
+
def flatten_omega_conf(cfg: Any, resolve: bool = False) -> List[Tuple[str, Any]]:
|
21 |
+
ret = []
|
22 |
+
|
23 |
+
def handle_dict(key: Any, value: Any, resolve: bool) -> List[Tuple[str, Any]]:
|
24 |
+
return [(f"{key}.{k1}", v1) for k1, v1 in flatten_omega_conf(value, resolve=resolve)]
|
25 |
+
|
26 |
+
def handle_list(key: Any, value: Any, resolve: bool) -> List[Tuple[str, Any]]:
|
27 |
+
return [(f"{key}.{idx}", v1) for idx, v1 in flatten_omega_conf(value, resolve=resolve)]
|
28 |
+
|
29 |
+
if isinstance(cfg, DictConfig):
|
30 |
+
for k, v in cfg.items_ex(resolve=resolve):
|
31 |
+
if isinstance(v, DictConfig):
|
32 |
+
ret.extend(handle_dict(k, v, resolve=resolve))
|
33 |
+
elif isinstance(v, ListConfig):
|
34 |
+
ret.extend(handle_list(k, v, resolve=resolve))
|
35 |
+
else:
|
36 |
+
ret.append((str(k), v))
|
37 |
+
elif isinstance(cfg, ListConfig):
|
38 |
+
for idx, v in enumerate(cfg._iter_ex(resolve=resolve)):
|
39 |
+
if isinstance(v, DictConfig):
|
40 |
+
ret.extend(handle_dict(idx, v, resolve=resolve))
|
41 |
+
elif isinstance(v, ListConfig):
|
42 |
+
ret.extend(handle_list(idx, v, resolve=resolve))
|
43 |
+
else:
|
44 |
+
ret.append((str(idx), v))
|
45 |
+
else:
|
46 |
+
assert False
|
47 |
+
|
48 |
+
return ret
|
49 |
+
|
50 |
+
|
51 |
+
##################################################
|
52 |
+
# training utils
|
53 |
+
##################################################
|
54 |
+
def soft_target_cross_entropy(logits, targets, soft_targets):
|
55 |
+
# ignore the first token from logits and targets (class id token)
|
56 |
+
logits = logits[:, 1:]
|
57 |
+
targets = targets[:, 1:]
|
58 |
+
|
59 |
+
logits = logits[..., : soft_targets.shape[-1]]
|
60 |
+
|
61 |
+
log_probs = F.log_softmax(logits, dim=-1)
|
62 |
+
padding_mask = targets.eq(-100)
|
63 |
+
|
64 |
+
loss = torch.sum(-soft_targets * log_probs, dim=-1)
|
65 |
+
loss.masked_fill_(padding_mask, 0.0)
|
66 |
+
|
67 |
+
# Take the mean over the label dimensions, then divide by the number of active elements (i.e. not-padded):
|
68 |
+
num_active_elements = padding_mask.numel() - padding_mask.long().sum()
|
69 |
+
loss = loss.sum() / num_active_elements
|
70 |
+
return loss
|
71 |
+
|
72 |
+
|
73 |
+
def get_loss_weight(t, mask, min_val=0.3):
|
74 |
+
return 1 - (1 - mask) * ((1 - t) * (1 - min_val))[:, None]
|
75 |
+
|
76 |
+
|
77 |
+
def mask_or_random_replace_tokens(image_tokens, mask_id, config, mask_schedule, is_train=True):
|
78 |
+
batch_size, seq_len = image_tokens.shape
|
79 |
+
|
80 |
+
if not is_train and config.training.get("eval_mask_ratios", None):
|
81 |
+
mask_prob = random.choices(config.training.eval_mask_ratios, k=batch_size)
|
82 |
+
mask_prob = torch.tensor(mask_prob, device=image_tokens.device)
|
83 |
+
else:
|
84 |
+
# Sample a random timestep for each image
|
85 |
+
timesteps = torch.rand(batch_size, device=image_tokens.device)
|
86 |
+
# Sample a random mask probability for each image using timestep and cosine schedule
|
87 |
+
mask_prob = mask_schedule(timesteps)
|
88 |
+
mask_prob = mask_prob.clip(config.training.min_masking_rate)
|
89 |
+
|
90 |
+
# creat a random mask for each image
|
91 |
+
num_token_masked = (seq_len * mask_prob).round().clamp(min=1)
|
92 |
+
|
93 |
+
mask_contiguous_region_prob = config.training.get("mask_contiguous_region_prob", None)
|
94 |
+
|
95 |
+
if mask_contiguous_region_prob is None:
|
96 |
+
mask_contiguous_region = False
|
97 |
+
else:
|
98 |
+
mask_contiguous_region = random.random() < mask_contiguous_region_prob
|
99 |
+
|
100 |
+
if not mask_contiguous_region:
|
101 |
+
batch_randperm = torch.rand(batch_size, seq_len, device=image_tokens.device).argsort(dim=-1)
|
102 |
+
mask = batch_randperm < num_token_masked.unsqueeze(-1)
|
103 |
+
else:
|
104 |
+
resolution = int(seq_len ** 0.5)
|
105 |
+
mask = torch.zeros((batch_size, resolution, resolution), device=image_tokens.device)
|
106 |
+
|
107 |
+
# TODO - would be nice to vectorize
|
108 |
+
for batch_idx, num_token_masked_ in enumerate(num_token_masked):
|
109 |
+
num_token_masked_ = int(num_token_masked_.item())
|
110 |
+
|
111 |
+
# NOTE: a bit handwavy with the bounds but gets a rectangle of ~num_token_masked_
|
112 |
+
num_token_masked_height = random.randint(
|
113 |
+
math.ceil(num_token_masked_ / resolution), min(resolution, num_token_masked_)
|
114 |
+
)
|
115 |
+
num_token_masked_height = min(num_token_masked_height, resolution)
|
116 |
+
|
117 |
+
num_token_masked_width = math.ceil(num_token_masked_ / num_token_masked_height)
|
118 |
+
num_token_masked_width = min(num_token_masked_width, resolution)
|
119 |
+
|
120 |
+
start_idx_height = random.randint(0, resolution - num_token_masked_height)
|
121 |
+
start_idx_width = random.randint(0, resolution - num_token_masked_width)
|
122 |
+
|
123 |
+
mask[
|
124 |
+
batch_idx,
|
125 |
+
start_idx_height: start_idx_height + num_token_masked_height,
|
126 |
+
start_idx_width: start_idx_width + num_token_masked_width,
|
127 |
+
] = 1
|
128 |
+
|
129 |
+
mask = mask.reshape(batch_size, seq_len)
|
130 |
+
mask = mask.to(torch.bool)
|
131 |
+
|
132 |
+
# mask images and create input and labels
|
133 |
+
if config.training.get("noise_type", "mask"):
|
134 |
+
input_ids = torch.where(mask, mask_id, image_tokens)
|
135 |
+
elif config.training.get("noise_type", "random_replace"):
|
136 |
+
# sample random tokens from the vocabulary
|
137 |
+
random_tokens = torch.randint_like(
|
138 |
+
image_tokens, low=0, high=config.model.codebook_size, device=image_tokens.device
|
139 |
+
)
|
140 |
+
input_ids = torch.where(mask, random_tokens, image_tokens)
|
141 |
+
else:
|
142 |
+
raise ValueError(f"noise_type {config.training.noise_type} not supported")
|
143 |
+
|
144 |
+
if (
|
145 |
+
config.training.get("predict_all_tokens", False)
|
146 |
+
or config.training.get("noise_type", "mask") == "random_replace"
|
147 |
+
):
|
148 |
+
labels = image_tokens
|
149 |
+
loss_weight = get_loss_weight(mask_prob, mask.long())
|
150 |
+
else:
|
151 |
+
labels = torch.where(mask, image_tokens, -100)
|
152 |
+
loss_weight = None
|
153 |
+
|
154 |
+
return input_ids, labels, loss_weight, mask_prob
|
155 |
+
|
156 |
+
|
157 |
+
##################################################
|
158 |
+
# misc
|
159 |
+
##################################################
|
160 |
+
class AverageMeter(object):
|
161 |
+
"""Computes and stores the average and current value"""
|
162 |
+
|
163 |
+
def __init__(self):
|
164 |
+
self.reset()
|
165 |
+
|
166 |
+
def reset(self):
|
167 |
+
self.val = 0
|
168 |
+
self.avg = 0
|
169 |
+
self.sum = 0
|
170 |
+
self.count = 0
|
171 |
+
|
172 |
+
def update(self, val, n=1):
|
173 |
+
self.val = val
|
174 |
+
self.sum += val * n
|
175 |
+
self.count += n
|
176 |
+
self.avg = self.sum / self.count
|
177 |
+
|
178 |
+
from torchvision import transforms
|
179 |
+
def image_transform(image, resolution=256, normalize=True):
|
180 |
+
image = transforms.Resize(resolution, interpolation=transforms.InterpolationMode.BILINEAR)(image)
|
181 |
+
image = transforms.CenterCrop((resolution, resolution))(image)
|
182 |
+
image = transforms.ToTensor()(image)
|
183 |
+
if normalize:
|
184 |
+
image = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)(image)
|
185 |
+
return image
|
validation_prompts/showoprompts.txt
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Close-up view of a computer screen, with the screen displaying a webpage.
|
2 |
+
A tranquil scene of a lotus pond with koi fish swimming gracefully in a peaceful Chinese ink painting.
|
3 |
+
Paper artwork, layered paper, colorful Chinese dragon surrounded by clouds.
|
4 |
+
Pixel art character riding a dragon through the clouds.
|
5 |
+
A peaceful village nestled at the foot of towering mountains in a tranquil East Asian watercolor scene.
|
6 |
+
A person with swirling patterns of teal paint on their face and a shimmering silver crescent moon placed above their eyebrow, symbolizing mystery and magic.
|
7 |
+
A dynamic scene of a rally car race.
|
8 |
+
The breathtaking view of Santorini, a renowned landmark in Greece. The white-washed buildings with blue domes overlook the deep blue waters of the Aegean Sea, creating a stunning contrast against the vibrant sunset.
|
9 |
+
An abstract portrait of a pensive face, rendered in cool shades of blues, purples, and grays.
|
10 |
+
A punk rock frog in a studded leather jacket shouting into a microphone while standing on a boulder.
|
11 |
+
A rebellious squirrel in a studded denim vest, strumming an electric guitar with fervor in a forest clearing.
|
12 |
+
a captivating watercolor portrait of a dog's head, rendered in a vibrant palette of colors.
|
13 |
+
A captivating watercolor portrait of a cat's face, rendered in a soft palette of pastels.
|
14 |
+
A captivating watercolor portrait of a rabbit's profile, rendered in gentle hues of pinks and browns.
|
15 |
+
a white Lamborghini Gallardo Spyder is parked on a cobblestone street.
|
16 |
+
the breathtaking beauty of Whitehaven Beach.
|
17 |
+
The breathtaking view of Moraine Lake, a renowned landmark in Canada. The turquoise waters of the lake reflect the rugged peaks of the Valley of the Ten Peaks, creating a scene of unparalleled natural beauty.
|
18 |
+
The breathtaking view of Mount Fuji, a renowned landmark in Japan. The iconic snow-capped peak rises majestically above the surrounding landscape, mirrored perfectly in the tranquil waters of Lake Kawaguchi.
|
19 |
+
A bustling Asian market at night, with colorful lanterns, street food vendors, and a mix of traditional and modern architecture.
|
20 |
+
A stunning coastal cliffside at sunset, with waves crashing against the rocks and the sky painted in shades of orange, pink, and purple.
|
21 |
+
A tranquil island paradise, with a white sandy beach, crystal-clear water, and palm trees swaying in the gentle breeze.
|
22 |
+
A carnival of dreams where carousel horses gallop into the sky and cotton candy clouds drift by.
|
23 |
+
A floating market in the sky where clouds serve as stalls for trading dreams.
|
24 |
+
Intricate paper-cut creation featuring a vibrant peacock perched among blooming flowers.
|