MohamedRashad commited on
Commit
32287b3
·
1 Parent(s): a8efd17

Add initial project structure with requirements and utility functions

Browse files
.gitignore ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # UV
98
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ #uv.lock
102
+
103
+ # poetry
104
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
106
+ # commonly ignored for libraries.
107
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108
+ #poetry.lock
109
+
110
+ # pdm
111
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112
+ #pdm.lock
113
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114
+ # in version control.
115
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
116
+ .pdm.toml
117
+ .pdm-python
118
+ .pdm-build/
119
+
120
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121
+ __pypackages__/
122
+
123
+ # Celery stuff
124
+ celerybeat-schedule
125
+ celerybeat.pid
126
+
127
+ # SageMath parsed files
128
+ *.sage.py
129
+
130
+ # Environments
131
+ .env
132
+ .venv
133
+ env/
134
+ venv/
135
+ ENV/
136
+ env.bak/
137
+ venv.bak/
138
+
139
+ # Spyder project settings
140
+ .spyderproject
141
+ .spyproject
142
+
143
+ # Rope project settings
144
+ .ropeproject
145
+
146
+ # mkdocs documentation
147
+ /site
148
+
149
+ # mypy
150
+ .mypy_cache/
151
+ .dmypy.json
152
+ dmypy.json
153
+
154
+ # Pyre type checker
155
+ .pyre/
156
+
157
+ # pytype static type analyzer
158
+ .pytype/
159
+
160
+ # Cython debug symbols
161
+ cython_debug/
162
+
163
+ # PyCharm
164
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
167
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168
+ #.idea/
169
+
170
+ # PyPI configuration file
171
+ .pypirc
app.py ADDED
@@ -0,0 +1,475 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
3
+
4
+ import os.path as osp
5
+ import time
6
+ import hashlib
7
+ import argparse
8
+ import shutil
9
+ import re
10
+ import random
11
+ from pathlib import Path
12
+ from typing import List
13
+
14
+ import cv2
15
+ import numpy as np
16
+ import pandas as pd
17
+ import torch
18
+ import torch.nn.functional as F
19
+ from PIL import Image, ImageEnhance
20
+ import PIL.Image as PImage
21
+ from torchvision.transforms.functional import to_tensor
22
+ from transformers import AutoTokenizer, T5EncoderModel, T5TokenizerFast, T5Tokenizer, T5ForConditionalGeneration
23
+ from huggingface_hub import hf_hub_download
24
+ import gradio as gr
25
+ import spaces
26
+
27
+ from models.infinity import Infinity
28
+ from models.basic import *
29
+ from utils.dynamic_resolution import dynamic_resolution_h_w, h_div_w_templates
30
+
31
+ torch._dynamo.config.cache_size_limit = 64
32
+
33
+ # Define a function to download weights if not present
34
+ def download_weights(weights_path):
35
+ try:
36
+ model_file = weights_path / 'infinity_2b_reg.pth'
37
+ if not model_file.exists():
38
+ hf_hub_download(repo_id="FoundationVision/Infinity", filename="infinity_2b_reg.pth", local_dir=str(weights_path))
39
+
40
+ vae_file = weights_path / 'infinity_vae_d32reg.pth'
41
+ if not vae_file.exists():
42
+ hf_hub_download(repo_id="FoundationVision/Infinity", filename="infinity_vae_d32reg.pth", local_dir=str(weights_path))
43
+
44
+ # For the text encoder, we need to download the entire model
45
+ text_encoder_ckpt = weights_path / 'flan-t5-xl'
46
+ if not text_encoder_ckpt.exists():
47
+ tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-xl")
48
+ model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-xl")
49
+ tokenizer.save_pretrained(text_encoder_ckpt)
50
+ model.save_pretrained(text_encoder_ckpt)
51
+ except Exception as e:
52
+ print(f"Error downloading weights: {e}")
53
+
54
+ def extract_key_val(text):
55
+ pattern = r'<(.+?):(.+?)>'
56
+ matches = re.findall(pattern, text)
57
+ key_val = {}
58
+ for match in matches:
59
+ key_val[match[0]] = match[1].lstrip()
60
+ return key_val
61
+
62
+ def encode_prompt(text_tokenizer, text_encoder, prompt, enable_positive_prompt=False):
63
+ if enable_positive_prompt:
64
+ print(f'before positive_prompt aug: {prompt}')
65
+ prompt = aug_with_positive_prompt(prompt)
66
+ print(f'after positive_prompt aug: {prompt}')
67
+ print(f'prompt={prompt}')
68
+ captions = [prompt]
69
+ tokens = text_tokenizer(text=captions, max_length=512, padding='max_length', truncation=True, return_tensors='pt') # todo: put this into dataset
70
+ input_ids = tokens.input_ids.cuda(non_blocking=True)
71
+ mask = tokens.attention_mask.cuda(non_blocking=True)
72
+ text_features = text_encoder(input_ids=input_ids, attention_mask=mask)['last_hidden_state'].float()
73
+ lens: List[int] = mask.sum(dim=-1).tolist()
74
+ cu_seqlens_k = F.pad(mask.sum(dim=-1).to(dtype=torch.int32).cumsum_(0), (1, 0))
75
+ Ltext = max(lens)
76
+ kv_compact = []
77
+ for len_i, feat_i in zip(lens, text_features.unbind(0)):
78
+ kv_compact.append(feat_i[:len_i])
79
+ kv_compact = torch.cat(kv_compact, dim=0)
80
+ text_cond_tuple = (kv_compact, lens, cu_seqlens_k, Ltext)
81
+ return text_cond_tuple
82
+
83
+ def aug_with_positive_prompt(prompt):
84
+ for key in ['man', 'woman', 'men', 'women', 'boy', 'girl', 'child', 'person', 'human', 'adult', 'teenager', 'employee',
85
+ 'employer', 'worker', 'mother', 'father', 'sister', 'brother', 'grandmother', 'grandfather', 'son', 'daughter']:
86
+ if key in prompt:
87
+ prompt = prompt + '. very smooth faces, good looking faces, face to the camera, perfect facial features'
88
+ break
89
+ return prompt
90
+
91
+ def enhance_image(image):
92
+ for t in range(1):
93
+ contrast_image = image.copy()
94
+ contrast_enhancer = ImageEnhance.Contrast(contrast_image)
95
+ contrast_image = contrast_enhancer.enhance(1.05) # 增强对比度
96
+ color_image = contrast_image.copy()
97
+ color_enhancer = ImageEnhance.Color(color_image)
98
+ color_image = color_enhancer.enhance(1.05) # 增强饱和度
99
+ return color_image
100
+
101
+ def gen_one_img(
102
+ infinity_test,
103
+ vae,
104
+ text_tokenizer,
105
+ text_encoder,
106
+ prompt,
107
+ cfg_list=[],
108
+ tau_list=[],
109
+ negative_prompt='',
110
+ scale_schedule=None,
111
+ top_k=900,
112
+ top_p=0.97,
113
+ cfg_sc=3,
114
+ cfg_exp_k=0.0,
115
+ cfg_insertion_layer=-5,
116
+ vae_type=0,
117
+ gumbel=0,
118
+ softmax_merge_topk=-1,
119
+ gt_leak=-1,
120
+ gt_ls_Bl=None,
121
+ g_seed=None,
122
+ sampling_per_bits=1,
123
+ enable_positive_prompt=0,
124
+ ):
125
+ sstt = time.time()
126
+ if not isinstance(cfg_list, list):
127
+ cfg_list = [cfg_list] * len(scale_schedule)
128
+ if not isinstance(tau_list, list):
129
+ tau_list = [tau_list] * len(scale_schedule)
130
+ text_cond_tuple = encode_prompt(text_tokenizer, text_encoder, prompt, enable_positive_prompt)
131
+ if negative_prompt:
132
+ negative_label_B_or_BLT = encode_prompt(text_tokenizer, text_encoder, negative_prompt)
133
+ else:
134
+ negative_label_B_or_BLT = None
135
+ print(f'cfg: {cfg_list}, tau: {tau_list}')
136
+ with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16, cache_enabled=True):
137
+ stt = time.time()
138
+ _, _, img_list = infinity_test.autoregressive_infer_cfg(
139
+ vae=vae,
140
+ scale_schedule=scale_schedule,
141
+ label_B_or_BLT=text_cond_tuple, g_seed=g_seed,
142
+ B=1, negative_label_B_or_BLT=negative_label_B_or_BLT, force_gt_Bhw=None,
143
+ cfg_sc=cfg_sc, cfg_list=cfg_list, tau_list=tau_list, top_k=top_k, top_p=top_p,
144
+ returns_vemb=1, ratio_Bl1=None, gumbel=gumbel, norm_cfg=False,
145
+ cfg_exp_k=cfg_exp_k, cfg_insertion_layer=cfg_insertion_layer,
146
+ vae_type=vae_type, softmax_merge_topk=softmax_merge_topk,
147
+ ret_img=True, trunk_scale=1000,
148
+ gt_leak=gt_leak, gt_ls_Bl=gt_ls_Bl, inference_mode=True,
149
+ sampling_per_bits=sampling_per_bits,
150
+ )
151
+ print(f"cost: {time.time() - sstt}, infinity cost={time.time() - stt}")
152
+ img = img_list[0]
153
+ return img
154
+
155
+ def get_prompt_id(prompt):
156
+ md5 = hashlib.md5()
157
+ md5.update(prompt.encode('utf-8'))
158
+ prompt_id = md5.hexdigest()
159
+ return prompt_id
160
+
161
+ def save_slim_model(infinity_model_path, save_file=None, device='cpu', key='gpt_fsdp'):
162
+ print('[Save slim model]')
163
+ full_ckpt = torch.load(infinity_model_path, map_location=device)
164
+ infinity_slim = full_ckpt['trainer'][key]
165
+ # ema_state_dict = cpu_d['trainer'].get('gpt_ema_fsdp', state_dict)
166
+ if not save_file:
167
+ save_file = osp.splitext(infinity_model_path)[0] + '-slim.pth'
168
+ print(f'Save to {save_file}')
169
+ torch.save(infinity_slim, save_file)
170
+ print('[Save slim model] done')
171
+ return save_file
172
+
173
+ def load_tokenizer(t5_path =''):
174
+ print(f'[Loading tokenizer and text encoder]')
175
+ text_tokenizer: T5TokenizerFast = AutoTokenizer.from_pretrained(t5_path, revision=None, legacy=True)
176
+ text_tokenizer.model_max_length = 512
177
+ text_encoder: T5EncoderModel = T5EncoderModel.from_pretrained(t5_path, torch_dtype=torch.float16)
178
+ text_encoder.to('cuda')
179
+ text_encoder.eval()
180
+ text_encoder.requires_grad_(False)
181
+ return text_tokenizer, text_encoder
182
+
183
+ def load_infinity(
184
+ rope2d_each_sa_layer,
185
+ rope2d_normalized_by_hw,
186
+ use_scale_schedule_embedding,
187
+ pn,
188
+ use_bit_label,
189
+ add_lvl_embeding_only_first_block,
190
+ model_path='',
191
+ scale_schedule=None,
192
+ vae=None,
193
+ device='cuda',
194
+ model_kwargs=None,
195
+ text_channels=2048,
196
+ apply_spatial_patchify=0,
197
+ use_flex_attn=False,
198
+ bf16=False,
199
+ ):
200
+ print(f'[Loading Infinity]')
201
+ text_maxlen = 512
202
+ with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16, cache_enabled=True), torch.no_grad():
203
+ infinity_test: Infinity = Infinity(
204
+ vae_local=vae, text_channels=text_channels, text_maxlen=text_maxlen,
205
+ shared_aln=True, raw_scale_schedule=scale_schedule,
206
+ checkpointing='full-block',
207
+ customized_flash_attn=False,
208
+ fused_norm=True,
209
+ pad_to_multiplier=128,
210
+ use_flex_attn=use_flex_attn,
211
+ add_lvl_embeding_only_first_block=add_lvl_embeding_only_first_block,
212
+ use_bit_label=use_bit_label,
213
+ rope2d_each_sa_layer=rope2d_each_sa_layer,
214
+ rope2d_normalized_by_hw=rope2d_normalized_by_hw,
215
+ pn=pn,
216
+ apply_spatial_patchify=apply_spatial_patchify,
217
+ inference_mode=True,
218
+ train_h_div_w_list=[1.0],
219
+ **model_kwargs,
220
+ ).to(device=device)
221
+ print(f'[you selected Infinity with {model_kwargs=}] model size: {sum(p.numel() for p in infinity_test.parameters())/1e9:.2f}B, bf16={bf16}')
222
+
223
+ if bf16:
224
+ for block in infinity_test.unregistered_blocks:
225
+ block.bfloat16()
226
+
227
+ infinity_test.eval()
228
+ infinity_test.requires_grad_(False)
229
+
230
+ infinity_test.cuda()
231
+ torch.cuda.empty_cache()
232
+
233
+ print(f'[Load Infinity weights]')
234
+ state_dict = torch.load(model_path, map_location=device)
235
+ print(infinity_test.load_state_dict(state_dict))
236
+ infinity_test.rng = torch.Generator(device=device)
237
+ return infinity_test
238
+
239
+ def transform(pil_img, tgt_h, tgt_w):
240
+ width, height = pil_img.size
241
+ if width / height <= tgt_w / tgt_h:
242
+ resized_width = tgt_w
243
+ resized_height = int(tgt_w / (width / height))
244
+ else:
245
+ resized_height = tgt_h
246
+ resized_width = int((width / height) * tgt_h)
247
+ pil_img = pil_img.resize((resized_width, resized_height), resample=PImage.LANCZOS)
248
+ # crop the center out
249
+ arr = np.array(pil_img)
250
+ crop_y = (arr.shape[0] - tgt_h) // 2
251
+ crop_x = (arr.shape[1] - tgt_w) // 2
252
+ im = to_tensor(arr[crop_y: crop_y + tgt_h, crop_x: crop_x + tgt_w])
253
+ return im.add(im).add_(-1)
254
+
255
+ def joint_vi_vae_encode_decode(vae, image_path, scale_schedule, device, tgt_h, tgt_w):
256
+ pil_image = Image.open(image_path).convert('RGB')
257
+ inp = transform(pil_image, tgt_h, tgt_w)
258
+ inp = inp.unsqueeze(0).to(device)
259
+ scale_schedule = [(item[0], item[1], item[2]) for item in scale_schedule]
260
+ t1 = time.time()
261
+ h, z, _, all_bit_indices, _, infinity_input = vae.encode(inp, scale_schedule=scale_schedule)
262
+ t2 = time.time()
263
+ recons_img = vae.decode(z)[0]
264
+ if len(recons_img.shape) == 4:
265
+ recons_img = recons_img.squeeze(1)
266
+ print(f'recons: z.shape: {z.shape}, recons_img shape: {recons_img.shape}')
267
+ t3 = time.time()
268
+ print(f'vae encode takes {t2-t1:.2f}s, decode takes {t3-t2:.2f}s')
269
+ recons_img = (recons_img + 1) / 2
270
+ recons_img = recons_img.permute(1, 2, 0).mul_(255).cpu().numpy().astype(np.uint8)
271
+ gt_img = (inp[0] + 1) / 2
272
+ gt_img = gt_img.permute(1, 2, 0).mul_(255).cpu().numpy().astype(np.uint8)
273
+ print(recons_img.shape, gt_img.shape)
274
+ return gt_img, recons_img, all_bit_indices
275
+
276
+ def load_visual_tokenizer(args):
277
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
278
+ # load vae
279
+ if args.vae_type in [16,18,20,24,32,64]:
280
+ from models.bsq_vae.vae import vae_model
281
+ schedule_mode = "dynamic"
282
+ codebook_dim = args.vae_type
283
+ codebook_size = 2**codebook_dim
284
+ if args.apply_spatial_patchify:
285
+ patch_size = 8
286
+ encoder_ch_mult=[1, 2, 4, 4]
287
+ decoder_ch_mult=[1, 2, 4, 4]
288
+ else:
289
+ patch_size = 16
290
+ encoder_ch_mult=[1, 2, 4, 4, 4]
291
+ decoder_ch_mult=[1, 2, 4, 4, 4]
292
+ vae = vae_model(args.vae_path, schedule_mode, codebook_dim, codebook_size, patch_size=patch_size,
293
+ encoder_ch_mult=encoder_ch_mult, decoder_ch_mult=decoder_ch_mult, test_mode=True).to(device)
294
+ else:
295
+ raise ValueError(f'vae_type={args.vae_type} not supported')
296
+ return vae
297
+
298
+ def load_transformer(vae, args):
299
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
300
+ model_path = args.model_path
301
+ if args.checkpoint_type == 'torch':
302
+ # copy large model to local; save slim to local; and copy slim to nas; load local slim model
303
+ if osp.exists(args.cache_dir):
304
+ local_model_path = osp.join(args.cache_dir, 'tmp', model_path.replace('/', '_'))
305
+ else:
306
+ local_model_path = model_path
307
+ if args.enable_model_cache:
308
+ slim_model_path = model_path.replace('ar-', 'slim-')
309
+ local_slim_model_path = local_model_path.replace('ar-', 'slim-')
310
+ os.makedirs(osp.dirname(local_slim_model_path), exist_ok=True)
311
+ print(f'model_path: {model_path}, slim_model_path: {slim_model_path}')
312
+ print(f'local_model_path: {local_model_path}, local_slim_model_path: {local_slim_model_path}')
313
+ if not osp.exists(local_slim_model_path):
314
+ if osp.exists(slim_model_path):
315
+ print(f'copy {slim_model_path} to {local_slim_model_path}')
316
+ shutil.copyfile(slim_model_path, local_slim_model_path)
317
+ else:
318
+ if not osp.exists(local_model_path):
319
+ print(f'copy {model_path} to {local_model_path}')
320
+ shutil.copyfile(model_path, local_model_path)
321
+ save_slim_model(local_model_path, save_file=local_slim_model_path, device=device)
322
+ print(f'copy {local_slim_model_path} to {slim_model_path}')
323
+ if not osp.exists(slim_model_path):
324
+ shutil.copyfile(local_slim_model_path, slim_model_path)
325
+ os.remove(local_model_path)
326
+ os.remove(model_path)
327
+ slim_model_path = local_slim_model_path
328
+ else:
329
+ slim_model_path = model_path
330
+ print(f'load checkpoint from {slim_model_path}')
331
+
332
+ if args.model_type == 'infinity_2b':
333
+ kwargs_model = dict(depth=32, embed_dim=2048, num_heads=2048//128, drop_path_rate=0.1, mlp_ratio=4, block_chunks=8) # 2b model
334
+ elif args.model_type == 'infinity_layer12':
335
+ kwargs_model = dict(depth=12, embed_dim=768, num_heads=8, drop_path_rate=0.1, mlp_ratio=4, block_chunks=4)
336
+ elif args.model_type == 'infinity_layer16':
337
+ kwargs_model = dict(depth=16, embed_dim=1152, num_heads=12, drop_path_rate=0.1, mlp_ratio=4, block_chunks=4)
338
+ elif args.model_type == 'infinity_layer24':
339
+ kwargs_model = dict(depth=24, embed_dim=1536, num_heads=16, drop_path_rate=0.1, mlp_ratio=4, block_chunks=4)
340
+ elif args.model_type == 'infinity_layer32':
341
+ kwargs_model = dict(depth=32, embed_dim=2080, num_heads=20, drop_path_rate=0.1, mlp_ratio=4, block_chunks=4)
342
+ elif args.model_type == 'infinity_layer40':
343
+ kwargs_model = dict(depth=40, embed_dim=2688, num_heads=24, drop_path_rate=0.1, mlp_ratio=4, block_chunks=4)
344
+ elif args.model_type == 'infinity_layer48':
345
+ kwargs_model = dict(depth=48, embed_dim=3360, num_heads=28, drop_path_rate=0.1, mlp_ratio=4, block_chunks=4)
346
+ infinity = load_infinity(
347
+ rope2d_each_sa_layer=args.rope2d_each_sa_layer,
348
+ rope2d_normalized_by_hw=args.rope2d_normalized_by_hw,
349
+ use_scale_schedule_embedding=args.use_scale_schedule_embedding,
350
+ pn=args.pn,
351
+ use_bit_label=args.use_bit_label,
352
+ add_lvl_embeding_only_first_block=args.add_lvl_embeding_only_first_block,
353
+ model_path=slim_model_path,
354
+ scale_schedule=None,
355
+ vae=vae,
356
+ device=device,
357
+ model_kwargs=kwargs_model,
358
+ text_channels=args.text_channels,
359
+ apply_spatial_patchify=args.apply_spatial_patchify,
360
+ use_flex_attn=args.use_flex_attn,
361
+ bf16=args.bf16,
362
+ )
363
+ return infinity
364
+
365
+ # Set up paths
366
+ weights_path = Path(__file__).parent / 'weights'
367
+ weights_path.mkdir(exist_ok=True)
368
+ download_weights(weights_path)
369
+
370
+ # Device setup
371
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
372
+ dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float32
373
+
374
+ # Define args
375
+ args = argparse.Namespace(
376
+ pn='1M',
377
+ model_path=str(weights_path / 'infinity_2b_reg.pth'),
378
+ cfg_insertion_layer=0,
379
+ vae_type=32,
380
+ vae_path=str(weights_path / 'infinity_vae_d32reg.pth'),
381
+ add_lvl_embeding_only_first_block=1,
382
+ use_bit_label=1,
383
+ model_type='infinity_2b',
384
+ rope2d_each_sa_layer=1,
385
+ rope2d_normalized_by_hw=2,
386
+ use_scale_schedule_embedding=0,
387
+ sampling_per_bits=1,
388
+ text_encoder_ckpt=str(weights_path / 'flan-t5-xl'),
389
+ text_channels=2048,
390
+ apply_spatial_patchify=0,
391
+ h_div_w_template=1.000,
392
+ use_flex_attn=0,
393
+ cache_dir='/dev/shm',
394
+ checkpoint_type='torch',
395
+ seed=0,
396
+ bf16=1 if dtype == torch.bfloat16 else 0,
397
+ save_file='tmp.jpg',
398
+ enable_model_cache=False,
399
+ )
400
+
401
+ # Load models
402
+ text_tokenizer, text_encoder = load_tokenizer(t5_path=str(weights_path / 'flan-t5-xl'))
403
+ vae = load_visual_tokenizer(args)
404
+ infinity = load_transformer(vae, args)
405
+
406
+ # Define the image generation function
407
+ @spaces.GPU
408
+ def generate_image(prompt, cfg, tau, h_div_w, seed, enable_positive_prompt):
409
+ try:
410
+ args.prompt = prompt
411
+ args.cfg = cfg
412
+ args.tau = tau
413
+ args.h_div_w = h_div_w
414
+ args.seed = seed
415
+ args.enable_positive_prompt = enable_positive_prompt
416
+
417
+ # Find the closest h_div_w_template
418
+ h_div_w_template_ = h_div_w_templates[np.argmin(np.abs(h_div_w_templates - h_div_w))]
419
+
420
+ # Get scale_schedule based on h_div_w_template_
421
+ scale_schedule = dynamic_resolution_h_w[h_div_w_template_][args.pn]['scales']
422
+ scale_schedule = [(1, h, w) for (_, h, w) in scale_schedule]
423
+
424
+ # Generate the image
425
+ generated_image = gen_one_img(
426
+ infinity,
427
+ vae,
428
+ text_tokenizer,
429
+ text_encoder,
430
+ prompt,
431
+ g_seed=seed,
432
+ gt_leak=0,
433
+ gt_ls_Bl=None,
434
+ cfg_list=cfg,
435
+ tau_list=tau,
436
+ scale_schedule=scale_schedule,
437
+ cfg_insertion_layer=[args.cfg_insertion_layer],
438
+ vae_type=args.vae_type,
439
+ sampling_per_bits=args.sampling_per_bits,
440
+ enable_positive_prompt=enable_positive_prompt,
441
+ )
442
+
443
+ # Convert the image to RGB and uint8
444
+ image = generated_image.cpu().numpy()
445
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
446
+ image = np.uint8(image)
447
+
448
+ return image
449
+ except Exception as e:
450
+ print(f"Error generating image: {e}")
451
+ return None
452
+
453
+ # Set up Gradio interface
454
+ with gr.Blocks() as demo:
455
+ gr.Markdown("<h1><center>Infinity Image Generator</center></h1>")
456
+
457
+ with gr.Row():
458
+ prompt = gr.Textbox(label="Prompt", value="alien spaceship enterprise")
459
+ cfg = gr.Slider(label="CFG", minimum=1, maximum=10, step=0.5, value=3)
460
+ tau = gr.Slider(label="Tau", minimum=0.1, maximum=1.0, step=0.1, value=0.5)
461
+ h_div_w = gr.Slider(label="Aspect Ratio (Height/Width)", minimum=0.5, maximum=2.0, step=0.1, value=1.0)
462
+ seed = gr.Number(label="Seed", value=random.randint(0, 10000))
463
+ enable_positive_prompt = gr.Checkbox(label="Enable Positive Prompt", value=False)
464
+
465
+ generate_button = gr.Button("Generate Image")
466
+ output_image = gr.Image(label="Generated Image", type="pil")
467
+
468
+ generate_button.click(
469
+ generate_image,
470
+ inputs=[prompt, cfg, tau, h_div_w, seed, enable_positive_prompt],
471
+ outputs=output_image
472
+ )
473
+
474
+ # Launch the Gradio app
475
+ demo.launch()
models/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from timm.loss import SoftTargetCrossEntropy
3
+
4
+ from timm.models.layers import DropPath
5
+
6
+ from .infinity import Infinity, sample_with_top_k_top_p_also_inplace_modifying_logits_
7
+
8
+ def _ex_repr(self):
9
+ return ', '.join(
10
+ f'{k}=' + (f'{v:g}' if isinstance(v, float) else str(v))
11
+ for k, v in vars(self).items()
12
+ if not k.startswith('_') and k != 'training'
13
+ and not isinstance(v, (torch.nn.Module, torch.Tensor))
14
+ )
15
+ for clz in (torch.nn.CrossEntropyLoss, SoftTargetCrossEntropy): # no longer __repr__ DropPath with drop_prob
16
+ if hasattr(clz, 'extra_repr'):
17
+ clz.extra_repr = _ex_repr
18
+ else:
19
+ clz.__repr__ = lambda self: f'{type(self).__name__}({_ex_repr(self)})'
20
+
21
+ DropPath.__repr__ = lambda self: f'{type(self).__name__}(...)'
22
+
23
+ alias_dict = {}
24
+ for d in range(6, 40+2, 2):
25
+ alias_dict[f'd{d}'] = f'infinity_d{d}'
26
+ alias_dict_inv = {v: k for k, v in alias_dict.items()}
models/basic.py ADDED
@@ -0,0 +1,575 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Definitions of blocks of VAR transformer model.
3
+ """
4
+
5
+ import math
6
+ import os
7
+ from functools import partial
8
+ from typing import Optional, Tuple, Union
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ import numpy as np
14
+ from timm.models.layers import DropPath, drop_path
15
+ from torch.utils.checkpoint import checkpoint
16
+
17
+ # Import flash_attn's attention
18
+ from flash_attn import flash_attn_func # q, k, or v: BLHc, ret: BLHc
19
+ from flash_attn import flash_attn_varlen_kvpacked_func # qkv: N3Hc, ret: NHc
20
+
21
+ from torch.nn.functional import scaled_dot_product_attention as slow_attn # q, k, v: BHLc
22
+
23
+ # Import flash_attn's fused ops
24
+ try:
25
+ from flash_attn.ops.layer_norm import dropout_add_layer_norm
26
+ from flash_attn.ops.rms_norm import dropout_add_rms_norm
27
+ from flash_attn.ops.rms_norm import rms_norm as rms_norm_impl
28
+ from flash_attn.ops.fused_dense import fused_mlp_func
29
+ flash_fused_op_installed = True
30
+ except ImportError:
31
+ dropout_add_layer_norm = dropout_add_rms_norm = fused_mlp_func = None
32
+ flash_fused_op_installed = False
33
+
34
+ def rms_norm_impl(x, weight, epsilon):
35
+ return (x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True).add_(epsilon))) * weight
36
+
37
+
38
+ def precompute_rope2d_freqs_grid(dim, dynamic_resolution_h_w, rope2d_normalized_by_hw, pad_to_multiplier=1, max_height=2048 // 16, max_width=2048 // 16, base=10000.0, device=None, scaling_factor=1.0):
39
+ # split the dimension into half, one for x and one for y
40
+ half_dim = dim // 2
41
+ inv_freq = 1.0 / (base ** (torch.arange(0, half_dim, 2, dtype=torch.int64).float().to(device) / half_dim)) # namely theta, 1 / (10000^(i/half_dim)), i=0,2,..., half_dim-2
42
+ t_height = torch.arange(max_height, device=device, dtype=torch.int64).type_as(inv_freq)
43
+ t_width = torch.arange(max_width, device=device, dtype=torch.int64).type_as(inv_freq)
44
+ t_height = t_height / scaling_factor
45
+ freqs_height = torch.outer(t_height, inv_freq) # (max_height, dim / (1 for 1d, 2 for 2d, 3 for 3d) / 2), namely y*theta
46
+ t_width = t_width / scaling_factor
47
+ freqs_width = torch.outer(t_width, inv_freq) # (max_width, dim / (1 for 1d, 2 for 2d, 3 for 3d) / 2), namely x*theta
48
+ freqs_grid_map = torch.concat([
49
+ freqs_height[:, None, :].expand(-1, max_width, -1), # (max_height, max_width, dim / (1 for 1d, 2 for 2d, 3 for 3d) / 2)
50
+ freqs_width[None, :, :].expand(max_height, -1, -1), # (max_height, max_width, dim / (1 for 1d, 2 for 2d, 3 for 3d) / 2)
51
+ ], dim=-1) # (max_height, max_width, dim / (1 for 1d, 2 for 2d, 3 for 3d))
52
+ freqs_grid_map = torch.stack([torch.cos(freqs_grid_map), torch.sin(freqs_grid_map)], dim=0)
53
+ # (2, max_height, max_width, dim / (1 for 1d, 2 for 2d, 3 for 3d))
54
+
55
+ rope2d_freqs_grid = {}
56
+ for h_div_w in dynamic_resolution_h_w:
57
+ scale_schedule = dynamic_resolution_h_w[h_div_w]['1M']['scales']
58
+ _, ph, pw = scale_schedule[-1]
59
+ max_edge_length = freqs_grid_map.shape[1]
60
+ if ph >= pw:
61
+ uph, upw = max_edge_length, int(max_edge_length / ph * pw)
62
+ else:
63
+ uph, upw = int(max_edge_length / pw * ph), max_edge_length
64
+ rope_cache_list = []
65
+ for (_, ph, pw) in scale_schedule:
66
+ ph_mul_pw = ph * pw
67
+ if rope2d_normalized_by_hw == 1: # downsample
68
+ rope_cache = F.interpolate(freqs_grid_map[:, :uph, :upw, :].permute([0,3,1,2]), size=(ph, pw), mode='bilinear', align_corners=True)
69
+ rope_cache = rope_cache.permute([0,2,3,1]) # (2, ph, pw, half_head_dim)
70
+ elif rope2d_normalized_by_hw == 2: # star stylee
71
+ _, uph, upw = scale_schedule[-1]
72
+ indices = torch.stack([
73
+ (torch.arange(ph) * (uph / ph)).reshape(ph, 1).expand(ph, pw),
74
+ (torch.arange(pw) * (upw / pw)).reshape(1, pw).expand(ph, pw),
75
+ ], dim=-1).round().int() # (ph, pw, 2)
76
+ indices = indices.reshape(-1, 2) # (ph*pw, 2)
77
+ rope_cache = freqs_grid_map[:, indices[:,0], indices[:,1], :] # (2, ph*pw, half_head_dim)
78
+ rope_cache = rope_cache.reshape(2, ph, pw, -1)
79
+ elif rope2d_normalized_by_hw == 0:
80
+ rope_cache = freqs_grid_map[:, :ph, :pw, :] # (2, ph, pw, half_head_dim)
81
+ else:
82
+ raise ValueError(f'Unknown rope2d_normalized_by_hw: {rope2d_normalized_by_hw}')
83
+ rope_cache_list.append(rope_cache.reshape(2, ph_mul_pw, -1))
84
+ cat_rope_cache = torch.cat(rope_cache_list, 1) # (2, seq_len, half_head_dim)
85
+ if cat_rope_cache.shape[1] % pad_to_multiplier:
86
+ pad = torch.zeros(2, pad_to_multiplier - cat_rope_cache.shape[1] % pad_to_multiplier, half_dim)
87
+ cat_rope_cache = torch.cat([cat_rope_cache, pad], dim=1)
88
+ cat_rope_cache = cat_rope_cache[:,None,None,None] # (2, 1, 1, 1, seq_len, half_dim)
89
+ for pn in dynamic_resolution_h_w[h_div_w]:
90
+ scale_schedule = dynamic_resolution_h_w[h_div_w][pn]['scales']
91
+ tmp_scale_schedule = [(1, h, w) for _, h, w in scale_schedule]
92
+ rope2d_freqs_grid[str(tuple(tmp_scale_schedule))] = cat_rope_cache
93
+ return rope2d_freqs_grid
94
+
95
+
96
+ def apply_rotary_emb(q, k, scale_schedule, rope2d_freqs_grid, pad_to_multiplier, rope2d_normalized_by_hw, scale_ind):
97
+ qk = torch.stack((q, k), dim=0) #(2, batch_size, heads, seq_len, head_dim)
98
+ device_type = qk.device.type
99
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
100
+ with torch.autocast(device_type=device_type, enabled=False):
101
+ seq_len = qk.shape[3]
102
+ start = 0
103
+ if scale_ind >= 1:
104
+ assert len(scale_schedule[0]) == 3
105
+ start = np.sum([item[0] * item[1] * item[2] for item in scale_schedule[:scale_ind]])
106
+ rope2d_freqs_grid[str(tuple(scale_schedule))] = rope2d_freqs_grid[str(tuple(scale_schedule))].to(qk.device)
107
+ assert start+seq_len <= rope2d_freqs_grid[str(tuple(scale_schedule))].shape[4]
108
+ rope_cache = rope2d_freqs_grid[str(tuple(scale_schedule))][:, :, :, :, start:start+seq_len] # rope_cache shape: [2, 1, 1, 1, seq_len, half_head_dim]
109
+ qk = qk.reshape(*qk.shape[:-1], -1, 2) #(2, batch_size, heads, seq_len, half_head_dim, 2)
110
+ qk = torch.stack([
111
+ rope_cache[0] * qk[...,0] - rope_cache[1] * qk[...,1],
112
+ rope_cache[1] * qk[...,0] + rope_cache[0] * qk[...,1],
113
+ ], dim=-1) # (2, batch_size, heads, seq_len, half_head_dim, 2), here stack + reshape should not be concate
114
+ qk = qk.reshape(*qk.shape[:-2], -1) #(2, batch_size, heads, seq_len, head_dim)
115
+ q, k = qk.unbind(dim=0) # (batch_size, heads, seq_len, head_dim)
116
+ return q, k
117
+
118
+
119
+ class FastRMSNorm(nn.Module):
120
+ def __init__(self, C, eps=1e-6, elementwise_affine=True):
121
+ super().__init__()
122
+ self.C = C
123
+ self.eps = eps
124
+ self.elementwise_affine = elementwise_affine
125
+ if self.elementwise_affine:
126
+ self.weight = nn.Parameter(torch.ones(C))
127
+ else:
128
+ self.register_buffer('weight', torch.ones(C))
129
+
130
+ def forward(self, x):
131
+ src_type = x.dtype
132
+ return rms_norm_impl(x.float(), self.weight, epsilon=self.eps).to(src_type)
133
+
134
+ def extra_repr(self) -> str:
135
+ return f'C={self.C}, eps={self.eps:g}, elementwise_affine={self.elementwise_affine}'
136
+
137
+
138
+ def get_dropout_layer(p):
139
+ return nn.Dropout(p, inplace=True) if p > 0 else nn.Identity()
140
+
141
+
142
+ class FFN(nn.Module):
143
+ def __init__(self, in_features, hidden_features=None, out_features=None, drop=0., fused_mlp=False):
144
+ super().__init__()
145
+ self.fused_mlp_func = fused_mlp_func if fused_mlp else None
146
+ out_features = out_features or in_features
147
+ hidden_features = hidden_features or in_features
148
+ self.fc1 = nn.Linear(in_features, hidden_features)
149
+ self.act = nn.GELU(approximate='tanh')
150
+ self.fc2 = nn.Linear(hidden_features, out_features)
151
+ self.drop = get_dropout_layer(drop)
152
+ self.heuristic = -1
153
+
154
+ def forward(self, x):
155
+ if self.fused_mlp_func is not None:
156
+ return self.drop(self.fused_mlp_func(
157
+ x=x,
158
+ weight1=self.fc1.weight,
159
+ weight2=self.fc2.weight,
160
+ bias1=self.fc1.bias,
161
+ bias2=self.fc2.bias,
162
+ activation='gelu_approx',
163
+ save_pre_act=self.training,
164
+ return_residual=False,
165
+ checkpoint_lvl=0,
166
+ heuristic=self.heuristic,
167
+ process_group=None,
168
+ ))
169
+ else:
170
+ return self.drop(self.fc2( self.act(self.fc1(x)) ))
171
+
172
+ def extra_repr(self) -> str:
173
+ return f'fused_mlp={self.fused_mlp_func is not None}'
174
+
175
+
176
+ class FFNSwiGLU(nn.Module):
177
+ def __init__(self, in_features, hidden_features, out_features=None, drop=0., fused_mlp=False):
178
+ super().__init__()
179
+ self.fused_mlp_func = None
180
+ hidden_features = round(2 * hidden_features / 3 / 256) * 256
181
+
182
+ out_features = out_features or in_features
183
+ self.fcg = nn.Linear(in_features, hidden_features, bias=False)
184
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=False)
185
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=False)
186
+ self.drop = get_dropout_layer(drop)
187
+
188
+ def forward(self, x):
189
+ return self.drop(self.fc2( F.silu(self.fcg(x), inplace=True).mul_(self.fc1(x)) ))
190
+
191
+ def extra_repr(self) -> str:
192
+ return f'fused_mlp={self.fused_mlp_func is not None}'
193
+
194
+
195
+ class SelfAttention(nn.Module):
196
+ def __init__(
197
+ self, embed_dim=768, num_heads=12,
198
+ proj_drop=0., tau=1, cos_attn=False, customized_flash_attn=True, use_flex_attn=False,
199
+ batch_size=2, pad_to_multiplier=1, rope2d_normalized_by_hw=0,
200
+ ):
201
+ """
202
+ :param embed_dim: model's width
203
+ :param num_heads: num heads of multi-head attention
204
+ :param proj_drop: always 0 for testing
205
+ :param tau: always 1
206
+ :param cos_attn: always True: during attention, q and k will be L2-normalized and scaled by a head-wise learnable parameter self.scale_mul_1H11
207
+ :param customized_flash_attn:
208
+ """
209
+ super().__init__()
210
+ assert embed_dim % num_heads == 0
211
+ self.using_flash = customized_flash_attn
212
+
213
+ self.num_heads, self.head_dim = num_heads, embed_dim // num_heads
214
+ self.tau, self.cos_attn = tau, cos_attn
215
+ if self.cos_attn:
216
+ self.scale = 1
217
+ size = (1, 1, self.num_heads, 1) if self.using_flash else (1, self.num_heads, 1, 1)
218
+ # size: 11H1 or 1H11
219
+ self.scale_mul_1H11 = nn.Parameter(torch.full(size=size, fill_value=4.0).log(), requires_grad=True)
220
+ self.max_scale_mul = torch.log(torch.tensor(100)).item()
221
+ else:
222
+ self.scale = 1 / math.sqrt(self.head_dim) / self.tau
223
+
224
+ self.mat_qkv = nn.Linear(embed_dim, embed_dim * 3, bias=False)
225
+ self.q_bias, self.v_bias = nn.Parameter(torch.zeros(embed_dim)), nn.Parameter(torch.zeros(embed_dim))
226
+ self.register_buffer('zero_k_bias', torch.zeros(embed_dim))
227
+
228
+ self.proj = nn.Linear(embed_dim, embed_dim)
229
+ self.proj_drop = get_dropout_layer(proj_drop)
230
+
231
+ self.caching = False # kv caching: only used during inference
232
+ self.cached_k = None # kv caching: only used during inference
233
+ self.cached_v = None # kv caching: only used during inference
234
+
235
+ self.batch_size = batch_size
236
+ self.use_flex_attn = use_flex_attn
237
+ self.pad_to_multiplier = pad_to_multiplier
238
+
239
+ self.rope2d_normalized_by_hw = rope2d_normalized_by_hw
240
+
241
+
242
+ def kv_caching(self, enable: bool): # kv caching: only used during inference
243
+ self.caching = enable
244
+ self.cached_k = None
245
+ self.cached_v = None
246
+
247
+ # NOTE: attn_bias_or_two_vector is None during inference
248
+ def forward(self, x, attn_bias_or_two_vector: Union[torch.Tensor, Tuple[torch.IntTensor, torch.IntTensor]], attn_fn=None, scale_schedule=None, rope2d_freqs_grid=None, scale_ind=0):
249
+ """
250
+ :param (fp32) x: shaped (B or batch_size, L or seq_length, C or hidden_dim); if seq-parallel is used, the `L` dim would be shared
251
+ :param (fp32) attn_bias_or_two_vector:
252
+ if not using_flash:
253
+ a block-wise, lower-triangle matrix, like:
254
+ [[[[0, -, -, -, -, -, -, -, -, -, -, -, -, -],
255
+ [0, 0, 0, 0, 0, -, -, -, -, -, -, -, -, -],
256
+ [0, 0, 0, 0, 0, -, -, -, -, -, -, -, -, -],
257
+ [0, 0, 0, 0, 0, -, -, -, -, -, -, -, -, -],
258
+ [0, 0, 0, 0, 0, -, -, -, -, -, -, -, -, -],
259
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
260
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
261
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
262
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
263
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
264
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
265
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
266
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
267
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]]]
268
+ where 0 means visible and - means invisible (-inf)
269
+ else:
270
+ a tuple of two 1-dim int vector (VAR_visible_kvlen, VAR_invisible_qlen)
271
+ :return: shaped (B or batch_size, L or seq_length, C or hidden_dim); if seq-parallel is used, the `L` dim would be shared
272
+ """
273
+ # x: fp32
274
+ B, L, C = x.shape
275
+
276
+ # qkv: amp, bf16
277
+ qkv = F.linear(input=x, weight=self.mat_qkv.weight, bias=torch.cat((self.q_bias, self.zero_k_bias, self.v_bias))).view(B, L, 3, self.num_heads, self.head_dim) # BL3Hc
278
+ if self.using_flash: q, k, v = qkv.unbind(dim=2); L_dim = 1 # q or k or v: all are shaped in (B:batch_size, L:seq_len, H:heads, c:head_dim)
279
+ else: q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(dim=0); L_dim = 2 # q or k or v: all are shaped in (B:batch_size, H:heads, L:seq_len, c:head_dim)
280
+
281
+ if self.cos_attn: # always True
282
+ scale_mul = self.scale_mul_1H11.clamp_max(self.max_scale_mul).exp() # 11H1 (flash), or 1H11 (not flash)
283
+ q = F.normalize(q, dim=-1, eps=1e-12).mul(scale_mul).contiguous() # fp32
284
+ k = F.normalize(k, dim=-1, eps=1e-12).contiguous() # fp32
285
+ v = v.contiguous() # bf16
286
+ else: # be contiguous, to make kernel happy
287
+ q = q.contiguous() # bf16
288
+ k = k.contiguous() # bf16
289
+ v = v.contiguous() # bf16
290
+ if rope2d_freqs_grid is not None:
291
+ q, k = apply_rotary_emb(q, k, scale_schedule, rope2d_freqs_grid, self.pad_to_multiplier, self.rope2d_normalized_by_hw, scale_ind) #, freqs_cis=freqs_cis)
292
+ if self.caching: # kv caching: only used during inference
293
+ if self.cached_k is None: self.cached_k = k; self.cached_v = v
294
+ else: k = self.cached_k = torch.cat((self.cached_k, k), dim=L_dim); v = self.cached_v = torch.cat((self.cached_v, v), dim=L_dim)
295
+
296
+ if self.using_flash:
297
+ if attn_bias_or_two_vector is not None: # training
298
+ kw = dict(VAR_visible_kvlen=attn_bias_or_two_vector[0], VAR_invisible_qlen=attn_bias_or_two_vector[1])
299
+ else: # inference (autoregressive sampling)
300
+ kw = dict()
301
+ oup = flash_attn_func(q.to(v.dtype), k.to(v.dtype), v, dropout_p=0, softmax_scale=self.scale, **kw).view(B, L, C)
302
+ else:
303
+ # if self.cos_attn: q, k are in fp32; v is in bf16
304
+ # else: q, k, v are in bf16
305
+ if self.use_flex_attn and attn_fn is not None:
306
+ oup = attn_fn(q, k, v, scale=self.scale).transpose(1, 2).reshape(B, L, C)
307
+ else:
308
+ oup = slow_attn(query=q, key=k, value=v, scale=self.scale, attn_mask=attn_bias_or_two_vector, dropout_p=0).transpose(1, 2).reshape(B, L, C)
309
+ # oup: bf16
310
+
311
+ return self.proj_drop(self.proj(oup))
312
+
313
+ def extra_repr(self) -> str:
314
+ tail = ''
315
+ return f'using_flash={self.using_flash}, tau={self.tau}, cos_attn={self.cos_attn}{tail}'
316
+
317
+
318
+ class CrossAttention(nn.Module):
319
+ def __init__(
320
+ self, for_attn_pool=False, embed_dim=768, kv_dim=4096, num_heads=12,
321
+ proj_drop=0., cos_attn=False,
322
+ ):
323
+ """
324
+ :param for_attn_pool: only used in VAR.text_proj_for_sos
325
+ :param embed_dim: Q's dim
326
+ :param kv_dim: K's and V's dim
327
+ :param num_heads: num heads of multi-head attention
328
+ :param proj_drop: proj drop out
329
+ :param cos_attn: during attention, q and k will be L2-normalized and scaled by a head-wise learnable parameter self.scale_mul_1H11
330
+ """
331
+ cos_attn = False # TODO: never use cos attn in cross attention with T5 kv
332
+ super().__init__()
333
+ self.for_attn_pool = for_attn_pool
334
+ self.embed_dim = embed_dim
335
+ self.kv_dim = kv_dim
336
+ assert embed_dim % num_heads == 0
337
+ self.num_heads, self.head_dim = num_heads, embed_dim // num_heads # =64
338
+ self.cos_attn = cos_attn
339
+ if self.cos_attn:
340
+ self.scale = 1
341
+ self.scale_mul_1H1 = nn.Parameter(torch.full(size=(1, self.num_heads, 1, 1), fill_value=4.0).log(), requires_grad=True)
342
+ self.max_scale_mul = torch.log(torch.tensor(100)).item()
343
+ else:
344
+ self.scale = 1 / math.sqrt(self.head_dim)
345
+
346
+ if for_attn_pool:
347
+ q = torch.empty(1, self.num_heads, self.head_dim)
348
+ nn.init.trunc_normal_(q, mean=0, std=math.sqrt(1 / embed_dim / 3))
349
+ self.mat_q = nn.Parameter(q)
350
+ else:
351
+ self.mat_q = nn.Linear(embed_dim, embed_dim, bias=True)
352
+ self.mat_kv = nn.Linear(kv_dim, embed_dim*2, bias=False)
353
+ self.v_bias = nn.Parameter(torch.zeros(embed_dim))
354
+ self.register_buffer('zero_k_bias', torch.zeros(embed_dim))
355
+
356
+ self.proj = nn.Linear(embed_dim, embed_dim)
357
+ self.proj_drop = get_dropout_layer(proj_drop)
358
+
359
+ def forward(self, q, ca_kv):
360
+ """
361
+ :param q: shaped as (batch, seq_len, Q_dim)
362
+ :param ca_kv: contains several vectors, each of which is shaped as (len_i, KV_dim). We have [len_1xKV_dim, len_2xKV_dim, len_3xKV_dim, ...] and lens == [len_1, len_2, len_3, ...]
363
+ - kv_compact: shaped as (sum(lens), KV_dim)
364
+ - cu_seqlens_k: cumulated sum of lens
365
+ - max_seqlen_k: int, max(lens)
366
+ NOTE: seq_len (num of Qs) can reach 10k; but len_i (num of KVs) must <= 256
367
+
368
+ :return: shaped as (batch, seq_len, Q_dim)
369
+ """
370
+ kv_compact, cu_seqlens_k, max_seqlen_k = ca_kv
371
+ N = kv_compact.shape[0]
372
+
373
+ kv_compact = F.linear(kv_compact, weight=self.mat_kv.weight, bias=torch.cat((self.zero_k_bias, self.v_bias))).view(N, 2, self.num_heads, self.head_dim) # NC => N2Hc
374
+ # attn_bias = xformers.ops.fmha.BlockDiagonalMask.from_seqlens
375
+
376
+ if not self.for_attn_pool:
377
+ B, Lq = q.shape[:2]
378
+ q_compact = self.mat_q(q).view(-1, self.num_heads, self.head_dim)
379
+ else:
380
+ B = cu_seqlens_k.shape[0] - 1
381
+ Lq = 1
382
+ q_compact = self.mat_q.repeat(B, 1, 1).to(dtype=kv_compact.dtype)
383
+
384
+ if self.cos_attn: # always False
385
+ scale_mul = self.scale_mul_1H1.clamp_max(self.max_scale_mul).exp()
386
+ k, v = kv_compact.unbind(dim=1)
387
+ q_compact = F.normalize(q_compact, dim=-1).mul(scale_mul)
388
+ k = F.normalize(k, dim=-1)
389
+ kv_compact = torch.stack((k, v), dim=1)
390
+
391
+ q_compact = q_compact.contiguous()
392
+ kv_compact = kv_compact.contiguous()
393
+
394
+ cu_seqlens_q = torch.arange(0, Lq * (B+1), Lq, dtype=torch.int32, device=q_compact.device)
395
+ if q_compact.dtype == torch.float32: # todo: fp16 or bf16?
396
+ oup = flash_attn_varlen_kvpacked_func(q=q_compact.to(dtype=torch.bfloat16), kv=kv_compact.to(dtype=torch.bfloat16), cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=Lq, max_seqlen_k=max_seqlen_k, dropout_p=0, softmax_scale=self.scale).reshape(B, Lq, -1)
397
+ oup = oup.float()
398
+ else:
399
+ oup = flash_attn_varlen_kvpacked_func(q=q_compact, kv=kv_compact, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=Lq, max_seqlen_k=max_seqlen_k, dropout_p=0, softmax_scale=self.scale).reshape(B, Lq, -1)
400
+
401
+ return self.proj_drop(self.proj(oup))
402
+
403
+ def extra_repr(self) -> str:
404
+ return f'Cq={self.embed_dim}, Ckv={self.kv_dim}, cos_attn={self.cos_attn}'
405
+
406
+
407
+ class SelfAttnBlock(nn.Module):
408
+ def __init__(
409
+ self, embed_dim, kv_dim, cross_attn_layer_scale, cond_dim, act: bool, shared_aln: bool, norm_layer: partial,
410
+ num_heads, mlp_ratio=4., drop=0., drop_path=0., tau=1, cos_attn=False,
411
+ swiglu=False, customized_flash_attn=False, fused_mlp=False, fused_norm_func=None, checkpointing_sa_only=False,
412
+ ):
413
+ super(SelfAttnBlock, self).__init__()
414
+ self.C, self.D = embed_dim, cond_dim
415
+ self.drop_path_rate = drop_path
416
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
417
+ self.attn = SelfAttention(
418
+ embed_dim=embed_dim, num_heads=num_heads, proj_drop=drop, tau=tau, cos_attn=cos_attn, customized_flash_attn=customized_flash_attn, attn_fn = attn_fn
419
+ )
420
+ self.using_swiglu = swiglu
421
+ self.ffn = (FFNSwiGLU if swiglu else FFN)(in_features=embed_dim, hidden_features=round(embed_dim * mlp_ratio / 256) * 256, drop=drop, fused_mlp=fused_mlp)
422
+
423
+ self.ln_wo_grad = norm_layer(embed_dim, elementwise_affine=False)
424
+ self.fused_norm_func = fused_norm_func
425
+ self.norm_eps = norm_layer.keywords.get('eps', 1e-6)
426
+
427
+ self.shared_aln = shared_aln
428
+ if self.shared_aln:
429
+ self.ada_gss = nn.Parameter(torch.randn(1, 1, 6, embed_dim) / embed_dim**0.5)
430
+ else:
431
+ lin = nn.Linear(cond_dim, 6*embed_dim)
432
+ self.ada_lin = nn.Sequential(nn.SiLU(inplace=False), lin) if act else nn.Sequential(lin)
433
+
434
+ # NOTE: attn_bias_or_two_vector is None during inference
435
+ def forward(self, x, cond_BD, ca_kv, attn_bias_or_two_vector): # todo: minGPT and vqgan also uses pre-norm, just like this, while MaskGiT uses post-norm
436
+ with torch.cuda.amp.autocast(enabled=False):
437
+ if self.shared_aln: # always True; (1, 1, 6, C) + (B, 1, 6, C)
438
+ gamma1, gamma2, scale1, scale2, shift1, shift2 = (self.ada_gss + cond_BD).unbind(2) # 116C + B16C =unbind(2)=> 6 B1C
439
+ else:
440
+ gamma1, gamma2, scale1, scale2, shift1, shift2 = self.ada_lin(cond_BD).view(-1, 1, 6, self.C).unbind(2)
441
+
442
+ if self.fused_ada_norm is None:
443
+ x = x + self.drop_path(self.attn( self.ln_wo_grad(x.float()).mul(scale1.add(1)).add_(shift1), attn_bias_or_two_vector=attn_bias_or_two_vector ).mul_(gamma1))
444
+ x = x + self.drop_path(self.ffn( self.ln_wo_grad(x.float()).mul(scale2.add(1)).add_(shift2) ).mul(gamma2)) # this mul(gamma2) cannot be in-placed cuz we possibly use FusedMLP
445
+ else:
446
+ x = x + self.drop_path(self.attn(self.fused_ada_norm(C=self.C, eps=self.norm_eps, x=x, scale=scale1, shift=shift1), attn_bias_or_two_vector=attn_bias_or_two_vector).mul_(gamma1))
447
+ x = x + self.drop_path(self.ffn(self.fused_ada_norm(C=self.C, eps=self.norm_eps, x=x, scale=scale2, shift=shift2)).mul(gamma2)) # this mul(gamma2) cannot be in-placed cuz we possibly use FusedMLP
448
+ return x
449
+
450
+ def extra_repr(self) -> str:
451
+ return f'shared_aln={self.shared_aln}, fused_norm={self.fused_norm_func is not None}'
452
+
453
+
454
+ class CrossAttnBlock(nn.Module):
455
+ def __init__(
456
+ self,
457
+ embed_dim, kv_dim, cross_attn_layer_scale, cond_dim, act: bool, shared_aln: bool, norm_layer: partial,
458
+ num_heads, mlp_ratio=4., drop=0., drop_path=0., tau=1, cos_attn=False,
459
+ swiglu=False, customized_flash_attn=False, fused_mlp=False, fused_norm_func=None, checkpointing_sa_only=False,
460
+ use_flex_attn=False, batch_size=2, pad_to_multiplier=1, apply_rope2d=False, rope2d_normalized_by_hw=False,
461
+ ):
462
+ super(CrossAttnBlock, self).__init__()
463
+ self.C, self.D = embed_dim, cond_dim
464
+ self.drop_path_rate = drop_path
465
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
466
+ self.sa = SelfAttention(
467
+ embed_dim=embed_dim, num_heads=num_heads, proj_drop=drop, tau=tau, cos_attn=cos_attn, customized_flash_attn=customized_flash_attn,
468
+ use_flex_attn=use_flex_attn, batch_size=batch_size, pad_to_multiplier=pad_to_multiplier, rope2d_normalized_by_hw=rope2d_normalized_by_hw,
469
+ )
470
+ self.ca = CrossAttention(embed_dim=embed_dim, kv_dim=kv_dim, num_heads=num_heads, proj_drop=drop, cos_attn=cos_attn)
471
+ self.using_swiglu = swiglu
472
+ self.ffn = (FFNSwiGLU if swiglu else FFN)(in_features=embed_dim, hidden_features=round(embed_dim * mlp_ratio / 256) * 256, drop=drop, fused_mlp=fused_mlp)
473
+
474
+ self.ln_wo_grad = norm_layer(embed_dim, elementwise_affine=False)
475
+ self.fused_norm_func = fused_norm_func
476
+ self.norm_eps = norm_layer.keywords.get('eps', 1e-6)
477
+ self.ca_norm = norm_layer(embed_dim, elementwise_affine=True)
478
+
479
+ self.shared_aln = shared_aln
480
+ if self.shared_aln: # always True
481
+ self.ada_gss = nn.Parameter(torch.randn(1, 1, 6, embed_dim) / embed_dim**0.5)
482
+ else:
483
+ lin = nn.Linear(cond_dim, 6*embed_dim)
484
+ self.ada_lin = nn.Sequential(nn.SiLU(inplace=False), lin) if act else nn.Sequential(lin)
485
+
486
+ if cross_attn_layer_scale >= 0:
487
+ self.ca_gamma = nn.Parameter(cross_attn_layer_scale * torch.ones(embed_dim), requires_grad=True)
488
+ else:
489
+ self.ca_gamma = 1
490
+
491
+ self.checkpointing_sa_only = checkpointing_sa_only
492
+
493
+ # NOTE: attn_bias_or_two_vector is None during inference
494
+ def forward(self, x, cond_BD, ca_kv, attn_bias_or_two_vector, attn_fn=None, scale_schedule=None, rope2d_freqs_grid=None, scale_ind=0): # todo: minGPT and vqgan also uses pre-norm, just like this, while MaskGiT uses post-norm
495
+ with torch.cuda.amp.autocast(enabled=False): # disable half precision
496
+ if self.shared_aln: # always True; (1, 1, 6, C) + (B, 1, 6, C)
497
+ gamma1, gamma2, scale1, scale2, shift1, shift2 = (self.ada_gss + cond_BD).unbind(2) # 116C + B16C =unbind(2)=> 6 B1C
498
+ else:
499
+ gamma1, gamma2, scale1, scale2, shift1, shift2 = self.ada_lin(cond_BD).view(-1, 1, 6, self.C).unbind(2)
500
+
501
+ if self.fused_norm_func is None:
502
+ x_sa = self.ln_wo_grad(x.float()).mul(scale1.add(1)).add_(shift1)
503
+ if self.checkpointing_sa_only and self.training:
504
+ x_sa = checkpoint(self.sa, x_sa, attn_bias_or_two_vector, attn_fn, scale_schedule, rope2d_freqs_grid, use_reentrant=False)
505
+ else:
506
+ x_sa = self.sa(x_sa, attn_bias_or_two_vector, attn_fn, scale_schedule, rope2d_freqs_grid)
507
+ x = x + self.drop_path(x_sa.mul_(gamma1))
508
+ x = x + self.ca(self.ca_norm(x), ca_kv).float().mul_(self.ca_gamma)
509
+ x = x + self.drop_path(self.ffn( self.ln_wo_grad(x.float()).mul(scale2.add(1)).add_(shift2) ).mul(gamma2)) # this mul(gamma2) cannot be in-placed cuz we possibly use FusedMLP
510
+ else:
511
+ x_sa = self.fused_norm_func(C=self.C, eps=self.norm_eps, x=x, scale=scale1, shift=shift1)
512
+ if self.checkpointing_sa_only and self.training:
513
+ x_sa = checkpoint(self.sa, x_sa, attn_bias_or_two_vector, attn_fn, scale_schedule, rope2d_freqs_grid, use_reentrant=False)
514
+ else:
515
+ x_sa = self.sa(x_sa, attn_bias_or_two_vector, attn_fn, scale_schedule, rope2d_freqs_grid, scale_ind=scale_ind)
516
+ x = x + self.drop_path(x_sa.mul_(gamma1))
517
+ x = x + self.ca(self.ca_norm(x), ca_kv).float().mul_(self.ca_gamma)
518
+ x = x + self.drop_path(self.ffn(self.fused_norm_func(C=self.C, eps=self.norm_eps, x=x, scale=scale2, shift=shift2)).mul(gamma2)) # this mul(gamma2) cannot be in-placed cuz we possibly use FusedMLP
519
+ return x
520
+
521
+ def extra_repr(self) -> str:
522
+ return f'shared_aln={self.shared_aln}, fused_norm={self.fused_norm_func is not None}, ca_gamma={"<learnable>" if isinstance(self.ca_gamma, nn.Parameter) else self.ca_gamma}'
523
+
524
+
525
+ class AdaLNBeforeHead(nn.Module):
526
+ def __init__(self, C, D, act: bool, norm_layer: partial, fused_norm_func=None): # C: embed_dim, D: cond_dim
527
+ super().__init__()
528
+ self.C, self.D = C, D
529
+ self.ln_wo_grad = norm_layer(C, elementwise_affine=False)
530
+ self.fused_norm_func = fused_norm_func
531
+ self.norm_eps = norm_layer.keywords.get('eps', 1e-6)
532
+ lin = nn.Linear(D, 2*C)
533
+ self.ada_lin = nn.Sequential(nn.SiLU(inplace=False), lin) if act else nn.Sequential(lin)
534
+
535
+ def forward(self, x_BLC: torch.Tensor, cond_BD: Optional[torch.Tensor]):
536
+ scale, shift = self.ada_lin(cond_BD).view(-1, 1, 2, self.C).unbind(2)
537
+ if self.fused_norm_func is None:
538
+ return self.ln_wo_grad(x_BLC).mul(scale.add(1)).add_(shift)
539
+ else:
540
+ return self.fused_norm_func(C=self.C, eps=self.norm_eps, x=x_BLC, scale=scale, shift=shift)
541
+
542
+
543
+ def main():
544
+ dev = 'cpu' # 'cuda' if torch.cuda.is_available() else 'cpu'
545
+ rng = torch.Generator(device=dev)
546
+ # for Li in ([1, 3, 5], [1, 3]):
547
+ rng.manual_seed(0)
548
+ B, H, cq, ckv = 4, 8, 64, 96
549
+ Cq = H*cq
550
+ Ckv = H*ckv
551
+
552
+ Li = [5, 4, 7, 6]
553
+ Lq = 10
554
+ L = max(Li)
555
+ attn_bias = torch.zeros(B, 1, Lq, L, device=dev)
556
+ for i, x in enumerate(Li):
557
+ attn_bias[i, 0, :, x:] = -torch.inf
558
+
559
+ q = torch.randn(B, Lq, H, cq, generator=rng, device=dev)
560
+ k = torch.randn(B, L, H, ckv, generator=rng, device=dev)
561
+ v = torch.randn(B, L, H, ckv, generator=rng, device=dev)
562
+ tq, tk, tv = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) # BHLc
563
+
564
+ seqlen_k = torch.tensor(Li, dtype=torch.int32, device=dev)
565
+ cu_seqlens_k = F.pad(torch.cumsum(seqlen_k, dim=0, dtype=torch.torch.int32), (1, 0))
566
+ kv = torch.stack([k, v], dim=2)
567
+ kv_compact = torch.cat([kv[i, :Li[i]] for i in range(B)], dim=0)
568
+
569
+ ca = CrossAttention(for_attn_pool=False, embed_dim=Cq, kv_dim=Ckv, num_heads=H)
570
+ CrossAttention.forward
571
+ ca(q, (kv_compact, cu_seqlens_k, max(Li))).mean().backward()
572
+
573
+
574
+ if __name__ == '__main__':
575
+ main()
models/bitwise_self_correction.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import os.path as osp
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ import numpy as np
7
+
8
+
9
+ def labels2image(all_indices, label_type='int_label', scale_schedule=None):
10
+ summed_codes, recons_imgs = self.vae.decode_from_indices(all_indices, scale_schedule, label_type)
11
+ recons_img = recons_imgs[0]
12
+ recons_img = (recons_img + 1) / 2
13
+ recons_img = recons_img.permute(1, 2, 0).mul_(255).cpu().numpy().astype(np.uint8)[:,:,::-1]
14
+ return recons_img
15
+
16
+ def features2image(raw_features):
17
+ recons_imgs = self.vae.decode(raw_features.squeeze(-3))
18
+ recons_img = recons_imgs[0]
19
+ recons_img = (recons_img + 1) / 2
20
+ recons_img = recons_img.permute(1, 2, 0).mul_(255).cpu().numpy().astype(np.uint8)[:,:,::-1]
21
+ return recons_img
22
+
23
+ class BitwiseSelfCorrection(object):
24
+ def __init__(self, vae, args):
25
+ self.noise_apply_layers = args.noise_apply_layers
26
+ self.noise_apply_requant = args.noise_apply_requant
27
+ self.noise_apply_strength = args.noise_apply_strength
28
+ self.apply_spatial_patchify = args.apply_spatial_patchify
29
+ self.vae = vae
30
+ self.debug_bsc = args.debug_bsc
31
+
32
+ def flip_requant(self, vae_scale_schedule, inp_B3HW, raw_features, device):
33
+ with torch.amp.autocast('cuda', enabled = False):
34
+ B = raw_features.shape[0]
35
+ if raw_features.dim() == 4:
36
+ codes_out = raw_features.unsqueeze(2)
37
+ else:
38
+ codes_out = raw_features
39
+ cum_var_input = 0
40
+ gt_all_bit_indices = []
41
+ pred_all_bit_indices = []
42
+ x_BLC_wo_prefix = []
43
+ for si, (pt, ph, pw) in enumerate(vae_scale_schedule):
44
+ residual = codes_out - cum_var_input
45
+ if si != len(vae_scale_schedule)-1:
46
+ residual = F.interpolate(residual, size=vae_scale_schedule[si], mode=self.vae.quantizer.z_interplote_down).contiguous()
47
+ quantized, _, bit_indices, loss = self.vae.quantizer.lfq(residual) # quantized shape: [B, d_vae, 1, h, w], bit_indices shape: [B,1,h,w,d_vae]
48
+ gt_all_bit_indices.append(bit_indices)
49
+ if si < self.noise_apply_layers:
50
+ noise_apply_strength = np.random.randint(0, 100 * self.noise_apply_strength+1) * 0.01
51
+ mask = torch.rand(*bit_indices.shape).to(device) < noise_apply_strength
52
+ pred_bit_indices = bit_indices.clone()
53
+ pred_bit_indices[mask] = 1 - pred_bit_indices[mask]
54
+ pred_all_bit_indices.append(pred_bit_indices)
55
+ if self.noise_apply_requant:
56
+ quantized = self.vae.quantizer.lfq.indices_to_codes(pred_bit_indices, label_type = 'bit_label')
57
+ else:
58
+ pred_all_bit_indices.append(bit_indices)
59
+ cum_var_input = cum_var_input + F.interpolate(quantized, size=vae_scale_schedule[-1], mode=self.vae.quantizer.z_interplote_up).contiguous()
60
+ if si < len(vae_scale_schedule)-1:
61
+ this_scale_input = F.interpolate(cum_var_input, size=vae_scale_schedule[si+1], mode=self.vae.quantizer.z_interplote_up).contiguous()
62
+ if self.apply_spatial_patchify:
63
+ # (B,d,1,H,W) -> (B,d,H,W) -> (B,4d,H/2,W/2)
64
+ this_scale_input = torch.nn.functional.pixel_unshuffle(this_scale_input.squeeze(-3), 2)
65
+ x_BLC_wo_prefix.append(this_scale_input.reshape(*this_scale_input.shape[:2], -1).permute(0,2,1)) # (B,H/2*W/2,4C) or (B,H*W,C)
66
+
67
+ if self.apply_spatial_patchify:
68
+ gt_ms_idx_Bl = []
69
+ for item in gt_all_bit_indices:
70
+ # item shape: (B,1,H,W,d)
71
+ item = item.squeeze(1).permute(0,3,1,2) # (B,d,H,W)
72
+ # (B,d,H,W) -> (B,4d,H/2,W/2)
73
+ item = torch.nn.functional.pixel_unshuffle(item, 2)
74
+ # (B,4d,H/2,W/2) -> (B,H/2,W/2,4d) -> (B,H/2*w/2,4d)
75
+ item = item.permute(0,2,3,1).reshape(B, -1, 4*self.vae.codebook_dim)
76
+ gt_ms_idx_Bl.append(item)
77
+ else:
78
+ gt_ms_idx_Bl = [item.reshape(B, -1, self.vae.codebook_dim) for item in gt_all_bit_indices]
79
+ x_BLC_wo_prefix = torch.cat(x_BLC_wo_prefix, 1)
80
+
81
+ if self.debug_bsc:
82
+ self.visualize(vae_scale_schedule, inp_B3HW, gt_all_bit_indices, pred_all_bit_indices)
83
+
84
+ return x_BLC_wo_prefix, gt_ms_idx_Bl
85
+
86
+ def visualize(self, vae_scale_schedule, inp_B3HW, gt_all_bit_indices, pred_all_bit_indices):
87
+ gt_img = (inp_B3HW.squeeze(-3) + 1) / 2 * 255
88
+ gt_img = gt_img[0].permute(1,2,0).cpu().numpy().astype(np.uint8)[:,:,::-1]
89
+ recons_img_2 = labels2image(gt_all_bit_indices, label_type='bit_label', scale_schedule=vae_scale_schedule)
90
+ recons_img_3 = labels2image(pred_all_bit_indices, label_type='bit_label', scale_schedule=vae_scale_schedule)
91
+ cat_image = np.concatenate([gt_img, recons_img_2, recons_img_3], axis=1)
92
+ save_path = osp.abspath('non_teacher_force.jpg')
93
+ cv2.imwrite(save_path, cat_image)
94
+ print(f'Save to {save_path}')
95
+ import pdb; pdb.set_trace()
96
+ print(cat_image.shape)
97
+
models/bsq_vae/conv.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from einops import rearrange
4
+ import torch.nn.functional as F
5
+
6
+
7
+ class Conv(nn.Module):
8
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, cnn_type="2d", causal_offset=0, temporal_down=False):
9
+ super().__init__()
10
+ self.cnn_type = cnn_type
11
+ self.slice_seq_len = 17
12
+
13
+ if cnn_type == "2d":
14
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding)
15
+ if cnn_type == "3d":
16
+ if temporal_down == False:
17
+ stride = (1, stride, stride)
18
+ else:
19
+ stride = (stride, stride, stride)
20
+ self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=0)
21
+ if isinstance(kernel_size, int):
22
+ kernel_size = (kernel_size, kernel_size, kernel_size)
23
+ self.padding = (
24
+ kernel_size[0] - 1 + causal_offset, # Temporal causal padding
25
+ padding, # Height padding
26
+ padding # Width padding
27
+ )
28
+ self.causal_offset = causal_offset
29
+ self.stride = stride
30
+ self.kernel_size = kernel_size
31
+
32
+ def forward(self, x):
33
+ if self.cnn_type == "2d":
34
+ if x.ndim == 5:
35
+ B, C, T, H, W = x.shape
36
+ x = rearrange(x, "B C T H W -> (B T) C H W")
37
+ x = self.conv(x)
38
+ x = rearrange(x, "(B T) C H W -> B C T H W", T=T)
39
+ return x
40
+ else:
41
+ return self.conv(x)
42
+ if self.cnn_type == "3d":
43
+ assert self.stride[0] == 1 or self.stride[0] == 2, f"only temporal stride = 1 or 2 are supported"
44
+ xs = []
45
+ for i in range(0, x.shape[2], self.slice_seq_len+self.stride[0]-1):
46
+ st = i
47
+ en = min(i+self.slice_seq_len, x.shape[2])
48
+ _x = x[:,:,st:en,:,:]
49
+ if i == 0:
50
+ _x = F.pad(_x, (self.padding[2], self.padding[2], # Width
51
+ self.padding[1], self.padding[1], # Height
52
+ self.padding[0], 0)) # Temporal
53
+ else:
54
+ padding_0 = self.kernel_size[0] - 1
55
+ _x = F.pad(_x, (self.padding[2], self.padding[2], # Width
56
+ self.padding[1], self.padding[1], # Height
57
+ padding_0, 0)) # Temporal
58
+ _x[:,:,:padding_0,
59
+ self.padding[1]:_x.shape[-2]-self.padding[1],
60
+ self.padding[2]:_x.shape[-1]-self.padding[2]] += x[:,:,i-padding_0:i,:,:]
61
+ _x = self.conv(_x)
62
+ xs.append(_x)
63
+ try:
64
+ x = torch.cat(xs, dim=2)
65
+ except:
66
+ device = x.device
67
+ del x
68
+ xs = [_x.cpu().pin_memory() for _x in xs]
69
+ torch.cuda.empty_cache()
70
+ x = torch.cat([_x.cpu() for _x in xs], dim=2).to(device=device)
71
+ return x
models/bsq_vae/dynamic_resolution.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import numpy as np
3
+ import tqdm
4
+
5
+ vae_stride = 16
6
+ ratio2hws = {
7
+ 1.000: [(1,1),(2,2),(4,4),(6,6),(8,8),(12,12),(16,16),(20,20),(24,24),(32,32),(40,40),(48,48),(64,64)],
8
+ 1.250: [(1,1),(2,2),(3,3),(5,4),(10,8),(15,12),(20,16),(25,20),(30,24),(35,28),(45,36),(55,44),(70,56)],
9
+ 1.333: [(1,1),(2,2),(4,3),(8,6),(12,9),(16,12),(20,15),(24,18),(28,21),(36,27),(48,36),(60,45),(72,54)],
10
+ 1.500: [(1,1),(2,2),(3,2),(6,4),(9,6),(15,10),(21,14),(27,18),(33,22),(39,26),(48,32),(63,42),(78,52)],
11
+ 1.750: [(1,1),(2,2),(3,3),(7,4),(11,6),(14,8),(21,12),(28,16),(35,20),(42,24),(56,32),(70,40),(84,48)],
12
+ 2.000: [(1,1),(2,2),(4,2),(6,3),(10,5),(16,8),(22,11),(30,15),(38,19),(46,23),(60,30),(74,37),(90,45)],
13
+ 2.500: [(1,1),(2,2),(5,2),(10,4),(15,6),(20,8),(25,10),(30,12),(40,16),(50,20),(65,26),(80,32),(100,40)],
14
+ 3.000: [(1,1),(2,2),(6,2),(9,3),(15,5),(21,7),(27,9),(36,12),(45,15),(54,18),(72,24),(90,30),(111,37)],
15
+ }
16
+ full_ratio2hws = {}
17
+ for ratio, hws in ratio2hws.items():
18
+ full_ratio2hws[ratio] = hws
19
+ full_ratio2hws[int(1/ratio*1000)/1000] = [(item[1], item[0]) for item in hws]
20
+
21
+ dynamic_resolution_h_w = {}
22
+ predefined_HW_Scales_dynamic = {}
23
+ for ratio in full_ratio2hws:
24
+ dynamic_resolution_h_w[ratio] ={}
25
+ for ind, leng in enumerate([7, 10, 13]):
26
+ h, w = full_ratio2hws[ratio][leng-1][0], full_ratio2hws[ratio][leng-1][1] # feature map size
27
+ pixel = (h * vae_stride, w * vae_stride) # The original image (H, W)
28
+ dynamic_resolution_h_w[ratio][pixel[1]] = {
29
+ 'pixel': pixel,
30
+ 'scales': full_ratio2hws[ratio][:leng]
31
+ } # W as key
32
+ predefined_HW_Scales_dynamic[(h, w)] = full_ratio2hws[ratio][:leng]
models/bsq_vae/flux_vqgan.py ADDED
@@ -0,0 +1,557 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import imageio
4
+ import torch
5
+ import numpy as np
6
+ from einops import rearrange
7
+ from torch import Tensor, nn
8
+ import torch.nn.functional as F
9
+ import torchvision
10
+ from torchvision import transforms
11
+ from safetensors.torch import load_file
12
+ import torch.utils.checkpoint as checkpoint
13
+
14
+ from .conv import Conv
15
+ from .multiscale_bsq import MultiScaleBSQ
16
+
17
+ ptdtype = {None: torch.float32, 'fp32': torch.float32, 'bf16': torch.bfloat16}
18
+
19
+ class Normalize(nn.Module):
20
+ def __init__(self, in_channels, norm_type, norm_axis="spatial"):
21
+ super().__init__()
22
+ self.norm_axis = norm_axis
23
+ assert norm_type in ['group', 'batch', "no"]
24
+ if norm_type == 'group':
25
+ if in_channels % 32 == 0:
26
+ self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
27
+ elif in_channels % 24 == 0:
28
+ self.norm = nn.GroupNorm(num_groups=24, num_channels=in_channels, eps=1e-6, affine=True)
29
+ else:
30
+ raise NotImplementedError
31
+ elif norm_type == 'batch':
32
+ self.norm = nn.SyncBatchNorm(in_channels, track_running_stats=False) # Runtime Error: grad inplace if set track_running_stats to True
33
+ elif norm_type == 'no':
34
+ self.norm = nn.Identity()
35
+
36
+ def forward(self, x):
37
+ if self.norm_axis == "spatial":
38
+ if x.ndim == 4:
39
+ x = self.norm(x)
40
+ else:
41
+ B, C, T, H, W = x.shape
42
+ x = rearrange(x, "B C T H W -> (B T) C H W")
43
+ x = self.norm(x)
44
+ x = rearrange(x, "(B T) C H W -> B C T H W", T=T)
45
+ elif self.norm_axis == "spatial-temporal":
46
+ x = self.norm(x)
47
+ else:
48
+ raise NotImplementedError
49
+ return x
50
+
51
+ def swish(x: Tensor) -> Tensor:
52
+ try:
53
+ return x * torch.sigmoid(x)
54
+ except:
55
+ device = x.device
56
+ x = x.cpu().pin_memory()
57
+ return (x*torch.sigmoid(x)).to(device=device)
58
+
59
+
60
+ class AttnBlock(nn.Module):
61
+ def __init__(self, in_channels, norm_type='group', cnn_param=None):
62
+ super().__init__()
63
+ self.in_channels = in_channels
64
+
65
+ self.norm = Normalize(in_channels, norm_type, norm_axis=cnn_param["cnn_norm_axis"])
66
+
67
+ self.q = Conv(in_channels, in_channels, kernel_size=1)
68
+ self.k = Conv(in_channels, in_channels, kernel_size=1)
69
+ self.v = Conv(in_channels, in_channels, kernel_size=1)
70
+ self.proj_out = Conv(in_channels, in_channels, kernel_size=1)
71
+
72
+ def attention(self, h_: Tensor) -> Tensor:
73
+ B, _, T, _, _ = h_.shape
74
+ h_ = self.norm(h_)
75
+ h_ = rearrange(h_, "B C T H W -> (B T) C H W") # spatial attention only
76
+ q = self.q(h_)
77
+ k = self.k(h_)
78
+ v = self.v(h_)
79
+
80
+ b, c, h, w = q.shape
81
+ q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
82
+ k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
83
+ v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
84
+ h_ = nn.functional.scaled_dot_product_attention(q, k, v)
85
+
86
+ return rearrange(h_, "(b t) 1 (h w) c -> b c t h w", h=h, w=w, c=c, b=B, t=T)
87
+
88
+ def forward(self, x: Tensor) -> Tensor:
89
+ return x + self.proj_out(self.attention(x))
90
+
91
+
92
+ class ResnetBlock(nn.Module):
93
+ def __init__(self, in_channels: int, out_channels: int, norm_type='group', cnn_param=None):
94
+ super().__init__()
95
+ self.in_channels = in_channels
96
+ out_channels = in_channels if out_channels is None else out_channels
97
+ self.out_channels = out_channels
98
+
99
+ self.norm1 = Normalize(in_channels, norm_type, norm_axis=cnn_param["cnn_norm_axis"])
100
+ if cnn_param["res_conv_2d"] in ["half", "full"]:
101
+ self.conv1 = Conv(in_channels, out_channels, kernel_size=3, stride=1, padding=1, cnn_type="2d")
102
+ else:
103
+ self.conv1 = Conv(in_channels, out_channels, kernel_size=3, stride=1, padding=1, cnn_type=cnn_param["cnn_type"])
104
+ self.norm2 = Normalize(out_channels, norm_type, norm_axis=cnn_param["cnn_norm_axis"])
105
+ if cnn_param["res_conv_2d"] in ["full"]:
106
+ self.conv2 = Conv(out_channels, out_channels, kernel_size=3, stride=1, padding=1, cnn_type="2d")
107
+ else:
108
+ self.conv2 = Conv(out_channels, out_channels, kernel_size=3, stride=1, padding=1, cnn_type=cnn_param["cnn_type"])
109
+ if self.in_channels != self.out_channels:
110
+ self.nin_shortcut = Conv(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
111
+
112
+ def forward(self, x):
113
+ h = x
114
+ h = self.norm1(h)
115
+ h = swish(h)
116
+ h = self.conv1(h)
117
+
118
+ h = self.norm2(h)
119
+ h = swish(h)
120
+ h = self.conv2(h)
121
+
122
+ if self.in_channels != self.out_channels:
123
+ x = self.nin_shortcut(x)
124
+
125
+ return x + h
126
+
127
+
128
+ class Downsample(nn.Module):
129
+ def __init__(self, in_channels, cnn_type="2d", spatial_down=False, temporal_down=False):
130
+ super().__init__()
131
+ assert spatial_down == True
132
+ if cnn_type == "2d":
133
+ self.pad = (0,1,0,1)
134
+ if cnn_type == "3d":
135
+ self.pad = (0,1,0,1,0,0) # add padding to the right for h-axis and w-axis. No padding for t-axis
136
+ # no asymmetric padding in torch conv, must do it ourselves
137
+ self.conv = Conv(in_channels, in_channels, kernel_size=3, stride=2, padding=0, cnn_type=cnn_type, temporal_down=temporal_down)
138
+
139
+ def forward(self, x: Tensor):
140
+ x = nn.functional.pad(x, self.pad, mode="constant", value=0)
141
+ x = self.conv(x)
142
+ return x
143
+
144
+
145
+ class Upsample(nn.Module):
146
+ def __init__(self, in_channels, cnn_type="2d", spatial_up=False, temporal_up=False, use_pxsl=False):
147
+ super().__init__()
148
+ if cnn_type == "2d":
149
+ self.scale_factor = 2
150
+ self.causal_offset = 0
151
+ else:
152
+ assert spatial_up == True
153
+ if temporal_up:
154
+ self.scale_factor = (2,2,2)
155
+ self.causal_offset = -1
156
+ else:
157
+ self.scale_factor = (1,2,2)
158
+ self.causal_offset = 0
159
+ self.use_pxsl = use_pxsl
160
+ if self.use_pxsl:
161
+ self.conv = Conv(in_channels, in_channels*4, kernel_size=3, stride=1, padding=1, cnn_type=cnn_type, causal_offset=self.causal_offset)
162
+ self.pxsl = nn.PixelShuffle(2)
163
+ else:
164
+ self.conv = Conv(in_channels, in_channels, kernel_size=3, stride=1, padding=1, cnn_type=cnn_type, causal_offset=self.causal_offset)
165
+
166
+ def forward(self, x: Tensor):
167
+ if self.use_pxsl:
168
+ x = self.conv(x)
169
+ x = self.pxsl(x)
170
+ else:
171
+ try:
172
+ x = F.interpolate(x, scale_factor=self.scale_factor, mode="nearest")
173
+ except:
174
+ # shard across channel
175
+ _xs = []
176
+ for i in range(x.shape[1]):
177
+ _x = F.interpolate(x[:,i:i+1,...], scale_factor=self.scale_factor, mode="nearest")
178
+ _xs.append(_x)
179
+ x = torch.cat(_xs, dim=1)
180
+ x = self.conv(x)
181
+ return x
182
+
183
+
184
+ class Encoder(nn.Module):
185
+ def __init__(
186
+ self,
187
+ ch: int,
188
+ ch_mult: list[int],
189
+ num_res_blocks: int,
190
+ z_channels: int,
191
+ in_channels = 3,
192
+ patch_size=8, temporal_patch_size=4,
193
+ norm_type='group', cnn_param=None,
194
+ use_checkpoint=False,
195
+ use_vae=True,
196
+ ):
197
+ super().__init__()
198
+ self.max_down = np.log2(patch_size)
199
+ self.temporal_max_down = np.log2(temporal_patch_size)
200
+ self.temporal_down_offset = self.max_down - self.temporal_max_down
201
+ self.ch = ch
202
+ self.num_resolutions = len(ch_mult)
203
+ self.num_res_blocks = num_res_blocks
204
+ self.in_channels = in_channels
205
+ self.cnn_param = cnn_param
206
+ self.use_checkpoint = use_checkpoint
207
+ # downsampling
208
+ # self.conv_in = Conv(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
209
+ # cnn_param["cnn_type"] = "2d" for images, cnn_param["cnn_type"] = "3d" for videos
210
+ if cnn_param["conv_in_out_2d"] == "yes": # "yes" for video
211
+ self.conv_in = Conv(in_channels, ch, kernel_size=3, stride=1, padding=1, cnn_type="2d")
212
+ else:
213
+ self.conv_in = Conv(in_channels, ch, kernel_size=3, stride=1, padding=1, cnn_type=cnn_param["cnn_type"])
214
+
215
+ in_ch_mult = (1,) + tuple(ch_mult)
216
+ self.in_ch_mult = in_ch_mult
217
+ self.down = nn.ModuleList()
218
+ block_in = self.ch
219
+ for i_level in range(self.num_resolutions):
220
+ block = nn.ModuleList()
221
+ attn = nn.ModuleList()
222
+ block_in = ch * in_ch_mult[i_level]
223
+ block_out = ch * ch_mult[i_level]
224
+ for _ in range(self.num_res_blocks):
225
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, norm_type=norm_type, cnn_param=cnn_param))
226
+ block_in = block_out
227
+ down = nn.Module()
228
+ down.block = block
229
+ down.attn = attn
230
+ # downsample, stride=1, stride=2, stride=2 for 4x8x8 Video VAE
231
+ spatial_down = True if i_level < self.max_down else False
232
+ temporal_down = True if i_level < self.max_down and i_level >= self.temporal_down_offset else False
233
+ if spatial_down or temporal_down:
234
+ down.downsample = Downsample(block_in, cnn_type=cnn_param["cnn_type"], spatial_down=spatial_down, temporal_down=temporal_down)
235
+ self.down.append(down)
236
+
237
+ # middle
238
+ self.mid = nn.Module()
239
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, norm_type=norm_type, cnn_param=cnn_param)
240
+ if cnn_param["cnn_attention"] == "yes":
241
+ self.mid.attn_1 = AttnBlock(block_in, norm_type, cnn_param=cnn_param)
242
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, norm_type=norm_type, cnn_param=cnn_param)
243
+
244
+ # end
245
+ self.norm_out = Normalize(block_in, norm_type, norm_axis=cnn_param["cnn_norm_axis"])
246
+ if cnn_param["conv_inner_2d"] == "yes":
247
+ self.conv_out = Conv(block_in, (int(use_vae) + 1) * z_channels, kernel_size=3, stride=1, padding=1, cnn_type="2d")
248
+ else:
249
+ self.conv_out = Conv(block_in, (int(use_vae) + 1) * z_channels, kernel_size=3, stride=1, padding=1, cnn_type=cnn_param["cnn_type"])
250
+
251
+ def forward(self, x, return_hidden=False):
252
+ if not self.use_checkpoint:
253
+ return self._forward(x, return_hidden=return_hidden)
254
+ else:
255
+ return checkpoint.checkpoint(self._forward, x, return_hidden, use_reentrant=False)
256
+
257
+ def _forward(self, x: Tensor, return_hidden=False) -> Tensor:
258
+ # downsampling
259
+ h0 = self.conv_in(x)
260
+ hs = [h0]
261
+ for i_level in range(self.num_resolutions):
262
+ for i_block in range(self.num_res_blocks):
263
+ h = self.down[i_level].block[i_block](hs[-1])
264
+ if len(self.down[i_level].attn) > 0:
265
+ h = self.down[i_level].attn[i_block](h)
266
+ hs.append(h)
267
+ if hasattr(self.down[i_level], "downsample"):
268
+ hs.append(self.down[i_level].downsample(hs[-1]))
269
+
270
+ # middle
271
+ h = hs[-1]
272
+ hs_mid = [h]
273
+ h = self.mid.block_1(h)
274
+ if self.cnn_param["cnn_attention"] == "yes":
275
+ h = self.mid.attn_1(h)
276
+ h = self.mid.block_2(h)
277
+ hs_mid.append(h)
278
+ # end
279
+ h = self.norm_out(h)
280
+ h = swish(h)
281
+ h = self.conv_out(h)
282
+ if return_hidden:
283
+ return h, hs, hs_mid
284
+ else:
285
+ return h
286
+
287
+
288
+ class Decoder(nn.Module):
289
+ def __init__(
290
+ self,
291
+ ch: int,
292
+ ch_mult: list[int],
293
+ num_res_blocks: int,
294
+ z_channels: int,
295
+ out_ch = 3,
296
+ patch_size=8, temporal_patch_size=4,
297
+ norm_type="group", cnn_param=None,
298
+ use_checkpoint=False,
299
+ use_freq_dec=False, # use frequency features for decoder
300
+ use_pxsf=False
301
+ ):
302
+ super().__init__()
303
+ self.max_up = np.log2(patch_size)
304
+ self.temporal_max_up = np.log2(temporal_patch_size)
305
+ self.temporal_up_offset = self.max_up - self.temporal_max_up
306
+ self.ch = ch
307
+ self.num_resolutions = len(ch_mult)
308
+ self.num_res_blocks = num_res_blocks
309
+ self.ffactor = 2 ** (self.num_resolutions - 1)
310
+ self.cnn_param = cnn_param
311
+ self.use_checkpoint = use_checkpoint
312
+ self.use_freq_dec = use_freq_dec
313
+ self.use_pxsf = use_pxsf
314
+
315
+ # compute in_ch_mult, block_in and curr_res at lowest res
316
+ block_in = ch * ch_mult[self.num_resolutions - 1]
317
+
318
+ # z to block_in
319
+ if cnn_param["conv_inner_2d"] == "yes":
320
+ self.conv_in = Conv(z_channels, block_in, kernel_size=3, stride=1, padding=1, cnn_type="2d")
321
+ else:
322
+ self.conv_in = Conv(z_channels, block_in, kernel_size=3, stride=1, padding=1, cnn_type=cnn_param["cnn_type"])
323
+
324
+ # middle
325
+ self.mid = nn.Module()
326
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, norm_type=norm_type, cnn_param=cnn_param)
327
+ if cnn_param["cnn_attention"] == "yes":
328
+ self.mid.attn_1 = AttnBlock(block_in, norm_type=norm_type, cnn_param=cnn_param)
329
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, norm_type=norm_type, cnn_param=cnn_param)
330
+
331
+ # upsampling
332
+ self.up = nn.ModuleList()
333
+ for i_level in reversed(range(self.num_resolutions)):
334
+ block = nn.ModuleList()
335
+ attn = nn.ModuleList()
336
+ block_out = ch * ch_mult[i_level]
337
+ for _ in range(self.num_res_blocks + 1):
338
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, norm_type=norm_type, cnn_param=cnn_param))
339
+ block_in = block_out
340
+ up = nn.Module()
341
+ up.block = block
342
+ up.attn = attn
343
+ # upsample, stride=1, stride=2, stride=2 for 4x8x8 Video VAE, offset 1 compared with encoder
344
+ # https://github.com/black-forest-labs/flux/blob/b4f689aaccd40de93429865793e84a734f4a6254/src/flux/modules/autoencoder.py#L228
345
+ spatial_up = True if 1 <= i_level <= self.max_up else False
346
+ temporal_up = True if 1 <= i_level <= self.max_up and i_level >= self.temporal_up_offset+1 else False
347
+ if spatial_up or temporal_up:
348
+ up.upsample = Upsample(block_in, cnn_type=cnn_param["cnn_type"], spatial_up=spatial_up, temporal_up=temporal_up, use_pxsl=self.use_pxsf)
349
+ self.up.insert(0, up) # prepend to get consistent order
350
+
351
+ # end
352
+ self.norm_out = Normalize(block_in, norm_type, norm_axis=cnn_param["cnn_norm_axis"])
353
+ if cnn_param["conv_in_out_2d"] == "yes":
354
+ self.conv_out = Conv(block_in, out_ch, kernel_size=3, stride=1, padding=1, cnn_type="2d")
355
+ else:
356
+ self.conv_out = Conv(block_in, out_ch, kernel_size=3, stride=1, padding=1, cnn_type=cnn_param["cnn_type"])
357
+
358
+ def forward(self, z):
359
+ if not self.use_checkpoint:
360
+ return self._forward(z)
361
+ else:
362
+ return checkpoint.checkpoint(self._forward, z, use_reentrant=False)
363
+
364
+ def _forward(self, z: Tensor) -> Tensor:
365
+ # z to block_in
366
+ h = self.conv_in(z)
367
+
368
+ # middle
369
+ h = self.mid.block_1(h)
370
+ if self.cnn_param["cnn_attention"] == "yes":
371
+ h = self.mid.attn_1(h)
372
+ h = self.mid.block_2(h)
373
+
374
+ # upsampling
375
+ for i_level in reversed(range(self.num_resolutions)):
376
+ for i_block in range(self.num_res_blocks + 1):
377
+ h = self.up[i_level].block[i_block](h)
378
+ if len(self.up[i_level].attn) > 0:
379
+ h = self.up[i_level].attn[i_block](h)
380
+ if hasattr(self.up[i_level], "upsample"):
381
+ h = self.up[i_level].upsample(h)
382
+
383
+ # end
384
+ h = self.norm_out(h)
385
+ h = swish(h)
386
+ h = self.conv_out(h)
387
+ return h
388
+
389
+
390
+ class AutoEncoder(nn.Module):
391
+ def __init__(self, args):
392
+ super().__init__()
393
+ self.args = args
394
+ cnn_param = dict(
395
+ cnn_type=args.cnn_type,
396
+ conv_in_out_2d=args.conv_in_out_2d,
397
+ res_conv_2d=args.res_conv_2d,
398
+ cnn_attention=args.cnn_attention,
399
+ cnn_norm_axis=args.cnn_norm_axis,
400
+ conv_inner_2d=args.conv_inner_2d,
401
+ )
402
+ self.encoder = Encoder(
403
+ ch=args.base_ch,
404
+ ch_mult=args.encoder_ch_mult,
405
+ num_res_blocks=args.num_res_blocks,
406
+ z_channels=args.codebook_dim,
407
+ patch_size=args.patch_size,
408
+ temporal_patch_size=args.temporal_patch_size,
409
+ cnn_param=cnn_param,
410
+ use_checkpoint=args.use_checkpoint,
411
+ use_vae=args.use_vae,
412
+ )
413
+ self.decoder = Decoder(
414
+ ch=args.base_ch,
415
+ ch_mult=args.decoder_ch_mult,
416
+ num_res_blocks=args.num_res_blocks,
417
+ z_channels=args.codebook_dim,
418
+ patch_size=args.patch_size,
419
+ temporal_patch_size=args.temporal_patch_size,
420
+ cnn_param=cnn_param,
421
+ use_checkpoint=args.use_checkpoint,
422
+ use_freq_dec=args.use_freq_dec,
423
+ use_pxsf=args.use_pxsf # pixelshuffle for upsampling
424
+ )
425
+ self.z_drop = nn.Dropout(args.z_drop)
426
+ self.scale_factor = 0.3611
427
+ self.shift_factor = 0.1159
428
+ self.codebook_dim = self.embed_dim = args.codebook_dim
429
+
430
+ self.gan_feat_weight = args.gan_feat_weight
431
+ self.video_perceptual_weight = args.video_perceptual_weight
432
+ self.recon_loss_type = args.recon_loss_type
433
+ self.l1_weight = args.l1_weight
434
+ self.use_vae = args.use_vae
435
+ self.kl_weight = args.kl_weight
436
+ self.lfq_weight = args.lfq_weight
437
+ self.image_gan_weight = args.image_gan_weight # image GAN loss weight
438
+ self.video_gan_weight = args.video_gan_weight # video GAN loss weight
439
+ self.perceptual_weight = args.perceptual_weight
440
+ self.flux_weight = args.flux_weight
441
+ self.cycle_weight = args.cycle_weight
442
+ self.cycle_feat_weight = args.cycle_feat_weight
443
+ self.cycle_gan_weight = args.cycle_gan_weight
444
+
445
+ self.flux_image_encoder = None
446
+
447
+ if not args.use_vae:
448
+ if args.quantizer_type == 'MultiScaleBSQ':
449
+ self.quantizer = MultiScaleBSQ(
450
+ dim = args.codebook_dim, # this is the input feature dimension, defaults to log2(codebook_size) if not defined
451
+ codebook_size = args.codebook_size, # codebook size, must be a power of 2
452
+ entropy_loss_weight = args.entropy_loss_weight, # how much weight to place on entropy loss
453
+ diversity_gamma = args.diversity_gamma, # within entropy loss, how much weight to give to diversity of codes, taken from https://arxiv.org/abs/1911.05894
454
+ preserve_norm=args.preserve_norm, # preserve norm of the input for BSQ
455
+ ln_before_quant=args.ln_before_quant, # use layer norm before quantization
456
+ ln_init_by_sqrt=args.ln_init_by_sqrt, # layer norm init value 1/sqrt(d)
457
+ commitment_loss_weight=args.commitment_loss_weight, # loss weight of commitment loss
458
+ new_quant=args.new_quant,
459
+ use_decay_factor=args.use_decay_factor,
460
+ mask_out=args.mask_out,
461
+ use_stochastic_depth=args.use_stochastic_depth,
462
+ drop_rate=args.drop_rate,
463
+ schedule_mode=args.schedule_mode,
464
+ keep_first_quant=args.keep_first_quant,
465
+ keep_last_quant=args.keep_last_quant,
466
+ remove_residual_detach=args.remove_residual_detach,
467
+ use_out_phi=args.use_out_phi,
468
+ use_out_phi_res=args.use_out_phi_res,
469
+ random_flip = args.random_flip,
470
+ flip_prob = args.flip_prob,
471
+ flip_mode = args.flip_mode,
472
+ max_flip_lvl = args.max_flip_lvl,
473
+ random_flip_1lvl = args.random_flip_1lvl,
474
+ flip_lvl_idx = args.flip_lvl_idx,
475
+ drop_when_test = args.drop_when_test,
476
+ drop_lvl_idx = args.drop_lvl_idx,
477
+ drop_lvl_num = args.drop_lvl_num,
478
+ )
479
+ self.quantize = self.quantizer
480
+ self.vocab_size = args.codebook_size
481
+ else:
482
+ raise NotImplementedError(f"{args.quantizer_type} not supported")
483
+
484
+
485
+ def forward(self, x):
486
+ is_image = x.ndim == 4
487
+ if not is_image:
488
+ B, C, T, H, W = x.shape
489
+ else:
490
+ B, C, H, W = x.shape
491
+ T = 1
492
+ enc_dtype = ptdtype[self.args.encoder_dtype]
493
+
494
+ with torch.amp.autocast("cuda", dtype=enc_dtype):
495
+ h, hs, hs_mid = self.encoder(x, return_hidden=True) # B C H W or B C T H W
496
+ hs = [_h.detach() for _h in hs]
497
+ hs_mid = [_h.detach() for _h in hs_mid]
498
+ h = h.to(dtype=torch.float32)
499
+ # print(z.shape)
500
+ # Multiscale LFQ
501
+ z, all_indices, all_loss = self.quantizer(h)
502
+ x_recon = self.decoder(z)
503
+ vq_output = {
504
+ "commitment_loss": torch.mean(all_loss) * self.lfq_weight, # here commitment loss is sum of commitment loss and entropy penalty
505
+ "encodings": all_indices,
506
+ }
507
+ return x_recon, vq_output
508
+
509
+ def encode_for_raw_features(self, x, scale_schedule, return_residual_norm_per_scale=False):
510
+ is_image = x.ndim == 4
511
+ if not is_image:
512
+ B, C, T, H, W = x.shape
513
+ else:
514
+ B, C, H, W = x.shape
515
+ T = 1
516
+
517
+ enc_dtype = ptdtype[self.args.encoder_dtype]
518
+ with torch.amp.autocast("cuda", dtype=enc_dtype):
519
+ h, hs, hs_mid = self.encoder(x, return_hidden=True) # B C H W or B C T H W
520
+
521
+ hs = [_h.detach() for _h in hs]
522
+ hs_mid = [_h.detach() for _h in hs_mid]
523
+ h = h.to(dtype=torch.float32)
524
+ return h, hs, hs_mid
525
+
526
+ def encode(self, x, scale_schedule, return_residual_norm_per_scale=False):
527
+ h, hs, hs_mid = self.encode_for_raw_features(x, scale_schedule, return_residual_norm_per_scale)
528
+ # Multiscale LFQ
529
+ z, all_indices, all_bit_indices, residual_norm_per_scale, all_loss, var_input = self.quantizer(h, scale_schedule=scale_schedule, return_residual_norm_per_scale=return_residual_norm_per_scale)
530
+ return h, z, all_indices, all_bit_indices, residual_norm_per_scale, var_input
531
+
532
+ def decode(self, z):
533
+ x_recon = self.decoder(z)
534
+ x_recon = torch.clamp(x_recon, min=-1, max=1)
535
+ return x_recon
536
+
537
+ def decode_from_indices(self, all_indices, scale_schedule, label_type):
538
+ summed_codes = 0
539
+ for idx_Bl in all_indices:
540
+ codes = self.quantizer.lfq.indices_to_codes(idx_Bl, label_type)
541
+ summed_codes += F.interpolate(codes, size=scale_schedule[-1], mode=self.quantizer.z_interplote_up)
542
+ assert summed_codes.shape[-3] == 1
543
+ x_recon = self.decoder(summed_codes.squeeze(-3))
544
+ x_recon = torch.clamp(x_recon, min=-1, max=1)
545
+ return summed_codes, x_recon
546
+
547
+ @staticmethod
548
+ def add_model_specific_args(parent_parser):
549
+ parser = argparse.ArgumentParser(parents=[parent_parser], add_help=False)
550
+ parser.add_argument("--flux_weight", type=float, default=0)
551
+ parser.add_argument("--cycle_weight", type=float, default=0)
552
+ parser.add_argument("--cycle_feat_weight", type=float, default=0)
553
+ parser.add_argument("--cycle_gan_weight", type=float, default=0)
554
+ parser.add_argument("--cycle_loop", type=int, default=0)
555
+ parser.add_argument("--z_drop", type=float, default=0.)
556
+ return parser
557
+
models/bsq_vae/multiscale_bsq.py ADDED
@@ -0,0 +1,718 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Binary Spherical Quantization
3
+ Proposed in https://arxiv.org/abs/2406.07548
4
+
5
+ In the simplest setup, each dimension is quantized into {-1, 1}.
6
+ An entropy penalty is used to encourage utilization.
7
+ """
8
+
9
+ import random
10
+ from math import log2, ceil
11
+ from functools import partial, cache
12
+ from collections import namedtuple
13
+ from contextlib import nullcontext
14
+
15
+ import torch.distributed as dist
16
+ from torch.distributed import nn as dist_nn
17
+
18
+ import torch
19
+ from torch import nn, einsum
20
+ import torch.nn.functional as F
21
+ from torch.nn import Module
22
+ from torch.amp import autocast
23
+ import numpy as np
24
+
25
+ from einops import rearrange, reduce, pack, unpack
26
+
27
+ # from einx import get_at
28
+
29
+ from .dynamic_resolution import predefined_HW_Scales_dynamic
30
+
31
+ # constants
32
+
33
+ Return = namedtuple('Return', ['quantized', 'indices', 'bit_indices', 'entropy_aux_loss'])
34
+
35
+ LossBreakdown = namedtuple('LossBreakdown', ['per_sample_entropy', 'batch_entropy', 'commitment'])
36
+
37
+ # distributed helpers
38
+
39
+ @cache
40
+ def is_distributed():
41
+ return dist.is_initialized() and dist.get_world_size() > 1
42
+
43
+ def maybe_distributed_mean(t):
44
+ if not is_distributed():
45
+ return t
46
+
47
+ dist_nn.all_reduce(t)
48
+ t = t / dist.get_world_size()
49
+ return t
50
+
51
+ # helper functions
52
+
53
+ def exists(v):
54
+ return v is not None
55
+
56
+ def identity(t):
57
+ return t
58
+
59
+ def default(*args):
60
+ for arg in args:
61
+ if exists(arg):
62
+ return arg() if callable(arg) else arg
63
+ return None
64
+
65
+ def round_up_multiple(num, mult):
66
+ return ceil(num / mult) * mult
67
+
68
+ def pack_one(t, pattern):
69
+ return pack([t], pattern)
70
+
71
+ def unpack_one(t, ps, pattern):
72
+ return unpack(t, ps, pattern)[0]
73
+
74
+ def l2norm(t):
75
+ return F.normalize(t, dim = -1)
76
+
77
+ # entropy
78
+
79
+ def log(t, eps = 1e-5):
80
+ return t.clamp(min = eps).log()
81
+
82
+ def entropy(prob):
83
+ return (-prob * log(prob)).sum(dim=-1)
84
+
85
+ # cosine sim linear
86
+
87
+ class CosineSimLinear(Module):
88
+ def __init__(
89
+ self,
90
+ dim_in,
91
+ dim_out,
92
+ scale = 1.
93
+ ):
94
+ super().__init__()
95
+ self.scale = scale
96
+ self.weight = nn.Parameter(torch.randn(dim_in, dim_out))
97
+
98
+ def forward(self, x):
99
+ x = F.normalize(x, dim = -1)
100
+ w = F.normalize(self.weight, dim = 0)
101
+ return (x @ w) * self.scale
102
+
103
+
104
+ def get_latent2scale_schedule(T: int, H: int, W: int, mode="original"):
105
+ assert mode in ["original", "dynamic", "dense", "same1", "same2", "same3"]
106
+ predefined_HW_Scales = {
107
+ # 256 * 256
108
+ (32, 32): [(1, 1), (2, 2), (3, 3), (4, 4), (6, 6), (9, 9), (13, 13), (18, 18), (24, 24), (32, 32)],
109
+ (16, 16): [(1, 1), (2, 2), (3, 3), (4, 4), (5, 5), (6, 6), (8, 8), (10, 10), (13, 13), (16, 16)],
110
+ # 1024x1024
111
+ (64, 64): [(1, 1), (2, 2), (3, 3), (4, 4), (5, 5), (7, 7), (9, 9), (12, 12), (16, 16), (21, 21), (27, 27), (36, 36), (48, 48), (64, 64)],
112
+
113
+ (36, 64): [(1, 1), (2, 2), (3, 3), (4, 4), (6, 6), (9, 12), (13, 16), (18, 24), (24, 32), (32, 48), (36, 64)],
114
+ }
115
+ if mode == "dynamic":
116
+ predefined_HW_Scales.update(predefined_HW_Scales_dynamic)
117
+ elif mode == "dense":
118
+ predefined_HW_Scales[(16, 16)] = [(x, x) for x in range(1, 16+1)]
119
+ predefined_HW_Scales[(32, 32)] = predefined_HW_Scales[(16, 16)] + [(20, 20), (24, 24), (28, 28), (32, 32)]
120
+ predefined_HW_Scales[(64, 64)] = predefined_HW_Scales[(32, 32)] + [(40, 40), (48, 48), (56, 56), (64, 64)]
121
+ elif mode.startswith("same"):
122
+ num_quant = int(mode[len("same"):])
123
+ predefined_HW_Scales[(16, 16)] = [(16, 16) for _ in range(num_quant)]
124
+ predefined_HW_Scales[(32, 32)] = [(32, 32) for _ in range(num_quant)]
125
+ predefined_HW_Scales[(64, 64)] = [(64, 64) for _ in range(num_quant)]
126
+
127
+ predefined_T_Scales = [1, 2, 3, 4, 5, 6, 7, 9, 11, 13, 15, 17, 17, 17, 17, 17]
128
+ patch_THW_shape_per_scale = predefined_HW_Scales[(H, W)]
129
+ if len(predefined_T_Scales) < len(patch_THW_shape_per_scale):
130
+ # print("warning: the length of predefined_T_Scales is less than the length of patch_THW_shape_per_scale!")
131
+ predefined_T_Scales += [predefined_T_Scales[-1]] * (len(patch_THW_shape_per_scale) - len(predefined_T_Scales))
132
+ patch_THW_shape_per_scale = [(min(T, t), h, w ) for (h, w), t in zip(patch_THW_shape_per_scale, predefined_T_Scales[:len(patch_THW_shape_per_scale)])]
133
+ return patch_THW_shape_per_scale
134
+
135
+ class LayerNorm(nn.Module):
136
+ r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
137
+ The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
138
+ shape (batch_size, height, width, channels) while channels_first corresponds to inputs
139
+ with shape (batch_size, channels, height, width).
140
+ normalized_shape: int
141
+ """
142
+ def __init__(self, normalized_shape, norm_weight=False, eps=1e-6, data_format="channels_first"):
143
+ super().__init__()
144
+ if norm_weight:
145
+ self.weight = nn.Parameter(torch.ones(normalized_shape)/(normalized_shape**0.5))
146
+ else:
147
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
148
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
149
+ self.eps = eps
150
+ self.data_format = data_format
151
+ if self.data_format not in ["channels_last", "channels_first"]:
152
+ raise NotImplementedError
153
+ self.normalized_shape = (normalized_shape, )
154
+
155
+ def forward(self, x):
156
+ if self.data_format == "channels_last":
157
+ return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
158
+ elif self.data_format == "channels_first":
159
+ u = x.mean(1, keepdim=True)
160
+ s = (x - u).pow(2).mean(1, keepdim=True)
161
+ x = (x - u) / torch.sqrt(s + self.eps)
162
+ if x.ndim == 4: # (b, c, h, w)
163
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
164
+ elif x.ndim == 5: # (b, c, t, h, w)
165
+ x = self.weight[:, None, None, None] * x + self.bias[:, None, None, None]
166
+ else:
167
+ raise ValueError("the number of dimensions of the input should be 4 or 5")
168
+ return x
169
+
170
+ class MultiScaleBSQ(Module):
171
+ """ Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf """
172
+
173
+ def __init__(
174
+ self,
175
+ *,
176
+ dim,
177
+ codebook_size,
178
+ soft_clamp_input_value = None,
179
+ aux_loss = False, # intermediate auxiliary loss
180
+ ln_before_quant=False, # add a LN before multi-scale RQ
181
+ ln_init_by_sqrt=False, # weight init by 1/sqrt(d)
182
+ use_decay_factor=False,
183
+ use_stochastic_depth=False,
184
+ drop_rate=0.,
185
+ schedule_mode="original", # ["original", "dynamic", "dense"]
186
+ keep_first_quant=False,
187
+ keep_last_quant=False,
188
+ remove_residual_detach=False,
189
+ random_flip = False,
190
+ flip_prob = 0.5,
191
+ flip_mode = "stochastic", # "stochastic", "deterministic"
192
+ max_flip_lvl = 1,
193
+ random_flip_1lvl = False, # random flip one level each time
194
+ flip_lvl_idx = None,
195
+ drop_when_test=False,
196
+ drop_lvl_idx=None,
197
+ drop_lvl_num=0,
198
+ **kwargs
199
+ ):
200
+ super().__init__()
201
+ codebook_dim = int(log2(codebook_size))
202
+
203
+ requires_projection = codebook_dim != dim
204
+ self.project_in = nn.Linear(dim, codebook_dim) if requires_projection else nn.Identity()
205
+ self.project_out = nn.Linear(codebook_dim, dim) if requires_projection else nn.Identity()
206
+ self.has_projections = requires_projection
207
+ self.layernorm = LayerNorm(codebook_dim, norm_weight=ln_init_by_sqrt) if ln_before_quant else nn.Identity()
208
+ self.use_stochastic_depth = use_stochastic_depth
209
+ self.drop_rate = drop_rate
210
+ self.remove_residual_detach = remove_residual_detach
211
+ self.random_flip = random_flip
212
+ self.flip_prob = flip_prob
213
+ self.flip_mode = flip_mode
214
+ self.max_flip_lvl = max_flip_lvl
215
+ self.random_flip_1lvl = random_flip_1lvl
216
+ self.flip_lvl_idx = flip_lvl_idx
217
+ assert (random_flip and random_flip_1lvl) == False
218
+ self.drop_when_test = drop_when_test
219
+ self.drop_lvl_idx = drop_lvl_idx
220
+ self.drop_lvl_num = drop_lvl_num
221
+ if self.drop_when_test:
222
+ assert drop_lvl_idx is not None
223
+ assert drop_lvl_num > 0
224
+
225
+ self.lfq = BSQ(
226
+ dim = codebook_dim,
227
+ codebook_scale = 1/np.sqrt(codebook_dim),
228
+ soft_clamp_input_value = soft_clamp_input_value,
229
+ # experimental_softplus_entropy_loss=True,
230
+ # entropy_loss_offset=2,
231
+ **kwargs
232
+ )
233
+
234
+ self.z_interplote_up = 'trilinear'
235
+ self.z_interplote_down = 'area'
236
+
237
+ self.use_decay_factor = use_decay_factor
238
+ self.schedule_mode = schedule_mode
239
+ self.keep_first_quant = keep_first_quant
240
+ self.keep_last_quant = keep_last_quant
241
+ if self.use_stochastic_depth and self.drop_rate > 0:
242
+ assert self.keep_first_quant or self.keep_last_quant
243
+
244
+ @property
245
+ def codebooks(self):
246
+ return self.lfq.codebook
247
+
248
+ def get_codes_from_indices(self, indices_list):
249
+ all_codes = []
250
+ for indices in indices_list:
251
+ codes = self.lfq.indices_to_codes(indices)
252
+ all_codes.append(codes)
253
+ _, _, T, H, W = all_codes[-1].size()
254
+ summed_codes = 0
255
+ for code in all_codes:
256
+ summed_codes += F.interpolate(code, size=(T, H, W), mode=self.z_interplote_up)
257
+ return summed_codes
258
+
259
+ def get_output_from_indices(self, indices):
260
+ codes = self.get_codes_from_indices(indices)
261
+ codes_summed = reduce(codes, 'q ... -> ...', 'sum')
262
+ return self.project_out(codes_summed)
263
+
264
+ def flip_quant(self, x):
265
+ assert self.flip_mode == 'stochastic'
266
+ flip_mask = torch.rand_like(x) < self.flip_prob
267
+ x = x.clone()
268
+ x[flip_mask] = -x[flip_mask]
269
+ return x
270
+
271
+ def forward(
272
+ self,
273
+ x,
274
+ scale_schedule=None,
275
+ mask = None,
276
+ return_all_codes = False,
277
+ return_residual_norm_per_scale = False
278
+ ):
279
+ if x.ndim == 4:
280
+ x = x.unsqueeze(2)
281
+ B, C, T, H, W = x.size()
282
+
283
+ if scale_schedule is None:
284
+ if self.schedule_mode.startswith("same"):
285
+ scale_num = int(self.schedule_mode[len("same"):])
286
+ assert T == 1
287
+ scale_schedule = [(1, H, W)] * scale_num
288
+ else:
289
+ scale_schedule = get_latent2scale_schedule(T, H, W, mode=self.schedule_mode)
290
+ scale_num = len(scale_schedule)
291
+
292
+ # x = self.project_in(x)
293
+ x = x.permute(0, 2, 3, 4, 1).contiguous() # (b, c, t, h, w) => (b, t, h, w, c)
294
+ x = self.project_in(x)
295
+ x = x.permute(0, 4, 1, 2, 3).contiguous() # (b, t, h, w, c) => (b, c, t, h, w)
296
+ x = self.layernorm(x)
297
+
298
+ quantized_out = 0.
299
+ residual = x
300
+
301
+ all_losses = []
302
+ all_indices = []
303
+ all_bit_indices = []
304
+ var_inputs = []
305
+ residual_norm_per_scale = []
306
+
307
+ # go through the layers
308
+ out_fact = init_out_fact = 1.0
309
+ # residual_list = []
310
+ # interpolate_residual_list = []
311
+ # quantized_list = []
312
+ if self.drop_when_test:
313
+ drop_lvl_start = self.drop_lvl_idx
314
+ drop_lvl_end = self.drop_lvl_idx + self.drop_lvl_num
315
+ scale_num = len(scale_schedule)
316
+ with autocast('cuda', enabled = False):
317
+ for si, (pt, ph, pw) in enumerate(scale_schedule):
318
+ out_fact = max(0.1, out_fact) if self.use_decay_factor else init_out_fact
319
+ if (pt, ph, pw) != (T, H, W):
320
+ interpolate_residual = F.interpolate(residual, size=(pt, ph, pw), mode=self.z_interplote_down)
321
+ else:
322
+ interpolate_residual = residual
323
+ if return_residual_norm_per_scale:
324
+ residual_norm_per_scale.append((torch.abs(interpolate_residual) < 0.05 * self.lfq.codebook_scale).sum() / interpolate_residual.numel())
325
+ # residual_list.append(torch.norm(residual.detach(), dim=1).mean())
326
+ # interpolate_residual_list.append(torch.norm(interpolate_residual.detach(), dim=1).mean())
327
+ if self.training and self.use_stochastic_depth and random.random() < self.drop_rate:
328
+ if (si == 0 and self.keep_first_quant) or (si == scale_num - 1 and self.keep_last_quant):
329
+ quantized, indices, _, loss = self.lfq(interpolate_residual)
330
+ quantized = quantized * out_fact
331
+ all_indices.append(indices)
332
+ all_losses.append(loss)
333
+ else:
334
+ quantized = torch.zeros_like(interpolate_residual)
335
+ elif self.drop_when_test and drop_lvl_start <= si < drop_lvl_end:
336
+ continue
337
+ else:
338
+ # residual_norm = torch.norm(interpolate_residual.detach(), dim=1) # (b, t, h, w)
339
+ # print(si, residual_norm.min(), residual_norm.max(), residual_norm.mean())
340
+ quantized, indices, bit_indices, loss = self.lfq(interpolate_residual)
341
+ if self.random_flip and si < self.max_flip_lvl:
342
+ quantized = self.flip_quant(quantized)
343
+ if self.random_flip_1lvl and si == self.flip_lvl_idx:
344
+ quantized = self.flip_quant(quantized)
345
+ quantized = quantized * out_fact
346
+ all_indices.append(indices)
347
+ # quantized_list.append(torch.norm(quantized.detach(), dim=1).mean())
348
+ if (pt, ph, pw) != (T, H, W):
349
+ quantized = F.interpolate(quantized, size=(T, H, W), mode=self.z_interplote_up).contiguous()
350
+
351
+ if self.remove_residual_detach:
352
+ residual = residual - quantized
353
+ else:
354
+ residual = residual - quantized.detach()
355
+ quantized_out = quantized_out + quantized
356
+
357
+ all_bit_indices.append(bit_indices)
358
+ all_losses.append(loss)
359
+ if si != scale_num - 1:
360
+ var_inputs.append(F.interpolate(quantized_out, size=scale_schedule[si+1], mode=self.z_interplote_down).contiguous())
361
+
362
+ if self.use_decay_factor:
363
+ out_fact -= 0.1
364
+ # print("residual_list:", residual_list)
365
+ # print("interpolate_residual_list:", interpolate_residual_list)
366
+ # print("quantized_list:", quantized_list)
367
+ # import ipdb; ipdb.set_trace()
368
+ # project out, if needed
369
+ quantized_out = quantized_out.permute(0, 2, 3, 4, 1).contiguous() # (b, c, t, h, w) => (b, t, h, w, c)
370
+ quantized_out = self.project_out(quantized_out)
371
+ quantized_out = quantized_out.permute(0, 4, 1, 2, 3).contiguous() # (b, t, h, w, c) => (b, c, t, h, w)
372
+
373
+ # image
374
+ if quantized_out.size(2) == 1:
375
+ quantized_out = quantized_out.squeeze(2)
376
+
377
+ # stack all losses and indices
378
+
379
+ all_losses = torch.stack(all_losses, dim = -1)
380
+
381
+ ret = (quantized_out, all_indices, all_bit_indices, residual_norm_per_scale, all_losses, var_inputs)
382
+
383
+ if not return_all_codes:
384
+ return ret
385
+
386
+ # whether to return all codes from all codebooks across layers
387
+ all_codes = self.get_codes_from_indices(all_indices)
388
+
389
+ # will return all codes in shape (quantizer, batch, sequence length, codebook dimension)
390
+
391
+ return (*ret, all_codes)
392
+
393
+
394
+ class BSQ(Module):
395
+ def __init__(
396
+ self,
397
+ *,
398
+ dim = None,
399
+ codebook_size = None,
400
+ entropy_loss_weight = 0.1,
401
+ commitment_loss_weight = 0.25,
402
+ diversity_gamma = 1.,
403
+ straight_through_activation = nn.Identity(),
404
+ num_codebooks = 1,
405
+ keep_num_codebooks_dim = None,
406
+ codebook_scale = 1., # for residual LFQ, codebook scaled down by 2x at each layer
407
+ frac_per_sample_entropy = 1., # make less than 1. to only use a random fraction of the probs for per sample entropy
408
+ has_projections = None,
409
+ projection_has_bias = True,
410
+ soft_clamp_input_value = None,
411
+ cosine_sim_project_in = False,
412
+ cosine_sim_project_in_scale = None,
413
+ channel_first = None,
414
+ experimental_softplus_entropy_loss = False,
415
+ entropy_loss_offset = 5., # how much to shift the loss before softplus
416
+ spherical = True, # from https://arxiv.org/abs/2406.07548
417
+ force_quantization_f32 = True, # will force the quantization step to be full precision
418
+ inv_temperature = 100.0,
419
+ gamma0=1.0, gamma=1.0, zeta=1.0,
420
+ preserve_norm = False, # whether to preserve the original norm info
421
+ new_quant = False, # new quant function,
422
+ mask_out = False, # mask the output as 0 in some conditions
423
+ use_out_phi = False, # use output phi network
424
+ use_out_phi_res = False, # residual out phi
425
+ ):
426
+ super().__init__()
427
+
428
+ # some assert validations
429
+
430
+ assert exists(dim) or exists(codebook_size), 'either dim or codebook_size must be specified for LFQ'
431
+ assert not exists(codebook_size) or log2(codebook_size).is_integer(), f'your codebook size must be a power of 2 for lookup free quantization (suggested {2 ** ceil(log2(codebook_size))})'
432
+
433
+ codebook_size = default(codebook_size, lambda: 2 ** dim)
434
+ self.codebook_size = codebook_size
435
+
436
+ codebook_dim = int(log2(codebook_size))
437
+ codebook_dims = codebook_dim * num_codebooks
438
+ dim = default(dim, codebook_dims)
439
+ self.codebook_dims = codebook_dims
440
+
441
+ has_projections = default(has_projections, dim != codebook_dims)
442
+
443
+ if cosine_sim_project_in:
444
+ cosine_sim_project_in = default(cosine_sim_project_in_scale, codebook_scale)
445
+ project_in_klass = partial(CosineSimLinear, scale = cosine_sim_project_in)
446
+ else:
447
+ project_in_klass = partial(nn.Linear, bias = projection_has_bias)
448
+
449
+ self.project_in = project_in_klass(dim, codebook_dims) if has_projections else nn.Identity() # nn.Identity()
450
+ self.project_out = nn.Linear(codebook_dims, dim, bias = projection_has_bias) if has_projections else nn.Identity() # nn.Identity()
451
+ self.has_projections = has_projections
452
+
453
+ self.out_phi = nn.Linear(codebook_dims, codebook_dims) if use_out_phi else nn.Identity()
454
+ self.use_out_phi_res = use_out_phi_res
455
+ if self.use_out_phi_res:
456
+ self.out_phi_scale = nn.Parameter(torch.zeros(codebook_dims), requires_grad=True) # init as zero
457
+
458
+ self.dim = dim
459
+ self.codebook_dim = codebook_dim
460
+ self.num_codebooks = num_codebooks
461
+
462
+ keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1)
463
+ assert not (num_codebooks > 1 and not keep_num_codebooks_dim)
464
+ self.keep_num_codebooks_dim = keep_num_codebooks_dim
465
+
466
+ # channel first
467
+
468
+ self.channel_first = channel_first
469
+
470
+ # straight through activation
471
+
472
+ self.activation = straight_through_activation
473
+
474
+ # For BSQ (binary spherical quantization)
475
+ if not spherical:
476
+ raise ValueError("For BSQ, spherical must be True.")
477
+ self.persample_entropy_compute = 'analytical'
478
+ self.inv_temperature = inv_temperature
479
+ self.gamma0 = gamma0 # loss weight for entropy penalty
480
+ self.gamma = gamma # loss weight for entropy penalty
481
+ self.zeta = zeta # loss weight for entire entropy penalty
482
+ self.preserve_norm = preserve_norm
483
+ self.new_quant = new_quant
484
+ self.mask_out = mask_out
485
+
486
+ # entropy aux loss related weights
487
+
488
+ assert 0 < frac_per_sample_entropy <= 1.
489
+ self.frac_per_sample_entropy = frac_per_sample_entropy
490
+
491
+ self.diversity_gamma = diversity_gamma
492
+ self.entropy_loss_weight = entropy_loss_weight
493
+
494
+ # codebook scale
495
+
496
+ self.codebook_scale = codebook_scale
497
+
498
+ # commitment loss
499
+
500
+ self.commitment_loss_weight = commitment_loss_weight
501
+
502
+ # whether to soft clamp the input value from -value to value
503
+
504
+ self.soft_clamp_input_value = soft_clamp_input_value
505
+ assert not exists(soft_clamp_input_value) or soft_clamp_input_value >= codebook_scale
506
+
507
+ # whether to make the entropy loss positive through a softplus (experimental, please report if this worked or not in discussions)
508
+
509
+ self.entropy_loss_offset = entropy_loss_offset
510
+ self.experimental_softplus_entropy_loss = experimental_softplus_entropy_loss
511
+
512
+ # for no auxiliary loss, during inference
513
+
514
+ self.register_buffer('mask', 2 ** torch.arange(codebook_dim - 1, -1, -1))
515
+ self.register_buffer('zero', torch.tensor(0.), persistent = False)
516
+
517
+ # whether to force quantization step to be f32
518
+
519
+ self.force_quantization_f32 = force_quantization_f32
520
+
521
+ # codes
522
+
523
+ # all_codes = torch.arange(codebook_size)
524
+ # bits = ((all_codes[..., None].int() & self.mask) != 0).float()
525
+ # codebook = self.bits_to_codes(bits)
526
+
527
+ # self.register_buffer('codebook', codebook.float(), persistent = False)
528
+
529
+ def bits_to_codes(self, bits):
530
+ return bits * self.codebook_scale * 2 - self.codebook_scale
531
+
532
+ # @property
533
+ # def dtype(self):
534
+ # return self.codebook.dtype
535
+
536
+ def indices_to_codes(
537
+ self,
538
+ indices,
539
+ label_type = 'int_label',
540
+ project_out = True
541
+ ):
542
+ assert label_type in ['int_label', 'bit_label']
543
+ is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim))
544
+ should_transpose = default(self.channel_first, is_img_or_video)
545
+
546
+ if not self.keep_num_codebooks_dim:
547
+ if label_type == 'int_label':
548
+ indices = rearrange(indices, '... -> ... 1')
549
+ else:
550
+ indices = indices.unsqueeze(-2)
551
+
552
+ # indices to codes, which are bits of either -1 or 1
553
+
554
+ if label_type == 'int_label':
555
+ assert indices[..., None].int().min() > 0
556
+ bits = ((indices[..., None].int() & self.mask) != 0).float() # .to(self.dtype)
557
+ else:
558
+ bits = indices
559
+
560
+ codes = self.bits_to_codes(bits)
561
+
562
+ codes = l2norm(codes) # must normalize when using BSQ
563
+
564
+ codes = rearrange(codes, '... c d -> ... (c d)')
565
+
566
+ # whether to project codes out to original dimensions
567
+ # if the input feature dimensions were not log2(codebook size)
568
+
569
+ if project_out:
570
+ codes = self.project_out(codes)
571
+
572
+ # rearrange codes back to original shape
573
+
574
+ if should_transpose:
575
+ codes = rearrange(codes, 'b ... d -> b d ...')
576
+
577
+ return codes
578
+
579
+ def quantize(self, z):
580
+ assert z.shape[-1] == self.codebook_dims, f"Expected {self.codebook_dims} dimensions, got {z.shape[-1]}"
581
+
582
+ zhat = torch.where(z > 0,
583
+ torch.tensor(1, dtype=z.dtype, device=z.device),
584
+ torch.tensor(-1, dtype=z.dtype, device=z.device))
585
+ return z + (zhat - z).detach()
586
+
587
+ def quantize_new(self, z):
588
+ assert z.shape[-1] == self.codebook_dims, f"Expected {self.codebook_dims} dimensions, got {z.shape[-1]}"
589
+
590
+ zhat = torch.where(z > 0,
591
+ torch.tensor(1, dtype=z.dtype, device=z.device),
592
+ torch.tensor(-1, dtype=z.dtype, device=z.device))
593
+
594
+ q_scale = 1. / (self.codebook_dims ** 0.5)
595
+ zhat = q_scale * zhat # on unit sphere
596
+
597
+ return z + (zhat - z).detach()
598
+
599
+ def soft_entropy_loss(self, z):
600
+ if self.persample_entropy_compute == 'analytical':
601
+ # if self.l2_norm:
602
+ p = torch.sigmoid(-4 * z / (self.codebook_dims ** 0.5) * self.inv_temperature)
603
+ # else:
604
+ # p = torch.sigmoid(-4 * z * self.inv_temperature)
605
+ prob = torch.stack([p, 1-p], dim=-1) # (b, h, w, 18, 2)
606
+ per_sample_entropy = self.get_entropy(prob, dim=-1, normalize=False).sum(dim=-1).mean() # (b,h,w,18)->(b,h,w)->scalar
607
+ else:
608
+ per_sample_entropy = self.get_entropy(prob, dim=-1, normalize=False).sum(dim=-1).mean()
609
+
610
+ # macro average of the probability of each subgroup
611
+ avg_prob = reduce(prob, '... g d ->g d', 'mean') # (18, 2)
612
+ codebook_entropy = self.get_entropy(avg_prob, dim=-1, normalize=False)
613
+
614
+ # the approximation of the entropy is the sum of the entropy of each subgroup
615
+ return per_sample_entropy, codebook_entropy.sum(), avg_prob
616
+
617
+ def get_entropy(self, count, dim=-1, eps=1e-4, normalize=True):
618
+ if normalize: # False
619
+ probs = (count + eps) / (count + eps).sum(dim=dim, keepdim =True)
620
+ else: # True
621
+ probs = count
622
+ H = -(probs * torch.log(probs + 1e-8)).sum(dim=dim)
623
+ return H
624
+
625
+ def forward(
626
+ self,
627
+ x,
628
+ return_loss_breakdown = False,
629
+ mask = None,
630
+ entropy_weight=0.1
631
+ ):
632
+ """
633
+ einstein notation
634
+ b - batch
635
+ n - sequence (or flattened spatial dimensions)
636
+ d - feature dimension, which is also log2(codebook size)
637
+ c - number of codebook dim
638
+ """
639
+
640
+ is_img_or_video = x.ndim >= 4
641
+ should_transpose = default(self.channel_first, is_img_or_video)
642
+
643
+ # standardize image or video into (batch, seq, dimension)
644
+
645
+ if should_transpose:
646
+ x = rearrange(x, 'b d ... -> b ... d')
647
+ x, ps = pack_one(x, 'b * d') # x.shape [b, hwt, c]
648
+
649
+ assert x.shape[-1] == self.dim, f'expected dimension of {self.dim} but received {x.shape[-1]}'
650
+
651
+ x = self.project_in(x)
652
+
653
+ # split out number of codebooks
654
+
655
+ x = rearrange(x, 'b n (c d) -> b n c d', c = self.num_codebooks)
656
+
657
+ x = l2norm(x)
658
+
659
+ # whether to force quantization step to be full precision or not
660
+
661
+ force_f32 = self.force_quantization_f32
662
+
663
+ quantization_context = partial(autocast, 'cuda', enabled = False) if force_f32 else nullcontext
664
+
665
+ indices = None
666
+ with quantization_context():
667
+
668
+ if force_f32:
669
+ orig_dtype = x.dtype
670
+ x = x.float()
671
+
672
+ # use straight-through gradients (optionally with custom activation fn) if training
673
+ if self.new_quant:
674
+ quantized = self.quantize_new(x)
675
+
676
+ # calculate indices
677
+ bit_indices = (quantized > 0).int()
678
+ entropy_penalty = persample_entropy = cb_entropy = self.zero
679
+ commit_loss = self.zero
680
+
681
+ # input back to original dtype if needed
682
+
683
+ if force_f32:
684
+ x = x.type(orig_dtype)
685
+
686
+ # merge back codebook dim
687
+ x = quantized # rename quantized to x for output
688
+ x = rearrange(x, 'b n c d -> b n (c d)')
689
+
690
+ # project out to feature dimension if needed
691
+
692
+ x = self.project_out(x)
693
+
694
+ # reconstitute image or video dimensions
695
+
696
+ if should_transpose:
697
+ x = unpack_one(x, ps, 'b * d')
698
+ x = rearrange(x, 'b ... d -> b d ...')
699
+
700
+ bit_indices = unpack_one(bit_indices, ps, 'b * c d')
701
+
702
+ # whether to remove single codebook dim
703
+
704
+ if not self.keep_num_codebooks_dim:
705
+ bit_indices = rearrange(bit_indices, '... 1 d -> ... d')
706
+
707
+ # complete aux loss
708
+
709
+ aux_loss = commit_loss * self.commitment_loss_weight + (self.zeta * entropy_penalty / self.inv_temperature)*entropy_weight
710
+ # returns
711
+
712
+ ret = Return(x, indices, bit_indices, aux_loss)
713
+
714
+ if not return_loss_breakdown:
715
+ return ret
716
+
717
+ return ret, LossBreakdown(persample_entropy, cb_entropy, commit_loss)
718
+
models/bsq_vae/vae.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+
4
+ from infinity.models.bsq_vae.flux_vqgan import AutoEncoder
5
+
6
+ def load_cnn(model, state_dict, prefix, expand=False, use_linear=False):
7
+ delete_keys = []
8
+ loaded_keys = []
9
+ for key in state_dict:
10
+ if key.startswith(prefix):
11
+ _key = key[len(prefix):]
12
+ if _key in model.state_dict():
13
+ # load nn.Conv2d or nn.Linear to nn.Linear
14
+ if use_linear and (".q.weight" in key or ".k.weight" in key or ".v.weight" in key or ".proj_out.weight" in key):
15
+ load_weights = state_dict[key].squeeze()
16
+ elif _key.endswith(".conv.weight") and expand:
17
+ if model.state_dict()[_key].shape == state_dict[key].shape:
18
+ # 2D cnn to 2D cnn
19
+ load_weights = state_dict[key]
20
+ else:
21
+ # 2D cnn to 3D cnn
22
+ _expand_dim = model.state_dict()[_key].shape[2]
23
+ load_weights = state_dict[key].unsqueeze(2).repeat(1, 1, _expand_dim, 1, 1)
24
+ else:
25
+ load_weights = state_dict[key]
26
+ model.state_dict()[_key].copy_(load_weights)
27
+ delete_keys.append(key)
28
+ loaded_keys.append(prefix+_key)
29
+ # load nn.Conv2d to Conv class
30
+ conv_list = ["conv"] if use_linear else ["conv", ".q.", ".k.", ".v.", ".proj_out.", ".nin_shortcut."]
31
+ if any(k in _key for k in conv_list):
32
+ if _key.endswith(".weight"):
33
+ conv_key = _key.replace(".weight", ".conv.weight")
34
+ if conv_key and conv_key in model.state_dict():
35
+ if model.state_dict()[conv_key].shape == state_dict[key].shape:
36
+ # 2D cnn to 2D cnn
37
+ load_weights = state_dict[key]
38
+ else:
39
+ # 2D cnn to 3D cnn
40
+ _expand_dim = model.state_dict()[conv_key].shape[2]
41
+ load_weights = state_dict[key].unsqueeze(2).repeat(1, 1, _expand_dim, 1, 1)
42
+ model.state_dict()[conv_key].copy_(load_weights)
43
+ delete_keys.append(key)
44
+ loaded_keys.append(prefix+conv_key)
45
+ if _key.endswith(".bias"):
46
+ conv_key = _key.replace(".bias", ".conv.bias")
47
+ if conv_key and conv_key in model.state_dict():
48
+ model.state_dict()[conv_key].copy_(state_dict[key])
49
+ delete_keys.append(key)
50
+ loaded_keys.append(prefix+conv_key)
51
+ # load nn.GroupNorm to Normalize class
52
+ if "norm" in _key:
53
+ if _key.endswith(".weight"):
54
+ norm_key = _key.replace(".weight", ".norm.weight")
55
+ if norm_key and norm_key in model.state_dict():
56
+ model.state_dict()[norm_key].copy_(state_dict[key])
57
+ delete_keys.append(key)
58
+ loaded_keys.append(prefix+norm_key)
59
+ if _key.endswith(".bias"):
60
+ norm_key = _key.replace(".bias", ".norm.bias")
61
+ if norm_key and norm_key in model.state_dict():
62
+ model.state_dict()[norm_key].copy_(state_dict[key])
63
+ delete_keys.append(key)
64
+ loaded_keys.append(prefix+norm_key)
65
+
66
+ for key in delete_keys:
67
+ del state_dict[key]
68
+
69
+ return model, state_dict, loaded_keys
70
+
71
+
72
+ def vae_model(vqgan_ckpt, schedule_mode, codebook_dim, codebook_size, test_mode=True, patch_size=16, encoder_ch_mult=[1, 2, 4, 4, 4], decoder_ch_mult=[1, 2, 4, 4, 4],):
73
+ args=argparse.Namespace(
74
+ vqgan_ckpt=vqgan_ckpt,
75
+ sd_ckpt=None,
76
+ inference_type='image',
77
+ save='./imagenet_val_bsq',
78
+ save_prediction=True,
79
+ image_recon4video=False,
80
+ junke_old=False,
81
+ device='cuda',
82
+ max_steps=1000000.0,
83
+ log_every=1,
84
+ visu_every=1000,
85
+ ckpt_every=1000,
86
+ default_root_dir='',
87
+ compile='no',
88
+ ema='no',
89
+ lr=0.0001,
90
+ beta1=0.9,
91
+ beta2=0.95,
92
+ warmup_steps=0,
93
+ optim_type='Adam',
94
+ disc_optim_type=None,
95
+ lr_min=0.0,
96
+ warmup_lr_init=0.0,
97
+ max_grad_norm=1.0,
98
+ max_grad_norm_disc=1.0,
99
+ disable_sch=False,
100
+ patch_size=patch_size,
101
+ temporal_patch_size=4,
102
+ embedding_dim=256,
103
+ codebook_dim=codebook_dim,
104
+ num_quantizers=8,
105
+ quantizer_type='MultiScaleBSQ',
106
+ use_vae=False,
107
+ use_freq_enc=False,
108
+ use_freq_dec=False,
109
+ preserve_norm=False,
110
+ ln_before_quant=False,
111
+ ln_init_by_sqrt=False,
112
+ use_pxsf=False,
113
+ new_quant=True,
114
+ use_decay_factor=False,
115
+ mask_out=False,
116
+ use_stochastic_depth=False,
117
+ drop_rate=0.0,
118
+ schedule_mode=schedule_mode,
119
+ lr_drop=None,
120
+ lr_drop_rate=0.1,
121
+ keep_first_quant=False,
122
+ keep_last_quant=False,
123
+ remove_residual_detach=False,
124
+ use_out_phi=False,
125
+ use_out_phi_res=False,
126
+ use_lecam_reg=False,
127
+ lecam_weight=0.05,
128
+ perceptual_model='vgg16',
129
+ base_ch_disc=64,
130
+ random_flip=False,
131
+ flip_prob=0.5,
132
+ flip_mode='stochastic',
133
+ max_flip_lvl=1,
134
+ not_load_optimizer=False,
135
+ use_lecam_reg_zero=False,
136
+ freeze_encoder=False,
137
+ rm_downsample=False,
138
+ random_flip_1lvl=False,
139
+ flip_lvl_idx=0,
140
+ drop_when_test=False,
141
+ drop_lvl_idx=0,
142
+ drop_lvl_num=1,
143
+ disc_version='v1',
144
+ magvit_disc=False,
145
+ sigmoid_in_disc=False,
146
+ activation_in_disc='leaky_relu',
147
+ apply_blur=False,
148
+ apply_noise=False,
149
+ dis_warmup_steps=0,
150
+ dis_lr_multiplier=1.0,
151
+ dis_minlr_multiplier=False,
152
+ disc_channels=64,
153
+ disc_layers=3,
154
+ discriminator_iter_start=0,
155
+ disc_pretrain_iter=0,
156
+ disc_optim_steps=1,
157
+ disc_warmup=0,
158
+ disc_pool='no',
159
+ disc_pool_size=1000,
160
+ advanced_disc=False,
161
+ recon_loss_type='l1',
162
+ video_perceptual_weight=0.0,
163
+ image_gan_weight=1.0,
164
+ video_gan_weight=1.0,
165
+ image_disc_weight=0.0,
166
+ video_disc_weight=0.0,
167
+ l1_weight=4.0,
168
+ gan_feat_weight=0.0,
169
+ perceptual_weight=0.0,
170
+ kl_weight=0.0,
171
+ lfq_weight=0.0,
172
+ entropy_loss_weight=0.1,
173
+ commitment_loss_weight=0.25,
174
+ diversity_gamma=1,
175
+ norm_type='group',
176
+ disc_loss_type='hinge',
177
+ use_checkpoint=False,
178
+ precision='fp32',
179
+ encoder_dtype='fp32',
180
+ upcast_attention='',
181
+ upcast_tf32=False,
182
+ tokenizer='flux',
183
+ pretrained=None,
184
+ pretrained_mode='full',
185
+ inflation_pe=False,
186
+ init_vgen='no',
187
+ no_init_idis=False,
188
+ init_idis='keep',
189
+ init_vdis='no',
190
+ enable_nan_detector=False,
191
+ turn_on_profiler=False,
192
+ profiler_scheduler_wait_steps=10,
193
+ debug=True,
194
+ video_logger=False,
195
+ bytenas='',
196
+ username='',
197
+ seed=1234,
198
+ vq_to_vae=False,
199
+ load_not_strict=False,
200
+ zero=0,
201
+ bucket_cap_mb=40,
202
+ manual_gc_interval=1000,
203
+ data_path=[''],
204
+ data_type=[''],
205
+ dataset_list=['imagenet'],
206
+ fps=-1,
207
+ dataaug='resizecrop',
208
+ multi_resolution=False,
209
+ random_bucket_ratio=0.0,
210
+ sequence_length=16,
211
+ resolution=[256, 256],
212
+ batch_size=[1],
213
+ num_workers=0,
214
+ image_channels=3,
215
+ codebook_size=codebook_size,
216
+ codebook_l2_norm=True,
217
+ codebook_show_usage=True,
218
+ commit_loss_beta=0.25,
219
+ entropy_loss_ratio=0.0,
220
+ base_ch=128,
221
+ num_res_blocks=2,
222
+ encoder_ch_mult=encoder_ch_mult,
223
+ decoder_ch_mult=decoder_ch_mult,
224
+ dropout_p=0.0,
225
+ cnn_type='2d',
226
+ cnn_version='v1',
227
+ conv_in_out_2d='no',
228
+ conv_inner_2d='no',
229
+ res_conv_2d='no',
230
+ cnn_attention='no',
231
+ cnn_norm_axis='spatial',
232
+ flux_weight=0,
233
+ cycle_weight=0,
234
+ cycle_feat_weight=0,
235
+ cycle_gan_weight=0,
236
+ cycle_loop=0,
237
+ z_drop=0.0)
238
+
239
+ vae = AutoEncoder(args)
240
+ use_vae = vae.use_vae
241
+ if not use_vae:
242
+ num_codes = args.codebook_size
243
+ if isinstance(vqgan_ckpt, str):
244
+ state_dict = torch.load(args.vqgan_ckpt, map_location=torch.device("cpu"), weights_only=True)
245
+ else:
246
+ state_dict = args.vqgan_ckpt
247
+ if state_dict:
248
+ if args.ema == "yes":
249
+ vae, new_state_dict, loaded_keys = load_cnn(vae, state_dict["ema"], prefix="", expand=False)
250
+ else:
251
+ vae, new_state_dict, loaded_keys = load_cnn(vae, state_dict["vae"], prefix="", expand=False)
252
+ if test_mode:
253
+ vae.eval()
254
+ [p.requires_grad_(False) for p in vae.parameters()]
255
+ return vae
models/ema.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import torch
3
+ from collections import OrderedDict
4
+
5
+
6
+ def get_ema_model(model):
7
+ ema_model = copy.deepcopy(model)
8
+ ema_model.eval()
9
+ for param in ema_model.parameters():
10
+ param.requires_grad = False
11
+ return ema_model
12
+
13
+ @torch.no_grad()
14
+ def update_ema(ema_model, model, decay=0.9999):
15
+ """
16
+ Step the EMA model towards the current model.
17
+ """
18
+ ema_params = OrderedDict(ema_model.named_parameters())
19
+ model_params = OrderedDict(model.named_parameters())
20
+
21
+ for name, param in model_params.items():
22
+ # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed
23
+ ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay)
models/flex_attn.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Wrap torch's flex attention and handle mess info or potentially refactor
3
+ """
4
+ from functools import partial
5
+ import torch
6
+ import numpy as np
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ try:
10
+ from torch.nn.attention.flex_attention import flex_attention, create_block_mask
11
+ flex_attention_available = True
12
+ except ImportError:
13
+ print(f"[Warning] flex attention need pytorch 2.5.0+ but your version is {torch.__version__}")
14
+ flex_attention_available = False
15
+
16
+ def _causal_mask(b, h, q_idx, kv_idx):
17
+ return q_idx >= kv_idx
18
+
19
+ def _length_to_offsets(lengths, device):
20
+ """Converts a list of lengths to a list of offsets.
21
+
22
+ Args:
23
+ lengths: A list of lengths.
24
+
25
+ """
26
+ offsets = [0]
27
+ offsets.extend(lengths)
28
+ offsets = torch.tensor(offsets, device=device, dtype=torch.int32)
29
+ offsets = torch.cumsum(offsets, dim=-1)
30
+ return offsets
31
+
32
+ def _generate_var_mask_mod(offsets):
33
+ """Generates mask mods that apply to inputs to flex attention in the sequence stacked
34
+ format.
35
+
36
+ Args:
37
+ offsets: This tensor should be of shape(num_documents + 1)
38
+ this should contain the cumulative counts of document tokens.
39
+ e.g. if you have 3 documents of length 2, 4, 3 then
40
+ offsets = [0, 2, 6, 9]
41
+
42
+ Note:
43
+ What is the sequence stacked format? When assembling batches of inputs, we
44
+ take multiple sequences and stack them together to form 1 large sequence. We then
45
+ use masking to ensure that the attention scores are only applied to tokens within
46
+ the same document.
47
+ """
48
+
49
+ def _offsets_to_doc_ids_tensor(offsets):
50
+ device = offsets.device
51
+ counts = offsets[1:] - offsets[:-1]
52
+ return torch.repeat_interleave(
53
+ torch.arange(len(counts), device=device, dtype=torch.int32), counts
54
+ )
55
+
56
+ document_id = _offsets_to_doc_ids_tensor(offsets)
57
+
58
+ def var_mask_mod(b, h, q_idx, kv_idx):
59
+ same_doc = document_id[q_idx] == document_id[kv_idx]
60
+ causal_mask = _causal_mask(b, h, q_idx, kv_idx)
61
+ return same_doc | causal_mask
62
+
63
+ return var_mask_mod
64
+
65
+ def _generate_var_infer_mask_with_kv_cache(lengths):
66
+ kv_len = sum(lengths)
67
+ def var_mask_mod(b, h, q_idx, kv_idx):
68
+ return kv_idx < kv_len
69
+
70
+ return var_mask_mod
71
+
72
+ class FlexAttn(nn.Module):
73
+ def __init__(
74
+ self, block_scales:list, mask_type:str, B, H, L:int, auto_padding=False
75
+ ):
76
+ """
77
+ :param block_scales: accept VAR's block sizes like [(1,1), (2,2), (3,3)]
78
+ :param mask_type: var/causal
79
+ :param B: batch size
80
+ :param H: heads num
81
+ :param L: sequence length
82
+ """
83
+ super().__init__()
84
+ if not flex_attention_available:
85
+ raise NotImplementedError((f"[Error] flex attention need pytorch 2.5.0+ but your version is {torch.__version__}"))
86
+
87
+ self.support_mask_type = ["var", "causal", "var_infer_mask_with_kv_cache"]
88
+ self.auto_padding = auto_padding
89
+
90
+ self.flex_attention = torch.compile(flex_attention)
91
+
92
+ self.block_scales = block_scales
93
+ self.lengths = [ x * y * z for x,y,z in block_scales]
94
+
95
+ self.offsets = _length_to_offsets(self.lengths, device='cuda')
96
+
97
+ # if L paded to align 128, block need to cover padding area
98
+ if self.offsets[-1] < L:
99
+ self.offsets = torch.cat((self.offsets, torch.tensor([L], device='cuda')), dim=0)
100
+
101
+ if mask_type == "var":
102
+ self.mask_mod = _generate_var_mask_mod(self.offsets)
103
+ self.block_mask = create_block_mask(self.mask_mod, B = B, H = H, Q_LEN = L, KV_LEN = L, device = 'cuda', _compile = True)
104
+ elif mask_type == "causal":
105
+ self.mask_mod = _causal_mask
106
+ self.block_mask = create_block_mask(self.mask_mod, B = B, H = H, Q_LEN = L, KV_LEN = L, device = 'cuda', _compile = True)
107
+ elif mask_type == 'var_infer_mask_with_kv_cache':
108
+ self.mask_mod = _generate_var_infer_mask_with_kv_cache(self.lengths)
109
+ self.block_mask = create_block_mask(self.mask_mod, B = B, H = H, Q_LEN = L, KV_LEN = L, device = 'cuda', _compile = True)
110
+ else:
111
+ raise NotImplementedError(f"{mask_type} not supportted in FlexAttn, support type:{self.support_mask_type}")
112
+
113
+
114
+ def forward(self, q, k, v, scale = None):
115
+ if self.auto_padding:
116
+ q_pad_len = (128 - q.shape[-2] % 128) % 128
117
+ kv_pad_len = (128 - k.shape[-2] % 128) % 128
118
+ q_pad = F.pad(q, (0, 0, 0, q_pad_len))
119
+ k_pad = F.pad(k, (0, 0, 0, kv_pad_len))
120
+ v_pad = F.pad(v, (0, 0, 0, kv_pad_len))
121
+ oup = self.flex_attention(q_pad.to(v_pad.dtype), k_pad.to(v.dtype), v_pad, block_mask = self.block_mask, scale = scale)
122
+ if q_pad_len > 0:
123
+ oup = oup[:,:,:-q_pad_len]
124
+ else:
125
+ oup = self.flex_attention(q.to(v.dtype), k.to(v.dtype), v, block_mask = self.block_mask, scale = scale)
126
+ return oup
127
+
128
+ def extra_repr(self) -> str:
129
+ tail = ''
130
+ return f'block size:{self.block_scales} {tail}'
models/fused_op.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ from copy import deepcopy
3
+ from typing import Union
4
+
5
+ import torch
6
+ from torch import nn as nn
7
+ from torch.nn import functional as F
8
+
9
+
10
+ @torch.compile(fullgraph=True)
11
+ def fused_rms_norm(x: torch.Tensor, weight: nn.Parameter, eps: float):
12
+ x = x.float()
13
+ return (x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True).add_(eps))) * weight
14
+
15
+
16
+ @torch.compile(fullgraph=True)
17
+ def fused_ada_layer_norm(C: int, eps: float, x: torch.Tensor, scale: torch.Tensor, shift: torch.Tensor):
18
+ x = x.float()
19
+ x = F.layer_norm(input=x, normalized_shape=(C,), weight=None, bias=None, eps=eps)
20
+ return x.mul(scale.add(1)).add_(shift)
21
+
22
+
23
+ @torch.compile(fullgraph=True)
24
+ def fused_ada_rms_norm(C: int, eps: float, x: torch.Tensor, scale: torch.Tensor, shift: torch.Tensor):
25
+ x = x.float()
26
+ x = (x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True).add_(eps)))
27
+ return x.mul(scale.add(1)).add_(shift)
models/infinity.py ADDED
@@ -0,0 +1,795 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Definition of Infinity transformer model.
3
+ """
4
+
5
+ import math
6
+ import random
7
+ import time
8
+ from contextlib import nullcontext
9
+ from functools import partial
10
+ from typing import List, Optional, Tuple, Union, Dict, Any
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ from timm.models import register_model
16
+ from torch.utils.checkpoint import checkpoint
17
+ from PIL import Image
18
+ import numpy as np
19
+ from torch.nn.attention.flex_attention import flex_attention
20
+
21
+ import infinity.utils.dist as dist
22
+ from infinity.utils.dist import for_visualize
23
+ from infinity.models.basic import flash_attn_func, flash_fused_op_installed, AdaLNBeforeHead, CrossAttnBlock, SelfAttnBlock, CrossAttention, FastRMSNorm, precompute_rope2d_freqs_grid
24
+ from infinity.utils import misc
25
+ from infinity.models.flex_attn import FlexAttn
26
+ from infinity.utils.dynamic_resolution import dynamic_resolution_h_w, h_div_w_templates
27
+
28
+ try:
29
+ from infinity.models.fused_op import fused_ada_layer_norm, fused_ada_rms_norm
30
+ except:
31
+ fused_ada_layer_norm, fused_ada_rms_norm = None, None
32
+
33
+
34
+ class MultiInpIdentity(nn.Module):
35
+ def forward(self, x, *args, **kwargs):
36
+ return x
37
+
38
+
39
+ class TextAttentivePool(nn.Module):
40
+ def __init__(self, Ct5: int, D: int):
41
+ super().__init__()
42
+ self.Ct5, self.D = Ct5, D
43
+ if D > 4096:
44
+ self.head_dim = 64
45
+ else:
46
+ self.head_dim = 128
47
+
48
+ self.num_heads = Ct5 // self.head_dim
49
+ self.ca = CrossAttention(for_attn_pool=True, embed_dim=self.D, kv_dim=Ct5, num_heads=self.num_heads)
50
+ def forward(self, ca_kv):
51
+ return self.ca(None, ca_kv).squeeze(1)
52
+
53
+ class SharedAdaLin(nn.Linear):
54
+ def forward(self, cond_BD):
55
+ C = self.weight.shape[0] // 6
56
+ return super().forward(cond_BD).reshape(-1, 1, 6, C) # B16C
57
+
58
+
59
+ class MultipleLayers(nn.Module):
60
+ def __init__(self, ls, num_blocks_in_a_chunk, index):
61
+ super().__init__()
62
+ self.module = nn.ModuleList()
63
+ for i in range(index, index+num_blocks_in_a_chunk):
64
+ self.module.append(ls[i])
65
+
66
+ def forward(self, x, cond_BD, ca_kv, attn_bias_or_two_vector, attn_fn=None, scale_schedule=None, checkpointing_full_block=False, rope2d_freqs_grid=None):
67
+ h = x
68
+ for m in self.module:
69
+ if checkpointing_full_block:
70
+ h = torch.utils.checkpoint.checkpoint(m, h, cond_BD, ca_kv, attn_bias_or_two_vector, attn_fn, scale_schedule, rope2d_freqs_grid, use_reentrant=False)
71
+ else:
72
+ h = m(h, cond_BD, ca_kv, attn_bias_or_two_vector, attn_fn, scale_schedule, rope2d_freqs_grid)
73
+ return h
74
+
75
+ class Infinity(nn.Module):
76
+ def __init__(
77
+ self, vae_local,
78
+ text_channels=0, text_maxlen=0, # text-cond generation
79
+ selecting_idx=None, # class-cond generation
80
+ embed_dim=1024, depth=16, num_heads=16, mlp_ratio=4., # model's architecture
81
+ drop_rate=0., drop_path_rate=0., # drop out and drop path
82
+ norm_eps=1e-6, rms_norm=False, # norm layer
83
+ shared_aln=False, head_aln=True, # adaptive norm
84
+ cond_drop_rate=0.1, # for classifier-free guidance
85
+ rand_uncond=False,
86
+ cross_attn_layer_scale=-1., nm0=False, tau=1, cos_attn=True, swiglu=False,
87
+ raw_scale_schedule=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16),
88
+ head_depth=1,
89
+ top_p=0.0, top_k=0.0,
90
+ customized_flash_attn=False, fused_mlp=False, fused_norm=False,
91
+ block_chunks=1,
92
+ checkpointing=None,
93
+ pad_to_multiplier=0,
94
+ use_flex_attn=False,
95
+ batch_size=2,
96
+ add_lvl_embeding_only_first_block=1,
97
+ use_bit_label=1,
98
+ rope2d_each_sa_layer=0,
99
+ rope2d_normalized_by_hw=0,
100
+ pn=None,
101
+ train_h_div_w_list=None,
102
+ video_frames=1,
103
+ always_training_scales=20,
104
+ apply_spatial_patchify = 0,
105
+ inference_mode=False,
106
+ ):
107
+ # set hyperparameters
108
+ self.C = embed_dim
109
+ self.inference_mode = inference_mode
110
+ self.apply_spatial_patchify = apply_spatial_patchify
111
+ if self.apply_spatial_patchify:
112
+ self.d_vae = vae_local.embed_dim * 4
113
+ else:
114
+ self.d_vae = vae_local.embed_dim
115
+ self.use_bit_label = use_bit_label
116
+ self.codebook_dim = self.d_vae
117
+ self.V = (self.codebook_dim * 2) if self.use_bit_label else vae_local.vocab_size
118
+ self.bit_mask = vae_local.quantizer.lfq.mask if self.use_bit_label else None
119
+ self.Ct5 = text_channels
120
+ self.depth = depth
121
+ self.num_heads = num_heads
122
+ self.batch_size = batch_size
123
+ self.mlp_ratio = mlp_ratio
124
+ self.cond_drop_rate = cond_drop_rate
125
+ self.norm_eps = norm_eps
126
+ self.prog_si = -1
127
+ self.pn = pn
128
+ self.train_h_div_w_list = train_h_div_w_list if train_h_div_w_list else h_div_w_templates
129
+ self.video_frames = video_frames
130
+ self.always_training_scales = always_training_scales
131
+
132
+ assert add_lvl_embeding_only_first_block in [0,1]
133
+ self.add_lvl_embeding_only_first_block = add_lvl_embeding_only_first_block
134
+ assert rope2d_each_sa_layer in [0,1]
135
+ self.rope2d_each_sa_layer = rope2d_each_sa_layer
136
+ self.rope2d_normalized_by_hw = rope2d_normalized_by_hw
137
+ print(f'self.codebook_dim: {self.codebook_dim}, self.add_lvl_embeding_only_first_block: {self.add_lvl_embeding_only_first_block}, \
138
+ self.use_bit_label: {self.use_bit_label}, self.rope2d_each_sa_layer: {rope2d_each_sa_layer}, self.rope2d_normalized_by_hw: {self.rope2d_normalized_by_hw}')
139
+ head_up_method = ''
140
+ word_patch_size = 1 if head_up_method in {'', 'no'} else 2
141
+ if word_patch_size > 1:
142
+ assert all(raw_pn % word_patch_size == 0 for raw_pn in raw_scale_schedule), f'raw_scale_schedule={raw_scale_schedule}, not compatible with word_patch_size={word_patch_size}'
143
+
144
+ self.checkpointing = checkpointing
145
+ self.pad_to_multiplier = max(1, pad_to_multiplier)
146
+
147
+ customized_kernel_installed = any('Infinity' in arg_name for arg_name in flash_attn_func.__code__.co_varnames)
148
+ self.customized_flash_attn = customized_flash_attn and customized_kernel_installed
149
+ if customized_flash_attn and not customized_kernel_installed:
150
+ import inspect, warnings
151
+ file_path = inspect.getsourcefile(flash_attn_func)
152
+ line_number = inspect.getsourcelines(flash_attn_func)[1]
153
+ info = (
154
+ f'>>>>>> Customized FlashAttention2 is not installed or compiled, but specified in args by --flash=1. Set customized_flash_attn = False. <<<<<<\n'
155
+ f'>>>>>> `flash_attn_func` is in [line {line_number}] [file {file_path}] <<<<<<\n'
156
+ f'>>>>>> {flash_attn_func.__code__.co_varnames=} <<<<<<\n'
157
+ )
158
+ warnings.warn(info, ImportWarning)
159
+ print(info, flush=True)
160
+
161
+ self.raw_scale_schedule = raw_scale_schedule # 'raw' means before any patchifying
162
+ self.first_l = 1
163
+ # solve top-p top-k sampling hyperparameters
164
+ self.top_p, self.top_k = max(min(top_p, 1), 0), (round(top_k * self.V) if 0 < top_k < 1 else round(top_k))
165
+ if self.top_p < 1e-5: self.top_p = 0
166
+ if self.top_k >= self.V or self.top_k <= 0: self.top_k = 0
167
+
168
+ t = torch.zeros(dist.get_world_size(), device=dist.get_device())
169
+ t[dist.get_rank()] = float(flash_fused_op_installed)
170
+ dist.barrier()
171
+ dist.allreduce(t)
172
+ assert round(t.sum().item()) in {0, dist.get_world_size()}, f'flash_fused_op_installed: {t}'
173
+
174
+ super().__init__()
175
+ self.rng = torch.Generator(device=dist.get_device())
176
+ self.maybe_record_function = nullcontext
177
+ self.text_maxlen = text_maxlen
178
+ self.t2i = text_channels != 0
179
+
180
+ # [inp & position embedding]
181
+ init_std = math.sqrt(1 / self.C / 3)
182
+ self.norm0_cond = nn.Identity()
183
+ if self.t2i:
184
+ self.selecting_idx = None
185
+ self.num_classes = 0
186
+ self.D = self.C
187
+
188
+ cfg_uncond = torch.empty(self.text_maxlen, self.Ct5)
189
+ rng = torch.Generator(device='cpu')
190
+ rng.manual_seed(0)
191
+ torch.nn.init.trunc_normal_(cfg_uncond, std=1.2, generator=rng)
192
+ cfg_uncond /= self.Ct5 ** 0.5
193
+ if rand_uncond:
194
+ self.register_buffer('cfg_uncond', cfg_uncond)
195
+ else:
196
+ self.cfg_uncond = nn.Parameter(cfg_uncond)
197
+
198
+ self.text_norm = FastRMSNorm(self.Ct5, elementwise_affine=True, eps=norm_eps)
199
+ self.text_proj_for_sos = TextAttentivePool(self.Ct5, self.D)
200
+ self.text_proj_for_ca = nn.Sequential(
201
+ nn.Linear(self.Ct5, self.D),
202
+ nn.GELU(approximate='tanh'),
203
+ nn.Linear(self.D, self.D),
204
+ )
205
+ else: # class-label cond
206
+ if selecting_idx is None:
207
+ num_classes = 1000
208
+ print(f'======= WARNING: selecting_idx not specified, set to 1/{num_classes} @ {dist.get_device()} =======')
209
+ selecting_idx = torch.full((1, num_classes), fill_value=1/num_classes, dtype=torch.float32, device=dist.get_device())
210
+ self.selecting_idx = selecting_idx
211
+ self.num_classes = selecting_idx.shape[-1]
212
+ self.D = self.C
213
+ self.class_emb = nn.Embedding(self.num_classes + 1, self.C)
214
+ nn.init.trunc_normal_(self.class_emb.weight.data, mean=0, std=init_std)
215
+
216
+ self.pos_start = nn.Parameter(torch.empty(1, self.first_l, self.C))
217
+ nn.init.trunc_normal_(self.pos_start.data, mean=0, std=init_std)
218
+ if self.rope2d_each_sa_layer:
219
+ rope2d_freqs_grid = precompute_rope2d_freqs_grid(dim=self.C//self.num_heads, dynamic_resolution_h_w=dynamic_resolution_h_w, pad_to_multiplier=self.pad_to_multiplier, rope2d_normalized_by_hw=self.rope2d_normalized_by_hw)
220
+ self.rope2d_freqs_grid = rope2d_freqs_grid
221
+ else:
222
+ raise ValueError(f'self.rope2d_each_sa_layer={self.rope2d_each_sa_layer} not implemented')
223
+ self.lvl_embed = nn.Embedding(15, self.C)
224
+ nn.init.trunc_normal_(self.lvl_embed.weight.data, mean=0, std=init_std)
225
+
226
+ # [input layers] input norm && input embedding
227
+ norm_layer = partial(FastRMSNorm if rms_norm else nn.LayerNorm, eps=norm_eps)
228
+ self.norm0_ve = norm_layer(self.d_vae) if nm0 else nn.Identity()
229
+ self.word_embed = nn.Linear(self.d_vae, self.C)
230
+
231
+ # [shared adaptive layernorm mapping network]
232
+ self.shared_ada_lin = nn.Sequential(nn.SiLU(inplace=False), SharedAdaLin(self.D, 6*self.C)) if shared_aln else nn.Identity()
233
+
234
+ # fused norm
235
+ if fused_norm:
236
+ fused_norm_func = fused_ada_rms_norm if rms_norm else fused_ada_layer_norm
237
+ if fused_norm_func is not None: # pre-compile
238
+ B = 2
239
+ x = torch.randn(B, 1, self.C).requires_grad_(True)
240
+ scale = torch.randn(B, 1, self.C).mul_(0.01).requires_grad_(True)
241
+ shift = torch.randn(B, 1, self.C).mul_(0.01).requires_grad_(True)
242
+ # fused_norm_func(C=self.C, eps=self.norm_eps, x=x, scale=scale, shift=shift).mean().backward()
243
+ del B, x, scale, shift
244
+ else:
245
+ fused_norm_func = None
246
+
247
+ # [backbone and head]
248
+ self.use_flex_attn = use_flex_attn
249
+ self.attn_fn_compile_dict = {}
250
+ self.batch_size = batch_size
251
+ if self.use_flex_attn:
252
+ self.attn_fn_compile_dict = self.compile_flex_attn()
253
+
254
+ self.drop_path_rate = drop_path_rate
255
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # dpr means drop path rate (linearly increasing)
256
+ self.unregistered_blocks = []
257
+ for block_idx in range(depth):
258
+ block = (CrossAttnBlock if self.t2i else SelfAttnBlock)(
259
+ embed_dim=self.C, kv_dim=self.D, cross_attn_layer_scale=cross_attn_layer_scale, cond_dim=self.D, act=True, shared_aln=shared_aln, norm_layer=norm_layer,
260
+ num_heads=num_heads, mlp_ratio=mlp_ratio, drop=drop_rate, drop_path=dpr[block_idx], tau=tau, cos_attn=cos_attn,
261
+ swiglu=swiglu, customized_flash_attn=self.customized_flash_attn, fused_mlp=fused_mlp, fused_norm_func=fused_norm_func,
262
+ checkpointing_sa_only=self.checkpointing == 'self-attn',
263
+ use_flex_attn=use_flex_attn, batch_size=batch_size, pad_to_multiplier=pad_to_multiplier, rope2d_normalized_by_hw=rope2d_normalized_by_hw,
264
+ )
265
+ self.unregistered_blocks.append(block)
266
+
267
+ # [head]
268
+ V = self.V
269
+ if head_aln:
270
+ self.head_nm = AdaLNBeforeHead(self.C, self.D, act=True, norm_layer=norm_layer, fused_norm_func=fused_norm_func)
271
+ self.head = nn.Linear(self.C, V) if head_depth == 1 else nn.Sequential(nn.Linear(self.C, self.C, bias=True), nn.GELU(approximate='tanh'), nn.Linear(self.C, V))
272
+ else:
273
+ self.head_nm = MultiInpIdentity()
274
+ self.head = nn.Sequential(norm_layer(self.C), nn.Linear(self.C, V)) if head_depth == 1 else nn.Sequential(norm_layer(self.C), nn.Linear(self.C, self.C, bias=True), nn.GELU(approximate='tanh'), nn.Linear(self.C, V))
275
+
276
+ self.num_block_chunks = block_chunks or 1
277
+ self.num_blocks_in_a_chunk = depth // block_chunks
278
+ print(f"{self.num_blocks_in_a_chunk=}, {depth=}, {block_chunks=}")
279
+ assert self.num_blocks_in_a_chunk * block_chunks == depth
280
+ if self.num_block_chunks == 1:
281
+ self.blocks = nn.ModuleList(self.unregistered_blocks)
282
+ else:
283
+ self.block_chunks = nn.ModuleList()
284
+ for i in range(self.num_block_chunks):
285
+ self.block_chunks.append(MultipleLayers(self.unregistered_blocks, self.num_blocks_in_a_chunk, i*self.num_blocks_in_a_chunk))
286
+ print(
287
+ f'\n[constructor] ==== customized_flash_attn={self.customized_flash_attn} (using_flash={sum((b.sa.using_flash if self.t2i else b.attn.using_flash) for b in self.unregistered_blocks)}/{self.depth}), fused_mlp={fused_mlp} (fused_mlp={sum(b.ffn.fused_mlp_func is not None for b in self.unregistered_blocks)}/{self.depth}) ==== \n'
288
+ f' [Infinity config ] embed_dim={embed_dim}, num_heads={num_heads}, depth={depth}, mlp_ratio={mlp_ratio}, swiglu={swiglu} num_blocks_in_a_chunk={self.num_blocks_in_a_chunk}\n'
289
+ f' [drop ratios] drop_rate={drop_rate}, drop_path_rate={drop_path_rate:g} ({torch.linspace(0, drop_path_rate, depth)})',
290
+ end='\n\n', flush=True
291
+ )
292
+
293
+
294
+ def compile_flex_attn(self):
295
+ attn_fn_compile_dict = {}
296
+ for h_div_w in self.train_h_div_w_list:
297
+ h_div_w_template = h_div_w_templates[np.argmin(np.abs(float(h_div_w) - h_div_w_templates))]
298
+ full_scale_schedule = dynamic_resolution_h_w[h_div_w_template][self.pn]['scales']
299
+ if self.inference_mode:
300
+ apply_flex_attn_scales = list(range(1, 1+len(full_scale_schedule)))
301
+ mask_type = "infinity_infer_mask_with_kv_cache"
302
+ auto_padding = True
303
+ else:
304
+ mask_type = 'var'
305
+ auto_padding = False
306
+ apply_flex_attn_scales = [min(self.always_training_scales, len(full_scale_schedule))]
307
+ for scales_num in apply_flex_attn_scales:
308
+ print(f'====== apply flex attn hdivw: {h_div_w} scales: {scales_num} ======')
309
+ scale_schedule = full_scale_schedule[:scales_num]
310
+ scale_schedule = [ (min(t, self.video_frames//4+1), h, w) for (t,h, w) in scale_schedule]
311
+ patchs_nums_tuple = tuple(scale_schedule)
312
+ SEQ_L = sum( pt * ph * pw for pt, ph, pw in patchs_nums_tuple)
313
+ aligned_L = SEQ_L+ (self.pad_to_multiplier - SEQ_L % self.pad_to_multiplier) if SEQ_L % self.pad_to_multiplier != 0 else SEQ_L
314
+ attn_fn = FlexAttn(block_scales = patchs_nums_tuple,
315
+ mask_type = mask_type,
316
+ B = self.batch_size,
317
+ H = self.num_heads,
318
+ L = aligned_L,
319
+ auto_padding=auto_padding)
320
+ attn_fn_compile_dict[patchs_nums_tuple] = attn_fn
321
+
322
+ if self.video_frames > 1: # append image attn_fn when self.video_frames > 1 (namely videos)
323
+ scale_schedule = [ (1, h, w) for (t,h, w) in scale_schedule]
324
+ patchs_nums_tuple = tuple(scale_schedule)
325
+ SEQ_L = sum( pt * ph * pw for pt, ph, pw in patchs_nums_tuple)
326
+ aligned_L = SEQ_L+ (self.pad_to_multiplier - SEQ_L % self.pad_to_multiplier) if SEQ_L % self.pad_to_multiplier != 0 else SEQ_L
327
+ attn_fn = FlexAttn(block_scales = patchs_nums_tuple,
328
+ mask_type = mask_type,
329
+ B = self.batch_size,
330
+ H = self.num_heads,
331
+ L = aligned_L)
332
+ attn_fn_compile_dict[patchs_nums_tuple] = attn_fn
333
+ return attn_fn_compile_dict
334
+
335
+ def get_logits(self, h: torch.Tensor, cond_BD: Optional[torch.Tensor]):
336
+ """
337
+ :param h: hidden_state, shaped (B or batch_size, L or seq_len, C or hidden_dim)
338
+ :param cond_BD: shaped (B or batch_size, D or cond_dim)
339
+ :param tau: temperature
340
+ :return: logits, shaped (B or batch_size, V or vocabulary_size)
341
+ """
342
+ with torch.amp.autocast('cuda', enabled=False):
343
+ return self.head(self.head_nm(h.float(), cond_BD.float()))
344
+
345
+ def add_lvl_embeding(self, feature, scale_ind, scale_schedule, need_to_pad=0):
346
+ bs, seq_len, c = feature.shape
347
+ patch_t, patch_h, patch_w = scale_schedule[scale_ind]
348
+ t_mul_h_mul_w = patch_t * patch_h * patch_w
349
+ assert t_mul_h_mul_w + need_to_pad == seq_len
350
+ feature[:, :t_mul_h_mul_w] += self.lvl_embed(scale_ind*torch.ones((bs, t_mul_h_mul_w),dtype=torch.int).to(feature.device))
351
+ return feature
352
+
353
+ def add_lvl_embeding_for_x_BLC(self, x_BLC, scale_schedule, need_to_pad=0):
354
+ ptr = 0
355
+ x_BLC_list = []
356
+ for scale_ind, patch_t_h_w in enumerate(scale_schedule):
357
+ scale_seq_len = np.array(patch_t_h_w).prod()
358
+ x_BLC_this_scale = x_BLC[:,ptr:ptr+scale_seq_len] # shape: [bs, patch_h*patch_w, c]
359
+ ptr += scale_seq_len
360
+ x_BLC_this_scale = self.add_lvl_embeding(x_BLC_this_scale, scale_ind, scale_schedule)
361
+ x_BLC_list.append(x_BLC_this_scale)
362
+ assert x_BLC.shape[1] == (ptr + need_to_pad), f'{x_BLC.shape[1]} != {ptr} + {need_to_pad}'
363
+ x_BLC_list.append(x_BLC[:,ptr:])
364
+ x_BLC = torch.cat(x_BLC_list, dim=1)
365
+ return x_BLC
366
+
367
+ def forward(self, label_B_or_BLT: Union[torch.LongTensor, Tuple[torch.FloatTensor, torch.IntTensor, int]], x_BLC_wo_prefix: torch.Tensor, scale_schedule: List[Tuple[int]],
368
+ cfg_infer=False,
369
+ **kwargs,
370
+ ) -> Union[torch.Tensor, List[torch.Tensor]]: # returns logits_BLV
371
+ """
372
+ label_B_or_BLT: label_B or (kv_compact, cu_seqlens_k, max_seqlen_k)
373
+ :return: logits BLV, V is vocab_size
374
+ """
375
+ if cfg_infer:
376
+ return self.autoregressive_infer_cfg(label_B_or_BLT=label_B_or_BLT, scale_schedule=scale_schedule, **kwargs)
377
+
378
+ x_BLC_wo_prefix = x_BLC_wo_prefix.float() # input should be float32
379
+ B = x_BLC_wo_prefix.shape[0]
380
+
381
+ # [1. get input sequence x_BLC]
382
+ with torch.amp.autocast('cuda', enabled=False):
383
+ kv_compact, lens, cu_seqlens_k, max_seqlen_k = label_B_or_BLT
384
+ # drop cond
385
+ total = 0
386
+ for le in lens:
387
+ if random.random() < self.cond_drop_rate:
388
+ kv_compact[total:total+le] = self.cfg_uncond[:le]
389
+ total += le
390
+ must_on_graph = self.cfg_uncond[0, 0] * 0
391
+ kv_compact = self.text_norm(kv_compact).contiguous()
392
+ sos = cond_BD = self.text_proj_for_sos((kv_compact, cu_seqlens_k, max_seqlen_k)).float().contiguous() # cond_BD should be float32
393
+ kv_compact = self.text_proj_for_ca(kv_compact).contiguous()
394
+ kv_compact[0, 0] += must_on_graph
395
+ ca_kv = kv_compact, cu_seqlens_k, max_seqlen_k
396
+
397
+ cond_BD_or_gss = self.shared_ada_lin(cond_BD).contiguous() # gss: gamma, scale, shift; cond_BD_or_gss should be float32
398
+
399
+ sos = sos.unsqueeze(1).expand(B, 1, -1) + self.pos_start.expand(B, 1, -1)
400
+ x_BLC = torch.cat((sos, self.word_embed(self.norm0_ve(x_BLC_wo_prefix))), dim=1)
401
+
402
+ # [1.1. pad the seqlen dim]
403
+ l_end = x_BLC.shape[1]
404
+ need_to_pad = (l_end + self.pad_to_multiplier - 1) // self.pad_to_multiplier * self.pad_to_multiplier - l_end # 0
405
+
406
+ if self.customized_flash_attn:
407
+ Infinity_visible_kvlen = self.Infinity_visible_kvlen[:l_end]
408
+ Infinity_invisible_qlen = self.Infinity_invisible_qlen[:l_end]
409
+ attn_bias_or_two_vector = (Infinity_visible_kvlen, Infinity_invisible_qlen)
410
+ # todo: solve need_to_pad here
411
+ elif self.use_flex_attn:
412
+ if need_to_pad:
413
+ x_BLC = F.pad(x_BLC, (0, 0, 0, need_to_pad))
414
+ assert x_BLC.shape[-1] % 128 == 0, 'x_BLC.shape[-1] % 128 != 0'
415
+ attn_bias_or_two_vector = None
416
+ else:
417
+ d: torch.Tensor = torch.cat([torch.full((pn[0]*pn[1]*pn[2],), i) for i, pn in enumerate(scale_schedule)]).view(1, l_end, 1)
418
+ dT = d.transpose(1, 2) # dT: 11L
419
+ attn_bias_for_masking = torch.where(d >= dT, 0., -torch.inf).reshape(1, 1, l_end, l_end)
420
+ attn_bias = attn_bias_for_masking[:, :, :l_end, :l_end].contiguous() # attn_bias: 11LL
421
+ if need_to_pad:
422
+ attn_bias = F.pad(attn_bias, (0, need_to_pad, 0, need_to_pad), value=-torch.inf)
423
+ attn_bias[0, 0, l_end:, 0] = 0
424
+ x_BLC = F.pad(x_BLC, (0, 0, 0, need_to_pad))
425
+ attn_bias_or_two_vector = attn_bias.type_as(x_BLC).to(x_BLC.device)
426
+
427
+ if self.use_flex_attn:
428
+ attn_fn = self.attn_fn_compile_dict[tuple(scale_schedule)]
429
+ else:
430
+ attn_fn = None
431
+
432
+ # [2. block loop]
433
+ SelfAttnBlock.forward, CrossAttnBlock.forward
434
+ checkpointing_full_block = self.checkpointing == 'full-block' and self.training
435
+ if self.num_block_chunks == 1:
436
+ for i, b in enumerate(self.blocks):
437
+ if self.add_lvl_embeding_only_first_block and i == 0:
438
+ x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad)
439
+ if not self.add_lvl_embeding_only_first_block:
440
+ x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad)
441
+ if checkpointing_full_block:
442
+ x_BLC = torch.utils.checkpoint.checkpoint(b, x_BLC, cond_BD_or_gss, ca_kv, attn_bias_or_two_vector, attn_fn, scale_schedule, self.rope2d_freqs_grid, use_reentrant=False)
443
+ else:
444
+ x_BLC = b(x=x_BLC, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=attn_bias_or_two_vector, attn_fn=attn_fn, scale_schedule=scale_schedule, rope2d_freqs_grid=self.rope2d_freqs_grid)
445
+ else:
446
+ for i, chunk in enumerate(self.block_chunks): # this path
447
+ if self.add_lvl_embeding_only_first_block and i == 0:
448
+ x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad)
449
+ if not self.add_lvl_embeding_only_first_block:
450
+ x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad)
451
+ x_BLC = chunk(x=x_BLC, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=attn_bias_or_two_vector, attn_fn=attn_fn, scale_schedule=scale_schedule, checkpointing_full_block=checkpointing_full_block, rope2d_freqs_grid=self.rope2d_freqs_grid)
452
+
453
+ # [3. unpad the seqlen dim, and then get logits]
454
+ return self.get_logits(x_BLC[:, :l_end], cond_BD) # return logits BLV, V is vocab_size
455
+
456
+ @torch.no_grad()
457
+ def autoregressive_infer_cfg(
458
+ self,
459
+ vae=None,
460
+ scale_schedule=None,
461
+ label_B_or_BLT=None,
462
+ B=1, negative_label_B_or_BLT=None, force_gt_Bhw=None,
463
+ g_seed=None, cfg_list=[], tau_list=[], cfg_sc=3, top_k=0, top_p=0.0,
464
+ returns_vemb=0, ratio_Bl1=None, gumbel=0, norm_cfg=False,
465
+ cfg_exp_k: float=0.0, cfg_insertion_layer=[-5],
466
+ vae_type=0, softmax_merge_topk=-1, ret_img=False,
467
+ trunk_scale=1000,
468
+ gt_leak=0, gt_ls_Bl=None,
469
+ inference_mode=False,
470
+ save_img_path=None,
471
+ sampling_per_bits=1,
472
+ ): # returns List[idx_Bl]
473
+ if g_seed is None: rng = None
474
+ else: self.rng.manual_seed(g_seed); rng = self.rng
475
+ assert len(cfg_list) >= len(scale_schedule)
476
+ assert len(tau_list) >= len(scale_schedule)
477
+
478
+ # scale_schedule is used by infinity, vae_scale_schedule is used by vae if there exists a spatial patchify,
479
+ # we need to convert scale_schedule to vae_scale_schedule by multiply 2 to h and w
480
+ if self.apply_spatial_patchify:
481
+ vae_scale_schedule = [(pt, 2*ph, 2*pw) for pt, ph, pw in scale_schedule]
482
+ else:
483
+ vae_scale_schedule = scale_schedule
484
+
485
+ kv_compact, lens, cu_seqlens_k, max_seqlen_k = label_B_or_BLT
486
+ if any(np.array(cfg_list) != 1):
487
+ bs = 2*B
488
+ if not negative_label_B_or_BLT:
489
+ kv_compact_un = kv_compact.clone()
490
+ total = 0
491
+ for le in lens:
492
+ kv_compact_un[total:total+le] = (self.cfg_uncond)[:le]
493
+ total += le
494
+ kv_compact = torch.cat((kv_compact, kv_compact_un), dim=0)
495
+ cu_seqlens_k = torch.cat((cu_seqlens_k, cu_seqlens_k[1:]+cu_seqlens_k[-1]), dim=0)
496
+ else:
497
+ kv_compact_un, lens_un, cu_seqlens_k_un, max_seqlen_k_un = negative_label_B_or_BLT
498
+ kv_compact = torch.cat((kv_compact, kv_compact_un), dim=0)
499
+ cu_seqlens_k = torch.cat((cu_seqlens_k, cu_seqlens_k_un[1:]+cu_seqlens_k[-1]), dim=0)
500
+ max_seqlen_k = max(max_seqlen_k, max_seqlen_k_un)
501
+ else:
502
+ bs = B
503
+
504
+ kv_compact = self.text_norm(kv_compact)
505
+ sos = cond_BD = self.text_proj_for_sos((kv_compact, cu_seqlens_k, max_seqlen_k)) # sos shape: [2, 4096]
506
+ kv_compact = self.text_proj_for_ca(kv_compact) # kv_compact shape: [304, 4096]
507
+ ca_kv = kv_compact, cu_seqlens_k, max_seqlen_k
508
+ last_stage = sos.unsqueeze(1).expand(bs, 1, -1) + self.pos_start.expand(bs, 1, -1)
509
+
510
+ with torch.amp.autocast('cuda', enabled=False):
511
+ cond_BD_or_gss = self.shared_ada_lin(cond_BD.float()).float().contiguous()
512
+ accu_BChw, cur_L, ret = None, 0, [] # current length, list of reconstructed images
513
+ idx_Bl_list, idx_Bld_list = [], []
514
+
515
+ if inference_mode:
516
+ for b in self.unregistered_blocks: (b.sa if isinstance(b, CrossAttnBlock) else b.attn).kv_caching(True)
517
+ else:
518
+ assert self.num_block_chunks > 1
519
+ for block_chunk_ in self.block_chunks:
520
+ for module in block_chunk_.module.module:
521
+ (module.sa if isinstance(module, CrossAttnBlock) else module.attn).kv_caching(True)
522
+
523
+ abs_cfg_insertion_layers = []
524
+ add_cfg_on_logits, add_cfg_on_probs = False, False
525
+ leng = len(self.unregistered_blocks)
526
+ for item in cfg_insertion_layer:
527
+ if item == 0: # add cfg on logits
528
+ add_cfg_on_logits = True
529
+ elif item == 1: # add cfg on probs
530
+ add_cfg_on_probs = True # todo in the future, we may want to add cfg on logits and probs
531
+ elif item < 0: # determine to add cfg at item-th layer's output
532
+ assert leng+item > 0, f'cfg_insertion_layer: {item} is not valid since len(unregistered_blocks)={self.num_block_chunks}'
533
+ abs_cfg_insertion_layers.append(leng+item)
534
+ else:
535
+ raise ValueError(f'cfg_insertion_layer: {item} is not valid')
536
+
537
+ num_stages_minus_1 = len(scale_schedule)-1
538
+ summed_codes = 0
539
+ for si, pn in enumerate(scale_schedule): # si: i-th segment
540
+ cfg = cfg_list[si]
541
+ if si >= trunk_scale:
542
+ break
543
+ cur_L += np.array(pn).prod()
544
+
545
+ need_to_pad = 0
546
+ attn_fn = None
547
+ if self.use_flex_attn:
548
+ # need_to_pad = (self.pad_to_multiplier - cur_L % self.pad_to_multiplier) % self.pad_to_multiplier
549
+ # if need_to_pad:
550
+ # last_stage = F.pad(last_stage, (0, 0, 0, need_to_pad))
551
+ attn_fn = self.attn_fn_compile_dict.get(tuple(scale_schedule[:(si+1)]), None)
552
+
553
+ # assert self.attn_bias_for_masking[:, :, last_L:cur_L, :cur_L].sum() == 0, f'AR with {(self.attn_bias_for_masking[:, :, last_L:cur_L, :cur_L] != 0).sum()} / {self.attn_bias_for_masking[:, :, last_L:cur_L, :cur_L].numel()} mask item'
554
+ layer_idx = 0
555
+ for block_idx, b in enumerate(self.block_chunks):
556
+ # last_stage shape: [4, 1, 2048], cond_BD_or_gss.shape: [4, 1, 6, 2048], ca_kv[0].shape: [64, 2048], ca_kv[1].shape [5], ca_kv[2]: int
557
+ if self.add_lvl_embeding_only_first_block and block_idx == 0:
558
+ last_stage = self.add_lvl_embeding(last_stage, si, scale_schedule, need_to_pad=need_to_pad)
559
+ if not self.add_lvl_embeding_only_first_block:
560
+ last_stage = self.add_lvl_embeding(last_stage, si, scale_schedule, need_to_pad=need_to_pad)
561
+
562
+ for m in b.module:
563
+ last_stage = m(x=last_stage, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=None, attn_fn=attn_fn, scale_schedule=scale_schedule, rope2d_freqs_grid=self.rope2d_freqs_grid, scale_ind=si)
564
+ if (cfg != 1) and (layer_idx in abs_cfg_insertion_layers):
565
+ # print(f'add cfg={cfg} on {layer_idx}-th layer output')
566
+ last_stage = cfg * last_stage[:B] + (1-cfg) * last_stage[B:]
567
+ last_stage = torch.cat((last_stage, last_stage), 0)
568
+ layer_idx += 1
569
+
570
+ if (cfg != 1) and add_cfg_on_logits:
571
+ # print(f'add cfg on add_cfg_on_logits')
572
+ logits_BlV = self.get_logits(last_stage, cond_BD).mul(1/tau_list[si])
573
+ logits_BlV = cfg * logits_BlV[:B] + (1-cfg) * logits_BlV[B:]
574
+ else:
575
+ logits_BlV = self.get_logits(last_stage[:B], cond_BD[:B]).mul(1/tau_list[si])
576
+
577
+ if self.use_bit_label:
578
+ tmp_bs, tmp_seq_len = logits_BlV.shape[:2]
579
+ logits_BlV = logits_BlV.reshape(tmp_bs, -1, 2)
580
+ idx_Bld = sample_with_top_k_top_p_also_inplace_modifying_logits_(logits_BlV, rng=rng, top_k=top_k or self.top_k, top_p=top_p or self.top_p, num_samples=1)[:, :, 0]
581
+ idx_Bld = idx_Bld.reshape(tmp_bs, tmp_seq_len, -1)
582
+ else:
583
+ idx_Bl = sample_with_top_k_top_p_also_inplace_modifying_logits_(logits_BlV, rng=rng, top_k=top_k or self.top_k, top_p=top_p or self.top_p, num_samples=1)[:, :, 0]
584
+ if vae_type != 0:
585
+ assert returns_vemb
586
+ if si < gt_leak:
587
+ idx_Bld = gt_ls_Bl[si]
588
+ else:
589
+ assert pn[0] == 1
590
+ idx_Bld = idx_Bld.reshape(B, pn[1], pn[2], -1) # shape: [B, h, w, d] or [B, h, w, 4d]
591
+ if self.apply_spatial_patchify: # unpatchify operation
592
+ idx_Bld = idx_Bld.permute(0,3,1,2) # [B, 4d, h, w]
593
+ idx_Bld = torch.nn.functional.pixel_shuffle(idx_Bld, 2) # [B, d, 2h, 2w]
594
+ idx_Bld = idx_Bld.permute(0,2,3,1) # [B, 2h, 2w, d]
595
+ idx_Bld = idx_Bld.unsqueeze(1) # [B, 1, h, w, d] or [B, 1, 2h, 2w, d]
596
+
597
+ idx_Bld_list.append(idx_Bld)
598
+ codes = vae.quantizer.lfq.indices_to_codes(idx_Bld, label_type='bit_label') # [B, d, 1, h, w] or [B, d, 1, 2h, 2w]
599
+ if si != num_stages_minus_1:
600
+ summed_codes += F.interpolate(codes, size=vae_scale_schedule[-1], mode=vae.quantizer.z_interplote_up)
601
+ last_stage = F.interpolate(summed_codes, size=vae_scale_schedule[si+1], mode=vae.quantizer.z_interplote_down) # [B, d, 1, h, w] or [B, d, 1, 2h, 2w]
602
+ last_stage = last_stage.squeeze(-3) # [B, d, h, w] or [B, d, 2h, 2w]
603
+ if self.apply_spatial_patchify: # patchify operation
604
+ last_stage = torch.nn.functional.pixel_unshuffle(last_stage, 2) # [B, 4d, h, w]
605
+ last_stage = last_stage.reshape(*last_stage.shape[:2], -1) # [B, d, h*w] or [B, 4d, h*w]
606
+ last_stage = torch.permute(last_stage, [0,2,1]) # [B, h*w, d] or [B, h*w, 4d]
607
+ else:
608
+ summed_codes += codes
609
+ else:
610
+ if si < gt_leak:
611
+ idx_Bl = gt_ls_Bl[si]
612
+ h_BChw = self.quant_only_used_in_inference[0].embedding(idx_Bl).float() # BlC
613
+
614
+ # h_BChw = h_BChw.float().transpose_(1, 2).reshape(B, self.d_vae, scale_schedule[si][0], scale_schedule[si][1])
615
+ h_BChw = h_BChw.transpose_(1, 2).reshape(B, self.d_vae, scale_schedule[si][0], scale_schedule[si][1], scale_schedule[si][2])
616
+ ret.append(h_BChw if returns_vemb != 0 else idx_Bl)
617
+ idx_Bl_list.append(idx_Bl)
618
+ if si != num_stages_minus_1:
619
+ accu_BChw, last_stage = self.quant_only_used_in_inference[0].one_step_fuse(si, num_stages_minus_1+1, accu_BChw, h_BChw, scale_schedule)
620
+
621
+ if si != num_stages_minus_1:
622
+ last_stage = self.word_embed(self.norm0_ve(last_stage))
623
+ last_stage = last_stage.repeat(bs//B, 1, 1)
624
+
625
+ if inference_mode:
626
+ for b in self.unregistered_blocks: (b.sa if isinstance(b, CrossAttnBlock) else b.attn).kv_caching(False)
627
+ else:
628
+ assert self.num_block_chunks > 1
629
+ for block_chunk_ in self.block_chunks:
630
+ for module in block_chunk_.module.module:
631
+ (module.sa if isinstance(module, CrossAttnBlock) else module.attn).kv_caching(False)
632
+
633
+ if not ret_img:
634
+ return ret, idx_Bl_list, []
635
+
636
+ if vae_type != 0:
637
+ img = vae.decode(summed_codes.squeeze(-3))
638
+ else:
639
+ img = vae.viz_from_ms_h_BChw(ret, scale_schedule=scale_schedule, same_shape=True, last_one=True)
640
+
641
+ img = (img + 1) / 2
642
+ img = img.permute(0, 2, 3, 1).mul_(255).to(torch.uint8).flip(dims=(3,))
643
+ return ret, idx_Bl_list, img
644
+
645
+ @for_visualize
646
+ def vis_key_params(self, ep):
647
+ return
648
+
649
+ def load_state_dict(self, state_dict: Dict[str, Any], strict=False, assign=False):
650
+ for k in state_dict:
651
+ if 'cfg_uncond' in k:
652
+ old, new = state_dict[k], self.cfg_uncond.data
653
+ min_tlen = min(old.shape[0], new.shape[0])
654
+ if min_tlen == old.shape[0]:
655
+ state_dict[k] = torch.cat((old.to(device=new.device, dtype=new.dtype), new[min_tlen:]))
656
+ else:
657
+ state_dict[k] = old[:min_tlen]
658
+
659
+ for buf_name in ('lvl_1L', 'attn_bias_for_masking', 'Infinity_visible_kvlen', 'Infinity_invisible_qlen'):
660
+ state_dict.pop(buf_name, None)
661
+ if hasattr(self, buf_name):
662
+ state_dict[buf_name] = getattr(self, buf_name)
663
+
664
+ return super().load_state_dict(state_dict=state_dict, strict=strict, assign=assign)
665
+
666
+ def special_init(
667
+ self,
668
+ aln_init: float,
669
+ aln_gamma_init: float,
670
+ scale_head: float,
671
+ scale_proj: int,
672
+ ):
673
+ # init head's norm
674
+ if isinstance(self.head_nm, AdaLNBeforeHead):
675
+ self.head_nm.ada_lin[-1].weight.data.mul_(aln_init) # there's no gamma for head
676
+ if hasattr(self.head_nm.ada_lin[-1], 'bias') and self.head_nm.ada_lin[-1].bias is not None:
677
+ self.head_nm.ada_lin[-1].bias.data.zero_()
678
+
679
+ # init head's proj
680
+ if scale_head >= 0:
681
+ if isinstance(self.head, nn.Linear):
682
+ self.head.weight.data.mul_(scale_head)
683
+ self.head.bias.data.zero_()
684
+ elif isinstance(self.head, nn.Sequential):
685
+ self.head[-1].weight.data.mul_(scale_head)
686
+ self.head[-1].bias.data.zero_()
687
+
688
+ depth = len(self.unregistered_blocks)
689
+ for block_idx, sab in enumerate(self.unregistered_blocks):
690
+ sab: Union[SelfAttnBlock, CrossAttnBlock]
691
+ # init proj
692
+ scale = 1 / math.sqrt(2*depth if scale_proj == 1 else 2*(1 + block_idx))
693
+ if scale_proj == 1:
694
+ if self.t2i:
695
+ sab.sa.proj.weight.data.mul_(scale)
696
+ sab.ca.proj.weight.data.mul_(scale)
697
+ else:
698
+ sab.attn.proj.weight.data.mul_(scale)
699
+ sab.ffn.fc2.weight.data.mul_(scale)
700
+ # if sab.using_swiglu:
701
+ # nn.init.ones_(sab.ffn.fcg.bias)
702
+ # nn.init.trunc_normal_(sab.ffn.fcg.weight, std=1e-5)
703
+
704
+ # init ada_lin
705
+ if hasattr(sab, 'ada_lin'):
706
+ lin = sab.ada_lin[-1]
707
+ lin.weight.data[:2*self.C].mul_(aln_gamma_init) # init gamma
708
+ lin.weight.data[2*self.C:].mul_(aln_init) # init scale and shift
709
+ if hasattr(lin, 'bias') and lin.bias is not None:
710
+ lin.bias.data.zero_()
711
+ elif hasattr(sab, 'ada_gss'):
712
+ sab.ada_gss.data[:, :, :2, :].mul_(aln_gamma_init) # init gamma
713
+ sab.ada_gss.data[:, :, 2:, :].mul_(aln_init) # init scale and shift
714
+
715
+ def extra_repr(self):
716
+ return f'drop_path_rate={self.drop_path_rate}'
717
+
718
+ def get_layer_id_and_scale_exp(self, para_name: str):
719
+ raise NotImplementedError
720
+
721
+
722
+ def sample_with_top_k_top_p_also_inplace_modifying_logits_(logits_BlV: torch.Tensor, top_k: int = 0, top_p: float = 0.0, rng=None, num_samples=1) -> torch.Tensor: # return idx, shaped (B, l)
723
+ B, l, V = logits_BlV.shape
724
+ if top_k > 0:
725
+ top_k = min(top_k, V)
726
+ idx_to_remove = logits_BlV < logits_BlV.topk(top_k, largest=True, sorted=False, dim=-1)[0].amin(dim=-1, keepdim=True)
727
+ logits_BlV.masked_fill_(idx_to_remove, -torch.inf)
728
+ if top_p > 0:
729
+ sorted_logits, sorted_idx = logits_BlV.sort(dim=-1, descending=False)
730
+ sorted_idx_to_remove = sorted_logits.softmax(dim=-1).cumsum_(dim=-1) <= (1 - top_p)
731
+ sorted_idx_to_remove[..., -1:] = False
732
+ logits_BlV.masked_fill_(sorted_idx_to_remove.scatter(sorted_idx.ndim - 1, sorted_idx, sorted_idx_to_remove), -torch.inf)
733
+ # sample (have to squeeze cuz multinomial can only be used on 2D tensor)
734
+ replacement = num_samples >= 0
735
+ num_samples = abs(num_samples)
736
+ return torch.multinomial(logits_BlV.softmax(dim=-1).view(-1, V), num_samples=num_samples, replacement=replacement, generator=rng).view(B, l, num_samples)
737
+
738
+ def sampling_with_top_k_top_p_also_inplace_modifying_probs_(probs_BlV: torch.Tensor, top_k: int = 0, top_p: float = 0.0, rng=None, num_samples=1) -> torch.Tensor: # return idx, shaped (B, l)
739
+ B, l, V = probs_BlV.shape
740
+ if top_k > 0:
741
+ top_k = min(top_k, V)
742
+ idx_to_remove = probs_BlV < probs_BlV.topk(top_k, largest=True, sorted=False, dim=-1)[0].amin(dim=-1, keepdim=True)
743
+ probs_BlV.masked_fill_(idx_to_remove, 0)
744
+ if top_p > 0:
745
+ sorted_probs, sorted_idx = probs_BlV.sort(dim=-1, descending=False)
746
+ sorted_idx_to_remove = sorted_probs.softmax(dim=-1).cumsum_(dim=-1) <= (1 - top_p)
747
+ sorted_idx_to_remove[..., -1:] = False
748
+ probs_BlV.masked_fill_(sorted_idx_to_remove.scatter(sorted_idx.ndim - 1, sorted_idx, sorted_idx_to_remove), 0)
749
+ # sample (have to squeeze cuz multinomial can only be used on 2D tensor)
750
+ probs_BlV = probs_BlV / probs_BlV.sum(-1, keepdims=True)
751
+ replacement = num_samples >= 0
752
+ num_samples = abs(num_samples)
753
+ return torch.multinomial(probs_BlV.view(-1, V), num_samples=num_samples, replacement=replacement, generator=rng).view(B, l, num_samples)
754
+
755
+
756
+ def get_params_num(d, w, mlp):
757
+ m = round(mlp * w / 256) * 256
758
+ s = d * (w**2 * 8 + w*m * 2) # sa+ca, mlp
759
+ s += w**2 * 6 # saln
760
+ s += 4096 * w # pred
761
+ s += 32 * w # we
762
+
763
+ Ct5 = 4096
764
+ s += Ct5*w * 4 # T5 attn pool
765
+ s += Ct5*w + w*w # T5 mlp
766
+ return f'{s/1e9:.2f}B'
767
+
768
+
769
+ TIMM_KEYS = {'img_size', 'pretrained', 'pretrained_cfg', 'pretrained_cfg_overlay', 'global_pool'}
770
+
771
+ @register_model
772
+ def infinity_2b(depth=32, embed_dim=2048, num_heads=2048//128, drop_path_rate=0.1, **kwargs): return Infinity(depth=depth, embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=4, drop_path_rate=drop_path_rate, **{k: v for k, v in kwargs.items() if k not in TIMM_KEYS})
773
+
774
+ @register_model
775
+ def infinity_20b(depth=58, embed_dim=4608, num_heads=4608//128, drop_path_rate=0.25, **kwargs): return Infinity(depth=depth, embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=4, drop_path_rate=drop_path_rate, **{k: v for k, v in kwargs.items() if k not in TIMM_KEYS})
776
+
777
+ # model configuration for scaling Infinity transformer
778
+ @register_model
779
+ def infinity_layer12(depth=12, embed_dim=768, num_heads=8, drop_path_rate=0.1, **kwargs):
780
+ return Infinity(depth=depth, embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=4, drop_path_rate=drop_path_rate, **{k: v for k, v in kwargs.items() if k not in TIMM_KEYS})
781
+ @register_model
782
+ def infinity_layer16(depth=16, embed_dim=1152, num_heads=12, drop_path_rate=0.1, **kwargs):
783
+ return Infinity(depth=depth, embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=4, drop_path_rate=drop_path_rate, **{k: v for k, v in kwargs.items() if k not in TIMM_KEYS})
784
+ @register_model
785
+ def infinity_layer24(depth=24, embed_dim=1536, num_heads=16, drop_path_rate=0.1, **kwargs):
786
+ return Infinity(depth=depth, embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=4, drop_path_rate=drop_path_rate, **{k: v for k, v in kwargs.items() if k not in TIMM_KEYS})
787
+ @register_model
788
+ def infinity_layer32(depth=32, embed_dim=2080, num_heads=20, drop_path_rate=0.1, **kwargs):
789
+ return Infinity(depth=depth, embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=4, drop_path_rate=drop_path_rate, **{k: v for k, v in kwargs.items() if k not in TIMM_KEYS})
790
+ @register_model
791
+ def infinity_layer40(depth=40, embed_dim=2688, num_heads=24, drop_path_rate=0.1, **kwargs):
792
+ return Infinity(depth=depth, embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=4, drop_path_rate=drop_path_rate, **{k: v for k, v in kwargs.items() if k not in TIMM_KEYS})
793
+ @register_model
794
+ def infinity_layer48(depth=48, embed_dim=3360, num_heads=28, drop_path_rate=0.1, **kwargs):
795
+ return Infinity(depth=depth, embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=4, drop_path_rate=drop_path_rate, **{k: v for k, v in kwargs.items() if k not in TIMM_KEYS})
models/init_param.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+
4
+ def init_weights(model: nn.Module, conv_std_or_gain: float = 0.02, other_std: float = 0.02):
5
+ """
6
+ :param model: the model to be inited
7
+ :param conv_std_or_gain: how to init every conv layer `m`
8
+ > 0: nn.init.trunc_normal_(m.weight.data, std=conv_std_or_gain)
9
+ < 0: nn.init.xavier_normal_(m.weight.data, gain=-conv_std_or_gain)
10
+ :param other_std: how to init every linear layer or embedding layer
11
+ use nn.init.trunc_normal_(m.weight.data, std=other_std)
12
+ """
13
+ skip = abs(conv_std_or_gain) > 10
14
+ if skip: return
15
+ print(f'[init_weights] {type(model).__name__} with {"std" if conv_std_or_gain > 0 else "gain"}={abs(conv_std_or_gain):g}')
16
+ for m in model.modules():
17
+ if isinstance(m, nn.Linear):
18
+ nn.init.trunc_normal_(m.weight.data, std=other_std)
19
+ if m.bias is not None:
20
+ nn.init.constant_(m.bias.data, 0.)
21
+ elif isinstance(m, nn.Embedding):
22
+ nn.init.trunc_normal_(m.weight.data, std=other_std)
23
+ if m.padding_idx is not None:
24
+ m.weight.data[m.padding_idx].zero_()
25
+ elif isinstance(m, (nn.Conv1d, nn.Conv2d, nn.ConvTranspose1d, nn.ConvTranspose2d)):
26
+ nn.init.trunc_normal_(m.weight.data, std=conv_std_or_gain) if conv_std_or_gain > 0 else nn.init.xavier_normal_(m.weight.data, gain=-conv_std_or_gain) # todo: StyleSwin: (..., gain=.02)
27
+ if hasattr(m, 'bias') and m.bias is not None:
28
+ nn.init.constant_(m.bias.data, 0.)
29
+ elif isinstance(m, (nn.LayerNorm, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm, nn.GroupNorm, nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d)):
30
+ if m.bias is not None:
31
+ nn.init.constant_(m.bias.data, 0.)
32
+ if m.weight is not None:
33
+ nn.init.constant_(m.weight.data, 1.)
models/t5.py ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import torch
3
+ import os
4
+ import traceback
5
+ import numpy as np
6
+ from huggingface_hub import hf_hub_download
7
+ from transformers import AutoTokenizer, T5EncoderModel
8
+
9
+ import ftfy
10
+ import html
11
+ from bs4 import BeautifulSoup
12
+ import urllib.parse as ul
13
+
14
+
15
+ class T5Embedder:
16
+
17
+ available_models = ['t5-v1_1-xxl']
18
+ bad_punct_regex = re.compile(r'['+'#®•©™&@·º½¾¿¡§~'+'\)'+'\('+'\]'+'\['+'\}'+'\{'+'\|'+'\\'+'\/'+'\*' + r']{1,}') # noqa
19
+
20
+ def __init__(self, device, dir_or_name='t5-v1_1-xxl', *, local_cache=False, cache_dir=None, hf_token=None, use_text_preprocessing=True,
21
+ t5_model_kwargs=None, torch_dtype=torch.bfloat16, use_offload_folder=None, model_max_length=512, padding="max_length", clean_caption_func_name="clean_caption"):
22
+ self.device = torch.device(device)
23
+ self.torch_dtype = torch_dtype
24
+ if t5_model_kwargs is None:
25
+ t5_model_kwargs = {'low_cpu_mem_usage': True, 'torch_dtype': self.torch_dtype}
26
+ if use_offload_folder is not None:
27
+ t5_model_kwargs['offload_folder'] = use_offload_folder
28
+ t5_model_kwargs['device_map'] = {
29
+ 'shared': self.device,
30
+ 'encoder.embed_tokens': self.device,
31
+ 'encoder.block.0': self.device,
32
+ 'encoder.block.1': self.device,
33
+ 'encoder.block.2': self.device,
34
+ 'encoder.block.3': self.device,
35
+ 'encoder.block.4': self.device,
36
+ 'encoder.block.5': self.device,
37
+ 'encoder.block.6': self.device,
38
+ 'encoder.block.7': self.device,
39
+ 'encoder.block.8': self.device,
40
+ 'encoder.block.9': self.device,
41
+ 'encoder.block.10': self.device,
42
+ 'encoder.block.11': self.device,
43
+ 'encoder.block.12': 'disk',
44
+ 'encoder.block.13': 'disk',
45
+ 'encoder.block.14': 'disk',
46
+ 'encoder.block.15': 'disk',
47
+ 'encoder.block.16': 'disk',
48
+ 'encoder.block.17': 'disk',
49
+ 'encoder.block.18': 'disk',
50
+ 'encoder.block.19': 'disk',
51
+ 'encoder.block.20': 'disk',
52
+ 'encoder.block.21': 'disk',
53
+ 'encoder.block.22': 'disk',
54
+ 'encoder.block.23': 'disk',
55
+ 'encoder.final_layer_norm': 'disk',
56
+ 'encoder.dropout': 'disk',
57
+ }
58
+ else:
59
+ t5_model_kwargs['device_map'] = {'shared': self.device, 'encoder': self.device}
60
+
61
+ self.use_text_preprocessing = use_text_preprocessing
62
+ self.hf_token = hf_token
63
+ self.cache_dir = cache_dir or os.path.expanduser('~/.cache/IF_')
64
+ self.dir_or_name = dir_or_name
65
+ tokenizer_path, path = dir_or_name, dir_or_name
66
+ if local_cache:
67
+ cache_dir = os.path.join(self.cache_dir, dir_or_name)
68
+ tokenizer_path, path = cache_dir, cache_dir
69
+ elif dir_or_name in self.available_models:
70
+ cache_dir = os.path.join(self.cache_dir, dir_or_name)
71
+ for filename in [
72
+ 'config.json', 'special_tokens_map.json', 'spiece.model', 'tokenizer_config.json',
73
+ 'pytorch_model.bin.index.json', 'pytorch_model-00001-of-00002.bin', 'pytorch_model-00002-of-00002.bin'
74
+ ]:
75
+ hf_hub_download(repo_id=f'DeepFloyd/{dir_or_name}', filename=filename, cache_dir=cache_dir,
76
+ force_filename=filename, token=self.hf_token)
77
+ tokenizer_path, path = cache_dir, cache_dir
78
+ else:
79
+ cache_dir = os.path.join(self.cache_dir, 't5-v1_1-xxl')
80
+ for filename in [
81
+ 'config.json', 'special_tokens_map.json', 'spiece.model', 'tokenizer_config.json',
82
+ ]:
83
+ hf_hub_download(repo_id='DeepFloyd/t5-v1_1-xxl', filename=filename, cache_dir=cache_dir,
84
+ force_filename=filename, token=self.hf_token)
85
+ tokenizer_path = cache_dir
86
+
87
+ print(f"Loading T5 from {tokenizer_path}")
88
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
89
+ self.model = T5EncoderModel.from_pretrained(path, **t5_model_kwargs).eval()
90
+ self.model_max_length = model_max_length
91
+ self.padding = padding
92
+ self.clean_caption_func = self.__getattribute__(clean_caption_func_name)
93
+
94
+ @torch.no_grad()
95
+ def get_text_embeddings(self, texts):
96
+ import time
97
+ start_time = time.time()
98
+
99
+ texts = [self.text_preprocessing(text) for text in texts]
100
+ # print("text_preprocessing: ", time.time() - start_time)
101
+
102
+ text_tokens_and_mask = self.tokenizer(
103
+ texts,
104
+ max_length=self.model_max_length,
105
+ padding=self.padding,
106
+ truncation=True,
107
+ return_attention_mask=True,
108
+ add_special_tokens=True,
109
+ return_tensors='pt'
110
+ )
111
+
112
+ # print("tokenizer: ", time.time() - start_time)
113
+
114
+ text_tokens_and_mask['input_ids'] = text_tokens_and_mask['input_ids'].to(self.device)
115
+ text_tokens_and_mask['attention_mask'] = text_tokens_and_mask['attention_mask'].to(self.device)
116
+
117
+ with torch.no_grad():
118
+ text_encoder_embs = self.model(
119
+ input_ids=text_tokens_and_mask['input_ids'],
120
+ attention_mask=text_tokens_and_mask['attention_mask'],
121
+ )['last_hidden_state'].detach()
122
+
123
+ # print("model: ", time.time() - start_time)
124
+ return text_encoder_embs, text_tokens_and_mask['attention_mask'], text_tokens_and_mask['input_ids'], texts
125
+
126
+ def text_preprocessing(self, text):
127
+ if self.use_text_preprocessing:
128
+ try:
129
+ # The exact text cleaning as was in the training stage:
130
+ text = self.clean_caption_func(text)
131
+ text = self.clean_caption_func(text)
132
+ return text
133
+ except Exception as e:
134
+ print(f"Error in text preprocessing: {e} with text: {text}")
135
+ print(traceback.format_exc())
136
+ return text
137
+ else:
138
+ return text.lower().strip()
139
+
140
+ @staticmethod
141
+ def basic_clean(text):
142
+ text = ftfy.fix_text(text)
143
+ text = html.unescape(html.unescape(text))
144
+ return text.strip()
145
+
146
+ def clean_caption(self, caption):
147
+ caption = str(caption)
148
+ caption = ul.unquote_plus(caption)
149
+ caption = caption.strip().lower()
150
+ caption = re.sub('<person>', 'person', caption)
151
+ # urls:
152
+ caption = re.sub(
153
+ r'\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))', # noqa
154
+ '', caption) # regex for urls
155
+ caption = re.sub(
156
+ r'\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))', # noqa
157
+ '', caption) # regex for urls
158
+ # html:
159
+ try:
160
+ caption = BeautifulSoup(caption, features='html.parser').text
161
+ except Exception as e:
162
+ print(f"Error parsing caption:{caption} with html.parser: {e}")
163
+
164
+ # @<nickname>
165
+ caption = re.sub(r'@[\w\d]+\b', '', caption)
166
+
167
+ # 31C0—31EF CJK Strokes
168
+ # 31F0—31FF Katakana Phonetic Extensions
169
+ # 3200—32FF Enclosed CJK Letters and Months
170
+ # 3300—33FF CJK Compatibility
171
+ # 3400—4DBF CJK Unified Ideographs Extension A
172
+ # 4DC0—4DFF Yijing Hexagram Symbols
173
+ # 4E00—9FFF CJK Unified Ideographs
174
+ caption = re.sub(r'[\u31c0-\u31ef]+', '', caption)
175
+ caption = re.sub(r'[\u31f0-\u31ff]+', '', caption)
176
+ caption = re.sub(r'[\u3200-\u32ff]+', '', caption)
177
+ caption = re.sub(r'[\u3300-\u33ff]+', '', caption)
178
+ caption = re.sub(r'[\u3400-\u4dbf]+', '', caption)
179
+ caption = re.sub(r'[\u4dc0-\u4dff]+', '', caption)
180
+ caption = re.sub(r'[\u4e00-\u9fff]+', '', caption)
181
+ #######################################################
182
+
183
+ # все виды тире / all types of dash --> "-"
184
+ caption = re.sub(
185
+ r'[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+', # noqa
186
+ '-', caption)
187
+
188
+ # кавычки к одному стандарту
189
+ caption = re.sub(r'[`´«»“”¨]', '"', caption)
190
+ caption = re.sub(r'[‘’]', "'", caption)
191
+
192
+ # &quot;
193
+ caption = re.sub(r'&quot;?', '', caption)
194
+ # &amp
195
+ caption = re.sub(r'&amp', '', caption)
196
+
197
+ # ip adresses:
198
+ caption = re.sub(r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}', ' ', caption)
199
+
200
+ # article ids:
201
+ caption = re.sub(r'\d:\d\d\s+$', '', caption)
202
+
203
+ # \n
204
+ caption = re.sub(r'\\n', ' ', caption)
205
+
206
+ # "#123"
207
+ caption = re.sub(r'#\d{1,3}\b', '', caption)
208
+ # "#12345.."
209
+ caption = re.sub(r'#\d{5,}\b', '', caption)
210
+ # "123456.."
211
+ caption = re.sub(r'\b\d{6,}\b', '', caption)
212
+ # filenames:
213
+ caption = re.sub(r'[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)', '', caption)
214
+
215
+ #
216
+ caption = re.sub(r'[\"\']{2,}', r'"', caption) # """AUSVERKAUFT"""
217
+ caption = re.sub(r'[\.]{2,}', r' ', caption) # """AUSVERKAUFT"""
218
+
219
+ caption = re.sub(self.bad_punct_regex, r' ', caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
220
+ caption = re.sub(r'\s+\.\s+', r' ', caption) # " . "
221
+
222
+ # this-is-my-cute-cat / this_is_my_cute_cat
223
+ regex2 = re.compile(r'(?:\-|\_)')
224
+ if len(re.findall(regex2, caption)) > 3:
225
+ caption = re.sub(regex2, ' ', caption)
226
+
227
+ caption = self.basic_clean(caption)
228
+
229
+ caption = re.sub(r'\b[a-zA-Z]{1,3}\d{3,15}\b', '', caption) # jc6640
230
+ caption = re.sub(r'\b[a-zA-Z]+\d+[a-zA-Z]+\b', '', caption) # jc6640vc
231
+ caption = re.sub(r'\b\d+[a-zA-Z]+\d+\b', '', caption) # 6640vc231
232
+
233
+ caption = re.sub(r'(worldwide\s+)?(free\s+)?shipping', '', caption)
234
+ caption = re.sub(r'(free\s)?download(\sfree)?', '', caption)
235
+ caption = re.sub(r'\bclick\b\s(?:for|on)\s\w+', '', caption)
236
+ caption = re.sub(r'\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?', '', caption)
237
+ caption = re.sub(r'\bpage\s+\d+\b', '', caption)
238
+
239
+ caption = re.sub(r'\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b', r' ', caption) # j2d1a2a...
240
+
241
+ caption = re.sub(r'\b\d+\.?\d*[xх×]\d+\.?\d*\b', '', caption)
242
+
243
+ caption = re.sub(r'\b\s+\:\s+', r': ', caption)
244
+ caption = re.sub(r'(\D[,\./])\b', r'\1 ', caption)
245
+ caption = re.sub(r'\s+', ' ', caption)
246
+
247
+ caption.strip()
248
+
249
+ caption = re.sub(r'^[\"\']([\w\W]+)[\"\']$', r'\1', caption)
250
+ caption = re.sub(r'^[\'\_,\-\:;]', r'', caption)
251
+ caption = re.sub(r'[\'\_,\-\:\-\+]$', r'', caption)
252
+ caption = re.sub(r'^\.\S+$', '', caption)
253
+
254
+ return caption.strip()
255
+
256
+
257
+ def clean_caption_simplify(self, caption):
258
+ # 将 caption 转换为字符串
259
+ caption = str(caption)
260
+
261
+ # 解码 URL 编码的字符串
262
+ caption = ul.unquote_plus(caption)
263
+
264
+ # 去除首尾空格并转换为小写
265
+ caption = caption.strip().lower()
266
+
267
+ # 将 '<person>' 替换为 'person'
268
+ caption = re.sub('<person>', 'person', caption)
269
+
270
+ # 移除 URL
271
+ caption = re.sub(
272
+ r'\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))',
273
+ '', caption) # 匹配以 http:// 或 https:// 开头的 URL
274
+ caption = re.sub(
275
+ r'\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))',
276
+ '', caption) # 匹配以 www. 开头的 URL
277
+
278
+ # 解析 HTML 并删除 HTML 标签
279
+ caption = BeautifulSoup(caption, features='html.parser').text
280
+
281
+ # 移除 @nickname 标签
282
+ caption = re.sub(r'@[\w\d]+\b', '', caption)
283
+
284
+ # 移除特定 Unicode 范围的字符:CJK 相关字符
285
+ caption = re.sub(r'[\u31c0-\u31ef]+', '', caption) # CJK 笔划
286
+ caption = re.sub(r'[\u31f0-\u31ff]+', '', caption) # 片假名语音扩展
287
+ caption = re.sub(r'[\u3200-\u32ff]+', '', caption) # 圆括号中的 CJK 字母和月份
288
+ caption = re.sub(r'[\u3300-\u33ff]+', '', caption) # CJK 兼容性
289
+ caption = re.sub(r'[\u3400-\u4dbf]+', '', caption) # CJK 统一表意符号扩展 A
290
+ caption = re.sub(r'[\u4dc0-\u4dff]+', '', caption) # 易经卦象符号
291
+ caption = re.sub(r'[\u4e00-\u9fff]+', '', caption) # CJK 统一表意符号
292
+
293
+ # 所有类型的破折号替换为 "-"
294
+ caption = re.sub(
295
+ r'[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+',
296
+ '-', caption) # 匹配各种 Unicode 破折号
297
+
298
+ # 统一不同类型的引号
299
+ caption = re.sub(r'[`´«»“”¨]', '"', caption) # 将各种引号替换为标准引号
300
+ caption = re.sub(r'[‘’]', "'", caption) # 将左单引号和右单引号替换为标准单引号
301
+
302
+ # 移除 &quot; 和 &amp
303
+ caption = re.sub(r'&quot;?', '', caption) # 移除 HTML 实体 &quot;
304
+ caption = re.sub(r'&amp', '', caption) # 移除 HTML 实体 &amp
305
+
306
+ # 移除 IP 地址
307
+ caption = re.sub(r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}', ' ', caption) # 匹配 IPv4 地址
308
+
309
+ # 移除文章 ID 格式
310
+ caption = re.sub(r'\d:\d\d\s+$', '', caption) # 匹配类似 '1:23 ' 的格式
311
+
312
+ # 移除 \n 转义字符
313
+ caption = re.sub(r'\\n', ' ', caption)
314
+
315
+ # 移除特定格式的标签
316
+ # caption = re.sub(r'#\d{1,3}\b', '', caption) # #123 移除 # 加 1 到 3 位数字的标签
317
+ # caption = re.sub(r'#\d{5,}\b', '', caption) # #12345.. 移除 # 加 5 位或以上数字的标签
318
+ # caption = re.sub(r'\b\d{6,}\b', '', caption) # 123456.. 移除 6 位或以上的纯数字
319
+
320
+ # 移除文件名
321
+ caption = re.sub(r'[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)', '', caption) # 匹配图片和视频文件,匹配完整的文件名,包括文件名本身和扩展名。
322
+
323
+ # 简化多重引号和点
324
+ caption = re.sub(r'[\"\']{2,}', r'"', caption) # 连续的双引号替换为一个双引号
325
+ caption = re.sub(r'[\.]{2,}', r' ', caption) # 连续���点替换为空格
326
+
327
+ # 使用通用标点正则表达式清理无效标点
328
+ caption = re.sub(self.bad_punct_regex, r' ', caption) # 自定义的无效标点正则表达式
329
+ caption = re.sub(r'\s+\.\s+', r' ', caption) # 移除空格和点
330
+
331
+ # 过滤带有太多破折号或下划线的文本
332
+ regex2 = re.compile(r'(?:\-|\_)')
333
+ if len(re.findall(regex2, caption)) > 3:
334
+ caption = re.sub(regex2, ' ', caption)
335
+
336
+ # 基本清理
337
+ caption = self.basic_clean(caption)
338
+
339
+ # 移除特定格式的短字符串
340
+ # caption = re.sub(r'\b[a-zA-Z]{1,3}\d{3,15}\b', '', caption) # 匹配三个字母以下加三个数字以上的字符串
341
+ # caption = re.sub(r'\b[a-zA-Z]+\d+[a-zA-Z]+\b', '', caption) # 匹配字母数字混合的字符串
342
+ # caption = re.sub(r'\b\d+[a-zA-Z]+\d+\b', '', caption) # 匹配数字字母混合的字符串
343
+
344
+ # 移除特定的广告或指令性短语
345
+ # caption = re.sub(r'(worldwide\s+)?(free\s+)?shipping', '', caption) # 匹配 'worldwide free shipping', 'free shipping'
346
+ # caption = re.sub(r'(free\s)?download(\sfree)?', '', caption) # 匹配 'free download', 'download free'
347
+ # caption = re.sub(r'\bclick\b\s(?:for|on)\s\w+', '', caption) # 匹配 'click for ...' 或 'click on ...'
348
+ # caption = re.sub(r'\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?', '', caption) # 匹配文件扩展名,匹配独立的扩展名或扩展名后可能跟随的特定词汇的场景
349
+ # caption = re.sub(r'\bpage\s+\d+\b', '', caption) # 匹配 'page 123'
350
+
351
+ # 移除复杂模式的字符串
352
+ # caption = re.sub(r'\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b', r' ', caption) # 123A456B789
353
+
354
+ # 移除特定的矩形标识符
355
+ caption = re.sub(r'\b\d+\.?\d*[xх×]\d+\.?\d*\b', '', caption)
356
+
357
+ # 修复多余的空白和标点
358
+ caption = re.sub(r'\b\s+\:\s+', r': ', caption)
359
+ caption = re.sub(r'(\D[,\./])\b', r'\1 ', caption)
360
+ caption = re.sub(r'\s+', ' ', caption)
361
+
362
+ # 去除首尾的多余字符
363
+ caption.strip()
364
+ caption = re.sub(r'^[\"\']([\w\W]+)[\"\']$', r'\1', caption)
365
+ caption = re.sub(r'^[\'\_,\-\:;]', r'', caption)
366
+ caption = re.sub(r'[\'\_,\-\:\-\+]$', r'', caption)
367
+ caption = re.sub(r'^\.\S+$', '', caption)
368
+
369
+ return caption.strip()
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ random
2
+ torch
3
+ opencv-python
4
+ numpy
5
+ gradio
6
+ huggingface-hub
7
+ transformers
8
+ argparse
9
+ spaces
utils/amp_opt.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ import signal
4
+ import sys
5
+ import time
6
+ from typing import List, Optional, Tuple, Union
7
+
8
+ import torch
9
+ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
10
+ # from memory_profiler import profile
11
+
12
+ import infinity.utils.dist as dist
13
+ from infinity.utils import misc
14
+
15
+ class NullCtx:
16
+ def __enter__(self):
17
+ pass
18
+
19
+ def __exit__(self, exc_type, exc_val, exc_tb):
20
+ pass
21
+
22
+
23
+ def handle_timeout(signum, frame):
24
+ raise TimeoutError('took too long')
25
+
26
+
27
+ def per_param_clip_grad_norm_(parameters, thresh: float, stable=False, fp=None) -> (float, float):
28
+ skipped, max_grad = [], 0
29
+ for pi, p in enumerate(parameters):
30
+ if p.grad is not None:
31
+ g = p.grad.data.norm(2).item() + 1e-7
32
+ max_grad = max(max_grad, g)
33
+ clip_coef = thresh / g
34
+ if clip_coef < 1:
35
+ if stable and clip_coef < 0.2:
36
+ skipped.append(clip_coef)
37
+ p.grad.data.mul_(0) # todo NOTE: inf.mul_(0)==nan will shrink the scale ratio, but inf.zero_()==0 won't
38
+ else:
39
+ p.grad.data.mul_(clip_coef)
40
+
41
+ # if fp is not None: fp.write(f'[per_param_clip_grad_norm_:47] finished.\n'); fp.flush()
42
+ return 0 if len(skipped) == 0 else math.log10(max(min(skipped), 1e-7)), max_grad
43
+
44
+
45
+ class AmpOptimizer:
46
+ def __init__(
47
+ self,
48
+ model_name_3letters: str, mixed_precision: int,
49
+ optimizer: torch.optim.Optimizer, model_maybe_fsdp: Union[torch.nn.Module, FSDP],
50
+ r_accu: float, grad_clip: float, zero: int,
51
+ ):
52
+ self.enable_amp = mixed_precision > 0
53
+ self.zero = zero
54
+ if self.enable_amp:
55
+ self.using_fp16_rather_bf16 = mixed_precision != 2
56
+ self.max_sc = float(mixed_precision if mixed_precision > 128 else 32768)
57
+
58
+ # todo: on both V100 and A100, torch.get_autocast_gpu_dtype() returns fp16, not bf16.
59
+ self.amp_ctx = torch.autocast('cuda', enabled=True, dtype=torch.float16 if self.using_fp16_rather_bf16 else torch.bfloat16, cache_enabled=self.zero == 0) # todo: cache_enabled=False
60
+ if self.using_fp16_rather_bf16:
61
+ self.scaler = torch.cuda.amp.GradScaler(init_scale=2. ** 11, growth_interval=1000)
62
+ else:
63
+ self.scaler = None
64
+ else:
65
+ self.using_fp16_rather_bf16 = True
66
+ self.amp_ctx = NullCtx()
67
+ self.scaler = None
68
+
69
+ t = torch.zeros(dist.get_world_size())
70
+ t[dist.get_rank()] = float(self.enable_amp)
71
+ dist.allreduce(t)
72
+ assert round(t.sum().item()) in {0, dist.get_world_size()}, f'enable_amp: {t}'
73
+
74
+ t = torch.zeros(dist.get_world_size())
75
+ t[dist.get_rank()] = float(self.using_fp16_rather_bf16)
76
+ dist.allreduce(t)
77
+ assert round(t.sum().item()) in {0, dist.get_world_size()}, f'using_fp16_rather_bf16: {t}'
78
+
79
+ self.model_name_3letters = model_name_3letters
80
+ self.optimizer, self.model_maybe_fsdp = optimizer, model_maybe_fsdp
81
+ self.r_accu = r_accu
82
+
83
+ self.paras = self.names = ... # todo: solve EMA-related codes
84
+
85
+ self.grad_clip, self.grad_clip_we = grad_clip, 0 # todo: disable wclip
86
+ if self.grad_clip > 100:
87
+ self.grad_clip %= 100
88
+ self.per_param = True
89
+ else:
90
+ self.per_param = False
91
+ self.per_param = False # todo: disable wclip
92
+
93
+ self.early_clipping = grad_clip > 0 and not hasattr(optimizer, 'global_grad_norm')
94
+ self.late_clipping = grad_clip > 0 and hasattr(optimizer, 'global_grad_norm') # deepspeed's optimizer
95
+
96
+ self.fp = None
97
+ self.last_orig_norm: torch.Tensor = torch.tensor(0.1)
98
+
99
+ @torch.no_grad()
100
+ def log_param(self, ep: int):
101
+ if self.zero == 0:
102
+ for name, values in get_param_for_log(self.model_name_3letters, self.model_maybe_fsdp.named_parameters()).items():
103
+ values: List[float]
104
+ if len(values) == 1: # e.g., cls token will only have one value
105
+ values.append(values[0])
106
+ else:
107
+ ...
108
+ # todo: log params
109
+
110
+ # @profile(precision=4, stream=open('amp_sc.log', 'w+'))
111
+ def backward_clip_step(
112
+ self, ep: int, it: int, g_it: int, stepping: bool, logging_params: bool, loss: torch.Tensor, clip_decay_ratio=1, stable=False,
113
+ ) -> Tuple[torch.Tensor, Optional[float]]:
114
+ # backward
115
+ loss = loss.mul(self.r_accu) # r_accu == 1.0 / n_gradient_accumulation
116
+ orig_norm = scaler_sc = None
117
+ # if self.fp is not None:
118
+ # if g_it % 20 == 0: self.fp.seek(0); self.fp.truncate(0)
119
+ if self.scaler is not None:
120
+ self.scaler.scale(loss).backward(retain_graph=False, create_graph=False) # retain_graph=retain_graph, create_graph=create_graph
121
+ else:
122
+ loss.backward(retain_graph=False, create_graph=False)
123
+ # if self.fp is not None: self.fp.write(f'[backward_clip_step:131] [it{it}, g_it{g_it}] after backward\n'); self.fp.flush()
124
+
125
+ # clip gradients then step optimizer
126
+ if stepping:
127
+ if self.scaler is not None: self.scaler.unscale_(self.optimizer) # now the gradient can be correctly got
128
+ # if self.fp is not None: self.fp.write(f'[backward_clip_step:137] [it{it}, g_it{g_it}] after scaler.unscale_\n'); self.fp.flush()
129
+
130
+ skipped, orig_norm = 0, self.last_orig_norm
131
+ # try:
132
+ if self.fp is not None:
133
+ if g_it % 10 == 0: self.fp.seek(0); self.fp.truncate(0)
134
+ self.fp.write(f'<ep{ep} it{it} {g_it}>\n'); self.fp.flush()
135
+ if self.early_clipping:
136
+ c = self.grad_clip * clip_decay_ratio
137
+ if self.zero:
138
+ orig_norm: Optional[torch.Tensor] = self.model_maybe_fsdp.clip_grad_norm_(c)
139
+ else:
140
+ orig_norm: Optional[torch.Tensor] = torch.nn.utils.clip_grad_norm_(self.model_maybe_fsdp.parameters(), c)
141
+
142
+ # if self.fp is not None: self.fp.write(f'[backward_clip_step:175] [it{it}, g_it{g_it}] before opt step\n'); self.fp.flush()
143
+ if self.scaler is not None:
144
+ self.scaler: torch.cuda.amp.GradScaler
145
+ if self.zero:
146
+ # synchronize found_inf_per_device before calling step, so that even if only some ranks found inf on their sharded params, all other ranks will know
147
+ # otherwise, when saving FSDP optimizer state, it will cause AssertionError saying "Different ranks have different values for step."
148
+ for optimizer_state in self.scaler._per_optimizer_states.values():
149
+ for t in optimizer_state['found_inf_per_device'].values():
150
+ dist.allreduce(t) # ideally, each rank only has one single t; so no need to use async allreduce
151
+
152
+ self.scaler.step(self.optimizer)
153
+ scaler_sc: Optional[float] = self.scaler.get_scale()
154
+ if scaler_sc > self.max_sc: # fp16 will overflow when >65536, so multiply 32768 could be dangerous
155
+ # print(f'[fp16 scaling] too large loss scale {scaler_sc}! (clip to {self.max_sc:g})')
156
+ self.scaler.update(new_scale=self.max_sc)
157
+ else:
158
+ self.scaler.update()
159
+ try:
160
+ scaler_sc = float(math.log2(scaler_sc))
161
+ except Exception as e:
162
+ print(f'[scaler_sc = {scaler_sc}]\n' * 15, flush=True)
163
+ time.sleep(1)
164
+ print(f'[scaler_sc = {scaler_sc}]\n' * 15, flush=True)
165
+ raise e
166
+ else:
167
+ self.optimizer.step()
168
+
169
+ if self.late_clipping:
170
+ orig_norm: Optional[torch.Tensor] = self.optimizer.global_grad_norm
171
+ self.last_orig_norm = orig_norm
172
+ # no zero_grad calling here, gonna log those gradients!
173
+ return orig_norm, scaler_sc
174
+
175
+ def state_dict(self):
176
+ return {
177
+ 'optimizer': self.optimizer.state_dict()
178
+ } if self.scaler is None else {
179
+ 'scaler': self.scaler.state_dict(),
180
+ 'optimizer': self.optimizer.state_dict()
181
+ }
182
+
183
+ def load_state_dict(self, state, strict=True):
184
+ if self.scaler is not None:
185
+ try: self.scaler.load_state_dict(state['scaler'])
186
+ except Exception as e: print(f'[fp16 load_state_dict err] {e}')
187
+ self.optimizer.load_state_dict(state['optimizer'])
utils/arg_util.py ADDED
@@ -0,0 +1,482 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import math
3
+ import os
4
+ import random
5
+ import subprocess
6
+ import sys
7
+ import time
8
+ from collections import OrderedDict, deque
9
+ from typing import Optional, Union
10
+
11
+ import numpy as np
12
+ import torch
13
+ from tap import Tap
14
+
15
+ import infinity.utils.dist as dist
16
+
17
+
18
+ class Args(Tap):
19
+ local_out_path: str = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'local_output') # directory for save checkpoints
20
+ data_path: str = '' # dataset
21
+ bed: str = '' # bed directory for copy checkpoints apart from local_out_path
22
+ vae_ckpt: str = '' # VAE ckpt
23
+ exp_name: str = '' # experiment name
24
+ ds: str = 'oi' # only used in GPT training::load_viz_data & FID benchmark
25
+ model: str = '' # for VAE training, 'b' or any other for GPT training
26
+ short_cap_prob: float = 0.2 # prob for training with short captions
27
+ project_name: str = 'Infinity' # name of wandb project
28
+ tf32: bool = True # whether to use TensorFloat32
29
+ auto_resume: bool = True # whether to automatically resume from the last checkpoint found in args.bed
30
+ rush_resume: str = '' # pretrained infinity checkpoint
31
+ nowd: int = 1 # whether to disable weight decay on sparse params (like class token)
32
+ enable_hybrid_shard: bool = False # whether to use hybrid FSDP
33
+ inner_shard_degree: int = 1 # inner degree for FSDP
34
+ zero: int = 0 # ds zero
35
+ buck: str = 'chunk' # =0 for using module-wise
36
+ fsdp_orig: bool = True
37
+ enable_checkpointing: str = None # checkpointing strategy: full-block, self-attn
38
+ pad_to_multiplier: int = 1 # >1 for padding the seq len to a multiplier of this
39
+ log_every_iter: bool = False
40
+ checkpoint_type: str = 'torch' # checkpoint_type: torch, onmistore
41
+ seed: int = None # 3407
42
+ rand: bool = True # actual seed = seed + (dist.get_rank()*512 if rand else 0)
43
+ device: str = 'cpu'
44
+ task_id: str = '2493513'
45
+ trial_id: str = '7260554'
46
+ robust_run_id: str = '00'
47
+ ckpt_trials = []
48
+ real_trial_id: str = '7260552'
49
+ chunk_nodes: int = None
50
+ is_master_node: bool = None
51
+ # dir
52
+ log_txt_path: str = ''
53
+ t5_path: str = '' # if not specified: automatically find from all bytenas
54
+ online_t5: bool = True # whether to use online t5 or load local features
55
+ # GPT
56
+ sdpa_mem: bool = True # whether to use with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=False, enable_mem_efficient=True)
57
+ tfast: int = 0 # compile GPT
58
+ model_alias: str = 'b' # [automatically set; don't specify this]
59
+ rms: bool = False
60
+ aln: float = 1e-3 # multiplier of ada_lin.w's initialization
61
+ alng: float = -1 # multiplier of ada_lin.w[gamma channels]'s initialization, -1: the same as aln
62
+ saln: bool = False # whether to use a shared adaln layer
63
+ haln: bool = True # whether to use a specific adaln layer in head layer
64
+ nm0: bool = False # norm before word proj linear
65
+ tau: float = 1 # tau of self attention in GPT
66
+ cos: bool = True # cosine attn as in swin v2
67
+ swi: bool = False # whether to use FFNSwiGLU, instead of vanilla FFN
68
+ dp: float = -1
69
+ drop: float = 0.0 # GPT's dropout (VAE's is --vd)
70
+ hd: int = 0
71
+ ca_gamma: float = -1 # >=0 for using layer-scale for cross attention
72
+ diva: int = 1 # rescale_attn_fc_weights
73
+ hd0: float = 0.02 # head.w *= hd0
74
+ dec: int = 1 # dec depth
75
+ cum: int = 3 # cumulating fea map as GPT TF input, 0: not cum; 1: cum @ next hw, 2: cum @ final hw
76
+ rwe: bool = False # random word emb
77
+ tp: float = 0.0 # top-p
78
+ tk: float = 0.0 # top-k
79
+ tini: float = 0.02 # init parameters
80
+ cfg: float = 0.1 # >0: classifier-free guidance, drop cond with prob cfg
81
+ rand_uncond = False # whether to use random, unlearnable uncond embeding
82
+ ema: float = 0.9999 # VAE's ema ratio, not VAR's. 0.9977844 == 0.5 ** (32 / (10 * 1000)) from gans, 0.9999 from SD
83
+ tema: float = 0 # 0.9999 in DiffiT, DiT
84
+ fp16: int = 0 # 1: fp16, 2: bf16, >2: fp16's max scaling multiplier todo: 记得让quantize相关的feature都强制fp32!另外residueal最好也是fp32(根据flash-attention)nn.Conv2d有一个参数是use_float16?
85
+ fuse: bool = False # whether to use fused mlp
86
+ fused_norm: bool = False # whether to use fused norm
87
+ flash: bool = False # whether to use customized flash-attn kernel
88
+ xen: bool = False # whether to use xentropy
89
+ use_flex_attn: bool = False # whether to use flex_attn to speedup training
90
+ stable: bool = False
91
+ gblr: float = 1e-4
92
+ dblr: float = None # =gblr if is None
93
+ tblr: float = 6e-4
94
+ glr: float = None
95
+ dlr: float = None
96
+ tlr: float = None # vqgan: 4e-5
97
+ gwd: float = 0.005
98
+ dwd: float = 0.0005
99
+ twd: float = 0.005 # vqgan: 0.01
100
+ gwde: float = 0
101
+ dwde: float = 0
102
+ twde: float = 0
103
+ ls: float = 0.0 # label smooth
104
+ lz: float = 0.0 # z loss from PaLM = 1e-4 todo
105
+ eq: int = 0 # equalized loss
106
+ ep: int = 100
107
+ wp: float = 0
108
+ wp0: float = 0.005
109
+ wpe: float = 0.3 # 0.001, final cosine lr = wpe * peak lr
110
+ sche: str = '' # cos, exp, lin
111
+ log_freq: int = 50 # log frequency in the stdout
112
+ gclip: float = 6. # <=0 for not grad clip VAE
113
+ dclip: float = 6. # <=0 for not grad clip discriminator
114
+ tclip: float = 2. # <=0 for not grad clip GPT; >100 for per-param clip (%= 100 automatically)
115
+ cdec: bool = False # decay the grad clip thresholds of GPT and GPT's word embed
116
+ opt: str = 'adamw' # lion: https://cloud.tencent.com/developer/article/2336657?areaId=106001 lr=5e-5(比Adam学习率低四倍)和wd=0.8(比Adam高八倍);比如在小的 batch_size 时,Lion 的表现不如 AdamW
117
+ ada: str = '' # adam's beta0 and beta1 for VAE or GPT, '0_0.99' from style-swin and magvit, '0.5_0.9' from VQGAN
118
+ dada: str = '' # adam's beta0 and beta1 for discriminator
119
+ oeps: float = 0 # adam's eps, pixart uses 1e-10
120
+ afuse: bool = True # fused adam
121
+ # data
122
+ pn: str = '' # pixel nums, choose from 0.06M, 0.25M, 1M
123
+ scale_schedule: tuple = None # [automatically set; don't specify this] = tuple(map(int, args.pn.replace('-', '_').split('_')))
124
+ patch_size: int = None # [automatically set; don't specify this] = 2 ** (len(args.scale_schedule) - 1)
125
+ resos: tuple = None # [automatically set; don't specify this]
126
+ data_load_reso: int = None # [automatically set; don't specify this]
127
+ workers: int = 0 # num workers; 0: auto, -1: don't use multiprocessing in DataLoader
128
+ lbs: int = 0 # local batch size; if lbs != 0, bs will be ignored, and will be reset as round(args.lbs / args.ac) * dist.get_world_size()
129
+ bs: int = 0 # global batch size; if lbs != 0, bs will be ignored
130
+ batch_size: int = 0 # [automatically set; don't specify this] batch size per GPU = round(args.bs / args.ac / dist.get_world_size())
131
+ glb_batch_size: int = 0 # [automatically set; don't specify this] global batch size = args.batch_size * dist.get_world_size()
132
+ ac: int = 1 # gradient accumulation
133
+ r_accu: float = 1.0 # [automatically set; don't specify this] = 1 / args.ac
134
+ norm_eps: float = 1e-6 # norm eps for infinity
135
+ tlen: int = 512 # truncate text embedding to this length
136
+ Ct5: int = 2048 # feature dimension of text encoder
137
+ use_bit_label: int = 1 # pred bitwise labels or index-wise labels
138
+ bitloss_type: str = 'mean' # mean or sum
139
+ dynamic_resolution_across_gpus: int = 1 # allow dynamic resolution across gpus
140
+ enable_dynamic_length_prompt: int = 0 # enable dynamic length prompt during training
141
+ use_streaming_dataset: int = 0 # use streaming dataset
142
+ iterable_data_buffersize: int = 90000 # streaming dataset buffer size
143
+ save_model_iters_freq: int = 1000 # save model iter freq
144
+ noise_apply_layers: int = -1 # Bitwise Self-Correction: apply noise to layers, -1 means not apply noise
145
+ noise_apply_strength: float = -1 # Bitwise Self-Correction: apply noise strength, -1 means not apply noise
146
+ noise_apply_requant: int = 1 # Bitwise Self-Correction: requant after apply noise
147
+ rope2d_each_sa_layer: int = 0 # apply rope2d to each self-attention layer
148
+ rope2d_normalized_by_hw: int = 1 # apply normalized rope2d
149
+ use_fsdp_model_ema: int = 0 # use fsdp model ema
150
+ add_lvl_embeding_only_first_block: int = 1 # apply lvl pe embedding only first block or each block
151
+ reweight_loss_by_scale: int = 0 # reweight loss by scale
152
+ always_training_scales: int = 100 # trunc training scales
153
+ vae_type: int = 1 # here 16/32/64 is bsq vae of different quant bits
154
+ fake_vae_input: bool = False # fake vae input for debug
155
+ model_init_device: str = 'cuda' # model_init_device
156
+ prefetch_factor: int = 2 # prefetch_factor for dataset
157
+ apply_spatial_patchify: int = 0 # apply apply_spatial_patchify or not
158
+ debug_bsc: int = 0 # save figs and set breakpoint for debug bsc and check input
159
+ task_type: str = 't2i' # take type to t2i or t2v
160
+
161
+
162
+ ############################ Attention! The following arguments and configurations are set automatically, you can skip reading the following part ###############################
163
+ ############################ Attention! The following arguments and configurations are set automatically, you can skip reading the following part ###############################
164
+ ############################ Attention! The following arguments and configurations are set automatically, you can skip reading the following part ###############################
165
+
166
+
167
+ # would be automatically set in runtime
168
+ branch: str = subprocess.check_output(f'git symbolic-ref --short HEAD 2>/dev/null || git rev-parse HEAD', shell=True).decode('utf-8').strip() or '[unknown]' # [automatically set; don't specify this]
169
+ commit_id: str = '' # subprocess.check_output(f'git rev-parse HEAD', shell=True).decode('utf-8').strip() or '[unknown]' # [automatically set; don't specify this]
170
+ commit_msg: str = ''# (subprocess.check_output(f'git log -1', shell=True).decode('utf-8').strip().splitlines() or ['[unknown]'])[-1].strip() # [automatically set; don't specify this]
171
+ cmd: str = ' '.join(a.replace('--exp_name=', '').replace('--exp_name ', '') for a in sys.argv[7:]) # [automatically set; don't specify this]
172
+ tag: str = 'UK' # [automatically set; don't specify this]
173
+ acc_all: float = None # [automatically set; don't specify this]
174
+ acc_real: float = None # [automatically set; don't specify this]
175
+ acc_fake: float = None # [automatically set; don't specify this]
176
+ last_Lnll: float = None # [automatically set; don't specify this]
177
+ last_L1: float = None # [automatically set; don't specify this]
178
+ last_Ld: float = None # [automatically set; don't specify this]
179
+ last_wei_g: float = None # [automatically set; don't specify this]
180
+ grad_boom: str = None # [automatically set; don't specify this]
181
+ diff: float = None # [automatically set; don't specify this]
182
+ diffs: str = '' # [automatically set; don't specify this]
183
+ diffs_ema: str = None # [automatically set; don't specify this]
184
+ ca_performance: str = '' # [automatically set; don't specify this]
185
+ cur_phase: str = '' # [automatically set; don't specify this]
186
+ cur_it: str = '' # [automatically set; don't specify this]
187
+ cur_ep: str = '' # [automatically set; don't specify this]
188
+ remain_time: str = '' # [automatically set; don't specify this]
189
+ finish_time: str = '' # [automatically set; don't specify this]
190
+ iter_speed: float = None # [automatically set; don't specify this]
191
+ img_per_day: float = None # [automatically set; don't specify this]
192
+ max_nvidia_smi: float = 0 # [automatically set; don't specify this]
193
+ max_memory_allocated: float = None # [automatically set; don't specify this]
194
+ max_memory_reserved: float = None # [automatically set; don't specify this]
195
+ num_alloc_retries: int = None # [automatically set; don't specify this]
196
+ MFU: float = None # [automatically set; don't specify this]
197
+ HFU: float = None # [automatically set; don't specify this]
198
+ # ==================================================================================================================
199
+ # ======================== ignore these parts below since they are only for debug use ==============================
200
+ # ==================================================================================================================
201
+ dbg_modified: bool = False
202
+ dbg_ks: bool = False
203
+ dbg_ks_last = None
204
+ dbg_ks_fp = None
205
+ def dbg_ks_this_line(self, g_it: int):
206
+ if self.dbg_ks:
207
+ if self.dbg_ks_last is None:
208
+ self.dbg_ks_last = deque(maxlen=6)
209
+
210
+ from utils.misc import time_str
211
+ self.dbg_ks_fp.seek(0)
212
+ f_back = sys._getframe().f_back
213
+ file_desc = f'{f_back.f_code.co_filename:24s}'[-24:]
214
+ info = f'{time_str()} ({file_desc}, line{f_back.f_lineno:-4d})'
215
+ if g_it is not None:
216
+ info += f' [g_it: {g_it}]'
217
+
218
+ self.dbg_ks_last.append(info)
219
+ self.dbg_ks_fp.write('\n'.join(self.dbg_ks_last) + '\n')
220
+ self.dbg_ks_fp.flush()
221
+
222
+ dbg: bool = 'KEVIN_LOCAL' in os.environ # only used when debug about unused param in DDP
223
+ ks: bool = False
224
+ nodata: bool = False # if True, will set nova=True as well
225
+ nodata_tlen: int = 320
226
+ nova: bool = False # no val, no FID
227
+ prof: int = 0 # profile
228
+ prof_freq: int = 50 # profile
229
+ tos_profiler_file_prefix: str = 'vgpt_default/'
230
+ profall: int = 0
231
+ @property
232
+ def is_vae_visualization_only(self) -> bool:
233
+ return self.v_seed > 0
234
+ v_seed: int = 0 # v_seed != 0 means the visualization-only mode
235
+ @property
236
+ def is_gpt_visualization_only(self) -> bool:
237
+ return self.g_seed > 0
238
+ g_seed: int = 0 # g_seed != 0 means the visualization-only mode
239
+ # ==================================================================================================================
240
+ # ======================== ignore these parts above since they are only for debug use ==============================
241
+ # ==================================================================================================================
242
+
243
+ @property
244
+ def gpt_training(self):
245
+ return len(self.model) > 0
246
+
247
+ def set_initial_seed(self, benchmark: bool):
248
+ torch.backends.cudnn.enabled = True
249
+ torch.backends.cudnn.benchmark = benchmark
250
+ if self.seed is None:
251
+ torch.backends.cudnn.deterministic = False
252
+ else:
253
+ seed = self.seed + (dist.get_rank()*512 if self.rand else 0)
254
+ torch.backends.cudnn.deterministic = True
255
+ os.environ['PYTHONHASHSEED'] = str(seed)
256
+ random.seed(seed)
257
+ np.random.seed(seed)
258
+ torch.manual_seed(seed)
259
+ if torch.cuda.is_available():
260
+ torch.cuda.manual_seed(seed)
261
+ torch.cuda.manual_seed_all(seed)
262
+
263
+ def get_different_generator_for_each_rank(self) -> Optional[torch.Generator]: # for random augmentation
264
+ if self.seed is None:
265
+ return None
266
+ g = torch.Generator()
267
+ g.manual_seed(self.seed + dist.get_rank()*512)
268
+ return g
269
+
270
+ def compile_model(self, m, fast):
271
+ if fast == 0:
272
+ return m
273
+ return torch.compile(m, mode={
274
+ 1: 'reduce-overhead',
275
+ 2: 'max-autotune',
276
+ 3: 'default',
277
+ }[fast]) if hasattr(torch, 'compile') else m
278
+
279
+ def dump_log(self):
280
+ if not dist.is_local_master():
281
+ return
282
+ nd = {'is_master': dist.is_visualizer()}
283
+ r_trial, trial = str(self.real_trial_id), str(self.trial_id)
284
+ for k, v in {
285
+ 'name': self.exp_name, 'tag': self.tag, 'cmd': self.cmd, 'commit': self.commit_id, 'branch': self.branch,
286
+ 'Lnll': self.last_Lnll, 'L1': self.last_L1,
287
+ 'Ld': self.last_Ld,
288
+ 'acc': self.acc_all, 'acc_r': self.acc_real, 'acc_f': self.acc_fake,
289
+ 'weiG': self.last_wei_g if (self.last_wei_g is None or math.isfinite(self.last_wei_g)) else -23333,
290
+ 'grad': self.grad_boom,
291
+
292
+ 'cur': self.cur_phase, 'cur_ep': self.cur_ep, 'cur_it': self.cur_it,
293
+ 'rema': self.remain_time, 'fini': self.finish_time, 'last_upd': time.strftime("%Y-%m-%d %H:%M", time.localtime()),
294
+ 'bsep': f'{self.glb_batch_size}/{self.ep}',
295
+ 'G_lrwd': f'{self.glr:.1e}'.replace('.0', '').replace('-0', '-').replace('+0', '+') + f'/{self.gwd:g}',
296
+ 'D_lrwd': f'{self.dlr:.1e}'.replace('.0', '').replace('-0', '-').replace('+0', '+') + f'/{self.dwd:g}',
297
+ 'T_lrwd': f'{self.tlr:.1e}'.replace('.0', '').replace('-0', '-').replace('+0', '+') + f'/{self.twd:g}',
298
+ 'diff': self.diff, 'diffs': self.diffs, 'diffs_ema': self.diffs_ema if self.diffs_ema else None,
299
+ 'opt': self.opt,
300
+ 'is_master_node': self.is_master_node,
301
+ }.items():
302
+ if hasattr(v, 'item'):v = v.item()
303
+ if v is None or (isinstance(v, str) and len(v) == 0): continue
304
+ nd[k] = v
305
+ if r_trial == trial:
306
+ nd.pop('trial', None)
307
+
308
+ with open(self.log_txt_path, 'w') as fp:
309
+ json.dump(nd, fp, indent=2)
310
+
311
+ def touch_log(self): # listener will kill me if log_txt_path is not updated for 120s
312
+ os.utime(self.log_txt_path) # about 2e-6 sec
313
+
314
+ def state_dict(self, key_ordered=True) -> Union[OrderedDict, dict]:
315
+ d = (OrderedDict if key_ordered else dict)()
316
+ # self.as_dict() would contain methods, but we only need variables
317
+ for k in self.class_variables.keys():
318
+ if k not in {'device', 'dbg_ks_fp'}: # these are not serializable
319
+ d[k] = getattr(self, k)
320
+ return d
321
+
322
+ def load_state_dict(self, d: Union[OrderedDict, dict, str]):
323
+ if isinstance(d, str): # for compatibility with old version
324
+ d: dict = eval('\n'.join([l for l in d.splitlines() if '<bound' not in l and 'device(' not in l]))
325
+ for k in d.keys():
326
+ if k in {'is_large_model', 'gpt_training'}:
327
+ continue
328
+ try:
329
+ setattr(self, k, d[k])
330
+ except Exception as e:
331
+ print(f'k={k}, v={d[k]}')
332
+ raise e
333
+
334
+ @staticmethod
335
+ def set_tf32(tf32: bool):
336
+ if torch.cuda.is_available():
337
+ torch.backends.cudnn.allow_tf32 = bool(tf32)
338
+ torch.backends.cuda.matmul.allow_tf32 = bool(tf32)
339
+ if hasattr(torch, 'set_float32_matmul_precision'):
340
+ torch.set_float32_matmul_precision('high' if tf32 else 'highest')
341
+ print(f'[tf32] [precis] torch.get_float32_matmul_precision(): {torch.get_float32_matmul_precision()}')
342
+ print(f'[tf32] [ conv ] torch.backends.cudnn.allow_tf32: {torch.backends.cudnn.allow_tf32}')
343
+ print(f'[tf32] [matmul] torch.backends.cuda.matmul.allow_tf32: {torch.backends.cuda.matmul.allow_tf32}')
344
+
345
+ def __str__(self):
346
+ s = []
347
+ for k in self.class_variables.keys():
348
+ if k not in {'device', 'dbg_ks_fp'}: # these are not serializable
349
+ s.append(f' {k:20s}: {getattr(self, k)}')
350
+ s = '\n'.join(s)
351
+ return f'{{\n{s}\n}}\n'
352
+
353
+
354
+ def init_dist_and_get_args():
355
+ for i in range(len(sys.argv)):
356
+ if sys.argv[i].startswith('--local-rank=') or sys.argv[i].startswith('--local_rank='):
357
+ del sys.argv[i]
358
+ break
359
+ args = Args(explicit_bool=True).parse_args(known_only=True)
360
+ args.chunk_nodes = int(os.environ.get('CK', '') or '0')
361
+
362
+ if len(args.extra_args) > 0 and args.is_master_node == 0:
363
+ print(f'======================================================================================')
364
+ print(f'=========================== WARNING: UNEXPECTED EXTRA ARGS ===========================\n{args.extra_args}')
365
+ print(f'=========================== WARNING: UNEXPECTED EXTRA ARGS ===========================')
366
+ print(f'======================================================================================\n\n')
367
+
368
+ args.set_tf32(args.tf32)
369
+ if args.dbg:
370
+ torch.autograd.set_detect_anomaly(True)
371
+
372
+ try: os.makedirs(args.bed, exist_ok=True)
373
+ except: pass
374
+ try: os.makedirs(args.local_out_path, exist_ok=True)
375
+ except: pass
376
+
377
+ day3 = 60*24*3
378
+ dist.init_distributed_mode(local_out_path=args.local_out_path, fork=False, timeout_minutes=day3 if int(os.environ.get('LONG_DBG', '0') or '0') > 0 else 30)
379
+
380
+ args.tlen = max(args.tlen, args.nodata_tlen)
381
+ if args.zero and args.tema != 0:
382
+ args.tema = 0
383
+ print(f'======================================================================================')
384
+ print(f'======================== WARNING: args.tema:=0, due to zero={args.zero} ========================')
385
+ print(f'======================================================================================\n\n')
386
+
387
+ if args.nodata:
388
+ args.nova = True
389
+
390
+ if not args.tos_profiler_file_prefix.endswith('/'): args.tos_profiler_file_prefix += '/'
391
+
392
+ if args.alng < 0:
393
+ args.alng = args.aln
394
+
395
+ args.device = dist.get_device()
396
+ args.r_accu = 1 / args.ac # gradient accumulation
397
+ args.data_load_reso = None
398
+ args.rand |= args.seed is None
399
+ args.sche = args.sche or ('lin0' if args.gpt_training else 'cos')
400
+ if args.wp == 0:
401
+ args.wp = args.ep * 1/100
402
+
403
+ di = {
404
+ 'b': 'bilinear', 'c': 'bicubic', 'n': 'nearest', 'a': 'area', 'aa': 'area+area',
405
+ 'at': 'auto', 'auto': 'auto',
406
+ 'v': 'vae',
407
+ 'x': 'pix', 'xg': 'pix_glu', 'gx': 'pix_glu', 'g': 'pix_glu'
408
+ }
409
+
410
+ args.ada = args.ada or ('0.9_0.96' if args.gpt_training else '0.5_0.9')
411
+ args.dada = args.dada or args.ada
412
+ args.opt = args.opt.lower().strip()
413
+
414
+ if args.lbs:
415
+ bs_per_gpu = args.lbs / args.ac
416
+ else:
417
+ bs_per_gpu = args.bs / args.ac / dist.get_world_size()
418
+ bs_per_gpu = round(bs_per_gpu)
419
+ args.batch_size = bs_per_gpu
420
+ args.bs = args.glb_batch_size = args.batch_size * dist.get_world_size()
421
+ args.workers = min(args.workers, bs_per_gpu)
422
+ args.dblr = args.dblr or args.gblr
423
+ args.glr = args.ac * args.gblr * args.glb_batch_size / 256
424
+ args.dlr = args.ac * args.dblr * args.glb_batch_size / 256
425
+ args.tlr = args.ac * args.tblr * args.glb_batch_size / 256
426
+ args.gwde = args.gwde or args.gwd
427
+ args.dwde = args.dwde or args.dwd
428
+ args.twde = args.twde or args.twd
429
+
430
+ if args.dbg_modified:
431
+ torch.autograd.set_detect_anomaly(True)
432
+ args.dbg_ks &= dist.is_local_master()
433
+ if args.dbg_ks:
434
+ args.dbg_ks_fp = open(os.path.join(args.local_out_path, 'dbg_ks.txt'), 'w')
435
+
436
+ # gpt args
437
+ if args.gpt_training:
438
+ assert args.vae_ckpt, 'VAE ckpt must be specified when training GPT'
439
+ from infinity.models import alias_dict, alias_dict_inv
440
+ if args.model in alias_dict:
441
+ args.model = alias_dict[args.model]
442
+ args.model_alias = alias_dict_inv[args.model]
443
+ else:
444
+ args.model_alias = args.model
445
+ args.model = f'infinity_{args.model}'
446
+
447
+ args.task_id = '123'
448
+ args.trial_id = '123'
449
+ args.robust_run_id = '0'
450
+ args.log_txt_path = os.path.join(args.local_out_path, 'log.txt')
451
+
452
+ ls = '[]'
453
+ if 'AUTO_RESUME' in os.environ:
454
+ ls.append(int(os.environ['AUTO_RESUME']))
455
+ ls = sorted(ls, reverse=True)
456
+ ls = [str(i) for i in ls]
457
+ args.ckpt_trials = ls
458
+ args.real_trial_id = args.trial_id if len(ls) == 0 else str(ls[-1])
459
+
460
+ args.enable_checkpointing = None if args.enable_checkpointing in [False, 0, "0"] else args.enable_checkpointing
461
+ args.enable_checkpointing = "full-block" if args.enable_checkpointing in [True, 1, "1"] else args.enable_checkpointing
462
+ assert args.enable_checkpointing in [None, "full-block", "full-attn", "self-attn"], \
463
+ f"only support no-checkpointing or full-block/full-attn checkpointing, but got {args.enable_checkpointing}."
464
+
465
+ if len(args.exp_name) == 0:
466
+ args.exp_name = os.path.basename(args.bed) or 'test_exp'
467
+
468
+ if '-' in args.exp_name:
469
+ args.tag, args.exp_name = args.exp_name.split('-', maxsplit=1)
470
+ else:
471
+ args.tag = 'UK'
472
+
473
+ if dist.is_master():
474
+ os.system(f'rm -rf {os.path.join(args.bed, "ready-node*")} {os.path.join(args.local_out_path, "ready-node*")}')
475
+
476
+ if args.sdpa_mem:
477
+ from torch.backends.cuda import enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp
478
+ enable_flash_sdp(True)
479
+ enable_mem_efficient_sdp(True)
480
+ enable_math_sdp(False)
481
+
482
+ return args
utils/csv_util.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import os.path as osp
3
+ import csv
4
+
5
+ import numpy as np
6
+
7
+
8
+ def write_dicts2csv_file(input_dict_list, csv_filename):
9
+ os.makedirs(osp.dirname(csv_filename), exist_ok=True)
10
+ with open(csv_filename, mode='w', newline='', encoding='utf-8') as file:
11
+ fieldnames = input_dict_list[0].keys()
12
+ writer = csv.DictWriter(file, fieldnames=fieldnames)
13
+ writer.writeheader()
14
+ writer.writerows(input_dict_list)
15
+ print(f'"{csv_filename}" has been written.')
16
+
17
+ def load_csv_as_dicts(csv_filename):
18
+ with open(csv_filename, mode='r', newline='', encoding='utf-8') as csvfile:
19
+ reader = csv.DictReader(csvfile)
20
+ return list(reader)
utils/dist.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import functools
3
+ import os
4
+ import sys
5
+ from typing import List
6
+ from typing import Union
7
+
8
+ import pytz
9
+ import torch
10
+ import torch.distributed as tdist
11
+ import torch.multiprocessing as mp
12
+
13
+
14
+ __rank, __local_rank, __world_size, __device = 0, 0, 1, 'cpu'
15
+ __rank_str_zfill = '0'
16
+ __initialized = False
17
+
18
+
19
+ def initialized():
20
+ return __initialized
21
+
22
+
23
+ def __initialize(fork=False, backend='nccl', gpu_id_if_not_distibuted=0, timeout_minutes=30):
24
+ global __device
25
+ if not torch.cuda.is_available():
26
+ print(f'[dist initialize] cuda is not available, use cpu instead', file=sys.stderr)
27
+ return
28
+ elif 'RANK' not in os.environ:
29
+ torch.cuda.set_device(gpu_id_if_not_distibuted)
30
+ __device = torch.empty(1).cuda().device
31
+ print(f'[dist initialize] env variable "RANK" is not set, use {__device} as the device', file=sys.stderr)
32
+ return
33
+ # then 'RANK' must exist
34
+ global_rank, num_gpus = int(os.environ['RANK']), torch.cuda.device_count()
35
+ local_rank = global_rank % num_gpus
36
+ torch.cuda.set_device(local_rank)
37
+
38
+ # ref: https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py#L29
39
+ """
40
+ if mp.get_start_method(allow_none=True) is None:
41
+ method = 'fork' if fork else 'spawn'
42
+ print(f'[dist initialize] mp method={method}')
43
+ mp.set_start_method(method)
44
+ """
45
+ tdist.init_process_group(backend=backend, timeout=datetime.timedelta(seconds=timeout_minutes * 60))
46
+
47
+ global __rank, __local_rank, __world_size, __initialized, __rank_str_zfill
48
+ __local_rank = local_rank
49
+ __rank, __world_size = tdist.get_rank(), tdist.get_world_size()
50
+ __rank_str_zfill = str(__rank).zfill(len(str(__world_size)))
51
+ __device = torch.device(local_rank)
52
+ __initialized = True
53
+
54
+ assert tdist.is_initialized(), 'torch.distributed is not initialized!'
55
+ print(f'[lrk={get_local_rank()}, rk={get_rank()}]')
56
+
57
+
58
+ def get_rank():
59
+ return __rank
60
+
61
+
62
+ def get_rank_given_group(group: tdist.ProcessGroup):
63
+ return tdist.get_rank(group=group)
64
+
65
+
66
+ def get_rank_str_zfill():
67
+ return __rank_str_zfill
68
+
69
+
70
+ def get_local_rank():
71
+ return __local_rank
72
+
73
+
74
+ def get_world_size():
75
+ return __world_size
76
+
77
+
78
+ def get_device():
79
+ return __device
80
+
81
+
82
+ def set_gpu_id(gpu_id: int):
83
+ if gpu_id is None: return
84
+ global __device
85
+ if isinstance(gpu_id, (str, int)):
86
+ torch.cuda.set_device(int(gpu_id))
87
+ __device = torch.empty(1).cuda().device
88
+ else:
89
+ raise NotImplementedError
90
+
91
+
92
+ def is_master():
93
+ return __rank == 0
94
+
95
+
96
+ def is_local_master():
97
+ return __local_rank == 0
98
+
99
+
100
+ def is_visualizer():
101
+ return __rank == 0
102
+ # return __rank == max(__world_size - 8, 0)
103
+
104
+
105
+ def parallelize(net, syncbn=False):
106
+ if syncbn:
107
+ net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(net)
108
+ net = net.cuda()
109
+ net = torch.nn.parallel.DistributedDataParallel(net, device_ids=[get_local_rank()], find_unused_parameters=False, broadcast_buffers=False)
110
+ return net
111
+
112
+
113
+ def new_group(ranks: List[int]):
114
+ if __initialized:
115
+ return tdist.new_group(ranks=ranks)
116
+ return None
117
+
118
+
119
+ def new_local_machine_group():
120
+ if __initialized:
121
+ cur_subgroup, subgroups = tdist.new_subgroups()
122
+ return cur_subgroup
123
+ return None
124
+
125
+
126
+ def barrier():
127
+ if __initialized:
128
+ tdist.barrier()
129
+
130
+
131
+ def allreduce(t: torch.Tensor, async_op=False):
132
+ if __initialized:
133
+ if not t.is_cuda:
134
+ cu = t.detach().cuda()
135
+ ret = tdist.all_reduce(cu, async_op=async_op)
136
+ t.copy_(cu.cpu())
137
+ else:
138
+ ret = tdist.all_reduce(t, async_op=async_op)
139
+ return ret
140
+ return None
141
+
142
+
143
+ def allgather(t: torch.Tensor, cat=True) -> Union[List[torch.Tensor], torch.Tensor]:
144
+ if __initialized:
145
+ if not t.is_cuda:
146
+ t = t.cuda()
147
+ ls = [torch.empty_like(t) for _ in range(__world_size)]
148
+ tdist.all_gather(ls, t)
149
+ else:
150
+ ls = [t]
151
+ if cat:
152
+ ls = torch.cat(ls, dim=0)
153
+ return ls
154
+
155
+
156
+ def allgather_diff_shape(t: torch.Tensor, cat=True) -> Union[List[torch.Tensor], torch.Tensor]:
157
+ if __initialized:
158
+ if not t.is_cuda:
159
+ t = t.cuda()
160
+
161
+ t_size = torch.tensor(t.size(), device=t.device)
162
+ ls_size = [torch.empty_like(t_size) for _ in range(__world_size)]
163
+ tdist.all_gather(ls_size, t_size)
164
+
165
+ max_B = max(size[0].item() for size in ls_size)
166
+ pad = max_B - t_size[0].item()
167
+ if pad:
168
+ pad_size = (pad, *t.size()[1:])
169
+ t = torch.cat((t, t.new_empty(pad_size)), dim=0)
170
+
171
+ ls_padded = [torch.empty_like(t) for _ in range(__world_size)]
172
+ tdist.all_gather(ls_padded, t)
173
+ ls = []
174
+ for t, size in zip(ls_padded, ls_size):
175
+ ls.append(t[:size[0].item()])
176
+ else:
177
+ ls = [t]
178
+ if cat:
179
+ ls = torch.cat(ls, dim=0)
180
+ return ls
181
+
182
+
183
+ def broadcast(t: torch.Tensor, src_rank) -> None:
184
+ if __initialized:
185
+ if not t.is_cuda:
186
+ cu = t.detach().cuda()
187
+ tdist.broadcast(cu, src=src_rank)
188
+ t.copy_(cu.cpu())
189
+ else:
190
+ tdist.broadcast(t, src=src_rank)
191
+
192
+
193
+ def dist_fmt_vals(val: float, fmt: Union[str, None] = '%.2f') -> Union[torch.Tensor, List]:
194
+ if not initialized():
195
+ return torch.tensor([val]) if fmt is None else [fmt % val]
196
+
197
+ ts = torch.zeros(__world_size)
198
+ ts[__rank] = val
199
+ allreduce(ts)
200
+ if fmt is None:
201
+ return ts
202
+ return [fmt % v for v in ts.cpu().numpy().tolist()]
203
+
204
+
205
+ def master_only(func):
206
+ @functools.wraps(func)
207
+ def wrapper(*args, **kwargs):
208
+ force = kwargs.pop('force', False)
209
+ if force or is_master():
210
+ ret = func(*args, **kwargs)
211
+ else:
212
+ ret = None
213
+ barrier()
214
+ return ret
215
+ return wrapper
216
+
217
+
218
+ def local_master_only(func):
219
+ @functools.wraps(func)
220
+ def wrapper(*args, **kwargs):
221
+ force = kwargs.pop('force', False)
222
+ if force or is_local_master():
223
+ ret = func(*args, **kwargs)
224
+ else:
225
+ ret = None
226
+ barrier()
227
+ return ret
228
+ return wrapper
229
+
230
+
231
+ def for_visualize(func):
232
+ @functools.wraps(func)
233
+ def wrapper(*args, **kwargs):
234
+ if is_visualizer():
235
+ # with torch.no_grad():
236
+ ret = func(*args, **kwargs)
237
+ else:
238
+ ret = None
239
+ return ret
240
+ return wrapper
241
+
242
+
243
+ def finalize():
244
+ if __initialized:
245
+ tdist.destroy_process_group()
246
+
247
+
248
+ def init_distributed_mode(local_out_path, fork=False, only_sync_master=False, timeout_minutes=30):
249
+ try:
250
+ __initialize(fork=fork, timeout_minutes=timeout_minutes)
251
+ barrier()
252
+ except RuntimeError as e:
253
+ print(f'{"!"*80} dist init error (NCCL Error?), stopping training! {"!"*80}', flush=True)
254
+ raise e
255
+
256
+ if local_out_path is not None: os.makedirs(local_out_path, exist_ok=True)
257
+ _change_builtin_print(is_local_master())
258
+ if (is_master() if only_sync_master else is_local_master()) and local_out_path is not None and len(local_out_path):
259
+ sys.stdout, sys.stderr = BackupStreamToFile(local_out_path, for_stdout=True), BackupStreamToFile(local_out_path, for_stdout=False)
260
+
261
+
262
+ def _change_builtin_print(is_master):
263
+ import builtins as __builtin__
264
+
265
+ builtin_print = __builtin__.print
266
+ if type(builtin_print) != type(open):
267
+ return
268
+
269
+ def prt(*args, **kwargs):
270
+ force = kwargs.pop('force', False)
271
+ clean = kwargs.pop('clean', False)
272
+ deeper = kwargs.pop('deeper', False)
273
+ if is_master or force:
274
+ if not clean:
275
+ f_back = sys._getframe().f_back
276
+ if deeper and f_back.f_back is not None:
277
+ f_back = f_back.f_back
278
+ file_desc = f'{f_back.f_code.co_filename:24s}'[-24:]
279
+ time_str = datetime.datetime.now(tz=pytz.timezone('Asia/Shanghai')).strftime('[%m-%d %H:%M:%S]')
280
+ builtin_print(f'{time_str} ({file_desc}, line{f_back.f_lineno:-4d})=>', *args, **kwargs)
281
+ else:
282
+ builtin_print(*args, **kwargs)
283
+
284
+ __builtin__.print = prt
285
+
286
+
287
+ class BackupStreamToFile(object):
288
+ def __init__(self, local_output_dir, for_stdout=True):
289
+ self.for_stdout = for_stdout
290
+ self.terminal_stream = sys.stdout if for_stdout else sys.stderr
291
+ fname = os.path.join(local_output_dir, 'b1_stdout.txt' if for_stdout else 'b2_stderr.txt')
292
+ existing = os.path.exists(fname)
293
+ self.file_stream = open(fname, 'a')
294
+ if existing:
295
+ time_str = datetime.datetime.now(tz=pytz.timezone('Asia/Shanghai')).strftime('[%m-%d %H:%M:%S]')
296
+ self.file_stream.write('\n'*7 + '='*55 + f' RESTART {time_str} ' + '='*55 + '\n')
297
+ self.file_stream.flush()
298
+ os.system(f'ln -s {fname} /opt/tiger/run_trial/ >/dev/null 2>&1')
299
+ self.enabled = True
300
+
301
+ def write(self, message):
302
+ self.terminal_stream.write(message)
303
+ self.file_stream.write(message)
304
+
305
+ def flush(self):
306
+ self.terminal_stream.flush()
307
+ self.file_stream.flush()
308
+
309
+ def isatty(self):
310
+ return True
311
+
312
+ def close(self):
313
+ if not self.enabled:
314
+ return
315
+ self.enabled = False
316
+ self.file_stream.flush()
317
+ self.file_stream.close()
318
+ if self.for_stdout:
319
+ sys.stdout = self.terminal_stream
320
+ sys.stdout.flush()
321
+ else:
322
+ sys.stderr = self.terminal_stream
323
+ sys.stderr.flush()
324
+
325
+ def __del__(self):
326
+ self.close()
utils/dynamic_resolution.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import numpy as np
3
+ import tqdm
4
+
5
+ vae_stride = 16
6
+ ratio2hws = {
7
+ 1.000: [(1,1),(2,2),(4,4),(6,6),(8,8),(12,12),(16,16),(20,20),(24,24),(32,32),(40,40),(48,48),(64,64)],
8
+ 1.250: [(1,1),(2,2),(3,3),(5,4),(10,8),(15,12),(20,16),(25,20),(30,24),(35,28),(45,36),(55,44),(70,56)],
9
+ 1.333: [(1,1),(2,2),(4,3),(8,6),(12,9),(16,12),(20,15),(24,18),(28,21),(36,27),(48,36),(60,45),(72,54)],
10
+ 1.500: [(1,1),(2,2),(3,2),(6,4),(9,6),(15,10),(21,14),(27,18),(33,22),(39,26),(48,32),(63,42),(78,52)],
11
+ 1.750: [(1,1),(2,2),(3,3),(7,4),(11,6),(14,8),(21,12),(28,16),(35,20),(42,24),(56,32),(70,40),(84,48)],
12
+ 2.000: [(1,1),(2,2),(4,2),(6,3),(10,5),(16,8),(22,11),(30,15),(38,19),(46,23),(60,30),(74,37),(90,45)],
13
+ 2.500: [(1,1),(2,2),(5,2),(10,4),(15,6),(20,8),(25,10),(30,12),(40,16),(50,20),(65,26),(80,32),(100,40)],
14
+ 3.000: [(1,1),(2,2),(6,2),(9,3),(15,5),(21,7),(27,9),(36,12),(45,15),(54,18),(72,24),(90,30),(111,37)],
15
+ }
16
+ predefined_t = [1, 2, 3, 4, 5, 6, 7, 9, 11, 13, 15, 17, 21]
17
+
18
+ full_ratio2hws = {}
19
+ for ratio, hws in ratio2hws.items():
20
+ full_ratio2hws[ratio] = hws
21
+ if ratio != 1.000:
22
+ full_ratio2hws[int(1/ratio*1000)/1000] = [(item[1], item[0]) for item in hws]
23
+
24
+ dynamic_resolution_h_w = {}
25
+ for ratio in full_ratio2hws:
26
+ dynamic_resolution_h_w[ratio] ={}
27
+ for ind, leng in enumerate([7, 10, 12, 13]):
28
+ h_div_w = full_ratio2hws[ratio][leng-1][0] / full_ratio2hws[ratio][leng-1][1]
29
+ assert np.abs(h_div_w-ratio) < 0.01, f'{full_ratio2hws[ratio][leng-1]}: {h_div_w} != {ratio}'
30
+ pixel = (full_ratio2hws[ratio][leng-1][0] * vae_stride, full_ratio2hws[ratio][leng-1][1] * vae_stride)
31
+ if ind == 0:
32
+ total_pixels = '0.06M'
33
+ elif ind == 1:
34
+ total_pixels = '0.25M'
35
+ elif ind == 2:
36
+ total_pixels = '0.60M'
37
+ else:
38
+ total_pixels = '1M'
39
+
40
+ scales = full_ratio2hws[ratio][:leng]
41
+ scales = [ (t, h, w) for t, (h, w) in zip(predefined_t, scales) ]
42
+ dynamic_resolution_h_w[ratio][total_pixels] = {
43
+ 'pixel': pixel,
44
+ 'scales': scales
45
+ }
46
+
47
+ h_div_w_templates = []
48
+ for h_div_w in dynamic_resolution_h_w.keys():
49
+ h_div_w_templates.append(h_div_w)
50
+ h_div_w_templates = np.array(h_div_w_templates)
51
+
52
+ def get_h_div_w_template2indices(h_div_w_list, h_div_w_templates):
53
+ indices = list(range(len(h_div_w_list)))
54
+ h_div_w_template2indices = {}
55
+ pbar = tqdm.tqdm(total=len(indices), desc='get_h_div_w_template2indices...')
56
+ for h_div_w, index in zip(h_div_w_list, indices):
57
+ pbar.update(1)
58
+ nearest_h_div_w_template_ = h_div_w_templates[np.argmin(np.abs(h_div_w-h_div_w_templates))]
59
+ if nearest_h_div_w_template_ not in h_div_w_template2indices:
60
+ h_div_w_template2indices[nearest_h_div_w_template_] = []
61
+ h_div_w_template2indices[nearest_h_div_w_template_].append(index)
62
+ for h_div_w_template_, sub_indices in h_div_w_template2indices.items():
63
+ h_div_w_template2indices[h_div_w_template_] = np.array(sub_indices)
64
+ return h_div_w_template2indices
65
+
66
+ if __name__ == '__main__':
67
+ for h_div_w_template in dynamic_resolution_h_w:
68
+ for total_pixels in dynamic_resolution_h_w[h_div_w_template]:
69
+ scales = np.array(dynamic_resolution_h_w[h_div_w_template][total_pixels]['scales'])
70
+ seq_len = np.sum(scales[:,0]*scales[:,1])
71
+ if total_pixels == '1M':
72
+ string = f'{h_div_w_template}, {total_pixels}, {dynamic_resolution_h_w[h_div_w_template][total_pixels]}, seq_len: {seq_len}'.replace(', ', ',')
73
+ print(string)
utils/large_file_util.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import os.path as osp
3
+ import time
4
+ import itertools
5
+ import shutil
6
+ import glob
7
+ import argparse
8
+
9
+ import tqdm
10
+ import numpy as np
11
+ import threading
12
+
13
+ def save_lines(lines, filename):
14
+ os.makedirs(osp.dirname(filename), exist_ok=True)
15
+ with open(filename, 'w') as f:
16
+ f.writelines(lines)
17
+ del lines
18
+
19
+ def get_part_jsonls(filepath, total_line_number, parts=512):
20
+ dirname, filename, ext = osp.dirname(filepath), osp.splitext(osp.basename(filepath))[0], osp.splitext(osp.basename(filepath))[1]
21
+ if parts == 1:
22
+ return False, {1: filepath}
23
+ save_dir = osp.join(dirname, f'{parts:04d}_parts')
24
+ chunk_id2save_files = {}
25
+ missing = False
26
+ chunk_size = int(total_line_number/parts)
27
+ for chunk_id in range(1, parts+1):
28
+ if chunk_id == parts:
29
+ num_of_lines = total_line_number - chunk_size * (parts-1)
30
+ else:
31
+ num_of_lines = chunk_size
32
+ chunk_id2save_files[chunk_id] = osp.join(save_dir, f'{filename}_{chunk_id:04d}_{parts:04d}_{num_of_lines:09d}{ext}')
33
+ if not osp.exists(chunk_id2save_files[chunk_id]):
34
+ missing = True
35
+ return missing, chunk_id2save_files
36
+
37
+ def split_large_txt_files(filepath, chunk_id2save_files):
38
+ thread_list = []
39
+ chunk_id = 1
40
+ with open(filepath, 'r') as f:
41
+ chunk = []
42
+ pbar = tqdm.tqdm(total=len(chunk_id2save_files))
43
+ for line in f:
44
+ chunk.append(line)
45
+ cur_chunk_size = int(osp.splitext(osp.basename(chunk_id2save_files[chunk_id]))[0].split('_')[-1])
46
+ if len(chunk) >= cur_chunk_size:
47
+ pbar.update(1)
48
+ thread_list.append(threading.Thread(target=save_lines, args=(chunk, chunk_id2save_files[chunk_id])))
49
+ thread_list[-1].start()
50
+ chunk = []
51
+ chunk_id += 1
52
+ if len(chunk):
53
+ import ipdb; ipdb.set_trace()
54
+ assert not len(chunk)
55
+ for thread in thread_list:
56
+ thread.join()
57
+
58
+ if __name__ == '__main__':
59
+ parser = argparse.ArgumentParser()
60
+ parser.add_argument('--jsonl_folder', type=str, default='')
61
+ parser.add_argument('--parts', type=int, default=600)
62
+ args = parser.parse_args()
63
+ for jsonl_filepath in sorted(glob.glob(osp.join(args.jsonl_folder, '*.jsonl'))):
64
+ print(jsonl_filepath)
65
+ t1 = time.time()
66
+ line_num = int(jsonl_filepath.split('_')[-1].split('.')[0])
67
+ missing, chunk_id2save_files = get_part_jsonls(jsonl_filepath, line_num, parts=args.parts)
68
+ split_large_txt_files(jsonl_filepath, chunk_id2save_files)
69
+ t2 = time.time()
70
+ print(f'split takes {t2-t1}s')
utils/load.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ import gc
3
+ import os
4
+ import os.path as osp
5
+ import random
6
+ import sys
7
+ from copy import deepcopy
8
+ from typing import Tuple, Union
9
+
10
+ import colorama
11
+ import torch
12
+ import yaml
13
+
14
+ import infinity.utils.dist as dist
15
+
16
+ from infinity.models import Infinity
17
+ from infinity.models.ema import get_ema_model
18
+ from infinity.utils import arg_util, misc
19
+ from infinity.utils.misc import os_system
20
+
21
+
22
+ def build_vae_gpt(args: arg_util.Args, vae_st: dict, skip_gpt: bool, force_flash=False, device='cuda'):
23
+ if args.vae_type in [8,16,18,20,24,32,64,128]:
24
+ from infinity.models.bsq_vae.vae import vae_model
25
+ schedule_mode = "dynamic"
26
+ codebook_dim = args.vae_type # 18
27
+ codebook_size = 2**codebook_dim
28
+ if args.apply_spatial_patchify:
29
+ patch_size = 8
30
+ encoder_ch_mult=[1, 2, 4, 4]
31
+ decoder_ch_mult=[1, 2, 4, 4]
32
+ else:
33
+ patch_size = 16
34
+ encoder_ch_mult=[1, 2, 4, 4, 4]
35
+ decoder_ch_mult=[1, 2, 4, 4, 4]
36
+ vae_local = vae_model(vae_st, schedule_mode, codebook_dim, codebook_size, patch_size=patch_size,
37
+ encoder_ch_mult=encoder_ch_mult, decoder_ch_mult=decoder_ch_mult, test_mode=True).to(args.device)
38
+ if args.fake_vae_input:
39
+ vae_local.encoder = None
40
+ vae_local.decoder = None
41
+ torch.cuda.empty_cache()
42
+ else:
43
+ raise ValueError(f"vae_type {args.vae_type} not supported")
44
+ if force_flash: args.flash = True
45
+ gpt_kw = dict(
46
+ pretrained=False, global_pool='',
47
+ text_channels=args.Ct5, text_maxlen=args.tlen,
48
+ norm_eps=args.norm_eps, rms_norm=args.rms,
49
+ shared_aln=args.saln, head_aln=args.haln,
50
+ cond_drop_rate=args.cfg, rand_uncond=args.rand_uncond, drop_rate=args.drop,
51
+ cross_attn_layer_scale=args.ca_gamma, nm0=args.nm0, tau=args.tau, cos_attn=args.cos, swiglu=args.swi,
52
+ raw_scale_schedule=args.scale_schedule,
53
+ head_depth=args.dec,
54
+ top_p=args.tp, top_k=args.tk,
55
+ customized_flash_attn=args.flash, fused_mlp=args.fuse, fused_norm=args.fused_norm,
56
+ checkpointing=args.enable_checkpointing,
57
+ pad_to_multiplier=args.pad_to_multiplier,
58
+ use_flex_attn=args.use_flex_attn,
59
+ batch_size=args.batch_size,
60
+ add_lvl_embeding_only_first_block=args.add_lvl_embeding_only_first_block,
61
+ use_bit_label=args.use_bit_label,
62
+ rope2d_each_sa_layer=args.rope2d_each_sa_layer,
63
+ rope2d_normalized_by_hw=args.rope2d_normalized_by_hw,
64
+ pn=args.pn,
65
+ train_h_div_w_list=args.train_h_div_w_list,
66
+ always_training_scales=args.always_training_scales,
67
+ apply_spatial_patchify=args.apply_spatial_patchify,
68
+ )
69
+ if args.dp >= 0: gpt_kw['drop_path_rate'] = args.dp
70
+ if args.hd > 0: gpt_kw['num_heads'] = args.hd
71
+
72
+ print(f'[create gpt_wo_ddp] constructor kw={gpt_kw}\n')
73
+ gpt_kw['vae_local'] = vae_local
74
+
75
+ model_str = args.model.replace('vgpt', 'infinity') # legacy
76
+ print(f"{model_str=}")
77
+ if model_str.rsplit('c', maxsplit=1)[-1].isdecimal():
78
+ model_str, block_chunks = model_str.rsplit('c', maxsplit=1)
79
+ block_chunks = int(block_chunks)
80
+ else:
81
+ block_chunks = 1
82
+ gpt_kw['block_chunks'] = block_chunks
83
+
84
+ from infinity.models import Infinity
85
+ from timm.models import create_model
86
+ gpt_wo_ddp: Infinity = create_model(model_str, **gpt_kw)
87
+ if args.use_fsdp_model_ema:
88
+ gpt_wo_ddp_ema = get_ema_model(gpt_wo_ddp)
89
+ else:
90
+ gpt_wo_ddp_ema = None
91
+ gpt_wo_ddp = gpt_wo_ddp.to(device)
92
+
93
+ assert all(not p.requires_grad for p in vae_local.parameters())
94
+ assert all(p.requires_grad for n, p in gpt_wo_ddp.named_parameters())
95
+
96
+ return vae_local, gpt_wo_ddp, gpt_wo_ddp_ema
97
+
98
+
99
+ if __name__ == '__main__':
100
+ ld(sys.argv[1])
utils/lr_control.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from pprint import pformat
3
+ from typing import Tuple, List, Dict, Union
4
+
5
+ import torch.nn
6
+ import infinity.utils.dist as dist
7
+
8
+
9
+ def lr_wd_annealing(sche_type: str, optimizer, peak_lr, wd, wd_end, cur_it, wp_it, max_it, wp0=0.005, wpe=0.001):
10
+ """Decay the learning rate with half-cycle cosine after warmup"""
11
+ wp_it = round(wp_it)
12
+
13
+ if cur_it < wp_it:
14
+ cur_lr = wp0 + (1-wp0) * cur_it / wp_it
15
+ else:
16
+ pasd = (cur_it - wp_it) / (max_it-1 - wp_it) # [0, 1]
17
+ rest = 1 - pasd # [1, 0]
18
+ if sche_type == 'cos':
19
+ cur_lr = wpe + (1-wpe) * (0.5 + 0.5 * math.cos(math.pi * pasd))
20
+ elif sche_type == 'lin':
21
+ T = 0.15; max_rest = 1-T
22
+ if pasd < T: cur_lr = 1
23
+ else: cur_lr = wpe + (1-wpe) * rest / max_rest # 1 to wpe
24
+ elif sche_type == 'lin0':
25
+ T = 0.05; max_rest = 1-T
26
+ if pasd < T: cur_lr = 1
27
+ else: cur_lr = wpe + (1-wpe) * rest / max_rest
28
+ elif sche_type == 'lin00':
29
+ cur_lr = wpe + (1-wpe) * rest
30
+ elif sche_type.startswith('lin'):
31
+ T = float(sche_type[3:]); max_rest = 1-T
32
+ wpe_mid = wpe + (1-wpe) * max_rest
33
+ wpe_mid = (1 + wpe_mid) / 2
34
+ if pasd < T: cur_lr = 1 + (wpe_mid-1) * pasd / T
35
+ else: cur_lr = wpe + (wpe_mid-wpe) * rest / max_rest
36
+ elif sche_type == 'exp':
37
+ T = 0.15; max_rest = 1-T
38
+ if pasd < T: cur_lr = 1
39
+ else:
40
+ expo = (pasd-T) / max_rest * math.log(wpe)
41
+ cur_lr = math.exp(expo)
42
+ else:
43
+ raise NotImplementedError(f'unknown sche_type {sche_type}')
44
+
45
+ cur_lr *= peak_lr
46
+ pasd = cur_it / (max_it-1)
47
+ cur_wd = wd_end + (wd - wd_end) * (0.5 + 0.5 * math.cos(math.pi * pasd))
48
+
49
+ inf = 1e6
50
+ min_lr, max_lr = inf, -1
51
+ min_wd, max_wd = inf, -1
52
+ for param_group in optimizer.param_groups:
53
+ param_group['lr'] = cur_lr * param_group.get('lr_sc', 1) # 'lr_sc' could be assigned
54
+ max_lr = max(max_lr, param_group['lr'])
55
+ min_lr = min(min_lr, param_group['lr'])
56
+
57
+ param_group['weight_decay'] = cur_wd * param_group.get('wd_sc', 1)
58
+ max_wd = max(max_wd, param_group['weight_decay'])
59
+ if param_group['weight_decay'] > 0:
60
+ min_wd = min(min_wd, param_group['weight_decay'])
61
+
62
+ if min_lr == inf: min_lr = -1
63
+ if min_wd == inf: min_wd = -1
64
+ return min_lr, max_lr, min_wd, max_wd
65
+
66
+
67
+ def filter_params(model, ndim_dict, nowd_keys=(), lr_scale=0.0) -> Tuple[
68
+ List[str], List[torch.nn.Parameter], List[Dict[str, Union[torch.nn.Parameter, float]]]
69
+ ]:
70
+ with_lr_scale = hasattr(model, 'get_layer_id_and_scale_exp') and 0 < lr_scale <= 1
71
+ print(f'[get_param_groups][lr decay] with_lr_scale={with_lr_scale}, lr_scale={lr_scale}')
72
+ para_groups, para_groups_dbg = {}, {}
73
+ names, paras = [], []
74
+ names_no_grad = []
75
+ count, numel = 0, 0
76
+ for name, para in model.named_parameters():
77
+ name = name.replace('_fsdp_wrapped_module.', '')
78
+ if not para.requires_grad:
79
+ names_no_grad.append(name)
80
+ continue # frozen weights
81
+ count += 1
82
+ numel += para.numel()
83
+ names.append(name)
84
+ paras.append(para)
85
+
86
+ if ndim_dict.get(name, 2) == 1 or name.endswith('bias') or any(k in name for k in nowd_keys):
87
+ cur_wd_sc, group_name = 0., 'ND'
88
+ # elif any(k in name for k in small_wd_keys):
89
+ # cur_wd_sc, group_name = small_wd, 'small_decay'
90
+ else:
91
+ cur_wd_sc, group_name = 1., 'D'
92
+
93
+ if with_lr_scale:
94
+ layer_id, scale_exp = model.get_layer_id_and_scale_exp(name)
95
+ group_name = f'layer{layer_id}_' + group_name
96
+ cur_lr_sc = lr_scale ** scale_exp
97
+ dbg = f'[layer {layer_id}][sc = {lr_scale} ** {scale_exp}]'
98
+ else:
99
+ cur_lr_sc = 1.
100
+ dbg = f'[no scale]'
101
+
102
+ if group_name not in para_groups:
103
+ para_groups[group_name] = {'params': [], 'wd_sc': cur_wd_sc, 'lr_sc': cur_lr_sc}
104
+ para_groups_dbg[group_name] = {'params': [], 'wd_sc': cur_wd_sc, 'lr_sc': dbg}
105
+ para_groups[group_name]['params'].append(para)
106
+ para_groups_dbg[group_name]['params'].append(name)
107
+
108
+ for g in para_groups_dbg.values():
109
+ g['params'] = pformat(', '.join(g['params']), width=200)
110
+
111
+ print(f'[get_param_groups] param_groups = \n{pformat(para_groups_dbg, indent=2, width=240)}\n')
112
+
113
+ for rk in range(dist.get_world_size()):
114
+ dist.barrier()
115
+ if dist.get_rank() == rk:
116
+ print(f'[get_param_groups][rank{dist.get_rank()}] {type(model).__name__=} {count=}, {numel=}', flush=True, force=True)
117
+ print('')
118
+
119
+ assert len(names_no_grad) == 0, f'[get_param_groups] names_no_grad = \n{pformat(names_no_grad, indent=2, width=240)}\n'
120
+ del ndim_dict
121
+ return names, paras, list(para_groups.values())
122
+
123
+
124
+ def plot():
125
+ import matplotlib.pyplot as plt
126
+ import torch.nn as nn
127
+ from torch.optim import SGD
128
+ # for sche in ('lin', 'lin0', 'lin00', 'lin0.5', 'lin0.75'):
129
+ for sche in ('lin0', ):
130
+ op = SGD(nn.Linear(3, 4).parameters(), lr=1e-3)
131
+ it, lr = [], []
132
+ iters = 500
133
+ wp_it, max_it = 1 * iters, 10 * iters
134
+ for cur_it in range(max_it):
135
+ it.append(cur_it)
136
+ lr.append(lr_wd_annealing(sche, op, 0.1, 1e-5, 1e-5, cur_it, wp_it, max_it, wpe=0.3)[0])
137
+
138
+ plt.figure()
139
+ plt.title(sche)
140
+ plt.plot(it, lr, 'b', label=sche)
141
+ plt.xlabel('it'), plt.ylabel('lr')
142
+ plt.legend()
143
+
144
+ plt.savefig('lr.jpg')
145
+
146
+
147
+ if __name__ == '__main__':
148
+ plot()
utils/misc.py ADDED
@@ -0,0 +1,397 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import functools
3
+ import math
4
+ import os
5
+ import random
6
+ import subprocess
7
+ import sys
8
+ import threading
9
+ import time
10
+ from collections import defaultdict, deque
11
+ from typing import Iterator, List, Tuple
12
+
13
+ import numpy as np
14
+ import pytz
15
+ import torch
16
+ import torch.distributed as tdist
17
+ import torch.nn.functional as F
18
+
19
+ import infinity.utils.dist as dist
20
+
21
+ os_system = functools.partial(subprocess.call, shell=True)
22
+ def echo(info):
23
+ os_system(f'echo "[$(date "+%m-%d-%H:%M:%S")] ({os.path.basename(sys._getframe().f_back.f_code.co_filename)}, line{sys._getframe().f_back.f_lineno})=> {info}"')
24
+ def os_system_get_stdout(cmd):
25
+ return subprocess.run(cmd, shell=True, stdout=subprocess.PIPE).stdout.decode('utf-8')
26
+ def os_system_get_stdout_stderr(cmd):
27
+ cnt = 0
28
+ while True:
29
+ try:
30
+ sp = subprocess.run(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, timeout=30)
31
+ except subprocess.TimeoutExpired:
32
+ cnt += 1
33
+ print(f'[fetch free_port file] timeout cnt={cnt}')
34
+ else:
35
+ return sp.stdout.decode('utf-8'), sp.stderr.decode('utf-8')
36
+
37
+
38
+ def is_pow2n(x):
39
+ return x > 0 and (x & (x - 1) == 0)
40
+
41
+
42
+ def time_str(fmt='[%m-%d %H:%M:%S]'):
43
+ return datetime.datetime.now(tz=pytz.timezone('Asia/Shanghai')).strftime(fmt)
44
+
45
+
46
+ class DistLogger(object):
47
+ def __init__(self, lg):
48
+ self._lg = lg
49
+
50
+ @staticmethod
51
+ def do_nothing(*args, **kwargs):
52
+ pass
53
+
54
+ def __getattr__(self, attr: str):
55
+ return getattr(self._lg, attr) if self._lg is not None else DistLogger.do_nothing
56
+
57
+ class TensorboardLogger(object):
58
+ def __init__(self, log_dir, filename_suffix):
59
+ try: import tensorflow_io as tfio
60
+ except: pass
61
+ from torch.utils.tensorboard import SummaryWriter
62
+ self.writer = SummaryWriter(log_dir=log_dir, filename_suffix=filename_suffix)
63
+ self.step = 0
64
+
65
+ def set_step(self, step=None):
66
+ if step is not None:
67
+ self.step = step
68
+ else:
69
+ self.step += 1
70
+
71
+ def loggable(self):
72
+ return self.step == 0 or (self.step + 1) % 500 == 0
73
+
74
+ def update(self, head='scalar', step=None, **kwargs):
75
+ if step is None:
76
+ step = self.step
77
+ if not self.loggable(): return
78
+ for k, v in kwargs.items():
79
+ if v is None: continue
80
+ if hasattr(v, 'item'): v = v.item()
81
+ self.writer.add_scalar(f'{head}/{k}', v, step)
82
+
83
+ def log_tensor_as_distri(self, tag, tensor1d, step=None):
84
+ if step is None:
85
+ step = self.step
86
+ if not self.loggable(): return
87
+ try:
88
+ self.writer.add_histogram(tag=tag, values=tensor1d, global_step=step)
89
+ except Exception as e:
90
+ print(f'[log_tensor_as_distri writer.add_histogram failed]: {e}')
91
+
92
+ def log_image(self, tag, img_chw, step=None):
93
+ if step is None:
94
+ step = self.step
95
+ if not self.loggable(): return
96
+ self.writer.add_image(tag, img_chw, step, dataformats='CHW')
97
+
98
+ def flush(self):
99
+ self.writer.flush()
100
+
101
+ def close(self):
102
+ self.writer.close()
103
+
104
+
105
+ class Low_GPU_usage(object):
106
+ def __init__(self, files, sleep_secs, verbose):
107
+ pass
108
+
109
+ def early_stop(self):
110
+ pass
111
+
112
+ def __enter__(self):
113
+ return self
114
+
115
+ def __exit__(self, exc_type, exc_val, exc_tb):
116
+ pass
117
+
118
+ class TouchingDaemonDontForgetToStartMe(threading.Thread):
119
+ def __init__(self, files: List[str], sleep_secs: int, verbose=False):
120
+ super().__init__(daemon=True)
121
+ self.files = tuple(files)
122
+ self.sleep_secs = sleep_secs
123
+ self.is_finished = False
124
+ self.verbose = verbose
125
+
126
+ f_back = sys._getframe().f_back
127
+ file_desc = f'{f_back.f_code.co_filename:24s}'[-24:]
128
+ self.print_prefix = f' ({file_desc}, line{f_back.f_lineno:-4d}) @daemon@ '
129
+
130
+ def finishing(self):
131
+ self.is_finished = True
132
+
133
+ def run(self) -> None:
134
+ kw = {}
135
+ if tdist.is_initialized(): kw['clean'] = True
136
+
137
+ stt = time.time()
138
+ if self.verbose: print(f'{time_str()}{self.print_prefix}[TouchingDaemon tid={threading.get_native_id()}] start touching {self.files} per {self.sleep_secs}s ...', **kw)
139
+ while not self.is_finished:
140
+ for f in self.files:
141
+ if os.path.exists(f):
142
+ try:
143
+ os.utime(f)
144
+ fp = open(f, 'a')
145
+ fp.close()
146
+ except: pass
147
+ time.sleep(self.sleep_secs)
148
+
149
+ if self.verbose: print(f'{time_str()}{self.print_prefix}[TouchingDaemon tid={threading.get_native_id()}] finish touching after {time.time()-stt:.1f} secs {self.files} per {self.sleep_secs}s. ', **kw)
150
+
151
+
152
+ class SmoothedValue(object):
153
+ """Track a series of values and provide access to smoothed values over a
154
+ window or the global series average.
155
+ """
156
+
157
+ def __init__(self, window_size=30, fmt=None):
158
+ if fmt is None:
159
+ fmt = "{median:.4f} ({global_avg:.4f})"
160
+ self.deque = deque(maxlen=window_size)
161
+ self.total = 0.0
162
+ self.count = 0
163
+ self.fmt = fmt
164
+
165
+ def update(self, value, n=1):
166
+ self.deque.append(value)
167
+ self.count += n
168
+ self.total += value * n
169
+
170
+ def synchronize_between_processes(self):
171
+ """
172
+ Warning: does not synchronize the deque!
173
+ """
174
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
175
+ tdist.barrier()
176
+ tdist.all_reduce(t)
177
+ t = t.tolist()
178
+ self.count = int(t[0])
179
+ self.total = t[1]
180
+
181
+ @property
182
+ def median(self):
183
+ return np.median(self.deque) if len(self.deque) else 0
184
+
185
+ @property
186
+ def avg(self):
187
+ return sum(self.deque) / (len(self.deque) or 1)
188
+
189
+ @property
190
+ def global_avg(self):
191
+ return self.total / (self.count or 1)
192
+
193
+ @property
194
+ def max(self):
195
+ return max(self.deque) if len(self.deque) else 0
196
+
197
+ @property
198
+ def value(self):
199
+ return self.deque[-1] if len(self.deque) else 0
200
+
201
+ def time_preds(self, counts) -> Tuple[float, str, str]:
202
+ remain_secs = counts * self.median
203
+ return remain_secs, str(datetime.timedelta(seconds=round(remain_secs))), time.strftime("%Y-%m-%d %H:%M", time.localtime(time.time() + remain_secs))
204
+
205
+ def __str__(self):
206
+ return self.fmt.format(median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value)
207
+
208
+
209
+ class MetricLogger(object):
210
+ def __init__(self):
211
+ self.meters = defaultdict(SmoothedValue)
212
+ self.iter_end_t = time.time()
213
+ self.log_iters = set()
214
+ self.log_every_iter = False
215
+
216
+ def update(self, **kwargs):
217
+ # if it != 0 and it not in self.log_iters: return
218
+ for k, v in kwargs.items():
219
+ if v is None: continue
220
+ if hasattr(v, 'item'): v = v.item()
221
+ # assert isinstance(v, (float, int)), type(v)
222
+ self.meters[k].update(v)
223
+
224
+ def __getattr__(self, attr):
225
+ if attr in self.meters:
226
+ return self.meters[attr]
227
+ if attr in self.__dict__:
228
+ return self.__dict__[attr]
229
+ raise AttributeError("'{}' object has no attribute '{}'".format(
230
+ type(self).__name__, attr))
231
+
232
+ def __str__(self):
233
+ loss_str = []
234
+ for name, meter in self.meters.items():
235
+ if len(meter.deque):
236
+ loss_str.append(
237
+ "{}: {}".format(name, str(meter))
238
+ )
239
+ return ' '.join(loss_str)
240
+
241
+ def synchronize_between_processes(self):
242
+ for meter in self.meters.values():
243
+ meter.synchronize_between_processes()
244
+
245
+ def add_meter(self, name, meter):
246
+ self.meters[name] = meter
247
+
248
+ def log_every(self, start_it, max_iters, itrt, log_freq, log_every_iter=False, header=''): # also solve logging & skipping iterations before start_it
249
+ start_it = start_it % max_iters
250
+ self.log_iters = set(range(start_it, max_iters, log_freq))
251
+ self.log_iters.add(start_it)
252
+ self.log_iters.add(max_iters-1)
253
+ self.log_iters.add(max_iters)
254
+ self.log_every_iter = log_every_iter
255
+ self.iter_end_t = time.time()
256
+ self.iter_time = SmoothedValue(fmt='{value:.4f}')
257
+ self.data_time = SmoothedValue(fmt='{value:.3f}')
258
+ header_fmt = header + ': [{0:' + str(len(str(max_iters))) + 'd}/{1}]'
259
+
260
+ start_time = time.time()
261
+ if isinstance(itrt, Iterator) and not hasattr(itrt, 'preload') and not hasattr(itrt, 'set_epoch'):
262
+ for it in range(start_it, max_iters):
263
+ obj = next(itrt)
264
+ if it < start_it: continue
265
+ self.data_time.update(time.time() - self.iter_end_t)
266
+ yield it, obj
267
+ self.iter_time.update(time.time() - self.iter_end_t)
268
+ if self.log_every_iter or it in self.log_iters:
269
+ eta_seconds = self.iter_time.avg * (max_iters - it)
270
+ print(f'{header_fmt.format(it, max_iters)} eta: {str(datetime.timedelta(seconds=int(eta_seconds)))} {str(self)} T: {self.iter_time.value:.3f}s dataT: {self.data_time.value*1e3:.1f}ms', flush=True)
271
+ self.iter_end_t = time.time()
272
+ else:
273
+ if isinstance(itrt, int): itrt = range(itrt)
274
+ for it, obj in enumerate(itrt):
275
+ if it < start_it:
276
+ self.iter_end_t = time.time()
277
+ continue
278
+ self.data_time.update(time.time() - self.iter_end_t)
279
+ yield it, obj
280
+ self.iter_time.update(time.time() - self.iter_end_t)
281
+ if self.log_every_iter or it in self.log_iters:
282
+ eta_seconds = self.iter_time.avg * (max_iters - it)
283
+ print(f'{header_fmt.format(it, max_iters)} eta: {str(datetime.timedelta(seconds=int(eta_seconds)))} {str(self)} T: {self.iter_time.value:.3f}s dataT: {self.data_time.value*1e3:.1f}ms', flush=True)
284
+ self.iter_end_t = time.time()
285
+ cost = time.time() - start_time
286
+ cost_str = str(datetime.timedelta(seconds=int(cost)))
287
+ print(f'{header} Cost of this ep: {cost_str} ({cost / (max_iters-start_it):.3f} s / it)', flush=True)
288
+
289
+
290
+ class NullDDP(torch.nn.Module):
291
+ def __init__(self, module, *args, **kwargs):
292
+ super(NullDDP, self).__init__()
293
+ self.module = module
294
+ self.require_backward_grad_sync = False
295
+
296
+ def forward(self, *args, **kwargs):
297
+ return self.module(*args, **kwargs)
298
+
299
+
300
+ def build_2d_sincos_position_embedding(h, w, embed_dim, temperature=10000., sc=0, verbose=True): # (1, hw**2, embed_dim)
301
+ # DiT: sc=0
302
+ # DETR: sc=2?
303
+ grid_w = torch.arange(w, dtype=torch.float32)
304
+ grid_h = torch.arange(h, dtype=torch.float32)
305
+ grid_w, grid_h = torch.meshgrid([grid_w, grid_h], indexing='ij')
306
+ if sc == 0:
307
+ scale = 1
308
+ elif sc == 1:
309
+ scale = math.pi * 2 / w
310
+ else:
311
+ scale = 1 / w
312
+ grid_w = scale * grid_w.reshape(h*w, 1) # scale * [0, 0, 0, 1, 1, 1, 2, 2, 2]
313
+ grid_h = scale * grid_h.reshape(h*w, 1) # scale * [0, 1, 2, 0, 1, 2, 0, 1, 2]
314
+
315
+ assert embed_dim % 4 == 0, f'Embed dimension ({embed_dim}) must be divisible by 4 for 2D sin-cos position embedding!'
316
+ pos_dim = embed_dim // 4
317
+ omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
318
+ omega = (-math.log(temperature) * omega).exp()
319
+ # omega == (1/T) ** (arange(pos_dim) / pos_dim), a vector only dependent on C
320
+ out_w = grid_w * omega.view(1, pos_dim) # out_w: scale * [0*ome, 0*ome, 0*ome, 1*ome, 1*ome, 1*ome, 2*ome, 2*ome, 2*ome]
321
+ out_h = grid_h * omega.view(1, pos_dim) # out_h: scale * [0*ome, 1*ome, 2*ome, 0*ome, 1*ome, 2*ome, 0*ome, 1*ome, 2*ome]
322
+ pos_emb = torch.cat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h), torch.cos(out_h)], dim=1)[None, :, :]
323
+ if verbose: print(f'[build_2d_sincos_position_embedding @ {hw} x {hw}] scale_type={sc}, temperature={temperature:g}, shape={pos_emb.shape}')
324
+ return pos_emb # (1, hw**2, embed_dim)
325
+
326
+
327
+ if __name__ == '__main__':
328
+ import seaborn as sns
329
+ import matplotlib.pyplot as plt
330
+ cmap_div = sns.color_palette('icefire', as_cmap=True)
331
+
332
+ scs = [0, 1, 2]
333
+ temps = [20, 50, 100, 1000]
334
+ reso = 3.0
335
+ RR, CC = len(scs), len(temps)
336
+ plt.figure(figsize=(CC * reso, RR * reso)) # figsize=(16, 16)
337
+ for row, sc in enumerate(scs):
338
+ for col, temp in enumerate(temps):
339
+ name = f'sc={sc}, T={temp}'
340
+ hw, C = 16, 512
341
+ N = hw*hw
342
+ pe = build_2d_sincos_position_embedding(hw, C, temperature=temp, sc=sc, verbose=False)[0] # N, C = 64, 16
343
+
344
+ hw2 = 16
345
+ N2 = hw2*hw2
346
+ pe2 = build_2d_sincos_position_embedding(hw2, C, temperature=temp, sc=sc, verbose=False)[0] # N, C = 64, 16
347
+ # pe2 = pe2.flip(dims=(0,))
348
+ bchw, bchw2 = F.normalize(pe.view(hw, hw, C).permute(2, 0, 1).unsqueeze(0), dim=1), F.normalize(pe2.view(hw2, hw2, C).permute(2, 0, 1).unsqueeze(0), dim=1)
349
+ dis = [
350
+ f'{F.mse_loss(bchw, F.interpolate(bchw2, size=bchw.shape[-2], mode=inter)).item():.3f}'
351
+ for inter in ('bilinear', 'bicubic', 'nearest')
352
+ ]
353
+ dis += [
354
+ f'{F.mse_loss(F.interpolate(bchw, size=bchw2.shape[-2], mode=inter), bchw2).item():.3f}'
355
+ for inter in ('area', 'nearest')
356
+ ]
357
+ print(f'[{name:^20s}] dis: {dis}')
358
+ """
359
+ [ sc=0, T=20 ] dis: ['0.010', '0.011', '0.011', '0.009', '0.010']
360
+ [ sc=0, T=100 ] dis: ['0.007', '0.007', '0.007', '0.006', '0.007']
361
+ [ sc=0, T=1000 ] dis: ['0.005', '0.005', '0.005', '0.004', '0.005']
362
+ [ sc=0, T=10000 ] dis: ['0.004', '0.004', '0.004', '0.003', '0.004']
363
+ [ sc=1, T=20 ] dis: ['0.007', '0.008', '0.008', '0.007', '0.008']
364
+ [ sc=1, T=100 ] dis: ['0.005', '0.005', '0.005', '0.005', '0.005']
365
+ [ sc=1, T=1000 ] dis: ['0.003', '0.003', '0.003', '0.003', '0.003']
366
+ [ sc=1, T=10000 ] dis: ['0.003', '0.003', '0.003', '0.003', '0.003']
367
+ [ sc=2, T=20 ] dis: ['0.000', '0.000', '0.000', '0.000', '0.000']
368
+ [ sc=2, T=100 ] dis: ['0.000', '0.000', '0.000', '0.000', '0.000']
369
+ [ sc=2, T=1000 ] dis: ['0.000', '0.000', '0.000', '0.000', '0.000']
370
+ [ sc=2, T=10000 ] dis: ['0.000', '0.000', '0.000', '0.000', '0.000']
371
+ Process finished with exit code 0
372
+ """
373
+
374
+ pe = torch.from_numpy(cmap_div(pe.T.numpy())[:, :, :3]) # C, N, 3
375
+ tar_h, tar_w = 1024, 1024
376
+ pe = pe.repeat_interleave(tar_w//pe.shape[0], dim=0).repeat_interleave(tar_h//pe.shape[1], dim=1)
377
+ plt.subplot(RR, CC, 1+row*CC+col)
378
+ plt.title(name)
379
+ plt.xlabel('hxw'), plt.ylabel('C')
380
+ plt.xticks([]), plt.yticks([])
381
+ plt.imshow(pe.mul(255).round().clamp(0, 255).byte().numpy())
382
+ plt.tight_layout(h_pad=0.02)
383
+ plt.show()
384
+
385
+
386
+ def check_randomness(args):
387
+ U = 16384
388
+ t = torch.zeros(dist.get_world_size(), 4, dtype=torch.float32, device=args.device)
389
+ t0 = torch.zeros(1, dtype=torch.float32, device=args.device).random_(U)
390
+ t[dist.get_rank(), 0] = float(random.randrange(U))
391
+ t[dist.get_rank(), 1] = float(np.random.randint(U))
392
+ t[dist.get_rank(), 2] = float(torch.randint(0, U, (1,))[0])
393
+ t[dist.get_rank(), 3] = float(t0[0])
394
+ dist.allreduce(t)
395
+ for rk in range(1, dist.get_world_size()):
396
+ assert torch.allclose(t[rk - 1], t[rk]), f't={t}'
397
+ del t0, t, U
utils/save_and_load.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import os
3
+ import subprocess
4
+ import time
5
+ import re
6
+ from typing import List, Optional, Tuple
7
+
8
+ import torch
9
+ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
10
+
11
+ import glob
12
+ import shutil
13
+ from infinity.utils import arg_util
14
+ import infinity.utils.dist as dist
15
+
16
+
17
+ def glob_with_epoch_iter(pattern, recursive=False):
18
+ def extract_ep_iter(filename):
19
+ match = re.search(r'ep(\d+)-iter(\d+)', filename)
20
+ if match:
21
+ ep = int(match.group(1))
22
+ iter_idx = int(match.group(2))
23
+ return ep, iter_idx
24
+ return 0, 0
25
+ return sorted(glob.glob(pattern, recursive=recursive), key=lambda x: extract_ep_iter(os.path.basename(x)), reverse=True)
26
+
27
+
28
+ def glob_with_global_step(pattern, recursive=False):
29
+ def extract_ep_iter(filename):
30
+ match = re.search(r'global_step_(\d+)', filename)
31
+ if match:
32
+ iter_idx = int(match.group(1))
33
+ return iter_idx
34
+ return 0
35
+ return sorted(glob.glob(pattern, recursive=recursive), key=lambda x: extract_ep_iter(os.path.basename(x)), reverse=True)
36
+
37
+
38
+ class CKPTSaver(object):
39
+ def __init__(self, is_master: bool, eval_milestone: List[Tuple[float, float]]):
40
+ self.is_master = is_master
41
+ self.time_stamp = torch.tensor([time.time() - 1e5, time.time()], device=dist.get_device())
42
+ self.sp_also: subprocess.Popen = None
43
+ self.sp_best: subprocess.Popen = None
44
+ self.sp_backup: subprocess.Popen = None
45
+ self.acc_str, self.eval_milestone = '[no acc str]', eval_milestone
46
+
47
+ def sav(
48
+ self, args: arg_util.Args, g_it: int, next_ep: int, next_it: int, trainer,
49
+ acc_str: Optional[str] = None, eval_milestone: Optional[List[Tuple[float, float]]] = None,
50
+ also_save_to: str = None, best_save_to: str = None,
51
+ ):
52
+ self.time_stamp[1] = time.time()
53
+ dist.broadcast(self.time_stamp, src_rank=0)
54
+ last_save_time, cur_time = self.time_stamp.cpu().tolist()
55
+
56
+ auto_save = cur_time - last_save_time > 20 * 60
57
+ need_save = also_save_to is not None or best_save_to is not None or next_ep == args.ep or auto_save
58
+ if not need_save:
59
+ return
60
+
61
+ if acc_str is not None: self.acc_str = acc_str
62
+ if eval_milestone is not None: self.eval_milestone = eval_milestone
63
+
64
+ fname = f'ar-ckpt-giter{g_it//1000:03d}K-ep{next_ep}-iter{next_it}-last.pth' if args.gpt_training else f'ckpt-last.pth'
65
+ local_out_ckpt = os.path.join(args.local_out_path, fname)
66
+
67
+ # NOTE: all rank should call this state_dict(), not master only!
68
+ trainer_state = trainer.state_dict()
69
+
70
+ if self.is_master:
71
+ stt = time.time()
72
+ torch.save({
73
+ 'args': args.state_dict(),
74
+ 'gpt_training': args.gpt_training,
75
+ 'arch': args.model if args.gpt_training else args.vv,
76
+ 'epoch': next_ep,
77
+ 'iter': next_it,
78
+ 'trainer': trainer_state,
79
+ 'acc_str': self.acc_str,
80
+ 'milestones': self.eval_milestone,
81
+ }, local_out_ckpt)
82
+
83
+ print(f'[CKPTSaver][rank00] start: {also_save_to=} {best_save_to=} {(next_ep == args.ep)=} {auto_save=} | see {local_out_ckpt}', flush=True)
84
+ print(f'[CKPTSaver][rank00] dbg: {args.bed=}', flush=True)
85
+ if auto_save:
86
+ if self.sp_backup is not None:
87
+ self.sp_backup.wait(timeout=300); self.sp_backup.kill(); self.sp_backup.communicate()
88
+ self.time_stamp[0] = time.time()
89
+
90
+ def auto_sync(source_filename, target_filename):
91
+ cmd = f'cp -r {source_filename} {target_filename}'
92
+ self.sp_backup = subprocess.Popen(cmd, shell=True, bufsize=-1)
93
+ print(f'[CKPTSaver] auto_save cmd: {cmd}', flush=True)
94
+
95
+ local_files = glob.glob(f"{args.local_out_path}/*")
96
+ for filename in local_files:
97
+ basename = os.path.basename(filename)
98
+ target_filename = f'{args.bed}/{basename}'
99
+ if basename.endswith('.pth'):
100
+ if not os.path.isfile(target_filename):
101
+ auto_sync(filename, target_filename)
102
+ else:
103
+ auto_sync(filename, target_filename)
104
+ cost = time.time() - stt
105
+ print(f'[CKPTSaver][rank00] cost: {cost:.2f}s', flush=True)
106
+
107
+ del trainer_state
108
+ time.sleep(3), gc.collect(), torch.cuda.empty_cache(), time.sleep(3)
109
+ dist.barrier()
110
+
111
+
112
+ def auto_resume(args: arg_util.Args, pattern='ckpt*.pth') -> Tuple[List[str], int, int, str, List[Tuple[float, float]], dict, dict]:
113
+ info = []
114
+ resume = ''
115
+ if args.auto_resume:
116
+ for dd in (args.local_out_path, args.bed):
117
+ all_ckpt = glob_with_epoch_iter(os.path.join(dd, pattern))
118
+ if len(all_ckpt): break
119
+ if len(all_ckpt) == 0:
120
+ info.append(f'[auto_resume] no ckpt found @ {pattern}')
121
+ info.append(f'[auto_resume quit]')
122
+ else:
123
+ resume = all_ckpt[0]
124
+ info.append(f'[auto_resume] auto load from @ {resume} ...')
125
+ else:
126
+ info.append(f'[auto_resume] disabled')
127
+ info.append(f'[auto_resume quit]')
128
+
129
+ if len(resume) == 0:
130
+ return info, 0, 0, '[no acc str]', [], {}, {}
131
+
132
+ print(f'auto resume from {resume}')
133
+
134
+ try:
135
+ ckpt = torch.load(resume, map_location='cpu')
136
+ except Exception as e:
137
+ info.append(f'[auto_resume] failed, {e} @ {resume}')
138
+ if len(all_ckpt) < 2:
139
+ return info, 0, 0, '[no acc str]', [], {}, {}
140
+ try: # another chance to load from bytenas
141
+ ckpt = torch.load(all_ckpt[1], map_location='cpu')
142
+ except Exception as e:
143
+ info.append(f'[auto_resume] failed, {e} @ {all_ckpt[1]}')
144
+ return info, 0, 0, '[no acc str]', [], {}, {}
145
+
146
+ dist.barrier()
147
+ ep, it = ckpt['epoch'], ckpt['iter']
148
+ eval_milestone = ckpt.get('milestones', [])
149
+ info.append(f'[auto_resume success] resume from ep{ep}, it{it}, eval_milestone: {eval_milestone}')
150
+ return info, ep, it, ckpt.get('acc_str', '[no acc str]'), eval_milestone, ckpt['trainer'], ckpt['args']
utils/wandb_utils.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import wandb
2
+ import torch
3
+ from torchvision.utils import make_grid
4
+ import torch.distributed as dist
5
+ from PIL import Image
6
+ import os
7
+ import argparse
8
+ import hashlib
9
+ import math
10
+
11
+
12
+ def is_main_process():
13
+ return dist.get_rank() == 0
14
+
15
+ def namespace_to_dict(namespace):
16
+ return {
17
+ k: namespace_to_dict(v) if isinstance(v, argparse.Namespace) else v
18
+ for k, v in vars(namespace).items()
19
+ }
20
+
21
+
22
+ def generate_run_id(exp_name):
23
+ # https://stackoverflow.com/questions/16008670/how-to-hash-a-string-into-8-digits
24
+ return str(int(hashlib.sha256(exp_name.encode('utf-8')).hexdigest(), 16) % 10 ** 8)
25
+
26
+
27
+ def initialize(args, entity, exp_name, project_name):
28
+ config_dict = namespace_to_dict(args)
29
+ wandb.login(key=os.environ["WANDB_KEY"])
30
+ wandb.init(
31
+ entity=entity,
32
+ project=project_name,
33
+ name=exp_name,
34
+ config=config_dict,
35
+ id=generate_run_id(exp_name),
36
+ resume="allow",
37
+ )
38
+
39
+
40
+ def log(stats, step=None):
41
+ if is_main_process():
42
+ wandb.log({k: v for k, v in stats.items()}, step=step)
43
+
44
+
45
+ def log_image(name, sample, step=None):
46
+ if is_main_process():
47
+ sample = array2grid(sample)
48
+ wandb.log({f"{name}": wandb.Image(sample), "train_step": step})
49
+
50
+
51
+ def array2grid(x):
52
+ nrow = round(math.sqrt(x.size(0)))
53
+ x = make_grid(x, nrow=nrow, normalize=True, value_range=(-1,1))
54
+ x = x.mul(255).add_(0.5).clamp_(0,255).permute(1,2,0).to('cpu', torch.uint8).numpy()
55
+ return x