JosephPai commited on
Commit
8741abe
1 Parent(s): 36c1777
Files changed (48) hide show
  1. LICENSE +21 -0
  2. README.md +2 -2
  3. app.py +526 -0
  4. configs/showo_demo.yaml +49 -0
  5. configs/showo_demo_w_clip_vit.yaml +49 -0
  6. gradio/app.py +488 -0
  7. gradio/app_gradio.py +281 -0
  8. gradio/app_w_clip.py +559 -0
  9. gradio/share_btn.py +113 -0
  10. inference_mmu.py +174 -0
  11. inference_t2i.py +331 -0
  12. inpainting_validation/.DS_Store +0 -0
  13. inpainting_validation/alpine_lake.jpg +0 -0
  14. inpainting_validation/bedroom.jpg +0 -0
  15. inpainting_validation/bedroom_mask.webp +0 -0
  16. inpainting_validation/bench.jpg +0 -0
  17. inpainting_validation/bench_mask.webp +0 -0
  18. inpainting_validation/bus.jpg +0 -0
  19. inpainting_validation/bus_mask.webp +0 -0
  20. inpainting_validation/lake_mountain.jpg +0 -0
  21. inpainting_validation/maya.png +0 -0
  22. inpainting_validation/river.png +0 -0
  23. inpainting_validation/train.jpg +0 -0
  24. inpainting_validation/train_mask.webp +0 -0
  25. inpainting_validation/truebsee.jpg +0 -0
  26. inpainting_validation/truebsee_mask.webp +0 -0
  27. inpainting_validation/wukong1.jpg +0 -0
  28. inpainting_validation/wukong2.jpg +0 -0
  29. mmu_validation/sofa_under_water.jpg +0 -0
  30. models/__init__.py +4 -0
  31. models/clip_encoder.py +140 -0
  32. models/common_modules.py +407 -0
  33. models/logging.py +338 -0
  34. models/lr_schedulers.py +292 -0
  35. models/misc.py +53 -0
  36. models/modeling_magvitv2.py +440 -0
  37. models/modeling_showo.py +206 -0
  38. models/modeling_utils.py +1207 -0
  39. models/phi.py +1489 -0
  40. models/sampling.py +118 -0
  41. models/training_utils.py +455 -0
  42. prompting_utils.py +528 -0
  43. requirements.txt +228 -0
  44. training/__init__.py +1 -0
  45. training/conversation.py +432 -0
  46. training/utils.py +185 -0
  47. training_utils.py +185 -0
  48. validation_prompts/showoprompts.txt +24 -0
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Jinheng Xie
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:4081eb457c27efd213e1df5d71e6075fdd1a969a7ac408bf8e7968345250d360
3
- size 227
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:df8da49db1cd3db14e34d015e398bd3d5ab51c3988fd98976405701ce1838ef5
3
+ size 224
app.py ADDED
@@ -0,0 +1,526 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "False"
3
+ os.environ["TOKENIZERS_PARALLELISM"] = "true"
4
+ import numpy as np
5
+ import gradio as gr
6
+ import spaces
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from PIL import Image
10
+ from omegaconf import OmegaConf
11
+ from transformers import AutoTokenizer
12
+
13
+ from prompting_utils import UniversalPrompting, create_attention_mask_predict_next, create_attention_mask_for_mmu
14
+ from training_utils import image_transform
15
+ from models import Showo, MAGVITv2, get_mask_chedule
16
+
17
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
+
19
+ config = OmegaConf.load("configs/showo_demo.yaml")
20
+ tokenizer = AutoTokenizer.from_pretrained(config.model.showo.llm_model_path, padding_side="left")
21
+
22
+ uni_prompting = UniversalPrompting(tokenizer, max_text_len=config.dataset.preprocessing.max_seq_length,
23
+ special_tokens=("<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", "<|mmu|>",
24
+ "<|t2v|>", "<|v2v|>", "<|lvg|>"),
25
+ ignore_id=-100, cond_dropout_prob=config.training.cond_dropout_prob)
26
+
27
+ vq_model = MAGVITv2()
28
+ vq_model = vq_model.from_pretrained(config.model.vq_model.vq_model_name).to(device)
29
+ vq_model.requires_grad_(False)
30
+ vq_model.eval()
31
+
32
+ model = Showo.from_pretrained(config.model.showo.pretrained_model_path).to(device)
33
+ model.eval()
34
+ mask_token_id = model.config.mask_token_id
35
+
36
+
37
+ @spaces.GPU
38
+ def text_to_image_generation(input_text, guidance_scale=1.75, generation_timesteps=18):
39
+ prompts = [input_text]
40
+ config.training.batch_size = config.batch_size = 1
41
+ config.training.guidance_scale = config.guidance_scale = guidance_scale
42
+ config.training.generation_timesteps = config.generation_timesteps = generation_timesteps
43
+
44
+ image_tokens = torch.ones((len(prompts), config.model.showo.num_vq_tokens),
45
+ dtype=torch.long, device=device) * mask_token_id
46
+
47
+ input_ids, _ = uni_prompting((prompts, image_tokens), 't2i_gen')
48
+
49
+ if config.training.guidance_scale > 0:
50
+ uncond_input_ids, _ = uni_prompting(([''] * len(prompts), image_tokens), 't2i_gen')
51
+ attention_mask = create_attention_mask_predict_next(torch.cat([input_ids, uncond_input_ids], dim=0),
52
+ pad_id=int(uni_prompting.sptids_dict['<|pad|>']),
53
+ soi_id=int(uni_prompting.sptids_dict['<|soi|>']),
54
+ eoi_id=int(uni_prompting.sptids_dict['<|eoi|>']),
55
+ rm_pad_in_image=True)
56
+ else:
57
+ attention_mask = create_attention_mask_predict_next(input_ids,
58
+ pad_id=int(uni_prompting.sptids_dict['<|pad|>']),
59
+ soi_id=int(uni_prompting.sptids_dict['<|soi|>']),
60
+ eoi_id=int(uni_prompting.sptids_dict['<|eoi|>']),
61
+ rm_pad_in_image=True)
62
+ uncond_input_ids = None
63
+
64
+ if config.get("mask_schedule", None) is not None:
65
+ schedule = config.mask_schedule.schedule
66
+ args = config.mask_schedule.get("params", {})
67
+ mask_schedule = get_mask_chedule(schedule, **args)
68
+ else:
69
+ mask_schedule = get_mask_chedule(config.training.get("mask_schedule", "cosine"))
70
+
71
+ with torch.no_grad():
72
+ gen_token_ids = model.t2i_generate(
73
+ input_ids=input_ids,
74
+ uncond_input_ids=uncond_input_ids,
75
+ attention_mask=attention_mask,
76
+ guidance_scale=config.training.guidance_scale,
77
+ temperature=config.training.get("generation_temperature", 1.0),
78
+ timesteps=config.training.generation_timesteps,
79
+ noise_schedule=mask_schedule,
80
+ noise_type=config.training.get("noise_type", "mask"),
81
+ seq_len=config.model.showo.num_vq_tokens,
82
+ uni_prompting=uni_prompting,
83
+ config=config,
84
+ )
85
+
86
+ gen_token_ids = torch.clamp(gen_token_ids, max=config.model.showo.codebook_size - 1, min=0)
87
+ images = vq_model.decode_code(gen_token_ids)
88
+
89
+ images = torch.clamp((images + 1.0) / 2.0, min=0.0, max=1.0)
90
+ images *= 255.0
91
+ images = images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8)
92
+
93
+ return images[0]
94
+
95
+
96
+ @spaces.GPU
97
+ def text_guided_inpainting(input_text, inpainting_image, inpainting_mask, guidance_scale=1.75, generation_timesteps=16):
98
+ prompt = [input_text]
99
+
100
+ config.training.batch_size = config.batch_size = 1
101
+ config.training.guidance_scale = config.guidance_scale = guidance_scale
102
+ config.training.generation_timesteps = config.generation_timesteps = generation_timesteps
103
+
104
+ inpainting_image = image_transform(inpainting_image, resolution=config.dataset.params.resolution).to(device)
105
+ inpainting_mask = image_transform(inpainting_mask, resolution=config.dataset.params.resolution, normalize=False)
106
+
107
+ inpainting_image = inpainting_image.unsqueeze(0).repeat(config.training.batch_size, 1, 1, 1)
108
+
109
+ inpainting_mask = inpainting_mask.unsqueeze(0).to(device)
110
+ inpainting_mask = F.interpolate(inpainting_mask, size=config.dataset.params.resolution // 16, mode='bicubic')
111
+ inpainting_mask = inpainting_mask.repeat(config.training.batch_size, 1, 1, 1)
112
+
113
+ inpainting_mask[inpainting_mask < 0.5] = 0
114
+ inpainting_mask[inpainting_mask >= 0.5] = 1
115
+
116
+ inpainting_mask = inpainting_mask.reshape(config.training.batch_size, -1)
117
+ inpainting_mask = inpainting_mask.to(torch.bool)
118
+
119
+ inpainting_image_tokens = vq_model.get_code(inpainting_image) + len(uni_prompting.text_tokenizer)
120
+ inpainting_image_tokens[inpainting_mask] = mask_token_id
121
+
122
+ input_ids, _ = uni_prompting((prompt, inpainting_image_tokens), 't2i_gen')
123
+
124
+ if config.training.guidance_scale > 0:
125
+ uncond_input_ids, _ = uni_prompting(([''] * len(prompt), inpainting_image_tokens), 't2i_gen')
126
+ attention_mask = create_attention_mask_predict_next(torch.cat([input_ids, uncond_input_ids], dim=0),
127
+ pad_id=int(uni_prompting.sptids_dict['<|pad|>']),
128
+ soi_id=int(uni_prompting.sptids_dict['<|soi|>']),
129
+ eoi_id=int(uni_prompting.sptids_dict['<|eoi|>']),
130
+ rm_pad_in_image=True)
131
+ else:
132
+ attention_mask = create_attention_mask_predict_next(input_ids,
133
+ pad_id=int(uni_prompting.sptids_dict['<|pad|>']),
134
+ soi_id=int(uni_prompting.sptids_dict['<|soi|>']),
135
+ eoi_id=int(uni_prompting.sptids_dict['<|eoi|>']),
136
+ rm_pad_in_image=True)
137
+ uncond_input_ids = None
138
+
139
+ if config.get("mask_schedule", None) is not None:
140
+ schedule = config.mask_schedule.schedule
141
+ args = config.mask_schedule.get("params", {})
142
+ mask_schedule = get_mask_chedule(schedule, **args)
143
+ else:
144
+ mask_schedule = get_mask_chedule(config.training.get("mask_schedule", "cosine"))
145
+
146
+ with torch.no_grad():
147
+ gen_token_ids = model.t2i_generate(
148
+ input_ids=input_ids,
149
+ uncond_input_ids=uncond_input_ids,
150
+ attention_mask=attention_mask,
151
+ guidance_scale=config.training.guidance_scale,
152
+ temperature=config.training.get("generation_temperature", 1.0),
153
+ timesteps=config.training.generation_timesteps,
154
+ noise_schedule=mask_schedule,
155
+ noise_type=config.training.get("noise_type", "mask"),
156
+ seq_len=config.model.showo.num_vq_tokens,
157
+ uni_prompting=uni_prompting,
158
+ config=config,
159
+ )
160
+
161
+ gen_token_ids = torch.clamp(gen_token_ids, max=config.model.showo.codebook_size - 1, min=0)
162
+ images = vq_model.decode_code(gen_token_ids)
163
+
164
+ images = torch.clamp((images + 1.0) / 2.0, min=0.0, max=1.0)
165
+ images *= 255.0
166
+ images = images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8)
167
+
168
+ return images[0]
169
+
170
+
171
+ @spaces.GPU
172
+ def text_guided_extrapolation(input_img, input_text, left_ext, right_ext, guidance_scale=1.75, generation_timesteps=16):
173
+ config.offset = 0
174
+ config.training.batch_size = config.batch_size = 1
175
+ config.training.guidance_scale = config.guidance_scale = guidance_scale
176
+ config.training.generation_timesteps = config.generation_timesteps = generation_timesteps
177
+
178
+ extra_direction = ['right'] * int(right_ext) + ['left'] * int(left_ext)
179
+ prompt = [input_text] * len(extra_direction)
180
+ W = config.dataset.params.resolution // 16
181
+ for id, (prt, direction) in enumerate(zip(prompt, extra_direction)):
182
+ prt = [prt] * config.training.batch_size
183
+ if id == 0:
184
+ # extrapolation_image = Image.open(config.image_path).convert("RGB")
185
+ extrapolation_image = input_img
186
+ extrapolation_image = image_transform(extrapolation_image,
187
+ resolution=config.dataset.params.resolution).to(device)
188
+
189
+ B, _, _ = extrapolation_image.shape
190
+ extrapolation_image = extrapolation_image.unsqueeze(0)
191
+ extrapolation_image_tokens = vq_model.get_code(extrapolation_image) + len(uni_prompting.text_tokenizer)
192
+ extrapolation_image_tokens = extrapolation_image_tokens.reshape(1,
193
+ config.dataset.params.resolution // 16,
194
+ config.dataset.params.resolution // 16)
195
+ extrapolation_image_tokens = extrapolation_image_tokens.repeat(config.training.batch_size, 1, 1)
196
+ else:
197
+
198
+ extrapolation_image_tokens = gen_token_ids + len(uni_prompting.text_tokenizer)
199
+
200
+ image_left_part = extrapolation_image_tokens[:, :, :-(W // 2 - config.offset)] - len(
201
+ uni_prompting.text_tokenizer)
202
+ image_right_part = extrapolation_image_tokens[:, :, W // 2 - config.offset:] - len(uni_prompting.text_tokenizer)
203
+ image_up_part = extrapolation_image_tokens[:, :-(W // 2 - config.offset), :] - len(uni_prompting.text_tokenizer)
204
+ image_down_part = extrapolation_image_tokens[:, W // 2 - config.offset:, :] - len(uni_prompting.text_tokenizer)
205
+
206
+ if direction in ['left', 'right']:
207
+ extrapolation_mask = torch.zeros((config.training.batch_size,
208
+ config.dataset.params.resolution // 16,
209
+ config.dataset.params.resolution // 16 // 2 + config.offset),
210
+ dtype=torch.int64, device=device) + mask_token_id
211
+ else:
212
+ extrapolation_mask = torch.zeros((config.training.batch_size,
213
+ config.dataset.params.resolution // 16 // 2 + config.offset,
214
+ config.dataset.params.resolution // 16),
215
+ dtype=torch.int64, device=device) + mask_token_id
216
+
217
+ if direction == 'left':
218
+ extrapolation_image_tokens = torch.cat(
219
+ [extrapolation_mask, extrapolation_image_tokens[:, :, :W // 2 - config.offset]], dim=-1)
220
+ elif direction == 'right':
221
+ extrapolation_image_tokens = torch.cat(
222
+ [extrapolation_image_tokens[:, :, -(W // 2 - config.offset):], extrapolation_mask], dim=-1)
223
+ elif direction == 'up':
224
+ extrapolation_image_tokens = torch.cat(
225
+ [extrapolation_mask, extrapolation_image_tokens[:, :W // 2 - config.offset, :]], dim=-2)
226
+ else:
227
+ extrapolation_image_tokens = torch.cat(
228
+ [extrapolation_image_tokens[:, -(W // 2 - config.offset):, :], extrapolation_mask], dim=-2)
229
+
230
+ extrapolation_image_tokens = extrapolation_image_tokens.reshape(config.training.batch_size, -1)
231
+
232
+ input_ids, _ = uni_prompting((prt, extrapolation_image_tokens), 't2i_gen')
233
+
234
+ if config.training.guidance_scale > 0:
235
+ uncond_input_ids, _ = uni_prompting(([''] * len(prt), extrapolation_image_tokens), 't2i_gen')
236
+ attention_mask = create_attention_mask_predict_next(torch.cat([input_ids, uncond_input_ids], dim=0),
237
+ pad_id=int(uni_prompting.sptids_dict['<|pad|>']),
238
+ soi_id=int(uni_prompting.sptids_dict['<|soi|>']),
239
+ eoi_id=int(uni_prompting.sptids_dict['<|eoi|>']),
240
+ rm_pad_in_image=True)
241
+ else:
242
+ attention_mask = create_attention_mask_predict_next(input_ids,
243
+ pad_id=int(uni_prompting.sptids_dict['<|pad|>']),
244
+ soi_id=int(uni_prompting.sptids_dict['<|soi|>']),
245
+ eoi_id=int(uni_prompting.sptids_dict['<|eoi|>']),
246
+ rm_pad_in_image=True)
247
+ uncond_input_ids = None
248
+
249
+ if config.get("mask_schedule", None) is not None:
250
+ schedule = config.mask_schedule.schedule
251
+ args = config.mask_schedule.get("params", {})
252
+ mask_schedule = get_mask_chedule(schedule, **args)
253
+ else:
254
+ mask_schedule = get_mask_chedule(config.training.get("mask_schedule", "cosine"))
255
+
256
+ with torch.no_grad():
257
+ gen_token_ids = model.t2i_generate(
258
+ input_ids=input_ids,
259
+ uncond_input_ids=uncond_input_ids,
260
+ attention_mask=attention_mask,
261
+ guidance_scale=config.training.guidance_scale,
262
+ temperature=config.training.get("generation_temperature", 1.0),
263
+ timesteps=config.training.generation_timesteps,
264
+ noise_schedule=mask_schedule,
265
+ noise_type=config.training.get("noise_type", "mask"),
266
+ seq_len=config.model.showo.num_vq_tokens,
267
+ uni_prompting=uni_prompting,
268
+ config=config,
269
+ )
270
+
271
+ gen_token_ids = torch.clamp(gen_token_ids, max=config.model.showo.codebook_size - 1, min=0)
272
+ gen_token_ids = gen_token_ids.reshape(config.training.batch_size,
273
+ config.dataset.params.resolution // 16,
274
+ config.dataset.params.resolution // 16)
275
+ if direction == 'left':
276
+ gen_token_ids = torch.cat([gen_token_ids, image_right_part], dim=-1)
277
+ elif direction == 'right':
278
+ gen_token_ids = torch.cat([image_left_part, gen_token_ids], dim=-1)
279
+ elif direction == 'up':
280
+ gen_token_ids = torch.cat([gen_token_ids, image_down_part], dim=-2)
281
+ else:
282
+ gen_token_ids = torch.cat([image_left_part, gen_token_ids], dim=-2)
283
+
284
+ _, h, w = gen_token_ids.shape
285
+ gen_token_ids = gen_token_ids.reshape(config.training.batch_size, -1)
286
+ images = vq_model.decode_code(gen_token_ids, shape=(h, w))
287
+
288
+ images = torch.clamp((images + 1.0) / 2.0, min=0.0, max=1.0)
289
+ images *= 255.0
290
+ images = images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8)
291
+
292
+ return images[0]
293
+
294
+
295
+ @spaces.GPU
296
+ def multimodal_understanding(input_img, input_text, chat_history):
297
+ top_k = 1 # retain only the top_k most likely tokens, clamp others to have 0 probability
298
+
299
+ image_ori = input_img
300
+ image = image_transform(image_ori, resolution=config.dataset.params.resolution).to(device)
301
+ image = image.unsqueeze(0)
302
+ image_tokens = vq_model.get_code(image) + len(uni_prompting.text_tokenizer)
303
+
304
+ question = input_text
305
+ input_ids = uni_prompting.text_tokenizer(['USER: \n' + question + ' ASSISTANT:'])[
306
+ 'input_ids']
307
+ input_ids = torch.tensor(input_ids).to(device)
308
+
309
+ input_ids = torch.cat([
310
+ (torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|mmu|>']).to(device),
311
+ (torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|soi|>']).to(device),
312
+ image_tokens,
313
+ (torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|eoi|>']).to(device),
314
+ (torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|sot|>']).to(device),
315
+ input_ids
316
+ ], dim=1).long()
317
+
318
+ attention_mask = create_attention_mask_for_mmu(input_ids.to(device),
319
+ eoi_id=int(uni_prompting.sptids_dict['<|eoi|>']))
320
+
321
+ cont_toks_list = model.mmu_generate(input_ids, attention_mask=attention_mask,
322
+ max_new_tokens=100, top_k=top_k,
323
+ eot_token=uni_prompting.sptids_dict['<|eot|>'])
324
+
325
+ cont_toks_list = torch.stack(cont_toks_list).squeeze()[None]
326
+
327
+ output_text = uni_prompting.text_tokenizer.batch_decode(cont_toks_list, skip_special_tokens=True)
328
+
329
+ output_text = output_text[0].strip()
330
+
331
+ chat_history.append((input_text, output_text))
332
+
333
+ return "", chat_history
334
+
335
+
336
+ with gr.Blocks() as demo:
337
+ gr.HTML("""
338
+ <h1 class="display-2 fw-bold title">
339
+ <a style="color: #70a8dc;">S</a><a style="color: #6fb051;">h</a><a style="color: #e06766;">o</a><a style="color: #f7b26b;">w</a>-o
340
+ </h1>
341
+ <p>This is the official Gradio demo for the Show-o model, a unified model that can do multimodal understanding and generation.</p>
342
+
343
+ <strong>Paper:</strong> <a href="https://arxiv.org/abs/2408.12528" target="_blank">Show-o: One Single Transformer To Unify Multimodal Understanding and Generation </a>
344
+ <br/>
345
+ <strong>Project Website:</strong> <a href="https://showlab.github.io/Show-o/" target="_blank">Show-o Website</a>
346
+ <br/>
347
+ <strong>Code and Models:</strong> <a href="https://github.com/showlab/Show-o" target="_blank">GitHub</a>
348
+ <br/>
349
+ <br/>
350
+ """)
351
+
352
+ with gr.Row():
353
+ with gr.Column():
354
+ text_prompt_t2i = gr.Textbox(
355
+ label="Text prompt",
356
+ lines=2,
357
+ placeholder="Input the text prompt here for image generation."
358
+ )
359
+ guidance_scale_t2i = gr.Slider(
360
+ label="guidance scale",
361
+ minimum=0,
362
+ maximum=5,
363
+ step=0.05,
364
+ value=1.75
365
+ )
366
+ generation_timesteps_t2i = gr.Slider(
367
+ label="timesteps",
368
+ minimum=1,
369
+ maximum=30,
370
+ step=1,
371
+ value=18
372
+ )
373
+ generated_img_t2i = gr.Image(
374
+ label="Output image"
375
+ )
376
+ examples_t2i = gr.Examples(
377
+ label="Text to image generation examples",
378
+ examples=[
379
+ "A dynamic scene of a rally car race.",
380
+ "Paper artwork, layered paper, colorful Chinese dragon surrounded by clouds.",
381
+ "Pixel art character riding a dragon through the clouds.",
382
+ ],
383
+ inputs=text_prompt_t2i,
384
+ )
385
+ submit_btn_t2i = gr.Button("Generate: Text-to-image")
386
+ submit_btn_t2i.click(text_to_image_generation,
387
+ [text_prompt_t2i, guidance_scale_t2i, generation_timesteps_t2i],
388
+ [generated_img_t2i])
389
+
390
+ with gr.Row():
391
+ inpainting_input_img = gr.Image(
392
+ label="Input image",
393
+ type="pil",
394
+ )
395
+ inpainting_input_mask = gr.Image(
396
+ label="Inpainting mask",
397
+ image_mode="L",
398
+ type="pil",
399
+ )
400
+
401
+ with gr.Column():
402
+ text_prompt_inpainting = gr.Textbox(
403
+ label="Text prompt",
404
+ lines=2,
405
+ placeholder="Input the text prompt here for image inpainting."
406
+ )
407
+ guidance_scale_inpainting = gr.Slider(
408
+ label="guidance scale",
409
+ minimum=0,
410
+ maximum=5,
411
+ step=0.05,
412
+ value=1.75
413
+ )
414
+ generation_timesteps_inpainting = gr.Slider(
415
+ label="timesteps",
416
+ minimum=1,
417
+ maximum=30,
418
+ step=1,
419
+ value=16
420
+ )
421
+ generated_img_inpainting = gr.Image(
422
+ label="Output image"
423
+ )
424
+ examples_inpainting = gr.Examples(
425
+ label="Text-guided inpainting examples",
426
+ examples=[
427
+ [
428
+ "a blue sports car with sleek curves and tinted windows, parked on a bustling city street.",
429
+ Image.open("./inpainting_validation/bus.jpg").convert("RGB"),
430
+ Image.open("./inpainting_validation/bus_mask.webp").convert("L"),
431
+ ],
432
+ [
433
+ "a clear, shallow river with some vibrant flowers in it.",
434
+ Image.open("./inpainting_validation/train.jpg").convert("RGB"),
435
+ Image.open("./inpainting_validation/train_mask.webp").convert("L"),
436
+ ],
437
+ ],
438
+ inputs=[text_prompt_inpainting, inpainting_input_img, inpainting_input_mask],
439
+ )
440
+ submit_btn_inpainting = gr.Button("Generate: Text-guided Inpainting")
441
+ submit_btn_inpainting.click(text_guided_inpainting,
442
+ [text_prompt_inpainting, inpainting_input_img, inpainting_input_mask,
443
+ guidance_scale_inpainting, generation_timesteps_inpainting],
444
+ [generated_img_inpainting])
445
+
446
+ with gr.Row():
447
+ extra_input_img = gr.Image(
448
+ label="Input image",
449
+ type="pil",
450
+ image_mode="RGB",
451
+ )
452
+
453
+ with gr.Column():
454
+ text_prompt_extrapolation = gr.Textbox(
455
+ label="Text prompt",
456
+ lines=1,
457
+ placeholder="Input the text prompt here for image extrapolation."
458
+ )
459
+ guidance_scale_extrapolation = gr.Slider(
460
+ label="guidance scale",
461
+ minimum=0,
462
+ maximum=5,
463
+ step=0.05,
464
+ value=1.75
465
+ )
466
+ generation_timesteps_extrapolation = gr.Slider(
467
+ label="timesteps",
468
+ minimum=1,
469
+ maximum=30,
470
+ step=1,
471
+ value=16
472
+ )
473
+ left_extrapolation = gr.Slider(
474
+ label="left extrapolation",
475
+ minimum=0,
476
+ maximum=5,
477
+ step=1,
478
+ value=1
479
+ )
480
+ right_extrapolation = gr.Slider(
481
+ label="right extrapolation",
482
+ minimum=0,
483
+ maximum=5,
484
+ step=1,
485
+ value=1
486
+ )
487
+ generated_img_extrapolation = gr.Image(
488
+ label="Output image"
489
+ )
490
+ examples_extra = gr.Examples(
491
+ label="Text-guided extrapolation examples",
492
+ examples=[
493
+ [
494
+ Image.open("./inpainting_validation/wukong2.jpg").convert("RGB"),
495
+ "the continuous mountain ranges and jungles, with meandering rivers occasionally appearing.",
496
+ 2,
497
+ 2,
498
+ ],
499
+ [
500
+ Image.open("./inpainting_validation/alpine_lake.jpg").convert("RGB"),
501
+ "a serene natural landscape featuring a clear, blue lake surrounded by lush green trees.",
502
+ 2,
503
+ 2,
504
+ ],
505
+ ],
506
+ inputs=[extra_input_img, text_prompt_extrapolation, left_extrapolation, right_extrapolation],
507
+ )
508
+ submit_btn_inpainting = gr.Button("Generate: Text-guided Extrapolation")
509
+ submit_btn_inpainting.click(text_guided_extrapolation,
510
+ [extra_input_img, text_prompt_extrapolation, left_extrapolation, right_extrapolation,
511
+ guidance_scale_extrapolation, generation_timesteps_extrapolation],
512
+ [generated_img_extrapolation])
513
+ with gr.Row():
514
+ with gr.Row():
515
+ chat_input_img = gr.Image(
516
+ label="Input image",
517
+ type="pil",
518
+ image_mode="RGB",
519
+ )
520
+ with gr.Column():
521
+ chatbot = gr.Chatbot()
522
+ msg = gr.Textbox(label="Press Enter to send a message for chat")
523
+ clear = gr.ClearButton([msg, chatbot])
524
+ msg.submit(multimodal_understanding, [chat_input_img, msg, chatbot], [msg, chatbot])
525
+
526
+ demo.launch()
configs/showo_demo.yaml ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ wandb:
2
+ entity: null
3
+ # run_id: askkz9i2
4
+ resume: 'auto'
5
+
6
+ experiment:
7
+ project: "demo"
8
+ name: "show-o-demo"
9
+ output_dir: "show-o-demo"
10
+
11
+ model:
12
+ vq_model:
13
+ type: "magvitv2"
14
+ vq_model_name: "showlab/magvitv2"
15
+
16
+ showo:
17
+ pretrained_model_path: "showlab/show-o"
18
+ w_clip_vit: False
19
+ vocab_size: 58498
20
+ llm_vocab_size: 50295
21
+ llm_model_path: 'microsoft/phi-1_5'
22
+ codebook_size: 8192
23
+ num_vq_tokens: 256
24
+
25
+ gradient_checkpointing: True
26
+ enable_xformers_memory_efficient_attention: True
27
+
28
+
29
+ dataset:
30
+ gen_type: "t2i"
31
+ und_type: "large_cap"
32
+ params:
33
+ batch_size: ${training.batch_size}
34
+ shuffle_buffer_size: 1000
35
+ num_workers: 32
36
+ resolution: 256
37
+ pin_memory: True
38
+ persistent_workers: True
39
+
40
+ preprocessing:
41
+ max_seq_length: 128
42
+ resolution: 256
43
+ center_crop: False
44
+ random_flip: False
45
+
46
+ training:
47
+ gradient_accumulation_steps: 1
48
+ cond_dropout_prob: 0.1
49
+ batch_size: 20
configs/showo_demo_w_clip_vit.yaml ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ wandb:
2
+ entity: null
3
+ # run_id: askkz9i2
4
+ resume: 'auto'
5
+
6
+ experiment:
7
+ project: "demo"
8
+ name: "show-o-demo"
9
+ output_dir: "show-o-demo"
10
+
11
+ model:
12
+ vq_model:
13
+ type: "magvitv2"
14
+ vq_model_name: "showlab/magvitv2"
15
+
16
+ showo:
17
+ pretrained_model_path: "showlab/show-o-w-clip-vit"
18
+ w_clip_vit: True
19
+ vocab_size: 58498
20
+ llm_vocab_size: 50295
21
+ llm_model_path: 'microsoft/phi-1_5'
22
+ codebook_size: 8192
23
+ num_vq_tokens: 256
24
+
25
+ gradient_checkpointing: True
26
+ enable_xformers_memory_efficient_attention: True
27
+
28
+
29
+ dataset:
30
+ gen_type: "t2i"
31
+ und_type: "large_cap"
32
+ params:
33
+ batch_size: ${training.batch_size}
34
+ shuffle_buffer_size: 1000
35
+ num_workers: 32
36
+ resolution: 256
37
+ pin_memory: True
38
+ persistent_workers: True
39
+
40
+ preprocessing:
41
+ max_seq_length: 128
42
+ resolution: 256
43
+ center_crop: False
44
+ random_flip: False
45
+
46
+ training:
47
+ gradient_accumulation_steps: 1
48
+ cond_dropout_prob: 0.1
49
+ batch_size: 20
gradio/app.py ADDED
@@ -0,0 +1,488 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "False"
3
+ os.environ["TOKENIZERS_PARALLELISM"] = "true"
4
+ import numpy as np
5
+ import gradio as gr
6
+ import torch
7
+ from PIL import Image
8
+ from omegaconf import OmegaConf
9
+ from transformers import AutoTokenizer
10
+ import torch.nn.functional as F
11
+ from transformers import CLIPImageProcessor
12
+
13
+ import sys
14
+ sys.path.insert(0, ".")
15
+ from training import conversation as conversation_lib
16
+ from prompting_utils import UniversalPrompting, create_attention_mask_predict_next, create_attention_mask_for_mmu
17
+ from training_utils import image_transform
18
+ from models import Showo, MAGVITv2, get_mask_chedule, CLIPVisionTower
19
+
20
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
+ conversation_lib.default_conversation = conversation_lib.conv_templates["phi1.5"]
22
+ SYSTEM_PROMPT = "A chat between a curious user and an artificial intelligence assistant. " \
23
+ "The assistant gives helpful, detailed, and polite answers to the user's questions."
24
+ SYSTEM_PROMPT_LEN = 28
25
+
26
+
27
+ config = OmegaConf.load("configs/showo_demo.yaml")
28
+ tokenizer = AutoTokenizer.from_pretrained(config.model.showo.llm_model_path, padding_side="left")
29
+
30
+ uni_prompting = UniversalPrompting(tokenizer, max_text_len=config.dataset.preprocessing.max_seq_length,
31
+ special_tokens=("<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", "<|mmu|>",
32
+ "<|t2v|>", "<|v2v|>", "<|lvg|>"),
33
+ ignore_id=-100, cond_dropout_prob=config.training.cond_dropout_prob)
34
+
35
+ vq_model = MAGVITv2()
36
+ vq_model = vq_model.from_pretrained(config.model.vq_model.vq_model_name).to(device)
37
+ vq_model.requires_grad_(False)
38
+ vq_model.eval()
39
+
40
+ model = Showo.from_pretrained(config.model.showo.pretrained_model_path).to(device)
41
+ model.eval()
42
+ mask_token_id = model.config.mask_token_id
43
+
44
+
45
+ def text_to_image_generation(input_text, guidance_scale, generation_timesteps):
46
+ prompts = [input_text]
47
+ config.training.batch_size = config.batch_size = 1
48
+ config.training.guidance_scale = config.guidance_scale = guidance_scale
49
+ config.training.generation_timesteps = config.generation_timesteps = generation_timesteps
50
+
51
+ image_tokens = torch.ones((len(prompts), config.model.showo.num_vq_tokens),
52
+ dtype=torch.long, device=device) * mask_token_id
53
+
54
+ input_ids, _ = uni_prompting((prompts, image_tokens), 't2i_gen')
55
+
56
+ if config.training.guidance_scale > 0:
57
+ uncond_input_ids, _ = uni_prompting(([''] * len(prompts), image_tokens), 't2i_gen')
58
+ attention_mask = create_attention_mask_predict_next(torch.cat([input_ids, uncond_input_ids], dim=0),
59
+ pad_id=int(uni_prompting.sptids_dict['<|pad|>']),
60
+ soi_id=int(uni_prompting.sptids_dict['<|soi|>']),
61
+ eoi_id=int(uni_prompting.sptids_dict['<|eoi|>']),
62
+ rm_pad_in_image=True)
63
+ else:
64
+ attention_mask = create_attention_mask_predict_next(input_ids,
65
+ pad_id=int(uni_prompting.sptids_dict['<|pad|>']),
66
+ soi_id=int(uni_prompting.sptids_dict['<|soi|>']),
67
+ eoi_id=int(uni_prompting.sptids_dict['<|eoi|>']),
68
+ rm_pad_in_image=True)
69
+ uncond_input_ids = None
70
+
71
+ if config.get("mask_schedule", None) is not None:
72
+ schedule = config.mask_schedule.schedule
73
+ args = config.mask_schedule.get("params", {})
74
+ mask_schedule = get_mask_chedule(schedule, **args)
75
+ else:
76
+ mask_schedule = get_mask_chedule(config.training.get("mask_schedule", "cosine"))
77
+
78
+ with torch.no_grad():
79
+ gen_token_ids = model.t2i_generate(
80
+ input_ids=input_ids,
81
+ uncond_input_ids=uncond_input_ids,
82
+ attention_mask=attention_mask,
83
+ guidance_scale=config.training.guidance_scale,
84
+ temperature=config.training.get("generation_temperature", 1.0),
85
+ timesteps=config.training.generation_timesteps,
86
+ noise_schedule=mask_schedule,
87
+ noise_type=config.training.get("noise_type", "mask"),
88
+ seq_len=config.model.showo.num_vq_tokens,
89
+ uni_prompting=uni_prompting,
90
+ config=config,
91
+ )
92
+
93
+ gen_token_ids = torch.clamp(gen_token_ids, max=config.model.showo.codebook_size - 1, min=0)
94
+ images = vq_model.decode_code(gen_token_ids)
95
+
96
+ images = torch.clamp((images + 1.0) / 2.0, min=0.0, max=1.0)
97
+ images *= 255.0
98
+ images = images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8)
99
+
100
+ return images[0]
101
+
102
+
103
+ def text_guided_inpainting(input_text, inpainting_image, inpainting_mask, guidance_scale, generation_timesteps):
104
+ prompt = [input_text]
105
+
106
+ config.training.batch_size = config.batch_size = 1
107
+ config.training.guidance_scale = config.guidance_scale = guidance_scale
108
+ config.training.generation_timesteps = config.generation_timesteps = generation_timesteps
109
+
110
+ inpainting_image = image_transform(inpainting_image, resolution=config.dataset.params.resolution).to(device)
111
+ inpainting_mask = image_transform(inpainting_mask, resolution=config.dataset.params.resolution, normalize=False)
112
+
113
+ inpainting_image = inpainting_image.unsqueeze(0).repeat(config.training.batch_size, 1, 1, 1)
114
+
115
+ inpainting_mask = inpainting_mask.unsqueeze(0).to(device)
116
+ inpainting_mask = F.interpolate(inpainting_mask, size=config.dataset.params.resolution // 16, mode='bicubic')
117
+ inpainting_mask = inpainting_mask.repeat(config.training.batch_size, 1, 1, 1)
118
+
119
+ inpainting_mask[inpainting_mask < 0.5] = 0
120
+ inpainting_mask[inpainting_mask >= 0.5] = 1
121
+
122
+ inpainting_mask = inpainting_mask.reshape(config.training.batch_size, -1)
123
+ inpainting_mask = inpainting_mask.to(torch.bool)
124
+
125
+ inpainting_image_tokens = vq_model.get_code(inpainting_image) + len(uni_prompting.text_tokenizer)
126
+ inpainting_image_tokens[inpainting_mask] = mask_token_id
127
+
128
+ input_ids, _ = uni_prompting((prompt, inpainting_image_tokens), 't2i_gen')
129
+
130
+ if config.training.guidance_scale > 0:
131
+ uncond_input_ids, _ = uni_prompting(([''] * len(prompt), inpainting_image_tokens), 't2i_gen')
132
+ attention_mask = create_attention_mask_predict_next(torch.cat([input_ids, uncond_input_ids], dim=0),
133
+ pad_id=int(uni_prompting.sptids_dict['<|pad|>']),
134
+ soi_id=int(uni_prompting.sptids_dict['<|soi|>']),
135
+ eoi_id=int(uni_prompting.sptids_dict['<|eoi|>']),
136
+ rm_pad_in_image=True)
137
+ else:
138
+ attention_mask = create_attention_mask_predict_next(input_ids,
139
+ pad_id=int(uni_prompting.sptids_dict['<|pad|>']),
140
+ soi_id=int(uni_prompting.sptids_dict['<|soi|>']),
141
+ eoi_id=int(uni_prompting.sptids_dict['<|eoi|>']),
142
+ rm_pad_in_image=True)
143
+ uncond_input_ids = None
144
+
145
+ if config.get("mask_schedule", None) is not None:
146
+ schedule = config.mask_schedule.schedule
147
+ args = config.mask_schedule.get("params", {})
148
+ mask_schedule = get_mask_chedule(schedule, **args)
149
+ else:
150
+ mask_schedule = get_mask_chedule(config.training.get("mask_schedule", "cosine"))
151
+
152
+ with torch.no_grad():
153
+ gen_token_ids = model.t2i_generate(
154
+ input_ids=input_ids,
155
+ uncond_input_ids=uncond_input_ids,
156
+ attention_mask=attention_mask,
157
+ guidance_scale=config.training.guidance_scale,
158
+ temperature=config.training.get("generation_temperature", 1.0),
159
+ timesteps=config.training.generation_timesteps,
160
+ noise_schedule=mask_schedule,
161
+ noise_type=config.training.get("noise_type", "mask"),
162
+ seq_len=config.model.showo.num_vq_tokens,
163
+ uni_prompting=uni_prompting,
164
+ config=config,
165
+ )
166
+
167
+ gen_token_ids = torch.clamp(gen_token_ids, max=config.model.showo.codebook_size - 1, min=0)
168
+ images = vq_model.decode_code(gen_token_ids)
169
+
170
+ images = torch.clamp((images + 1.0) / 2.0, min=0.0, max=1.0)
171
+ images *= 255.0
172
+ images = images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8)
173
+
174
+ return images[0]
175
+
176
+
177
+ def text_guided_extrapolation(input_img, input_text, left_ext, right_ext, guidance_scale, generation_timesteps):
178
+ config.offset = 0
179
+ config.training.batch_size = config.batch_size = 1
180
+ config.training.guidance_scale = config.guidance_scale = guidance_scale
181
+ config.training.generation_timesteps = config.generation_timesteps = generation_timesteps
182
+
183
+ extra_direction = ['right'] * int(right_ext) + ['left'] * int(left_ext)
184
+ prompt = [input_text] * len(extra_direction)
185
+ W = config.dataset.params.resolution // 16
186
+ for id, (prt, direction) in enumerate(zip(prompt, extra_direction)):
187
+ prt = [prt] * config.training.batch_size
188
+ if id == 0:
189
+ # extrapolation_image = Image.open(config.image_path).convert("RGB")
190
+ extrapolation_image = input_img
191
+ extrapolation_image = image_transform(extrapolation_image,
192
+ resolution=config.dataset.params.resolution).to(device)
193
+
194
+ B, _, _ = extrapolation_image.shape
195
+ extrapolation_image = extrapolation_image.unsqueeze(0)
196
+ extrapolation_image_tokens = vq_model.get_code(extrapolation_image) + len(uni_prompting.text_tokenizer)
197
+ extrapolation_image_tokens = extrapolation_image_tokens.reshape(1,
198
+ config.dataset.params.resolution // 16,
199
+ config.dataset.params.resolution // 16)
200
+ extrapolation_image_tokens = extrapolation_image_tokens.repeat(config.training.batch_size, 1, 1)
201
+ else:
202
+
203
+ extrapolation_image_tokens = gen_token_ids + len(uni_prompting.text_tokenizer)
204
+
205
+ image_left_part = extrapolation_image_tokens[:, :, :-(W // 2 - config.offset)] - len(
206
+ uni_prompting.text_tokenizer)
207
+ image_right_part = extrapolation_image_tokens[:, :, W // 2 - config.offset:] - len(uni_prompting.text_tokenizer)
208
+ image_up_part = extrapolation_image_tokens[:, :-(W // 2 - config.offset), :] - len(uni_prompting.text_tokenizer)
209
+ image_down_part = extrapolation_image_tokens[:, W // 2 - config.offset:, :] - len(uni_prompting.text_tokenizer)
210
+
211
+ if direction in ['left', 'right']:
212
+ extrapolation_mask = torch.zeros((config.training.batch_size,
213
+ config.dataset.params.resolution // 16,
214
+ config.dataset.params.resolution // 16 // 2 + config.offset),
215
+ dtype=torch.int64, device=device) + mask_token_id
216
+ else:
217
+ extrapolation_mask = torch.zeros((config.training.batch_size,
218
+ config.dataset.params.resolution // 16 // 2 + config.offset,
219
+ config.dataset.params.resolution // 16),
220
+ dtype=torch.int64, device=device) + mask_token_id
221
+
222
+ if direction == 'left':
223
+ extrapolation_image_tokens = torch.cat(
224
+ [extrapolation_mask, extrapolation_image_tokens[:, :, :W // 2 - config.offset]], dim=-1)
225
+ elif direction == 'right':
226
+ extrapolation_image_tokens = torch.cat(
227
+ [extrapolation_image_tokens[:, :, -(W // 2 - config.offset):], extrapolation_mask], dim=-1)
228
+ elif direction == 'up':
229
+ extrapolation_image_tokens = torch.cat(
230
+ [extrapolation_mask, extrapolation_image_tokens[:, :W // 2 - config.offset, :]], dim=-2)
231
+ else:
232
+ extrapolation_image_tokens = torch.cat(
233
+ [extrapolation_image_tokens[:, -(W // 2 - config.offset):, :], extrapolation_mask], dim=-2)
234
+
235
+ extrapolation_image_tokens = extrapolation_image_tokens.reshape(config.training.batch_size, -1)
236
+
237
+ input_ids, _ = uni_prompting((prt, extrapolation_image_tokens), 't2i_gen')
238
+
239
+ if config.training.guidance_scale > 0:
240
+ uncond_input_ids, _ = uni_prompting(([''] * len(prt), extrapolation_image_tokens), 't2i_gen')
241
+ attention_mask = create_attention_mask_predict_next(torch.cat([input_ids, uncond_input_ids], dim=0),
242
+ pad_id=int(uni_prompting.sptids_dict['<|pad|>']),
243
+ soi_id=int(uni_prompting.sptids_dict['<|soi|>']),
244
+ eoi_id=int(uni_prompting.sptids_dict['<|eoi|>']),
245
+ rm_pad_in_image=True)
246
+ else:
247
+ attention_mask = create_attention_mask_predict_next(input_ids,
248
+ pad_id=int(uni_prompting.sptids_dict['<|pad|>']),
249
+ soi_id=int(uni_prompting.sptids_dict['<|soi|>']),
250
+ eoi_id=int(uni_prompting.sptids_dict['<|eoi|>']),
251
+ rm_pad_in_image=True)
252
+ uncond_input_ids = None
253
+
254
+ if config.get("mask_schedule", None) is not None:
255
+ schedule = config.mask_schedule.schedule
256
+ args = config.mask_schedule.get("params", {})
257
+ mask_schedule = get_mask_chedule(schedule, **args)
258
+ else:
259
+ mask_schedule = get_mask_chedule(config.training.get("mask_schedule", "cosine"))
260
+
261
+ with torch.no_grad():
262
+ gen_token_ids = model.t2i_generate(
263
+ input_ids=input_ids,
264
+ uncond_input_ids=uncond_input_ids,
265
+ attention_mask=attention_mask,
266
+ guidance_scale=config.training.guidance_scale,
267
+ temperature=config.training.get("generation_temperature", 1.0),
268
+ timesteps=config.training.generation_timesteps,
269
+ noise_schedule=mask_schedule,
270
+ noise_type=config.training.get("noise_type", "mask"),
271
+ seq_len=config.model.showo.num_vq_tokens,
272
+ uni_prompting=uni_prompting,
273
+ config=config,
274
+ )
275
+
276
+ gen_token_ids = torch.clamp(gen_token_ids, max=config.model.showo.codebook_size - 1, min=0)
277
+ gen_token_ids = gen_token_ids.reshape(config.training.batch_size,
278
+ config.dataset.params.resolution // 16,
279
+ config.dataset.params.resolution // 16)
280
+ if direction == 'left':
281
+ gen_token_ids = torch.cat([gen_token_ids, image_right_part], dim=-1)
282
+ elif direction == 'right':
283
+ gen_token_ids = torch.cat([image_left_part, gen_token_ids], dim=-1)
284
+ elif direction == 'up':
285
+ gen_token_ids = torch.cat([gen_token_ids, image_down_part], dim=-2)
286
+ else:
287
+ gen_token_ids = torch.cat([image_left_part, gen_token_ids], dim=-2)
288
+
289
+ _, h, w = gen_token_ids.shape
290
+ gen_token_ids = gen_token_ids.reshape(config.training.batch_size, -1)
291
+ images = vq_model.decode_code(gen_token_ids, shape=(h, w))
292
+
293
+ images = torch.clamp((images + 1.0) / 2.0, min=0.0, max=1.0)
294
+ images *= 255.0
295
+ images = images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8)
296
+
297
+ return images[0]
298
+
299
+
300
+ def multimodal_understanding(input_img, input_text, chat_history):
301
+ top_k = 1 # retain only the top_k most likely tokens, clamp others to have 0 probability
302
+
303
+ image_ori = input_img
304
+ image = image_transform(image_ori, resolution=config.dataset.params.resolution).to(device)
305
+ image = image.unsqueeze(0)
306
+ image_tokens = vq_model.get_code(image) + len(uni_prompting.text_tokenizer)
307
+
308
+ question = input_text
309
+ input_ids = uni_prompting.text_tokenizer(['USER: \n' + question + ' ASSISTANT:'])[
310
+ 'input_ids']
311
+ input_ids = torch.tensor(input_ids).to(device)
312
+
313
+ input_ids = torch.cat([
314
+ (torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|mmu|>']).to(device),
315
+ (torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|soi|>']).to(device),
316
+ image_tokens,
317
+ (torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|eoi|>']).to(device),
318
+ (torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|sot|>']).to(device),
319
+ input_ids
320
+ ], dim=1).long()
321
+
322
+ attention_mask = create_attention_mask_for_mmu(input_ids.to(device),
323
+ eoi_id=int(uni_prompting.sptids_dict['<|eoi|>']))
324
+
325
+ cont_toks_list = model.mmu_generate(input_ids, attention_mask=attention_mask,
326
+ max_new_tokens=100, top_k=top_k,
327
+ eot_token=uni_prompting.sptids_dict['<|eot|>'])
328
+
329
+ cont_toks_list = torch.stack(cont_toks_list).squeeze()[None]
330
+
331
+ output_text = uni_prompting.text_tokenizer.batch_decode(cont_toks_list, skip_special_tokens=True)
332
+
333
+ output_text = output_text[0].strip()
334
+
335
+ chat_history.append((input_text, output_text))
336
+
337
+ return "", chat_history
338
+
339
+
340
+ with gr.Blocks() as demo:
341
+ gr.HTML("""
342
+ <h1 class="display-2 fw-bold title">
343
+ <a style="color: #70a8dc;">S</a><a style="color: #6fb051;">h</a><a style="color: #e06766;">o</a><a style="color: #f7b26b;">w</a>-o
344
+ </h1>
345
+ <p>This is the official Gradio demo for the Show-o model, a unified model that can do multimodal understanding and generation.</p>
346
+
347
+ <strong>Paper:</strong> <a href="https://arxiv.org/abs/2408.12528" target="_blank">Show-o: One Single Transformer To Unify Multimodal Understanding and Generation </a>
348
+ <br/>
349
+ <strong>Project Website:</strong> <a href="https://showlab.github.io/Show-o/" target="_blank">Show-o Website</a>
350
+ <br/>
351
+ <strong>Code and Models:</strong> <a href="https://github.com/showlab/Show-o" target="_blank">GitHub</a>
352
+ <br/>
353
+ <br/>
354
+ """)
355
+
356
+ with gr.Row():
357
+ with gr.Column():
358
+ text_prompt_t2i = gr.Textbox(
359
+ label="Text prompt",
360
+ lines=2,
361
+ placeholder="Input the text prompt here for image generation."
362
+ )
363
+ guidance_scale_t2i = gr.Slider(
364
+ label="guidance scale",
365
+ minimum=0,
366
+ maximum=5,
367
+ step=0.05,
368
+ value=1.75
369
+ )
370
+ generation_timesteps_t2i = gr.Slider(
371
+ label="timesteps",
372
+ minimum=1,
373
+ maximum=30,
374
+ step=1,
375
+ value=18
376
+ )
377
+ generated_img_t2i = gr.Image(
378
+ label="Output image"
379
+ )
380
+ submit_btn_t2i = gr.Button("Generate: Text-to-image")
381
+ submit_btn_t2i.click(text_to_image_generation,
382
+ [text_prompt_t2i, guidance_scale_t2i, generation_timesteps_t2i],
383
+ [generated_img_t2i])
384
+
385
+ with gr.Row():
386
+ inpainting_input_img = gr.Image(
387
+ label="Input image",
388
+ type="pil",
389
+ )
390
+ inpainting_input_mask = gr.Image(
391
+ label="Inpainting mask",
392
+ image_mode="L",
393
+ type="pil",
394
+ )
395
+
396
+ with gr.Column():
397
+ text_prompt_inpainting = gr.Textbox(
398
+ label="Text prompt",
399
+ lines=2,
400
+ placeholder="Input the text prompt here for image inpainting."
401
+ )
402
+ guidance_scale_inpainting = gr.Slider(
403
+ label="guidance scale",
404
+ minimum=0,
405
+ maximum=5,
406
+ step=0.05,
407
+ value=1.75
408
+ )
409
+ generation_timesteps_inpainting = gr.Slider(
410
+ label="timesteps",
411
+ minimum=1,
412
+ maximum=30,
413
+ step=1,
414
+ value=16
415
+ )
416
+ generated_img_inpainting = gr.Image(
417
+ label="Output image"
418
+ )
419
+ submit_btn_inpainting = gr.Button("Generate: Text-guided Inpainting")
420
+ submit_btn_inpainting.click(text_guided_inpainting,
421
+ [text_prompt_inpainting, inpainting_input_img, inpainting_input_mask,
422
+ guidance_scale_inpainting, generation_timesteps_inpainting],
423
+ [generated_img_inpainting])
424
+
425
+ with gr.Row():
426
+ extra_input_img = gr.Image(
427
+ label="Input image",
428
+ type="pil",
429
+ image_mode="RGB",
430
+ )
431
+
432
+ with gr.Column():
433
+ text_prompt_extrapolation = gr.Textbox(
434
+ label="Text prompt",
435
+ lines=1,
436
+ placeholder="Input the text prompt here for image extrapolation."
437
+ )
438
+ guidance_scale_extrapolation = gr.Slider(
439
+ label="guidance scale",
440
+ minimum=0,
441
+ maximum=5,
442
+ step=0.05,
443
+ value=1.75
444
+ )
445
+ generation_timesteps_extrapolation = gr.Slider(
446
+ label="timesteps",
447
+ minimum=1,
448
+ maximum=30,
449
+ step=1,
450
+ value=16
451
+ )
452
+ left_extrapolation = gr.Slider(
453
+ label="left extrapolation",
454
+ minimum=0,
455
+ maximum=5,
456
+ step=1,
457
+ value=1
458
+ )
459
+ right_extrapolation = gr.Slider(
460
+ label="right extrapolation",
461
+ minimum=0,
462
+ maximum=5,
463
+ step=1,
464
+ value=1
465
+ )
466
+ generated_img_extrapolation = gr.Image(
467
+ label="Output image"
468
+ )
469
+ submit_btn_inpainting = gr.Button("Generate: Text-guided Extrapolation")
470
+ submit_btn_inpainting.click(text_guided_extrapolation,
471
+ [extra_input_img, text_prompt_extrapolation, left_extrapolation, right_extrapolation,
472
+ guidance_scale_extrapolation, generation_timesteps_extrapolation],
473
+ [generated_img_extrapolation])
474
+
475
+ with gr.Row():
476
+ with gr.Row():
477
+ chat_input_img = gr.Image(
478
+ label="Input image",
479
+ type="pil",
480
+ image_mode="RGB",
481
+ )
482
+ with gr.Column():
483
+ chatbot = gr.Chatbot()
484
+ msg = gr.Textbox(label="Press Enter to send a message for chat")
485
+ clear = gr.ClearButton([msg, chatbot])
486
+ msg.submit(multimodal_understanding, [chat_input_img, msg, chatbot], [msg, chatbot])
487
+
488
+ demo.launch()
gradio/app_gradio.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "False"
3
+ os.environ["TOKENIZERS_PARALLELISM"] = "true"
4
+ import tempfile
5
+ from share_btn import share_js, save_js
6
+ import gradio as gr
7
+ from PIL import Image
8
+ import torch
9
+ from omegaconf import OmegaConf
10
+ from transformers import AutoTokenizer
11
+
12
+ from models import Showo, MAGVITv2, get_mask_chedule
13
+ from prompting_utils import UniversalPrompting, create_attention_mask_predict_next
14
+
15
+
16
+ # Prepare model
17
+ config = OmegaConf.load("configs/showo_demo.yaml")
18
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
+ tokenizer = AutoTokenizer.from_pretrained(config.model.showo.llm_model_path, padding_side="left")
20
+
21
+ uni_prompting = UniversalPrompting(tokenizer, max_text_len=config.dataset.preprocessing.max_seq_length,
22
+ special_tokens=("<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", "<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>"),
23
+ ignore_id=-100, cond_dropout_prob=config.training.cond_dropout_prob)
24
+
25
+ vq_model = MAGVITv2(config.model.vq_model.type)
26
+ vq_model = vq_model.from_pretrained(config.model.vq_model.vq_model_name).to(device)
27
+ vq_model.requires_grad_(False)
28
+ vq_model.eval()
29
+
30
+ model = Showo.from_pretrained(config.model.showo.pretrained_model_path).to(device)
31
+ model.eval()
32
+
33
+ mask_token_id = model.config.mask_token_id
34
+
35
+
36
+ css = """
37
+ #chatbot { min-height: 300px; }
38
+ #save-btn {
39
+ background-image: linear-gradient(to right bottom, rgba(130,217,244, 0.9), rgba(158,231,214, 1.0));
40
+ }
41
+ #save-btn:hover {
42
+ background-image: linear-gradient(to right bottom, rgba(110,197,224, 0.9), rgba(138,211,194, 1.0));
43
+ }
44
+ #share-btn {
45
+ background-image: linear-gradient(to right bottom, rgba(130,217,244, 0.9), rgba(158,231,214, 1.0));
46
+ }
47
+ #share-btn:hover {
48
+ background-image: linear-gradient(to right bottom, rgba(110,197,224, 0.9), rgba(138,211,194, 1.0));
49
+ }
50
+ #gallery { z-index: 999999; }
51
+ #gallery img:hover {transform: scale(2.3); z-index: 999999; position: relative; padding-right: 30%; padding-bottom: 30%;}
52
+ #gallery button img:hover {transform: none; z-index: 999999; position: relative; padding-right: 0; padding-bottom: 0;}
53
+ @media (hover: none) {
54
+ #gallery img:hover {transform: none; z-index: 999999; position: relative; padding-right: 0; 0;}
55
+ }
56
+ .html2canvas-container { width: 3000px !important; height: 3000px !important; }
57
+ """
58
+
59
+
60
+ def upload_image(state, image_input):
61
+ conversation = state[0]
62
+ chat_history = state[1]
63
+ input_image = Image.open(image_input.name).resize(
64
+ (224, 224)).convert('RGB')
65
+ input_image.save(image_input.name) # Overwrite with smaller image.
66
+ conversation += [(f'<img src="./file={image_input.name}" style="display: inline-block;">', "")]
67
+ return [conversation, chat_history + [input_image, ""]], conversation
68
+
69
+
70
+ def reset():
71
+ return [[], []], []
72
+
73
+
74
+ def reset_last(state):
75
+ conversation = state[0][:-1]
76
+ chat_history = state[1][:-2]
77
+ return [conversation, chat_history], conversation
78
+
79
+
80
+ def save_image_to_local(image: Image.Image):
81
+ filename = next(tempfile._get_candidate_names()) + '.png'
82
+ image.save(filename)
83
+ return filename
84
+
85
+
86
+ def text_to_image_generation(input_text, state, guidance_scale, generation_timesteps):
87
+ prompts = [input_text]
88
+ config.training.batch_size = config.batch_size = 1
89
+ config.training.guidance_scale = config.guidance_scale = guidance_scale
90
+ config.training.generation_timesteps = config.generation_timesteps = generation_timesteps
91
+
92
+ image_tokens = torch.ones((len(prompts), config.model.showo.num_vq_tokens),
93
+ dtype=torch.long, device=device) * mask_token_id
94
+
95
+ input_ids, _ = uni_prompting((prompts, image_tokens), 't2i_gen')
96
+
97
+ if config.training.guidance_scale > 0:
98
+ uncond_input_ids, _ = uni_prompting(([''] * len(prompts), image_tokens), 't2i_gen')
99
+ attention_mask = create_attention_mask_predict_next(torch.cat([input_ids, uncond_input_ids], dim=0),
100
+ pad_id=int(uni_prompting.sptids_dict['<|pad|>']),
101
+ soi_id=int(uni_prompting.sptids_dict['<|soi|>']),
102
+ eoi_id=int(uni_prompting.sptids_dict['<|eoi|>']),
103
+ rm_pad_in_image=True)
104
+ else:
105
+ attention_mask = create_attention_mask_predict_next(input_ids,
106
+ pad_id=int(uni_prompting.sptids_dict['<|pad|>']),
107
+ soi_id=int(uni_prompting.sptids_dict['<|soi|>']),
108
+ eoi_id=int(uni_prompting.sptids_dict['<|eoi|>']),
109
+ rm_pad_in_image=True)
110
+ uncond_input_ids = None
111
+
112
+ if config.get("mask_schedule", None) is not None:
113
+ schedule = config.mask_schedule.schedule
114
+ args = config.mask_schedule.get("params", {})
115
+ mask_schedule = get_mask_chedule(schedule, **args)
116
+ else:
117
+ mask_schedule = get_mask_chedule(config.training.get("mask_schedule", "cosine"))
118
+
119
+ with torch.no_grad():
120
+ gen_token_ids = model.t2i_generate(
121
+ input_ids=input_ids,
122
+ uncond_input_ids=uncond_input_ids,
123
+ attention_mask=attention_mask,
124
+ guidance_scale=config.training.guidance_scale,
125
+ temperature=config.training.get("generation_temperature", 1.0),
126
+ timesteps=config.training.generation_timesteps,
127
+ noise_schedule=mask_schedule,
128
+ noise_type=config.training.get("noise_type", "mask"),
129
+ seq_len=config.model.showo.num_vq_tokens,
130
+ uni_prompting=uni_prompting,
131
+ config=config,
132
+ )
133
+
134
+ gen_token_ids = torch.clamp(gen_token_ids, max=config.model.showo.codebook_size - 1, min=0)
135
+ images = vq_model.decode_code(gen_token_ids)
136
+
137
+ images = torch.clamp((images + 1.0) / 2.0, min=0.0, max=1.0)
138
+ images *= 255.0
139
+ images = images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8)
140
+ pil_images = [Image.fromarray(image) for image in images]
141
+
142
+ wandb_images = [wandb.Image(image, caption=prompts[i]) for i, image in enumerate(pil_images)]
143
+ wandb.log({"generated_images": wandb_images}, step=step)
144
+
145
+
146
+ def generate_for_prompt(input_text, state, ret_scale_factor, num_words, temperature):
147
+ g_cuda = torch.Generator(device='cuda').manual_seed(1337)
148
+
149
+ # Ignore empty inputs.
150
+ if len(input_text) == 0:
151
+ return state, state[0], gr.update(visible=True)
152
+
153
+ input_prompt = 'Q: ' + input_text + '\nA:'
154
+ conversation = state[0]
155
+ chat_history = state[1]
156
+ print('Generating for', chat_history, flush=True)
157
+
158
+ # If an image was uploaded, prepend it to the model.
159
+ model_inputs = chat_history
160
+ model_inputs.append(input_prompt)
161
+ # Remove empty text.
162
+ model_inputs = [s for s in model_inputs if s != '']
163
+
164
+ top_p = 1.0
165
+ if temperature != 0.0:
166
+ top_p = 0.95
167
+
168
+ print('Running model.generate_for_images_and_texts with', model_inputs, flush=True)
169
+ model_outputs = model.generate_for_images_and_texts(model_inputs,
170
+ num_words=max(num_words, 1), ret_scale_factor=ret_scale_factor, top_p=top_p,
171
+ temperature=temperature, max_num_rets=1,
172
+ num_inference_steps=50, generator=g_cuda)
173
+ print('model_outputs', model_outputs, ret_scale_factor, flush=True)
174
+
175
+ response = ''
176
+ text_outputs = []
177
+ for output_i, p in enumerate(model_outputs):
178
+ if type(p) == str:
179
+ if output_i > 0:
180
+ response += '<br/>'
181
+ # Remove the image tokens for output.
182
+ text_outputs.append(p.replace('[IMG0] [IMG1] [IMG2] [IMG3] [IMG4] [IMG5] [IMG6] [IMG7]', ''))
183
+ response += p
184
+ if len(model_outputs) > 1:
185
+ response += '<br/>'
186
+ elif type(p) == dict:
187
+ # Decide whether to generate or retrieve.
188
+ if p['decision'] is not None and p['decision'][0] == 'gen':
189
+ image = p['gen'][0][0]#.resize((224, 224))
190
+ filename = save_image_to_local(image)
191
+ response += f'<img src="./file={filename}" style="display: inline-block;"><p style="font-size: 12px; color: #555; margin-top: 0;">(Generated)</p>'
192
+ else:
193
+ image = p['ret'][0][0]#.resize((224, 224))
194
+ filename = save_image_to_local(image)
195
+ response += f'<img src="./file={filename}" style="display: inline-block;"><p style="font-size: 12px; color: #555; margin-top: 0;">(Retrieved)</p>'
196
+
197
+ chat_history = model_inputs + \
198
+ [' '.join([s for s in model_outputs if type(s) == str]) + '\n']
199
+ # Remove [RET] from outputs.
200
+ conversation.append((input_text, response.replace('[IMG0] [IMG1] [IMG2] [IMG3] [IMG4] [IMG5] [IMG6] [IMG7]', '')))
201
+
202
+ # Set input image to None.
203
+ print('state', state, flush=True)
204
+ print('updated state', [conversation, chat_history], flush=True)
205
+ return [conversation, chat_history], conversation, gr.update(visible=True), gr.update(visible=True)
206
+
207
+
208
+ with gr.Blocks(css=css) as demo:
209
+ gr.HTML("""
210
+ <h1>🐟 GILL</h1>
211
+ <p>This is the official Gradio demo for the GILL model, a model that can process arbitrarily interleaved image and text inputs, and produce image and text outputs.</p>
212
+
213
+ <strong>Paper:</strong> <a href="https://arxiv.org/abs/2305.17216" target="_blank">Generating Images with Multimodal Language Models</a>
214
+ <br/>
215
+ <strong>Project Website:</strong> <a href="https://jykoh.com/gill" target="_blank">GILL Website</a>
216
+ <br/>
217
+ <strong>Code and Models:</strong> <a href="https://github.com/kohjingyu/gill" target="_blank">GitHub</a>
218
+ <br/>
219
+ <br/>
220
+
221
+ <strong>Tips:</strong>
222
+ <ul>
223
+ <li>Start by inputting either image or text prompts (or both) and chat with GILL to get image-and-text replies.</li>
224
+ <li>Tweak the level of sensitivity to images and text using the parameters on the right.</li>
225
+ <li>Check out cool conversations in the examples or community tab for inspiration and share your own!</li>
226
+ <li>If the model outputs a blank image, it is because Stable Diffusion's safety filter detected inappropriate content. Please try again with a different prompt.</li>
227
+ <li>Outputs may differ slightly from the paper due to slight implementation differences. For reproducing paper results, please use our <a href="https://github.com/kohjingyu/gill" target="_blank">official code</a>.</li>
228
+ <li>For faster inference without waiting in queue, you may duplicate the space and use your own GPU: <a href="https://huggingface.co/spaces/jykoh/gill?duplicate=true"><img style="display: inline-block; margin-top: 0em; margin-bottom: 0em" src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAAAAXNSR0IArs4c6QAAAP5JREFUOE+lk7FqAkEURY+ltunEgFXS2sZGIbXfEPdLlnxJyDdYB62sbbUKpLbVNhyYFzbrrA74YJlh9r079973psed0cvUD4A+4HoCjsA85X0Dfn/RBLBgBDxnQPfAEJgBY+A9gALA4tcbamSzS4xq4FOQAJgCDwV2CPKV8tZAJcAjMMkUe1vX+U+SMhfAJEHasQIWmXNN3abzDwHUrgcRGmYcgKe0bxrblHEB4E/pndMazNpSZGcsZdBlYJcEL9Afo75molJyM2FxmPgmgPqlWNLGfwZGG6UiyEvLzHYDmoPkDDiNm9JR9uboiONcBXrpY1qmgs21x1QwyZcpvxt9NS09PlsPAAAAAElFTkSuQmCC&logoWidth=14" alt="Duplicate Space"></a></li>
229
+ </ul>
230
+ """)
231
+
232
+ gr_state = gr.State([[], []]) # conversation, chat_history
233
+
234
+ with gr.Row():
235
+ with gr.Column(scale=0.7, min_width=500):
236
+ with gr.Row():
237
+ chatbot = gr.Chatbot(elem_id="chatbot", label="🐟 GILL Chatbot")
238
+ with gr.Row():
239
+ image_btn = gr.UploadButton("🖼️ Upload Image", file_types=["image"])
240
+
241
+ text_input = gr.Textbox(label="Message", placeholder="Type a message")
242
+
243
+ with gr.Column():
244
+ submit_btn = gr.Button("Submit", interactive=True, variant="primary")
245
+ clear_last_btn = gr.Button("Undo")
246
+ clear_btn = gr.Button("Reset All")
247
+ with gr.Row(visible=False) as save_group:
248
+ save_button = gr.Button("💾 Save Conversation as .png", elem_id="save-btn")
249
+
250
+ with gr.Row(visible=False) as share_group:
251
+ share_button = gr.Button("🤗 Share to Community (opens new window)", elem_id="share-btn")
252
+
253
+ with gr.Column(scale=0.3, min_width=400):
254
+ ret_scale_factor = gr.Slider(minimum=0.0, maximum=3.0, value=1.3, step=0.1, interactive=True,
255
+ label="Frequency multiplier for returning images (higher means more frequent)")
256
+ gr_max_len = gr.Slider(minimum=1, maximum=64, value=32,
257
+ step=1, interactive=True, label="Max # of words")
258
+ gr_temperature = gr.Slider(
259
+ minimum=0.0, maximum=1.0, value=0.0, step=0.1, interactive=True, label="Temperature (0 for deterministic, higher for more randomness)")
260
+
261
+ gallery = gr.Gallery(
262
+ value=[Image.open(e) for e in examples], label="Example Conversations", show_label=True, elem_id="gallery",
263
+ ).style(grid=[2], height="auto")
264
+
265
+ text_input.submit(generate_for_prompt, [text_input, gr_state, ret_scale_factor,
266
+ gr_max_len, gr_temperature], [gr_state, chatbot, share_group, save_group])
267
+ text_input.submit(lambda: "", None, text_input) # Reset chatbox.
268
+
269
+ submit_btn.click(generate_for_prompt, [text_input, gr_state, ret_scale_factor,
270
+ gr_max_len, gr_temperature], [gr_state, chatbot, share_group, save_group])
271
+ submit_btn.click(lambda: "", None, text_input) # Reset chatbox.
272
+
273
+ image_btn.upload(upload_image, [gr_state, image_btn], [gr_state, chatbot])
274
+ clear_last_btn.click(reset_last, [gr_state], [gr_state, chatbot])
275
+ clear_btn.click(reset, [], [gr_state, chatbot])
276
+ share_button.click(None, [], [], _js=share_js)
277
+ save_button.click(None, [], [], _js=save_js)
278
+
279
+
280
+ demo.queue(concurrency_count=1, api_open=False, max_size=16)
281
+ demo.launch(debug=True, server_name="0.0.0.0")
gradio/app_w_clip.py ADDED
@@ -0,0 +1,559 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "False"
3
+ os.environ["TOKENIZERS_PARALLELISM"] = "true"
4
+ import numpy as np
5
+ import gradio as gr
6
+ import torch
7
+ from PIL import Image
8
+ from omegaconf import OmegaConf
9
+ from transformers import AutoTokenizer
10
+ import torch.nn.functional as F
11
+ from transformers import CLIPImageProcessor
12
+
13
+ import sys
14
+ sys.path.insert(0, ".")
15
+ from training import conversation as conversation_lib
16
+ from prompting_utils import UniversalPrompting, create_attention_mask_predict_next, create_attention_mask_for_mmu_vit
17
+ from training_utils import image_transform
18
+ from models import Showo, MAGVITv2, get_mask_chedule, CLIPVisionTower
19
+
20
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
+ conversation_lib.default_conversation = conversation_lib.conv_templates["phi1.5"]
22
+ SYSTEM_PROMPT = "A chat between a curious user and an artificial intelligence assistant. " \
23
+ "The assistant gives helpful, detailed, and polite answers to the user's questions."
24
+ SYSTEM_PROMPT_LEN = 28
25
+
26
+
27
+ def load_discrete_checkpoint():
28
+ config = OmegaConf.load("configs/showo_demo.yaml")
29
+ tokenizer = AutoTokenizer.from_pretrained(config.model.showo.llm_model_path, padding_side="left")
30
+
31
+ uni_prompting = UniversalPrompting(tokenizer, max_text_len=config.dataset.preprocessing.max_seq_length,
32
+ special_tokens=("<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", "<|mmu|>",
33
+ "<|t2v|>", "<|v2v|>", "<|lvg|>"),
34
+ ignore_id=-100, cond_dropout_prob=config.training.cond_dropout_prob)
35
+
36
+ vq_model = MAGVITv2()
37
+ vq_model = vq_model.from_pretrained(config.model.vq_model.vq_model_name).to(device)
38
+ vq_model.requires_grad_(False)
39
+ vq_model.eval()
40
+
41
+ model = Showo.from_pretrained(config.model.showo.pretrained_model_path).to(device)
42
+ model.eval()
43
+ mask_token_id = model.config.mask_token_id
44
+
45
+ return config, uni_prompting, tokenizer, vq_model, model, mask_token_id
46
+
47
+
48
+ config_gen, uni_prompting_gen, tokenizer_gen, vq_model_gen, model_gen, mask_token_id = load_discrete_checkpoint()
49
+
50
+
51
+ def load_continuous_checkpoint():
52
+ config = OmegaConf.load("configs/showo_demo_w_clip_vit.yaml")
53
+
54
+ tokenizer = AutoTokenizer.from_pretrained(config.model.showo.llm_model_path, padding_side="left")
55
+
56
+ uni_prompting = UniversalPrompting(tokenizer, max_text_len=config.dataset.preprocessing.max_seq_length,
57
+ special_tokens=(
58
+ "<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", "<|mmu|>", "<|t2v|>",
59
+ "<|v2v|>", "<|lvg|>"),
60
+ ignore_id=-100, cond_dropout_prob=config.training.cond_dropout_prob)
61
+
62
+ vision_tower_name = "openai/clip-vit-large-patch14-336"
63
+ vision_tower = CLIPVisionTower(vision_tower_name).to(device)
64
+ clip_image_processor = CLIPImageProcessor.from_pretrained(vision_tower_name)
65
+
66
+ model = Showo.from_pretrained(config.model.showo.pretrained_model_path).to(device)
67
+ model.eval()
68
+
69
+ return config, uni_prompting, tokenizer, model, vision_tower, clip_image_processor
70
+
71
+
72
+ config_mmu = uni_prompting_mmu = tokenizer_mmu = model_mmu = vision_tower = clip_image_processor = None
73
+
74
+
75
+ def text_to_image_generation(input_text, guidance_scale, generation_timesteps):
76
+ config, uni_prompting, tokenizer, vq_model, model = config_gen, uni_prompting_gen, tokenizer_gen, vq_model_gen, model_gen
77
+
78
+ prompts = [input_text]
79
+ config.training.batch_size = config.batch_size = 1
80
+ config.training.guidance_scale = config.guidance_scale = guidance_scale
81
+ config.training.generation_timesteps = config.generation_timesteps = generation_timesteps
82
+
83
+ image_tokens = torch.ones((len(prompts), config.model.showo.num_vq_tokens),
84
+ dtype=torch.long, device=device) * mask_token_id
85
+
86
+ input_ids, _ = uni_prompting((prompts, image_tokens), 't2i_gen')
87
+
88
+ if config.training.guidance_scale > 0:
89
+ uncond_input_ids, _ = uni_prompting(([''] * len(prompts), image_tokens), 't2i_gen')
90
+ attention_mask = create_attention_mask_predict_next(torch.cat([input_ids, uncond_input_ids], dim=0),
91
+ pad_id=int(uni_prompting.sptids_dict['<|pad|>']),
92
+ soi_id=int(uni_prompting.sptids_dict['<|soi|>']),
93
+ eoi_id=int(uni_prompting.sptids_dict['<|eoi|>']),
94
+ rm_pad_in_image=True)
95
+ else:
96
+ attention_mask = create_attention_mask_predict_next(input_ids,
97
+ pad_id=int(uni_prompting.sptids_dict['<|pad|>']),
98
+ soi_id=int(uni_prompting.sptids_dict['<|soi|>']),
99
+ eoi_id=int(uni_prompting.sptids_dict['<|eoi|>']),
100
+ rm_pad_in_image=True)
101
+ uncond_input_ids = None
102
+
103
+ if config.get("mask_schedule", None) is not None:
104
+ schedule = config.mask_schedule.schedule
105
+ args = config.mask_schedule.get("params", {})
106
+ mask_schedule = get_mask_chedule(schedule, **args)
107
+ else:
108
+ mask_schedule = get_mask_chedule(config.training.get("mask_schedule", "cosine"))
109
+
110
+ with torch.no_grad():
111
+ gen_token_ids = model.t2i_generate(
112
+ input_ids=input_ids,
113
+ uncond_input_ids=uncond_input_ids,
114
+ attention_mask=attention_mask,
115
+ guidance_scale=config.training.guidance_scale,
116
+ temperature=config.training.get("generation_temperature", 1.0),
117
+ timesteps=config.training.generation_timesteps,
118
+ noise_schedule=mask_schedule,
119
+ noise_type=config.training.get("noise_type", "mask"),
120
+ seq_len=config.model.showo.num_vq_tokens,
121
+ uni_prompting=uni_prompting,
122
+ config=config,
123
+ )
124
+
125
+ gen_token_ids = torch.clamp(gen_token_ids, max=config.model.showo.codebook_size - 1, min=0)
126
+ images = vq_model.decode_code(gen_token_ids)
127
+
128
+ images = torch.clamp((images + 1.0) / 2.0, min=0.0, max=1.0)
129
+ images *= 255.0
130
+ images = images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8)
131
+
132
+ return images[0]
133
+
134
+
135
+ def text_guided_inpainting(input_text, inpainting_image, inpainting_mask, guidance_scale, generation_timesteps):
136
+ config, uni_prompting, tokenizer, vq_model, model = config_gen, uni_prompting_gen, tokenizer_gen, vq_model_gen, model_gen
137
+
138
+ prompt = [input_text]
139
+
140
+ config.training.batch_size = config.batch_size = 1
141
+ config.training.guidance_scale = config.guidance_scale = guidance_scale
142
+ config.training.generation_timesteps = config.generation_timesteps = generation_timesteps
143
+
144
+ inpainting_image = image_transform(inpainting_image, resolution=config.dataset.params.resolution).to(device)
145
+ inpainting_mask = image_transform(inpainting_mask, resolution=config.dataset.params.resolution, normalize=False)
146
+
147
+ inpainting_image = inpainting_image.unsqueeze(0).repeat(config.training.batch_size, 1, 1, 1)
148
+
149
+ inpainting_mask = inpainting_mask.unsqueeze(0).to(device)
150
+ inpainting_mask = F.interpolate(inpainting_mask, size=config.dataset.params.resolution // 16, mode='bicubic')
151
+ inpainting_mask = inpainting_mask.repeat(config.training.batch_size, 1, 1, 1)
152
+
153
+ inpainting_mask[inpainting_mask < 0.5] = 0
154
+ inpainting_mask[inpainting_mask >= 0.5] = 1
155
+
156
+ inpainting_mask = inpainting_mask.reshape(config.training.batch_size, -1)
157
+ inpainting_mask = inpainting_mask.to(torch.bool)
158
+
159
+ inpainting_image_tokens = vq_model.get_code(inpainting_image) + len(uni_prompting.text_tokenizer)
160
+ inpainting_image_tokens[inpainting_mask] = mask_token_id
161
+
162
+ input_ids, _ = uni_prompting((prompt, inpainting_image_tokens), 't2i_gen')
163
+
164
+ if config.training.guidance_scale > 0:
165
+ uncond_input_ids, _ = uni_prompting(([''] * len(prompt), inpainting_image_tokens), 't2i_gen')
166
+ attention_mask = create_attention_mask_predict_next(torch.cat([input_ids, uncond_input_ids], dim=0),
167
+ pad_id=int(uni_prompting.sptids_dict['<|pad|>']),
168
+ soi_id=int(uni_prompting.sptids_dict['<|soi|>']),
169
+ eoi_id=int(uni_prompting.sptids_dict['<|eoi|>']),
170
+ rm_pad_in_image=True)
171
+ else:
172
+ attention_mask = create_attention_mask_predict_next(input_ids,
173
+ pad_id=int(uni_prompting.sptids_dict['<|pad|>']),
174
+ soi_id=int(uni_prompting.sptids_dict['<|soi|>']),
175
+ eoi_id=int(uni_prompting.sptids_dict['<|eoi|>']),
176
+ rm_pad_in_image=True)
177
+ uncond_input_ids = None
178
+
179
+ if config.get("mask_schedule", None) is not None:
180
+ schedule = config.mask_schedule.schedule
181
+ args = config.mask_schedule.get("params", {})
182
+ mask_schedule = get_mask_chedule(schedule, **args)
183
+ else:
184
+ mask_schedule = get_mask_chedule(config.training.get("mask_schedule", "cosine"))
185
+
186
+ with torch.no_grad():
187
+ gen_token_ids = model.t2i_generate(
188
+ input_ids=input_ids,
189
+ uncond_input_ids=uncond_input_ids,
190
+ attention_mask=attention_mask,
191
+ guidance_scale=config.training.guidance_scale,
192
+ temperature=config.training.get("generation_temperature", 1.0),
193
+ timesteps=config.training.generation_timesteps,
194
+ noise_schedule=mask_schedule,
195
+ noise_type=config.training.get("noise_type", "mask"),
196
+ seq_len=config.model.showo.num_vq_tokens,
197
+ uni_prompting=uni_prompting,
198
+ config=config,
199
+ )
200
+
201
+ gen_token_ids = torch.clamp(gen_token_ids, max=config.model.showo.codebook_size - 1, min=0)
202
+ images = vq_model.decode_code(gen_token_ids)
203
+
204
+ images = torch.clamp((images + 1.0) / 2.0, min=0.0, max=1.0)
205
+ images *= 255.0
206
+ images = images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8)
207
+
208
+ return images[0]
209
+
210
+
211
+ def text_guided_extrapolation(input_img, input_text, left_ext, right_ext, guidance_scale, generation_timesteps):
212
+ config, uni_prompting, tokenizer, vq_model, model = config_gen, uni_prompting_gen, tokenizer_gen, vq_model_gen, model_gen
213
+
214
+ config.offset = 0
215
+ config.training.batch_size = config.batch_size = 1
216
+ config.training.guidance_scale = config.guidance_scale = guidance_scale
217
+ config.training.generation_timesteps = config.generation_timesteps = generation_timesteps
218
+
219
+ extra_direction = ['right'] * int(right_ext) + ['left'] * int(left_ext)
220
+ prompt = [input_text] * len(extra_direction)
221
+ W = config.dataset.params.resolution // 16
222
+ for id, (prt, direction) in enumerate(zip(prompt, extra_direction)):
223
+ prt = [prt] * config.training.batch_size
224
+ if id == 0:
225
+ # extrapolation_image = Image.open(config.image_path).convert("RGB")
226
+ extrapolation_image = input_img
227
+ extrapolation_image = image_transform(extrapolation_image,
228
+ resolution=config.dataset.params.resolution).to(device)
229
+
230
+ B, _, _ = extrapolation_image.shape
231
+ extrapolation_image = extrapolation_image.unsqueeze(0)
232
+ extrapolation_image_tokens = vq_model.get_code(extrapolation_image) + len(uni_prompting.text_tokenizer)
233
+ extrapolation_image_tokens = extrapolation_image_tokens.reshape(1,
234
+ config.dataset.params.resolution // 16,
235
+ config.dataset.params.resolution // 16)
236
+ extrapolation_image_tokens = extrapolation_image_tokens.repeat(config.training.batch_size, 1, 1)
237
+ else:
238
+
239
+ extrapolation_image_tokens = gen_token_ids + len(uni_prompting.text_tokenizer)
240
+
241
+ image_left_part = extrapolation_image_tokens[:, :, :-(W // 2 - config.offset)] - len(
242
+ uni_prompting.text_tokenizer)
243
+ image_right_part = extrapolation_image_tokens[:, :, W // 2 - config.offset:] - len(uni_prompting.text_tokenizer)
244
+ image_up_part = extrapolation_image_tokens[:, :-(W // 2 - config.offset), :] - len(uni_prompting.text_tokenizer)
245
+ image_down_part = extrapolation_image_tokens[:, W // 2 - config.offset:, :] - len(uni_prompting.text_tokenizer)
246
+
247
+ if direction in ['left', 'right']:
248
+ extrapolation_mask = torch.zeros((config.training.batch_size,
249
+ config.dataset.params.resolution // 16,
250
+ config.dataset.params.resolution // 16 // 2 + config.offset),
251
+ dtype=torch.int64, device=device) + mask_token_id
252
+ else:
253
+ extrapolation_mask = torch.zeros((config.training.batch_size,
254
+ config.dataset.params.resolution // 16 // 2 + config.offset,
255
+ config.dataset.params.resolution // 16),
256
+ dtype=torch.int64, device=device) + mask_token_id
257
+
258
+ if direction == 'left':
259
+ extrapolation_image_tokens = torch.cat(
260
+ [extrapolation_mask, extrapolation_image_tokens[:, :, :W // 2 - config.offset]], dim=-1)
261
+ elif direction == 'right':
262
+ extrapolation_image_tokens = torch.cat(
263
+ [extrapolation_image_tokens[:, :, -(W // 2 - config.offset):], extrapolation_mask], dim=-1)
264
+ elif direction == 'up':
265
+ extrapolation_image_tokens = torch.cat(
266
+ [extrapolation_mask, extrapolation_image_tokens[:, :W // 2 - config.offset, :]], dim=-2)
267
+ else:
268
+ extrapolation_image_tokens = torch.cat(
269
+ [extrapolation_image_tokens[:, -(W // 2 - config.offset):, :], extrapolation_mask], dim=-2)
270
+
271
+ extrapolation_image_tokens = extrapolation_image_tokens.reshape(config.training.batch_size, -1)
272
+
273
+ input_ids, _ = uni_prompting((prt, extrapolation_image_tokens), 't2i_gen')
274
+
275
+ if config.training.guidance_scale > 0:
276
+ uncond_input_ids, _ = uni_prompting(([''] * len(prt), extrapolation_image_tokens), 't2i_gen')
277
+ attention_mask = create_attention_mask_predict_next(torch.cat([input_ids, uncond_input_ids], dim=0),
278
+ pad_id=int(uni_prompting.sptids_dict['<|pad|>']),
279
+ soi_id=int(uni_prompting.sptids_dict['<|soi|>']),
280
+ eoi_id=int(uni_prompting.sptids_dict['<|eoi|>']),
281
+ rm_pad_in_image=True)
282
+ else:
283
+ attention_mask = create_attention_mask_predict_next(input_ids,
284
+ pad_id=int(uni_prompting.sptids_dict['<|pad|>']),
285
+ soi_id=int(uni_prompting.sptids_dict['<|soi|>']),
286
+ eoi_id=int(uni_prompting.sptids_dict['<|eoi|>']),
287
+ rm_pad_in_image=True)
288
+ uncond_input_ids = None
289
+
290
+ if config.get("mask_schedule", None) is not None:
291
+ schedule = config.mask_schedule.schedule
292
+ args = config.mask_schedule.get("params", {})
293
+ mask_schedule = get_mask_chedule(schedule, **args)
294
+ else:
295
+ mask_schedule = get_mask_chedule(config.training.get("mask_schedule", "cosine"))
296
+
297
+ with torch.no_grad():
298
+ gen_token_ids = model.t2i_generate(
299
+ input_ids=input_ids,
300
+ uncond_input_ids=uncond_input_ids,
301
+ attention_mask=attention_mask,
302
+ guidance_scale=config.training.guidance_scale,
303
+ temperature=config.training.get("generation_temperature", 1.0),
304
+ timesteps=config.training.generation_timesteps,
305
+ noise_schedule=mask_schedule,
306
+ noise_type=config.training.get("noise_type", "mask"),
307
+ seq_len=config.model.showo.num_vq_tokens,
308
+ uni_prompting=uni_prompting,
309
+ config=config,
310
+ )
311
+
312
+ gen_token_ids = torch.clamp(gen_token_ids, max=config.model.showo.codebook_size - 1, min=0)
313
+ gen_token_ids = gen_token_ids.reshape(config.training.batch_size,
314
+ config.dataset.params.resolution // 16,
315
+ config.dataset.params.resolution // 16)
316
+ if direction == 'left':
317
+ gen_token_ids = torch.cat([gen_token_ids, image_right_part], dim=-1)
318
+ elif direction == 'right':
319
+ gen_token_ids = torch.cat([image_left_part, gen_token_ids], dim=-1)
320
+ elif direction == 'up':
321
+ gen_token_ids = torch.cat([gen_token_ids, image_down_part], dim=-2)
322
+ else:
323
+ gen_token_ids = torch.cat([image_left_part, gen_token_ids], dim=-2)
324
+
325
+ _, h, w = gen_token_ids.shape
326
+ gen_token_ids = gen_token_ids.reshape(config.training.batch_size, -1)
327
+ images = vq_model.decode_code(gen_token_ids, shape=(h, w))
328
+
329
+ images = torch.clamp((images + 1.0) / 2.0, min=0.0, max=1.0)
330
+ images *= 255.0
331
+ images = images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8)
332
+
333
+ return images[0]
334
+
335
+
336
+ def multimodal_understanding(input_img, input_text, chat_history):
337
+ global config_mmu, uni_prompting_mmu, tokenizer_mmu, model_mmu, vision_tower, clip_image_processor
338
+ if model_mmu is None:
339
+ config_mmu, uni_prompting_mmu, tokenizer_mmu, model_mmu, vision_tower, clip_image_processor = load_continuous_checkpoint()
340
+ config, uni_prompting, tokenizer, model = config_mmu, uni_prompting_mmu, tokenizer_mmu, model_mmu
341
+
342
+ image_ori = input_img
343
+ pixel_values = clip_image_processor.preprocess(image_ori, return_tensors="pt")["pixel_values"][0]
344
+ batch_size = 1
345
+ question = input_text
346
+ top_k = 1 # retain only the top_k most likely tokens, clamp others to have 0 probability
347
+
348
+ conv = conversation_lib.default_conversation.copy()
349
+ conv.append_message(conv.roles[0], question)
350
+ conv.append_message(conv.roles[1], None)
351
+ prompt_question = conv.get_prompt()
352
+ question_input = []
353
+ question_input.append(prompt_question.strip())
354
+
355
+ input_ids_system = [uni_prompting.text_tokenizer(SYSTEM_PROMPT, return_tensors="pt", padding="longest").input_ids
356
+ for _ in range(batch_size)]
357
+ input_ids_system = torch.stack(input_ids_system, dim=0)
358
+ assert input_ids_system.shape[-1] == 28
359
+ input_ids_system = input_ids_system.to(device)
360
+ input_ids_system = input_ids_system[0]
361
+
362
+ input_ids = [uni_prompting.text_tokenizer(prompt, return_tensors="pt", padding="longest").input_ids
363
+ for prompt in question_input]
364
+
365
+ input_ids = torch.stack(input_ids)
366
+ input_ids = torch.nn.utils.rnn.pad_sequence(
367
+ input_ids, batch_first=True, padding_value=uni_prompting.text_tokenizer.pad_token_id
368
+ )
369
+ input_ids = torch.tensor(input_ids).to(device).squeeze(0)
370
+ input_ids_llava = torch.cat([
371
+ (torch.ones(input_ids.shape[0], 1) *uni_prompting.sptids_dict['<|mmu|>']).to(device),
372
+ input_ids_system,
373
+ (torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|soi|>']).to(device),
374
+ # place your img embedding here
375
+ (torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|eoi|>']).to(device),
376
+ input_ids,
377
+ ], dim=1).long()
378
+
379
+ images_embeddings = vision_tower(pixel_values[None])
380
+ images_embeddings = model.mm_projector(images_embeddings)
381
+
382
+ text_embeddings = model.showo.model.embed_tokens(input_ids_llava)
383
+
384
+ # Full input seq
385
+ part1 = text_embeddings[:, :2 + SYSTEM_PROMPT_LEN, :]
386
+ part2 = text_embeddings[:, 2 + SYSTEM_PROMPT_LEN:, :]
387
+ input_embeddings = torch.cat((part1, images_embeddings, part2), dim=1)
388
+
389
+ attention_mask_llava = create_attention_mask_for_mmu_vit(input_embeddings,
390
+ system_prompt_len=SYSTEM_PROMPT_LEN)
391
+
392
+ cont_toks_list = model.mmu_generate(input_embeddings=input_embeddings,
393
+ attention_mask=attention_mask_llava[0].unsqueeze(0),
394
+ max_new_tokens=100,
395
+ top_k=top_k,
396
+ # eot_token=uni_prompting.sptids_dict['<|eot|>']
397
+ eot_token=tokenizer.eos_token_id
398
+ )
399
+
400
+ cont_toks_list = torch.stack(cont_toks_list).squeeze()[None]
401
+
402
+ output_text = uni_prompting.text_tokenizer.batch_decode(cont_toks_list, skip_special_tokens=True)
403
+
404
+ output_text = output_text[0].strip()
405
+
406
+ chat_history.append((input_text, output_text))
407
+
408
+ return "", chat_history
409
+
410
+
411
+ with gr.Blocks() as demo:
412
+ gr.HTML("""
413
+ <h1 class="display-2 fw-bold title">
414
+ <a style="color: #70a8dc;">S</a><a style="color: #6fb051;">h</a><a style="color: #e06766;">o</a><a style="color: #f7b26b;">w</a>-o
415
+ </h1>
416
+ <p>This is the official Gradio demo for the Show-o model, a unified model that can do multimodal understanding and generation.</p>
417
+
418
+ <strong>Paper:</strong> <a href="https://arxiv.org/abs/2408.12528" target="_blank">Show-o: One Single Transformer To Unify Multimodal Understanding and Generation </a>
419
+ <br/>
420
+ <strong>Project Website:</strong> <a href="https://showlab.github.io/Show-o/" target="_blank">Show-o Website</a>
421
+ <br/>
422
+ <strong>Code and Models:</strong> <a href="https://github.com/showlab/Show-o" target="_blank">GitHub</a>
423
+ <br/>
424
+ <br/>
425
+ """)
426
+
427
+ with gr.Row():
428
+ with gr.Column():
429
+ text_prompt_t2i = gr.Textbox(
430
+ label="Text prompt",
431
+ lines=2,
432
+ placeholder="Input the text prompt here for image generation."
433
+ )
434
+ guidance_scale_t2i = gr.Slider(
435
+ label="guidance scale",
436
+ minimum=0,
437
+ maximum=5,
438
+ step=0.05,
439
+ value=1.75
440
+ )
441
+ generation_timesteps_t2i = gr.Slider(
442
+ label="timesteps",
443
+ minimum=1,
444
+ maximum=30,
445
+ step=1,
446
+ value=18
447
+ )
448
+ generated_img_t2i = gr.Image(
449
+ label="Output image"
450
+ )
451
+ submit_btn_t2i = gr.Button("Generate: Text-to-image")
452
+ submit_btn_t2i.click(text_to_image_generation,
453
+ [text_prompt_t2i, guidance_scale_t2i, generation_timesteps_t2i],
454
+ [generated_img_t2i])
455
+
456
+ with gr.Row():
457
+ inpainting_input_img = gr.Image(
458
+ label="Input image",
459
+ type="pil",
460
+ )
461
+ inpainting_input_mask = gr.Image(
462
+ label="Inpainting mask",
463
+ image_mode="L",
464
+ type="pil",
465
+ )
466
+
467
+ with gr.Column():
468
+ text_prompt_inpainting = gr.Textbox(
469
+ label="Text prompt",
470
+ lines=2,
471
+ placeholder="Input the text prompt here for image inpainting."
472
+ )
473
+ guidance_scale_inpainting = gr.Slider(
474
+ label="guidance scale",
475
+ minimum=0,
476
+ maximum=5,
477
+ step=0.05,
478
+ value=1.75
479
+ )
480
+ generation_timesteps_inpainting = gr.Slider(
481
+ label="timesteps",
482
+ minimum=1,
483
+ maximum=30,
484
+ step=1,
485
+ value=16
486
+ )
487
+ generated_img_inpainting = gr.Image(
488
+ label="Output image"
489
+ )
490
+ submit_btn_inpainting = gr.Button("Generate: Text-guided Inpainting")
491
+ submit_btn_inpainting.click(text_guided_inpainting,
492
+ [text_prompt_inpainting, inpainting_input_img, inpainting_input_mask,
493
+ guidance_scale_inpainting, generation_timesteps_inpainting],
494
+ [generated_img_inpainting])
495
+
496
+ with gr.Row():
497
+ extra_input_img = gr.Image(
498
+ label="Input image",
499
+ type="pil",
500
+ image_mode="RGB",
501
+ )
502
+
503
+ with gr.Column():
504
+ text_prompt_extrapolation = gr.Textbox(
505
+ label="Text prompt",
506
+ lines=1,
507
+ placeholder="Input the text prompt here for image extrapolation."
508
+ )
509
+ guidance_scale_extrapolation = gr.Slider(
510
+ label="guidance scale",
511
+ minimum=0,
512
+ maximum=5,
513
+ step=0.05,
514
+ value=1.75
515
+ )
516
+ generation_timesteps_extrapolation = gr.Slider(
517
+ label="timesteps",
518
+ minimum=1,
519
+ maximum=30,
520
+ step=1,
521
+ value=16
522
+ )
523
+ left_extrapolation = gr.Slider(
524
+ label="left extrapolation",
525
+ minimum=0,
526
+ maximum=5,
527
+ step=1,
528
+ value=1
529
+ )
530
+ right_extrapolation = gr.Slider(
531
+ label="right extrapolation",
532
+ minimum=0,
533
+ maximum=5,
534
+ step=1,
535
+ value=1
536
+ )
537
+ generated_img_extrapolation = gr.Image(
538
+ label="Output image"
539
+ )
540
+ submit_btn_inpainting = gr.Button("Generate: Text-guided Extrapolation")
541
+ submit_btn_inpainting.click(text_guided_extrapolation,
542
+ [extra_input_img, text_prompt_extrapolation, left_extrapolation, right_extrapolation,
543
+ guidance_scale_extrapolation, generation_timesteps_extrapolation],
544
+ [generated_img_extrapolation])
545
+
546
+ with gr.Row():
547
+ with gr.Row():
548
+ chat_input_img = gr.Image(
549
+ label="Input image",
550
+ type="pil",
551
+ image_mode="RGB",
552
+ )
553
+ with gr.Column():
554
+ chatbot = gr.Chatbot()
555
+ msg = gr.Textbox(label="Press Enter to send a message for chat")
556
+ clear = gr.ClearButton([msg, chatbot])
557
+ msg.submit(multimodal_understanding, [chat_input_img, msg, chatbot], [msg, chatbot])
558
+
559
+ demo.launch()
gradio/share_btn.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from https://huggingface.co/spaces/haoheliu/audioldm-text-to-audio-generation/blob/79681cd8cb235160a27cdd100673346eb1784e53/share_btn.py
2
+
3
+ community_icon_html = """<svg id="share-btn-share-icon" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32">
4
+ <path d="M20.6081 3C21.7684 3 22.8053 3.49196 23.5284 4.38415C23.9756 4.93678 24.4428 5.82749 24.4808 7.16133C24.9674 7.01707 25.4353 6.93643 25.8725 6.93643C26.9833 6.93643 27.9865 7.37587 28.696 8.17411C29.6075 9.19872 30.0124 10.4579 29.8361 11.7177C29.7523 12.3177 29.5581 12.8555 29.2678 13.3534C29.8798 13.8646 30.3306 14.5763 30.5485 15.4322C30.719 16.1032 30.8939 17.5006 29.9808 18.9403C30.0389 19.0342 30.0934 19.1319 30.1442 19.2318C30.6932 20.3074 30.7283 21.5229 30.2439 22.6548C29.5093 24.3704 27.6841 25.7219 24.1397 27.1727C21.9347 28.0753 19.9174 28.6523 19.8994 28.6575C16.9842 29.4379 14.3477 29.8345 12.0653 29.8345C7.87017 29.8345 4.8668 28.508 3.13831 25.8921C0.356375 21.6797 0.754104 17.8269 4.35369 14.1131C6.34591 12.058 7.67023 9.02782 7.94613 8.36275C8.50224 6.39343 9.97271 4.20438 12.4172 4.20438H12.4179C12.6236 4.20438 12.8314 4.2214 13.0364 4.25468C14.107 4.42854 15.0428 5.06476 15.7115 6.02205C16.4331 5.09583 17.134 4.359 17.7682 3.94323C18.7242 3.31737 19.6794 3 20.6081 3ZM20.6081 5.95917C20.2427 5.95917 19.7963 6.1197 19.3039 6.44225C17.7754 7.44319 14.8258 12.6772 13.7458 14.7131C13.3839 15.3952 12.7655 15.6837 12.2086 15.6837C11.1036 15.6837 10.2408 14.5497 12.1076 13.1085C14.9146 10.9402 13.9299 7.39584 12.5898 7.1776C12.5311 7.16799 12.4731 7.16355 12.4172 7.16355C11.1989 7.16355 10.6615 9.33114 10.6615 9.33114C10.6615 9.33114 9.0863 13.4148 6.38031 16.206C3.67434 18.998 3.5346 21.2388 5.50675 24.2246C6.85185 26.2606 9.42666 26.8753 12.0653 26.8753C14.8021 26.8753 17.6077 26.2139 19.1799 25.793C19.2574 25.7723 28.8193 22.984 27.6081 20.6107C27.4046 20.212 27.0693 20.0522 26.6471 20.0522C24.9416 20.0522 21.8393 22.6726 20.5057 22.6726C20.2076 22.6726 19.9976 22.5416 19.9116 22.222C19.3433 20.1173 28.552 19.2325 27.7758 16.1839C27.639 15.6445 27.2677 15.4256 26.746 15.4263C24.4923 15.4263 19.4358 19.5181 18.3759 19.5181C18.2949 19.5181 18.2368 19.4937 18.2053 19.4419C17.6743 18.557 17.9653 17.9394 21.7082 15.6009C25.4511 13.2617 28.0783 11.8545 26.5841 10.1752C26.4121 9.98141 26.1684 9.8956 25.8725 9.8956C23.6001 9.89634 18.2311 14.9403 18.2311 14.9403C18.2311 14.9403 16.7821 16.496 15.9057 16.496C15.7043 16.496 15.533 16.4139 15.4169 16.2112C14.7956 15.1296 21.1879 10.1286 21.5484 8.06535C21.7928 6.66715 21.3771 5.95917 20.6081 5.95917Z" fill="#FF9D00"></path>
5
+ <path d="M5.50686 24.2246C3.53472 21.2387 3.67446 18.9979 6.38043 16.206C9.08641 13.4147 10.6615 9.33111 10.6615 9.33111C10.6615 9.33111 11.2499 6.95933 12.59 7.17757C13.93 7.39581 14.9139 10.9401 12.1069 13.1084C9.29997 15.276 12.6659 16.7489 13.7459 14.713C14.8258 12.6772 17.7747 7.44316 19.304 6.44221C20.8326 5.44128 21.9089 6.00204 21.5484 8.06532C21.188 10.1286 14.795 15.1295 15.4171 16.2118C16.0391 17.2934 18.2312 14.9402 18.2312 14.9402C18.2312 14.9402 25.0907 8.49588 26.5842 10.1752C28.0776 11.8545 25.4512 13.2616 21.7082 15.6008C17.9646 17.9393 17.6744 18.557 18.2054 19.4418C18.7372 20.3266 26.9998 13.1351 27.7759 16.1838C28.5513 19.2324 19.3434 20.1173 19.9117 22.2219C20.48 24.3274 26.3979 18.2382 27.6082 20.6107C28.8193 22.9839 19.2574 25.7722 19.18 25.7929C16.0914 26.62 8.24723 28.3726 5.50686 24.2246Z" fill="#FFD21E"></path>
6
+ </svg>"""
7
+
8
+ loading_icon_html = """<svg id="share-btn-loading-icon" style="display:none;" class="animate-spin"
9
+ style="color: #ffffff;
10
+ "
11
+ xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" fill="none" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 24 24"><circle style="opacity: 0.25;" cx="12" cy="12" r="10" stroke="white" stroke-width="4"></circle><path style="opacity: 0.75;" fill="white" d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"></path></svg>"""
12
+
13
+ share_js = """
14
+ async () => {
15
+ const html2canvas = (await import('https://cdnjs.cloudflare.com/ajax/libs/html2canvas/1.4.1/html2canvas.esm.js')).default;
16
+ async function uploadFile(file) {
17
+ console.log(file.type)
18
+ const UPLOAD_URL = 'https://huggingface.co/uploads';
19
+ const response = await fetch(UPLOAD_URL, {
20
+ method: 'POST',
21
+ headers: {
22
+ 'Content-Type': file.type,
23
+ 'X-Requested-With': 'XMLHttpRequest',
24
+ },
25
+ body: file, /// <- File inherits from Blob
26
+ });
27
+ const url = await response.text();
28
+ return url;
29
+ }
30
+ async function getImageFile(div) {
31
+ let chatbot = document.getElementById("chatbot");
32
+ chatbot.style.height = "";
33
+ return new Promise((resolve, reject) =>
34
+ html2canvas(div)
35
+ .then((canvas) => {
36
+ chatbot.style.height = "400px";
37
+ const imageBlob = canvas.toBlob((blob) => {
38
+ const imageId = Date.now();
39
+ const fileName = "GILL-" + imageId + ".jpg";
40
+ resolve(new File([blob], fileName, { type: 'image/jpeg' }));
41
+ }, 'image/jpeg', 0.95);
42
+ })
43
+
44
+ )
45
+ }
46
+
47
+ const gradioEl = document.querySelector("gradio-app").shadowRoot || document.querySelector('body > gradio-app');
48
+ const chatbotEl = gradioEl.querySelector('#chatbot')
49
+ const imageFile = await getImageFile(chatbotEl);
50
+ console.log(imageFile);
51
+ const urlChatbotImage = await uploadFile(imageFile);
52
+ console.log(urlChatbotImage);
53
+ let titleTxt = `GILL Example`;
54
+
55
+ //const shareBtnEl = gradioEl.querySelector('#share-btn');
56
+ //shareBtnEl.style.pointerEvents = 'none';
57
+ const descriptionMd = `
58
+
59
+ <img src='${urlChatbotImage}'>
60
+ `;
61
+ const params = new URLSearchParams({
62
+ title: titleTxt,
63
+ description: descriptionMd,
64
+ });
65
+ const paramsStr = params.toString();
66
+ window.open(`https://huggingface.co/spaces/jykoh/gill/discussions/new?${paramsStr}`, '_blank');
67
+ //shareBtnEl.style.removeProperty('pointer-events');
68
+ }
69
+ """
70
+
71
+ save_js = """
72
+ async () => {
73
+ const html2canvas = (await import('https://cdnjs.cloudflare.com/ajax/libs/html2canvas/1.4.1/html2canvas.esm.js')).default;
74
+
75
+ function saveAs(uri, filename) {
76
+ var link = document.createElement('a');
77
+ if (typeof link.download === 'string') {
78
+ link.href = uri;
79
+ link.download = filename;
80
+
81
+ //Firefox requires the link to be in the body
82
+ document.body.appendChild(link);
83
+
84
+ //simulate click
85
+ link.click();
86
+
87
+ //remove the link when done
88
+ document.body.removeChild(link);
89
+ } else {
90
+ window.open(uri);
91
+ }
92
+ }
93
+
94
+ async function getImageFile(div) {
95
+ let chatbot = document.getElementById("chatbot");
96
+ chatbot.style.height = "";
97
+ return new Promise((resolve, reject) =>
98
+ html2canvas(div)
99
+ .then((canvas) => {
100
+ chatbot.style.height = "400px";
101
+ const imageId = Date.now();
102
+ const fileName = "GILL-" + imageId + ".png";
103
+ saveAs(canvas.toDataURL(), fileName);
104
+ })
105
+
106
+ )
107
+ }
108
+ const gradioEl = document.querySelector("gradio-app").shadowRoot || document.querySelector('body > gradio-app');
109
+ const chatbotEl = gradioEl.querySelector('#chatbot')
110
+ const imageFile = await getImageFile(chatbotEl);
111
+ console.log(imageFile);
112
+ }
113
+ """
inference_mmu.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ["TOKENIZERS_PARALLELISM"] = "true"
3
+ from PIL import Image
4
+ from tqdm import tqdm
5
+ import numpy as np
6
+ import torch
7
+ import wandb
8
+ from models import Showo, MAGVITv2
9
+ from prompting_utils import UniversalPrompting, create_attention_mask_for_mmu, create_attention_mask_for_mmu_vit
10
+ from training.utils import get_config, flatten_omega_conf, image_transform
11
+ from transformers import AutoTokenizer
12
+ from models.clip_encoder import CLIPVisionTower
13
+ from transformers import CLIPImageProcessor
14
+
15
+ # import.training.conversation as conversation_lib
16
+ from training import conversation as conversation_lib
17
+
18
+ conversation_lib.default_conversation = conversation_lib.conv_templates["phi1.5"]
19
+ SYSTEM_PROMPT = "A chat between a curious user and an artificial intelligence assistant. " \
20
+ "The assistant gives helpful, detailed, and polite answers to the user's questions."
21
+ SYSTEM_PROMPT_LEN = 28
22
+
23
+ def get_vq_model_class(model_type):
24
+ if model_type == "magvitv2":
25
+ return MAGVITv2
26
+ else:
27
+ raise ValueError(f"model_type {model_type} not supported.")
28
+
29
+ if __name__ == '__main__':
30
+
31
+ config = get_config()
32
+
33
+ resume_wandb_run = config.wandb.resume
34
+ run_id = config.wandb.get("run_id", None)
35
+ if run_id is None:
36
+ resume_wandb_run = False
37
+ run_id = wandb.util.generate_id()
38
+ config.wandb.run_id = run_id
39
+
40
+ wandb_config = {k: v for k, v in flatten_omega_conf(config, resolve=True)}
41
+
42
+ wandb.init(
43
+ project="demo",
44
+ name=config.experiment.name + '_mmu',
45
+ config=wandb_config,
46
+ )
47
+
48
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
49
+ tokenizer = AutoTokenizer.from_pretrained(config.model.showo.llm_model_path, padding_side="left")
50
+
51
+ uni_prompting = UniversalPrompting(tokenizer, max_text_len=config.dataset.preprocessing.max_seq_length,
52
+ special_tokens=("<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", "<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>"),
53
+ ignore_id=-100, cond_dropout_prob=config.training.cond_dropout_prob)
54
+
55
+ vq_model = get_vq_model_class(config.model.vq_model.type)
56
+ vq_model = vq_model.from_pretrained(config.model.vq_model.vq_model_name).to(device)
57
+ vq_model.requires_grad_(False)
58
+ vq_model.eval()
59
+
60
+ vision_tower_name = "openai/clip-vit-large-patch14-336"
61
+ vision_tower = CLIPVisionTower(vision_tower_name).to(device)
62
+ clip_image_processor = CLIPImageProcessor.from_pretrained(vision_tower_name)
63
+
64
+ model = Showo.from_pretrained(config.model.showo.pretrained_model_path).to(device)
65
+ model.eval()
66
+
67
+ temperature = 0.8 # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions
68
+ top_k = 1 # retain only the top_k most likely tokens, clamp others to have 0 probability
69
+
70
+ file_list = os.listdir(config.mmu_image_root)
71
+ responses = ['' for i in range(len(file_list))]
72
+ images = []
73
+ config.question = config.question.split(' *** ')
74
+ for i, file_name in enumerate(tqdm(file_list)):
75
+ image_path = os.path.join(config.mmu_image_root, file_name)
76
+ image_ori = Image.open(image_path).convert("RGB")
77
+ image = image_transform(image_ori, resolution=config.dataset.params.resolution).to(device)
78
+ image = image.unsqueeze(0)
79
+ images.append(image)
80
+
81
+ pixel_values = clip_image_processor.preprocess(image_ori, return_tensors="pt")["pixel_values"][0]
82
+
83
+ image_tokens = vq_model.get_code(image) + len(uni_prompting.text_tokenizer)
84
+ batch_size = 1
85
+
86
+ for question in config.question:
87
+ if config.model.showo.w_clip_vit:
88
+ conv = conversation_lib.default_conversation.copy()
89
+ conv.append_message(conv.roles[0], question)
90
+ conv.append_message(conv.roles[1], None)
91
+ prompt_question = conv.get_prompt()
92
+ question_input = []
93
+ question_input.append(prompt_question.strip())
94
+
95
+ input_ids_system = [uni_prompting.text_tokenizer(SYSTEM_PROMPT, return_tensors="pt", padding="longest").input_ids
96
+ for _ in range(batch_size)]
97
+ input_ids_system = torch.stack(input_ids_system, dim=0)
98
+ assert input_ids_system.shape[-1] == 28
99
+ input_ids_system = input_ids_system.to(device)
100
+ input_ids_system = input_ids_system[0]
101
+
102
+ input_ids = [uni_prompting.text_tokenizer(prompt, return_tensors="pt", padding="longest").input_ids
103
+ for prompt in question_input]
104
+
105
+ input_ids = torch.stack(input_ids)
106
+ input_ids = torch.nn.utils.rnn.pad_sequence(
107
+ input_ids, batch_first=True, padding_value=uni_prompting.text_tokenizer.pad_token_id
108
+ )
109
+ input_ids = torch.tensor(input_ids).to(device).squeeze(0)
110
+ # import pdb; pdb.set_trace()
111
+ input_ids_llava = torch.cat([
112
+ (torch.ones(input_ids.shape[0], 1) *uni_prompting.sptids_dict['<|mmu|>']).to(device),
113
+ input_ids_system,
114
+ (torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|soi|>']).to(device),
115
+ # place your img embedding here
116
+ (torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|eoi|>']).to(device),
117
+ input_ids,
118
+ ], dim=1).long()
119
+
120
+ images_embeddings = vision_tower(pixel_values[None])
121
+ images_embeddings = model.mm_projector(images_embeddings)
122
+
123
+ text_embeddings = model.showo.model.embed_tokens(input_ids_llava)
124
+
125
+ # Full input seq
126
+ part1 = text_embeddings[:, :2 + SYSTEM_PROMPT_LEN, :]
127
+ part2 = text_embeddings[:, 2 + SYSTEM_PROMPT_LEN:, :]
128
+ input_embeddings = torch.cat((part1, images_embeddings, part2), dim=1)
129
+
130
+ attention_mask_llava = create_attention_mask_for_mmu_vit(input_embeddings,
131
+ system_prompt_len=SYSTEM_PROMPT_LEN)
132
+
133
+ cont_toks_list = model.mmu_generate(input_embeddings=input_embeddings,
134
+ attention_mask=attention_mask_llava[0].unsqueeze(0),
135
+ max_new_tokens=100,
136
+ top_k=top_k,
137
+ eot_token=uni_prompting.sptids_dict['<|eot|>']
138
+ )
139
+ else:
140
+ input_ids = uni_prompting.text_tokenizer(['USER: \n' + question + ' ASSISTANT:'])[
141
+ 'input_ids']
142
+ input_ids = torch.tensor(input_ids).to(device)
143
+
144
+ input_ids = torch.cat([
145
+ (torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|mmu|>']).to(device),
146
+ (torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|soi|>']).to(device),
147
+ image_tokens,
148
+ (torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|eoi|>']).to(device),
149
+ (torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|sot|>']).to(device),
150
+ input_ids
151
+ ], dim=1).long()
152
+
153
+ attention_mask = create_attention_mask_for_mmu(input_ids.to(device),
154
+ eoi_id=int(uni_prompting.sptids_dict['<|eoi|>']))
155
+
156
+ cont_toks_list = model.mmu_generate(input_ids, attention_mask=attention_mask,
157
+ max_new_tokens=100, top_k=top_k,
158
+ eot_token=uni_prompting.sptids_dict['<|eot|>'])
159
+
160
+ cont_toks_list = torch.stack(cont_toks_list).squeeze()[None]
161
+
162
+ text = uni_prompting.text_tokenizer.batch_decode(cont_toks_list, skip_special_tokens=True)
163
+ print(text)
164
+ responses[i] += f'User: ' + question + f'\n Answer : ' + text[0] + '\n'
165
+
166
+ images = torch.cat(images, dim=0)
167
+ images = torch.clamp((images + 1.0) / 2.0, min=0.0, max=1.0)
168
+ images *= 255.0
169
+ images = images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8)
170
+ pil_images = [Image.fromarray(image) for image in images]
171
+
172
+ wandb_images = [wandb.Image(image, caption=responses[i]) for i, image in enumerate(pil_images)]
173
+ wandb.log({"multimodal understanding": wandb_images}, step=0)
174
+
inference_t2i.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ["TOKENIZERS_PARALLELISM"] = "true"
3
+ from PIL import Image
4
+ from tqdm import tqdm
5
+ import numpy as np
6
+ import torch
7
+ import wandb
8
+ from models import Showo, MAGVITv2, get_mask_chedule
9
+ from prompting_utils import UniversalPrompting, create_attention_mask_predict_next
10
+ from training.utils import get_config, flatten_omega_conf, image_transform
11
+ from transformers import AutoTokenizer
12
+ import torch.nn.functional as F
13
+
14
+ def get_vq_model_class(model_type):
15
+ if model_type == "magvitv2":
16
+ return MAGVITv2
17
+ else:
18
+ raise ValueError(f"model_type {model_type} not supported.")
19
+
20
+ if __name__ == '__main__':
21
+
22
+ config = get_config()
23
+
24
+ resume_wandb_run = config.wandb.resume
25
+ run_id = config.wandb.get("run_id", None)
26
+ if run_id is None:
27
+ resume_wandb_run = False
28
+ run_id = wandb.util.generate_id()
29
+ config.wandb.run_id = run_id
30
+
31
+ wandb_config = {k: v for k, v in flatten_omega_conf(config, resolve=True)}
32
+
33
+ wandb.init(
34
+ project="demo",
35
+ name=config.experiment.name + '_t2i' + f'_{config.mode}',
36
+ config=wandb_config,
37
+ )
38
+
39
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
40
+ tokenizer = AutoTokenizer.from_pretrained(config.model.showo.llm_model_path, padding_side="left")
41
+
42
+ uni_prompting = UniversalPrompting(tokenizer, max_text_len=config.dataset.preprocessing.max_seq_length,
43
+ special_tokens=("<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", "<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>"),
44
+ ignore_id=-100, cond_dropout_prob=config.training.cond_dropout_prob)
45
+
46
+ vq_model = get_vq_model_class(config.model.vq_model.type)
47
+ vq_model = vq_model.from_pretrained(config.model.vq_model.vq_model_name).to(device)
48
+ vq_model.requires_grad_(False)
49
+ vq_model.eval()
50
+
51
+ model = Showo.from_pretrained(config.model.showo.pretrained_model_path).to(device)
52
+ model.eval()
53
+
54
+ mask_token_id = model.config.mask_token_id
55
+
56
+ # load from users passed arguments
57
+ if config.get("validation_prompts_file", None) is not None:
58
+ config.dataset.params.validation_prompts_file = config.validation_prompts_file
59
+ config.training.batch_size = config.batch_size
60
+ config.training.guidance_scale = config.guidance_scale
61
+ config.training.generation_timesteps = config.generation_timesteps
62
+ # load from users passed arguments
63
+
64
+ if config.mode == 'inpainting':
65
+
66
+ prompt = [config.prompt] * config.batch_size
67
+ inpainting_image = Image.open(config.image_path).convert("RGB")
68
+ inpainting_mask = Image.open(config.inpainting_mask_path).convert("L")
69
+
70
+ import pdb
71
+ pdb.set_trace()
72
+
73
+ inpainting_image = image_transform(inpainting_image, resolution=config.dataset.params.resolution).to(device)
74
+ inpainting_mask = image_transform(inpainting_mask, resolution=config.dataset.params.resolution, normalize=False)
75
+
76
+ # record original image and inpainting mask
77
+ images = torch.clamp(
78
+ (torch.stack([inpainting_image, inpainting_mask.repeat(3, 1, 1).to(device)], dim=0) + 1.0) / 2.0,
79
+ min=0.0, max=1.0)
80
+ images *= 255.0
81
+ images = images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8)
82
+ pil_images = [Image.fromarray(image) for image in images]
83
+
84
+ labels = ['original image', 'inpainting mask']
85
+ wandb_images = [wandb.Image(image, caption=labels[i]) for i, image in enumerate(pil_images)]
86
+
87
+ inpainting_image = inpainting_image.unsqueeze(0).repeat(config.training.batch_size, 1, 1, 1)
88
+
89
+ inpainting_mask = inpainting_mask.unsqueeze(0).to(device)
90
+ inpainting_mask = F.interpolate(inpainting_mask, size=config.dataset.params.resolution // 16, mode='bicubic')
91
+ inpainting_mask = inpainting_mask.repeat(config.training.batch_size, 1, 1, 1)
92
+
93
+ inpainting_mask[inpainting_mask < 0.5] = 0
94
+ inpainting_mask[inpainting_mask >= 0.5] = 1
95
+
96
+ inpainting_mask = inpainting_mask.reshape(config.training.batch_size, -1)
97
+ inpainting_mask = inpainting_mask.to(torch.bool)
98
+
99
+ inpainting_image_tokens = vq_model.get_code(inpainting_image) + len(uni_prompting.text_tokenizer)
100
+ inpainting_image_tokens[inpainting_mask] = mask_token_id
101
+
102
+ input_ids, _ = uni_prompting((prompt, inpainting_image_tokens), 't2i_gen')
103
+
104
+ if config.training.guidance_scale > 0:
105
+ uncond_input_ids, _ = uni_prompting(([''] * len(prompt), inpainting_image_tokens), 't2i_gen')
106
+ attention_mask = create_attention_mask_predict_next(torch.cat([input_ids, uncond_input_ids], dim=0),
107
+ pad_id=int(uni_prompting.sptids_dict['<|pad|>']),
108
+ soi_id=int(uni_prompting.sptids_dict['<|soi|>']),
109
+ eoi_id=int(uni_prompting.sptids_dict['<|eoi|>']),
110
+ rm_pad_in_image=True)
111
+ else:
112
+ attention_mask = create_attention_mask_predict_next(input_ids,
113
+ pad_id=int(uni_prompting.sptids_dict['<|pad|>']),
114
+ soi_id=int(uni_prompting.sptids_dict['<|soi|>']),
115
+ eoi_id=int(uni_prompting.sptids_dict['<|eoi|>']),
116
+ rm_pad_in_image=True)
117
+ uncond_input_ids = None
118
+
119
+ if config.get("mask_schedule", None) is not None:
120
+ schedule = config.mask_schedule.schedule
121
+ args = config.mask_schedule.get("params", {})
122
+ mask_schedule = get_mask_chedule(schedule, **args)
123
+ else:
124
+ mask_schedule = get_mask_chedule(config.training.get("mask_schedule", "cosine"))
125
+
126
+ with torch.no_grad():
127
+ gen_token_ids = model.t2i_generate(
128
+ input_ids=input_ids,
129
+ uncond_input_ids=uncond_input_ids,
130
+ attention_mask=attention_mask,
131
+ guidance_scale=config.training.guidance_scale,
132
+ temperature=config.training.get("generation_temperature", 1.0),
133
+ timesteps=config.training.generation_timesteps,
134
+ noise_schedule=mask_schedule,
135
+ noise_type=config.training.get("noise_type", "mask"),
136
+ seq_len=config.model.showo.num_vq_tokens,
137
+ uni_prompting=uni_prompting,
138
+ config=config,
139
+ )
140
+
141
+ gen_token_ids = torch.clamp(gen_token_ids, max=config.model.showo.codebook_size - 1, min=0)
142
+ images = vq_model.decode_code(gen_token_ids)
143
+
144
+ images = torch.clamp((images + 1.0) / 2.0, min=0.0, max=1.0)
145
+ images *= 255.0
146
+ images = images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8)
147
+ pil_images = [Image.fromarray(image) for image in images]
148
+ # import ipdb
149
+ # ipdb.set_trace()
150
+ wandb_images.extend([wandb.Image(image, caption=prompt[i]) for i, image in enumerate(pil_images)])
151
+ wandb.log({"generated_images": wandb_images}, step=0)
152
+
153
+ elif config.mode == 'extrapolation':
154
+
155
+ prompt = [p for p in config.prompt.split(" *** ") if len(p) != 0]
156
+ extra_direction = [d for d in config.extra_direction.split(" *** ") if len(d) != 0]
157
+ print(prompt, extra_direction)
158
+ W = config.dataset.params.resolution // 16
159
+ for id, (prt, direction) in enumerate(zip(prompt, extra_direction)):
160
+ prt = [prt] * config.training.batch_size
161
+ if id == 0:
162
+ extrapolation_image = Image.open(config.image_path).convert("RGB")
163
+ extrapolation_image = image_transform(extrapolation_image,
164
+ resolution=config.dataset.params.resolution).to(device)
165
+
166
+ B, _, _ = extrapolation_image.shape
167
+ extrapolation_image = extrapolation_image.unsqueeze(0)
168
+ extrapolation_image_tokens = vq_model.get_code(extrapolation_image) + len(uni_prompting.text_tokenizer)
169
+ extrapolation_image_tokens = extrapolation_image_tokens.reshape(1,
170
+ config.dataset.params.resolution // 16,
171
+ config.dataset.params.resolution // 16)
172
+ extrapolation_image_tokens = extrapolation_image_tokens.repeat(config.training.batch_size, 1, 1)
173
+ else:
174
+
175
+
176
+ extrapolation_image_tokens = gen_token_ids + len(uni_prompting.text_tokenizer)
177
+
178
+ image_left_part = extrapolation_image_tokens[:, :, :-(W//2-config.offset)] - len(uni_prompting.text_tokenizer)
179
+ image_right_part = extrapolation_image_tokens[:, :, W//2-config.offset:] - len(uni_prompting.text_tokenizer)
180
+ image_up_part = extrapolation_image_tokens[:, :-(W//2-config.offset), :] - len(uni_prompting.text_tokenizer)
181
+ image_down_part = extrapolation_image_tokens[:, W//2-config.offset:, :] - len(uni_prompting.text_tokenizer)
182
+
183
+ if direction in ['left', 'right']:
184
+ extrapolation_mask = torch.zeros((config.training.batch_size,
185
+ config.dataset.params.resolution // 16,
186
+ config.dataset.params.resolution // 16 // 2 + config.offset),
187
+ dtype=torch.int64, device=device) + mask_token_id
188
+ else:
189
+ extrapolation_mask = torch.zeros((config.training.batch_size,
190
+ config.dataset.params.resolution // 16 // 2 + config.offset,
191
+ config.dataset.params.resolution // 16),
192
+ dtype=torch.int64, device=device) + mask_token_id
193
+
194
+ if direction == 'left':
195
+ extrapolation_image_tokens = torch.cat(
196
+ [extrapolation_mask, extrapolation_image_tokens[:, :, :W//2-config.offset]], dim=-1)
197
+ elif direction == 'right':
198
+ extrapolation_image_tokens = torch.cat(
199
+ [extrapolation_image_tokens[:, :, -(W//2-config.offset):], extrapolation_mask], dim=-1)
200
+ elif direction == 'up':
201
+ extrapolation_image_tokens = torch.cat(
202
+ [extrapolation_mask, extrapolation_image_tokens[:, :W // 2 - config.offset, :]], dim=-2)
203
+ else:
204
+ extrapolation_image_tokens = torch.cat(
205
+ [extrapolation_image_tokens[:, -(W // 2 - config.offset):, :], extrapolation_mask], dim=-2)
206
+
207
+ extrapolation_image_tokens = extrapolation_image_tokens.reshape(config.training.batch_size, -1)
208
+
209
+ input_ids, _ = uni_prompting((prt, extrapolation_image_tokens), 't2i_gen')
210
+
211
+ if config.training.guidance_scale > 0:
212
+ uncond_input_ids, _ = uni_prompting(([''] * len(prt), extrapolation_image_tokens), 't2i_gen')
213
+ attention_mask = create_attention_mask_predict_next(torch.cat([input_ids, uncond_input_ids], dim=0),
214
+ pad_id=int(uni_prompting.sptids_dict['<|pad|>']),
215
+ soi_id=int(uni_prompting.sptids_dict['<|soi|>']),
216
+ eoi_id=int(uni_prompting.sptids_dict['<|eoi|>']),
217
+ rm_pad_in_image=True)
218
+ else:
219
+ attention_mask = create_attention_mask_predict_next(input_ids,
220
+ pad_id=int(uni_prompting.sptids_dict['<|pad|>']),
221
+ soi_id=int(uni_prompting.sptids_dict['<|soi|>']),
222
+ eoi_id=int(uni_prompting.sptids_dict['<|eoi|>']),
223
+ rm_pad_in_image=True)
224
+ uncond_input_ids = None
225
+
226
+ if config.get("mask_schedule", None) is not None:
227
+ schedule = config.mask_schedule.schedule
228
+ args = config.mask_schedule.get("params", {})
229
+ mask_schedule = get_mask_chedule(schedule, **args)
230
+ else:
231
+ mask_schedule = get_mask_chedule(config.training.get("mask_schedule", "cosine"))
232
+
233
+ with torch.no_grad():
234
+ gen_token_ids = model.t2i_generate(
235
+ input_ids=input_ids,
236
+ uncond_input_ids=uncond_input_ids,
237
+ attention_mask=attention_mask,
238
+ guidance_scale=config.training.guidance_scale,
239
+ temperature=config.training.get("generation_temperature", 1.0),
240
+ timesteps=config.training.generation_timesteps,
241
+ noise_schedule=mask_schedule,
242
+ noise_type=config.training.get("noise_type", "mask"),
243
+ seq_len=config.model.showo.num_vq_tokens,
244
+ uni_prompting=uni_prompting,
245
+ config=config,
246
+ )
247
+
248
+ gen_token_ids = torch.clamp(gen_token_ids, max=config.model.showo.codebook_size - 1, min=0)
249
+ gen_token_ids = gen_token_ids.reshape(config.training.batch_size,
250
+ config.dataset.params.resolution // 16,
251
+ config.dataset.params.resolution // 16)
252
+ if direction == 'left':
253
+ gen_token_ids = torch.cat([gen_token_ids, image_right_part], dim=-1)
254
+ elif direction == 'right':
255
+ gen_token_ids = torch.cat([image_left_part, gen_token_ids], dim=-1)
256
+ elif direction == 'up':
257
+ gen_token_ids = torch.cat([gen_token_ids, image_down_part], dim=-2)
258
+ else:
259
+ gen_token_ids = torch.cat([image_left_part, gen_token_ids], dim=-2)
260
+
261
+ _, h, w = gen_token_ids.shape
262
+ gen_token_ids = gen_token_ids.reshape(config.training.batch_size, -1)
263
+ images = vq_model.decode_code(gen_token_ids, shape=(h, w))
264
+
265
+ images = torch.clamp((images + 1.0) / 2.0, min=0.0, max=1.0)
266
+ images *= 255.0
267
+ images = images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8)
268
+ pil_images = [Image.fromarray(image) for image in images]
269
+
270
+ wandb_images = [wandb.Image(image, caption=' '.join(prompt)) for i, image in enumerate(pil_images)]
271
+ wandb.log({"generated_images": wandb_images}, step=0)
272
+
273
+ elif config.mode == 't2i':
274
+ with open(config.dataset.params.validation_prompts_file, "r") as f:
275
+ validation_prompts = f.read().splitlines()
276
+
277
+ for step in tqdm(range(0, len(validation_prompts), config.training.batch_size)):
278
+ prompts = validation_prompts[step:step + config.training.batch_size]
279
+
280
+ image_tokens = torch.ones((len(prompts), config.model.showo.num_vq_tokens),
281
+ dtype=torch.long, device=device) * mask_token_id
282
+
283
+ input_ids, _ = uni_prompting((prompts, image_tokens), 't2i_gen')
284
+
285
+ if config.training.guidance_scale > 0:
286
+ uncond_input_ids, _ = uni_prompting(([''] * len(prompts), image_tokens), 't2i_gen')
287
+ attention_mask = create_attention_mask_predict_next(torch.cat([input_ids, uncond_input_ids], dim=0),
288
+ pad_id=int(uni_prompting.sptids_dict['<|pad|>']),
289
+ soi_id=int(uni_prompting.sptids_dict['<|soi|>']),
290
+ eoi_id=int(uni_prompting.sptids_dict['<|eoi|>']),
291
+ rm_pad_in_image=True)
292
+ else:
293
+ attention_mask = create_attention_mask_predict_next(input_ids,
294
+ pad_id=int(uni_prompting.sptids_dict['<|pad|>']),
295
+ soi_id=int(uni_prompting.sptids_dict['<|soi|>']),
296
+ eoi_id=int(uni_prompting.sptids_dict['<|eoi|>']),
297
+ rm_pad_in_image=True)
298
+ uncond_input_ids = None
299
+
300
+ if config.get("mask_schedule", None) is not None:
301
+ schedule = config.mask_schedule.schedule
302
+ args = config.mask_schedule.get("params", {})
303
+ mask_schedule = get_mask_chedule(schedule, **args)
304
+ else:
305
+ mask_schedule = get_mask_chedule(config.training.get("mask_schedule", "cosine"))
306
+
307
+ with torch.no_grad():
308
+ gen_token_ids = model.t2i_generate(
309
+ input_ids=input_ids,
310
+ uncond_input_ids=uncond_input_ids,
311
+ attention_mask=attention_mask,
312
+ guidance_scale=config.training.guidance_scale,
313
+ temperature=config.training.get("generation_temperature", 1.0),
314
+ timesteps=config.training.generation_timesteps,
315
+ noise_schedule=mask_schedule,
316
+ noise_type=config.training.get("noise_type", "mask"),
317
+ seq_len=config.model.showo.num_vq_tokens,
318
+ uni_prompting=uni_prompting,
319
+ config=config,
320
+ )
321
+
322
+ gen_token_ids = torch.clamp(gen_token_ids, max=config.model.showo.codebook_size - 1, min=0)
323
+ images = vq_model.decode_code(gen_token_ids)
324
+
325
+ images = torch.clamp((images + 1.0) / 2.0, min=0.0, max=1.0)
326
+ images *= 255.0
327
+ images = images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8)
328
+ pil_images = [Image.fromarray(image) for image in images]
329
+
330
+ wandb_images = [wandb.Image(image, caption=prompts[i]) for i, image in enumerate(pil_images)]
331
+ wandb.log({"generated_images": wandb_images}, step=step)
inpainting_validation/.DS_Store ADDED
Binary file (6.15 kB). View file
 
inpainting_validation/alpine_lake.jpg ADDED
inpainting_validation/bedroom.jpg ADDED
inpainting_validation/bedroom_mask.webp ADDED
inpainting_validation/bench.jpg ADDED
inpainting_validation/bench_mask.webp ADDED
inpainting_validation/bus.jpg ADDED
inpainting_validation/bus_mask.webp ADDED
inpainting_validation/lake_mountain.jpg ADDED
inpainting_validation/maya.png ADDED
inpainting_validation/river.png ADDED
inpainting_validation/train.jpg ADDED
inpainting_validation/train_mask.webp ADDED
inpainting_validation/truebsee.jpg ADDED
inpainting_validation/truebsee_mask.webp ADDED
inpainting_validation/wukong1.jpg ADDED
inpainting_validation/wukong2.jpg ADDED
mmu_validation/sofa_under_water.jpg ADDED
models/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .modeling_showo import Showo
2
+ from .modeling_magvitv2 import VQGANEncoder, VQGANDecoder, LFQuantizer, MAGVITv2
3
+ from .sampling import *
4
+ from .clip_encoder import CLIPVisionTower
models/clip_encoder.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
5
+
6
+ class CLIPVisionTower(nn.Module):
7
+ def __init__(self, vision_tower):
8
+ super().__init__()
9
+
10
+ self.is_loaded = False
11
+
12
+ self.vision_tower_name = vision_tower
13
+ self.select_layer = -2
14
+ self.select_feature = "patch"
15
+ self.load_model()
16
+ self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
17
+
18
+ def load_model(self, device_map=None):
19
+ if self.is_loaded:
20
+ print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name))
21
+ return
22
+
23
+ self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
24
+ self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map)
25
+ self.vision_tower.requires_grad_(False)
26
+
27
+ self.is_loaded = True
28
+
29
+ def feature_select(self, image_forward_outs):
30
+ image_features = image_forward_outs.hidden_states[self.select_layer]
31
+ if self.select_feature == 'patch':
32
+ image_features = image_features[:, 1:]
33
+ elif self.select_feature == 'cls_patch':
34
+ image_features = image_features
35
+ else:
36
+ raise ValueError(f'Unexpected select feature: {self.select_feature}')
37
+ return image_features
38
+
39
+ @torch.no_grad()
40
+ def forward(self, images):
41
+ if type(images) is list:
42
+ image_features = []
43
+ for image in images:
44
+ image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
45
+ image_feature = self.feature_select(image_forward_out).to(image.dtype)
46
+ image_features.append(image_feature)
47
+ else:
48
+ image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
49
+ image_features = self.feature_select(image_forward_outs).to(images.dtype)
50
+
51
+ return image_features
52
+
53
+ @property
54
+ def dummy_feature(self):
55
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
56
+
57
+ @property
58
+ def dtype(self):
59
+ return self.vision_tower.dtype
60
+
61
+ @property
62
+ def device(self):
63
+ return self.vision_tower.device
64
+
65
+ @property
66
+ def config(self):
67
+ if self.is_loaded:
68
+ return self.vision_tower.config
69
+ else:
70
+ return self.cfg_only
71
+
72
+ @property
73
+ def hidden_size(self):
74
+ return self.config.hidden_size
75
+
76
+ @property
77
+ def num_patches_per_side(self):
78
+ return self.config.image_size // self.config.patch_size
79
+
80
+ @property
81
+ def num_patches(self):
82
+ return (self.config.image_size // self.config.patch_size) ** 2
83
+
84
+
85
+ class CLIPVisionTowerS2(CLIPVisionTower):
86
+ def __init__(self, vision_tower, args, delay_load=False):
87
+ super().__init__(vision_tower, args, delay_load)
88
+
89
+ self.s2_scales = getattr(args, 's2_scales', '336,672,1008')
90
+ self.s2_scales = list(map(int, self.s2_scales.split(',')))
91
+ self.s2_scales.sort()
92
+ self.s2_split_size = self.s2_scales[0]
93
+ self.s2_image_size = self.s2_scales[-1]
94
+
95
+ try:
96
+ from s2wrapper import forward as multiscale_forward
97
+ except ImportError:
98
+ raise ImportError('Package s2wrapper not found! Please install by running: \npip install git+https://github.com/bfshi/scaling_on_scales.git')
99
+ self.multiscale_forward = multiscale_forward
100
+
101
+ # change resize/crop size in preprocessing to the largest image size in s2_scale
102
+ if not delay_load or getattr(args, 'unfreeze_mm_vision_tower', False):
103
+ self.image_processor.size['shortest_edge'] = self.s2_image_size
104
+ self.image_processor.crop_size['height'] = self.image_processor.crop_size['width'] = self.s2_image_size
105
+
106
+ def load_model(self, device_map=None):
107
+ if self.is_loaded:
108
+ print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name))
109
+ return
110
+
111
+ self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
112
+ self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map)
113
+ self.vision_tower.requires_grad_(False)
114
+
115
+ self.image_processor.size['shortest_edge'] = self.s2_image_size
116
+ self.image_processor.crop_size['height'] = self.image_processor.crop_size['width'] = self.s2_image_size
117
+
118
+ self.is_loaded = True
119
+
120
+ @torch.no_grad()
121
+ def forward_feature(self, images):
122
+ image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
123
+ image_features = self.feature_select(image_forward_outs).to(images.dtype)
124
+ return image_features
125
+
126
+ @torch.no_grad()
127
+ def forward(self, images):
128
+ if type(images) is list:
129
+ image_features = []
130
+ for image in images:
131
+ image_feature = self.multiscale_forward(self.forward_feature, image.unsqueeze(0), img_sizes=self.s2_scales, max_split_size=self.s2_split_size)
132
+ image_features.append(image_feature)
133
+ else:
134
+ image_features = self.multiscale_forward(self.forward_feature, images, img_sizes=self.s2_scales, max_split_size=self.s2_split_size)
135
+
136
+ return image_features
137
+
138
+ @property
139
+ def hidden_size(self):
140
+ return self.config.hidden_size * len(self.s2_scales)
models/common_modules.py ADDED
@@ -0,0 +1,407 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Modified from https://github.com/CompVis/taming-transformers/blob/master/taming/modules/diffusionmodules/model.py#L34
3
+ """
4
+
5
+ import math
6
+ from typing import Tuple, Union
7
+
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from einops import rearrange, repeat
13
+ from einops.layers.torch import Rearrange
14
+
15
+
16
+ def nonlinearity(x):
17
+ # swish
18
+ return x * torch.sigmoid(x)
19
+
20
+
21
+ def Normalize(in_channels):
22
+ return torch.nn.GroupNorm(
23
+ num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
24
+ )
25
+
26
+
27
+ class Upsample(nn.Module):
28
+ def __init__(self, in_channels, with_conv):
29
+ super().__init__()
30
+ self.with_conv = with_conv
31
+ if self.with_conv:
32
+ self.conv = torch.nn.Conv2d(
33
+ in_channels, in_channels, kernel_size=3, stride=1, padding=1
34
+ )
35
+
36
+ def forward(self, x):
37
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
38
+ if self.with_conv:
39
+ x = self.conv(x)
40
+ return x
41
+
42
+
43
+ class DepthToSpaceUpsample(nn.Module):
44
+ def __init__(
45
+ self,
46
+ in_channels,
47
+ ):
48
+ super().__init__()
49
+ conv = nn.Conv2d(in_channels, in_channels * 4, 1)
50
+
51
+ self.net = nn.Sequential(
52
+ conv,
53
+ nn.SiLU(),
54
+ Rearrange("b (c p1 p2) h w -> b c (h p1) (w p2)", p1=2, p2=2),
55
+ )
56
+
57
+ self.init_conv_(conv)
58
+
59
+ def init_conv_(self, conv):
60
+ o, i, h, w = conv.weight.shape
61
+ conv_weight = torch.empty(o // 4, i, h, w)
62
+ nn.init.kaiming_uniform_(conv_weight)
63
+ conv_weight = repeat(conv_weight, "o ... -> (o 4) ...")
64
+
65
+ conv.weight.data.copy_(conv_weight)
66
+ nn.init.zeros_(conv.bias.data)
67
+
68
+ def forward(self, x):
69
+ out = self.net(x)
70
+ return out
71
+
72
+
73
+ class Downsample(nn.Module):
74
+ def __init__(self, in_channels, with_conv):
75
+ super().__init__()
76
+ self.with_conv = with_conv
77
+ if self.with_conv:
78
+ # no asymmetric padding in torch conv, must do it ourselves
79
+ self.conv = torch.nn.Conv2d(
80
+ in_channels, in_channels, kernel_size=3, stride=2, padding=0
81
+ )
82
+
83
+ def forward(self, x):
84
+ if self.with_conv:
85
+ pad = (0, 1, 0, 1)
86
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
87
+ x = self.conv(x)
88
+ else:
89
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
90
+ return x
91
+
92
+
93
+ def unpack_time(t, batch):
94
+ _, c, w, h = t.size()
95
+ out = torch.reshape(t, [batch, -1, c, w, h])
96
+ out = rearrange(out, "b t c h w -> b c t h w")
97
+ return out
98
+
99
+
100
+ def pack_time(t):
101
+ out = rearrange(t, "b c t h w -> b t c h w")
102
+ _, _, c, w, h = out.size()
103
+ return torch.reshape(out, [-1, c, w, h])
104
+
105
+
106
+ class TimeDownsample2x(nn.Module):
107
+ def __init__(
108
+ self,
109
+ dim,
110
+ dim_out=None,
111
+ kernel_size=3,
112
+ ):
113
+ super().__init__()
114
+ if dim_out is None:
115
+ dim_out = dim
116
+ self.time_causal_padding = (kernel_size - 1, 0)
117
+ self.conv = nn.Conv1d(dim, dim_out, kernel_size, stride=2)
118
+
119
+ def forward(self, x):
120
+ x = rearrange(x, "b c t h w -> b h w c t")
121
+ b, h, w, c, t = x.size()
122
+ x = torch.reshape(x, [-1, c, t])
123
+
124
+ x = F.pad(x, self.time_causal_padding)
125
+ out = self.conv(x)
126
+
127
+ out = torch.reshape(out, [b, h, w, c, t])
128
+ out = rearrange(out, "b h w c t -> b c t h w")
129
+ out = rearrange(out, "b h w c t -> b c t h w")
130
+ return out
131
+
132
+
133
+ class TimeUpsample2x(nn.Module):
134
+ def __init__(self, dim, dim_out=None):
135
+ super().__init__()
136
+ if dim_out is None:
137
+ dim_out = dim
138
+ conv = nn.Conv1d(dim, dim_out * 2, 1)
139
+
140
+ self.net = nn.Sequential(
141
+ nn.SiLU(), conv, Rearrange("b (c p) t -> b c (t p)", p=2)
142
+ )
143
+
144
+ self.init_conv_(conv)
145
+
146
+ def init_conv_(self, conv):
147
+ o, i, t = conv.weight.shape
148
+ conv_weight = torch.empty(o // 2, i, t)
149
+ nn.init.kaiming_uniform_(conv_weight)
150
+ conv_weight = repeat(conv_weight, "o ... -> (o 2) ...")
151
+
152
+ conv.weight.data.copy_(conv_weight)
153
+ nn.init.zeros_(conv.bias.data)
154
+
155
+ def forward(self, x):
156
+ x = rearrange(x, "b c t h w -> b h w c t")
157
+ b, h, w, c, t = x.size()
158
+ x = torch.reshape(x, [-1, c, t])
159
+
160
+ out = self.net(x)
161
+ out = out[:, :, 1:].contiguous()
162
+
163
+ out = torch.reshape(out, [b, h, w, c, t])
164
+ out = rearrange(out, "b h w c t -> b c t h w")
165
+ return out
166
+
167
+
168
+ class AttnBlock(nn.Module):
169
+ def __init__(self, in_channels):
170
+ super().__init__()
171
+ self.in_channels = in_channels
172
+
173
+ self.norm = Normalize(in_channels)
174
+ self.q = torch.nn.Conv2d(
175
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
176
+ )
177
+ self.k = torch.nn.Conv2d(
178
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
179
+ )
180
+ self.v = torch.nn.Conv2d(
181
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
182
+ )
183
+ self.proj_out = torch.nn.Conv2d(
184
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
185
+ )
186
+
187
+ def forward(self, x):
188
+ h_ = x
189
+ h_ = self.norm(h_)
190
+ q = self.q(h_)
191
+ k = self.k(h_)
192
+ v = self.v(h_)
193
+
194
+ # compute attention
195
+ b, c, h, w = q.shape
196
+ q = q.reshape(b, c, h * w)
197
+ q = q.permute(0, 2, 1) # b,hw,c
198
+ k = k.reshape(b, c, h * w) # b,c,hw
199
+ w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
200
+ w_ = w_ * (int(c) ** (-0.5))
201
+ w_ = torch.nn.functional.softmax(w_, dim=2)
202
+
203
+ # attend to values
204
+ v = v.reshape(b, c, h * w)
205
+ w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
206
+ h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
207
+ h_ = h_.reshape(b, c, h, w)
208
+
209
+ h_ = self.proj_out(h_)
210
+
211
+ return x + h_
212
+
213
+
214
+ class TimeAttention(AttnBlock):
215
+ def forward(self, x, *args, **kwargs):
216
+ x = rearrange(x, "b c t h w -> b h w t c")
217
+ b, h, w, t, c = x.size()
218
+ x = torch.reshape(x, (-1, t, c))
219
+
220
+ x = super().forward(x, *args, **kwargs)
221
+
222
+ x = torch.reshape(x, [b, h, w, t, c])
223
+ return rearrange(x, "b h w t c -> b c t h w")
224
+
225
+
226
+ class Residual(nn.Module):
227
+ def __init__(self, fn: nn.Module):
228
+ super().__init__()
229
+ self.fn = fn
230
+
231
+ def forward(self, x, **kwargs):
232
+ return self.fn(x, **kwargs) + x
233
+
234
+
235
+ def cast_tuple(t, length=1):
236
+ return t if isinstance(t, tuple) else ((t,) * length)
237
+
238
+
239
+ class CausalConv3d(nn.Module):
240
+ def __init__(
241
+ self,
242
+ chan_in,
243
+ chan_out,
244
+ kernel_size: Union[int, Tuple[int, int, int]],
245
+ pad_mode="constant",
246
+ **kwargs
247
+ ):
248
+ super().__init__()
249
+ kernel_size = cast_tuple(kernel_size, 3)
250
+
251
+ time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
252
+
253
+ dilation = kwargs.pop("dilation", 1)
254
+ stride = kwargs.pop("stride", 1)
255
+
256
+ self.pad_mode = pad_mode
257
+ time_pad = dilation * (time_kernel_size - 1) + (1 - stride)
258
+ height_pad = height_kernel_size // 2
259
+ width_pad = width_kernel_size // 2
260
+
261
+ self.time_pad = time_pad
262
+ self.time_causal_padding = (
263
+ width_pad,
264
+ width_pad,
265
+ height_pad,
266
+ height_pad,
267
+ time_pad,
268
+ 0,
269
+ )
270
+
271
+ stride = (stride, 1, 1)
272
+ dilation = (dilation, 1, 1)
273
+ self.conv = nn.Conv3d(
274
+ chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs
275
+ )
276
+
277
+ def forward(self, x):
278
+ pad_mode = self.pad_mode if self.time_pad < x.shape[2] else "constant"
279
+
280
+ x = F.pad(x, self.time_causal_padding, mode=pad_mode)
281
+ return self.conv(x)
282
+
283
+
284
+ def ResnetBlockCausal3D(
285
+ dim, kernel_size: Union[int, Tuple[int, int, int]], pad_mode: str = "constant"
286
+ ):
287
+ net = nn.Sequential(
288
+ Normalize(dim),
289
+ nn.SiLU(),
290
+ CausalConv3d(dim, dim, kernel_size, pad_mode),
291
+ Normalize(dim),
292
+ nn.SiLU(),
293
+ CausalConv3d(dim, dim, kernel_size, pad_mode),
294
+ )
295
+ return Residual(net)
296
+
297
+
298
+ class ResnetBlock(nn.Module):
299
+ def __init__(
300
+ self,
301
+ *,
302
+ in_channels,
303
+ out_channels=None,
304
+ conv_shortcut=False,
305
+ dropout,
306
+ temb_channels=512
307
+ ):
308
+ super().__init__()
309
+ self.in_channels = in_channels
310
+ out_channels = in_channels if out_channels is None else out_channels
311
+ self.out_channels = out_channels
312
+ self.use_conv_shortcut = conv_shortcut
313
+
314
+ self.norm1 = Normalize(in_channels)
315
+ self.conv1 = torch.nn.Conv2d(
316
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
317
+ )
318
+ if temb_channels > 0:
319
+ self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
320
+ else:
321
+ self.temb_proj = None
322
+ self.norm2 = Normalize(out_channels)
323
+ self.dropout = torch.nn.Dropout(dropout)
324
+ self.conv2 = torch.nn.Conv2d(
325
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1
326
+ )
327
+ if self.in_channels != self.out_channels:
328
+ if self.use_conv_shortcut:
329
+ self.conv_shortcut = torch.nn.Conv2d(
330
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
331
+ )
332
+ else:
333
+ self.nin_shortcut = torch.nn.Conv2d(
334
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0
335
+ )
336
+
337
+ def forward(self, x, temb):
338
+ h = x
339
+ h = self.norm1(h)
340
+ h = nonlinearity(h)
341
+ h = self.conv1(h)
342
+
343
+ if temb is not None:
344
+ h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
345
+
346
+ h = self.norm2(h)
347
+ h = nonlinearity(h)
348
+ h = self.dropout(h)
349
+ h = self.conv2(h)
350
+
351
+ if self.in_channels != self.out_channels:
352
+ if self.use_conv_shortcut:
353
+ x = self.conv_shortcut(x)
354
+ else:
355
+ x = self.nin_shortcut(x)
356
+
357
+ return x + h
358
+
359
+
360
+ class DinoV2Model(nn.Module):
361
+ def __init__(
362
+ self,
363
+ model_name,
364
+ local_checkpoint_path="",
365
+ renorm_input=False,
366
+ old_input_mean=0.5,
367
+ old_input_std=0.5,
368
+ freeze_model=False,
369
+ ):
370
+ super().__init__()
371
+ if local_checkpoint_path != "":
372
+ self._model = torch.hub.load(
373
+ local_checkpoint_path, model_name, source="local"
374
+ )
375
+ else:
376
+ self._model = torch.hub.load("facebookresearch/dinov2", model_name)
377
+ self.register_buffer(
378
+ "_dino_input_mean",
379
+ torch.tensor([0.485, 0.456, 0.406]).float()[None, :, None, None],
380
+ )
381
+ self.register_buffer(
382
+ "_dino_input_std",
383
+ torch.tensor([0.229, 0.224, 0.225]).float()[None, :, None, None],
384
+ )
385
+ self._old_input_mean = old_input_mean
386
+ self._old_input_std = old_input_std
387
+ self._renorm_input = renorm_input
388
+ if freeze_model:
389
+ for param in self._model.parameters():
390
+ param.requires_grad = False
391
+
392
+ def forward(self, inputs):
393
+ batch, _, height, width = inputs.size()
394
+ if self._renorm_input:
395
+ inputs = inputs * self._old_input_mean + self._old_input_std
396
+ inputs = (inputs - self._dino_input_mean) / self._dino_input_std
397
+ # TODO(yanwan): If we want to remove this resizing, have to modify the decoder to support upscaling by a factor of 14.
398
+ # Reduce both height and width to 7/8 of their original values while maintaining aspect ratio to fit dinov2 requirement.
399
+ new_height = height // 8 * 7
400
+ new_width = width // 8 * 7
401
+ inputs = F.interpolate(inputs, (new_height, new_width), mode="bilinear")
402
+ features = self._model.forward_features(inputs)["x_norm_patchtokens"]
403
+ features = torch.transpose(features, 1, 2).contiguous()
404
+ features = torch.reshape(
405
+ features, (batch, -1, new_height // 14, new_width // 14)
406
+ )
407
+ return features
models/logging.py ADDED
@@ -0,0 +1,338 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Optuna, Hugging Face
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ Logging utilities."""
16
+
17
+ import logging
18
+ import os
19
+ import sys
20
+ import threading
21
+ from logging import CRITICAL # NOQA
22
+ from logging import DEBUG # NOQA
23
+ from logging import ERROR # NOQA
24
+ from logging import FATAL # NOQA
25
+ from logging import INFO # NOQA
26
+ from logging import NOTSET # NOQA
27
+ from logging import WARN # NOQA
28
+ from logging import WARNING # NOQA
29
+ from typing import Optional
30
+
31
+ from tqdm import auto as tqdm_lib
32
+
33
+ _lock = threading.Lock()
34
+ _default_handler: Optional[logging.Handler] = None
35
+
36
+ log_levels = {
37
+ "debug": logging.DEBUG,
38
+ "info": logging.INFO,
39
+ "warning": logging.WARNING,
40
+ "error": logging.ERROR,
41
+ "critical": logging.CRITICAL,
42
+ }
43
+
44
+ _default_log_level = logging.WARNING
45
+
46
+ _tqdm_active = True
47
+
48
+
49
+ def _get_default_logging_level():
50
+ """
51
+ If muse_VERBOSITY env var is set to one of the valid choices return that as the new default level. If it is
52
+ not - fall back to `_default_log_level`
53
+ """
54
+ env_level_str = os.getenv("muse_VERBOSITY", None)
55
+ if env_level_str:
56
+ if env_level_str in log_levels:
57
+ return log_levels[env_level_str]
58
+ else:
59
+ logging.getLogger().warning(
60
+ f"Unknown option muse_VERBOSITY={env_level_str}, has to be one of: { ', '.join(log_levels.keys()) }"
61
+ )
62
+ return _default_log_level
63
+
64
+
65
+ def _get_library_name() -> str:
66
+ return __name__.split(".")[0]
67
+
68
+
69
+ def _get_library_root_logger() -> logging.Logger:
70
+ return logging.getLogger(_get_library_name())
71
+
72
+
73
+ def _configure_library_root_logger() -> None:
74
+ global _default_handler
75
+
76
+ with _lock:
77
+ if _default_handler:
78
+ # This library has already configured the library root logger.
79
+ return
80
+ _default_handler = logging.StreamHandler() # Set sys.stderr as stream.
81
+ _default_handler.flush = sys.stderr.flush
82
+
83
+ # Apply our default configuration to the library root logger.
84
+ library_root_logger = _get_library_root_logger()
85
+ library_root_logger.addHandler(_default_handler)
86
+ library_root_logger.setLevel(_get_default_logging_level())
87
+ library_root_logger.propagate = False
88
+
89
+
90
+ def _reset_library_root_logger() -> None:
91
+ global _default_handler
92
+
93
+ with _lock:
94
+ if not _default_handler:
95
+ return
96
+
97
+ library_root_logger = _get_library_root_logger()
98
+ library_root_logger.removeHandler(_default_handler)
99
+ library_root_logger.setLevel(logging.NOTSET)
100
+ _default_handler = None
101
+
102
+
103
+ def get_log_levels_dict():
104
+ return log_levels
105
+
106
+
107
+ def get_logger(name: Optional[str] = None) -> logging.Logger:
108
+ """
109
+ Return a logger with the specified name.
110
+
111
+ This function is not supposed to be directly accessed unless you are writing a custom muse module.
112
+ """
113
+
114
+ if name is None:
115
+ name = _get_library_name()
116
+
117
+ _configure_library_root_logger()
118
+ return logging.getLogger(name)
119
+
120
+
121
+ def get_verbosity() -> int:
122
+ """
123
+ Return the current level for the 🤗 muse' root logger as an int.
124
+
125
+ Returns:
126
+ `int`: The logging level.
127
+
128
+ <Tip>
129
+
130
+ 🤗 muse has following logging levels:
131
+
132
+ - 50: `muse.logging.CRITICAL` or `muse.logging.FATAL`
133
+ - 40: `muse.logging.ERROR`
134
+ - 30: `muse.logging.WARNING` or `muse.logging.WARN`
135
+ - 20: `muse.logging.INFO`
136
+ - 10: `muse.logging.DEBUG`
137
+
138
+ </Tip>"""
139
+
140
+ _configure_library_root_logger()
141
+ return _get_library_root_logger().getEffectiveLevel()
142
+
143
+
144
+ def set_verbosity(verbosity: int) -> None:
145
+ """
146
+ Set the verbosity level for the 🤗 muse' root logger.
147
+
148
+ Args:
149
+ verbosity (`int`):
150
+ Logging level, e.g., one of:
151
+
152
+ - `muse.logging.CRITICAL` or `muse.logging.FATAL`
153
+ - `muse.logging.ERROR`
154
+ - `muse.logging.WARNING` or `muse.logging.WARN`
155
+ - `muse.logging.INFO`
156
+ - `muse.logging.DEBUG`
157
+ """
158
+
159
+ _configure_library_root_logger()
160
+ _get_library_root_logger().setLevel(verbosity)
161
+
162
+
163
+ def set_verbosity_info():
164
+ """Set the verbosity to the `INFO` level."""
165
+ return set_verbosity(INFO)
166
+
167
+
168
+ def set_verbosity_warning():
169
+ """Set the verbosity to the `WARNING` level."""
170
+ return set_verbosity(WARNING)
171
+
172
+
173
+ def set_verbosity_debug():
174
+ """Set the verbosity to the `DEBUG` level."""
175
+ return set_verbosity(DEBUG)
176
+
177
+
178
+ def set_verbosity_error():
179
+ """Set the verbosity to the `ERROR` level."""
180
+ return set_verbosity(ERROR)
181
+
182
+
183
+ def disable_default_handler() -> None:
184
+ """Disable the default handler of the HuggingFace muse' root logger."""
185
+
186
+ _configure_library_root_logger()
187
+
188
+ assert _default_handler is not None
189
+ _get_library_root_logger().removeHandler(_default_handler)
190
+
191
+
192
+ def enable_default_handler() -> None:
193
+ """Enable the default handler of the HuggingFace muse' root logger."""
194
+
195
+ _configure_library_root_logger()
196
+
197
+ assert _default_handler is not None
198
+ _get_library_root_logger().addHandler(_default_handler)
199
+
200
+
201
+ def add_handler(handler: logging.Handler) -> None:
202
+ """adds a handler to the HuggingFace muse' root logger."""
203
+
204
+ _configure_library_root_logger()
205
+
206
+ assert handler is not None
207
+ _get_library_root_logger().addHandler(handler)
208
+
209
+
210
+ def remove_handler(handler: logging.Handler) -> None:
211
+ """removes given handler from the HuggingFace muse' root logger."""
212
+
213
+ _configure_library_root_logger()
214
+
215
+ assert handler is not None and handler not in _get_library_root_logger().handlers
216
+ _get_library_root_logger().removeHandler(handler)
217
+
218
+
219
+ def disable_propagation() -> None:
220
+ """
221
+ Disable propagation of the library log outputs. Note that log propagation is disabled by default.
222
+ """
223
+
224
+ _configure_library_root_logger()
225
+ _get_library_root_logger().propagate = False
226
+
227
+
228
+ def enable_propagation() -> None:
229
+ """
230
+ Enable propagation of the library log outputs. Please disable the HuggingFace muse' default handler to prevent
231
+ double logging if the root logger has been configured.
232
+ """
233
+
234
+ _configure_library_root_logger()
235
+ _get_library_root_logger().propagate = True
236
+
237
+
238
+ def enable_explicit_format() -> None:
239
+ """
240
+ Enable explicit formatting for every HuggingFace muse' logger. The explicit formatter is as follows:
241
+ ```
242
+ [LEVELNAME|FILENAME|LINE NUMBER] TIME >> MESSAGE
243
+ ```
244
+ All handlers currently bound to the root logger are affected by this method.
245
+ """
246
+ handlers = _get_library_root_logger().handlers
247
+
248
+ for handler in handlers:
249
+ formatter = logging.Formatter("[%(levelname)s|%(filename)s:%(lineno)s] %(asctime)s >> %(message)s")
250
+ handler.setFormatter(formatter)
251
+
252
+
253
+ def reset_format() -> None:
254
+ """
255
+ Resets the formatting for HuggingFace muse' loggers.
256
+
257
+ All handlers currently bound to the root logger are affected by this method.
258
+ """
259
+ handlers = _get_library_root_logger().handlers
260
+
261
+ for handler in handlers:
262
+ handler.setFormatter(None)
263
+
264
+
265
+ def warning_advice(self, *args, **kwargs):
266
+ """
267
+ This method is identical to `logger.warning()`, but if env var muse_NO_ADVISORY_WARNINGS=1 is set, this
268
+ warning will not be printed
269
+ """
270
+ no_advisory_warnings = os.getenv("muse_NO_ADVISORY_WARNINGS", False)
271
+ if no_advisory_warnings:
272
+ return
273
+ self.warning(*args, **kwargs)
274
+
275
+
276
+ logging.Logger.warning_advice = warning_advice
277
+
278
+
279
+ class EmptyTqdm:
280
+ """Dummy tqdm which doesn't do anything."""
281
+
282
+ def __init__(self, *args, **kwargs): # pylint: disable=unused-argument
283
+ self._iterator = args[0] if args else None
284
+
285
+ def __iter__(self):
286
+ return iter(self._iterator)
287
+
288
+ def __getattr__(self, _):
289
+ """Return empty function."""
290
+
291
+ def empty_fn(*args, **kwargs): # pylint: disable=unused-argument
292
+ return
293
+
294
+ return empty_fn
295
+
296
+ def __enter__(self):
297
+ return self
298
+
299
+ def __exit__(self, type_, value, traceback):
300
+ return
301
+
302
+
303
+ class _tqdm_cls:
304
+ def __call__(self, *args, **kwargs):
305
+ if _tqdm_active:
306
+ return tqdm_lib.tqdm(*args, **kwargs)
307
+ else:
308
+ return EmptyTqdm(*args, **kwargs)
309
+
310
+ def set_lock(self, *args, **kwargs):
311
+ self._lock = None
312
+ if _tqdm_active:
313
+ return tqdm_lib.tqdm.set_lock(*args, **kwargs)
314
+
315
+ def get_lock(self):
316
+ if _tqdm_active:
317
+ return tqdm_lib.tqdm.get_lock()
318
+
319
+
320
+ tqdm = _tqdm_cls()
321
+
322
+
323
+ def is_progress_bar_enabled() -> bool:
324
+ """Return a boolean indicating whether tqdm progress bars are enabled."""
325
+ global _tqdm_active
326
+ return bool(_tqdm_active)
327
+
328
+
329
+ def enable_progress_bar():
330
+ """Enable tqdm progress bar."""
331
+ global _tqdm_active
332
+ _tqdm_active = True
333
+
334
+
335
+ def disable_progress_bar():
336
+ """Disable tqdm progress bar."""
337
+ global _tqdm_active
338
+ _tqdm_active = False
models/lr_schedulers.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch optimization for diffusion models."""
16
+
17
+ import math
18
+ from enum import Enum
19
+ from typing import Optional, Union
20
+
21
+ from torch.optim import Optimizer
22
+ from torch.optim.lr_scheduler import LambdaLR
23
+
24
+ from .logging import get_logger
25
+
26
+ logger = get_logger(__name__)
27
+
28
+
29
+ class SchedulerType(Enum):
30
+ LINEAR = "linear"
31
+ COSINE = "cosine"
32
+ COSINE_WITH_RESTARTS = "cosine_with_restarts"
33
+ POLYNOMIAL = "polynomial"
34
+ CONSTANT = "constant"
35
+ CONSTANT_WITH_WARMUP = "constant_with_warmup"
36
+
37
+
38
+ def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1):
39
+ """
40
+ Create a schedule with a constant learning rate, using the learning rate set in optimizer.
41
+
42
+ Args:
43
+ optimizer ([`~torch.optim.Optimizer`]):
44
+ The optimizer for which to schedule the learning rate.
45
+ last_epoch (`int`, *optional*, defaults to -1):
46
+ The index of the last epoch when resuming training.
47
+
48
+ Return:
49
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
50
+ """
51
+ return LambdaLR(optimizer, lambda _: 1, last_epoch=last_epoch)
52
+
53
+
54
+ def get_constant_schedule_with_warmup(optimizer: Optimizer, num_warmup_steps: int, last_epoch: int = -1):
55
+ """
56
+ Create a schedule with a constant learning rate preceded by a warmup period during which the learning rate
57
+ increases linearly between 0 and the initial lr set in the optimizer.
58
+
59
+ Args:
60
+ optimizer ([`~torch.optim.Optimizer`]):
61
+ The optimizer for which to schedule the learning rate.
62
+ num_warmup_steps (`int`):
63
+ The number of steps for the warmup phase.
64
+ last_epoch (`int`, *optional*, defaults to -1):
65
+ The index of the last epoch when resuming training.
66
+
67
+ Return:
68
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
69
+ """
70
+
71
+ def lr_lambda(current_step: int):
72
+ if current_step < num_warmup_steps:
73
+ return float(current_step) / float(max(1.0, num_warmup_steps))
74
+ return 1.0
75
+
76
+ return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)
77
+
78
+
79
+ def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
80
+ """
81
+ Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after
82
+ a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.
83
+
84
+ Args:
85
+ optimizer ([`~torch.optim.Optimizer`]):
86
+ The optimizer for which to schedule the learning rate.
87
+ num_warmup_steps (`int`):
88
+ The number of steps for the warmup phase.
89
+ num_training_steps (`int`):
90
+ The total number of training steps.
91
+ last_epoch (`int`, *optional*, defaults to -1):
92
+ The index of the last epoch when resuming training.
93
+
94
+ Return:
95
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
96
+ """
97
+
98
+ def lr_lambda(current_step: int):
99
+ if current_step < num_warmup_steps:
100
+ return float(current_step) / float(max(1, num_warmup_steps))
101
+ return max(
102
+ 0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
103
+ )
104
+
105
+ return LambdaLR(optimizer, lr_lambda, last_epoch)
106
+
107
+
108
+ def get_cosine_schedule_with_warmup(
109
+ optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: float = 0.5, last_epoch: int = -1
110
+ ):
111
+ """
112
+ Create a schedule with a learning rate that decreases following the values of the cosine function between the
113
+ initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
114
+ initial lr set in the optimizer.
115
+
116
+ Args:
117
+ optimizer ([`~torch.optim.Optimizer`]):
118
+ The optimizer for which to schedule the learning rate.
119
+ num_warmup_steps (`int`):
120
+ The number of steps for the warmup phase.
121
+ num_training_steps (`int`):
122
+ The total number of training steps.
123
+ num_periods (`float`, *optional*, defaults to 0.5):
124
+ The number of periods of the cosine function in a schedule (the default is to just decrease from the max
125
+ value to 0 following a half-cosine).
126
+ last_epoch (`int`, *optional*, defaults to -1):
127
+ The index of the last epoch when resuming training.
128
+
129
+ Return:
130
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
131
+ """
132
+
133
+ def lr_lambda(current_step):
134
+ if current_step < num_warmup_steps:
135
+ return float(current_step) / float(max(1, num_warmup_steps))
136
+ progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
137
+ return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
138
+
139
+ return LambdaLR(optimizer, lr_lambda, last_epoch)
140
+
141
+
142
+ def get_cosine_with_hard_restarts_schedule_with_warmup(
143
+ optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: int = 1, last_epoch: int = -1
144
+ ):
145
+ """
146
+ Create a schedule with a learning rate that decreases following the values of the cosine function between the
147
+ initial lr set in the optimizer to 0, with several hard restarts, after a warmup period during which it increases
148
+ linearly between 0 and the initial lr set in the optimizer.
149
+
150
+ Args:
151
+ optimizer ([`~torch.optim.Optimizer`]):
152
+ The optimizer for which to schedule the learning rate.
153
+ num_warmup_steps (`int`):
154
+ The number of steps for the warmup phase.
155
+ num_training_steps (`int`):
156
+ The total number of training steps.
157
+ num_cycles (`int`, *optional*, defaults to 1):
158
+ The number of hard restarts to use.
159
+ last_epoch (`int`, *optional*, defaults to -1):
160
+ The index of the last epoch when resuming training.
161
+
162
+ Return:
163
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
164
+ """
165
+
166
+ def lr_lambda(current_step):
167
+ if current_step < num_warmup_steps:
168
+ return float(current_step) / float(max(1, num_warmup_steps))
169
+ progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
170
+ if progress >= 1.0:
171
+ return 0.0
172
+ return max(0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0))))
173
+
174
+ return LambdaLR(optimizer, lr_lambda, last_epoch)
175
+
176
+
177
+ def get_polynomial_decay_schedule_with_warmup(
178
+ optimizer, num_warmup_steps, num_training_steps, lr_end=1e-7, power=1.0, last_epoch=-1
179
+ ):
180
+ """
181
+ Create a schedule with a learning rate that decreases as a polynomial decay from the initial lr set in the
182
+ optimizer to end lr defined by *lr_end*, after a warmup period during which it increases linearly from 0 to the
183
+ initial lr set in the optimizer.
184
+
185
+ Args:
186
+ optimizer ([`~torch.optim.Optimizer`]):
187
+ The optimizer for which to schedule the learning rate.
188
+ num_warmup_steps (`int`):
189
+ The number of steps for the warmup phase.
190
+ num_training_steps (`int`):
191
+ The total number of training steps.
192
+ lr_end (`float`, *optional*, defaults to 1e-7):
193
+ The end LR.
194
+ power (`float`, *optional*, defaults to 1.0):
195
+ Power factor.
196
+ last_epoch (`int`, *optional*, defaults to -1):
197
+ The index of the last epoch when resuming training.
198
+
199
+ Note: *power* defaults to 1.0 as in the fairseq implementation, which in turn is based on the original BERT
200
+ implementation at
201
+ https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/optimization.py#L37
202
+
203
+ Return:
204
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
205
+
206
+ """
207
+
208
+ lr_init = optimizer.defaults["lr"]
209
+ if not (lr_init > lr_end):
210
+ raise ValueError(f"lr_end ({lr_end}) must be be smaller than initial lr ({lr_init})")
211
+
212
+ def lr_lambda(current_step: int):
213
+ if current_step < num_warmup_steps:
214
+ return float(current_step) / float(max(1, num_warmup_steps))
215
+ elif current_step > num_training_steps:
216
+ return lr_end / lr_init # as LambdaLR multiplies by lr_init
217
+ else:
218
+ lr_range = lr_init - lr_end
219
+ decay_steps = num_training_steps - num_warmup_steps
220
+ pct_remaining = 1 - (current_step - num_warmup_steps) / decay_steps
221
+ decay = lr_range * pct_remaining**power + lr_end
222
+ return decay / lr_init # as LambdaLR multiplies by lr_init
223
+
224
+ return LambdaLR(optimizer, lr_lambda, last_epoch)
225
+
226
+
227
+ TYPE_TO_SCHEDULER_FUNCTION = {
228
+ SchedulerType.LINEAR: get_linear_schedule_with_warmup,
229
+ SchedulerType.COSINE: get_cosine_schedule_with_warmup,
230
+ SchedulerType.COSINE_WITH_RESTARTS: get_cosine_with_hard_restarts_schedule_with_warmup,
231
+ SchedulerType.POLYNOMIAL: get_polynomial_decay_schedule_with_warmup,
232
+ SchedulerType.CONSTANT: get_constant_schedule,
233
+ SchedulerType.CONSTANT_WITH_WARMUP: get_constant_schedule_with_warmup,
234
+ }
235
+
236
+
237
+ def get_scheduler(
238
+ name: Union[str, SchedulerType],
239
+ optimizer: Optimizer,
240
+ num_warmup_steps: Optional[int] = None,
241
+ num_training_steps: Optional[int] = None,
242
+ num_cycles: int = 1,
243
+ power: float = 1.0,
244
+ ):
245
+ """
246
+ Unified API to get any scheduler from its name.
247
+
248
+ Args:
249
+ name (`str` or `SchedulerType`):
250
+ The name of the scheduler to use.
251
+ optimizer (`torch.optim.Optimizer`):
252
+ The optimizer that will be used during training.
253
+ num_warmup_steps (`int`, *optional*):
254
+ The number of warmup steps to do. This is not required by all schedulers (hence the argument being
255
+ optional), the function will raise an error if it's unset and the scheduler type requires it.
256
+ num_training_steps (`int``, *optional*):
257
+ The number of training steps to do. This is not required by all schedulers (hence the argument being
258
+ optional), the function will raise an error if it's unset and the scheduler type requires it.
259
+ num_cycles (`int`, *optional*):
260
+ The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler.
261
+ power (`float`, *optional*, defaults to 1.0):
262
+ Power factor. See `POLYNOMIAL` scheduler
263
+ last_epoch (`int`, *optional*, defaults to -1):
264
+ The index of the last epoch when resuming training.
265
+ """
266
+ name = SchedulerType(name)
267
+ schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
268
+ if name == SchedulerType.CONSTANT:
269
+ return schedule_func(optimizer)
270
+
271
+ # All other schedulers require `num_warmup_steps`
272
+ if num_warmup_steps is None:
273
+ raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
274
+
275
+ if name == SchedulerType.CONSTANT_WITH_WARMUP:
276
+ return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)
277
+
278
+ # All other schedulers require `num_training_steps`
279
+ if num_training_steps is None:
280
+ raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
281
+
282
+ if name == SchedulerType.COSINE_WITH_RESTARTS:
283
+ return schedule_func(
284
+ optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles
285
+ )
286
+
287
+ if name == SchedulerType.POLYNOMIAL:
288
+ return schedule_func(
289
+ optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power
290
+ )
291
+
292
+ return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
models/misc.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from omegaconf import OmegaConf
2
+ import torch
3
+ from typing import (
4
+ Any,
5
+ Callable,
6
+ Dict,
7
+ Iterable,
8
+ List,
9
+ NamedTuple,
10
+ NewType,
11
+ Optional,
12
+ Sized,
13
+ Tuple,
14
+ Type,
15
+ TypeVar,
16
+ Union,
17
+ )
18
+ try:
19
+ from typing import Literal
20
+ except ImportError:
21
+ from typing_extensions import Literal
22
+
23
+ # Tensor dtype
24
+ # for jaxtyping usage, see https://github.com/google/jaxtyping/blob/main/API.md
25
+ from jaxtyping import Bool, Complex, Float, Inexact, Int, Integer, Num, Shaped, UInt
26
+
27
+ # Config type
28
+ from omegaconf import DictConfig
29
+
30
+ # PyTorch Tensor type
31
+ from torch import Tensor
32
+
33
+ # Runtime type checking decorator
34
+ from typeguard import typechecked as typechecker
35
+
36
+
37
+ def broadcast(tensor, src=0):
38
+ if not _distributed_available():
39
+ return tensor
40
+ else:
41
+ torch.distributed.broadcast(tensor, src=src)
42
+ return tensor
43
+
44
+ def _distributed_available():
45
+ return torch.distributed.is_available() and torch.distributed.is_initialized()
46
+
47
+ def parse_structured(fields: Any, cfg: Optional[Union[dict, DictConfig]] = None) -> Any:
48
+ # added by Xavier -- delete '--local-rank' in multi-nodes training, don't know why there is such a keyword
49
+ if '--local-rank' in cfg:
50
+ del cfg['--local-rank']
51
+ # added by Xavier -- delete '--local-rank' in multi-nodes training, don't know why there is such a keyword
52
+ scfg = OmegaConf.structured(fields(**cfg))
53
+ return scfg
models/modeling_magvitv2.py ADDED
@@ -0,0 +1,440 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ from .common_modules import *
6
+ from .modeling_utils import ConfigMixin, ModelMixin, register_to_config
7
+ from .misc import *
8
+ import math
9
+
10
+ class Updateable:
11
+ def do_update_step(
12
+ self, epoch: int, global_step: int, on_load_weights: bool = False
13
+ ):
14
+ for attr in self.__dir__():
15
+ if attr.startswith("_"):
16
+ continue
17
+ try:
18
+ module = getattr(self, attr)
19
+ except:
20
+ continue # ignore attributes like property, which can't be retrived using getattr?
21
+ if isinstance(module, Updateable):
22
+ module.do_update_step(
23
+ epoch, global_step, on_load_weights=on_load_weights
24
+ )
25
+ self.update_step(epoch, global_step, on_load_weights=on_load_weights)
26
+
27
+ def do_update_step_end(self, epoch: int, global_step: int):
28
+ for attr in self.__dir__():
29
+ if attr.startswith("_"):
30
+ continue
31
+ try:
32
+ module = getattr(self, attr)
33
+ except:
34
+ continue # ignore attributes like property, which can't be retrived using getattr?
35
+ if isinstance(module, Updateable):
36
+ module.do_update_step_end(epoch, global_step)
37
+ self.update_step_end(epoch, global_step)
38
+
39
+ def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False):
40
+ # override this method to implement custom update logic
41
+ # if on_load_weights is True, you should be careful doing things related to model evaluations,
42
+ # as the models and tensors are not guarenteed to be on the same device
43
+ pass
44
+
45
+ def update_step_end(self, epoch: int, global_step: int):
46
+ pass
47
+
48
+ class VQGANEncoder(ModelMixin, ConfigMixin):
49
+ @dataclass
50
+ class Config:
51
+ ch: int = 128
52
+ ch_mult: List[int] = field(default_factory=lambda: [1, 2, 2, 4, 4])
53
+ num_res_blocks: List[int] = field(default_factory=lambda: [4, 3, 4, 3, 4])
54
+ attn_resolutions: List[int] = field(default_factory=lambda: [5])
55
+ dropout: float = 0.0
56
+ in_ch: int = 3
57
+ out_ch: int = 3
58
+ resolution: int = 256
59
+ z_channels: int = 13
60
+ double_z: bool = False
61
+
62
+ def __init__(self,
63
+ ch: int = 128,
64
+ ch_mult: List[int] = [1, 2, 2, 4, 4],
65
+ num_res_blocks: List[int] = [4, 3, 4, 3, 4],
66
+ attn_resolutions: List[int] = [5],
67
+ dropout: float = 0.0,
68
+ in_ch: int = 3,
69
+ out_ch: int = 3,
70
+ resolution: int = 256,
71
+ z_channels: int = 13,
72
+ double_z: bool = False):
73
+ super().__init__()
74
+ self.ch = ch
75
+ self.temb_ch = 0
76
+ self.num_resolutions = len(ch_mult)
77
+ self.num_res_blocks = num_res_blocks
78
+ self.resolution = resolution
79
+ self.in_ch = in_ch
80
+ # downsampling
81
+ self.conv_in = torch.nn.Conv2d(
82
+ self.in_ch, self.ch, kernel_size=3, stride=1, padding=1
83
+ )
84
+
85
+ curr_res = self.resolution
86
+ in_ch_mult = (1,) + tuple(ch_mult)
87
+ self.down = nn.ModuleList()
88
+ for i_level in range(self.num_resolutions):
89
+ block = nn.ModuleList()
90
+ attn = nn.ModuleList()
91
+ block_in = self.ch * in_ch_mult[i_level]
92
+ block_out = self.ch * ch_mult[i_level]
93
+ for i_block in range(self.num_res_blocks[i_level]):
94
+ block.append(
95
+ ResnetBlock(
96
+ in_channels=block_in,
97
+ out_channels=block_out,
98
+ temb_channels=self.temb_ch,
99
+ dropout=dropout,
100
+ )
101
+ )
102
+ block_in = block_out
103
+ if curr_res in attn_resolutions:
104
+ attn.append(AttnBlock(block_in))
105
+ down = nn.Module()
106
+ down.block = block
107
+ down.attn = attn
108
+ if i_level != self.num_resolutions - 1:
109
+ down.downsample = Downsample(block_in, True)
110
+ curr_res = curr_res // 2
111
+ self.down.append(down)
112
+
113
+ # middle
114
+ self.mid = nn.Module()
115
+ self.mid.block_1 = ResnetBlock(
116
+ in_channels=block_in,
117
+ out_channels=block_in,
118
+ temb_channels=self.temb_ch,
119
+ dropout=dropout,
120
+ )
121
+ self.mid.attn_1 = AttnBlock(block_in)
122
+ self.mid.block_2 = ResnetBlock(
123
+ in_channels=block_in,
124
+ out_channels=block_in,
125
+ temb_channels=self.temb_ch,
126
+ dropout=dropout,
127
+ )
128
+
129
+
130
+ self.norm_out = Normalize(block_in)
131
+ self.conv_out = torch.nn.Conv2d(
132
+ block_in,
133
+ 2 * z_channels if double_z else z_channels,
134
+ kernel_size=3,
135
+ stride=1,
136
+ padding=1,
137
+ )
138
+
139
+ self.quant_conv = torch.nn.Conv2d(z_channels, z_channels, 1)
140
+ # for param in self.parameters():
141
+ # broadcast(param, src=0)
142
+
143
+ def forward(self, x):
144
+ # timestep embedding
145
+ temb = None
146
+
147
+ # downsampling
148
+ hs = [self.conv_in(x)]
149
+ for i_level in range(self.num_resolutions):
150
+ for i_block in range(self.num_res_blocks[i_level]):
151
+ h = self.down[i_level].block[i_block](hs[-1], temb)
152
+ if len(self.down[i_level].attn) > 0:
153
+ h = self.down[i_level].attn[i_block](h)
154
+ hs.append(h)
155
+ if i_level != self.num_resolutions - 1:
156
+ hs.append(self.down[i_level].downsample(hs[-1]))
157
+
158
+ # middle
159
+ h = hs[-1]
160
+ h = self.mid.block_1(h, temb)
161
+ h = self.mid.attn_1(h)
162
+ h = self.mid.block_2(h, temb)
163
+
164
+ # end
165
+ h = self.norm_out(h)
166
+ h = nonlinearity(h)
167
+ h = self.conv_out(h)
168
+ h = self.quant_conv(h)
169
+ return h
170
+
171
+
172
+ class LFQuantizer(nn.Module):
173
+ def __init__(self, num_codebook_entry: int = -1,
174
+ codebook_dim: int = 13,
175
+ beta: float = 0.25,
176
+ entropy_multiplier: float = 0.1,
177
+ commit_loss_multiplier: float = 0.1, ):
178
+ super().__init__()
179
+ self.codebook_size = 2 ** codebook_dim
180
+ print(
181
+ f"Look-up free quantizer with codebook size: {self.codebook_size}"
182
+ )
183
+ self.e_dim = codebook_dim
184
+ self.beta = beta
185
+
186
+ indices = torch.arange(self.codebook_size)
187
+
188
+ binary = (
189
+ indices.unsqueeze(1)
190
+ >> torch.arange(codebook_dim - 1, -1, -1, dtype=torch.long)
191
+ ) & 1
192
+
193
+ embedding = binary.float() * 2 - 1
194
+ self.register_buffer("embedding", embedding)
195
+ self.register_buffer(
196
+ "power_vals", 2 ** torch.arange(codebook_dim - 1, -1, -1)
197
+ )
198
+ self.commit_loss_multiplier = commit_loss_multiplier
199
+ self.entropy_multiplier = entropy_multiplier
200
+
201
+ def get_indices(self, z_q):
202
+ return (
203
+ (self.power_vals.reshape(1, -1, 1, 1) * (z_q > 0).float())
204
+ .sum(1, keepdim=True)
205
+ .long()
206
+ )
207
+
208
+ def get_codebook_entry(self, indices, shape=None):
209
+ if shape is None:
210
+ h, w = int(math.sqrt(indices.shape[-1])), int(math.sqrt(indices.shape[-1]))
211
+ else:
212
+ h, w = shape
213
+ b, _ = indices.shape
214
+ indices = indices.reshape(-1)
215
+ z_q = self.embedding[indices]
216
+ z_q = z_q.view(b, h, w, -1)
217
+
218
+ # reshape back to match original input shape
219
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
220
+
221
+ return z_q
222
+
223
+ def forward(self, z, get_code=False):
224
+ """
225
+ Inputs the output of the encoder network z and maps it to a discrete
226
+ one-hot vector that is the index of the closest embedding vector e_j
227
+ z (continuous) -> z_q (discrete)
228
+ z.shape = (batch, channel, height, width)
229
+ quantization pipeline:
230
+ 1. get encoder input (B,C,H,W)
231
+ 2. flatten input to (B*H*W,C)
232
+ """
233
+ if get_code:
234
+ return self.get_codebook_entry(z)
235
+
236
+ # reshape z -> (batch, height, width, channel) and flatten
237
+ z = z.permute(0, 2, 3, 1).contiguous()
238
+ z_flattened = z.view(-1, self.e_dim)
239
+ ge_zero = (z_flattened > 0).float()
240
+ ones = torch.ones_like(z_flattened)
241
+ z_q = ones * ge_zero + -ones * (1 - ge_zero)
242
+
243
+ # preserve gradients
244
+ z_q = z_flattened + (z_q - z_flattened).detach()
245
+
246
+ # compute entropy loss
247
+ CatDist = torch.distributions.categorical.Categorical
248
+ logit = torch.stack(
249
+ [
250
+ -(z_flattened - torch.ones_like(z_q)).pow(2),
251
+ -(z_flattened - torch.ones_like(z_q) * -1).pow(2),
252
+ ],
253
+ dim=-1,
254
+ )
255
+ cat_dist = CatDist(logits=logit)
256
+ entropy = cat_dist.entropy().mean()
257
+ mean_prob = cat_dist.probs.mean(0)
258
+ mean_entropy = CatDist(probs=mean_prob).entropy().mean()
259
+
260
+ # compute loss for embedding
261
+ commit_loss = torch.mean(
262
+ (z_q.detach() - z_flattened) ** 2
263
+ ) + self.beta * torch.mean((z_q - z_flattened.detach()) ** 2)
264
+
265
+ # reshape back to match original input shape
266
+ z_q = z_q.view(z.shape)
267
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
268
+
269
+ return {
270
+ "z": z_q,
271
+ "quantizer_loss": commit_loss * self.commit_loss_multiplier,
272
+ "entropy_loss": (entropy - mean_entropy) * self.entropy_multiplier,
273
+ "indices": self.get_indices(z_q),
274
+ }
275
+
276
+
277
+ class VQGANDecoder(ModelMixin, ConfigMixin):
278
+ def __init__(self, ch: int = 128,
279
+ ch_mult: List[int] = [1, 1, 2, 2, 4],
280
+ num_res_blocks: List[int] = [4, 4, 3, 4, 3],
281
+ attn_resolutions: List[int] = [5],
282
+ dropout: float = 0.0,
283
+ in_ch: int = 3,
284
+ out_ch: int = 3,
285
+ resolution: int = 256,
286
+ z_channels: int = 13,
287
+ double_z: bool = False):
288
+ super().__init__()
289
+ self.ch = ch
290
+ self.temb_ch = 0
291
+ self.num_resolutions = len(ch_mult)
292
+ self.num_res_blocks = num_res_blocks
293
+ self.resolution = resolution
294
+ self.in_ch = in_ch
295
+ self.give_pre_end = False
296
+
297
+ self.z_channels = z_channels
298
+ # compute in_ch_mult, block_in and curr_res at lowest res
299
+ in_ch_mult = (1,) + tuple(ch_mult)
300
+ block_in = ch * ch_mult[self.num_resolutions - 1]
301
+ curr_res = self.resolution // 2 ** (self.num_resolutions - 1)
302
+ self.z_shape = (1, z_channels, curr_res, curr_res)
303
+ print(
304
+ "Working with z of shape {} = {} dimensions.".format(
305
+ self.z_shape, np.prod(self.z_shape)
306
+ )
307
+ )
308
+
309
+ # z to block_in
310
+ self.conv_in = torch.nn.Conv2d(
311
+ z_channels, block_in, kernel_size=3, stride=1, padding=1
312
+ )
313
+
314
+ # middle
315
+ self.mid = nn.Module()
316
+ self.mid.block_1 = ResnetBlock(
317
+ in_channels=block_in,
318
+ out_channels=block_in,
319
+ temb_channels=self.temb_ch,
320
+ dropout=dropout,
321
+ )
322
+ self.mid.attn_1 = AttnBlock(block_in)
323
+ self.mid.block_2 = ResnetBlock(
324
+ in_channels=block_in,
325
+ out_channels=block_in,
326
+ temb_channels=self.temb_ch,
327
+ dropout=dropout,
328
+ )
329
+
330
+ # upsampling
331
+ self.up = nn.ModuleList()
332
+ for i_level in reversed(range(self.num_resolutions)):
333
+ block = nn.ModuleList()
334
+ attn = nn.ModuleList()
335
+ block_out = ch * ch_mult[i_level]
336
+ for i_block in range(self.num_res_blocks[i_level]):
337
+ block.append(
338
+ ResnetBlock(
339
+ in_channels=block_in,
340
+ out_channels=block_out,
341
+ temb_channels=self.temb_ch,
342
+ dropout=dropout,
343
+ )
344
+ )
345
+ block_in = block_out
346
+ if curr_res in attn_resolutions:
347
+ attn.append(AttnBlock(block_in))
348
+ up = nn.Module()
349
+ up.block = block
350
+ up.attn = attn
351
+ if i_level != 0:
352
+ up.upsample = Upsample(block_in, True)
353
+ curr_res = curr_res * 2
354
+ self.up.insert(0, up) # prepend to get consistent order
355
+
356
+ self.norm_out = Normalize(block_in)
357
+ self.conv_out = torch.nn.Conv2d(
358
+ block_in, out_ch, kernel_size=3, stride=1, padding=1
359
+ )
360
+ self.post_quant_conv = torch.nn.Conv2d(
361
+ z_channels, z_channels, 1
362
+ )
363
+
364
+
365
+ def forward(self, z):
366
+ # assert z.shape[1:] == self.z_shape[1:]
367
+ self.last_z_shape = z.shape
368
+ # timestep embedding
369
+ temb = None
370
+ output = dict()
371
+ z = self.post_quant_conv(z)
372
+
373
+ # z to block_in
374
+ h = self.conv_in(z)
375
+
376
+ # middle
377
+ h = self.mid.block_1(h, temb)
378
+ h = self.mid.attn_1(h)
379
+ h = self.mid.block_2(h, temb)
380
+
381
+ # upsampling
382
+ for i_level in reversed(range(self.num_resolutions)):
383
+ for i_block in range(self.num_res_blocks[i_level]):
384
+ h = self.up[i_level].block[i_block](h, temb)
385
+ if len(self.up[i_level].attn) > 0:
386
+ h = self.up[i_level].attn[i_block](h)
387
+ if i_level != 0:
388
+ h = self.up[i_level].upsample(h)
389
+
390
+ # end
391
+ output["output"] = h
392
+ if self.give_pre_end:
393
+ return output
394
+
395
+ h = self.norm_out(h)
396
+ h = nonlinearity(h)
397
+ h = self.conv_out(h)
398
+ output["output"] = h
399
+ return output
400
+
401
+
402
+ class MAGVITv2(ModelMixin, ConfigMixin):
403
+ @register_to_config
404
+ def __init__(
405
+ self,
406
+ ):
407
+ super().__init__()
408
+
409
+ self.encoder = VQGANEncoder()
410
+ self.decoder = VQGANDecoder()
411
+ self.quantize = LFQuantizer()
412
+
413
+ def forward(self, pixel_values, return_loss=False):
414
+ pass
415
+
416
+ def encode(self, pixel_values, return_loss=False):
417
+ hidden_states = self.encoder(pixel_values)
418
+ quantized_states = self.quantize(hidden_states)['z']
419
+ codebook_indices = self.quantize.get_indices(quantized_states).reshape(pixel_values.shape[0], -1)
420
+ output = (quantized_states, codebook_indices)
421
+ return output
422
+
423
+ def get_code(self, pixel_values):
424
+ hidden_states = self.encoder(pixel_values)
425
+ codebook_indices = self.quantize.get_indices(self.quantize(hidden_states)['z']).reshape(pixel_values.shape[0], -1)
426
+
427
+ return codebook_indices
428
+
429
+ def decode_code(self, codebook_indices, shape=None):
430
+ z_q = self.quantize.get_codebook_entry(codebook_indices, shape=shape)
431
+
432
+ reconstructed_pixel_values = self.decoder(z_q)["output"]
433
+ return reconstructed_pixel_values
434
+
435
+
436
+ if __name__ == '__main__':
437
+ encoder = VQGANEncoder()
438
+ import ipdb
439
+ ipdb.set_trace()
440
+ print()
models/modeling_showo.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from transformers import AutoConfig
4
+ from .modeling_utils import ConfigMixin, ModelMixin, register_to_config
5
+ from .sampling import cosine_schedule, mask_by_random_topk
6
+ from .phi import PhiForCausalLM
7
+
8
+ try:
9
+ import xformers.ops as xops
10
+
11
+ is_xformers_available = True
12
+ except ImportError:
13
+ is_xformers_available = False
14
+
15
+
16
+ class Showo(ModelMixin, ConfigMixin):
17
+ _supports_gradient_checkpointing = True
18
+
19
+ @register_to_config
20
+ def __init__(
21
+ self,
22
+ w_clip_vit,
23
+ vocab_size,
24
+ llm_vocab_size,
25
+ llm_model_path='',
26
+ codebook_size=8192,
27
+ num_vq_tokens=256,
28
+ **kwargs,
29
+ ):
30
+ super().__init__()
31
+ self.vocab_size = vocab_size
32
+ self.register_to_config(mask_token_id=vocab_size - 1)
33
+ config = AutoConfig.from_pretrained(llm_model_path)
34
+ self.showo = PhiForCausalLM(config)
35
+ self.showo.resize_token_embeddings(self.vocab_size)
36
+ self.output_size = self.vocab_size
37
+
38
+ if self.w_clip_vit:
39
+ self.mm_projector = torch.nn.Sequential(
40
+ torch.nn.Linear(1024, 2048),
41
+ torch.nn.GELU(),
42
+ torch.nn.Linear(2048, 2048)
43
+ )
44
+
45
+ def _set_gradient_checkpointing(self, module, value=False):
46
+ self.gradient_checkpointing = True
47
+
48
+ def forward(
49
+ self,
50
+ input_ids,
51
+ input_embeddings=None,
52
+ attention_mask=None,
53
+ labels=None,
54
+ label_smoothing=0.0,
55
+ config=None,
56
+ labels_mask_text=None,
57
+ labels_mask_image=None,
58
+ **kwargs,
59
+ ):
60
+
61
+ if input_embeddings is None:
62
+ logits = self.showo(input_ids=input_ids, attention_mask=attention_mask)['logits']
63
+ else:
64
+ logits = self.showo(inputs_embeds=input_embeddings, attention_mask=attention_mask)['logits']
65
+
66
+ if labels is not None:
67
+ raise NotImplementedError
68
+
69
+ return logits
70
+
71
+ def t2i_generate(
72
+ self,
73
+ input_ids: torch.LongTensor = None,
74
+ uncond_input_ids: torch.LongTensor = None,
75
+ attention_mask=None,
76
+ temperature=1.0,
77
+ timesteps=18, # ideal number of steps is 18 in maskgit paper
78
+ guidance_scale=0,
79
+ noise_schedule=cosine_schedule,
80
+ generator: torch.Generator = None,
81
+ uni_prompting=None,
82
+ config=None,
83
+ **kwargs,
84
+ ):
85
+ """
86
+ Generate 1:1 similar to the original MaskGit repo
87
+ https://github.com/google-research/maskgit/blob/main/maskgit/libml/parallel_decode.py#L79
88
+ """
89
+ # begin with all image token ids masked
90
+ mask_token_id = self.config.mask_token_id
91
+ seq_len = config.model.showo.num_vq_tokens
92
+
93
+ input_ids_minus_lm_vocab_size = input_ids[:, -(seq_len + 1):-1].clone()
94
+ input_ids_minus_lm_vocab_size = torch.where(input_ids_minus_lm_vocab_size == mask_token_id,
95
+ mask_token_id,
96
+ input_ids_minus_lm_vocab_size - config.model.showo.llm_vocab_size - 10)
97
+ # import ipdb
98
+ # ipdb.set_trace()
99
+ if uncond_input_ids is not None:
100
+ uncond_prefix = uncond_input_ids[:, :config.dataset.preprocessing.max_seq_length + 1]
101
+
102
+ for step in range(timesteps):
103
+ if uncond_input_ids is not None and guidance_scale > 0:
104
+ uncond_input_ids = torch.cat(
105
+ [uncond_prefix, input_ids[:, config.dataset.preprocessing.max_seq_length + 1:]], dim=1)
106
+ model_input = torch.cat([input_ids, uncond_input_ids])
107
+ cond_logits, uncond_logits = self(model_input, attention_mask=attention_mask).chunk(2)
108
+ # logits = uncond_logits + guidance_scale * (cond_logits - uncond_logits)
109
+ # it seems that muse has different cfg setting
110
+ logits = (1 + guidance_scale) * cond_logits - guidance_scale * uncond_logits
111
+ logits = logits[:, -(seq_len + 1):-1, config.model.showo.llm_vocab_size + 10:-1]
112
+ else:
113
+ logits = self(input_ids, attention_mask=attention_mask)
114
+ logits = logits[:, -(seq_len + 1):-1, config.model.showo.llm_vocab_size + 10:-1]
115
+
116
+ probs = logits.softmax(dim=-1)
117
+ sampled = probs.reshape(-1, logits.size(-1))
118
+ sampled_ids = torch.multinomial(sampled, 1, generator=generator)[:, 0].view(*logits.shape[:-1])
119
+
120
+ unknown_map = input_ids_minus_lm_vocab_size == mask_token_id
121
+ sampled_ids = torch.where(unknown_map, sampled_ids, input_ids_minus_lm_vocab_size)
122
+ # Defines the mask ratio for the next round. The number to mask out is
123
+ # determined by mask_ratio * unknown_number_in_the_beginning.
124
+ ratio = 1.0 * (step + 1) / timesteps
125
+ mask_ratio = noise_schedule(torch.tensor(ratio))
126
+ # Computes the probabilities of each selected tokens.
127
+ selected_probs = torch.gather(probs, -1, sampled_ids.long()[..., None])
128
+ selected_probs = selected_probs.squeeze(-1)
129
+
130
+ # Ignores the tokens given in the input by overwriting their confidence.
131
+ selected_probs = torch.where(unknown_map, selected_probs, torch.finfo(selected_probs.dtype).max)
132
+ # Gets mask lens for each sample in the batch according to the mask ratio.
133
+ mask_len = (seq_len * mask_ratio).floor().unsqueeze(0).to(logits.device)
134
+ # Keeps at least one of prediction in this round and also masks out at least
135
+ # one and for the next iteration
136
+ mask_len = torch.max(
137
+ torch.tensor([1], device=logits.device), torch.min(unknown_map.sum(dim=-1, keepdim=True) - 1, mask_len)
138
+ )
139
+ # Adds noise for randomness
140
+ temperature = temperature * (1.0 - ratio)
141
+ masking = mask_by_random_topk(mask_len, selected_probs, temperature, generator=generator)
142
+ # Masks tokens with lower confidence.
143
+ input_ids[:, -(seq_len + 1):-1] = torch.where(masking, mask_token_id,
144
+ sampled_ids + config.model.showo.llm_vocab_size + 10)
145
+ input_ids_minus_lm_vocab_size = torch.where(masking, mask_token_id, sampled_ids)
146
+
147
+ return sampled_ids
148
+
149
+ @torch.no_grad()
150
+ def mmu_generate(self, idx=None, input_embeddings=None, attention_mask=None, max_new_tokens=100, temperature=1.0, top_k=None, eot_token=None):
151
+ """
152
+ Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
153
+ the sequence max_new_tokens times, feeding the predictions back into the model each time.
154
+ Most likely you'll want to make sure to be in model.eval() mode of operation for this.
155
+ """
156
+ try:
157
+ device = idx.device
158
+ except:
159
+ device = input_embeddings.device
160
+
161
+ result = []
162
+ for _ in range(max_new_tokens):
163
+ # if the sequence context is growing too long we must crop it at block_size
164
+ # idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
165
+ # forward the model to get the logits for the index in the sequence
166
+ # logits, _ = self(idx_cond)
167
+ logits = self(idx, input_embeddings=input_embeddings, attention_mask=attention_mask)
168
+
169
+ L = attention_mask.shape[-1]
170
+ attention_mask = attention_mask.squeeze()
171
+ attention_mask_a = torch.hstack(
172
+ [
173
+ attention_mask, # L, L
174
+ torch.zeros((L, 1)).to(device) + torch.finfo(logits.dtype).min,
175
+ ]
176
+ )
177
+ attention_mask_b = torch.vstack(
178
+ [
179
+ attention_mask_a, # L, L+1
180
+ torch.hstack([attention_mask[-1, :], torch.tensor([0]).to(device)]).unsqueeze(0),
181
+ ]
182
+ )
183
+ attention_mask = attention_mask_b
184
+
185
+ # pluck the logits at the final step and scale by desired temperature
186
+ logits = logits[:, -1, :] / temperature
187
+ # optionally crop the logits to only the top k options
188
+ if top_k is not None:
189
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
190
+ logits[logits < v[:, [-1]]] = -float('Inf')
191
+ # apply softmax to convert logits to (normalized) probabilities
192
+ probs = F.softmax(logits, dim=-1)
193
+ # sample from the distribution
194
+ idx_next = torch.multinomial(probs, num_samples=1)
195
+ result.append(idx_next[0][0])
196
+ # append sampled index to the running sequence and continue
197
+ if self.config.w_clip_vit:
198
+ idx_next_embeddings = self.showo.model.embed_tokens(idx_next)
199
+ input_embeddings = torch.cat([input_embeddings, idx_next_embeddings], dim=1)
200
+ else:
201
+ idx = torch.cat((idx, idx_next), dim=1)
202
+
203
+ if eot_token is not None and idx_next.cpu() == eot_token:
204
+ break
205
+
206
+ return result
models/modeling_utils.py ADDED
@@ -0,0 +1,1207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team.
3
+ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import inspect
18
+ import itertools
19
+ import json
20
+ import os
21
+ import re
22
+ from collections import OrderedDict
23
+ from functools import partial
24
+ from pathlib import Path
25
+ from typing import Any, Callable, List, Optional, Tuple, Union
26
+
27
+ import safetensors
28
+ import torch
29
+ from huggingface_hub import create_repo, split_torch_state_dict_into_shards
30
+ from huggingface_hub.utils import validate_hf_hub_args
31
+ from torch import Tensor, nn
32
+
33
+ from diffusers import __version__
34
+ from diffusers.utils import (
35
+ FLAX_WEIGHTS_NAME,
36
+ SAFE_WEIGHTS_INDEX_NAME,
37
+ WEIGHTS_INDEX_NAME,
38
+ _add_variant,
39
+ _get_checkpoint_shard_files,
40
+ _get_model_file,
41
+ deprecate,
42
+ is_accelerate_available,
43
+ is_torch_version,
44
+ logging,
45
+ )
46
+
47
+ CONFIG_NAME = "config.json"
48
+ WEIGHTS_NAME = "pytorch_model.bin"
49
+ SAFETENSORS_WEIGHTS_NAME = "pytorch_model.safetensors"
50
+ HUGGINGFACE_CO_RESOLVE_ENDPOINT = "https://huggingface.co"
51
+
52
+ from diffusers.utils.hub_utils import (
53
+ PushToHubMixin,
54
+ load_or_create_model_card,
55
+ populate_model_card,
56
+ )
57
+ from diffusers.models.model_loading_utils import (
58
+ _determine_device_map,
59
+ _fetch_index_file,
60
+ _load_state_dict_into_model,
61
+ load_model_dict_into_meta,
62
+ load_state_dict,
63
+ )
64
+
65
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
66
+
67
+ logger = logging.get_logger(__name__)
68
+
69
+ _REGEX_SHARD = re.compile(r"(.*?)-\d{5}-of-\d{5}")
70
+
71
+
72
+ if is_torch_version(">=", "1.9.0"):
73
+ _LOW_CPU_MEM_USAGE_DEFAULT = True
74
+ else:
75
+ _LOW_CPU_MEM_USAGE_DEFAULT = False
76
+
77
+
78
+ if is_accelerate_available():
79
+ import accelerate
80
+
81
+
82
+ def get_parameter_device(parameter: torch.nn.Module) -> torch.device:
83
+ try:
84
+ parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers())
85
+ return next(parameters_and_buffers).device
86
+ except StopIteration:
87
+ # For torch.nn.DataParallel compatibility in PyTorch 1.5
88
+
89
+ def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
90
+ tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
91
+ return tuples
92
+
93
+ gen = parameter._named_members(get_members_fn=find_tensor_attributes)
94
+ first_tuple = next(gen)
95
+ return first_tuple[1].device
96
+
97
+
98
+ def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype:
99
+ try:
100
+ params = tuple(parameter.parameters())
101
+ if len(params) > 0:
102
+ return params[0].dtype
103
+
104
+ buffers = tuple(parameter.buffers())
105
+ if len(buffers) > 0:
106
+ return buffers[0].dtype
107
+
108
+ except StopIteration:
109
+ # For torch.nn.DataParallel compatibility in PyTorch 1.5
110
+
111
+ def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
112
+ tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
113
+ return tuples
114
+
115
+ gen = parameter._named_members(get_members_fn=find_tensor_attributes)
116
+ first_tuple = next(gen)
117
+ return first_tuple[1].dtype
118
+
119
+
120
+ class ModelMixin(torch.nn.Module, PushToHubMixin):
121
+ r"""
122
+ Base class for all models.
123
+
124
+ [`ModelMixin`] takes care of storing the model configuration and provides methods for loading, downloading and
125
+ saving models.
126
+
127
+ - **config_name** ([`str`]) -- Filename to save a model to when calling [`~models.ModelMixin.save_pretrained`].
128
+ """
129
+
130
+ config_name = CONFIG_NAME
131
+ _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
132
+ _supports_gradient_checkpointing = False
133
+ _keys_to_ignore_on_load_unexpected = None
134
+ _no_split_modules = None
135
+
136
+ def __init__(self):
137
+ super().__init__()
138
+
139
+ def __getattr__(self, name: str) -> Any:
140
+ """The only reason we overwrite `getattr` here is to gracefully deprecate accessing
141
+ config attributes directly. See https://github.com/huggingface/diffusers/pull/3129 We need to overwrite
142
+ __getattr__ here in addition so that we don't trigger `torch.nn.Module`'s __getattr__':
143
+ https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
144
+ """
145
+
146
+ is_in_config = "_internal_dict" in self.__dict__ and hasattr(self.__dict__["_internal_dict"], name)
147
+ is_attribute = name in self.__dict__
148
+
149
+ if is_in_config and not is_attribute:
150
+ deprecation_message = f"Accessing config attribute `{name}` directly via '{type(self).__name__}' object attribute is deprecated. Please access '{name}' over '{type(self).__name__}'s config object instead, e.g. 'unet.config.{name}'."
151
+ deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False, stacklevel=3)
152
+ return self._internal_dict[name]
153
+
154
+ # call PyTorch's https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
155
+ return super().__getattr__(name)
156
+
157
+ @property
158
+ def is_gradient_checkpointing(self) -> bool:
159
+ """
160
+ Whether gradient checkpointing is activated for this model or not.
161
+ """
162
+ return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
163
+
164
+ def enable_gradient_checkpointing(self) -> None:
165
+ """
166
+ Activates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or
167
+ *checkpoint activations* in other frameworks).
168
+ """
169
+ if not self._supports_gradient_checkpointing:
170
+ raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
171
+ self.apply(partial(self._set_gradient_checkpointing, value=True))
172
+
173
+ def disable_gradient_checkpointing(self) -> None:
174
+ """
175
+ Deactivates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or
176
+ *checkpoint activations* in other frameworks).
177
+ """
178
+ if self._supports_gradient_checkpointing:
179
+ self.apply(partial(self._set_gradient_checkpointing, value=False))
180
+
181
+ def set_use_npu_flash_attention(self, valid: bool) -> None:
182
+ r"""
183
+ Set the switch for the npu flash attention.
184
+ """
185
+
186
+ def fn_recursive_set_npu_flash_attention(module: torch.nn.Module):
187
+ if hasattr(module, "set_use_npu_flash_attention"):
188
+ module.set_use_npu_flash_attention(valid)
189
+
190
+ for child in module.children():
191
+ fn_recursive_set_npu_flash_attention(child)
192
+
193
+ for module in self.children():
194
+ if isinstance(module, torch.nn.Module):
195
+ fn_recursive_set_npu_flash_attention(module)
196
+
197
+ def enable_npu_flash_attention(self) -> None:
198
+ r"""
199
+ Enable npu flash attention from torch_npu
200
+
201
+ """
202
+ self.set_use_npu_flash_attention(True)
203
+
204
+ def disable_npu_flash_attention(self) -> None:
205
+ r"""
206
+ disable npu flash attention from torch_npu
207
+
208
+ """
209
+ self.set_use_npu_flash_attention(False)
210
+
211
+ def set_use_memory_efficient_attention_xformers(
212
+ self, valid: bool, attention_op: Optional[Callable] = None
213
+ ) -> None:
214
+ # Recursively walk through all the children.
215
+ # Any children which exposes the set_use_memory_efficient_attention_xformers method
216
+ # gets the message
217
+ def fn_recursive_set_mem_eff(module: torch.nn.Module):
218
+ if hasattr(module, "set_use_memory_efficient_attention_xformers"):
219
+ module.set_use_memory_efficient_attention_xformers(valid, attention_op)
220
+
221
+ for child in module.children():
222
+ fn_recursive_set_mem_eff(child)
223
+
224
+ for module in self.children():
225
+ if isinstance(module, torch.nn.Module):
226
+ fn_recursive_set_mem_eff(module)
227
+
228
+ def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None) -> None:
229
+ r"""
230
+ Enable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/).
231
+
232
+ When this option is enabled, you should observe lower GPU memory usage and a potential speed up during
233
+ inference. Speed up during training is not guaranteed.
234
+
235
+ <Tip warning={true}>
236
+
237
+ ⚠️ When memory efficient attention and sliced attention are both enabled, memory efficient attention takes
238
+ precedent.
239
+
240
+ </Tip>
241
+
242
+ Parameters:
243
+ attention_op (`Callable`, *optional*):
244
+ Override the default `None` operator for use as `op` argument to the
245
+ [`memory_efficient_attention()`](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.memory_efficient_attention)
246
+ function of xFormers.
247
+
248
+ Examples:
249
+
250
+ ```py
251
+ >>> import torch
252
+ >>> from diffusers import UNet2DConditionModel
253
+ >>> from xformers.ops import MemoryEfficientAttentionFlashAttentionOp
254
+
255
+ >>> model = UNet2DConditionModel.from_pretrained(
256
+ ... "stabilityai/stable-diffusion-2-1", subfolder="unet", torch_dtype=torch.float16
257
+ ... )
258
+ >>> model = model.to("cuda")
259
+ >>> model.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp)
260
+ ```
261
+ """
262
+ self.set_use_memory_efficient_attention_xformers(True, attention_op)
263
+
264
+ def disable_xformers_memory_efficient_attention(self) -> None:
265
+ r"""
266
+ Disable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/).
267
+ """
268
+ self.set_use_memory_efficient_attention_xformers(False)
269
+
270
+ def save_pretrained(
271
+ self,
272
+ save_directory: Union[str, os.PathLike],
273
+ is_main_process: bool = True,
274
+ save_function: Optional[Callable] = None,
275
+ safe_serialization: bool = True,
276
+ variant: Optional[str] = None,
277
+ max_shard_size: Union[int, str] = "10GB",
278
+ push_to_hub: bool = False,
279
+ **kwargs,
280
+ ):
281
+ """
282
+ Save a model and its configuration file to a directory so that it can be reloaded using the
283
+ [`~models.ModelMixin.from_pretrained`] class method.
284
+
285
+ Arguments:
286
+ save_directory (`str` or `os.PathLike`):
287
+ Directory to save a model and its configuration file to. Will be created if it doesn't exist.
288
+ is_main_process (`bool`, *optional*, defaults to `True`):
289
+ Whether the process calling this is the main process or not. Useful during distributed training and you
290
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
291
+ process to avoid race conditions.
292
+ save_function (`Callable`):
293
+ The function to use to save the state dictionary. Useful during distributed training when you need to
294
+ replace `torch.save` with another method. Can be configured with the environment variable
295
+ `DIFFUSERS_SAVE_MODE`.
296
+ safe_serialization (`bool`, *optional*, defaults to `True`):
297
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
298
+ variant (`str`, *optional*):
299
+ If specified, weights are saved in the format `pytorch_model.<variant>.bin`.
300
+ max_shard_size (`int` or `str`, defaults to `"10GB"`):
301
+ The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size
302
+ lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5GB"`).
303
+ If expressed as an integer, the unit is bytes. Note that this limit will be decreased after a certain
304
+ period of time (starting from Oct 2024) to allow users to upgrade to the latest version of `diffusers`.
305
+ This is to establish a common default size for this argument across different libraries in the Hugging
306
+ Face ecosystem (`transformers`, and `accelerate`, for example).
307
+ push_to_hub (`bool`, *optional*, defaults to `False`):
308
+ Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
309
+ repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
310
+ namespace).
311
+ kwargs (`Dict[str, Any]`, *optional*):
312
+ Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
313
+ """
314
+ if os.path.isfile(save_directory):
315
+ logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
316
+ return
317
+
318
+ weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
319
+ weights_name = _add_variant(weights_name, variant)
320
+ weight_name_split = weights_name.split(".")
321
+ if len(weight_name_split) in [2, 3]:
322
+ weights_name_pattern = weight_name_split[0] + "{suffix}." + ".".join(weight_name_split[1:])
323
+ else:
324
+ raise ValueError(f"Invalid {weights_name} provided.")
325
+
326
+ os.makedirs(save_directory, exist_ok=True)
327
+
328
+ if push_to_hub:
329
+ commit_message = kwargs.pop("commit_message", None)
330
+ private = kwargs.pop("private", False)
331
+ create_pr = kwargs.pop("create_pr", False)
332
+ token = kwargs.pop("token", None)
333
+ repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
334
+ repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id
335
+
336
+ # Only save the model itself if we are using distributed training
337
+ model_to_save = self
338
+
339
+ # Attach architecture to the config
340
+ # Save the config
341
+ if is_main_process:
342
+ model_to_save.save_config(save_directory)
343
+
344
+ # Save the model
345
+ state_dict = model_to_save.state_dict()
346
+
347
+ # Save the model
348
+ state_dict_split = split_torch_state_dict_into_shards(
349
+ state_dict, max_shard_size=max_shard_size, filename_pattern=weights_name_pattern
350
+ )
351
+
352
+ # Clean the folder from a previous save
353
+ if is_main_process:
354
+ for filename in os.listdir(save_directory):
355
+ if filename in state_dict_split.filename_to_tensors.keys():
356
+ continue
357
+ full_filename = os.path.join(save_directory, filename)
358
+ if not os.path.isfile(full_filename):
359
+ continue
360
+ weights_without_ext = weights_name_pattern.replace(".bin", "").replace(".safetensors", "")
361
+ weights_without_ext = weights_without_ext.replace("{suffix}", "")
362
+ filename_without_ext = filename.replace(".bin", "").replace(".safetensors", "")
363
+ # make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005
364
+ if (
365
+ filename.startswith(weights_without_ext)
366
+ and _REGEX_SHARD.fullmatch(filename_without_ext) is not None
367
+ ):
368
+ os.remove(full_filename)
369
+
370
+ for filename, tensors in state_dict_split.filename_to_tensors.items():
371
+ shard = {tensor: state_dict[tensor] for tensor in tensors}
372
+ filepath = os.path.join(save_directory, filename)
373
+ if safe_serialization:
374
+ # At some point we will need to deal better with save_function (used for TPU and other distributed
375
+ # joyfulness), but for now this enough.
376
+ safetensors.torch.save_file(shard, filepath, metadata={"format": "pt"})
377
+ else:
378
+ torch.save(shard, filepath)
379
+
380
+ if state_dict_split.is_sharded:
381
+ index = {
382
+ "metadata": state_dict_split.metadata,
383
+ "weight_map": state_dict_split.tensor_to_filename,
384
+ }
385
+ save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
386
+ save_index_file = os.path.join(save_directory, _add_variant(save_index_file, variant))
387
+ # Save the index as well
388
+ with open(save_index_file, "w", encoding="utf-8") as f:
389
+ content = json.dumps(index, indent=2, sort_keys=True) + "\n"
390
+ f.write(content)
391
+ logger.info(
392
+ f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
393
+ f"split in {len(state_dict_split.filename_to_tensors)} checkpoint shards. You can find where each parameters has been saved in the "
394
+ f"index located at {save_index_file}."
395
+ )
396
+ else:
397
+ path_to_weights = os.path.join(save_directory, weights_name)
398
+ logger.info(f"Model weights saved in {path_to_weights}")
399
+
400
+ if push_to_hub:
401
+ # Create a new empty model card and eventually tag it
402
+ model_card = load_or_create_model_card(repo_id, token=token)
403
+ model_card = populate_model_card(model_card)
404
+ model_card.save(Path(save_directory, "README.md").as_posix())
405
+
406
+ self._upload_folder(
407
+ save_directory,
408
+ repo_id,
409
+ token=token,
410
+ commit_message=commit_message,
411
+ create_pr=create_pr,
412
+ )
413
+
414
+ @classmethod
415
+ @validate_hf_hub_args
416
+ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
417
+ r"""
418
+ Instantiate a pretrained PyTorch model from a pretrained model configuration.
419
+
420
+ The model is set in evaluation mode - `model.eval()` - by default, and dropout modules are deactivated. To
421
+ train the model, set it back in training mode with `model.train()`.
422
+
423
+ Parameters:
424
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
425
+ Can be either:
426
+
427
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
428
+ the Hub.
429
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
430
+ with [`~ModelMixin.save_pretrained`].
431
+
432
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
433
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
434
+ is not used.
435
+ torch_dtype (`str` or `torch.dtype`, *optional*):
436
+ Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
437
+ dtype is automatically derived from the model's weights.
438
+ force_download (`bool`, *optional*, defaults to `False`):
439
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
440
+ cached versions if they exist.
441
+ proxies (`Dict[str, str]`, *optional*):
442
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
443
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
444
+ output_loading_info (`bool`, *optional*, defaults to `False`):
445
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
446
+ local_files_only(`bool`, *optional*, defaults to `False`):
447
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
448
+ won't be downloaded from the Hub.
449
+ token (`str` or *bool*, *optional*):
450
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
451
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
452
+ revision (`str`, *optional*, defaults to `"main"`):
453
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
454
+ allowed by Git.
455
+ from_flax (`bool`, *optional*, defaults to `False`):
456
+ Load the model weights from a Flax checkpoint save file.
457
+ subfolder (`str`, *optional*, defaults to `""`):
458
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
459
+ mirror (`str`, *optional*):
460
+ Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
461
+ guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
462
+ information.
463
+ device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
464
+ A map that specifies where each submodule should go. It doesn't need to be defined for each
465
+ parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the
466
+ same device. Defaults to `None`, meaning that the model will be loaded on CPU.
467
+
468
+ Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For
469
+ more information about each option see [designing a device
470
+ map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
471
+ max_memory (`Dict`, *optional*):
472
+ A dictionary device identifier for the maximum memory. Will default to the maximum memory available for
473
+ each GPU and the available CPU RAM if unset.
474
+ offload_folder (`str` or `os.PathLike`, *optional*):
475
+ The path to offload weights if `device_map` contains the value `"disk"`.
476
+ offload_state_dict (`bool`, *optional*):
477
+ If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if
478
+ the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True`
479
+ when there is some disk offload.
480
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
481
+ Speed up model loading only loading the pretrained weights and not initializing the weights. This also
482
+ tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
483
+ Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
484
+ argument to `True` will raise an error.
485
+ variant (`str`, *optional*):
486
+ Load weights from a specified `variant` filename such as `"fp16"` or `"ema"`. This is ignored when
487
+ loading `from_flax`.
488
+ use_safetensors (`bool`, *optional*, defaults to `None`):
489
+ If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the
490
+ `safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors`
491
+ weights. If set to `False`, `safetensors` weights are not loaded.
492
+
493
+ <Tip>
494
+
495
+ To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with
496
+ `huggingface-cli login`. You can also activate the special
497
+ ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a
498
+ firewalled environment.
499
+
500
+ </Tip>
501
+
502
+ Example:
503
+
504
+ ```py
505
+ from diffusers import UNet2DConditionModel
506
+
507
+ unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet")
508
+ ```
509
+
510
+ If you get the error message below, you need to finetune the weights for your downstream task:
511
+
512
+ ```bash
513
+ Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
514
+ - conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated
515
+ You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
516
+ ```
517
+ """
518
+ cache_dir = kwargs.pop("cache_dir", None)
519
+ ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
520
+ force_download = kwargs.pop("force_download", False)
521
+ from_flax = kwargs.pop("from_flax", False)
522
+ proxies = kwargs.pop("proxies", None)
523
+ output_loading_info = kwargs.pop("output_loading_info", False)
524
+ local_files_only = kwargs.pop("local_files_only", None)
525
+ token = kwargs.pop("token", None)
526
+ revision = kwargs.pop("revision", None)
527
+ torch_dtype = kwargs.pop("torch_dtype", None)
528
+ subfolder = kwargs.pop("subfolder", None)
529
+ device_map = kwargs.pop("device_map", None)
530
+ max_memory = kwargs.pop("max_memory", None)
531
+ offload_folder = kwargs.pop("offload_folder", None)
532
+ offload_state_dict = kwargs.pop("offload_state_dict", False)
533
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
534
+ variant = kwargs.pop("variant", None)
535
+ use_safetensors = kwargs.pop("use_safetensors", None)
536
+
537
+ allow_pickle = False
538
+ if use_safetensors is None:
539
+ use_safetensors = True
540
+ allow_pickle = True
541
+
542
+ if low_cpu_mem_usage and not is_accelerate_available():
543
+ low_cpu_mem_usage = False
544
+ logger.warning(
545
+ "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
546
+ " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
547
+ " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
548
+ " install accelerate\n```\n."
549
+ )
550
+
551
+ if device_map is not None and not is_accelerate_available():
552
+ raise NotImplementedError(
553
+ "Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set"
554
+ " `device_map=None`. You can install accelerate with `pip install accelerate`."
555
+ )
556
+
557
+ # Check if we can handle device_map and dispatching the weights
558
+ if device_map is not None and not is_torch_version(">=", "1.9.0"):
559
+ raise NotImplementedError(
560
+ "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
561
+ " `device_map=None`."
562
+ )
563
+
564
+ if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
565
+ raise NotImplementedError(
566
+ "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
567
+ " `low_cpu_mem_usage=False`."
568
+ )
569
+
570
+ if low_cpu_mem_usage is False and device_map is not None:
571
+ raise ValueError(
572
+ f"You cannot set `low_cpu_mem_usage` to `False` while using device_map={device_map} for loading and"
573
+ " dispatching. Please make sure to set `low_cpu_mem_usage=True`."
574
+ )
575
+
576
+ # change device_map into a map if we passed an int, a str or a torch.device
577
+ if isinstance(device_map, torch.device):
578
+ device_map = {"": device_map}
579
+ elif isinstance(device_map, str) and device_map not in ["auto", "balanced", "balanced_low_0", "sequential"]:
580
+ try:
581
+ device_map = {"": torch.device(device_map)}
582
+ except RuntimeError:
583
+ raise ValueError(
584
+ "When passing device_map as a string, the value needs to be a device name (e.g. cpu, cuda:0) or "
585
+ f"'auto', 'balanced', 'balanced_low_0', 'sequential' but found {device_map}."
586
+ )
587
+ elif isinstance(device_map, int):
588
+ if device_map < 0:
589
+ raise ValueError(
590
+ "You can't pass device_map as a negative int. If you want to put the model on the cpu, pass device_map = 'cpu' "
591
+ )
592
+ else:
593
+ device_map = {"": device_map}
594
+
595
+ if device_map is not None:
596
+ if low_cpu_mem_usage is None:
597
+ low_cpu_mem_usage = True
598
+ elif not low_cpu_mem_usage:
599
+ raise ValueError("Passing along a `device_map` requires `low_cpu_mem_usage=True`")
600
+
601
+ if low_cpu_mem_usage:
602
+ if device_map is not None and not is_torch_version(">=", "1.10"):
603
+ # The max memory utils require PyTorch >= 1.10 to have torch.cuda.mem_get_info.
604
+ raise ValueError("`low_cpu_mem_usage` and `device_map` require PyTorch >= 1.10.")
605
+
606
+ # Load config if we don't provide a configuration
607
+ config_path = pretrained_model_name_or_path
608
+
609
+ user_agent = {
610
+ "diffusers": __version__,
611
+ "file_type": "model",
612
+ "framework": "pytorch",
613
+ }
614
+
615
+ # load config
616
+ config, unused_kwargs, commit_hash = cls.load_config(
617
+ config_path,
618
+ cache_dir=cache_dir,
619
+ return_unused_kwargs=True,
620
+ return_commit_hash=True,
621
+ force_download=force_download,
622
+ proxies=proxies,
623
+ local_files_only=local_files_only,
624
+ token=token,
625
+ revision=revision,
626
+ subfolder=subfolder,
627
+ user_agent=user_agent,
628
+ **kwargs,
629
+ )
630
+
631
+ # Determine if we're loading from a directory of sharded checkpoints.
632
+ is_sharded = False
633
+ index_file = None
634
+ is_local = os.path.isdir(pretrained_model_name_or_path)
635
+ index_file = _fetch_index_file(
636
+ is_local=is_local,
637
+ pretrained_model_name_or_path=pretrained_model_name_or_path,
638
+ subfolder=subfolder or "",
639
+ use_safetensors=use_safetensors,
640
+ cache_dir=cache_dir,
641
+ variant=variant,
642
+ force_download=force_download,
643
+ proxies=proxies,
644
+ local_files_only=local_files_only,
645
+ token=token,
646
+ revision=revision,
647
+ user_agent=user_agent,
648
+ commit_hash=commit_hash,
649
+ )
650
+ if index_file is not None and index_file.is_file():
651
+ is_sharded = True
652
+
653
+ if is_sharded and from_flax:
654
+ raise ValueError("Loading of sharded checkpoints is not supported when `from_flax=True`.")
655
+
656
+ # load model
657
+ model_file = None
658
+ if from_flax:
659
+ model_file = _get_model_file(
660
+ pretrained_model_name_or_path,
661
+ weights_name=FLAX_WEIGHTS_NAME,
662
+ cache_dir=cache_dir,
663
+ force_download=force_download,
664
+ proxies=proxies,
665
+ local_files_only=local_files_only,
666
+ token=token,
667
+ revision=revision,
668
+ subfolder=subfolder,
669
+ user_agent=user_agent,
670
+ commit_hash=commit_hash,
671
+ )
672
+ model = cls.from_config(config, **unused_kwargs)
673
+
674
+ # Convert the weights
675
+ from .modeling_pytorch_flax_utils import load_flax_checkpoint_in_pytorch_model
676
+
677
+ model = load_flax_checkpoint_in_pytorch_model(model, model_file)
678
+ else:
679
+ if is_sharded:
680
+ sharded_ckpt_cached_folder, sharded_metadata = _get_checkpoint_shard_files(
681
+ pretrained_model_name_or_path,
682
+ index_file,
683
+ cache_dir=cache_dir,
684
+ proxies=proxies,
685
+ local_files_only=local_files_only,
686
+ token=token,
687
+ user_agent=user_agent,
688
+ revision=revision,
689
+ subfolder=subfolder or "",
690
+ )
691
+
692
+ elif use_safetensors and not is_sharded:
693
+ try:
694
+ model_file = _get_model_file(
695
+ pretrained_model_name_or_path,
696
+ weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
697
+ cache_dir=cache_dir,
698
+ force_download=force_download,
699
+ proxies=proxies,
700
+ local_files_only=local_files_only,
701
+ token=token,
702
+ revision=revision,
703
+ subfolder=subfolder,
704
+ user_agent=user_agent,
705
+ commit_hash=commit_hash,
706
+ )
707
+
708
+ except IOError as e:
709
+ logger.error(f"An error occurred while trying to fetch {pretrained_model_name_or_path}: {e}")
710
+ if not allow_pickle:
711
+ raise
712
+ logger.warning(
713
+ "Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead."
714
+ )
715
+
716
+ if model_file is None and not is_sharded:
717
+ model_file = _get_model_file(
718
+ pretrained_model_name_or_path,
719
+ weights_name=_add_variant(WEIGHTS_NAME, variant),
720
+ cache_dir=cache_dir,
721
+ force_download=force_download,
722
+ proxies=proxies,
723
+ local_files_only=local_files_only,
724
+ token=token,
725
+ revision=revision,
726
+ subfolder=subfolder,
727
+ user_agent=user_agent,
728
+ commit_hash=commit_hash,
729
+ )
730
+
731
+ if low_cpu_mem_usage:
732
+ # Instantiate model with empty weights
733
+ with accelerate.init_empty_weights():
734
+ model = cls.from_config(config, **unused_kwargs)
735
+
736
+ # if device_map is None, load the state dict and move the params from meta device to the cpu
737
+ if device_map is None and not is_sharded:
738
+ param_device = "cpu"
739
+ state_dict = load_state_dict(model_file, variant=variant)
740
+ model._convert_deprecated_attention_blocks(state_dict)
741
+ # move the params from meta device to cpu
742
+ missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
743
+ if len(missing_keys) > 0:
744
+ raise ValueError(
745
+ f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are"
746
+ f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
747
+ " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
748
+ " those weights or else make sure your checkpoint file is correct."
749
+ )
750
+
751
+ unexpected_keys = load_model_dict_into_meta(
752
+ model,
753
+ state_dict,
754
+ device=param_device,
755
+ dtype=torch_dtype,
756
+ model_name_or_path=pretrained_model_name_or_path,
757
+ )
758
+
759
+ if cls._keys_to_ignore_on_load_unexpected is not None:
760
+ for pat in cls._keys_to_ignore_on_load_unexpected:
761
+ unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
762
+
763
+ if len(unexpected_keys) > 0:
764
+ logger.warning(
765
+ f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
766
+ )
767
+
768
+ else: # else let accelerate handle loading and dispatching.
769
+ # Load weights and dispatch according to the device_map
770
+ # by default the device_map is None and the weights are loaded on the CPU
771
+ force_hook = True
772
+ device_map = _determine_device_map(model, device_map, max_memory, torch_dtype)
773
+ if device_map is None and is_sharded:
774
+ # we load the parameters on the cpu
775
+ device_map = {"": "cpu"}
776
+ force_hook = False
777
+ try:
778
+ accelerate.load_checkpoint_and_dispatch(
779
+ model,
780
+ model_file if not is_sharded else index_file,
781
+ device_map,
782
+ max_memory=max_memory,
783
+ offload_folder=offload_folder,
784
+ offload_state_dict=offload_state_dict,
785
+ dtype=torch_dtype,
786
+ force_hooks=force_hook,
787
+ strict=True,
788
+ )
789
+ except AttributeError as e:
790
+ # When using accelerate loading, we do not have the ability to load the state
791
+ # dict and rename the weight names manually. Additionally, accelerate skips
792
+ # torch loading conventions and directly writes into `module.{_buffers, _parameters}`
793
+ # (which look like they should be private variables?), so we can't use the standard hooks
794
+ # to rename parameters on load. We need to mimic the original weight names so the correct
795
+ # attributes are available. After we have loaded the weights, we convert the deprecated
796
+ # names to the new non-deprecated names. Then we _greatly encourage_ the user to convert
797
+ # the weights so we don't have to do this again.
798
+
799
+ if "'Attention' object has no attribute" in str(e):
800
+ logger.warning(
801
+ f"Taking `{str(e)}` while using `accelerate.load_checkpoint_and_dispatch` to mean {pretrained_model_name_or_path}"
802
+ " was saved with deprecated attention block weight names. We will load it with the deprecated attention block"
803
+ " names and convert them on the fly to the new attention block format. Please re-save the model after this conversion,"
804
+ " so we don't have to do the on the fly renaming in the future. If the model is from a hub checkpoint,"
805
+ " please also re-upload it or open a PR on the original repository."
806
+ )
807
+ model._temp_convert_self_to_deprecated_attention_blocks()
808
+ accelerate.load_checkpoint_and_dispatch(
809
+ model,
810
+ model_file if not is_sharded else index_file,
811
+ device_map,
812
+ max_memory=max_memory,
813
+ offload_folder=offload_folder,
814
+ offload_state_dict=offload_state_dict,
815
+ dtype=torch_dtype,
816
+ force_hooks=force_hook,
817
+ strict=True,
818
+ )
819
+ model._undo_temp_convert_self_to_deprecated_attention_blocks()
820
+ else:
821
+ raise e
822
+
823
+ loading_info = {
824
+ "missing_keys": [],
825
+ "unexpected_keys": [],
826
+ "mismatched_keys": [],
827
+ "error_msgs": [],
828
+ }
829
+ else:
830
+ model = cls.from_config(config, **unused_kwargs)
831
+
832
+ state_dict = load_state_dict(model_file, variant=variant)
833
+ model._convert_deprecated_attention_blocks(state_dict)
834
+
835
+ model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
836
+ model,
837
+ state_dict,
838
+ model_file,
839
+ pretrained_model_name_or_path,
840
+ ignore_mismatched_sizes=ignore_mismatched_sizes,
841
+ )
842
+
843
+ loading_info = {
844
+ "missing_keys": missing_keys,
845
+ "unexpected_keys": unexpected_keys,
846
+ "mismatched_keys": mismatched_keys,
847
+ "error_msgs": error_msgs,
848
+ }
849
+
850
+ if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
851
+ raise ValueError(
852
+ f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
853
+ )
854
+ elif torch_dtype is not None:
855
+ model = model.to(torch_dtype)
856
+
857
+ model.register_to_config(_name_or_path=pretrained_model_name_or_path)
858
+
859
+ # Set model in evaluation mode to deactivate DropOut modules by default
860
+ model.eval()
861
+ if output_loading_info:
862
+ return model, loading_info
863
+
864
+ return model
865
+
866
+ @classmethod
867
+ def _load_pretrained_model(
868
+ cls,
869
+ model,
870
+ state_dict: OrderedDict,
871
+ resolved_archive_file,
872
+ pretrained_model_name_or_path: Union[str, os.PathLike],
873
+ ignore_mismatched_sizes: bool = False,
874
+ ):
875
+ # Retrieve missing & unexpected_keys
876
+ model_state_dict = model.state_dict()
877
+ loaded_keys = list(state_dict.keys())
878
+
879
+ expected_keys = list(model_state_dict.keys())
880
+
881
+ original_loaded_keys = loaded_keys
882
+
883
+ missing_keys = list(set(expected_keys) - set(loaded_keys))
884
+ unexpected_keys = list(set(loaded_keys) - set(expected_keys))
885
+
886
+ # Make sure we are able to load base models as well as derived models (with heads)
887
+ model_to_load = model
888
+
889
+ def _find_mismatched_keys(
890
+ state_dict,
891
+ model_state_dict,
892
+ loaded_keys,
893
+ ignore_mismatched_sizes,
894
+ ):
895
+ mismatched_keys = []
896
+ if ignore_mismatched_sizes:
897
+ for checkpoint_key in loaded_keys:
898
+ model_key = checkpoint_key
899
+
900
+ if (
901
+ model_key in model_state_dict
902
+ and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
903
+ ):
904
+ mismatched_keys.append(
905
+ (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
906
+ )
907
+ del state_dict[checkpoint_key]
908
+ return mismatched_keys
909
+
910
+ if state_dict is not None:
911
+ # Whole checkpoint
912
+ mismatched_keys = _find_mismatched_keys(
913
+ state_dict,
914
+ model_state_dict,
915
+ original_loaded_keys,
916
+ ignore_mismatched_sizes,
917
+ )
918
+ error_msgs = _load_state_dict_into_model(model_to_load, state_dict)
919
+
920
+ if len(error_msgs) > 0:
921
+ error_msg = "\n\t".join(error_msgs)
922
+ if "size mismatch" in error_msg:
923
+ error_msg += (
924
+ "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
925
+ )
926
+ raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
927
+
928
+ if len(unexpected_keys) > 0:
929
+ logger.warning(
930
+ f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
931
+ f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
932
+ f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task"
933
+ " or with another architecture (e.g. initializing a BertForSequenceClassification model from a"
934
+ " BertForPreTraining model).\n- This IS NOT expected if you are initializing"
935
+ f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly"
936
+ " identical (initializing a BertForSequenceClassification model from a"
937
+ " BertForSequenceClassification model)."
938
+ )
939
+ else:
940
+ logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
941
+ if len(missing_keys) > 0:
942
+ logger.warning(
943
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
944
+ f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
945
+ " TRAIN this model on a down-stream task to be able to use it for predictions and inference."
946
+ )
947
+ elif len(mismatched_keys) == 0:
948
+ logger.info(
949
+ f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
950
+ f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the"
951
+ f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions"
952
+ " without further training."
953
+ )
954
+ if len(mismatched_keys) > 0:
955
+ mismatched_warning = "\n".join(
956
+ [
957
+ f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
958
+ for key, shape1, shape2 in mismatched_keys
959
+ ]
960
+ )
961
+ logger.warning(
962
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
963
+ f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
964
+ f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be"
965
+ " able to use it for predictions and inference."
966
+ )
967
+
968
+ return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
969
+
970
+ @classmethod
971
+ def _get_signature_keys(cls, obj):
972
+ parameters = inspect.signature(obj.__init__).parameters
973
+ required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty}
974
+ optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty})
975
+ expected_modules = set(required_parameters.keys()) - {"self"}
976
+
977
+ return expected_modules, optional_parameters
978
+
979
+ # Adapted from `transformers` modeling_utils.py
980
+ def _get_no_split_modules(self, device_map: str):
981
+ """
982
+ Get the modules of the model that should not be spit when using device_map. We iterate through the modules to
983
+ get the underlying `_no_split_modules`.
984
+
985
+ Args:
986
+ device_map (`str`):
987
+ The device map value. Options are ["auto", "balanced", "balanced_low_0", "sequential"]
988
+
989
+ Returns:
990
+ `List[str]`: List of modules that should not be split
991
+ """
992
+ _no_split_modules = set()
993
+ modules_to_check = [self]
994
+ while len(modules_to_check) > 0:
995
+ module = modules_to_check.pop(-1)
996
+ # if the module does not appear in _no_split_modules, we also check the children
997
+ if module.__class__.__name__ not in _no_split_modules:
998
+ if isinstance(module, ModelMixin):
999
+ if module._no_split_modules is None:
1000
+ raise ValueError(
1001
+ f"{module.__class__.__name__} does not support `device_map='{device_map}'`. To implement support, the model "
1002
+ "class needs to implement the `_no_split_modules` attribute."
1003
+ )
1004
+ else:
1005
+ _no_split_modules = _no_split_modules | set(module._no_split_modules)
1006
+ modules_to_check += list(module.children())
1007
+ return list(_no_split_modules)
1008
+
1009
+ @property
1010
+ def device(self) -> torch.device:
1011
+ """
1012
+ `torch.device`: The device on which the module is (assuming that all the module parameters are on the same
1013
+ device).
1014
+ """
1015
+ return get_parameter_device(self)
1016
+
1017
+ @property
1018
+ def dtype(self) -> torch.dtype:
1019
+ """
1020
+ `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
1021
+ """
1022
+ return get_parameter_dtype(self)
1023
+
1024
+ def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int:
1025
+ """
1026
+ Get number of (trainable or non-embedding) parameters in the module.
1027
+
1028
+ Args:
1029
+ only_trainable (`bool`, *optional*, defaults to `False`):
1030
+ Whether or not to return only the number of trainable parameters.
1031
+ exclude_embeddings (`bool`, *optional*, defaults to `False`):
1032
+ Whether or not to return only the number of non-embedding parameters.
1033
+
1034
+ Returns:
1035
+ `int`: The number of parameters.
1036
+
1037
+ Example:
1038
+
1039
+ ```py
1040
+ from diffusers import UNet2DConditionModel
1041
+
1042
+ model_id = "runwayml/stable-diffusion-v1-5"
1043
+ unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet")
1044
+ unet.num_parameters(only_trainable=True)
1045
+ 859520964
1046
+ ```
1047
+ """
1048
+
1049
+ if exclude_embeddings:
1050
+ embedding_param_names = [
1051
+ f"{name}.weight"
1052
+ for name, module_type in self.named_modules()
1053
+ if isinstance(module_type, torch.nn.Embedding)
1054
+ ]
1055
+ non_embedding_parameters = [
1056
+ parameter for name, parameter in self.named_parameters() if name not in embedding_param_names
1057
+ ]
1058
+ return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable)
1059
+ else:
1060
+ return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
1061
+
1062
+ def _convert_deprecated_attention_blocks(self, state_dict: OrderedDict) -> None:
1063
+ deprecated_attention_block_paths = []
1064
+
1065
+ def recursive_find_attn_block(name, module):
1066
+ if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
1067
+ deprecated_attention_block_paths.append(name)
1068
+
1069
+ for sub_name, sub_module in module.named_children():
1070
+ sub_name = sub_name if name == "" else f"{name}.{sub_name}"
1071
+ recursive_find_attn_block(sub_name, sub_module)
1072
+
1073
+ recursive_find_attn_block("", self)
1074
+
1075
+ # NOTE: we have to check if the deprecated parameters are in the state dict
1076
+ # because it is possible we are loading from a state dict that was already
1077
+ # converted
1078
+
1079
+ for path in deprecated_attention_block_paths:
1080
+ # group_norm path stays the same
1081
+
1082
+ # query -> to_q
1083
+ if f"{path}.query.weight" in state_dict:
1084
+ state_dict[f"{path}.to_q.weight"] = state_dict.pop(f"{path}.query.weight")
1085
+ if f"{path}.query.bias" in state_dict:
1086
+ state_dict[f"{path}.to_q.bias"] = state_dict.pop(f"{path}.query.bias")
1087
+
1088
+ # key -> to_k
1089
+ if f"{path}.key.weight" in state_dict:
1090
+ state_dict[f"{path}.to_k.weight"] = state_dict.pop(f"{path}.key.weight")
1091
+ if f"{path}.key.bias" in state_dict:
1092
+ state_dict[f"{path}.to_k.bias"] = state_dict.pop(f"{path}.key.bias")
1093
+
1094
+ # value -> to_v
1095
+ if f"{path}.value.weight" in state_dict:
1096
+ state_dict[f"{path}.to_v.weight"] = state_dict.pop(f"{path}.value.weight")
1097
+ if f"{path}.value.bias" in state_dict:
1098
+ state_dict[f"{path}.to_v.bias"] = state_dict.pop(f"{path}.value.bias")
1099
+
1100
+ # proj_attn -> to_out.0
1101
+ if f"{path}.proj_attn.weight" in state_dict:
1102
+ state_dict[f"{path}.to_out.0.weight"] = state_dict.pop(f"{path}.proj_attn.weight")
1103
+ if f"{path}.proj_attn.bias" in state_dict:
1104
+ state_dict[f"{path}.to_out.0.bias"] = state_dict.pop(f"{path}.proj_attn.bias")
1105
+
1106
+ def _temp_convert_self_to_deprecated_attention_blocks(self) -> None:
1107
+ deprecated_attention_block_modules = []
1108
+
1109
+ def recursive_find_attn_block(module):
1110
+ if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
1111
+ deprecated_attention_block_modules.append(module)
1112
+
1113
+ for sub_module in module.children():
1114
+ recursive_find_attn_block(sub_module)
1115
+
1116
+ recursive_find_attn_block(self)
1117
+
1118
+ for module in deprecated_attention_block_modules:
1119
+ module.query = module.to_q
1120
+ module.key = module.to_k
1121
+ module.value = module.to_v
1122
+ module.proj_attn = module.to_out[0]
1123
+
1124
+ # We don't _have_ to delete the old attributes, but it's helpful to ensure
1125
+ # that _all_ the weights are loaded into the new attributes and we're not
1126
+ # making an incorrect assumption that this model should be converted when
1127
+ # it really shouldn't be.
1128
+ del module.to_q
1129
+ del module.to_k
1130
+ del module.to_v
1131
+ del module.to_out
1132
+
1133
+ def _undo_temp_convert_self_to_deprecated_attention_blocks(self) -> None:
1134
+ deprecated_attention_block_modules = []
1135
+
1136
+ def recursive_find_attn_block(module) -> None:
1137
+ if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
1138
+ deprecated_attention_block_modules.append(module)
1139
+
1140
+ for sub_module in module.children():
1141
+ recursive_find_attn_block(sub_module)
1142
+
1143
+ recursive_find_attn_block(self)
1144
+
1145
+ for module in deprecated_attention_block_modules:
1146
+ module.to_q = module.query
1147
+ module.to_k = module.key
1148
+ module.to_v = module.value
1149
+ module.to_out = nn.ModuleList([module.proj_attn, nn.Dropout(module.dropout)])
1150
+
1151
+ del module.query
1152
+ del module.key
1153
+ del module.value
1154
+ del module.proj_attn
1155
+
1156
+
1157
+ class LegacyModelMixin(ModelMixin):
1158
+ r"""
1159
+ A subclass of `ModelMixin` to resolve class mapping from legacy classes (like `Transformer2DModel`) to more
1160
+ pipeline-specific classes (like `DiTTransformer2DModel`).
1161
+ """
1162
+
1163
+ @classmethod
1164
+ @validate_hf_hub_args
1165
+ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
1166
+ # To prevent dependency import problem.
1167
+ from diffusers.models.model_loading_utils import _fetch_remapped_cls_from_config
1168
+
1169
+ # Create a copy of the kwargs so that we don't mess with the keyword arguments in the downstream calls.
1170
+ kwargs_copy = kwargs.copy()
1171
+
1172
+ cache_dir = kwargs.pop("cache_dir", None)
1173
+ force_download = kwargs.pop("force_download", False)
1174
+ proxies = kwargs.pop("proxies", None)
1175
+ local_files_only = kwargs.pop("local_files_only", None)
1176
+ token = kwargs.pop("token", None)
1177
+ revision = kwargs.pop("revision", None)
1178
+ subfolder = kwargs.pop("subfolder", None)
1179
+
1180
+ # Load config if we don't provide a configuration
1181
+ config_path = pretrained_model_name_or_path
1182
+
1183
+ user_agent = {
1184
+ "diffusers": __version__,
1185
+ "file_type": "model",
1186
+ "framework": "pytorch",
1187
+ }
1188
+
1189
+ # load config
1190
+ config, _, _ = cls.load_config(
1191
+ config_path,
1192
+ cache_dir=cache_dir,
1193
+ return_unused_kwargs=True,
1194
+ return_commit_hash=True,
1195
+ force_download=force_download,
1196
+ proxies=proxies,
1197
+ local_files_only=local_files_only,
1198
+ token=token,
1199
+ revision=revision,
1200
+ subfolder=subfolder,
1201
+ user_agent=user_agent,
1202
+ **kwargs,
1203
+ )
1204
+ # resolve remapping
1205
+ remapped_class = _fetch_remapped_cls_from_config(config, cls)
1206
+
1207
+ return remapped_class.from_pretrained(pretrained_model_name_or_path, **kwargs_copy)
models/phi.py ADDED
@@ -0,0 +1,1489 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Microsoft and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """PyTorch Phi model."""
17
+
18
+ import math
19
+ from typing import List, Optional, Tuple, Union
20
+
21
+ import torch
22
+ import torch.nn.functional as F
23
+ import torch.utils.checkpoint
24
+ from packaging import version
25
+ from torch import nn
26
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
27
+
28
+ from transformers.activations import ACT2FN
29
+ from transformers.cache_utils import Cache, DynamicCache
30
+ from transformers.modeling_attn_mask_utils import (
31
+ _prepare_4d_causal_attention_mask,
32
+ _prepare_4d_causal_attention_mask_for_sdpa,
33
+ )
34
+ from transformers.modeling_outputs import (
35
+ BaseModelOutputWithPast,
36
+ CausalLMOutputWithPast,
37
+ SequenceClassifierOutputWithPast,
38
+ TokenClassifierOutput,
39
+ )
40
+ from transformers.modeling_utils import PreTrainedModel
41
+ from transformers.utils import (
42
+ add_code_sample_docstrings,
43
+ add_start_docstrings,
44
+ add_start_docstrings_to_model_forward,
45
+ get_torch_version,
46
+ is_flash_attn_2_available,
47
+ is_flash_attn_greater_or_equal_2_10,
48
+ logging,
49
+ replace_return_docstrings,
50
+ )
51
+ from transformers.models.phi.configuration_phi import PhiConfig
52
+
53
+
54
+ if is_flash_attn_2_available():
55
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
56
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
57
+
58
+
59
+ logger = logging.get_logger(__name__)
60
+
61
+ _CHECKPOINT_FOR_DOC = "microsoft/phi-1"
62
+ _CONFIG_FOR_DOC = "PhiConfig"
63
+
64
+
65
+ # Copied from transformers.models.llama.modeling_llama._get_unpad_data
66
+ def _get_unpad_data(attention_mask):
67
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
68
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
69
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
70
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
71
+ return (
72
+ indices,
73
+ cu_seqlens,
74
+ max_seqlen_in_batch,
75
+ )
76
+
77
+
78
+ # Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Phi
79
+ class PhiRotaryEmbedding(nn.Module):
80
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
81
+ super().__init__()
82
+
83
+ self.dim = dim
84
+ self.max_position_embeddings = max_position_embeddings
85
+ self.base = base
86
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
87
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
88
+
89
+ # Build here to make `torch.jit.trace` work.
90
+ self._set_cos_sin_cache(
91
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
92
+ )
93
+
94
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
95
+ self.max_seq_len_cached = seq_len
96
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
97
+
98
+ freqs = torch.outer(t, self.inv_freq)
99
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
100
+ emb = torch.cat((freqs, freqs), dim=-1)
101
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
102
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
103
+
104
+ def forward(self, x, seq_len=None):
105
+ # x: [bs, num_attention_heads, seq_len, head_size]
106
+ if seq_len > self.max_seq_len_cached:
107
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
108
+
109
+ return (
110
+ self.cos_cached[:seq_len].to(dtype=x.dtype),
111
+ self.sin_cached[:seq_len].to(dtype=x.dtype),
112
+ )
113
+
114
+
115
+ # Copied from transformers.models.falcon.modeling_falcon.FalconLinearScalingRotaryEmbedding with Falcon->Phi
116
+ class PhiLinearScalingRotaryEmbedding(PhiRotaryEmbedding):
117
+ """PhiRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
118
+
119
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
120
+ self.scaling_factor = scaling_factor
121
+ super().__init__(dim, max_position_embeddings, base, device)
122
+
123
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
124
+ self.max_seq_len_cached = seq_len
125
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
126
+ t = t / self.scaling_factor
127
+
128
+ freqs = torch.outer(t, self.inv_freq)
129
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
130
+ emb = torch.cat((freqs, freqs), dim=-1)
131
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
132
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
133
+
134
+
135
+ # Copied from transformers.models.falcon.modeling_falcon.FalconDynamicNTKScalingRotaryEmbedding with Falcon->Phi
136
+ class PhiDynamicNTKScalingRotaryEmbedding(PhiRotaryEmbedding):
137
+ """PhiRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
138
+
139
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
140
+ self.scaling_factor = scaling_factor
141
+ super().__init__(dim, max_position_embeddings, base, device)
142
+
143
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
144
+ self.max_seq_len_cached = seq_len
145
+
146
+ if seq_len > self.max_position_embeddings:
147
+ base = self.base * (
148
+ (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
149
+ ) ** (self.dim / (self.dim - 2))
150
+ inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
151
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
152
+
153
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
154
+
155
+ freqs = torch.outer(t, self.inv_freq)
156
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
157
+ emb = torch.cat((freqs, freqs), dim=-1)
158
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
159
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
160
+
161
+
162
+ # Copied from transformers.models.llama.modeling_llama.rotate_half
163
+ def rotate_half(x):
164
+ """Rotates half the hidden dims of the input."""
165
+ x1 = x[..., : x.shape[-1] // 2]
166
+ x2 = x[..., x.shape[-1] // 2 :]
167
+ return torch.cat((-x2, x1), dim=-1)
168
+
169
+
170
+ # Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb
171
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
172
+ """Applies Rotary Position Embedding to the query and key tensors.
173
+
174
+ Args:
175
+ q (`torch.Tensor`): The query tensor.
176
+ k (`torch.Tensor`): The key tensor.
177
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
178
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
179
+ position_ids (`torch.Tensor`):
180
+ The position indices of the tokens corresponding to the query and key tensors. For example, this can be
181
+ used to pass offsetted position ids when working with a KV-cache.
182
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
183
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
184
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
185
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
186
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
187
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
188
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
189
+ Returns:
190
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
191
+ """
192
+ cos = cos[position_ids].unsqueeze(unsqueeze_dim)
193
+ sin = sin[position_ids].unsqueeze(unsqueeze_dim)
194
+ q_embed = (q * cos) + (rotate_half(q) * sin)
195
+ k_embed = (k * cos) + (rotate_half(k) * sin)
196
+ return q_embed, k_embed
197
+
198
+
199
+ # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Phi
200
+ class PhiMLP(nn.Module):
201
+ def __init__(self, config):
202
+ super().__init__()
203
+ self.config = config
204
+ self.activation_fn = ACT2FN[config.hidden_act]
205
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
206
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
207
+
208
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
209
+ hidden_states = self.fc1(hidden_states)
210
+ hidden_states = self.activation_fn(hidden_states)
211
+ hidden_states = self.fc2(hidden_states)
212
+ return hidden_states
213
+
214
+
215
+ # Copied from transformers.models.llama.modeling_llama.repeat_kv with llama->phi
216
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
217
+ """
218
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
219
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
220
+ """
221
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
222
+ if n_rep == 1:
223
+ return hidden_states
224
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
225
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
226
+
227
+
228
+ class PhiAttention(nn.Module):
229
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
230
+
231
+ def __init__(self, config: PhiConfig, layer_idx: Optional[int] = None):
232
+ super().__init__()
233
+ self.config = config
234
+ self.layer_idx = layer_idx
235
+ if layer_idx is None:
236
+ logger.warning_once(
237
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
238
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
239
+ "when creating this class."
240
+ )
241
+
242
+ self.attention_dropout = config.attention_dropout
243
+ self.hidden_size = config.hidden_size
244
+ self.num_heads = config.num_attention_heads
245
+ self.head_dim = self.hidden_size // self.num_heads
246
+ self.num_key_value_heads = config.num_key_value_heads
247
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
248
+ self.max_position_embeddings = config.max_position_embeddings
249
+ self.rope_theta = config.rope_theta
250
+ self.partial_rotary_factor = config.partial_rotary_factor
251
+ self.is_causal = True
252
+
253
+ if (self.head_dim * self.num_heads) != self.hidden_size:
254
+ raise ValueError(
255
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
256
+ f" and `num_heads`: {self.num_heads})."
257
+ )
258
+
259
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
260
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
261
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
262
+ self.dense = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=True)
263
+
264
+ self.qk_layernorm = config.qk_layernorm
265
+ if self.qk_layernorm:
266
+ self.q_layernorm = nn.LayerNorm(
267
+ config.hidden_size // self.num_heads, eps=config.layer_norm_eps, elementwise_affine=True
268
+ )
269
+ self.k_layernorm = nn.LayerNorm(
270
+ config.hidden_size // self.num_heads, eps=config.layer_norm_eps, elementwise_affine=True
271
+ )
272
+
273
+ self._init_rope()
274
+
275
+ def _init_rope(self):
276
+ if self.config.rope_scaling is None:
277
+ self.rotary_emb = PhiRotaryEmbedding(
278
+ int(self.partial_rotary_factor * self.head_dim),
279
+ max_position_embeddings=self.max_position_embeddings,
280
+ base=self.rope_theta,
281
+ )
282
+ else:
283
+ scaling_type = self.config.rope_scaling["type"]
284
+ scaling_factor = self.config.rope_scaling["factor"]
285
+ if scaling_type == "linear":
286
+ self.rotary_emb = PhiLinearScalingRotaryEmbedding(
287
+ int(self.partial_rotary_factor * self.head_dim),
288
+ max_position_embeddings=self.max_position_embeddings,
289
+ scaling_factor=scaling_factor,
290
+ base=self.rope_theta,
291
+ )
292
+ elif scaling_type == "dynamic":
293
+ self.rotary_emb = PhiDynamicNTKScalingRotaryEmbedding(
294
+ int(self.partial_rotary_factor * self.head_dim),
295
+ max_position_embeddings=self.max_position_embeddings,
296
+ scaling_factor=scaling_factor,
297
+ base=self.rope_theta,
298
+ )
299
+ else:
300
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
301
+
302
+ def forward(
303
+ self,
304
+ hidden_states: torch.Tensor,
305
+ attention_mask: Optional[torch.Tensor] = None,
306
+ position_ids: Optional[torch.LongTensor] = None,
307
+ past_key_value: Optional[Cache] = None,
308
+ output_attentions: bool = False,
309
+ use_cache: bool = False,
310
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
311
+ bsz, q_len, _ = hidden_states.size()
312
+
313
+ query_states = self.q_proj(hidden_states)
314
+ key_states = self.k_proj(hidden_states)
315
+ value_states = self.v_proj(hidden_states)
316
+
317
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
318
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
319
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
320
+
321
+ if self.qk_layernorm:
322
+ query_states = self.q_layernorm(query_states)
323
+ key_states = self.k_layernorm(key_states)
324
+
325
+ kv_seq_len = key_states.shape[-2]
326
+ if past_key_value is not None:
327
+ if self.layer_idx is None:
328
+ raise ValueError(
329
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
330
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
331
+ "with a layer index."
332
+ )
333
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
334
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
335
+
336
+ # Partial rotary embedding
337
+ query_rot, query_pass = (
338
+ query_states[..., : self.rotary_emb.dim],
339
+ query_states[..., self.rotary_emb.dim :],
340
+ )
341
+ key_rot, key_pass = (
342
+ key_states[..., : self.rotary_emb.dim],
343
+ key_states[..., self.rotary_emb.dim :],
344
+ )
345
+ # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
346
+ query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids)
347
+
348
+ # [batch_size, seq_length, num_heads, head_dim]
349
+ query_states = torch.cat((query_rot, query_pass), dim=-1)
350
+ key_states = torch.cat((key_rot, key_pass), dim=-1)
351
+
352
+ if past_key_value is not None:
353
+ cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim}
354
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
355
+
356
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
357
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
358
+
359
+ # Queries and keys upcast to fp32 is required by Phi-2 to avoid overflow
360
+ attn_weights = torch.matmul(
361
+ query_states.to(torch.float32), key_states.to(torch.float32).transpose(2, 3)
362
+ ) / math.sqrt(self.head_dim)
363
+
364
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
365
+ raise ValueError(
366
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
367
+ f" {attn_weights.size()}"
368
+ )
369
+
370
+ if attention_mask is not None:
371
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
372
+ raise ValueError(
373
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
374
+ )
375
+ attn_weights = attn_weights + attention_mask
376
+
377
+ # upcast attention to fp32
378
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(value_states.dtype)
379
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
380
+
381
+ attn_output = torch.matmul(attn_weights, value_states)
382
+
383
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
384
+ raise ValueError(
385
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
386
+ f" {attn_output.size()}"
387
+ )
388
+
389
+ attn_output = attn_output.transpose(1, 2).contiguous()
390
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
391
+
392
+ attn_output = self.dense(attn_output)
393
+
394
+ if not output_attentions:
395
+ attn_weights = None
396
+
397
+ return attn_output, attn_weights, past_key_value
398
+
399
+
400
+ class PhiFlashAttention2(PhiAttention):
401
+ """
402
+ Phi flash attention module. This module inherits from `PhiAttention` as the weights of the module stays
403
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
404
+ flash attention and deal with padding tokens in case the input contains any of them.
405
+ """
406
+
407
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
408
+ def __init__(self, *args, **kwargs):
409
+ super().__init__(*args, **kwargs)
410
+
411
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
412
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
413
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
414
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
415
+
416
+ def forward(
417
+ self,
418
+ hidden_states: torch.Tensor,
419
+ attention_mask: Optional[torch.LongTensor] = None,
420
+ position_ids: Optional[torch.LongTensor] = None,
421
+ past_key_value: Optional[Cache] = None,
422
+ output_attentions: bool = False,
423
+ use_cache: bool = False,
424
+ **kwargs,
425
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
426
+ # PhiFlashAttention2 attention does not support output_attentions
427
+
428
+ output_attentions = False
429
+
430
+ bsz, q_len, _ = hidden_states.size()
431
+
432
+ query_states = self.q_proj(hidden_states)
433
+ key_states = self.k_proj(hidden_states)
434
+ value_states = self.v_proj(hidden_states)
435
+
436
+ # Flash attention requires the input to have the shape
437
+ # batch_size x seq_length x head_dim x hidden_dim
438
+ # therefore we just need to keep the original shape
439
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
440
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
441
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
442
+
443
+ if self.qk_layernorm:
444
+ query_states = self.q_layernorm(query_states)
445
+ key_states = self.k_layernorm(key_states)
446
+
447
+ kv_seq_len = key_states.shape[-2]
448
+ if past_key_value is not None:
449
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
450
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
451
+
452
+ # Partial rotary embedding
453
+ query_rot, query_pass = (
454
+ query_states[..., : self.rotary_emb.dim],
455
+ query_states[..., self.rotary_emb.dim :],
456
+ )
457
+ key_rot, key_pass = (
458
+ key_states[..., : self.rotary_emb.dim],
459
+ key_states[..., self.rotary_emb.dim :],
460
+ )
461
+ # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
462
+ query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids)
463
+
464
+ # [batch_size, seq_length, num_heads, head_dim]
465
+ query_states = torch.cat((query_rot, query_pass), dim=-1)
466
+ key_states = torch.cat((key_rot, key_pass), dim=-1)
467
+
468
+ if past_key_value is not None:
469
+ cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim}
470
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
471
+
472
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
473
+ # to be able to avoid many of these transpose/reshape/view.
474
+ query_states = query_states.transpose(1, 2)
475
+ key_states = key_states.transpose(1, 2)
476
+ value_states = value_states.transpose(1, 2)
477
+
478
+ attn_dropout = self.attention_dropout if self.training else 0.0
479
+
480
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
481
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
482
+ # cast them back in the correct dtype just to be sure everything works as expected.
483
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
484
+ # in fp32.
485
+
486
+ if query_states.dtype == torch.float32:
487
+ if torch.is_autocast_enabled():
488
+ target_dtype = torch.get_autocast_gpu_dtype()
489
+ # Handle the case where the model is quantized
490
+ elif hasattr(self.config, "_pre_quantization_dtype"):
491
+ target_dtype = self.config._pre_quantization_dtype
492
+ else:
493
+ target_dtype = self.q_proj.weight.dtype
494
+
495
+ logger.warning_once(
496
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
497
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
498
+ f" {target_dtype}."
499
+ )
500
+
501
+ query_states = query_states.to(target_dtype)
502
+ key_states = key_states.to(target_dtype)
503
+ value_states = value_states.to(target_dtype)
504
+
505
+ attn_output = self._flash_attention_forward(
506
+ query_states, key_states, value_states, attention_mask, q_len, dropout=attn_dropout, softmax_scale=None
507
+ )
508
+
509
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
510
+ attn_output = self.dense(attn_output)
511
+
512
+ if not output_attentions:
513
+ attn_weights = None
514
+
515
+ return attn_output, attn_weights, past_key_value
516
+
517
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
518
+ def _flash_attention_forward(
519
+ self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
520
+ ):
521
+ """
522
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
523
+ first unpad the input, then computes the attention scores and pad the final attention scores.
524
+
525
+ Args:
526
+ query_states (`torch.Tensor`):
527
+ Input query states to be passed to Flash Attention API
528
+ key_states (`torch.Tensor`):
529
+ Input key states to be passed to Flash Attention API
530
+ value_states (`torch.Tensor`):
531
+ Input value states to be passed to Flash Attention API
532
+ attention_mask (`torch.Tensor`):
533
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
534
+ position of padding tokens and 1 for the position of non-padding tokens.
535
+ dropout (`float`):
536
+ Attention dropout
537
+ softmax_scale (`float`, *optional*):
538
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
539
+ """
540
+ if not self._flash_attn_uses_top_left_mask:
541
+ causal = self.is_causal
542
+ else:
543
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
544
+ causal = self.is_causal and query_length != 1
545
+
546
+ # Contains at least one padding token in the sequence
547
+ if attention_mask is not None:
548
+ batch_size = query_states.shape[0]
549
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
550
+ query_states, key_states, value_states, attention_mask, query_length
551
+ )
552
+
553
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
554
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
555
+
556
+ attn_output_unpad = flash_attn_varlen_func(
557
+ query_states,
558
+ key_states,
559
+ value_states,
560
+ cu_seqlens_q=cu_seqlens_q,
561
+ cu_seqlens_k=cu_seqlens_k,
562
+ max_seqlen_q=max_seqlen_in_batch_q,
563
+ max_seqlen_k=max_seqlen_in_batch_k,
564
+ dropout_p=dropout,
565
+ softmax_scale=softmax_scale,
566
+ causal=causal,
567
+ )
568
+
569
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
570
+ else:
571
+ attn_output = flash_attn_func(
572
+ query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
573
+ )
574
+
575
+ return attn_output
576
+
577
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
578
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
579
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
580
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
581
+
582
+ key_layer = index_first_axis(
583
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
584
+ )
585
+ value_layer = index_first_axis(
586
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
587
+ )
588
+ if query_length == kv_seq_len:
589
+ query_layer = index_first_axis(
590
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
591
+ )
592
+ cu_seqlens_q = cu_seqlens_k
593
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
594
+ indices_q = indices_k
595
+ elif query_length == 1:
596
+ max_seqlen_in_batch_q = 1
597
+ cu_seqlens_q = torch.arange(
598
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
599
+ ) # There is a memcpy here, that is very bad.
600
+ indices_q = cu_seqlens_q[:-1]
601
+ query_layer = query_layer.squeeze(1)
602
+ else:
603
+ # The -q_len: slice assumes left padding.
604
+ attention_mask = attention_mask[:, -query_length:]
605
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
606
+
607
+ return (
608
+ query_layer,
609
+ key_layer,
610
+ value_layer,
611
+ indices_q,
612
+ (cu_seqlens_q, cu_seqlens_k),
613
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
614
+ )
615
+
616
+
617
+ class PhiSdpaAttention(PhiAttention):
618
+ def __init__(self, *args, **kwargs):
619
+ super().__init__(*args, **kwargs)
620
+ self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0")
621
+
622
+ """
623
+ SDPA attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
624
+ `PhiAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
625
+ SDPA API.
626
+ """
627
+
628
+ # Adapted from PhiAttention.forward
629
+ def forward(
630
+ self,
631
+ hidden_states: torch.Tensor,
632
+ attention_mask: Optional[torch.Tensor] = None,
633
+ position_ids: Optional[torch.LongTensor] = None,
634
+ past_key_value: Optional[Cache] = None,
635
+ output_attentions: bool = False,
636
+ use_cache: bool = False,
637
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
638
+ if output_attentions:
639
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
640
+ logger.warning_once(
641
+ "PhiModel is using PhiSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not "
642
+ "support `output_attentions=True`. Falling back to the manual attention implementation, but specifying "
643
+ "the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can "
644
+ 'be removed using the argument `attn_implementation="eager"` when loading the model.'
645
+ )
646
+ return super().forward(
647
+ hidden_states=hidden_states,
648
+ attention_mask=attention_mask,
649
+ position_ids=position_ids,
650
+ past_key_value=past_key_value,
651
+ output_attentions=output_attentions,
652
+ use_cache=use_cache,
653
+ )
654
+
655
+ bsz, q_len, _ = hidden_states.size()
656
+
657
+ query_states = self.q_proj(hidden_states)
658
+ key_states = self.k_proj(hidden_states)
659
+ value_states = self.v_proj(hidden_states)
660
+
661
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
662
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
663
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
664
+
665
+ if self.qk_layernorm:
666
+ query_states = self.q_layernorm(query_states)
667
+ key_states = self.k_layernorm(key_states)
668
+
669
+ kv_seq_len = key_states.shape[-2]
670
+ if past_key_value is not None:
671
+ if self.layer_idx is None:
672
+ raise ValueError(
673
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
674
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
675
+ "with a layer index."
676
+ )
677
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
678
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
679
+
680
+ # Partial rotary embedding
681
+ query_rot, query_pass = (
682
+ query_states[..., : self.rotary_emb.dim],
683
+ query_states[..., self.rotary_emb.dim :],
684
+ )
685
+ key_rot, key_pass = (
686
+ key_states[..., : self.rotary_emb.dim],
687
+ key_states[..., self.rotary_emb.dim :],
688
+ )
689
+ # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
690
+ query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids)
691
+
692
+ # [batch_size, seq_length, num_heads, head_dim]
693
+ query_states = torch.cat((query_rot, query_pass), dim=-1)
694
+ key_states = torch.cat((key_rot, key_pass), dim=-1)
695
+
696
+ if past_key_value is not None:
697
+ cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim}
698
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
699
+
700
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
701
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
702
+
703
+ # SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom
704
+ # attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0.
705
+ # Reference: https://github.com/pytorch/pytorch/issues/112577
706
+ if self.require_contiguous_qkv and query_states.device.type == "cuda" and attention_mask is not None:
707
+ query_states = query_states.contiguous()
708
+ key_states = key_states.contiguous()
709
+ value_states = value_states.contiguous()
710
+
711
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
712
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
713
+ is_causal = True if self.is_causal and attention_mask is None and q_len > 1 else False
714
+
715
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
716
+ query_states,
717
+ key_states,
718
+ value_states,
719
+ attn_mask=attention_mask,
720
+ dropout_p=self.attention_dropout if self.training else 0.0,
721
+ is_causal=is_causal,
722
+ )
723
+
724
+ attn_output = attn_output.transpose(1, 2).contiguous()
725
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
726
+
727
+ attn_output = self.dense(attn_output)
728
+
729
+ return attn_output, None, past_key_value
730
+
731
+
732
+ PHI_ATTENTION_CLASSES = {
733
+ "eager": PhiAttention,
734
+ "flash_attention_2": PhiFlashAttention2,
735
+ "sdpa": PhiSdpaAttention,
736
+ }
737
+
738
+
739
+ class PhiDecoderLayer(nn.Module):
740
+ def __init__(self, config: PhiConfig, layer_idx: int):
741
+ super().__init__()
742
+ self.self_attn = PHI_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx)
743
+ self.mlp = PhiMLP(config)
744
+ self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
745
+ self.resid_dropout = nn.Dropout(config.resid_pdrop)
746
+
747
+ def forward(
748
+ self,
749
+ hidden_states: torch.Tensor,
750
+ attention_mask: Optional[torch.Tensor] = None,
751
+ position_ids: Optional[torch.LongTensor] = None,
752
+ output_attentions: Optional[bool] = False,
753
+ use_cache: Optional[bool] = False,
754
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
755
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
756
+ """
757
+ Args:
758
+ hidden_states (`torch.FloatTensor`):
759
+ input to the layer of shape `(batch, seq_len, embed_dim)`
760
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
761
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
762
+ position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
763
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range
764
+ `[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
765
+ output_attentions (`bool`, *optional*):
766
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
767
+ returned tensors for more detail.
768
+ use_cache (`bool`, *optional*):
769
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
770
+ (see `past_key_values`).
771
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
772
+ """
773
+
774
+ residual = hidden_states
775
+
776
+ hidden_states = self.input_layernorm(hidden_states)
777
+
778
+ # Self Attention
779
+ attn_outputs, self_attn_weights, present_key_value = self.self_attn(
780
+ hidden_states=hidden_states,
781
+ attention_mask=attention_mask,
782
+ position_ids=position_ids,
783
+ past_key_value=past_key_value,
784
+ output_attentions=output_attentions,
785
+ use_cache=use_cache,
786
+ )
787
+ attn_outputs = self.resid_dropout(attn_outputs)
788
+
789
+ feed_forward_hidden_states = self.resid_dropout(self.mlp(hidden_states))
790
+ hidden_states = attn_outputs + feed_forward_hidden_states + residual
791
+ outputs = (hidden_states,)
792
+
793
+ if output_attentions:
794
+ outputs += (self_attn_weights,)
795
+
796
+ if use_cache:
797
+ outputs += (present_key_value,)
798
+
799
+ return outputs
800
+
801
+
802
+ PHI_START_DOCSTRING = r"""
803
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
804
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
805
+ etc.)
806
+
807
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
808
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
809
+ and behavior.
810
+
811
+ Parameters:
812
+ config ([`PhiConfig`]):
813
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
814
+ load the weights associated with the model, only the configuration. Check out the
815
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
816
+ """
817
+
818
+
819
+ @add_start_docstrings(
820
+ "The bare Phi Model outputting raw hidden-states without any specific head on top.",
821
+ PHI_START_DOCSTRING,
822
+ )
823
+ class PhiPreTrainedModel(PreTrainedModel):
824
+ config_class = PhiConfig
825
+ base_model_prefix = "model"
826
+ supports_gradient_checkpointing = True
827
+ _no_split_modules = ["PhiDecoderLayer"]
828
+ _skip_keys_device_placement = "past_key_values"
829
+ _supports_flash_attn_2 = True
830
+ _supports_sdpa = True
831
+ _supports_cache_class = True
832
+
833
+ def _init_weights(self, module):
834
+ std = self.config.initializer_range
835
+ if isinstance(module, nn.Linear):
836
+ module.weight.data.normal_(mean=0.0, std=std)
837
+ if module.bias is not None:
838
+ module.bias.data.zero_()
839
+ elif isinstance(module, nn.Embedding):
840
+ module.weight.data.normal_(mean=0.0, std=std)
841
+ if module.padding_idx is not None:
842
+ module.weight.data[module.padding_idx].zero_()
843
+
844
+
845
+ PHI_INPUTS_DOCSTRING = r"""
846
+ Args:
847
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
848
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
849
+ it.
850
+
851
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
852
+ [`PreTrainedTokenizer.__call__`] for details.
853
+
854
+ [What are input IDs?](../glossary#input-ids)
855
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
856
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
857
+
858
+ - 1 for tokens that are **not masked**,
859
+ - 0 for tokens that are **masked**.
860
+
861
+ [What are attention masks?](../glossary#attention-mask)
862
+
863
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
864
+ [`PreTrainedTokenizer.__call__`] for details.
865
+
866
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
867
+ `past_key_values`).
868
+
869
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
870
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
871
+ information on the default strategy.
872
+
873
+ - 1 indicates the head is **not masked**,
874
+ - 0 indicates the head is **masked**.
875
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
876
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
877
+ config.n_positions - 1]`.
878
+
879
+ [What are position IDs?](../glossary#position-ids)
880
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
881
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
882
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
883
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
884
+
885
+ Two formats are allowed:
886
+ - a [`~cache_utils.Cache`] instance;
887
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
888
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
889
+ cache format.
890
+
891
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
892
+ legacy cache format will be returned.
893
+
894
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
895
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
896
+ of shape `(batch_size, sequence_length)`.
897
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
898
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
899
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
900
+ model's internal embedding lookup matrix.
901
+ use_cache (`bool`, *optional*):
902
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
903
+ `past_key_values`).
904
+ output_attentions (`bool`, *optional*):
905
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
906
+ tensors for more detail.
907
+ output_hidden_states (`bool`, *optional*):
908
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
909
+ more detail.
910
+ return_dict (`bool`, *optional*):
911
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
912
+ """
913
+
914
+
915
+ @add_start_docstrings(
916
+ "The bare Phi Model outputting raw hidden-states without any specific head on top.",
917
+ PHI_START_DOCSTRING,
918
+ )
919
+ class PhiModel(PhiPreTrainedModel):
920
+ """
921
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`PhiDecoderLayer`]
922
+
923
+ Args:
924
+ config: PhiConfig
925
+ """
926
+
927
+ def __init__(self, config: PhiConfig):
928
+ super().__init__(config)
929
+ self.padding_idx = config.pad_token_id
930
+ self.vocab_size = config.vocab_size
931
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
932
+ self.embed_dropout = nn.Dropout(config.embd_pdrop)
933
+ print("attention implementation: ", config._attn_implementation)
934
+ self.layers = nn.ModuleList(
935
+ [PhiDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
936
+ )
937
+ self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
938
+
939
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
940
+ self._use_sdpa = config._attn_implementation == "sdpa"
941
+
942
+ self.gradient_checkpointing = False
943
+ # Initialize weights and apply final processing
944
+ self.post_init()
945
+
946
+ def get_input_embeddings(self):
947
+ return self.embed_tokens
948
+
949
+ def set_input_embeddings(self, value):
950
+ self.embed_tokens = value
951
+
952
+ @add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING)
953
+ def forward(
954
+ self,
955
+ input_ids: torch.LongTensor = None,
956
+ attention_mask: Optional[torch.Tensor] = None,
957
+ position_ids: Optional[torch.LongTensor] = None,
958
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
959
+ inputs_embeds: Optional[torch.FloatTensor] = None,
960
+ use_cache: Optional[bool] = None,
961
+ output_attentions: Optional[bool] = None,
962
+ output_hidden_states: Optional[bool] = None,
963
+ return_dict: Optional[bool] = None,
964
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
965
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
966
+ output_hidden_states = (
967
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
968
+ )
969
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
970
+
971
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
972
+
973
+ # retrieve input_ids and inputs_embeds
974
+ if input_ids is not None and inputs_embeds is not None:
975
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
976
+ elif input_ids is not None:
977
+ batch_size, seq_length = input_ids.shape[:2]
978
+ elif inputs_embeds is not None:
979
+ batch_size, seq_length = inputs_embeds.shape[:2]
980
+ else:
981
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
982
+
983
+ past_key_values_length = 0
984
+
985
+ if self.gradient_checkpointing and self.training:
986
+ if use_cache:
987
+ logger.warning_once(
988
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
989
+ )
990
+ use_cache = False
991
+
992
+ if use_cache:
993
+ use_legacy_cache = not isinstance(past_key_values, Cache)
994
+ if use_legacy_cache:
995
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
996
+ past_key_values_length = past_key_values.get_usable_length(seq_length)
997
+
998
+ if position_ids is None:
999
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
1000
+ position_ids = torch.arange(
1001
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
1002
+ )
1003
+ position_ids = position_ids.unsqueeze(0)
1004
+
1005
+ if inputs_embeds is None:
1006
+ inputs_embeds = self.embed_tokens(input_ids)
1007
+
1008
+ inputs_embeds = self.embed_dropout(inputs_embeds)
1009
+ # commented by Xavier
1010
+ # Attention mask.
1011
+ # if self._use_flash_attention_2:
1012
+ # # 2d mask is passed through the layers
1013
+ # attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
1014
+ # elif self._use_sdpa and not output_attentions:
1015
+ # attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
1016
+ # attention_mask,
1017
+ # (batch_size, seq_length),
1018
+ # inputs_embeds,
1019
+ # past_key_values_length,
1020
+ # )
1021
+ # else:
1022
+ # # 4d mask is passed through the layers
1023
+ # attention_mask = _prepare_4d_causal_attention_mask(
1024
+ # attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
1025
+ # )
1026
+ # commented by Xavier
1027
+
1028
+ hidden_states = inputs_embeds
1029
+
1030
+ # decoder layers
1031
+ all_hidden_states = () if output_hidden_states else None
1032
+ all_self_attns = () if output_attentions else None
1033
+ next_decoder_cache = None
1034
+ for decoder_layer in self.layers:
1035
+ if output_hidden_states:
1036
+ all_hidden_states += (hidden_states,)
1037
+
1038
+ if self.gradient_checkpointing and self.training:
1039
+ layer_outputs = self._gradient_checkpointing_func(
1040
+ decoder_layer.__call__,
1041
+ hidden_states,
1042
+ attention_mask,
1043
+ position_ids,
1044
+ past_key_values,
1045
+ output_attentions,
1046
+ )
1047
+ else:
1048
+ layer_outputs = decoder_layer(
1049
+ hidden_states,
1050
+ attention_mask=attention_mask,
1051
+ position_ids=position_ids,
1052
+ past_key_value=past_key_values,
1053
+ output_attentions=output_attentions,
1054
+ use_cache=use_cache,
1055
+ )
1056
+
1057
+ hidden_states = layer_outputs[0]
1058
+
1059
+ if use_cache:
1060
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
1061
+
1062
+ if output_attentions:
1063
+ all_self_attns += (layer_outputs[1],)
1064
+
1065
+ hidden_states = self.final_layernorm(hidden_states)
1066
+
1067
+ # add hidden states from the last decoder layer
1068
+ if output_hidden_states:
1069
+ all_hidden_states += (hidden_states,)
1070
+
1071
+ next_cache = None
1072
+ if use_cache:
1073
+ next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
1074
+ if not return_dict:
1075
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
1076
+ return BaseModelOutputWithPast(
1077
+ last_hidden_state=hidden_states,
1078
+ past_key_values=next_cache,
1079
+ hidden_states=all_hidden_states,
1080
+ attentions=all_self_attns,
1081
+ )
1082
+
1083
+
1084
+ class PhiForCausalLM(PhiPreTrainedModel):
1085
+ _tied_weights_keys = ["lm_head.weight"]
1086
+ def __init__(self, config):
1087
+ super().__init__(config)
1088
+ config.qk_layernorm = True
1089
+ config.use_cache = False
1090
+ self.model = PhiModel(config)
1091
+ self.vocab_size = config.vocab_size
1092
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=True)
1093
+
1094
+ # Initialize weights and apply final processing
1095
+ self.post_init()
1096
+
1097
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_input_embeddings
1098
+ def get_input_embeddings(self):
1099
+ return self.model.embed_tokens
1100
+
1101
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_input_embeddings
1102
+ def set_input_embeddings(self, value):
1103
+ self.model.embed_tokens = value
1104
+
1105
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_output_embeddings
1106
+ def get_output_embeddings(self):
1107
+ return self.lm_head
1108
+
1109
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_output_embeddings
1110
+ def set_output_embeddings(self, new_embeddings):
1111
+ self.lm_head = new_embeddings
1112
+
1113
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_decoder
1114
+ def set_decoder(self, decoder):
1115
+ self.model = decoder
1116
+
1117
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_decoder
1118
+ def get_decoder(self):
1119
+ return self.model
1120
+
1121
+ @add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING)
1122
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1123
+ def forward(
1124
+ self,
1125
+ input_ids: torch.LongTensor = None,
1126
+ attention_mask: Optional[torch.Tensor] = None,
1127
+ position_ids: Optional[torch.LongTensor] = None,
1128
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1129
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1130
+ labels: Optional[torch.LongTensor] = None,
1131
+ use_cache: Optional[bool] = None,
1132
+ output_attentions: Optional[bool] = None,
1133
+ output_hidden_states: Optional[bool] = None,
1134
+ return_dict: Optional[bool] = None,
1135
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1136
+ r"""
1137
+ Args:
1138
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1139
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1140
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1141
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1142
+
1143
+ Returns:
1144
+
1145
+ Example:
1146
+
1147
+ ```python
1148
+ >>> from transformers import AutoTokenizer, PhiForCausalLM
1149
+
1150
+ >>> model = PhiForCausalLM.from_pretrained("microsoft/phi-1")
1151
+ >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-1")
1152
+
1153
+ >>> prompt = "This is an example script ."
1154
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1155
+
1156
+ >>> # Generate
1157
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1158
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1159
+ 'This is an example script .\n\n\n\nfrom typing import List\n\ndef find_most_common_letter(words: List[str'
1160
+ ```"""
1161
+
1162
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1163
+ output_hidden_states = (
1164
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1165
+ )
1166
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1167
+
1168
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1169
+ outputs = self.model(
1170
+ input_ids=input_ids,
1171
+ attention_mask=attention_mask,
1172
+ position_ids=position_ids,
1173
+ past_key_values=past_key_values,
1174
+ inputs_embeds=inputs_embeds,
1175
+ use_cache=use_cache,
1176
+ output_attentions=output_attentions,
1177
+ output_hidden_states=output_hidden_states,
1178
+ return_dict=return_dict,
1179
+ )
1180
+
1181
+ hidden_states = outputs[0]
1182
+ logits = self.lm_head(hidden_states)
1183
+ logits = logits.float()
1184
+
1185
+ loss = None
1186
+ if labels is not None:
1187
+ # Shift so that tokens < n predict n
1188
+ shift_logits = logits[..., :-1, :].contiguous()
1189
+ shift_labels = labels[..., 1:].contiguous()
1190
+ # Flatten the tokens
1191
+ loss_fct = CrossEntropyLoss()
1192
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1193
+ shift_labels = shift_labels.view(-1)
1194
+ # Enable model parallelism
1195
+ shift_labels = shift_labels.to(shift_logits.device)
1196
+ loss = loss_fct(shift_logits, shift_labels)
1197
+
1198
+ if not return_dict:
1199
+ output = (logits,) + outputs[1:]
1200
+ return (loss,) + output if loss is not None else output
1201
+
1202
+ return CausalLMOutputWithPast(
1203
+ loss=loss,
1204
+ logits=logits,
1205
+ past_key_values=outputs.past_key_values,
1206
+ hidden_states=outputs.hidden_states,
1207
+ attentions=outputs.attentions,
1208
+ )
1209
+
1210
+ # Copied from transformers.models.persimmon.modeling_persimmon.PersimmonForCausalLM.prepare_inputs_for_generation
1211
+ def prepare_inputs_for_generation(
1212
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
1213
+ ):
1214
+ if past_key_values is not None:
1215
+ if isinstance(past_key_values, Cache):
1216
+ cache_length = past_key_values.get_seq_length()
1217
+ past_length = past_key_values.seen_tokens
1218
+ max_cache_length = past_key_values.get_max_length()
1219
+ else:
1220
+ cache_length = past_length = past_key_values[0][0].shape[2]
1221
+ max_cache_length = None
1222
+
1223
+ # Keep only the unprocessed tokens:
1224
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1225
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
1226
+ # input)
1227
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
1228
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1229
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1230
+ # input_ids based on the past_length.
1231
+ elif past_length < input_ids.shape[1]:
1232
+ input_ids = input_ids[:, past_length:]
1233
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1234
+
1235
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1236
+ if (
1237
+ max_cache_length is not None
1238
+ and attention_mask is not None
1239
+ and cache_length + input_ids.shape[1] > max_cache_length
1240
+ ):
1241
+ attention_mask = attention_mask[:, -max_cache_length:]
1242
+
1243
+ position_ids = kwargs.get("position_ids", None)
1244
+ if attention_mask is not None and position_ids is None:
1245
+ # create position_ids on the fly for batch generation
1246
+ position_ids = attention_mask.long().cumsum(-1) - 1
1247
+ position_ids.masked_fill_(attention_mask == 0, 1)
1248
+ if past_key_values:
1249
+ position_ids = position_ids[:, -input_ids.shape[1] :]
1250
+
1251
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1252
+ if inputs_embeds is not None and past_key_values is None:
1253
+ model_inputs = {"inputs_embeds": inputs_embeds}
1254
+ else:
1255
+ model_inputs = {"input_ids": input_ids}
1256
+
1257
+ model_inputs.update(
1258
+ {
1259
+ "position_ids": position_ids,
1260
+ "past_key_values": past_key_values,
1261
+ "use_cache": kwargs.get("use_cache"),
1262
+ "attention_mask": attention_mask,
1263
+ }
1264
+ )
1265
+ return model_inputs
1266
+
1267
+ @staticmethod
1268
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM._reorder_cache
1269
+ def _reorder_cache(past_key_values, beam_idx):
1270
+ reordered_past = ()
1271
+ for layer_past in past_key_values:
1272
+ reordered_past += (
1273
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1274
+ )
1275
+ return reordered_past
1276
+
1277
+
1278
+ @add_start_docstrings(
1279
+ """
1280
+ The PhiModel with a sequence classification head on top (linear layer).
1281
+
1282
+ [`PhiForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1283
+ (e.g. GPT-2) do.
1284
+
1285
+ Since it does classification on the last token, it requires to know the position of the last token. If a
1286
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1287
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1288
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1289
+ each row of the batch).
1290
+ """,
1291
+ PHI_START_DOCSTRING,
1292
+ )
1293
+ # Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with LLAMA->PHI,Llama->Phi with self.transformer->self.model, transformer_outputs->model_outputs
1294
+ class PhiForSequenceClassification(PhiPreTrainedModel):
1295
+ def __init__(self, config):
1296
+ super().__init__(config)
1297
+ self.num_labels = config.num_labels
1298
+ self.model = PhiModel(config)
1299
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1300
+
1301
+ # Initialize weights and apply final processing
1302
+ self.post_init()
1303
+
1304
+ def get_input_embeddings(self):
1305
+ return self.model.embed_tokens
1306
+
1307
+ def set_input_embeddings(self, value):
1308
+ self.model.embed_tokens = value
1309
+
1310
+ @add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING)
1311
+ def forward(
1312
+ self,
1313
+ input_ids: torch.LongTensor = None,
1314
+ attention_mask: Optional[torch.Tensor] = None,
1315
+ position_ids: Optional[torch.LongTensor] = None,
1316
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
1317
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1318
+ labels: Optional[torch.LongTensor] = None,
1319
+ use_cache: Optional[bool] = None,
1320
+ output_attentions: Optional[bool] = None,
1321
+ output_hidden_states: Optional[bool] = None,
1322
+ return_dict: Optional[bool] = None,
1323
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1324
+ r"""
1325
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1326
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1327
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1328
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1329
+ """
1330
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1331
+
1332
+ model_outputs = self.model(
1333
+ input_ids,
1334
+ attention_mask=attention_mask,
1335
+ position_ids=position_ids,
1336
+ past_key_values=past_key_values,
1337
+ inputs_embeds=inputs_embeds,
1338
+ use_cache=use_cache,
1339
+ output_attentions=output_attentions,
1340
+ output_hidden_states=output_hidden_states,
1341
+ return_dict=return_dict,
1342
+ )
1343
+ hidden_states = model_outputs[0]
1344
+ logits = self.score(hidden_states)
1345
+
1346
+ if input_ids is not None:
1347
+ batch_size = input_ids.shape[0]
1348
+ else:
1349
+ batch_size = inputs_embeds.shape[0]
1350
+
1351
+ if self.config.pad_token_id is None and batch_size != 1:
1352
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
1353
+ if self.config.pad_token_id is None:
1354
+ sequence_lengths = -1
1355
+ else:
1356
+ if input_ids is not None:
1357
+ # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1358
+ sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1359
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
1360
+ sequence_lengths = sequence_lengths.to(logits.device)
1361
+ else:
1362
+ sequence_lengths = -1
1363
+
1364
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1365
+
1366
+ loss = None
1367
+ if labels is not None:
1368
+ labels = labels.to(logits.device)
1369
+ if self.config.problem_type is None:
1370
+ if self.num_labels == 1:
1371
+ self.config.problem_type = "regression"
1372
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1373
+ self.config.problem_type = "single_label_classification"
1374
+ else:
1375
+ self.config.problem_type = "multi_label_classification"
1376
+
1377
+ if self.config.problem_type == "regression":
1378
+ loss_fct = MSELoss()
1379
+ if self.num_labels == 1:
1380
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1381
+ else:
1382
+ loss = loss_fct(pooled_logits, labels)
1383
+ elif self.config.problem_type == "single_label_classification":
1384
+ loss_fct = CrossEntropyLoss()
1385
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
1386
+ elif self.config.problem_type == "multi_label_classification":
1387
+ loss_fct = BCEWithLogitsLoss()
1388
+ loss = loss_fct(pooled_logits, labels)
1389
+ if not return_dict:
1390
+ output = (pooled_logits,) + model_outputs[1:]
1391
+ return ((loss,) + output) if loss is not None else output
1392
+
1393
+ return SequenceClassifierOutputWithPast(
1394
+ loss=loss,
1395
+ logits=pooled_logits,
1396
+ past_key_values=model_outputs.past_key_values,
1397
+ hidden_states=model_outputs.hidden_states,
1398
+ attentions=model_outputs.attentions,
1399
+ )
1400
+
1401
+
1402
+ @add_start_docstrings(
1403
+ """
1404
+ PhiModel with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
1405
+ Named-Entity-Recognition (NER) tasks.
1406
+ """,
1407
+ PHI_START_DOCSTRING,
1408
+ )
1409
+ # Copied from transformers.models.mpt.modeling_mpt.MptForTokenClassification with MPT->PHI,Mpt->Phi,self.transformer->self.model,transformer_outputs->model_outputs
1410
+ class PhiForTokenClassification(PhiPreTrainedModel):
1411
+ def __init__(self, config: PhiConfig):
1412
+ super().__init__(config)
1413
+ self.num_labels = config.num_labels
1414
+
1415
+ self.model = PhiModel(config)
1416
+ if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None:
1417
+ classifier_dropout = config.classifier_dropout
1418
+ elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None:
1419
+ classifier_dropout = config.hidden_dropout
1420
+ else:
1421
+ classifier_dropout = 0.1
1422
+ self.dropout = nn.Dropout(classifier_dropout)
1423
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1424
+
1425
+ # Initialize weights and apply final processing
1426
+ self.post_init()
1427
+
1428
+ @add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING)
1429
+ @add_code_sample_docstrings(
1430
+ checkpoint=_CHECKPOINT_FOR_DOC,
1431
+ output_type=TokenClassifierOutput,
1432
+ config_class=_CONFIG_FOR_DOC,
1433
+ )
1434
+ def forward(
1435
+ self,
1436
+ input_ids: Optional[torch.LongTensor] = None,
1437
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
1438
+ attention_mask: Optional[torch.Tensor] = None,
1439
+ inputs_embeds: Optional[torch.Tensor] = None,
1440
+ labels: Optional[torch.Tensor] = None,
1441
+ use_cache: Optional[bool] = None,
1442
+ output_attentions: Optional[bool] = None,
1443
+ output_hidden_states: Optional[bool] = None,
1444
+ return_dict: Optional[bool] = None,
1445
+ **deprecated_arguments,
1446
+ ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
1447
+ r"""
1448
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1449
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1450
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1451
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1452
+ """
1453
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1454
+
1455
+ model_outputs = self.model(
1456
+ input_ids,
1457
+ past_key_values=past_key_values,
1458
+ attention_mask=attention_mask,
1459
+ inputs_embeds=inputs_embeds,
1460
+ use_cache=use_cache,
1461
+ output_attentions=output_attentions,
1462
+ output_hidden_states=output_hidden_states,
1463
+ return_dict=return_dict,
1464
+ )
1465
+
1466
+ hidden_states = model_outputs[0]
1467
+ hidden_states = self.dropout(hidden_states)
1468
+ logits = self.classifier(hidden_states)
1469
+
1470
+ loss = None
1471
+ if labels is not None:
1472
+ # move labels to correct device to enable model parallelism
1473
+ labels = labels.to(logits.device)
1474
+ batch_size, seq_length = labels.shape
1475
+ loss_fct = CrossEntropyLoss()
1476
+ loss = loss_fct(
1477
+ logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length)
1478
+ )
1479
+
1480
+ if not return_dict:
1481
+ output = (logits,) + model_outputs[2:]
1482
+ return ((loss,) + output) if loss is not None else output
1483
+
1484
+ return TokenClassifierOutput(
1485
+ loss=loss,
1486
+ logits=logits,
1487
+ hidden_states=model_outputs.hidden_states,
1488
+ attentions=model_outputs.attentions,
1489
+ )
models/sampling.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/lucidrains/muse-maskgit-pytorch
2
+
3
+ import math
4
+ from functools import partial
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+
9
+
10
+ def log(t, eps=1e-20):
11
+ return torch.log(t.clamp(min=eps))
12
+
13
+
14
+ def gumbel_noise(t, generator=None):
15
+ noise = torch.zeros_like(t).uniform_(0, 1, generator=generator)
16
+ return -log(-log(noise))
17
+
18
+
19
+ def gumbel_sample(t, temperature=1.0, dim=-1, generator=None):
20
+ return ((t / max(temperature, 1e-10)) + gumbel_noise(t, generator=generator)).argmax(dim=dim)
21
+
22
+
23
+ def top_k(logits, thres=0.9):
24
+ k = math.ceil((1 - thres) * logits.shape[-1])
25
+ val, ind = logits.topk(k, dim=-1)
26
+ probs = torch.full_like(logits, float("-inf"))
27
+ probs.scatter_(2, ind, val)
28
+ return probs
29
+
30
+
31
+ def mask_by_random_topk(mask_len, probs, temperature=1.0, generator=None):
32
+ confidence = log(probs) + temperature * gumbel_noise(probs, generator=generator)
33
+ sorted_confidence = torch.sort(confidence, dim=-1).values
34
+ cut_off = torch.gather(sorted_confidence, 1, mask_len.long())
35
+ masking = confidence < cut_off
36
+ return masking
37
+
38
+
39
+ def cosine_schedule(t):
40
+ return torch.cos(t * math.pi * 0.5)
41
+
42
+
43
+ def linear_schedule(t):
44
+ mask_ratio = 1 - t
45
+ mask_ratio = mask_ratio.clamp(min=1e-6, max=1.0)
46
+ return mask_ratio
47
+
48
+
49
+ def pow(t, method):
50
+ exponent = float(method.replace("pow", ""))
51
+ mask_ratio = 1.0 - t**exponent
52
+ mask_ratio = mask_ratio.clamp(min=1e-6, max=1.0)
53
+ return mask_ratio
54
+
55
+
56
+ def sigmoid_schedule(t, start=-3, end=3, tau=1.0, clip_min=1e-6):
57
+ for item in [t, start, end, tau]:
58
+ item = torch.tensor(item) if not torch.is_tensor(item) else item
59
+
60
+ # A gamma function based on sigmoid function.
61
+ v_start = torch.sigmoid(torch.tensor(start / tau))
62
+ v_end = torch.sigmoid(torch.tensor(end / tau))
63
+ output = torch.sigmoid((t * (end - start) + start) / tau)
64
+ output = (v_end - output) / (v_end - v_start)
65
+ return torch.clip(output, clip_min, 1.0)
66
+
67
+
68
+ def get_mask_chedule(method, **schedule_kwargs):
69
+ if method == "cosine":
70
+ return cosine_schedule
71
+ elif method == "linear":
72
+ return linear_schedule
73
+ elif "pow" in method:
74
+ return partial(pow, method=method)
75
+ elif method == "sigmoid":
76
+ return partial(sigmoid_schedule, **schedule_kwargs)
77
+ else:
78
+ raise ValueError("Unknown schedule method: {}".format(method))
79
+
80
+ def top_k_top_p_filtering(
81
+ logits: torch.Tensor,
82
+ top_k: int = 0,
83
+ top_p: float = 1.0,
84
+ filter_value: float = -float("Inf"),
85
+ min_tokens_to_keep: int = 1,
86
+ ) -> torch.Tensor:
87
+ """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
88
+ Args:
89
+ logits: logits distribution shape (batch size, vocabulary size)
90
+ if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
91
+ if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
92
+ Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
93
+ Make sure we keep at least min_tokens_to_keep per batch example in the output
94
+ From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
95
+ """
96
+ if top_k > 0:
97
+ top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check
98
+ # Remove all tokens with a probability less than the last token of the top-k
99
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
100
+ logits[indices_to_remove] = filter_value
101
+
102
+ if top_p < 1.0:
103
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
104
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
105
+
106
+ # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
107
+ sorted_indices_to_remove = cumulative_probs > top_p
108
+ if min_tokens_to_keep > 1:
109
+ # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
110
+ sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
111
+ # Shift the indices to the right to keep also the first token above the threshold
112
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
113
+ sorted_indices_to_remove[..., 0] = 0
114
+
115
+ # scatter sorted tensors to original indexing
116
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
117
+ logits[indices_to_remove] = filter_value
118
+ return logits
models/training_utils.py ADDED
@@ -0,0 +1,455 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import copy
17
+ import os
18
+ import random
19
+ from typing import Any, Dict, Iterable, Optional, Union
20
+
21
+ import numpy as np
22
+ import pandas as pd
23
+ import torch
24
+ import torch.nn.functional as F
25
+
26
+
27
+ def enable_full_determinism(seed: int):
28
+ """
29
+ Helper function for reproducible behavior during distributed training. See
30
+ - https://pytorch.org/docs/stable/notes/randomness.html for pytorch
31
+ """
32
+ # set seed first
33
+ set_seed(seed)
34
+
35
+ # Enable PyTorch deterministic mode. This potentially requires either the environment
36
+ # variable 'CUDA_LAUNCH_BLOCKING' or 'CUBLAS_WORKSPACE_CONFIG' to be set,
37
+ # depending on the CUDA version, so we set them both here
38
+ os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
39
+ os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
40
+ torch.use_deterministic_algorithms(True)
41
+
42
+ # Enable CUDNN deterministic mode
43
+ torch.backends.cudnn.deterministic = True
44
+ torch.backends.cudnn.benchmark = False
45
+
46
+
47
+ def set_seed(seed: int):
48
+ """
49
+ Args:
50
+ Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`.
51
+ seed (`int`): The seed to set.
52
+ """
53
+ random.seed(seed)
54
+ np.random.seed(seed)
55
+ torch.manual_seed(seed)
56
+ torch.cuda.manual_seed_all(seed)
57
+ # ^^ safe to call this function even if cuda is not available
58
+
59
+
60
+ # Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
61
+ class EMA:
62
+ """
63
+ Exponential Moving Average of models weights
64
+ """
65
+
66
+ def __init__(
67
+ self,
68
+ parameters: Iterable[torch.nn.Parameter],
69
+ decay: float = 0.9999,
70
+ min_decay: float = 0.0,
71
+ update_after_step: int = 0,
72
+ use_ema_warmup: bool = False,
73
+ inv_gamma: Union[float, int] = 1.0,
74
+ power: Union[float, int] = 2 / 3,
75
+ model_cls: Optional[Any] = None,
76
+ model_config: Dict[str, Any] = None,
77
+ **kwargs,
78
+ ):
79
+ """
80
+ Args:
81
+ parameters (Iterable[torch.nn.Parameter]): The parameters to track.
82
+ decay (float): The decay factor for the exponential moving average.
83
+ min_decay (float): The minimum decay factor for the exponential moving average.
84
+ update_after_step (int): The number of steps to wait before starting to update the EMA weights.
85
+ use_ema_warmup (bool): Whether to use EMA warmup.
86
+ inv_gamma (float):
87
+ Inverse multiplicative factor of EMA warmup. Default: 1. Only used if `use_ema_warmup` is True.
88
+ power (float): Exponential factor of EMA warmup. Default: 2/3. Only used if `use_ema_warmup` is True.
89
+ device (Optional[Union[str, torch.device]]): The device to store the EMA weights on. If None, the EMA
90
+ weights will be stored on CPU.
91
+
92
+ @crowsonkb's notes on EMA Warmup:
93
+ If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan
94
+ to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps),
95
+ gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999
96
+ at 215.4k steps).
97
+ """
98
+
99
+ parameters = list(parameters)
100
+ self.shadow_params = [p.clone().detach() for p in parameters]
101
+
102
+ self.temp_stored_params = None
103
+
104
+ self.decay = decay
105
+ self.min_decay = min_decay
106
+ self.update_after_step = update_after_step
107
+ self.use_ema_warmup = use_ema_warmup
108
+ self.inv_gamma = inv_gamma
109
+ self.power = power
110
+ self.optimization_step = 0
111
+ self.cur_decay_value = None # set in `step()`
112
+
113
+ self.model_cls = model_cls
114
+ self.model_config = model_config
115
+
116
+ @classmethod
117
+ def from_pretrained(cls, path, model_cls) -> "EMA":
118
+ _, ema_kwargs = model_cls.load_config(path, return_unused_kwargs=True)
119
+ model = model_cls.from_pretrained(path)
120
+
121
+ ema_model = cls(model.parameters(), model_cls=model_cls, model_config=model.config)
122
+
123
+ ema_model.load_state_dict(ema_kwargs)
124
+ return ema_model
125
+
126
+ def save_pretrained(self, path):
127
+ if self.model_cls is None:
128
+ raise ValueError("`save_pretrained` can only be used if `model_cls` was defined at __init__.")
129
+
130
+ if self.model_config is None:
131
+ raise ValueError("`save_pretrained` can only be used if `model_config` was defined at __init__.")
132
+
133
+ model = self.model_cls.from_config(self.model_config)
134
+ state_dict = self.state_dict()
135
+ state_dict.pop("shadow_params", None)
136
+
137
+ model.register_to_config(**state_dict)
138
+ self.copy_to(model.parameters())
139
+ model.save_pretrained(path)
140
+
141
+ def get_decay(self, optimization_step: int) -> float:
142
+ """
143
+ Compute the decay factor for the exponential moving average.
144
+ """
145
+ step = max(0, optimization_step - self.update_after_step - 1)
146
+
147
+ if step <= 0:
148
+ return 0.0
149
+
150
+ if self.use_ema_warmup:
151
+ cur_decay_value = 1 - (1 + step / self.inv_gamma) ** -self.power
152
+ else:
153
+ cur_decay_value = (1 + step) / (10 + step)
154
+
155
+ cur_decay_value = min(cur_decay_value, self.decay)
156
+ # make sure decay is not smaller than min_decay
157
+ cur_decay_value = max(cur_decay_value, self.min_decay)
158
+ return cur_decay_value
159
+
160
+ @torch.no_grad()
161
+ def step(self, parameters: Iterable[torch.nn.Parameter]):
162
+ parameters = list(parameters)
163
+
164
+ self.optimization_step += 1
165
+
166
+ # Compute the decay factor for the exponential moving average.
167
+ decay = self.get_decay(self.optimization_step)
168
+ self.cur_decay_value = decay
169
+ one_minus_decay = 1 - decay
170
+
171
+ for s_param, param in zip(self.shadow_params, parameters):
172
+ if param.requires_grad:
173
+ s_param.sub_(one_minus_decay * (s_param - param))
174
+ else:
175
+ s_param.copy_(param)
176
+
177
+ torch.cuda.empty_cache()
178
+
179
+ def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None:
180
+ """
181
+ Copy current averaged parameters into given collection of parameters.
182
+
183
+ Args:
184
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
185
+ updated with the stored moving averages. If `None`, the parameters with which this
186
+ `ExponentialMovingAverage` was initialized will be used.
187
+ """
188
+ parameters = list(parameters)
189
+ for s_param, param in zip(self.shadow_params, parameters):
190
+ param.data.copy_(s_param.to(param.device).data)
191
+
192
+ def to(self, device=None, dtype=None) -> None:
193
+ r"""Move internal buffers of the ExponentialMovingAverage to `device`.
194
+
195
+ Args:
196
+ device: like `device` argument to `torch.Tensor.to`
197
+ """
198
+ # .to() on the tensors handles None correctly
199
+ self.shadow_params = [
200
+ p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device)
201
+ for p in self.shadow_params
202
+ ]
203
+
204
+ def state_dict(self) -> dict:
205
+ r"""
206
+ Returns the state of the ExponentialMovingAverage as a dict. This method is used by accelerate during
207
+ checkpointing to save the ema state dict.
208
+ """
209
+ # Following PyTorch conventions, references to tensors are returned:
210
+ # "returns a reference to the state and not its copy!" -
211
+ # https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict
212
+ return {
213
+ "decay": self.decay,
214
+ "min_decay": self.min_decay,
215
+ "optimization_step": self.optimization_step,
216
+ "update_after_step": self.update_after_step,
217
+ "use_ema_warmup": self.use_ema_warmup,
218
+ "inv_gamma": self.inv_gamma,
219
+ "power": self.power,
220
+ "shadow_params": self.shadow_params,
221
+ }
222
+
223
+ def store(self, parameters: Iterable[torch.nn.Parameter]) -> None:
224
+ r"""
225
+ Args:
226
+ Save the current parameters for restoring later.
227
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
228
+ temporarily stored.
229
+ """
230
+ self.temp_stored_params = [param.detach().cpu().clone() for param in parameters]
231
+
232
+ def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None:
233
+ r"""
234
+ Args:
235
+ Restore the parameters stored with the `store` method. Useful to validate the model with EMA parameters without:
236
+ affecting the original optimization process. Store the parameters before the `copy_to()` method. After
237
+ validation (or model saving), use this to restore the former parameters.
238
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
239
+ updated with the stored parameters. If `None`, the parameters with which this
240
+ `ExponentialMovingAverage` was initialized will be used.
241
+ """
242
+ if self.temp_stored_params is None:
243
+ raise RuntimeError("This ExponentialMovingAverage has no `store()`ed weights to `restore()`")
244
+ for c_param, param in zip(self.temp_stored_params, parameters):
245
+ param.data.copy_(c_param.data)
246
+
247
+ # Better memory-wise.
248
+ self.temp_stored_params = None
249
+
250
+ def load_state_dict(self, state_dict: dict) -> None:
251
+ r"""
252
+ Args:
253
+ Loads the ExponentialMovingAverage state. This method is used by accelerate during checkpointing to save the
254
+ ema state dict.
255
+ state_dict (dict): EMA state. Should be an object returned
256
+ from a call to :meth:`state_dict`.
257
+ """
258
+ # deepcopy, to be consistent with module API
259
+ state_dict = copy.deepcopy(state_dict)
260
+
261
+ self.decay = state_dict.get("decay", self.decay)
262
+ if self.decay < 0.0 or self.decay > 1.0:
263
+ raise ValueError("Decay must be between 0 and 1")
264
+
265
+ self.min_decay = state_dict.get("min_decay", self.min_decay)
266
+ if not isinstance(self.min_decay, float):
267
+ raise ValueError("Invalid min_decay")
268
+
269
+ self.optimization_step = state_dict.get("optimization_step", self.optimization_step)
270
+ if not isinstance(self.optimization_step, int):
271
+ raise ValueError("Invalid optimization_step")
272
+
273
+ self.update_after_step = state_dict.get("update_after_step", self.update_after_step)
274
+ if not isinstance(self.update_after_step, int):
275
+ raise ValueError("Invalid update_after_step")
276
+
277
+ self.use_ema_warmup = state_dict.get("use_ema_warmup", self.use_ema_warmup)
278
+ if not isinstance(self.use_ema_warmup, bool):
279
+ raise ValueError("Invalid use_ema_warmup")
280
+
281
+ self.inv_gamma = state_dict.get("inv_gamma", self.inv_gamma)
282
+ if not isinstance(self.inv_gamma, (float, int)):
283
+ raise ValueError("Invalid inv_gamma")
284
+
285
+ self.power = state_dict.get("power", self.power)
286
+ if not isinstance(self.power, (float, int)):
287
+ raise ValueError("Invalid power")
288
+
289
+ shadow_params = state_dict.get("shadow_params", None)
290
+ if shadow_params is not None:
291
+ self.shadow_params = shadow_params
292
+ if not isinstance(self.shadow_params, list):
293
+ raise ValueError("shadow_params must be a list")
294
+ if not all(isinstance(p, torch.Tensor) for p in self.shadow_params):
295
+ raise ValueError("shadow_params must all be Tensors")
296
+
297
+
298
+ # calculates entropy over each pixel distribution
299
+ def pixel_entropy_per_percent_masked_bucket(logits, input_ids, mask_id):
300
+ # only calculated entropy over image tokens that were masked in the original image
301
+ masked_tokens = input_ids == mask_id
302
+ num_masked_pixels = masked_tokens.sum(-1)
303
+
304
+ probs = F.softmax(logits, dim=-1)
305
+ log_probs = F.log_softmax(logits, dim=-1)
306
+
307
+ entropy_per_pixel = -((probs * log_probs).sum(-1))
308
+
309
+ # the predictions for non-masked aren't used, so set their entropies to zero
310
+ entropy_per_pixel[~masked_tokens] = 0
311
+
312
+ entropy_per_image_numerator = entropy_per_pixel.sum(-1)
313
+ entropy_per_image = entropy_per_image_numerator / num_masked_pixels
314
+
315
+ total_buckets = 10
316
+ masked_buckets = input_ids_to_masked_buckets(input_ids, mask_id, total_buckets)
317
+
318
+ entropy_by_masked_bucket = average_by_buckets(entropy_per_image, masked_buckets, total_buckets)
319
+
320
+ return entropy_by_masked_bucket
321
+
322
+
323
+ # calculates entropy over the averaged distribution of pixels for the whole image
324
+ def image_entropy_per_percent_masked_bucket(logits, input_ids, mask_id):
325
+ # only calculated entropy over image tokens that were masked in the original image
326
+ masked_tokens = input_ids == mask_id
327
+ num_masked_pixels = masked_tokens.sum(-1, keepdim=True)
328
+
329
+ pixel_probs = F.softmax(logits, dim=-1)
330
+ pixel_probs[~masked_tokens] = 0
331
+ image_probs_numerator = pixel_probs.sum(-2)
332
+ image_probs = image_probs_numerator / num_masked_pixels
333
+
334
+ image_log_probs = image_probs.log()
335
+
336
+ entropy_per_image = -((image_probs * image_log_probs).sum(-1))
337
+
338
+ total_buckets = 10
339
+ masked_buckets = input_ids_to_masked_buckets(input_ids, mask_id, total_buckets)
340
+
341
+ entropy_by_masked_bucket = average_by_buckets(entropy_per_image, masked_buckets, total_buckets)
342
+
343
+ return entropy_by_masked_bucket
344
+
345
+
346
+ def cross_entropy_per_percent_masked_bucket(logits, labels, input_ids, mask_id, output_size, label_smoothing):
347
+ cross_entropy_per_image = F.cross_entropy(
348
+ logits.view(-1, output_size),
349
+ labels.view(-1),
350
+ ignore_index=-100,
351
+ label_smoothing=label_smoothing,
352
+ reduction="none",
353
+ )
354
+
355
+ total_buckets = 10
356
+ masked_buckets = input_ids_to_masked_buckets(input_ids, mask_id, total_buckets)
357
+
358
+ cross_entropy_by_percent_masked_bucket = average_by_buckets(cross_entropy_per_image, masked_buckets, total_buckets)
359
+
360
+ return cross_entropy_by_percent_masked_bucket
361
+
362
+
363
+ def token_probability_distributions_per_percent_masked_bucket(logits, input_ids, mask_id):
364
+ probs = F.softmax(logits, dim=-1)
365
+
366
+ total_buckets = 10
367
+ masked_buckets = input_ids_to_masked_buckets(input_ids, mask_id, total_buckets)
368
+
369
+ data = []
370
+
371
+ for bucket_idx in range(total_buckets):
372
+ indices_for_bucket = masked_buckets[masked_buckets == bucket_idx]
373
+
374
+ # It's ok if none were noised in the range of this bucket. This
375
+ # function will be called for a later training step where it's likely
376
+ # there will be an element noised in the range.
377
+ if indices_for_bucket.shape[0] == 0:
378
+ continue
379
+
380
+ index_for_bucket = indices_for_bucket[0]
381
+
382
+ image_probs = probs[index_for_bucket]
383
+
384
+ # find the index of a masked pixel for the image
385
+ input_ids_for_image = input_ids[index_for_bucket]
386
+ masked_pixels_probs = image_probs[input_ids_for_image == mask_id]
387
+
388
+ masked_pixel_probs = masked_pixels_probs[0]
389
+
390
+ masked_pixel_probs = masked_pixel_probs.cpu().numpy()
391
+
392
+ for masked_pixel_prob in masked_pixel_probs:
393
+ data.append({"bucket": bucket_idx, "masked_pixel_prob": masked_pixel_prob})
394
+
395
+ df = pd.DataFrame(data)
396
+
397
+ return df
398
+
399
+
400
+ def average_by_buckets(values, masked_buckets, total_buckets):
401
+ unique_buckets, bucket_counts = masked_buckets.unique(dim=0, return_counts=True)
402
+
403
+ numerator = torch.zeros(total_buckets, device=values.device)
404
+
405
+ numerator.scatter_add_(0, masked_buckets, values)
406
+
407
+ # default value is one because the buckets for which there aren't
408
+ # any values will have a numerator of zero. So we just need to not divide
409
+ # by zero.
410
+ denominator = torch.ones(total_buckets, device=values.device, dtype=torch.long)
411
+ denominator[unique_buckets] = bucket_counts
412
+
413
+ averaged_by_buckets = numerator / denominator
414
+
415
+ return averaged_by_buckets
416
+
417
+
418
+ def input_ids_to_masked_buckets(input_ids, mask_id, total_buckets=10):
419
+ assert total_buckets == 10
420
+
421
+ masked_percent = (input_ids == mask_id).sum(-1) / input_ids.shape[-1]
422
+
423
+ # we do not formally use timesteps to noise images. Instead, we mask a percent
424
+ # of the pixels. We don't want to log entropy for every mask percent between 0 and 1,
425
+ # and we also want to track how the entropy evolves over time w/in a range of mask
426
+ # percents that should have similar entropy. So we bucket the masked percents into a
427
+ # fixed number of buckets
428
+
429
+ # we could generalize this later if needed but for now, let's just assume a fixed
430
+ # number of 10 buckets.
431
+
432
+ # How this maps to a bucket index:
433
+ # (mask) * bucket_index +
434
+ # (mask_1) * bucket_index_1
435
+ #
436
+ # -> Where the mask is true will be set to the expected bucket index,
437
+ # where the mask is false will be set to 0.
438
+ #
439
+ # Given the probabilities are between 0 and 1, each masked_percent will get mapped
440
+ # to a timestep by one and only one of the masks.
441
+
442
+ masked_buckets = (
443
+ ((0 < masked_percent) & (masked_percent <= 0.1)) * 0
444
+ + ((0.1 < masked_percent) & (masked_percent <= 0.2)) * 1
445
+ + ((0.2 < masked_percent) & (masked_percent <= 0.3)) * 2
446
+ + ((0.3 < masked_percent) & (masked_percent <= 0.4)) * 3
447
+ + ((0.4 < masked_percent) & (masked_percent <= 0.5)) * 4
448
+ + ((0.5 < masked_percent) & (masked_percent <= 0.6)) * 5
449
+ + ((0.6 < masked_percent) & (masked_percent <= 0.7)) * 6
450
+ + ((0.7 < masked_percent) & (masked_percent <= 0.8)) * 7
451
+ + ((0.8 < masked_percent) & (masked_percent <= 0.9)) * 8
452
+ + ((0.9 < masked_percent) & (masked_percent <= 1.0)) * 9
453
+ )
454
+
455
+ return masked_buckets
prompting_utils.py ADDED
@@ -0,0 +1,528 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ class UniversalPrompting():
4
+ def __init__(self, text_tokenizer,
5
+ special_tokens=("<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", "<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>"),
6
+ max_text_len=8000, max_seq_len=377, ignore_id=-100, cond_dropout_prob=0.1):
7
+ """
8
+ :param text_tokenizer: original text tokenizer
9
+ """
10
+ self.text_tokenizer = text_tokenizer
11
+ self.text_tokenizer.add_special_tokens({'pad_token': '[PAD]'})
12
+ self.text_tokenizer.add_tokens(list(special_tokens))
13
+ self.sptids_dict = {token: torch.tensor(self.text_tokenizer.convert_tokens_to_ids([token])) for token in
14
+ special_tokens}
15
+ self.sptids_dict['<|sot|>'] = torch.tensor([self.text_tokenizer.bos_token_id])
16
+ self.sptids_dict['<|eot|>'] = torch.tensor([self.text_tokenizer.eos_token_id])
17
+ self.sptids_dict['<|pad|>'] = torch.tensor([self.text_tokenizer.pad_token_id])
18
+ # plus 1 because at this time we add a task token before
19
+ self.max_text_len = max_text_len + 1
20
+ self.pad_id = self.text_tokenizer.convert_tokens_to_ids('[PAD]')
21
+ self.ignore_id = ignore_id
22
+ self.cond_dropout_prob = cond_dropout_prob
23
+
24
+ def t2i_prompt_predict_next(self, text_ids, image_ids, labels):
25
+
26
+ device = image_ids.device
27
+ sequence_ids = []
28
+ attention_masks = []
29
+ label_ids = []
30
+ probs = torch.rand(len(text_ids))
31
+ for i in range(len(text_ids)):
32
+
33
+ if len(text_ids[i]) == 0:
34
+ text_ids[i] = [self.text_tokenizer.bos_token_id]
35
+ elif text_ids[i][0] != self.text_tokenizer.bos_token_id:
36
+ text_ids[i] = [self.text_tokenizer.bos_token_id] + text_ids[i]
37
+
38
+ temp_ids = [int(self.sptids_dict['<|t2i|>'])] + text_ids[i] + [self.text_tokenizer.eos_token_id]
39
+
40
+ # randomly dropout text condition
41
+ if probs[i] < self.cond_dropout_prob:
42
+ temp_ids = [int(self.sptids_dict['<|t2i|>']), self.text_tokenizer.bos_token_id, self.text_tokenizer.eos_token_id]
43
+
44
+ if self.max_text_len >= len(temp_ids):
45
+ temp_ids = [self.pad_id] * (self.max_text_len - len(temp_ids)) + temp_ids
46
+ temp_masks = [0] * (self.max_text_len - len(temp_ids)) + [1] * (len(temp_ids) + image_ids.shape[-1] + 3)
47
+ else:
48
+ # should add the eos token
49
+ temp_ids = temp_ids[:self.max_text_len - 1] + [self.text_tokenizer.eos_token_id]
50
+ temp_masks = [1] * (len(temp_ids) + image_ids.shape[-1] + 3) # +2 for two special tokens
51
+
52
+ # prompting -- [task token] [sot] [text tokens] [eot] [soi] [image tokens] [eoi]
53
+ temp_label_ids = torch.cat([
54
+ # should we predict text tokens when doing image reconstruction?
55
+ torch.tensor(temp_ids).to(device),
56
+ self.sptids_dict['<|soi|>'].to(device),
57
+ labels[i],
58
+ self.sptids_dict['<|eoi|>'].to(device)
59
+ ], dim=0)
60
+
61
+ temp_label_ids = torch.where(temp_label_ids == self.pad_id, self.ignore_id, temp_label_ids)
62
+
63
+ temp_ids = torch.cat([
64
+ torch.tensor(temp_ids).to(device),
65
+ self.sptids_dict['<|soi|>'].to(device),
66
+ image_ids[i],
67
+ self.sptids_dict['<|eoi|>'].to(device)
68
+ ], dim=0)
69
+
70
+ temp_masks = torch.tensor(temp_masks).to(device)
71
+ sequence_ids.append(temp_ids.unsqueeze(0))
72
+ attention_masks.append(temp_masks.unsqueeze(0))
73
+ label_ids.append(temp_label_ids.unsqueeze(0))
74
+
75
+ return torch.cat(sequence_ids, dim=0), torch.cat(attention_masks, dim=0), torch.cat(label_ids, dim=0)
76
+
77
+ def t2i_gen_prompt(self, text_ids, image_ids):
78
+
79
+ device = image_ids.device
80
+ sequence_ids = []
81
+ attention_masks = []
82
+ for i in range(len(text_ids)):
83
+ if len(text_ids[i]) == 0:
84
+ text_ids[i] = [self.text_tokenizer.bos_token_id]
85
+ elif text_ids[i][0] != self.text_tokenizer.bos_token_id:
86
+ text_ids[i] = [self.text_tokenizer.bos_token_id] + text_ids[i]
87
+ # note that, llama3 tokenizer automatically add the bot token at first but without eot
88
+ temp_ids = [int(self.sptids_dict['<|t2i|>'])] + text_ids[i] + [self.text_tokenizer.eos_token_id]
89
+ if self.max_text_len >= len(temp_ids):
90
+ temp_ids = [self.pad_id] * (self.max_text_len - len(temp_ids)) + temp_ids
91
+ temp_masks = [0] * (self.max_text_len - len(temp_ids)) + [1] * len(temp_ids)
92
+ else:
93
+ temp_ids = temp_ids[:self.max_text_len - 1] + [self.text_tokenizer.eos_token_id]
94
+ temp_masks = [1] * len(temp_ids) # +2 for two special tokens
95
+
96
+ # prompting -- [task token] [sot] [text tokens] [eot] [soi] [image tokens] [eoi]
97
+ temp_ids = torch.cat([
98
+ torch.tensor(temp_ids).to(device),
99
+ self.sptids_dict['<|soi|>'].to(device),
100
+ image_ids[i],
101
+ self.sptids_dict['<|eoi|>'].to(device)
102
+ ], dim=0)
103
+
104
+ temp_masks = torch.tensor(temp_masks).to(device)
105
+ sequence_ids.append(temp_ids.unsqueeze(0))
106
+ attention_masks.append(temp_masks.unsqueeze(0))
107
+
108
+ return torch.cat(sequence_ids, dim=0), torch.cat(attention_masks, dim=0)
109
+
110
+ # language modeling
111
+ def lm_prompt(self, text_ids, max_seq_len):
112
+
113
+ sequence_ids = []
114
+ attention_masks = []
115
+ label_ids = []
116
+ for i in range(len(text_ids)):
117
+ if len(text_ids[i]) == 0:
118
+ text_ids[i] = [self.text_tokenizer.bos_token_id]
119
+ elif text_ids[i][0] != self.text_tokenizer.bos_token_id:
120
+ text_ids[i] = [self.text_tokenizer.eos_token_id] + text_ids[i]
121
+
122
+ temp_ids = text_ids[i] + [self.text_tokenizer.eos_token_id]
123
+
124
+ if max_seq_len >= len(temp_ids):
125
+ temp_labels_ids = temp_ids + [self.ignore_id] * (max_seq_len - len(temp_ids))
126
+ temp_ids = temp_ids + [self.pad_id] * (max_seq_len - len(temp_ids))
127
+ temp_masks = [1] * len(temp_ids) + [0] * (max_seq_len - len(temp_ids))
128
+ else:
129
+ # In language modeling, we only process text tokens. We do not add the eos token if the text length
130
+ # exceeds the max sequence length
131
+ temp_labels_ids = temp_ids[:max_seq_len]
132
+ temp_ids = temp_ids[:max_seq_len]
133
+ temp_masks = [1] * len(temp_ids) # +2 for two special tokens
134
+
135
+ # prompting -- [task token] [sot] [text tokens] [eot] [soi] [image tokens] [eoi]
136
+ temp_ids = torch.tensor(temp_ids)
137
+ temp_masks = torch.tensor(temp_masks)
138
+ temp_labels_ids = torch.tensor(temp_labels_ids)
139
+
140
+ sequence_ids.append(temp_ids.unsqueeze(0))
141
+ attention_masks.append(temp_masks.unsqueeze(0))
142
+ label_ids.append(temp_labels_ids.unsqueeze(0))
143
+
144
+ # input_ids, masks, labels
145
+ return torch.cat(sequence_ids, dim=0), torch.cat(attention_masks, dim=0), torch.cat(label_ids, dim=0)
146
+
147
+ def mmu_prompt(self, image_ids, text_ids):
148
+ device = image_ids.device
149
+ sequence_ids = []
150
+ attention_masks = []
151
+ label_ids = []
152
+ max_text_len = self.max_text_len - 1
153
+ for i in range(len(text_ids)):
154
+ # note that, llama3 tokenizer automatically add the bot token at first but without eot
155
+ # for empty list []
156
+
157
+ if len(text_ids[i]) == 0:
158
+ text_ids[i] = [self.text_tokenizer.bos_token_id]
159
+ elif text_ids[i][0] != self.text_tokenizer.bos_token_id:
160
+ text_ids[i] = [self.text_tokenizer.eos_token_id] + text_ids[i]
161
+
162
+ temp_ids = text_ids[i] + [self.text_tokenizer.eos_token_id]
163
+
164
+ if max_text_len >= len(temp_ids):
165
+ # minus 1 because task token was prepended to the former image tokens
166
+ temp_ids = temp_ids + [self.pad_id] * (max_text_len - len(temp_ids))
167
+ temp_masks = [1] * (len(temp_ids) + image_ids.shape[-1] + 3) + [0] * (max_text_len - len(temp_ids))
168
+ else:
169
+ # should add the eos token
170
+ temp_ids = temp_ids[:max_text_len - 1] + [self.text_tokenizer.eos_token_id]
171
+ temp_masks = [1] * (len(temp_ids) + image_ids.shape[-1] + 3) # +2 for two special tokens
172
+
173
+ # prompting -- [task token] [sot] [text tokens] [eot] [soi] [image tokens] [eoi]
174
+ temp_label_ids = torch.cat([
175
+ torch.tensor([self.ignore_id]).to(device),
176
+ torch.tensor([self.ignore_id]).to(device),
177
+ torch.ones_like(image_ids[i]) * self.ignore_id,
178
+ torch.tensor([self.ignore_id]).to(device),
179
+ torch.tensor(temp_ids).to(device),
180
+ ], dim=0)
181
+
182
+ temp_label_ids = torch.where(temp_label_ids == self.pad_id, self.ignore_id, temp_label_ids)
183
+
184
+ temp_ids = torch.cat([
185
+ self.sptids_dict['<|mmu|>'].to(device), # task token
186
+ self.sptids_dict['<|soi|>'].to(device),
187
+ image_ids[i],
188
+ self.sptids_dict['<|eoi|>'].to(device),
189
+ torch.tensor(temp_ids).to(device),
190
+ ], dim=0)
191
+
192
+ temp_masks = torch.tensor(temp_masks).to(device)
193
+ sequence_ids.append(temp_ids.unsqueeze(0))
194
+ attention_masks.append(temp_masks.unsqueeze(0))
195
+ label_ids.append(temp_label_ids.unsqueeze(0))
196
+
197
+ return torch.cat(sequence_ids, dim=0), torch.cat(attention_masks, dim=0), torch.cat(label_ids, dim=0)
198
+
199
+ def t2v_prompt(self, text_ids, video_ids):
200
+ """
201
+ :param text_ids:
202
+ :param video_ids:
203
+ :return:
204
+ """
205
+ pass
206
+
207
+ def i2v_prompt(self, image_ids, video_ids):
208
+ """
209
+ :param image_ids:
210
+ :param video_ids:
211
+ :return:
212
+ """
213
+ pass
214
+
215
+ def lvg_prompt(self, text_ids, image_ids, labels):
216
+
217
+ device = image_ids.device
218
+ sequence_ids = []
219
+ attention_masks = []
220
+ label_ids = []
221
+ probs = torch.rand(len(text_ids))
222
+ probs2 = torch.rand(len(text_ids))
223
+ for i in range(len(text_ids)):
224
+
225
+ if len(text_ids[i]) == 0:
226
+ text_ids[i] = [self.text_tokenizer.bos_token_id]
227
+ elif text_ids[i][0] != self.text_tokenizer.bos_token_id:
228
+ text_ids[i] = [self.text_tokenizer.bos_token_id] + text_ids[i]
229
+
230
+ temp_ids = [int(self.sptids_dict['<|t2i|>'])] + text_ids[i] + [self.text_tokenizer.eos_token_id]
231
+
232
+ # randomly dropout text condition
233
+ if probs[i] < self.cond_dropout_prob:
234
+ temp_ids = [int(self.sptids_dict['<|t2i|>']), self.text_tokenizer.bos_token_id,
235
+ self.text_tokenizer.eos_token_id]
236
+
237
+ if self.max_text_len >= len(temp_ids):
238
+ temp_ids = [self.pad_id] * (self.max_text_len - len(temp_ids)) + temp_ids
239
+ temp_masks = [0] * (self.max_text_len - len(temp_ids)) + [1] * (len(temp_ids) + image_ids.shape[-1] + 3)
240
+ else:
241
+ # should add the eos token
242
+ temp_ids = temp_ids[:self.max_text_len - 1] + [self.text_tokenizer.eos_token_id]
243
+ temp_masks = [1] * (len(temp_ids) + image_ids.shape[-1] + 3) # +2 for two special tokens
244
+
245
+ # prompting -- [task token] [sot] [text tokens] [eot] [soi] [image tokens] [eoi]
246
+ temp_label_ids = torch.cat([
247
+ # should we predict text tokens when doing image reconstruction?
248
+ torch.tensor(temp_ids).to(device),
249
+ self.sptids_dict['<|soi|>'].to(device),
250
+ labels[i],
251
+ self.sptids_dict['<|eoi|>'].to(device)
252
+ ], dim=0)
253
+
254
+ temp_label_ids = torch.where(temp_label_ids == self.pad_id, self.ignore_id, temp_label_ids)
255
+
256
+ temp_ids = torch.cat([
257
+ torch.tensor(temp_ids).to(device),
258
+ self.sptids_dict['<|soi|>'].to(device),
259
+ image_ids[i],
260
+ self.sptids_dict['<|eoi|>'].to(device)
261
+ ], dim=0)
262
+
263
+ temp_masks = torch.tensor(temp_masks).to(device)
264
+ sequence_ids.append(temp_ids.unsqueeze(0))
265
+ attention_masks.append(temp_masks.unsqueeze(0))
266
+ label_ids.append(temp_label_ids.unsqueeze(0))
267
+
268
+ return torch.cat(sequence_ids, dim=0), torch.cat(attention_masks, dim=0), torch.cat(label_ids, dim=0)
269
+
270
+ def lvg_gen_prompt(self, text_ids, image_ids):
271
+
272
+ device = image_ids.device
273
+ sequence_ids = []
274
+ attention_masks = []
275
+ for i in range(len(text_ids)):
276
+ if len(text_ids[i]) == 0:
277
+ text_ids[i] = [self.text_tokenizer.bos_token_id]
278
+ elif text_ids[i][0] != self.text_tokenizer.bos_token_id:
279
+ text_ids[i] = [self.text_tokenizer.bos_token_id] + text_ids[i]
280
+ # note that, llama3 tokenizer automatically add the bot token at first but without eot
281
+ temp_ids = [int(self.sptids_dict['<|t2i|>'])] + text_ids[i] + [self.text_tokenizer.eos_token_id]
282
+ if self.max_text_len >= len(temp_ids):
283
+ temp_ids = [self.pad_id] * (self.max_text_len - len(temp_ids)) + temp_ids
284
+ temp_masks = [0] * (self.max_text_len - len(temp_ids)) + [1] * len(temp_ids)
285
+ else:
286
+ temp_ids = temp_ids[:self.max_text_len - 1] + [self.text_tokenizer.eos_token_id]
287
+ temp_masks = [1] * len(temp_ids) # +2 for two special tokens
288
+
289
+ # prompting -- [task token] [sot] [text tokens] [eot] [soi] [image tokens] [eoi]
290
+ temp_ids = torch.cat([
291
+ torch.tensor(temp_ids).to(device),
292
+ self.sptids_dict['<|soi|>'].to(device),
293
+ image_ids[i],
294
+ self.sptids_dict['<|eoi|>'].to(device)
295
+ ], dim=0)
296
+
297
+ temp_masks = torch.tensor(temp_masks).to(device)
298
+ sequence_ids.append(temp_ids.unsqueeze(0))
299
+ attention_masks.append(temp_masks.unsqueeze(0))
300
+
301
+ return torch.cat(sequence_ids, dim=0), torch.cat(attention_masks, dim=0)
302
+
303
+ def mask_prompt(self):
304
+ pass
305
+
306
+ def __call__(self, input, task, padding=True, config=None):
307
+ """
308
+ input (tuple) : data pairs contain text(str), image(tensor), or videos(tensor).
309
+ task (str) : a flag indicates the current task.
310
+ """
311
+ if task == "t2i":
312
+ text_ids = self.text_tokenizer(input[0])['input_ids'] # (B, max_len)
313
+ image_ids = input[1] # (B, #tokens)
314
+ sequence_ids_with_masks = self.t2i_prompt(text_ids, image_ids, input[2])
315
+
316
+ elif task == "t2i_predict_next":
317
+ text_ids = self.text_tokenizer(input[0])['input_ids'] # (B, max_len)
318
+ image_ids = input[1] # (B, #tokens)
319
+ sequence_ids_with_masks = self.t2i_prompt_predict_next(text_ids, image_ids, input[2])
320
+
321
+ elif task == "t2i_predict_next_plus_lm":
322
+ text_ids = self.text_tokenizer(input[0])['input_ids'] # (B, max_len)
323
+ image_ids = input[1] # (B, #tokens)
324
+ sequence_ids_with_masks = self.t2i_prompt_predict_next(text_ids[:config.training.batch_size], image_ids,
325
+ input[2])
326
+ sequence_ids_with_masks_lm = self.lm_prompt(text_ids[config.training.batch_size:], input[3])
327
+ return sequence_ids_with_masks, sequence_ids_with_masks_lm
328
+
329
+ elif task == "t2i_gen":
330
+ text_ids = self.text_tokenizer(input[0])['input_ids'] # (B, max_len)
331
+ image_ids = input[1] # (B, #tokens)
332
+ sequence_ids_with_masks = self.t2i_gen_prompt(text_ids, image_ids)
333
+
334
+ elif task == "lm":
335
+ text_ids = self.text_tokenizer(input[0], truncation=True)['input_ids'] # (B, max_len)
336
+ sequence_ids_with_masks = self.lm_prompt(text_ids, input[1])
337
+
338
+ elif task == "mmu":
339
+ image_ids = input[0]
340
+ text_ids = self.text_tokenizer(input[1])['input_ids']
341
+ sequence_ids_with_masks = self.mmu_prompt(image_ids, text_ids)
342
+
343
+ elif task == "t2v":
344
+ text_ids = self.text_tokenizer(input[0]['input_ids'])
345
+ video_ids = self.vision_tokenizer(input[1])
346
+ sequence_ids_with_masks = self.t2v_prompt(text_ids, video_ids)
347
+
348
+ elif task == "i2v":
349
+ image_ids = self.text_tokenizer(input[0])
350
+ video_ids = self.vision_tokenizer(input[1])
351
+ sequence_ids_with_masks = self.i2v_prompt(image_ids, video_ids)
352
+
353
+ elif task == "lvg":
354
+ text_ids = self.text_tokenizer(input[0])['input_ids'] # (B, max_len)
355
+ image_ids = input[1] # (B, #tokens)
356
+ sequence_ids_with_masks = self.lvg_prompt(text_ids, image_ids, input[2])
357
+
358
+ elif task == "lvg_gen":
359
+ text_ids = self.text_tokenizer(input[0])['input_ids'] # (B, max_len)
360
+ image_ids = input[1] # (B, #tokens)
361
+ sequence_ids_with_masks = self.lvg_gen_prompt(text_ids, image_ids)
362
+ else:
363
+ raise NotImplementedError
364
+
365
+ return sequence_ids_with_masks
366
+
367
+ def create_attention_mask_predict_next(sequence, pad_id=128256, soi_id=128257, eoi_id=128258, rm_pad_in_image=False,
368
+ return_inverse_mask=True):
369
+ # sequence is expected to be of shape [N, L]
370
+ N, L = sequence.shape
371
+
372
+ # Masks to identify different types of tokens
373
+ is_padding = sequence == pad_id
374
+
375
+ is_start_image = sequence == soi_id
376
+
377
+ is_end_image = sequence == eoi_id
378
+
379
+ # Create cumulative sum masks to identify regions of image tokens
380
+ cumulative_start = torch.cumsum(is_start_image, dim=1)
381
+ cumulative_end = torch.cumsum(is_end_image, dim=1)
382
+ in_image_segment = (cumulative_start > cumulative_end) | is_start_image | is_end_image
383
+
384
+ is_text = ~(in_image_segment)
385
+
386
+ causal_mask = torch.tril(torch.ones((L, L), dtype=torch.bool)).to(sequence.device)
387
+
388
+ mask_text = is_text[:, :, None] * causal_mask[None, :, :]
389
+
390
+ is_text_image = is_text | in_image_segment
391
+
392
+ mask_text_image_bi = is_text_image[:, :, None] * is_text_image[:, None, :]
393
+ if rm_pad_in_image:
394
+ sid_img = torch.where(sequence == soi_id)[1]
395
+ for i in range(mask_text_image_bi.shape[0]):
396
+ pad_end_idx = torch.where(sequence[i] == pad_id)
397
+ if len(pad_end_idx[0]) != 0:
398
+ pad_end_idx = pad_end_idx[0][-1]
399
+ mask_text[i][pad_end_idx + 1:, :pad_end_idx + 1] = 0
400
+ id_padding = torch.where(is_padding[i] == True)
401
+ mask_text_image_bi[i][sid_img[i]:, id_padding[0]] = 0
402
+
403
+ mask_text[in_image_segment] = mask_text_image_bi[in_image_segment]
404
+ # No token attends to padding tokens and padding tokens do not attend to any token
405
+ if return_inverse_mask:
406
+ inverted_mask = 1.0 - mask_text.type(sequence.dtype)
407
+ inverted_mask = inverted_mask.masked_fill(
408
+ inverted_mask.to(torch.bool), torch.iinfo(sequence.dtype).min
409
+ )
410
+ return inverted_mask.unsqueeze(1)
411
+ else:
412
+ return mask_text.unsqueeze(1)
413
+
414
+ def create_attention_mask_lvg(sequence, pad_id=128256, soi_id=128257, eoi_id=128258, return_inverse_mask=True):
415
+ # sequence is expected to be of shape [N, L]
416
+ N, L = sequence.shape
417
+ # Masks to identify different types of tokens
418
+ is_padding = sequence == pad_id
419
+ mask_text_image_bi = torch.tril(torch.ones(N, L, L), diagonal=0).to(sequence.device)
420
+
421
+ sid_img = torch.where(sequence == soi_id)[1].reshape(mask_text_image_bi.shape[0], -1)[:, 0]
422
+ sid_img_for_bi = torch.where(sequence == soi_id)[1].reshape(mask_text_image_bi.shape[0], -1)
423
+ eid_img_for_bi = torch.where(sequence == eoi_id)[1].reshape(mask_text_image_bi.shape[0], -1)
424
+ for i in range(N):
425
+ id_padding = torch.where(is_padding[i] == True)
426
+ mask_text_image_bi[i][sid_img[i]:, id_padding[0]] = 0
427
+ for j in range(sid_img_for_bi.shape[-1]):
428
+ mask_text_image_bi[i][sid_img_for_bi[i, j]:eid_img_for_bi[i, j] + 1,
429
+ sid_img_for_bi[i, j]:eid_img_for_bi[i, j] + 1] = 1
430
+
431
+ # No token attends to padding tokens and padding tokens do not attend to any token
432
+ if return_inverse_mask:
433
+ inverted_mask = 1.0 - mask_text_image_bi.type(sequence.dtype)
434
+ inverted_mask = inverted_mask.masked_fill(
435
+ inverted_mask.to(torch.bool), torch.iinfo(sequence.dtype).min
436
+ )
437
+ return inverted_mask.unsqueeze(1)
438
+ else:
439
+ return mask_text_image_bi.unsqueeze(1)
440
+
441
+ # texts without attending image regions
442
+ def create_attention_mask_lvg_v2(sequence, pad_id=128256, soi_id=128257, eoi_id=128258, sot_id=1000, eot_id=1001, return_inverse_mask=True):
443
+ # sequence is expected to be of shape [N, L]
444
+ N, L = sequence.shape
445
+ # Masks to identify different types of tokens
446
+ is_padding = sequence == pad_id
447
+ # is_text = torch.where(sequence < 2000, True, False)
448
+ is_text = torch.where(sequence < pad_id, True, False)
449
+ mask_text_image_bi = torch.tril(torch.ones(N, L, L), diagonal=0).to(sequence.device).int()
450
+ sid_text_for_bi = torch.where(sequence == sot_id)[1].reshape(mask_text_image_bi.shape[0], -1)
451
+ eid_text_for_bi = torch.where(sequence == eot_id)[1].reshape(mask_text_image_bi.shape[0], -1)
452
+ # import ipdb
453
+ # ipdb.set_trace()
454
+ if sot_id == eot_id:
455
+ if sid_text_for_bi.shape[-1] % 2 != 0:
456
+ sid_text_for_bi = sid_text_for_bi[:, :-1]
457
+ eid_text_for_bi = eid_text_for_bi[:, :-1]
458
+ select_idx = [i for i in range(0, sid_text_for_bi.shape[1], 2)]
459
+ sid_text_for_bi = sid_text_for_bi[:, select_idx]
460
+ select_idx = [i+1 for i in range(0, eid_text_for_bi.shape[1], 2)]
461
+ eid_text_for_bi = eid_text_for_bi[:, select_idx]
462
+ sid_img_for_bi = torch.where(sequence == soi_id)[1].reshape(mask_text_image_bi.shape[0], -1)
463
+ eid_img_for_bi = torch.where(sequence == eoi_id)[1].reshape(mask_text_image_bi.shape[0], -1)
464
+ all_zeros = torch.zeros_like(mask_text_image_bi).int()
465
+ for i in range(N):
466
+ all_zeros[i, :, is_text[i]] = 1
467
+ for j in range(sid_text_for_bi.shape[-1]):
468
+ all_zeros[i][is_text[i], sid_text_for_bi[i, j]:eid_text_for_bi[i, j]+1] = 1
469
+ all_zeros[i][~is_text[i], sid_text_for_bi[i, j]:eid_text_for_bi[i, j]+1] = 1
470
+ for j in range(sid_img_for_bi.shape[-1]):
471
+ all_zeros[i][~is_text[i], sid_img_for_bi[i, j]:eid_img_for_bi[i, j]+1] = 1
472
+ mask_text_image_bi = mask_text_image_bi * all_zeros
473
+ sid_img = torch.where(sequence == soi_id)[1].reshape(mask_text_image_bi.shape[0], -1)[:, 0]
474
+
475
+ for i in range(N):
476
+ id_padding = torch.where(is_padding[i] == True)
477
+ mask_text_image_bi[i][sid_img[i]:, id_padding[0]] = 0
478
+ for j in range(sid_img_for_bi.shape[-1]):
479
+ mask_text_image_bi[i][sid_img_for_bi[i, j]:eid_img_for_bi[i, j]+1, sid_img_for_bi[i, j]:eid_img_for_bi[i, j]+1] = 1
480
+
481
+ mask_text_image_bi[:, :, 0] = 1
482
+ # No token attends to padding tokens and padding tokens do not attend to any token
483
+ if return_inverse_mask:
484
+ inverted_mask = 1.0 - mask_text_image_bi.type(sequence.dtype)
485
+ inverted_mask = inverted_mask.masked_fill(
486
+ inverted_mask.to(torch.bool), torch.iinfo(sequence.dtype).min
487
+ )
488
+ return inverted_mask.unsqueeze(1)
489
+ else:
490
+ return mask_text_image_bi.unsqueeze(1)
491
+
492
+ def create_attention_mask_for_mmu(sequence, eoi_id=128258, return_inverse_mask=True):
493
+ N, L = sequence.shape
494
+ causal_mask = torch.tril(torch.ones((N, 1, L, L), dtype=torch.bool)).to(sequence.device)
495
+ eoi_image = torch.where(sequence == eoi_id)[1]
496
+ causal_mask[:, :, :, :eoi_image[0] + 1] = 1
497
+
498
+ if return_inverse_mask:
499
+ inverted_mask = 1.0 - causal_mask.type(sequence.dtype)
500
+ inverted_mask = inverted_mask.masked_fill(
501
+ inverted_mask.to(torch.bool), torch.iinfo(sequence.dtype).min
502
+ )
503
+ return inverted_mask
504
+ else:
505
+ return causal_mask
506
+
507
+ def create_attention_mask_for_mmu_vit(
508
+ sequence,
509
+ return_inverse_mask=True,
510
+ system_prompt_len=0
511
+ ):
512
+ N, L, H = sequence.shape
513
+ causal_mask = torch.tril(torch.ones((N, 1, L, L), dtype=torch.bool)).to(sequence.device)
514
+ index = 1 + system_prompt_len + 1 + 576
515
+
516
+ causal_mask[:, :, :, :index] = 1
517
+ if return_inverse_mask:
518
+ inverted_mask = 1.0 - causal_mask.type(torch.int64)
519
+ inverted_mask = inverted_mask.masked_fill(
520
+ inverted_mask.to(torch.bool), torch.iinfo(torch.int64).min
521
+ )
522
+ return inverted_mask
523
+ else:
524
+ return causal_mask
525
+
526
+
527
+ if __name__ == '__main__':
528
+ pass
requirements.txt ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.21.0
2
+ aiohttp==3.9.5
3
+ aiosignal==1.3.1
4
+ albumentations==0.3.2
5
+ annotated-types==0.7.0
6
+ antlr4-python3-runtime==4.9.3
7
+ anykeystore==0.2
8
+ asn1crypto==1.5.1
9
+ asttokens==2.4.1
10
+ async-timeout==4.0.3
11
+ attrs==21.2.0
12
+ bidict==0.23.1
13
+ blessed==1.20.0
14
+ boto3==1.34.113
15
+ botocore==1.34.113
16
+ braceexpand==0.1.7
17
+ cachetools==5.3.3
18
+ certifi==2024.2.2
19
+ cffi==1.16.0
20
+ chardet==5.2.0
21
+ charset-normalizer==3.3.2
22
+ click==8.1.7
23
+ clip==0.2.0
24
+ clip-openai==1.0.post20230121
25
+ cmake==3.29.3
26
+ cramjam==2.8.3
27
+ crcmod==1.7
28
+ cryptacular==1.6.2
29
+ cryptography==39.0.2
30
+ cycler==0.12.1
31
+ datasets
32
+ diffusers==0.30.1
33
+ decorator==5.1.1
34
+ decord==0.6.0
35
+ deepspeed==0.14.2
36
+ defusedxml==0.7.1
37
+ Deprecated==1.2.14
38
+ descartes==1.1.0
39
+ dill==0.3.8
40
+ distlib==0.3.8
41
+ distro-info==1.0
42
+ dnspython==2.6.1
43
+ docker-pycreds==0.4.0
44
+ docstring_parser==0.16
45
+ ecdsa==0.19.0
46
+ einops==0.6.0
47
+ exceptiongroup==1.2.1
48
+ executing==2.0.1
49
+ fairscale==0.4.13
50
+ fastparquet==2024.5.0
51
+ ffmpegcv==0.3.13
52
+ filelock==3.14.0
53
+ fire==0.6.0
54
+ fonttools==4.51.0
55
+ frozenlist==1.4.1
56
+ fsspec==2023.6.0
57
+ ftfy==6.2.0
58
+ gitdb==4.0.11
59
+ GitPython==3.1.43
60
+ gpustat==1.1.1
61
+ greenlet==3.0.3
62
+ grpcio==1.64.0
63
+ h11==0.14.0
64
+ hjson==3.1.0
65
+ huggingface-hub==0.23.2
66
+ hupper==1.12.1
67
+ idna==3.7
68
+ imageio==2.34.1
69
+ imgaug==0.2.6
70
+ iniconfig==2.0.0
71
+ ipaddress==1.0.23
72
+ ipdb==0.13.13
73
+ ipython==8.18.1
74
+ jaxtyping==0.2.28
75
+ jedi==0.19.1
76
+ Jinja2==3.1.4
77
+ jmespath==1.0.1
78
+ joblib==1.4.2
79
+ jsonargparse==4.14.1
80
+ jsonlines==4.0.0
81
+ kiwisolver==1.4.5
82
+ kornia==0.7.2
83
+ kornia_rs==0.1.3
84
+ lazy_loader==0.4
85
+ lightning==2.2.3
86
+ lightning-utilities==0.11.2
87
+ lit==18.1.6
88
+ MarkupSafe==2.1.5
89
+ matplotlib==3.5.3
90
+ matplotlib-inline==0.1.7
91
+ miscreant==0.3.0
92
+ mpmath==1.3.0
93
+ msgpack==1.0.8
94
+ multidict==6.0.5
95
+ multiprocess==0.70.16
96
+ natsort==8.4.0
97
+ networkx==3.2.1
98
+ ninja==1.11.1.1
99
+ numpy==1.24.4
100
+ nuscenes-devkit==1.1.11
101
+ oauthlib==3.2.2
102
+ omegaconf==2.3.0
103
+ open-clip-torch==2.24.0
104
+ openai-clip
105
+ opencv-python==4.9.0.80
106
+ opencv-python-headless==3.4.18.65
107
+ packaging==22.0
108
+ pandas==1.5.3
109
+ parquet==1.3.1
110
+ parso==0.8.4
111
+ PasteDeploy==3.1.0
112
+ pathlib2==2.3.7.post1
113
+ pathtools==0.1.2
114
+ pbkdf2==1.3
115
+ pexpect==4.9.0
116
+ pillow==10.3.0
117
+ plaster==1.1.2
118
+ plaster-pastedeploy==1.0.1
119
+ platformdirs==4.2.2
120
+ plotly==5.22.0
121
+ pluggy==1.5.0
122
+ ply==3.11
123
+ promise==2.3
124
+ prompt-toolkit==3.0.43
125
+ protobuf==3.20.3
126
+ psutil==5.9.8
127
+ ptyprocess==0.7.0
128
+ pure-eval==0.2.2
129
+ py==1.11.0
130
+ py-cpuinfo==9.0.0
131
+ py-spy==0.3.14
132
+ pyarrow==11.0.0
133
+ pyarrow-hotfix==0.6
134
+ pyasn1==0.6.0
135
+ pycocotools==2.0.7
136
+ pycparser==2.22
137
+ pycryptodomex==3.20.0
138
+ pycurl==7.43.0.6
139
+ pydantic==1.10.15
140
+ pydantic_core==2.18.3
141
+ Pygments==2.18.0
142
+ PyJWT==2.8.0
143
+ pynvml==11.5.0
144
+ pyope==0.2.2
145
+ pyOpenSSL==23.2.0
146
+ pyparsing==3.1.2
147
+ pyquaternion==0.9.9
148
+ pyramid==2.0.2
149
+ pyramid-mailer==0.15.1
150
+ pytest==6.2.5
151
+ python-consul==1.1.0
152
+ python-dateutil==2.9.0.post0
153
+ python-engineio==4.9.1
154
+ python-etcd==0.4.5
155
+ python-jose==3.3.0
156
+ python-socketio==5.11.2
157
+ python3-openid==3.2.0
158
+ pytorch-extension==0.2
159
+ pytorch-lightning==2.2.3
160
+ pytz==2024.1
161
+ PyYAML==6.0.1
162
+ regex==2024.5.15
163
+ repoze.sendmail==4.4.1
164
+ requests==2.31.0
165
+ requests-oauthlib==2.0.0
166
+ rsa==4.9
167
+ s3transfer==0.10.1
168
+ safetensors==0.4.3
169
+ schedule==1.2.2
170
+ scikit-image==0.22.0
171
+ scikit-learn==1.5.0
172
+ scipy==1.13.1
173
+ sentencepiece==0.2.0
174
+ sentry-sdk==2.3.1
175
+ setproctitle==1.3.3
176
+ Shapely==1.8.5.post1
177
+ shortuuid==1.0.13
178
+ simple-websocket==1.0.0
179
+ six==1.16.0
180
+ smmap==5.0.1
181
+ SQLAlchemy==2.0.30
182
+ stack-data==0.6.3
183
+ sympy==1.12
184
+ taming-transformers-rom1504==0.0.6
185
+ tenacity==8.3.0
186
+ tensorboardX==2.6.2.2
187
+ termcolor==2.4.0
188
+ threadpoolctl==3.5.0
189
+ thriftpy2==0.5.0
190
+ tifffile==2024.5.22
191
+ timm==1.0.3
192
+ tokenizers==0.19.1
193
+ toml==0.10.2
194
+ tomli==2.0.1
195
+ torch==2.2.1
196
+ torch-fidelity==0.3.0
197
+ torchmetrics==1.4.0.post0
198
+ torchvision==0.17.1
199
+ tox==3.28.0
200
+ tqdm==4.66.4
201
+ traitlets==5.14.3
202
+ transaction==4.0
203
+ transformers==4.41.1
204
+ translationstring==1.4
205
+ triton==2.2.0
206
+ typeguard==2.13.3
207
+ typing_extensions==4.12.0
208
+ tzdata==2024.1
209
+ urllib3==1.26.18
210
+ velruse==1.1.1
211
+ venusian==3.1.0
212
+ virtualenv==20.26.2
213
+ wandb==0.17.0
214
+ watchdog==4.0.1
215
+ wcwidth==0.2.13
216
+ webdataset==0.2.86
217
+ WebOb==1.8.7
218
+ websocket-client==1.8.0
219
+ wrapt==1.16.0
220
+ wsproto==1.2.0
221
+ WTForms==3.1.2
222
+ wtforms-recaptcha==0.3.2
223
+ xformers==0.0.25
224
+ xxhash==3.4.1
225
+ yarl==1.9.4
226
+ zope.deprecation==5.0
227
+ zope.interface==6.4.post2
228
+ zope.sqlalchemy==3.1
training/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
training/conversation.py ADDED
@@ -0,0 +1,432 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from LLaVA: https://github.com/haotian-liu/LLaVA.git
2
+ import dataclasses
3
+ from enum import auto, Enum
4
+ from typing import List, Tuple
5
+ import base64
6
+ from io import BytesIO
7
+ from PIL import Image
8
+
9
+
10
+ class SeparatorStyle(Enum):
11
+ """Different separator style."""
12
+ SINGLE = auto()
13
+ TWO = auto()
14
+ MPT = auto()
15
+ PLAIN = auto()
16
+ LLAMA_2 = auto()
17
+
18
+
19
+ @dataclasses.dataclass
20
+ class Conversation:
21
+ """A class that keeps all conversation history."""
22
+ system: str
23
+ roles: List[str]
24
+ messages: List[List[str]]
25
+ offset: int
26
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
27
+ sep: str = "###"
28
+ sep2: str = None
29
+ version: str = "Unknown"
30
+
31
+ skip_next: bool = False
32
+
33
+ def get_prompt(self):
34
+ messages = self.messages
35
+ if len(messages) > 0 and type(messages[0][1]) is tuple:
36
+ messages = self.messages.copy()
37
+ init_role, init_msg = messages[0].copy()
38
+ init_msg = init_msg[0].replace("<image>", "").strip()
39
+ if 'mmtag' in self.version:
40
+ messages[0] = (init_role, init_msg)
41
+ messages.insert(0, (self.roles[0], "<Image><image></Image>"))
42
+ messages.insert(1, (self.roles[1], "Received."))
43
+ else:
44
+ messages[0] = (init_role, "<image>\n" + init_msg)
45
+
46
+ if self.sep_style == SeparatorStyle.SINGLE:
47
+ ret = self.system + self.sep
48
+ for role, message in messages:
49
+ if message:
50
+ if type(message) is tuple:
51
+ message, _, _ = message
52
+ ret += role + ": " + message + self.sep
53
+ else:
54
+ ret += role + ":"
55
+ elif self.sep_style == SeparatorStyle.TWO:
56
+ seps = [self.sep, self.sep2]
57
+ ret = self.system + seps[0]
58
+ for i, (role, message) in enumerate(messages):
59
+ if message:
60
+ if type(message) is tuple:
61
+ message, _, _ = message
62
+ ret += role + ": " + message + seps[i % 2]
63
+ else:
64
+ ret += role + ":"
65
+ elif self.sep_style == SeparatorStyle.MPT:
66
+ ret = self.system + self.sep
67
+ for role, message in messages:
68
+ if message:
69
+ if type(message) is tuple:
70
+ message, _, _ = message
71
+ ret += role + message + self.sep
72
+ else:
73
+ ret += role
74
+ elif self.sep_style == SeparatorStyle.LLAMA_2:
75
+ wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n" if len(msg) > 0 else msg
76
+ wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
77
+ ret = ""
78
+
79
+ for i, (role, message) in enumerate(messages):
80
+ if i == 0:
81
+ assert message, "first message should not be none"
82
+ assert role == self.roles[0], "first message should come from user"
83
+ if message:
84
+ if type(message) is tuple:
85
+ message, _, _ = message
86
+ if i == 0: message = wrap_sys(self.system) + message
87
+ if i % 2 == 0:
88
+ message = wrap_inst(message)
89
+ ret += self.sep + message
90
+ else:
91
+ ret += " " + message + " " + self.sep2
92
+ else:
93
+ ret += ""
94
+ ret = ret.lstrip(self.sep)
95
+ elif self.sep_style == SeparatorStyle.PLAIN:
96
+ seps = [self.sep, self.sep2]
97
+ ret = self.system
98
+ for i, (role, message) in enumerate(messages):
99
+ if message:
100
+ if type(message) is tuple:
101
+ message, _, _ = message
102
+ ret += message + seps[i % 2]
103
+ else:
104
+ ret += ""
105
+ else:
106
+ raise ValueError(f"Invalid style: {self.sep_style}")
107
+
108
+ return ret
109
+
110
+ def append_message(self, role, message):
111
+ self.messages.append([role, message])
112
+
113
+ def process_image(self, image, image_process_mode, return_pil=False, image_format='PNG', max_len=1344, min_len=672):
114
+ if image_process_mode == "Pad":
115
+ def expand2square(pil_img, background_color=(122, 116, 104)):
116
+ width, height = pil_img.size
117
+ if width == height:
118
+ return pil_img
119
+ elif width > height:
120
+ result = Image.new(pil_img.mode, (width, width), background_color)
121
+ result.paste(pil_img, (0, (width - height) // 2))
122
+ return result
123
+ else:
124
+ result = Image.new(pil_img.mode, (height, height), background_color)
125
+ result.paste(pil_img, ((height - width) // 2, 0))
126
+ return result
127
+ image = expand2square(image)
128
+ elif image_process_mode in ["Default", "Crop"]:
129
+ pass
130
+ elif image_process_mode == "Resize":
131
+ image = image.resize((336, 336))
132
+ else:
133
+ raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
134
+ if max(image.size) > max_len:
135
+ max_hw, min_hw = max(image.size), min(image.size)
136
+ aspect_ratio = max_hw / min_hw
137
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
138
+ longest_edge = int(shortest_edge * aspect_ratio)
139
+ W, H = image.size
140
+ if H > W:
141
+ H, W = longest_edge, shortest_edge
142
+ else:
143
+ H, W = shortest_edge, longest_edge
144
+ image = image.resize((W, H))
145
+ if return_pil:
146
+ return image
147
+ else:
148
+ buffered = BytesIO()
149
+ image.save(buffered, format=image_format)
150
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
151
+ return img_b64_str
152
+
153
+ def get_images(self, return_pil=False):
154
+ images = []
155
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
156
+ if i % 2 == 0:
157
+ if type(msg) is tuple:
158
+ msg, image, image_process_mode = msg
159
+ image = self.process_image(image, image_process_mode, return_pil=return_pil)
160
+ images.append(image)
161
+ return images
162
+
163
+ def to_gradio_chatbot(self):
164
+ ret = []
165
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
166
+ if i % 2 == 0:
167
+ if type(msg) is tuple:
168
+ msg, image, image_process_mode = msg
169
+ img_b64_str = self.process_image(
170
+ image, "Default", return_pil=False,
171
+ image_format='JPEG')
172
+ img_str = f'<img src="data:image/jpeg;base64,{img_b64_str}" alt="user upload image" />'
173
+ msg = img_str + msg.replace('<image>', '').strip()
174
+ ret.append([msg, None])
175
+ else:
176
+ ret.append([msg, None])
177
+ else:
178
+ ret[-1][-1] = msg
179
+ return ret
180
+
181
+ def copy(self):
182
+ return Conversation(
183
+ system=self.system,
184
+ roles=self.roles,
185
+ messages=[[x, y] for x, y in self.messages],
186
+ offset=self.offset,
187
+ sep_style=self.sep_style,
188
+ sep=self.sep,
189
+ sep2=self.sep2,
190
+ version=self.version)
191
+
192
+ def dict(self):
193
+ if len(self.get_images()) > 0:
194
+ return {
195
+ "system": self.system,
196
+ "roles": self.roles,
197
+ "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
198
+ "offset": self.offset,
199
+ "sep": self.sep,
200
+ "sep2": self.sep2,
201
+ }
202
+ return {
203
+ "system": self.system,
204
+ "roles": self.roles,
205
+ "messages": self.messages,
206
+ "offset": self.offset,
207
+ "sep": self.sep,
208
+ "sep2": self.sep2,
209
+ }
210
+
211
+
212
+ conv_vicuna_v0 = Conversation(
213
+ system="A chat between a curious human and an artificial intelligence assistant. "
214
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
215
+ roles=("Human", "Assistant"),
216
+ messages=(
217
+ ("Human", "What are the key differences between renewable and non-renewable energy sources?"),
218
+ ("Assistant",
219
+ "Renewable energy sources are those that can be replenished naturally in a relatively "
220
+ "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
221
+ "Non-renewable energy sources, on the other hand, are finite and will eventually be "
222
+ "depleted, such as coal, oil, and natural gas. Here are some key differences between "
223
+ "renewable and non-renewable energy sources:\n"
224
+ "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
225
+ "energy sources are finite and will eventually run out.\n"
226
+ "2. Environmental impact: Renewable energy sources have a much lower environmental impact "
227
+ "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
228
+ "and other negative effects.\n"
229
+ "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
230
+ "have lower operational costs than non-renewable sources.\n"
231
+ "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
232
+ "locations than non-renewable sources.\n"
233
+ "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
234
+ "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
235
+ "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
236
+ "non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
237
+ ),
238
+ offset=2,
239
+ sep_style=SeparatorStyle.SINGLE,
240
+ sep="###",
241
+ )
242
+
243
+ conv_vicuna_v1 = Conversation(
244
+ system="A chat between a curious user and an artificial intelligence assistant. "
245
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
246
+ roles=("USER", "ASSISTANT"),
247
+ version="v1",
248
+ messages=(),
249
+ offset=0,
250
+ sep_style=SeparatorStyle.TWO,
251
+ sep=" ",
252
+ sep2="</s>",
253
+ )
254
+
255
+ conv_llama_2 = Conversation(
256
+ system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
257
+
258
+ If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
259
+ roles=("USER", "ASSISTANT"),
260
+ version="llama_v2",
261
+ messages=(),
262
+ offset=0,
263
+ sep_style=SeparatorStyle.LLAMA_2,
264
+ sep="<s>",
265
+ sep2="</s>",
266
+ )
267
+
268
+ conv_llava_llama_2 = Conversation(
269
+ system="You are a helpful language and vision assistant. "
270
+ "You are able to understand the visual content that the user provides, "
271
+ "and assist the user with a variety of tasks using natural language.",
272
+ roles=("USER", "ASSISTANT"),
273
+ version="llama_v2",
274
+ messages=(),
275
+ offset=0,
276
+ sep_style=SeparatorStyle.LLAMA_2,
277
+ sep="<s>",
278
+ sep2="</s>",
279
+ )
280
+
281
+ conv_mpt = Conversation(
282
+ system="""<|im_start|>system
283
+ A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
284
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
285
+ version="mpt",
286
+ messages=(),
287
+ offset=0,
288
+ sep_style=SeparatorStyle.MPT,
289
+ sep="<|im_end|>",
290
+ )
291
+
292
+ conv_llava_plain = Conversation(
293
+ system="",
294
+ roles=("", ""),
295
+ messages=(
296
+ ),
297
+ offset=0,
298
+ sep_style=SeparatorStyle.PLAIN,
299
+ sep="\n",
300
+ )
301
+
302
+ conv_llava_v0 = Conversation(
303
+ system="A chat between a curious human and an artificial intelligence assistant. "
304
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
305
+ roles=("Human", "Assistant"),
306
+ messages=(
307
+ ),
308
+ offset=0,
309
+ sep_style=SeparatorStyle.SINGLE,
310
+ sep="###",
311
+ )
312
+
313
+ conv_llava_v0_mmtag = Conversation(
314
+ system="A chat between a curious user and an artificial intelligence assistant. "
315
+ "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
316
+ "The visual content will be provided with the following format: <Image>visual content</Image>.",
317
+ roles=("Human", "Assistant"),
318
+ messages=(
319
+ ),
320
+ offset=0,
321
+ sep_style=SeparatorStyle.SINGLE,
322
+ sep="###",
323
+ version="v0_mmtag",
324
+ )
325
+
326
+ conv_llava_v1 = Conversation(
327
+ system="A chat between a curious human and an artificial intelligence assistant. "
328
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
329
+ roles=("USER", "ASSISTANT"),
330
+ version="v1",
331
+ messages=(),
332
+ offset=0,
333
+ sep_style=SeparatorStyle.TWO,
334
+ sep=" ",
335
+ sep2="</s>",
336
+ )
337
+
338
+ conv_llava_v1_mmtag = Conversation(
339
+ system="A chat between a curious user and an artificial intelligence assistant. "
340
+ "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
341
+ "The visual content will be provided with the following format: <Image>visual content</Image>.",
342
+ roles=("USER", "ASSISTANT"),
343
+ messages=(),
344
+ offset=0,
345
+ sep_style=SeparatorStyle.TWO,
346
+ sep=" ",
347
+ sep2="</s>",
348
+ version="v1_mmtag",
349
+ )
350
+
351
+ conv_mistral_instruct = Conversation(
352
+ system="",
353
+ roles=("USER", "ASSISTANT"),
354
+ version="llama_v2",
355
+ messages=(),
356
+ offset=0,
357
+ sep_style=SeparatorStyle.LLAMA_2,
358
+ sep="",
359
+ sep2="</s>",
360
+ )
361
+
362
+ conv_chatml_direct = Conversation(
363
+ system="""<|im_start|>system
364
+ Answer the questions.""",
365
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
366
+ version="mpt",
367
+ messages=(),
368
+ offset=0,
369
+ sep_style=SeparatorStyle.MPT,
370
+ sep="<|im_end|>",
371
+ )
372
+
373
+ conv_phi3_instruct = Conversation(
374
+ system="""<|system|>\nYou are a helpful AI assistant.""",
375
+ roles=("\n<|user|>\n", "\n<|assistant|>\n"),
376
+ version="phi3",
377
+ messages=(),
378
+ offset=0,
379
+ sep_style=SeparatorStyle.MPT,
380
+ sep="<|end|>",
381
+ )
382
+
383
+ # conv_phi_v0 = Conversation(
384
+ # system="A chat between a curious user and an artificial intelligence assistant. "
385
+ # "The assistant gives helpful, detailed, and polite answers to the user's questions.",
386
+ # roles=("USER", "ASSISTANT"),
387
+ # version="v0",
388
+ # messages=(),
389
+ # offset=0,
390
+ # sep_style=SeparatorStyle.TWO,
391
+ # sep=" ",
392
+ # sep2="<|endoftext|>",
393
+ # )
394
+
395
+ conv_phi_v0 = Conversation(
396
+ system="",
397
+ roles=("USER", "ASSISTANT"),
398
+ version="v0",
399
+ messages=(),
400
+ offset=0,
401
+ sep_style=SeparatorStyle.TWO,
402
+ sep=" ",
403
+ sep2="<|endoftext|>",
404
+ )
405
+
406
+ default_conversation = conv_vicuna_v1
407
+ conv_templates = {
408
+ "default": conv_vicuna_v0,
409
+ "v0": conv_vicuna_v0,
410
+ "v1": conv_vicuna_v1,
411
+ "vicuna_v1": conv_vicuna_v1,
412
+ "llama_2": conv_llama_2,
413
+ "mistral_instruct": conv_mistral_instruct,
414
+ "chatml_direct": conv_chatml_direct,
415
+ "mistral_direct": conv_chatml_direct,
416
+
417
+ "plain": conv_llava_plain,
418
+ "v0_plain": conv_llava_plain,
419
+ "llava_v0": conv_llava_v0,
420
+ "v0_mmtag": conv_llava_v0_mmtag,
421
+ "llava_v1": conv_llava_v1,
422
+ "v1_mmtag": conv_llava_v1_mmtag,
423
+ "llava_llama_2": conv_llava_llama_2,
424
+ "phi3_instruct": conv_phi3_instruct,
425
+ "phi1.5": conv_phi_v0,
426
+
427
+ "mpt": conv_mpt,
428
+ }
429
+
430
+
431
+ if __name__ == "__main__":
432
+ print(default_conversation.get_prompt())
training/utils.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from omegaconf import DictConfig, ListConfig, OmegaConf
6
+ from typing import Any, List, Tuple, Union
7
+
8
+
9
+ ##################################################
10
+ # config utils
11
+ ##################################################
12
+ def get_config():
13
+ cli_conf = OmegaConf.from_cli()
14
+ yaml_conf = OmegaConf.load(cli_conf.config)
15
+ conf = OmegaConf.merge(yaml_conf, cli_conf)
16
+
17
+ return conf
18
+
19
+
20
+ def flatten_omega_conf(cfg: Any, resolve: bool = False) -> List[Tuple[str, Any]]:
21
+ ret = []
22
+
23
+ def handle_dict(key: Any, value: Any, resolve: bool) -> List[Tuple[str, Any]]:
24
+ return [(f"{key}.{k1}", v1) for k1, v1 in flatten_omega_conf(value, resolve=resolve)]
25
+
26
+ def handle_list(key: Any, value: Any, resolve: bool) -> List[Tuple[str, Any]]:
27
+ return [(f"{key}.{idx}", v1) for idx, v1 in flatten_omega_conf(value, resolve=resolve)]
28
+
29
+ if isinstance(cfg, DictConfig):
30
+ for k, v in cfg.items_ex(resolve=resolve):
31
+ if isinstance(v, DictConfig):
32
+ ret.extend(handle_dict(k, v, resolve=resolve))
33
+ elif isinstance(v, ListConfig):
34
+ ret.extend(handle_list(k, v, resolve=resolve))
35
+ else:
36
+ ret.append((str(k), v))
37
+ elif isinstance(cfg, ListConfig):
38
+ for idx, v in enumerate(cfg._iter_ex(resolve=resolve)):
39
+ if isinstance(v, DictConfig):
40
+ ret.extend(handle_dict(idx, v, resolve=resolve))
41
+ elif isinstance(v, ListConfig):
42
+ ret.extend(handle_list(idx, v, resolve=resolve))
43
+ else:
44
+ ret.append((str(idx), v))
45
+ else:
46
+ assert False
47
+
48
+ return ret
49
+
50
+
51
+ ##################################################
52
+ # training utils
53
+ ##################################################
54
+ def soft_target_cross_entropy(logits, targets, soft_targets):
55
+ # ignore the first token from logits and targets (class id token)
56
+ logits = logits[:, 1:]
57
+ targets = targets[:, 1:]
58
+
59
+ logits = logits[..., : soft_targets.shape[-1]]
60
+
61
+ log_probs = F.log_softmax(logits, dim=-1)
62
+ padding_mask = targets.eq(-100)
63
+
64
+ loss = torch.sum(-soft_targets * log_probs, dim=-1)
65
+ loss.masked_fill_(padding_mask, 0.0)
66
+
67
+ # Take the mean over the label dimensions, then divide by the number of active elements (i.e. not-padded):
68
+ num_active_elements = padding_mask.numel() - padding_mask.long().sum()
69
+ loss = loss.sum() / num_active_elements
70
+ return loss
71
+
72
+
73
+ def get_loss_weight(t, mask, min_val=0.3):
74
+ return 1 - (1 - mask) * ((1 - t) * (1 - min_val))[:, None]
75
+
76
+
77
+ def mask_or_random_replace_tokens(image_tokens, mask_id, config, mask_schedule, is_train=True):
78
+ batch_size, seq_len = image_tokens.shape
79
+
80
+ if not is_train and config.training.get("eval_mask_ratios", None):
81
+ mask_prob = random.choices(config.training.eval_mask_ratios, k=batch_size)
82
+ mask_prob = torch.tensor(mask_prob, device=image_tokens.device)
83
+ else:
84
+ # Sample a random timestep for each image
85
+ timesteps = torch.rand(batch_size, device=image_tokens.device)
86
+ # Sample a random mask probability for each image using timestep and cosine schedule
87
+ mask_prob = mask_schedule(timesteps)
88
+ mask_prob = mask_prob.clip(config.training.min_masking_rate)
89
+
90
+ # creat a random mask for each image
91
+ num_token_masked = (seq_len * mask_prob).round().clamp(min=1)
92
+
93
+ mask_contiguous_region_prob = config.training.get("mask_contiguous_region_prob", None)
94
+
95
+ if mask_contiguous_region_prob is None:
96
+ mask_contiguous_region = False
97
+ else:
98
+ mask_contiguous_region = random.random() < mask_contiguous_region_prob
99
+
100
+ if not mask_contiguous_region:
101
+ batch_randperm = torch.rand(batch_size, seq_len, device=image_tokens.device).argsort(dim=-1)
102
+ mask = batch_randperm < num_token_masked.unsqueeze(-1)
103
+ else:
104
+ resolution = int(seq_len ** 0.5)
105
+ mask = torch.zeros((batch_size, resolution, resolution), device=image_tokens.device)
106
+
107
+ # TODO - would be nice to vectorize
108
+ for batch_idx, num_token_masked_ in enumerate(num_token_masked):
109
+ num_token_masked_ = int(num_token_masked_.item())
110
+
111
+ # NOTE: a bit handwavy with the bounds but gets a rectangle of ~num_token_masked_
112
+ num_token_masked_height = random.randint(
113
+ math.ceil(num_token_masked_ / resolution), min(resolution, num_token_masked_)
114
+ )
115
+ num_token_masked_height = min(num_token_masked_height, resolution)
116
+
117
+ num_token_masked_width = math.ceil(num_token_masked_ / num_token_masked_height)
118
+ num_token_masked_width = min(num_token_masked_width, resolution)
119
+
120
+ start_idx_height = random.randint(0, resolution - num_token_masked_height)
121
+ start_idx_width = random.randint(0, resolution - num_token_masked_width)
122
+
123
+ mask[
124
+ batch_idx,
125
+ start_idx_height: start_idx_height + num_token_masked_height,
126
+ start_idx_width: start_idx_width + num_token_masked_width,
127
+ ] = 1
128
+
129
+ mask = mask.reshape(batch_size, seq_len)
130
+ mask = mask.to(torch.bool)
131
+
132
+ # mask images and create input and labels
133
+ if config.training.get("noise_type", "mask"):
134
+ input_ids = torch.where(mask, mask_id, image_tokens)
135
+ elif config.training.get("noise_type", "random_replace"):
136
+ # sample random tokens from the vocabulary
137
+ random_tokens = torch.randint_like(
138
+ image_tokens, low=0, high=config.model.codebook_size, device=image_tokens.device
139
+ )
140
+ input_ids = torch.where(mask, random_tokens, image_tokens)
141
+ else:
142
+ raise ValueError(f"noise_type {config.training.noise_type} not supported")
143
+
144
+ if (
145
+ config.training.get("predict_all_tokens", False)
146
+ or config.training.get("noise_type", "mask") == "random_replace"
147
+ ):
148
+ labels = image_tokens
149
+ loss_weight = get_loss_weight(mask_prob, mask.long())
150
+ else:
151
+ labels = torch.where(mask, image_tokens, -100)
152
+ loss_weight = None
153
+
154
+ return input_ids, labels, loss_weight, mask_prob
155
+
156
+
157
+ ##################################################
158
+ # misc
159
+ ##################################################
160
+ class AverageMeter(object):
161
+ """Computes and stores the average and current value"""
162
+
163
+ def __init__(self):
164
+ self.reset()
165
+
166
+ def reset(self):
167
+ self.val = 0
168
+ self.avg = 0
169
+ self.sum = 0
170
+ self.count = 0
171
+
172
+ def update(self, val, n=1):
173
+ self.val = val
174
+ self.sum += val * n
175
+ self.count += n
176
+ self.avg = self.sum / self.count
177
+
178
+ from torchvision import transforms
179
+ def image_transform(image, resolution=256, normalize=True):
180
+ image = transforms.Resize(resolution, interpolation=transforms.InterpolationMode.BILINEAR)(image)
181
+ image = transforms.CenterCrop((resolution, resolution))(image)
182
+ image = transforms.ToTensor()(image)
183
+ if normalize:
184
+ image = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)(image)
185
+ return image
training_utils.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from omegaconf import DictConfig, ListConfig, OmegaConf
6
+ from typing import Any, List, Tuple, Union
7
+
8
+
9
+ ##################################################
10
+ # config utils
11
+ ##################################################
12
+ def get_config():
13
+ cli_conf = OmegaConf.from_cli()
14
+ yaml_conf = OmegaConf.load(cli_conf.config)
15
+ conf = OmegaConf.merge(yaml_conf, cli_conf)
16
+
17
+ return conf
18
+
19
+
20
+ def flatten_omega_conf(cfg: Any, resolve: bool = False) -> List[Tuple[str, Any]]:
21
+ ret = []
22
+
23
+ def handle_dict(key: Any, value: Any, resolve: bool) -> List[Tuple[str, Any]]:
24
+ return [(f"{key}.{k1}", v1) for k1, v1 in flatten_omega_conf(value, resolve=resolve)]
25
+
26
+ def handle_list(key: Any, value: Any, resolve: bool) -> List[Tuple[str, Any]]:
27
+ return [(f"{key}.{idx}", v1) for idx, v1 in flatten_omega_conf(value, resolve=resolve)]
28
+
29
+ if isinstance(cfg, DictConfig):
30
+ for k, v in cfg.items_ex(resolve=resolve):
31
+ if isinstance(v, DictConfig):
32
+ ret.extend(handle_dict(k, v, resolve=resolve))
33
+ elif isinstance(v, ListConfig):
34
+ ret.extend(handle_list(k, v, resolve=resolve))
35
+ else:
36
+ ret.append((str(k), v))
37
+ elif isinstance(cfg, ListConfig):
38
+ for idx, v in enumerate(cfg._iter_ex(resolve=resolve)):
39
+ if isinstance(v, DictConfig):
40
+ ret.extend(handle_dict(idx, v, resolve=resolve))
41
+ elif isinstance(v, ListConfig):
42
+ ret.extend(handle_list(idx, v, resolve=resolve))
43
+ else:
44
+ ret.append((str(idx), v))
45
+ else:
46
+ assert False
47
+
48
+ return ret
49
+
50
+
51
+ ##################################################
52
+ # training utils
53
+ ##################################################
54
+ def soft_target_cross_entropy(logits, targets, soft_targets):
55
+ # ignore the first token from logits and targets (class id token)
56
+ logits = logits[:, 1:]
57
+ targets = targets[:, 1:]
58
+
59
+ logits = logits[..., : soft_targets.shape[-1]]
60
+
61
+ log_probs = F.log_softmax(logits, dim=-1)
62
+ padding_mask = targets.eq(-100)
63
+
64
+ loss = torch.sum(-soft_targets * log_probs, dim=-1)
65
+ loss.masked_fill_(padding_mask, 0.0)
66
+
67
+ # Take the mean over the label dimensions, then divide by the number of active elements (i.e. not-padded):
68
+ num_active_elements = padding_mask.numel() - padding_mask.long().sum()
69
+ loss = loss.sum() / num_active_elements
70
+ return loss
71
+
72
+
73
+ def get_loss_weight(t, mask, min_val=0.3):
74
+ return 1 - (1 - mask) * ((1 - t) * (1 - min_val))[:, None]
75
+
76
+
77
+ def mask_or_random_replace_tokens(image_tokens, mask_id, config, mask_schedule, is_train=True):
78
+ batch_size, seq_len = image_tokens.shape
79
+
80
+ if not is_train and config.training.get("eval_mask_ratios", None):
81
+ mask_prob = random.choices(config.training.eval_mask_ratios, k=batch_size)
82
+ mask_prob = torch.tensor(mask_prob, device=image_tokens.device)
83
+ else:
84
+ # Sample a random timestep for each image
85
+ timesteps = torch.rand(batch_size, device=image_tokens.device)
86
+ # Sample a random mask probability for each image using timestep and cosine schedule
87
+ mask_prob = mask_schedule(timesteps)
88
+ mask_prob = mask_prob.clip(config.training.min_masking_rate)
89
+
90
+ # creat a random mask for each image
91
+ num_token_masked = (seq_len * mask_prob).round().clamp(min=1)
92
+
93
+ mask_contiguous_region_prob = config.training.get("mask_contiguous_region_prob", None)
94
+
95
+ if mask_contiguous_region_prob is None:
96
+ mask_contiguous_region = False
97
+ else:
98
+ mask_contiguous_region = random.random() < mask_contiguous_region_prob
99
+
100
+ if not mask_contiguous_region:
101
+ batch_randperm = torch.rand(batch_size, seq_len, device=image_tokens.device).argsort(dim=-1)
102
+ mask = batch_randperm < num_token_masked.unsqueeze(-1)
103
+ else:
104
+ resolution = int(seq_len ** 0.5)
105
+ mask = torch.zeros((batch_size, resolution, resolution), device=image_tokens.device)
106
+
107
+ # TODO - would be nice to vectorize
108
+ for batch_idx, num_token_masked_ in enumerate(num_token_masked):
109
+ num_token_masked_ = int(num_token_masked_.item())
110
+
111
+ # NOTE: a bit handwavy with the bounds but gets a rectangle of ~num_token_masked_
112
+ num_token_masked_height = random.randint(
113
+ math.ceil(num_token_masked_ / resolution), min(resolution, num_token_masked_)
114
+ )
115
+ num_token_masked_height = min(num_token_masked_height, resolution)
116
+
117
+ num_token_masked_width = math.ceil(num_token_masked_ / num_token_masked_height)
118
+ num_token_masked_width = min(num_token_masked_width, resolution)
119
+
120
+ start_idx_height = random.randint(0, resolution - num_token_masked_height)
121
+ start_idx_width = random.randint(0, resolution - num_token_masked_width)
122
+
123
+ mask[
124
+ batch_idx,
125
+ start_idx_height: start_idx_height + num_token_masked_height,
126
+ start_idx_width: start_idx_width + num_token_masked_width,
127
+ ] = 1
128
+
129
+ mask = mask.reshape(batch_size, seq_len)
130
+ mask = mask.to(torch.bool)
131
+
132
+ # mask images and create input and labels
133
+ if config.training.get("noise_type", "mask"):
134
+ input_ids = torch.where(mask, mask_id, image_tokens)
135
+ elif config.training.get("noise_type", "random_replace"):
136
+ # sample random tokens from the vocabulary
137
+ random_tokens = torch.randint_like(
138
+ image_tokens, low=0, high=config.model.codebook_size, device=image_tokens.device
139
+ )
140
+ input_ids = torch.where(mask, random_tokens, image_tokens)
141
+ else:
142
+ raise ValueError(f"noise_type {config.training.noise_type} not supported")
143
+
144
+ if (
145
+ config.training.get("predict_all_tokens", False)
146
+ or config.training.get("noise_type", "mask") == "random_replace"
147
+ ):
148
+ labels = image_tokens
149
+ loss_weight = get_loss_weight(mask_prob, mask.long())
150
+ else:
151
+ labels = torch.where(mask, image_tokens, -100)
152
+ loss_weight = None
153
+
154
+ return input_ids, labels, loss_weight, mask_prob
155
+
156
+
157
+ ##################################################
158
+ # misc
159
+ ##################################################
160
+ class AverageMeter(object):
161
+ """Computes and stores the average and current value"""
162
+
163
+ def __init__(self):
164
+ self.reset()
165
+
166
+ def reset(self):
167
+ self.val = 0
168
+ self.avg = 0
169
+ self.sum = 0
170
+ self.count = 0
171
+
172
+ def update(self, val, n=1):
173
+ self.val = val
174
+ self.sum += val * n
175
+ self.count += n
176
+ self.avg = self.sum / self.count
177
+
178
+ from torchvision import transforms
179
+ def image_transform(image, resolution=256, normalize=True):
180
+ image = transforms.Resize(resolution, interpolation=transforms.InterpolationMode.BILINEAR)(image)
181
+ image = transforms.CenterCrop((resolution, resolution))(image)
182
+ image = transforms.ToTensor()(image)
183
+ if normalize:
184
+ image = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)(image)
185
+ return image
validation_prompts/showoprompts.txt ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Close-up view of a computer screen, with the screen displaying a webpage.
2
+ A tranquil scene of a lotus pond with koi fish swimming gracefully in a peaceful Chinese ink painting.
3
+ Paper artwork, layered paper, colorful Chinese dragon surrounded by clouds.
4
+ Pixel art character riding a dragon through the clouds.
5
+ A peaceful village nestled at the foot of towering mountains in a tranquil East Asian watercolor scene.
6
+ A person with swirling patterns of teal paint on their face and a shimmering silver crescent moon placed above their eyebrow, symbolizing mystery and magic.
7
+ A dynamic scene of a rally car race.
8
+ The breathtaking view of Santorini, a renowned landmark in Greece. The white-washed buildings with blue domes overlook the deep blue waters of the Aegean Sea, creating a stunning contrast against the vibrant sunset.
9
+ An abstract portrait of a pensive face, rendered in cool shades of blues, purples, and grays.
10
+ A punk rock frog in a studded leather jacket shouting into a microphone while standing on a boulder.
11
+ A rebellious squirrel in a studded denim vest, strumming an electric guitar with fervor in a forest clearing.
12
+ a captivating watercolor portrait of a dog's head, rendered in a vibrant palette of colors.
13
+ A captivating watercolor portrait of a cat's face, rendered in a soft palette of pastels.
14
+ A captivating watercolor portrait of a rabbit's profile, rendered in gentle hues of pinks and browns.
15
+ a white Lamborghini Gallardo Spyder is parked on a cobblestone street.
16
+ the breathtaking beauty of Whitehaven Beach.
17
+ The breathtaking view of Moraine Lake, a renowned landmark in Canada. The turquoise waters of the lake reflect the rugged peaks of the Valley of the Ten Peaks, creating a scene of unparalleled natural beauty.
18
+ The breathtaking view of Mount Fuji, a renowned landmark in Japan. The iconic snow-capped peak rises majestically above the surrounding landscape, mirrored perfectly in the tranquil waters of Lake Kawaguchi.
19
+ A bustling Asian market at night, with colorful lanterns, street food vendors, and a mix of traditional and modern architecture.
20
+ A stunning coastal cliffside at sunset, with waves crashing against the rocks and the sky painted in shades of orange, pink, and purple.
21
+ A tranquil island paradise, with a white sandy beach, crystal-clear water, and palm trees swaying in the gentle breeze.
22
+ A carnival of dreams where carousel horses gallop into the sky and cotton candy clouds drift by.
23
+ A floating market in the sky where clouds serve as stalls for trading dreams.
24
+ Intricate paper-cut creation featuring a vibrant peacock perched among blooming flowers.