Spaces:
Running
on
Zero
Running
on
Zero
MohamedRashad
commited on
Commit
·
32287b3
1
Parent(s):
a8efd17
Add initial project structure with requirements and utility functions
Browse files- .gitignore +171 -0
- app.py +475 -0
- models/__init__.py +26 -0
- models/basic.py +575 -0
- models/bitwise_self_correction.py +97 -0
- models/bsq_vae/conv.py +71 -0
- models/bsq_vae/dynamic_resolution.py +32 -0
- models/bsq_vae/flux_vqgan.py +557 -0
- models/bsq_vae/multiscale_bsq.py +718 -0
- models/bsq_vae/vae.py +255 -0
- models/ema.py +23 -0
- models/flex_attn.py +130 -0
- models/fused_op.py +27 -0
- models/infinity.py +795 -0
- models/init_param.py +33 -0
- models/t5.py +369 -0
- requirements.txt +9 -0
- utils/amp_opt.py +187 -0
- utils/arg_util.py +482 -0
- utils/csv_util.py +20 -0
- utils/dist.py +326 -0
- utils/dynamic_resolution.py +73 -0
- utils/large_file_util.py +70 -0
- utils/load.py +100 -0
- utils/lr_control.py +148 -0
- utils/misc.py +397 -0
- utils/save_and_load.py +150 -0
- utils/wandb_utils.py +55 -0
.gitignore
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
share/python-wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
# PyInstaller
|
30 |
+
# Usually these files are written by a python script from a template
|
31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
+
*.manifest
|
33 |
+
*.spec
|
34 |
+
|
35 |
+
# Installer logs
|
36 |
+
pip-log.txt
|
37 |
+
pip-delete-this-directory.txt
|
38 |
+
|
39 |
+
# Unit test / coverage reports
|
40 |
+
htmlcov/
|
41 |
+
.tox/
|
42 |
+
.nox/
|
43 |
+
.coverage
|
44 |
+
.coverage.*
|
45 |
+
.cache
|
46 |
+
nosetests.xml
|
47 |
+
coverage.xml
|
48 |
+
*.cover
|
49 |
+
*.py,cover
|
50 |
+
.hypothesis/
|
51 |
+
.pytest_cache/
|
52 |
+
cover/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
.pybuilder/
|
76 |
+
target/
|
77 |
+
|
78 |
+
# Jupyter Notebook
|
79 |
+
.ipynb_checkpoints
|
80 |
+
|
81 |
+
# IPython
|
82 |
+
profile_default/
|
83 |
+
ipython_config.py
|
84 |
+
|
85 |
+
# pyenv
|
86 |
+
# For a library or package, you might want to ignore these files since the code is
|
87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
88 |
+
# .python-version
|
89 |
+
|
90 |
+
# pipenv
|
91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
+
# install all needed dependencies.
|
95 |
+
#Pipfile.lock
|
96 |
+
|
97 |
+
# UV
|
98 |
+
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
100 |
+
# commonly ignored for libraries.
|
101 |
+
#uv.lock
|
102 |
+
|
103 |
+
# poetry
|
104 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
105 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
106 |
+
# commonly ignored for libraries.
|
107 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
108 |
+
#poetry.lock
|
109 |
+
|
110 |
+
# pdm
|
111 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
112 |
+
#pdm.lock
|
113 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
114 |
+
# in version control.
|
115 |
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
116 |
+
.pdm.toml
|
117 |
+
.pdm-python
|
118 |
+
.pdm-build/
|
119 |
+
|
120 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
121 |
+
__pypackages__/
|
122 |
+
|
123 |
+
# Celery stuff
|
124 |
+
celerybeat-schedule
|
125 |
+
celerybeat.pid
|
126 |
+
|
127 |
+
# SageMath parsed files
|
128 |
+
*.sage.py
|
129 |
+
|
130 |
+
# Environments
|
131 |
+
.env
|
132 |
+
.venv
|
133 |
+
env/
|
134 |
+
venv/
|
135 |
+
ENV/
|
136 |
+
env.bak/
|
137 |
+
venv.bak/
|
138 |
+
|
139 |
+
# Spyder project settings
|
140 |
+
.spyderproject
|
141 |
+
.spyproject
|
142 |
+
|
143 |
+
# Rope project settings
|
144 |
+
.ropeproject
|
145 |
+
|
146 |
+
# mkdocs documentation
|
147 |
+
/site
|
148 |
+
|
149 |
+
# mypy
|
150 |
+
.mypy_cache/
|
151 |
+
.dmypy.json
|
152 |
+
dmypy.json
|
153 |
+
|
154 |
+
# Pyre type checker
|
155 |
+
.pyre/
|
156 |
+
|
157 |
+
# pytype static type analyzer
|
158 |
+
.pytype/
|
159 |
+
|
160 |
+
# Cython debug symbols
|
161 |
+
cython_debug/
|
162 |
+
|
163 |
+
# PyCharm
|
164 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
165 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
166 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
167 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
168 |
+
#.idea/
|
169 |
+
|
170 |
+
# PyPI configuration file
|
171 |
+
.pypirc
|
app.py
ADDED
@@ -0,0 +1,475 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
3 |
+
|
4 |
+
import os.path as osp
|
5 |
+
import time
|
6 |
+
import hashlib
|
7 |
+
import argparse
|
8 |
+
import shutil
|
9 |
+
import re
|
10 |
+
import random
|
11 |
+
from pathlib import Path
|
12 |
+
from typing import List
|
13 |
+
|
14 |
+
import cv2
|
15 |
+
import numpy as np
|
16 |
+
import pandas as pd
|
17 |
+
import torch
|
18 |
+
import torch.nn.functional as F
|
19 |
+
from PIL import Image, ImageEnhance
|
20 |
+
import PIL.Image as PImage
|
21 |
+
from torchvision.transforms.functional import to_tensor
|
22 |
+
from transformers import AutoTokenizer, T5EncoderModel, T5TokenizerFast, T5Tokenizer, T5ForConditionalGeneration
|
23 |
+
from huggingface_hub import hf_hub_download
|
24 |
+
import gradio as gr
|
25 |
+
import spaces
|
26 |
+
|
27 |
+
from models.infinity import Infinity
|
28 |
+
from models.basic import *
|
29 |
+
from utils.dynamic_resolution import dynamic_resolution_h_w, h_div_w_templates
|
30 |
+
|
31 |
+
torch._dynamo.config.cache_size_limit = 64
|
32 |
+
|
33 |
+
# Define a function to download weights if not present
|
34 |
+
def download_weights(weights_path):
|
35 |
+
try:
|
36 |
+
model_file = weights_path / 'infinity_2b_reg.pth'
|
37 |
+
if not model_file.exists():
|
38 |
+
hf_hub_download(repo_id="FoundationVision/Infinity", filename="infinity_2b_reg.pth", local_dir=str(weights_path))
|
39 |
+
|
40 |
+
vae_file = weights_path / 'infinity_vae_d32reg.pth'
|
41 |
+
if not vae_file.exists():
|
42 |
+
hf_hub_download(repo_id="FoundationVision/Infinity", filename="infinity_vae_d32reg.pth", local_dir=str(weights_path))
|
43 |
+
|
44 |
+
# For the text encoder, we need to download the entire model
|
45 |
+
text_encoder_ckpt = weights_path / 'flan-t5-xl'
|
46 |
+
if not text_encoder_ckpt.exists():
|
47 |
+
tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-xl")
|
48 |
+
model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-xl")
|
49 |
+
tokenizer.save_pretrained(text_encoder_ckpt)
|
50 |
+
model.save_pretrained(text_encoder_ckpt)
|
51 |
+
except Exception as e:
|
52 |
+
print(f"Error downloading weights: {e}")
|
53 |
+
|
54 |
+
def extract_key_val(text):
|
55 |
+
pattern = r'<(.+?):(.+?)>'
|
56 |
+
matches = re.findall(pattern, text)
|
57 |
+
key_val = {}
|
58 |
+
for match in matches:
|
59 |
+
key_val[match[0]] = match[1].lstrip()
|
60 |
+
return key_val
|
61 |
+
|
62 |
+
def encode_prompt(text_tokenizer, text_encoder, prompt, enable_positive_prompt=False):
|
63 |
+
if enable_positive_prompt:
|
64 |
+
print(f'before positive_prompt aug: {prompt}')
|
65 |
+
prompt = aug_with_positive_prompt(prompt)
|
66 |
+
print(f'after positive_prompt aug: {prompt}')
|
67 |
+
print(f'prompt={prompt}')
|
68 |
+
captions = [prompt]
|
69 |
+
tokens = text_tokenizer(text=captions, max_length=512, padding='max_length', truncation=True, return_tensors='pt') # todo: put this into dataset
|
70 |
+
input_ids = tokens.input_ids.cuda(non_blocking=True)
|
71 |
+
mask = tokens.attention_mask.cuda(non_blocking=True)
|
72 |
+
text_features = text_encoder(input_ids=input_ids, attention_mask=mask)['last_hidden_state'].float()
|
73 |
+
lens: List[int] = mask.sum(dim=-1).tolist()
|
74 |
+
cu_seqlens_k = F.pad(mask.sum(dim=-1).to(dtype=torch.int32).cumsum_(0), (1, 0))
|
75 |
+
Ltext = max(lens)
|
76 |
+
kv_compact = []
|
77 |
+
for len_i, feat_i in zip(lens, text_features.unbind(0)):
|
78 |
+
kv_compact.append(feat_i[:len_i])
|
79 |
+
kv_compact = torch.cat(kv_compact, dim=0)
|
80 |
+
text_cond_tuple = (kv_compact, lens, cu_seqlens_k, Ltext)
|
81 |
+
return text_cond_tuple
|
82 |
+
|
83 |
+
def aug_with_positive_prompt(prompt):
|
84 |
+
for key in ['man', 'woman', 'men', 'women', 'boy', 'girl', 'child', 'person', 'human', 'adult', 'teenager', 'employee',
|
85 |
+
'employer', 'worker', 'mother', 'father', 'sister', 'brother', 'grandmother', 'grandfather', 'son', 'daughter']:
|
86 |
+
if key in prompt:
|
87 |
+
prompt = prompt + '. very smooth faces, good looking faces, face to the camera, perfect facial features'
|
88 |
+
break
|
89 |
+
return prompt
|
90 |
+
|
91 |
+
def enhance_image(image):
|
92 |
+
for t in range(1):
|
93 |
+
contrast_image = image.copy()
|
94 |
+
contrast_enhancer = ImageEnhance.Contrast(contrast_image)
|
95 |
+
contrast_image = contrast_enhancer.enhance(1.05) # 增强对比度
|
96 |
+
color_image = contrast_image.copy()
|
97 |
+
color_enhancer = ImageEnhance.Color(color_image)
|
98 |
+
color_image = color_enhancer.enhance(1.05) # 增强饱和度
|
99 |
+
return color_image
|
100 |
+
|
101 |
+
def gen_one_img(
|
102 |
+
infinity_test,
|
103 |
+
vae,
|
104 |
+
text_tokenizer,
|
105 |
+
text_encoder,
|
106 |
+
prompt,
|
107 |
+
cfg_list=[],
|
108 |
+
tau_list=[],
|
109 |
+
negative_prompt='',
|
110 |
+
scale_schedule=None,
|
111 |
+
top_k=900,
|
112 |
+
top_p=0.97,
|
113 |
+
cfg_sc=3,
|
114 |
+
cfg_exp_k=0.0,
|
115 |
+
cfg_insertion_layer=-5,
|
116 |
+
vae_type=0,
|
117 |
+
gumbel=0,
|
118 |
+
softmax_merge_topk=-1,
|
119 |
+
gt_leak=-1,
|
120 |
+
gt_ls_Bl=None,
|
121 |
+
g_seed=None,
|
122 |
+
sampling_per_bits=1,
|
123 |
+
enable_positive_prompt=0,
|
124 |
+
):
|
125 |
+
sstt = time.time()
|
126 |
+
if not isinstance(cfg_list, list):
|
127 |
+
cfg_list = [cfg_list] * len(scale_schedule)
|
128 |
+
if not isinstance(tau_list, list):
|
129 |
+
tau_list = [tau_list] * len(scale_schedule)
|
130 |
+
text_cond_tuple = encode_prompt(text_tokenizer, text_encoder, prompt, enable_positive_prompt)
|
131 |
+
if negative_prompt:
|
132 |
+
negative_label_B_or_BLT = encode_prompt(text_tokenizer, text_encoder, negative_prompt)
|
133 |
+
else:
|
134 |
+
negative_label_B_or_BLT = None
|
135 |
+
print(f'cfg: {cfg_list}, tau: {tau_list}')
|
136 |
+
with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16, cache_enabled=True):
|
137 |
+
stt = time.time()
|
138 |
+
_, _, img_list = infinity_test.autoregressive_infer_cfg(
|
139 |
+
vae=vae,
|
140 |
+
scale_schedule=scale_schedule,
|
141 |
+
label_B_or_BLT=text_cond_tuple, g_seed=g_seed,
|
142 |
+
B=1, negative_label_B_or_BLT=negative_label_B_or_BLT, force_gt_Bhw=None,
|
143 |
+
cfg_sc=cfg_sc, cfg_list=cfg_list, tau_list=tau_list, top_k=top_k, top_p=top_p,
|
144 |
+
returns_vemb=1, ratio_Bl1=None, gumbel=gumbel, norm_cfg=False,
|
145 |
+
cfg_exp_k=cfg_exp_k, cfg_insertion_layer=cfg_insertion_layer,
|
146 |
+
vae_type=vae_type, softmax_merge_topk=softmax_merge_topk,
|
147 |
+
ret_img=True, trunk_scale=1000,
|
148 |
+
gt_leak=gt_leak, gt_ls_Bl=gt_ls_Bl, inference_mode=True,
|
149 |
+
sampling_per_bits=sampling_per_bits,
|
150 |
+
)
|
151 |
+
print(f"cost: {time.time() - sstt}, infinity cost={time.time() - stt}")
|
152 |
+
img = img_list[0]
|
153 |
+
return img
|
154 |
+
|
155 |
+
def get_prompt_id(prompt):
|
156 |
+
md5 = hashlib.md5()
|
157 |
+
md5.update(prompt.encode('utf-8'))
|
158 |
+
prompt_id = md5.hexdigest()
|
159 |
+
return prompt_id
|
160 |
+
|
161 |
+
def save_slim_model(infinity_model_path, save_file=None, device='cpu', key='gpt_fsdp'):
|
162 |
+
print('[Save slim model]')
|
163 |
+
full_ckpt = torch.load(infinity_model_path, map_location=device)
|
164 |
+
infinity_slim = full_ckpt['trainer'][key]
|
165 |
+
# ema_state_dict = cpu_d['trainer'].get('gpt_ema_fsdp', state_dict)
|
166 |
+
if not save_file:
|
167 |
+
save_file = osp.splitext(infinity_model_path)[0] + '-slim.pth'
|
168 |
+
print(f'Save to {save_file}')
|
169 |
+
torch.save(infinity_slim, save_file)
|
170 |
+
print('[Save slim model] done')
|
171 |
+
return save_file
|
172 |
+
|
173 |
+
def load_tokenizer(t5_path =''):
|
174 |
+
print(f'[Loading tokenizer and text encoder]')
|
175 |
+
text_tokenizer: T5TokenizerFast = AutoTokenizer.from_pretrained(t5_path, revision=None, legacy=True)
|
176 |
+
text_tokenizer.model_max_length = 512
|
177 |
+
text_encoder: T5EncoderModel = T5EncoderModel.from_pretrained(t5_path, torch_dtype=torch.float16)
|
178 |
+
text_encoder.to('cuda')
|
179 |
+
text_encoder.eval()
|
180 |
+
text_encoder.requires_grad_(False)
|
181 |
+
return text_tokenizer, text_encoder
|
182 |
+
|
183 |
+
def load_infinity(
|
184 |
+
rope2d_each_sa_layer,
|
185 |
+
rope2d_normalized_by_hw,
|
186 |
+
use_scale_schedule_embedding,
|
187 |
+
pn,
|
188 |
+
use_bit_label,
|
189 |
+
add_lvl_embeding_only_first_block,
|
190 |
+
model_path='',
|
191 |
+
scale_schedule=None,
|
192 |
+
vae=None,
|
193 |
+
device='cuda',
|
194 |
+
model_kwargs=None,
|
195 |
+
text_channels=2048,
|
196 |
+
apply_spatial_patchify=0,
|
197 |
+
use_flex_attn=False,
|
198 |
+
bf16=False,
|
199 |
+
):
|
200 |
+
print(f'[Loading Infinity]')
|
201 |
+
text_maxlen = 512
|
202 |
+
with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16, cache_enabled=True), torch.no_grad():
|
203 |
+
infinity_test: Infinity = Infinity(
|
204 |
+
vae_local=vae, text_channels=text_channels, text_maxlen=text_maxlen,
|
205 |
+
shared_aln=True, raw_scale_schedule=scale_schedule,
|
206 |
+
checkpointing='full-block',
|
207 |
+
customized_flash_attn=False,
|
208 |
+
fused_norm=True,
|
209 |
+
pad_to_multiplier=128,
|
210 |
+
use_flex_attn=use_flex_attn,
|
211 |
+
add_lvl_embeding_only_first_block=add_lvl_embeding_only_first_block,
|
212 |
+
use_bit_label=use_bit_label,
|
213 |
+
rope2d_each_sa_layer=rope2d_each_sa_layer,
|
214 |
+
rope2d_normalized_by_hw=rope2d_normalized_by_hw,
|
215 |
+
pn=pn,
|
216 |
+
apply_spatial_patchify=apply_spatial_patchify,
|
217 |
+
inference_mode=True,
|
218 |
+
train_h_div_w_list=[1.0],
|
219 |
+
**model_kwargs,
|
220 |
+
).to(device=device)
|
221 |
+
print(f'[you selected Infinity with {model_kwargs=}] model size: {sum(p.numel() for p in infinity_test.parameters())/1e9:.2f}B, bf16={bf16}')
|
222 |
+
|
223 |
+
if bf16:
|
224 |
+
for block in infinity_test.unregistered_blocks:
|
225 |
+
block.bfloat16()
|
226 |
+
|
227 |
+
infinity_test.eval()
|
228 |
+
infinity_test.requires_grad_(False)
|
229 |
+
|
230 |
+
infinity_test.cuda()
|
231 |
+
torch.cuda.empty_cache()
|
232 |
+
|
233 |
+
print(f'[Load Infinity weights]')
|
234 |
+
state_dict = torch.load(model_path, map_location=device)
|
235 |
+
print(infinity_test.load_state_dict(state_dict))
|
236 |
+
infinity_test.rng = torch.Generator(device=device)
|
237 |
+
return infinity_test
|
238 |
+
|
239 |
+
def transform(pil_img, tgt_h, tgt_w):
|
240 |
+
width, height = pil_img.size
|
241 |
+
if width / height <= tgt_w / tgt_h:
|
242 |
+
resized_width = tgt_w
|
243 |
+
resized_height = int(tgt_w / (width / height))
|
244 |
+
else:
|
245 |
+
resized_height = tgt_h
|
246 |
+
resized_width = int((width / height) * tgt_h)
|
247 |
+
pil_img = pil_img.resize((resized_width, resized_height), resample=PImage.LANCZOS)
|
248 |
+
# crop the center out
|
249 |
+
arr = np.array(pil_img)
|
250 |
+
crop_y = (arr.shape[0] - tgt_h) // 2
|
251 |
+
crop_x = (arr.shape[1] - tgt_w) // 2
|
252 |
+
im = to_tensor(arr[crop_y: crop_y + tgt_h, crop_x: crop_x + tgt_w])
|
253 |
+
return im.add(im).add_(-1)
|
254 |
+
|
255 |
+
def joint_vi_vae_encode_decode(vae, image_path, scale_schedule, device, tgt_h, tgt_w):
|
256 |
+
pil_image = Image.open(image_path).convert('RGB')
|
257 |
+
inp = transform(pil_image, tgt_h, tgt_w)
|
258 |
+
inp = inp.unsqueeze(0).to(device)
|
259 |
+
scale_schedule = [(item[0], item[1], item[2]) for item in scale_schedule]
|
260 |
+
t1 = time.time()
|
261 |
+
h, z, _, all_bit_indices, _, infinity_input = vae.encode(inp, scale_schedule=scale_schedule)
|
262 |
+
t2 = time.time()
|
263 |
+
recons_img = vae.decode(z)[0]
|
264 |
+
if len(recons_img.shape) == 4:
|
265 |
+
recons_img = recons_img.squeeze(1)
|
266 |
+
print(f'recons: z.shape: {z.shape}, recons_img shape: {recons_img.shape}')
|
267 |
+
t3 = time.time()
|
268 |
+
print(f'vae encode takes {t2-t1:.2f}s, decode takes {t3-t2:.2f}s')
|
269 |
+
recons_img = (recons_img + 1) / 2
|
270 |
+
recons_img = recons_img.permute(1, 2, 0).mul_(255).cpu().numpy().astype(np.uint8)
|
271 |
+
gt_img = (inp[0] + 1) / 2
|
272 |
+
gt_img = gt_img.permute(1, 2, 0).mul_(255).cpu().numpy().astype(np.uint8)
|
273 |
+
print(recons_img.shape, gt_img.shape)
|
274 |
+
return gt_img, recons_img, all_bit_indices
|
275 |
+
|
276 |
+
def load_visual_tokenizer(args):
|
277 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
278 |
+
# load vae
|
279 |
+
if args.vae_type in [16,18,20,24,32,64]:
|
280 |
+
from models.bsq_vae.vae import vae_model
|
281 |
+
schedule_mode = "dynamic"
|
282 |
+
codebook_dim = args.vae_type
|
283 |
+
codebook_size = 2**codebook_dim
|
284 |
+
if args.apply_spatial_patchify:
|
285 |
+
patch_size = 8
|
286 |
+
encoder_ch_mult=[1, 2, 4, 4]
|
287 |
+
decoder_ch_mult=[1, 2, 4, 4]
|
288 |
+
else:
|
289 |
+
patch_size = 16
|
290 |
+
encoder_ch_mult=[1, 2, 4, 4, 4]
|
291 |
+
decoder_ch_mult=[1, 2, 4, 4, 4]
|
292 |
+
vae = vae_model(args.vae_path, schedule_mode, codebook_dim, codebook_size, patch_size=patch_size,
|
293 |
+
encoder_ch_mult=encoder_ch_mult, decoder_ch_mult=decoder_ch_mult, test_mode=True).to(device)
|
294 |
+
else:
|
295 |
+
raise ValueError(f'vae_type={args.vae_type} not supported')
|
296 |
+
return vae
|
297 |
+
|
298 |
+
def load_transformer(vae, args):
|
299 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
300 |
+
model_path = args.model_path
|
301 |
+
if args.checkpoint_type == 'torch':
|
302 |
+
# copy large model to local; save slim to local; and copy slim to nas; load local slim model
|
303 |
+
if osp.exists(args.cache_dir):
|
304 |
+
local_model_path = osp.join(args.cache_dir, 'tmp', model_path.replace('/', '_'))
|
305 |
+
else:
|
306 |
+
local_model_path = model_path
|
307 |
+
if args.enable_model_cache:
|
308 |
+
slim_model_path = model_path.replace('ar-', 'slim-')
|
309 |
+
local_slim_model_path = local_model_path.replace('ar-', 'slim-')
|
310 |
+
os.makedirs(osp.dirname(local_slim_model_path), exist_ok=True)
|
311 |
+
print(f'model_path: {model_path}, slim_model_path: {slim_model_path}')
|
312 |
+
print(f'local_model_path: {local_model_path}, local_slim_model_path: {local_slim_model_path}')
|
313 |
+
if not osp.exists(local_slim_model_path):
|
314 |
+
if osp.exists(slim_model_path):
|
315 |
+
print(f'copy {slim_model_path} to {local_slim_model_path}')
|
316 |
+
shutil.copyfile(slim_model_path, local_slim_model_path)
|
317 |
+
else:
|
318 |
+
if not osp.exists(local_model_path):
|
319 |
+
print(f'copy {model_path} to {local_model_path}')
|
320 |
+
shutil.copyfile(model_path, local_model_path)
|
321 |
+
save_slim_model(local_model_path, save_file=local_slim_model_path, device=device)
|
322 |
+
print(f'copy {local_slim_model_path} to {slim_model_path}')
|
323 |
+
if not osp.exists(slim_model_path):
|
324 |
+
shutil.copyfile(local_slim_model_path, slim_model_path)
|
325 |
+
os.remove(local_model_path)
|
326 |
+
os.remove(model_path)
|
327 |
+
slim_model_path = local_slim_model_path
|
328 |
+
else:
|
329 |
+
slim_model_path = model_path
|
330 |
+
print(f'load checkpoint from {slim_model_path}')
|
331 |
+
|
332 |
+
if args.model_type == 'infinity_2b':
|
333 |
+
kwargs_model = dict(depth=32, embed_dim=2048, num_heads=2048//128, drop_path_rate=0.1, mlp_ratio=4, block_chunks=8) # 2b model
|
334 |
+
elif args.model_type == 'infinity_layer12':
|
335 |
+
kwargs_model = dict(depth=12, embed_dim=768, num_heads=8, drop_path_rate=0.1, mlp_ratio=4, block_chunks=4)
|
336 |
+
elif args.model_type == 'infinity_layer16':
|
337 |
+
kwargs_model = dict(depth=16, embed_dim=1152, num_heads=12, drop_path_rate=0.1, mlp_ratio=4, block_chunks=4)
|
338 |
+
elif args.model_type == 'infinity_layer24':
|
339 |
+
kwargs_model = dict(depth=24, embed_dim=1536, num_heads=16, drop_path_rate=0.1, mlp_ratio=4, block_chunks=4)
|
340 |
+
elif args.model_type == 'infinity_layer32':
|
341 |
+
kwargs_model = dict(depth=32, embed_dim=2080, num_heads=20, drop_path_rate=0.1, mlp_ratio=4, block_chunks=4)
|
342 |
+
elif args.model_type == 'infinity_layer40':
|
343 |
+
kwargs_model = dict(depth=40, embed_dim=2688, num_heads=24, drop_path_rate=0.1, mlp_ratio=4, block_chunks=4)
|
344 |
+
elif args.model_type == 'infinity_layer48':
|
345 |
+
kwargs_model = dict(depth=48, embed_dim=3360, num_heads=28, drop_path_rate=0.1, mlp_ratio=4, block_chunks=4)
|
346 |
+
infinity = load_infinity(
|
347 |
+
rope2d_each_sa_layer=args.rope2d_each_sa_layer,
|
348 |
+
rope2d_normalized_by_hw=args.rope2d_normalized_by_hw,
|
349 |
+
use_scale_schedule_embedding=args.use_scale_schedule_embedding,
|
350 |
+
pn=args.pn,
|
351 |
+
use_bit_label=args.use_bit_label,
|
352 |
+
add_lvl_embeding_only_first_block=args.add_lvl_embeding_only_first_block,
|
353 |
+
model_path=slim_model_path,
|
354 |
+
scale_schedule=None,
|
355 |
+
vae=vae,
|
356 |
+
device=device,
|
357 |
+
model_kwargs=kwargs_model,
|
358 |
+
text_channels=args.text_channels,
|
359 |
+
apply_spatial_patchify=args.apply_spatial_patchify,
|
360 |
+
use_flex_attn=args.use_flex_attn,
|
361 |
+
bf16=args.bf16,
|
362 |
+
)
|
363 |
+
return infinity
|
364 |
+
|
365 |
+
# Set up paths
|
366 |
+
weights_path = Path(__file__).parent / 'weights'
|
367 |
+
weights_path.mkdir(exist_ok=True)
|
368 |
+
download_weights(weights_path)
|
369 |
+
|
370 |
+
# Device setup
|
371 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
372 |
+
dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float32
|
373 |
+
|
374 |
+
# Define args
|
375 |
+
args = argparse.Namespace(
|
376 |
+
pn='1M',
|
377 |
+
model_path=str(weights_path / 'infinity_2b_reg.pth'),
|
378 |
+
cfg_insertion_layer=0,
|
379 |
+
vae_type=32,
|
380 |
+
vae_path=str(weights_path / 'infinity_vae_d32reg.pth'),
|
381 |
+
add_lvl_embeding_only_first_block=1,
|
382 |
+
use_bit_label=1,
|
383 |
+
model_type='infinity_2b',
|
384 |
+
rope2d_each_sa_layer=1,
|
385 |
+
rope2d_normalized_by_hw=2,
|
386 |
+
use_scale_schedule_embedding=0,
|
387 |
+
sampling_per_bits=1,
|
388 |
+
text_encoder_ckpt=str(weights_path / 'flan-t5-xl'),
|
389 |
+
text_channels=2048,
|
390 |
+
apply_spatial_patchify=0,
|
391 |
+
h_div_w_template=1.000,
|
392 |
+
use_flex_attn=0,
|
393 |
+
cache_dir='/dev/shm',
|
394 |
+
checkpoint_type='torch',
|
395 |
+
seed=0,
|
396 |
+
bf16=1 if dtype == torch.bfloat16 else 0,
|
397 |
+
save_file='tmp.jpg',
|
398 |
+
enable_model_cache=False,
|
399 |
+
)
|
400 |
+
|
401 |
+
# Load models
|
402 |
+
text_tokenizer, text_encoder = load_tokenizer(t5_path=str(weights_path / 'flan-t5-xl'))
|
403 |
+
vae = load_visual_tokenizer(args)
|
404 |
+
infinity = load_transformer(vae, args)
|
405 |
+
|
406 |
+
# Define the image generation function
|
407 |
+
@spaces.GPU
|
408 |
+
def generate_image(prompt, cfg, tau, h_div_w, seed, enable_positive_prompt):
|
409 |
+
try:
|
410 |
+
args.prompt = prompt
|
411 |
+
args.cfg = cfg
|
412 |
+
args.tau = tau
|
413 |
+
args.h_div_w = h_div_w
|
414 |
+
args.seed = seed
|
415 |
+
args.enable_positive_prompt = enable_positive_prompt
|
416 |
+
|
417 |
+
# Find the closest h_div_w_template
|
418 |
+
h_div_w_template_ = h_div_w_templates[np.argmin(np.abs(h_div_w_templates - h_div_w))]
|
419 |
+
|
420 |
+
# Get scale_schedule based on h_div_w_template_
|
421 |
+
scale_schedule = dynamic_resolution_h_w[h_div_w_template_][args.pn]['scales']
|
422 |
+
scale_schedule = [(1, h, w) for (_, h, w) in scale_schedule]
|
423 |
+
|
424 |
+
# Generate the image
|
425 |
+
generated_image = gen_one_img(
|
426 |
+
infinity,
|
427 |
+
vae,
|
428 |
+
text_tokenizer,
|
429 |
+
text_encoder,
|
430 |
+
prompt,
|
431 |
+
g_seed=seed,
|
432 |
+
gt_leak=0,
|
433 |
+
gt_ls_Bl=None,
|
434 |
+
cfg_list=cfg,
|
435 |
+
tau_list=tau,
|
436 |
+
scale_schedule=scale_schedule,
|
437 |
+
cfg_insertion_layer=[args.cfg_insertion_layer],
|
438 |
+
vae_type=args.vae_type,
|
439 |
+
sampling_per_bits=args.sampling_per_bits,
|
440 |
+
enable_positive_prompt=enable_positive_prompt,
|
441 |
+
)
|
442 |
+
|
443 |
+
# Convert the image to RGB and uint8
|
444 |
+
image = generated_image.cpu().numpy()
|
445 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
446 |
+
image = np.uint8(image)
|
447 |
+
|
448 |
+
return image
|
449 |
+
except Exception as e:
|
450 |
+
print(f"Error generating image: {e}")
|
451 |
+
return None
|
452 |
+
|
453 |
+
# Set up Gradio interface
|
454 |
+
with gr.Blocks() as demo:
|
455 |
+
gr.Markdown("<h1><center>Infinity Image Generator</center></h1>")
|
456 |
+
|
457 |
+
with gr.Row():
|
458 |
+
prompt = gr.Textbox(label="Prompt", value="alien spaceship enterprise")
|
459 |
+
cfg = gr.Slider(label="CFG", minimum=1, maximum=10, step=0.5, value=3)
|
460 |
+
tau = gr.Slider(label="Tau", minimum=0.1, maximum=1.0, step=0.1, value=0.5)
|
461 |
+
h_div_w = gr.Slider(label="Aspect Ratio (Height/Width)", minimum=0.5, maximum=2.0, step=0.1, value=1.0)
|
462 |
+
seed = gr.Number(label="Seed", value=random.randint(0, 10000))
|
463 |
+
enable_positive_prompt = gr.Checkbox(label="Enable Positive Prompt", value=False)
|
464 |
+
|
465 |
+
generate_button = gr.Button("Generate Image")
|
466 |
+
output_image = gr.Image(label="Generated Image", type="pil")
|
467 |
+
|
468 |
+
generate_button.click(
|
469 |
+
generate_image,
|
470 |
+
inputs=[prompt, cfg, tau, h_div_w, seed, enable_positive_prompt],
|
471 |
+
outputs=output_image
|
472 |
+
)
|
473 |
+
|
474 |
+
# Launch the Gradio app
|
475 |
+
demo.launch()
|
models/__init__.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from timm.loss import SoftTargetCrossEntropy
|
3 |
+
|
4 |
+
from timm.models.layers import DropPath
|
5 |
+
|
6 |
+
from .infinity import Infinity, sample_with_top_k_top_p_also_inplace_modifying_logits_
|
7 |
+
|
8 |
+
def _ex_repr(self):
|
9 |
+
return ', '.join(
|
10 |
+
f'{k}=' + (f'{v:g}' if isinstance(v, float) else str(v))
|
11 |
+
for k, v in vars(self).items()
|
12 |
+
if not k.startswith('_') and k != 'training'
|
13 |
+
and not isinstance(v, (torch.nn.Module, torch.Tensor))
|
14 |
+
)
|
15 |
+
for clz in (torch.nn.CrossEntropyLoss, SoftTargetCrossEntropy): # no longer __repr__ DropPath with drop_prob
|
16 |
+
if hasattr(clz, 'extra_repr'):
|
17 |
+
clz.extra_repr = _ex_repr
|
18 |
+
else:
|
19 |
+
clz.__repr__ = lambda self: f'{type(self).__name__}({_ex_repr(self)})'
|
20 |
+
|
21 |
+
DropPath.__repr__ = lambda self: f'{type(self).__name__}(...)'
|
22 |
+
|
23 |
+
alias_dict = {}
|
24 |
+
for d in range(6, 40+2, 2):
|
25 |
+
alias_dict[f'd{d}'] = f'infinity_d{d}'
|
26 |
+
alias_dict_inv = {v: k for k, v in alias_dict.items()}
|
models/basic.py
ADDED
@@ -0,0 +1,575 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Definitions of blocks of VAR transformer model.
|
3 |
+
"""
|
4 |
+
|
5 |
+
import math
|
6 |
+
import os
|
7 |
+
from functools import partial
|
8 |
+
from typing import Optional, Tuple, Union
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
import torch.nn.functional as F
|
13 |
+
import numpy as np
|
14 |
+
from timm.models.layers import DropPath, drop_path
|
15 |
+
from torch.utils.checkpoint import checkpoint
|
16 |
+
|
17 |
+
# Import flash_attn's attention
|
18 |
+
from flash_attn import flash_attn_func # q, k, or v: BLHc, ret: BLHc
|
19 |
+
from flash_attn import flash_attn_varlen_kvpacked_func # qkv: N3Hc, ret: NHc
|
20 |
+
|
21 |
+
from torch.nn.functional import scaled_dot_product_attention as slow_attn # q, k, v: BHLc
|
22 |
+
|
23 |
+
# Import flash_attn's fused ops
|
24 |
+
try:
|
25 |
+
from flash_attn.ops.layer_norm import dropout_add_layer_norm
|
26 |
+
from flash_attn.ops.rms_norm import dropout_add_rms_norm
|
27 |
+
from flash_attn.ops.rms_norm import rms_norm as rms_norm_impl
|
28 |
+
from flash_attn.ops.fused_dense import fused_mlp_func
|
29 |
+
flash_fused_op_installed = True
|
30 |
+
except ImportError:
|
31 |
+
dropout_add_layer_norm = dropout_add_rms_norm = fused_mlp_func = None
|
32 |
+
flash_fused_op_installed = False
|
33 |
+
|
34 |
+
def rms_norm_impl(x, weight, epsilon):
|
35 |
+
return (x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True).add_(epsilon))) * weight
|
36 |
+
|
37 |
+
|
38 |
+
def precompute_rope2d_freqs_grid(dim, dynamic_resolution_h_w, rope2d_normalized_by_hw, pad_to_multiplier=1, max_height=2048 // 16, max_width=2048 // 16, base=10000.0, device=None, scaling_factor=1.0):
|
39 |
+
# split the dimension into half, one for x and one for y
|
40 |
+
half_dim = dim // 2
|
41 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, half_dim, 2, dtype=torch.int64).float().to(device) / half_dim)) # namely theta, 1 / (10000^(i/half_dim)), i=0,2,..., half_dim-2
|
42 |
+
t_height = torch.arange(max_height, device=device, dtype=torch.int64).type_as(inv_freq)
|
43 |
+
t_width = torch.arange(max_width, device=device, dtype=torch.int64).type_as(inv_freq)
|
44 |
+
t_height = t_height / scaling_factor
|
45 |
+
freqs_height = torch.outer(t_height, inv_freq) # (max_height, dim / (1 for 1d, 2 for 2d, 3 for 3d) / 2), namely y*theta
|
46 |
+
t_width = t_width / scaling_factor
|
47 |
+
freqs_width = torch.outer(t_width, inv_freq) # (max_width, dim / (1 for 1d, 2 for 2d, 3 for 3d) / 2), namely x*theta
|
48 |
+
freqs_grid_map = torch.concat([
|
49 |
+
freqs_height[:, None, :].expand(-1, max_width, -1), # (max_height, max_width, dim / (1 for 1d, 2 for 2d, 3 for 3d) / 2)
|
50 |
+
freqs_width[None, :, :].expand(max_height, -1, -1), # (max_height, max_width, dim / (1 for 1d, 2 for 2d, 3 for 3d) / 2)
|
51 |
+
], dim=-1) # (max_height, max_width, dim / (1 for 1d, 2 for 2d, 3 for 3d))
|
52 |
+
freqs_grid_map = torch.stack([torch.cos(freqs_grid_map), torch.sin(freqs_grid_map)], dim=0)
|
53 |
+
# (2, max_height, max_width, dim / (1 for 1d, 2 for 2d, 3 for 3d))
|
54 |
+
|
55 |
+
rope2d_freqs_grid = {}
|
56 |
+
for h_div_w in dynamic_resolution_h_w:
|
57 |
+
scale_schedule = dynamic_resolution_h_w[h_div_w]['1M']['scales']
|
58 |
+
_, ph, pw = scale_schedule[-1]
|
59 |
+
max_edge_length = freqs_grid_map.shape[1]
|
60 |
+
if ph >= pw:
|
61 |
+
uph, upw = max_edge_length, int(max_edge_length / ph * pw)
|
62 |
+
else:
|
63 |
+
uph, upw = int(max_edge_length / pw * ph), max_edge_length
|
64 |
+
rope_cache_list = []
|
65 |
+
for (_, ph, pw) in scale_schedule:
|
66 |
+
ph_mul_pw = ph * pw
|
67 |
+
if rope2d_normalized_by_hw == 1: # downsample
|
68 |
+
rope_cache = F.interpolate(freqs_grid_map[:, :uph, :upw, :].permute([0,3,1,2]), size=(ph, pw), mode='bilinear', align_corners=True)
|
69 |
+
rope_cache = rope_cache.permute([0,2,3,1]) # (2, ph, pw, half_head_dim)
|
70 |
+
elif rope2d_normalized_by_hw == 2: # star stylee
|
71 |
+
_, uph, upw = scale_schedule[-1]
|
72 |
+
indices = torch.stack([
|
73 |
+
(torch.arange(ph) * (uph / ph)).reshape(ph, 1).expand(ph, pw),
|
74 |
+
(torch.arange(pw) * (upw / pw)).reshape(1, pw).expand(ph, pw),
|
75 |
+
], dim=-1).round().int() # (ph, pw, 2)
|
76 |
+
indices = indices.reshape(-1, 2) # (ph*pw, 2)
|
77 |
+
rope_cache = freqs_grid_map[:, indices[:,0], indices[:,1], :] # (2, ph*pw, half_head_dim)
|
78 |
+
rope_cache = rope_cache.reshape(2, ph, pw, -1)
|
79 |
+
elif rope2d_normalized_by_hw == 0:
|
80 |
+
rope_cache = freqs_grid_map[:, :ph, :pw, :] # (2, ph, pw, half_head_dim)
|
81 |
+
else:
|
82 |
+
raise ValueError(f'Unknown rope2d_normalized_by_hw: {rope2d_normalized_by_hw}')
|
83 |
+
rope_cache_list.append(rope_cache.reshape(2, ph_mul_pw, -1))
|
84 |
+
cat_rope_cache = torch.cat(rope_cache_list, 1) # (2, seq_len, half_head_dim)
|
85 |
+
if cat_rope_cache.shape[1] % pad_to_multiplier:
|
86 |
+
pad = torch.zeros(2, pad_to_multiplier - cat_rope_cache.shape[1] % pad_to_multiplier, half_dim)
|
87 |
+
cat_rope_cache = torch.cat([cat_rope_cache, pad], dim=1)
|
88 |
+
cat_rope_cache = cat_rope_cache[:,None,None,None] # (2, 1, 1, 1, seq_len, half_dim)
|
89 |
+
for pn in dynamic_resolution_h_w[h_div_w]:
|
90 |
+
scale_schedule = dynamic_resolution_h_w[h_div_w][pn]['scales']
|
91 |
+
tmp_scale_schedule = [(1, h, w) for _, h, w in scale_schedule]
|
92 |
+
rope2d_freqs_grid[str(tuple(tmp_scale_schedule))] = cat_rope_cache
|
93 |
+
return rope2d_freqs_grid
|
94 |
+
|
95 |
+
|
96 |
+
def apply_rotary_emb(q, k, scale_schedule, rope2d_freqs_grid, pad_to_multiplier, rope2d_normalized_by_hw, scale_ind):
|
97 |
+
qk = torch.stack((q, k), dim=0) #(2, batch_size, heads, seq_len, head_dim)
|
98 |
+
device_type = qk.device.type
|
99 |
+
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
100 |
+
with torch.autocast(device_type=device_type, enabled=False):
|
101 |
+
seq_len = qk.shape[3]
|
102 |
+
start = 0
|
103 |
+
if scale_ind >= 1:
|
104 |
+
assert len(scale_schedule[0]) == 3
|
105 |
+
start = np.sum([item[0] * item[1] * item[2] for item in scale_schedule[:scale_ind]])
|
106 |
+
rope2d_freqs_grid[str(tuple(scale_schedule))] = rope2d_freqs_grid[str(tuple(scale_schedule))].to(qk.device)
|
107 |
+
assert start+seq_len <= rope2d_freqs_grid[str(tuple(scale_schedule))].shape[4]
|
108 |
+
rope_cache = rope2d_freqs_grid[str(tuple(scale_schedule))][:, :, :, :, start:start+seq_len] # rope_cache shape: [2, 1, 1, 1, seq_len, half_head_dim]
|
109 |
+
qk = qk.reshape(*qk.shape[:-1], -1, 2) #(2, batch_size, heads, seq_len, half_head_dim, 2)
|
110 |
+
qk = torch.stack([
|
111 |
+
rope_cache[0] * qk[...,0] - rope_cache[1] * qk[...,1],
|
112 |
+
rope_cache[1] * qk[...,0] + rope_cache[0] * qk[...,1],
|
113 |
+
], dim=-1) # (2, batch_size, heads, seq_len, half_head_dim, 2), here stack + reshape should not be concate
|
114 |
+
qk = qk.reshape(*qk.shape[:-2], -1) #(2, batch_size, heads, seq_len, head_dim)
|
115 |
+
q, k = qk.unbind(dim=0) # (batch_size, heads, seq_len, head_dim)
|
116 |
+
return q, k
|
117 |
+
|
118 |
+
|
119 |
+
class FastRMSNorm(nn.Module):
|
120 |
+
def __init__(self, C, eps=1e-6, elementwise_affine=True):
|
121 |
+
super().__init__()
|
122 |
+
self.C = C
|
123 |
+
self.eps = eps
|
124 |
+
self.elementwise_affine = elementwise_affine
|
125 |
+
if self.elementwise_affine:
|
126 |
+
self.weight = nn.Parameter(torch.ones(C))
|
127 |
+
else:
|
128 |
+
self.register_buffer('weight', torch.ones(C))
|
129 |
+
|
130 |
+
def forward(self, x):
|
131 |
+
src_type = x.dtype
|
132 |
+
return rms_norm_impl(x.float(), self.weight, epsilon=self.eps).to(src_type)
|
133 |
+
|
134 |
+
def extra_repr(self) -> str:
|
135 |
+
return f'C={self.C}, eps={self.eps:g}, elementwise_affine={self.elementwise_affine}'
|
136 |
+
|
137 |
+
|
138 |
+
def get_dropout_layer(p):
|
139 |
+
return nn.Dropout(p, inplace=True) if p > 0 else nn.Identity()
|
140 |
+
|
141 |
+
|
142 |
+
class FFN(nn.Module):
|
143 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, drop=0., fused_mlp=False):
|
144 |
+
super().__init__()
|
145 |
+
self.fused_mlp_func = fused_mlp_func if fused_mlp else None
|
146 |
+
out_features = out_features or in_features
|
147 |
+
hidden_features = hidden_features or in_features
|
148 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
149 |
+
self.act = nn.GELU(approximate='tanh')
|
150 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
151 |
+
self.drop = get_dropout_layer(drop)
|
152 |
+
self.heuristic = -1
|
153 |
+
|
154 |
+
def forward(self, x):
|
155 |
+
if self.fused_mlp_func is not None:
|
156 |
+
return self.drop(self.fused_mlp_func(
|
157 |
+
x=x,
|
158 |
+
weight1=self.fc1.weight,
|
159 |
+
weight2=self.fc2.weight,
|
160 |
+
bias1=self.fc1.bias,
|
161 |
+
bias2=self.fc2.bias,
|
162 |
+
activation='gelu_approx',
|
163 |
+
save_pre_act=self.training,
|
164 |
+
return_residual=False,
|
165 |
+
checkpoint_lvl=0,
|
166 |
+
heuristic=self.heuristic,
|
167 |
+
process_group=None,
|
168 |
+
))
|
169 |
+
else:
|
170 |
+
return self.drop(self.fc2( self.act(self.fc1(x)) ))
|
171 |
+
|
172 |
+
def extra_repr(self) -> str:
|
173 |
+
return f'fused_mlp={self.fused_mlp_func is not None}'
|
174 |
+
|
175 |
+
|
176 |
+
class FFNSwiGLU(nn.Module):
|
177 |
+
def __init__(self, in_features, hidden_features, out_features=None, drop=0., fused_mlp=False):
|
178 |
+
super().__init__()
|
179 |
+
self.fused_mlp_func = None
|
180 |
+
hidden_features = round(2 * hidden_features / 3 / 256) * 256
|
181 |
+
|
182 |
+
out_features = out_features or in_features
|
183 |
+
self.fcg = nn.Linear(in_features, hidden_features, bias=False)
|
184 |
+
self.fc1 = nn.Linear(in_features, hidden_features, bias=False)
|
185 |
+
self.fc2 = nn.Linear(hidden_features, out_features, bias=False)
|
186 |
+
self.drop = get_dropout_layer(drop)
|
187 |
+
|
188 |
+
def forward(self, x):
|
189 |
+
return self.drop(self.fc2( F.silu(self.fcg(x), inplace=True).mul_(self.fc1(x)) ))
|
190 |
+
|
191 |
+
def extra_repr(self) -> str:
|
192 |
+
return f'fused_mlp={self.fused_mlp_func is not None}'
|
193 |
+
|
194 |
+
|
195 |
+
class SelfAttention(nn.Module):
|
196 |
+
def __init__(
|
197 |
+
self, embed_dim=768, num_heads=12,
|
198 |
+
proj_drop=0., tau=1, cos_attn=False, customized_flash_attn=True, use_flex_attn=False,
|
199 |
+
batch_size=2, pad_to_multiplier=1, rope2d_normalized_by_hw=0,
|
200 |
+
):
|
201 |
+
"""
|
202 |
+
:param embed_dim: model's width
|
203 |
+
:param num_heads: num heads of multi-head attention
|
204 |
+
:param proj_drop: always 0 for testing
|
205 |
+
:param tau: always 1
|
206 |
+
:param cos_attn: always True: during attention, q and k will be L2-normalized and scaled by a head-wise learnable parameter self.scale_mul_1H11
|
207 |
+
:param customized_flash_attn:
|
208 |
+
"""
|
209 |
+
super().__init__()
|
210 |
+
assert embed_dim % num_heads == 0
|
211 |
+
self.using_flash = customized_flash_attn
|
212 |
+
|
213 |
+
self.num_heads, self.head_dim = num_heads, embed_dim // num_heads
|
214 |
+
self.tau, self.cos_attn = tau, cos_attn
|
215 |
+
if self.cos_attn:
|
216 |
+
self.scale = 1
|
217 |
+
size = (1, 1, self.num_heads, 1) if self.using_flash else (1, self.num_heads, 1, 1)
|
218 |
+
# size: 11H1 or 1H11
|
219 |
+
self.scale_mul_1H11 = nn.Parameter(torch.full(size=size, fill_value=4.0).log(), requires_grad=True)
|
220 |
+
self.max_scale_mul = torch.log(torch.tensor(100)).item()
|
221 |
+
else:
|
222 |
+
self.scale = 1 / math.sqrt(self.head_dim) / self.tau
|
223 |
+
|
224 |
+
self.mat_qkv = nn.Linear(embed_dim, embed_dim * 3, bias=False)
|
225 |
+
self.q_bias, self.v_bias = nn.Parameter(torch.zeros(embed_dim)), nn.Parameter(torch.zeros(embed_dim))
|
226 |
+
self.register_buffer('zero_k_bias', torch.zeros(embed_dim))
|
227 |
+
|
228 |
+
self.proj = nn.Linear(embed_dim, embed_dim)
|
229 |
+
self.proj_drop = get_dropout_layer(proj_drop)
|
230 |
+
|
231 |
+
self.caching = False # kv caching: only used during inference
|
232 |
+
self.cached_k = None # kv caching: only used during inference
|
233 |
+
self.cached_v = None # kv caching: only used during inference
|
234 |
+
|
235 |
+
self.batch_size = batch_size
|
236 |
+
self.use_flex_attn = use_flex_attn
|
237 |
+
self.pad_to_multiplier = pad_to_multiplier
|
238 |
+
|
239 |
+
self.rope2d_normalized_by_hw = rope2d_normalized_by_hw
|
240 |
+
|
241 |
+
|
242 |
+
def kv_caching(self, enable: bool): # kv caching: only used during inference
|
243 |
+
self.caching = enable
|
244 |
+
self.cached_k = None
|
245 |
+
self.cached_v = None
|
246 |
+
|
247 |
+
# NOTE: attn_bias_or_two_vector is None during inference
|
248 |
+
def forward(self, x, attn_bias_or_two_vector: Union[torch.Tensor, Tuple[torch.IntTensor, torch.IntTensor]], attn_fn=None, scale_schedule=None, rope2d_freqs_grid=None, scale_ind=0):
|
249 |
+
"""
|
250 |
+
:param (fp32) x: shaped (B or batch_size, L or seq_length, C or hidden_dim); if seq-parallel is used, the `L` dim would be shared
|
251 |
+
:param (fp32) attn_bias_or_two_vector:
|
252 |
+
if not using_flash:
|
253 |
+
a block-wise, lower-triangle matrix, like:
|
254 |
+
[[[[0, -, -, -, -, -, -, -, -, -, -, -, -, -],
|
255 |
+
[0, 0, 0, 0, 0, -, -, -, -, -, -, -, -, -],
|
256 |
+
[0, 0, 0, 0, 0, -, -, -, -, -, -, -, -, -],
|
257 |
+
[0, 0, 0, 0, 0, -, -, -, -, -, -, -, -, -],
|
258 |
+
[0, 0, 0, 0, 0, -, -, -, -, -, -, -, -, -],
|
259 |
+
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
260 |
+
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
261 |
+
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
262 |
+
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
263 |
+
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
264 |
+
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
265 |
+
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
266 |
+
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
267 |
+
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]]]
|
268 |
+
where 0 means visible and - means invisible (-inf)
|
269 |
+
else:
|
270 |
+
a tuple of two 1-dim int vector (VAR_visible_kvlen, VAR_invisible_qlen)
|
271 |
+
:return: shaped (B or batch_size, L or seq_length, C or hidden_dim); if seq-parallel is used, the `L` dim would be shared
|
272 |
+
"""
|
273 |
+
# x: fp32
|
274 |
+
B, L, C = x.shape
|
275 |
+
|
276 |
+
# qkv: amp, bf16
|
277 |
+
qkv = F.linear(input=x, weight=self.mat_qkv.weight, bias=torch.cat((self.q_bias, self.zero_k_bias, self.v_bias))).view(B, L, 3, self.num_heads, self.head_dim) # BL3Hc
|
278 |
+
if self.using_flash: q, k, v = qkv.unbind(dim=2); L_dim = 1 # q or k or v: all are shaped in (B:batch_size, L:seq_len, H:heads, c:head_dim)
|
279 |
+
else: q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(dim=0); L_dim = 2 # q or k or v: all are shaped in (B:batch_size, H:heads, L:seq_len, c:head_dim)
|
280 |
+
|
281 |
+
if self.cos_attn: # always True
|
282 |
+
scale_mul = self.scale_mul_1H11.clamp_max(self.max_scale_mul).exp() # 11H1 (flash), or 1H11 (not flash)
|
283 |
+
q = F.normalize(q, dim=-1, eps=1e-12).mul(scale_mul).contiguous() # fp32
|
284 |
+
k = F.normalize(k, dim=-1, eps=1e-12).contiguous() # fp32
|
285 |
+
v = v.contiguous() # bf16
|
286 |
+
else: # be contiguous, to make kernel happy
|
287 |
+
q = q.contiguous() # bf16
|
288 |
+
k = k.contiguous() # bf16
|
289 |
+
v = v.contiguous() # bf16
|
290 |
+
if rope2d_freqs_grid is not None:
|
291 |
+
q, k = apply_rotary_emb(q, k, scale_schedule, rope2d_freqs_grid, self.pad_to_multiplier, self.rope2d_normalized_by_hw, scale_ind) #, freqs_cis=freqs_cis)
|
292 |
+
if self.caching: # kv caching: only used during inference
|
293 |
+
if self.cached_k is None: self.cached_k = k; self.cached_v = v
|
294 |
+
else: k = self.cached_k = torch.cat((self.cached_k, k), dim=L_dim); v = self.cached_v = torch.cat((self.cached_v, v), dim=L_dim)
|
295 |
+
|
296 |
+
if self.using_flash:
|
297 |
+
if attn_bias_or_two_vector is not None: # training
|
298 |
+
kw = dict(VAR_visible_kvlen=attn_bias_or_two_vector[0], VAR_invisible_qlen=attn_bias_or_two_vector[1])
|
299 |
+
else: # inference (autoregressive sampling)
|
300 |
+
kw = dict()
|
301 |
+
oup = flash_attn_func(q.to(v.dtype), k.to(v.dtype), v, dropout_p=0, softmax_scale=self.scale, **kw).view(B, L, C)
|
302 |
+
else:
|
303 |
+
# if self.cos_attn: q, k are in fp32; v is in bf16
|
304 |
+
# else: q, k, v are in bf16
|
305 |
+
if self.use_flex_attn and attn_fn is not None:
|
306 |
+
oup = attn_fn(q, k, v, scale=self.scale).transpose(1, 2).reshape(B, L, C)
|
307 |
+
else:
|
308 |
+
oup = slow_attn(query=q, key=k, value=v, scale=self.scale, attn_mask=attn_bias_or_two_vector, dropout_p=0).transpose(1, 2).reshape(B, L, C)
|
309 |
+
# oup: bf16
|
310 |
+
|
311 |
+
return self.proj_drop(self.proj(oup))
|
312 |
+
|
313 |
+
def extra_repr(self) -> str:
|
314 |
+
tail = ''
|
315 |
+
return f'using_flash={self.using_flash}, tau={self.tau}, cos_attn={self.cos_attn}{tail}'
|
316 |
+
|
317 |
+
|
318 |
+
class CrossAttention(nn.Module):
|
319 |
+
def __init__(
|
320 |
+
self, for_attn_pool=False, embed_dim=768, kv_dim=4096, num_heads=12,
|
321 |
+
proj_drop=0., cos_attn=False,
|
322 |
+
):
|
323 |
+
"""
|
324 |
+
:param for_attn_pool: only used in VAR.text_proj_for_sos
|
325 |
+
:param embed_dim: Q's dim
|
326 |
+
:param kv_dim: K's and V's dim
|
327 |
+
:param num_heads: num heads of multi-head attention
|
328 |
+
:param proj_drop: proj drop out
|
329 |
+
:param cos_attn: during attention, q and k will be L2-normalized and scaled by a head-wise learnable parameter self.scale_mul_1H11
|
330 |
+
"""
|
331 |
+
cos_attn = False # TODO: never use cos attn in cross attention with T5 kv
|
332 |
+
super().__init__()
|
333 |
+
self.for_attn_pool = for_attn_pool
|
334 |
+
self.embed_dim = embed_dim
|
335 |
+
self.kv_dim = kv_dim
|
336 |
+
assert embed_dim % num_heads == 0
|
337 |
+
self.num_heads, self.head_dim = num_heads, embed_dim // num_heads # =64
|
338 |
+
self.cos_attn = cos_attn
|
339 |
+
if self.cos_attn:
|
340 |
+
self.scale = 1
|
341 |
+
self.scale_mul_1H1 = nn.Parameter(torch.full(size=(1, self.num_heads, 1, 1), fill_value=4.0).log(), requires_grad=True)
|
342 |
+
self.max_scale_mul = torch.log(torch.tensor(100)).item()
|
343 |
+
else:
|
344 |
+
self.scale = 1 / math.sqrt(self.head_dim)
|
345 |
+
|
346 |
+
if for_attn_pool:
|
347 |
+
q = torch.empty(1, self.num_heads, self.head_dim)
|
348 |
+
nn.init.trunc_normal_(q, mean=0, std=math.sqrt(1 / embed_dim / 3))
|
349 |
+
self.mat_q = nn.Parameter(q)
|
350 |
+
else:
|
351 |
+
self.mat_q = nn.Linear(embed_dim, embed_dim, bias=True)
|
352 |
+
self.mat_kv = nn.Linear(kv_dim, embed_dim*2, bias=False)
|
353 |
+
self.v_bias = nn.Parameter(torch.zeros(embed_dim))
|
354 |
+
self.register_buffer('zero_k_bias', torch.zeros(embed_dim))
|
355 |
+
|
356 |
+
self.proj = nn.Linear(embed_dim, embed_dim)
|
357 |
+
self.proj_drop = get_dropout_layer(proj_drop)
|
358 |
+
|
359 |
+
def forward(self, q, ca_kv):
|
360 |
+
"""
|
361 |
+
:param q: shaped as (batch, seq_len, Q_dim)
|
362 |
+
:param ca_kv: contains several vectors, each of which is shaped as (len_i, KV_dim). We have [len_1xKV_dim, len_2xKV_dim, len_3xKV_dim, ...] and lens == [len_1, len_2, len_3, ...]
|
363 |
+
- kv_compact: shaped as (sum(lens), KV_dim)
|
364 |
+
- cu_seqlens_k: cumulated sum of lens
|
365 |
+
- max_seqlen_k: int, max(lens)
|
366 |
+
NOTE: seq_len (num of Qs) can reach 10k; but len_i (num of KVs) must <= 256
|
367 |
+
|
368 |
+
:return: shaped as (batch, seq_len, Q_dim)
|
369 |
+
"""
|
370 |
+
kv_compact, cu_seqlens_k, max_seqlen_k = ca_kv
|
371 |
+
N = kv_compact.shape[0]
|
372 |
+
|
373 |
+
kv_compact = F.linear(kv_compact, weight=self.mat_kv.weight, bias=torch.cat((self.zero_k_bias, self.v_bias))).view(N, 2, self.num_heads, self.head_dim) # NC => N2Hc
|
374 |
+
# attn_bias = xformers.ops.fmha.BlockDiagonalMask.from_seqlens
|
375 |
+
|
376 |
+
if not self.for_attn_pool:
|
377 |
+
B, Lq = q.shape[:2]
|
378 |
+
q_compact = self.mat_q(q).view(-1, self.num_heads, self.head_dim)
|
379 |
+
else:
|
380 |
+
B = cu_seqlens_k.shape[0] - 1
|
381 |
+
Lq = 1
|
382 |
+
q_compact = self.mat_q.repeat(B, 1, 1).to(dtype=kv_compact.dtype)
|
383 |
+
|
384 |
+
if self.cos_attn: # always False
|
385 |
+
scale_mul = self.scale_mul_1H1.clamp_max(self.max_scale_mul).exp()
|
386 |
+
k, v = kv_compact.unbind(dim=1)
|
387 |
+
q_compact = F.normalize(q_compact, dim=-1).mul(scale_mul)
|
388 |
+
k = F.normalize(k, dim=-1)
|
389 |
+
kv_compact = torch.stack((k, v), dim=1)
|
390 |
+
|
391 |
+
q_compact = q_compact.contiguous()
|
392 |
+
kv_compact = kv_compact.contiguous()
|
393 |
+
|
394 |
+
cu_seqlens_q = torch.arange(0, Lq * (B+1), Lq, dtype=torch.int32, device=q_compact.device)
|
395 |
+
if q_compact.dtype == torch.float32: # todo: fp16 or bf16?
|
396 |
+
oup = flash_attn_varlen_kvpacked_func(q=q_compact.to(dtype=torch.bfloat16), kv=kv_compact.to(dtype=torch.bfloat16), cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=Lq, max_seqlen_k=max_seqlen_k, dropout_p=0, softmax_scale=self.scale).reshape(B, Lq, -1)
|
397 |
+
oup = oup.float()
|
398 |
+
else:
|
399 |
+
oup = flash_attn_varlen_kvpacked_func(q=q_compact, kv=kv_compact, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=Lq, max_seqlen_k=max_seqlen_k, dropout_p=0, softmax_scale=self.scale).reshape(B, Lq, -1)
|
400 |
+
|
401 |
+
return self.proj_drop(self.proj(oup))
|
402 |
+
|
403 |
+
def extra_repr(self) -> str:
|
404 |
+
return f'Cq={self.embed_dim}, Ckv={self.kv_dim}, cos_attn={self.cos_attn}'
|
405 |
+
|
406 |
+
|
407 |
+
class SelfAttnBlock(nn.Module):
|
408 |
+
def __init__(
|
409 |
+
self, embed_dim, kv_dim, cross_attn_layer_scale, cond_dim, act: bool, shared_aln: bool, norm_layer: partial,
|
410 |
+
num_heads, mlp_ratio=4., drop=0., drop_path=0., tau=1, cos_attn=False,
|
411 |
+
swiglu=False, customized_flash_attn=False, fused_mlp=False, fused_norm_func=None, checkpointing_sa_only=False,
|
412 |
+
):
|
413 |
+
super(SelfAttnBlock, self).__init__()
|
414 |
+
self.C, self.D = embed_dim, cond_dim
|
415 |
+
self.drop_path_rate = drop_path
|
416 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
417 |
+
self.attn = SelfAttention(
|
418 |
+
embed_dim=embed_dim, num_heads=num_heads, proj_drop=drop, tau=tau, cos_attn=cos_attn, customized_flash_attn=customized_flash_attn, attn_fn = attn_fn
|
419 |
+
)
|
420 |
+
self.using_swiglu = swiglu
|
421 |
+
self.ffn = (FFNSwiGLU if swiglu else FFN)(in_features=embed_dim, hidden_features=round(embed_dim * mlp_ratio / 256) * 256, drop=drop, fused_mlp=fused_mlp)
|
422 |
+
|
423 |
+
self.ln_wo_grad = norm_layer(embed_dim, elementwise_affine=False)
|
424 |
+
self.fused_norm_func = fused_norm_func
|
425 |
+
self.norm_eps = norm_layer.keywords.get('eps', 1e-6)
|
426 |
+
|
427 |
+
self.shared_aln = shared_aln
|
428 |
+
if self.shared_aln:
|
429 |
+
self.ada_gss = nn.Parameter(torch.randn(1, 1, 6, embed_dim) / embed_dim**0.5)
|
430 |
+
else:
|
431 |
+
lin = nn.Linear(cond_dim, 6*embed_dim)
|
432 |
+
self.ada_lin = nn.Sequential(nn.SiLU(inplace=False), lin) if act else nn.Sequential(lin)
|
433 |
+
|
434 |
+
# NOTE: attn_bias_or_two_vector is None during inference
|
435 |
+
def forward(self, x, cond_BD, ca_kv, attn_bias_or_two_vector): # todo: minGPT and vqgan also uses pre-norm, just like this, while MaskGiT uses post-norm
|
436 |
+
with torch.cuda.amp.autocast(enabled=False):
|
437 |
+
if self.shared_aln: # always True; (1, 1, 6, C) + (B, 1, 6, C)
|
438 |
+
gamma1, gamma2, scale1, scale2, shift1, shift2 = (self.ada_gss + cond_BD).unbind(2) # 116C + B16C =unbind(2)=> 6 B1C
|
439 |
+
else:
|
440 |
+
gamma1, gamma2, scale1, scale2, shift1, shift2 = self.ada_lin(cond_BD).view(-1, 1, 6, self.C).unbind(2)
|
441 |
+
|
442 |
+
if self.fused_ada_norm is None:
|
443 |
+
x = x + self.drop_path(self.attn( self.ln_wo_grad(x.float()).mul(scale1.add(1)).add_(shift1), attn_bias_or_two_vector=attn_bias_or_two_vector ).mul_(gamma1))
|
444 |
+
x = x + self.drop_path(self.ffn( self.ln_wo_grad(x.float()).mul(scale2.add(1)).add_(shift2) ).mul(gamma2)) # this mul(gamma2) cannot be in-placed cuz we possibly use FusedMLP
|
445 |
+
else:
|
446 |
+
x = x + self.drop_path(self.attn(self.fused_ada_norm(C=self.C, eps=self.norm_eps, x=x, scale=scale1, shift=shift1), attn_bias_or_two_vector=attn_bias_or_two_vector).mul_(gamma1))
|
447 |
+
x = x + self.drop_path(self.ffn(self.fused_ada_norm(C=self.C, eps=self.norm_eps, x=x, scale=scale2, shift=shift2)).mul(gamma2)) # this mul(gamma2) cannot be in-placed cuz we possibly use FusedMLP
|
448 |
+
return x
|
449 |
+
|
450 |
+
def extra_repr(self) -> str:
|
451 |
+
return f'shared_aln={self.shared_aln}, fused_norm={self.fused_norm_func is not None}'
|
452 |
+
|
453 |
+
|
454 |
+
class CrossAttnBlock(nn.Module):
|
455 |
+
def __init__(
|
456 |
+
self,
|
457 |
+
embed_dim, kv_dim, cross_attn_layer_scale, cond_dim, act: bool, shared_aln: bool, norm_layer: partial,
|
458 |
+
num_heads, mlp_ratio=4., drop=0., drop_path=0., tau=1, cos_attn=False,
|
459 |
+
swiglu=False, customized_flash_attn=False, fused_mlp=False, fused_norm_func=None, checkpointing_sa_only=False,
|
460 |
+
use_flex_attn=False, batch_size=2, pad_to_multiplier=1, apply_rope2d=False, rope2d_normalized_by_hw=False,
|
461 |
+
):
|
462 |
+
super(CrossAttnBlock, self).__init__()
|
463 |
+
self.C, self.D = embed_dim, cond_dim
|
464 |
+
self.drop_path_rate = drop_path
|
465 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
466 |
+
self.sa = SelfAttention(
|
467 |
+
embed_dim=embed_dim, num_heads=num_heads, proj_drop=drop, tau=tau, cos_attn=cos_attn, customized_flash_attn=customized_flash_attn,
|
468 |
+
use_flex_attn=use_flex_attn, batch_size=batch_size, pad_to_multiplier=pad_to_multiplier, rope2d_normalized_by_hw=rope2d_normalized_by_hw,
|
469 |
+
)
|
470 |
+
self.ca = CrossAttention(embed_dim=embed_dim, kv_dim=kv_dim, num_heads=num_heads, proj_drop=drop, cos_attn=cos_attn)
|
471 |
+
self.using_swiglu = swiglu
|
472 |
+
self.ffn = (FFNSwiGLU if swiglu else FFN)(in_features=embed_dim, hidden_features=round(embed_dim * mlp_ratio / 256) * 256, drop=drop, fused_mlp=fused_mlp)
|
473 |
+
|
474 |
+
self.ln_wo_grad = norm_layer(embed_dim, elementwise_affine=False)
|
475 |
+
self.fused_norm_func = fused_norm_func
|
476 |
+
self.norm_eps = norm_layer.keywords.get('eps', 1e-6)
|
477 |
+
self.ca_norm = norm_layer(embed_dim, elementwise_affine=True)
|
478 |
+
|
479 |
+
self.shared_aln = shared_aln
|
480 |
+
if self.shared_aln: # always True
|
481 |
+
self.ada_gss = nn.Parameter(torch.randn(1, 1, 6, embed_dim) / embed_dim**0.5)
|
482 |
+
else:
|
483 |
+
lin = nn.Linear(cond_dim, 6*embed_dim)
|
484 |
+
self.ada_lin = nn.Sequential(nn.SiLU(inplace=False), lin) if act else nn.Sequential(lin)
|
485 |
+
|
486 |
+
if cross_attn_layer_scale >= 0:
|
487 |
+
self.ca_gamma = nn.Parameter(cross_attn_layer_scale * torch.ones(embed_dim), requires_grad=True)
|
488 |
+
else:
|
489 |
+
self.ca_gamma = 1
|
490 |
+
|
491 |
+
self.checkpointing_sa_only = checkpointing_sa_only
|
492 |
+
|
493 |
+
# NOTE: attn_bias_or_two_vector is None during inference
|
494 |
+
def forward(self, x, cond_BD, ca_kv, attn_bias_or_two_vector, attn_fn=None, scale_schedule=None, rope2d_freqs_grid=None, scale_ind=0): # todo: minGPT and vqgan also uses pre-norm, just like this, while MaskGiT uses post-norm
|
495 |
+
with torch.cuda.amp.autocast(enabled=False): # disable half precision
|
496 |
+
if self.shared_aln: # always True; (1, 1, 6, C) + (B, 1, 6, C)
|
497 |
+
gamma1, gamma2, scale1, scale2, shift1, shift2 = (self.ada_gss + cond_BD).unbind(2) # 116C + B16C =unbind(2)=> 6 B1C
|
498 |
+
else:
|
499 |
+
gamma1, gamma2, scale1, scale2, shift1, shift2 = self.ada_lin(cond_BD).view(-1, 1, 6, self.C).unbind(2)
|
500 |
+
|
501 |
+
if self.fused_norm_func is None:
|
502 |
+
x_sa = self.ln_wo_grad(x.float()).mul(scale1.add(1)).add_(shift1)
|
503 |
+
if self.checkpointing_sa_only and self.training:
|
504 |
+
x_sa = checkpoint(self.sa, x_sa, attn_bias_or_two_vector, attn_fn, scale_schedule, rope2d_freqs_grid, use_reentrant=False)
|
505 |
+
else:
|
506 |
+
x_sa = self.sa(x_sa, attn_bias_or_two_vector, attn_fn, scale_schedule, rope2d_freqs_grid)
|
507 |
+
x = x + self.drop_path(x_sa.mul_(gamma1))
|
508 |
+
x = x + self.ca(self.ca_norm(x), ca_kv).float().mul_(self.ca_gamma)
|
509 |
+
x = x + self.drop_path(self.ffn( self.ln_wo_grad(x.float()).mul(scale2.add(1)).add_(shift2) ).mul(gamma2)) # this mul(gamma2) cannot be in-placed cuz we possibly use FusedMLP
|
510 |
+
else:
|
511 |
+
x_sa = self.fused_norm_func(C=self.C, eps=self.norm_eps, x=x, scale=scale1, shift=shift1)
|
512 |
+
if self.checkpointing_sa_only and self.training:
|
513 |
+
x_sa = checkpoint(self.sa, x_sa, attn_bias_or_two_vector, attn_fn, scale_schedule, rope2d_freqs_grid, use_reentrant=False)
|
514 |
+
else:
|
515 |
+
x_sa = self.sa(x_sa, attn_bias_or_two_vector, attn_fn, scale_schedule, rope2d_freqs_grid, scale_ind=scale_ind)
|
516 |
+
x = x + self.drop_path(x_sa.mul_(gamma1))
|
517 |
+
x = x + self.ca(self.ca_norm(x), ca_kv).float().mul_(self.ca_gamma)
|
518 |
+
x = x + self.drop_path(self.ffn(self.fused_norm_func(C=self.C, eps=self.norm_eps, x=x, scale=scale2, shift=shift2)).mul(gamma2)) # this mul(gamma2) cannot be in-placed cuz we possibly use FusedMLP
|
519 |
+
return x
|
520 |
+
|
521 |
+
def extra_repr(self) -> str:
|
522 |
+
return f'shared_aln={self.shared_aln}, fused_norm={self.fused_norm_func is not None}, ca_gamma={"<learnable>" if isinstance(self.ca_gamma, nn.Parameter) else self.ca_gamma}'
|
523 |
+
|
524 |
+
|
525 |
+
class AdaLNBeforeHead(nn.Module):
|
526 |
+
def __init__(self, C, D, act: bool, norm_layer: partial, fused_norm_func=None): # C: embed_dim, D: cond_dim
|
527 |
+
super().__init__()
|
528 |
+
self.C, self.D = C, D
|
529 |
+
self.ln_wo_grad = norm_layer(C, elementwise_affine=False)
|
530 |
+
self.fused_norm_func = fused_norm_func
|
531 |
+
self.norm_eps = norm_layer.keywords.get('eps', 1e-6)
|
532 |
+
lin = nn.Linear(D, 2*C)
|
533 |
+
self.ada_lin = nn.Sequential(nn.SiLU(inplace=False), lin) if act else nn.Sequential(lin)
|
534 |
+
|
535 |
+
def forward(self, x_BLC: torch.Tensor, cond_BD: Optional[torch.Tensor]):
|
536 |
+
scale, shift = self.ada_lin(cond_BD).view(-1, 1, 2, self.C).unbind(2)
|
537 |
+
if self.fused_norm_func is None:
|
538 |
+
return self.ln_wo_grad(x_BLC).mul(scale.add(1)).add_(shift)
|
539 |
+
else:
|
540 |
+
return self.fused_norm_func(C=self.C, eps=self.norm_eps, x=x_BLC, scale=scale, shift=shift)
|
541 |
+
|
542 |
+
|
543 |
+
def main():
|
544 |
+
dev = 'cpu' # 'cuda' if torch.cuda.is_available() else 'cpu'
|
545 |
+
rng = torch.Generator(device=dev)
|
546 |
+
# for Li in ([1, 3, 5], [1, 3]):
|
547 |
+
rng.manual_seed(0)
|
548 |
+
B, H, cq, ckv = 4, 8, 64, 96
|
549 |
+
Cq = H*cq
|
550 |
+
Ckv = H*ckv
|
551 |
+
|
552 |
+
Li = [5, 4, 7, 6]
|
553 |
+
Lq = 10
|
554 |
+
L = max(Li)
|
555 |
+
attn_bias = torch.zeros(B, 1, Lq, L, device=dev)
|
556 |
+
for i, x in enumerate(Li):
|
557 |
+
attn_bias[i, 0, :, x:] = -torch.inf
|
558 |
+
|
559 |
+
q = torch.randn(B, Lq, H, cq, generator=rng, device=dev)
|
560 |
+
k = torch.randn(B, L, H, ckv, generator=rng, device=dev)
|
561 |
+
v = torch.randn(B, L, H, ckv, generator=rng, device=dev)
|
562 |
+
tq, tk, tv = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) # BHLc
|
563 |
+
|
564 |
+
seqlen_k = torch.tensor(Li, dtype=torch.int32, device=dev)
|
565 |
+
cu_seqlens_k = F.pad(torch.cumsum(seqlen_k, dim=0, dtype=torch.torch.int32), (1, 0))
|
566 |
+
kv = torch.stack([k, v], dim=2)
|
567 |
+
kv_compact = torch.cat([kv[i, :Li[i]] for i in range(B)], dim=0)
|
568 |
+
|
569 |
+
ca = CrossAttention(for_attn_pool=False, embed_dim=Cq, kv_dim=Ckv, num_heads=H)
|
570 |
+
CrossAttention.forward
|
571 |
+
ca(q, (kv_compact, cu_seqlens_k, max(Li))).mean().backward()
|
572 |
+
|
573 |
+
|
574 |
+
if __name__ == '__main__':
|
575 |
+
main()
|
models/bitwise_self_correction.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import os.path as osp
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
|
9 |
+
def labels2image(all_indices, label_type='int_label', scale_schedule=None):
|
10 |
+
summed_codes, recons_imgs = self.vae.decode_from_indices(all_indices, scale_schedule, label_type)
|
11 |
+
recons_img = recons_imgs[0]
|
12 |
+
recons_img = (recons_img + 1) / 2
|
13 |
+
recons_img = recons_img.permute(1, 2, 0).mul_(255).cpu().numpy().astype(np.uint8)[:,:,::-1]
|
14 |
+
return recons_img
|
15 |
+
|
16 |
+
def features2image(raw_features):
|
17 |
+
recons_imgs = self.vae.decode(raw_features.squeeze(-3))
|
18 |
+
recons_img = recons_imgs[0]
|
19 |
+
recons_img = (recons_img + 1) / 2
|
20 |
+
recons_img = recons_img.permute(1, 2, 0).mul_(255).cpu().numpy().astype(np.uint8)[:,:,::-1]
|
21 |
+
return recons_img
|
22 |
+
|
23 |
+
class BitwiseSelfCorrection(object):
|
24 |
+
def __init__(self, vae, args):
|
25 |
+
self.noise_apply_layers = args.noise_apply_layers
|
26 |
+
self.noise_apply_requant = args.noise_apply_requant
|
27 |
+
self.noise_apply_strength = args.noise_apply_strength
|
28 |
+
self.apply_spatial_patchify = args.apply_spatial_patchify
|
29 |
+
self.vae = vae
|
30 |
+
self.debug_bsc = args.debug_bsc
|
31 |
+
|
32 |
+
def flip_requant(self, vae_scale_schedule, inp_B3HW, raw_features, device):
|
33 |
+
with torch.amp.autocast('cuda', enabled = False):
|
34 |
+
B = raw_features.shape[0]
|
35 |
+
if raw_features.dim() == 4:
|
36 |
+
codes_out = raw_features.unsqueeze(2)
|
37 |
+
else:
|
38 |
+
codes_out = raw_features
|
39 |
+
cum_var_input = 0
|
40 |
+
gt_all_bit_indices = []
|
41 |
+
pred_all_bit_indices = []
|
42 |
+
x_BLC_wo_prefix = []
|
43 |
+
for si, (pt, ph, pw) in enumerate(vae_scale_schedule):
|
44 |
+
residual = codes_out - cum_var_input
|
45 |
+
if si != len(vae_scale_schedule)-1:
|
46 |
+
residual = F.interpolate(residual, size=vae_scale_schedule[si], mode=self.vae.quantizer.z_interplote_down).contiguous()
|
47 |
+
quantized, _, bit_indices, loss = self.vae.quantizer.lfq(residual) # quantized shape: [B, d_vae, 1, h, w], bit_indices shape: [B,1,h,w,d_vae]
|
48 |
+
gt_all_bit_indices.append(bit_indices)
|
49 |
+
if si < self.noise_apply_layers:
|
50 |
+
noise_apply_strength = np.random.randint(0, 100 * self.noise_apply_strength+1) * 0.01
|
51 |
+
mask = torch.rand(*bit_indices.shape).to(device) < noise_apply_strength
|
52 |
+
pred_bit_indices = bit_indices.clone()
|
53 |
+
pred_bit_indices[mask] = 1 - pred_bit_indices[mask]
|
54 |
+
pred_all_bit_indices.append(pred_bit_indices)
|
55 |
+
if self.noise_apply_requant:
|
56 |
+
quantized = self.vae.quantizer.lfq.indices_to_codes(pred_bit_indices, label_type = 'bit_label')
|
57 |
+
else:
|
58 |
+
pred_all_bit_indices.append(bit_indices)
|
59 |
+
cum_var_input = cum_var_input + F.interpolate(quantized, size=vae_scale_schedule[-1], mode=self.vae.quantizer.z_interplote_up).contiguous()
|
60 |
+
if si < len(vae_scale_schedule)-1:
|
61 |
+
this_scale_input = F.interpolate(cum_var_input, size=vae_scale_schedule[si+1], mode=self.vae.quantizer.z_interplote_up).contiguous()
|
62 |
+
if self.apply_spatial_patchify:
|
63 |
+
# (B,d,1,H,W) -> (B,d,H,W) -> (B,4d,H/2,W/2)
|
64 |
+
this_scale_input = torch.nn.functional.pixel_unshuffle(this_scale_input.squeeze(-3), 2)
|
65 |
+
x_BLC_wo_prefix.append(this_scale_input.reshape(*this_scale_input.shape[:2], -1).permute(0,2,1)) # (B,H/2*W/2,4C) or (B,H*W,C)
|
66 |
+
|
67 |
+
if self.apply_spatial_patchify:
|
68 |
+
gt_ms_idx_Bl = []
|
69 |
+
for item in gt_all_bit_indices:
|
70 |
+
# item shape: (B,1,H,W,d)
|
71 |
+
item = item.squeeze(1).permute(0,3,1,2) # (B,d,H,W)
|
72 |
+
# (B,d,H,W) -> (B,4d,H/2,W/2)
|
73 |
+
item = torch.nn.functional.pixel_unshuffle(item, 2)
|
74 |
+
# (B,4d,H/2,W/2) -> (B,H/2,W/2,4d) -> (B,H/2*w/2,4d)
|
75 |
+
item = item.permute(0,2,3,1).reshape(B, -1, 4*self.vae.codebook_dim)
|
76 |
+
gt_ms_idx_Bl.append(item)
|
77 |
+
else:
|
78 |
+
gt_ms_idx_Bl = [item.reshape(B, -1, self.vae.codebook_dim) for item in gt_all_bit_indices]
|
79 |
+
x_BLC_wo_prefix = torch.cat(x_BLC_wo_prefix, 1)
|
80 |
+
|
81 |
+
if self.debug_bsc:
|
82 |
+
self.visualize(vae_scale_schedule, inp_B3HW, gt_all_bit_indices, pred_all_bit_indices)
|
83 |
+
|
84 |
+
return x_BLC_wo_prefix, gt_ms_idx_Bl
|
85 |
+
|
86 |
+
def visualize(self, vae_scale_schedule, inp_B3HW, gt_all_bit_indices, pred_all_bit_indices):
|
87 |
+
gt_img = (inp_B3HW.squeeze(-3) + 1) / 2 * 255
|
88 |
+
gt_img = gt_img[0].permute(1,2,0).cpu().numpy().astype(np.uint8)[:,:,::-1]
|
89 |
+
recons_img_2 = labels2image(gt_all_bit_indices, label_type='bit_label', scale_schedule=vae_scale_schedule)
|
90 |
+
recons_img_3 = labels2image(pred_all_bit_indices, label_type='bit_label', scale_schedule=vae_scale_schedule)
|
91 |
+
cat_image = np.concatenate([gt_img, recons_img_2, recons_img_3], axis=1)
|
92 |
+
save_path = osp.abspath('non_teacher_force.jpg')
|
93 |
+
cv2.imwrite(save_path, cat_image)
|
94 |
+
print(f'Save to {save_path}')
|
95 |
+
import pdb; pdb.set_trace()
|
96 |
+
print(cat_image.shape)
|
97 |
+
|
models/bsq_vae/conv.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from einops import rearrange
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
|
7 |
+
class Conv(nn.Module):
|
8 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, cnn_type="2d", causal_offset=0, temporal_down=False):
|
9 |
+
super().__init__()
|
10 |
+
self.cnn_type = cnn_type
|
11 |
+
self.slice_seq_len = 17
|
12 |
+
|
13 |
+
if cnn_type == "2d":
|
14 |
+
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding)
|
15 |
+
if cnn_type == "3d":
|
16 |
+
if temporal_down == False:
|
17 |
+
stride = (1, stride, stride)
|
18 |
+
else:
|
19 |
+
stride = (stride, stride, stride)
|
20 |
+
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=0)
|
21 |
+
if isinstance(kernel_size, int):
|
22 |
+
kernel_size = (kernel_size, kernel_size, kernel_size)
|
23 |
+
self.padding = (
|
24 |
+
kernel_size[0] - 1 + causal_offset, # Temporal causal padding
|
25 |
+
padding, # Height padding
|
26 |
+
padding # Width padding
|
27 |
+
)
|
28 |
+
self.causal_offset = causal_offset
|
29 |
+
self.stride = stride
|
30 |
+
self.kernel_size = kernel_size
|
31 |
+
|
32 |
+
def forward(self, x):
|
33 |
+
if self.cnn_type == "2d":
|
34 |
+
if x.ndim == 5:
|
35 |
+
B, C, T, H, W = x.shape
|
36 |
+
x = rearrange(x, "B C T H W -> (B T) C H W")
|
37 |
+
x = self.conv(x)
|
38 |
+
x = rearrange(x, "(B T) C H W -> B C T H W", T=T)
|
39 |
+
return x
|
40 |
+
else:
|
41 |
+
return self.conv(x)
|
42 |
+
if self.cnn_type == "3d":
|
43 |
+
assert self.stride[0] == 1 or self.stride[0] == 2, f"only temporal stride = 1 or 2 are supported"
|
44 |
+
xs = []
|
45 |
+
for i in range(0, x.shape[2], self.slice_seq_len+self.stride[0]-1):
|
46 |
+
st = i
|
47 |
+
en = min(i+self.slice_seq_len, x.shape[2])
|
48 |
+
_x = x[:,:,st:en,:,:]
|
49 |
+
if i == 0:
|
50 |
+
_x = F.pad(_x, (self.padding[2], self.padding[2], # Width
|
51 |
+
self.padding[1], self.padding[1], # Height
|
52 |
+
self.padding[0], 0)) # Temporal
|
53 |
+
else:
|
54 |
+
padding_0 = self.kernel_size[0] - 1
|
55 |
+
_x = F.pad(_x, (self.padding[2], self.padding[2], # Width
|
56 |
+
self.padding[1], self.padding[1], # Height
|
57 |
+
padding_0, 0)) # Temporal
|
58 |
+
_x[:,:,:padding_0,
|
59 |
+
self.padding[1]:_x.shape[-2]-self.padding[1],
|
60 |
+
self.padding[2]:_x.shape[-1]-self.padding[2]] += x[:,:,i-padding_0:i,:,:]
|
61 |
+
_x = self.conv(_x)
|
62 |
+
xs.append(_x)
|
63 |
+
try:
|
64 |
+
x = torch.cat(xs, dim=2)
|
65 |
+
except:
|
66 |
+
device = x.device
|
67 |
+
del x
|
68 |
+
xs = [_x.cpu().pin_memory() for _x in xs]
|
69 |
+
torch.cuda.empty_cache()
|
70 |
+
x = torch.cat([_x.cpu() for _x in xs], dim=2).to(device=device)
|
71 |
+
return x
|
models/bsq_vae/dynamic_resolution.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import numpy as np
|
3 |
+
import tqdm
|
4 |
+
|
5 |
+
vae_stride = 16
|
6 |
+
ratio2hws = {
|
7 |
+
1.000: [(1,1),(2,2),(4,4),(6,6),(8,8),(12,12),(16,16),(20,20),(24,24),(32,32),(40,40),(48,48),(64,64)],
|
8 |
+
1.250: [(1,1),(2,2),(3,3),(5,4),(10,8),(15,12),(20,16),(25,20),(30,24),(35,28),(45,36),(55,44),(70,56)],
|
9 |
+
1.333: [(1,1),(2,2),(4,3),(8,6),(12,9),(16,12),(20,15),(24,18),(28,21),(36,27),(48,36),(60,45),(72,54)],
|
10 |
+
1.500: [(1,1),(2,2),(3,2),(6,4),(9,6),(15,10),(21,14),(27,18),(33,22),(39,26),(48,32),(63,42),(78,52)],
|
11 |
+
1.750: [(1,1),(2,2),(3,3),(7,4),(11,6),(14,8),(21,12),(28,16),(35,20),(42,24),(56,32),(70,40),(84,48)],
|
12 |
+
2.000: [(1,1),(2,2),(4,2),(6,3),(10,5),(16,8),(22,11),(30,15),(38,19),(46,23),(60,30),(74,37),(90,45)],
|
13 |
+
2.500: [(1,1),(2,2),(5,2),(10,4),(15,6),(20,8),(25,10),(30,12),(40,16),(50,20),(65,26),(80,32),(100,40)],
|
14 |
+
3.000: [(1,1),(2,2),(6,2),(9,3),(15,5),(21,7),(27,9),(36,12),(45,15),(54,18),(72,24),(90,30),(111,37)],
|
15 |
+
}
|
16 |
+
full_ratio2hws = {}
|
17 |
+
for ratio, hws in ratio2hws.items():
|
18 |
+
full_ratio2hws[ratio] = hws
|
19 |
+
full_ratio2hws[int(1/ratio*1000)/1000] = [(item[1], item[0]) for item in hws]
|
20 |
+
|
21 |
+
dynamic_resolution_h_w = {}
|
22 |
+
predefined_HW_Scales_dynamic = {}
|
23 |
+
for ratio in full_ratio2hws:
|
24 |
+
dynamic_resolution_h_w[ratio] ={}
|
25 |
+
for ind, leng in enumerate([7, 10, 13]):
|
26 |
+
h, w = full_ratio2hws[ratio][leng-1][0], full_ratio2hws[ratio][leng-1][1] # feature map size
|
27 |
+
pixel = (h * vae_stride, w * vae_stride) # The original image (H, W)
|
28 |
+
dynamic_resolution_h_w[ratio][pixel[1]] = {
|
29 |
+
'pixel': pixel,
|
30 |
+
'scales': full_ratio2hws[ratio][:leng]
|
31 |
+
} # W as key
|
32 |
+
predefined_HW_Scales_dynamic[(h, w)] = full_ratio2hws[ratio][:leng]
|
models/bsq_vae/flux_vqgan.py
ADDED
@@ -0,0 +1,557 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import imageio
|
4 |
+
import torch
|
5 |
+
import numpy as np
|
6 |
+
from einops import rearrange
|
7 |
+
from torch import Tensor, nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
import torchvision
|
10 |
+
from torchvision import transforms
|
11 |
+
from safetensors.torch import load_file
|
12 |
+
import torch.utils.checkpoint as checkpoint
|
13 |
+
|
14 |
+
from .conv import Conv
|
15 |
+
from .multiscale_bsq import MultiScaleBSQ
|
16 |
+
|
17 |
+
ptdtype = {None: torch.float32, 'fp32': torch.float32, 'bf16': torch.bfloat16}
|
18 |
+
|
19 |
+
class Normalize(nn.Module):
|
20 |
+
def __init__(self, in_channels, norm_type, norm_axis="spatial"):
|
21 |
+
super().__init__()
|
22 |
+
self.norm_axis = norm_axis
|
23 |
+
assert norm_type in ['group', 'batch', "no"]
|
24 |
+
if norm_type == 'group':
|
25 |
+
if in_channels % 32 == 0:
|
26 |
+
self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
27 |
+
elif in_channels % 24 == 0:
|
28 |
+
self.norm = nn.GroupNorm(num_groups=24, num_channels=in_channels, eps=1e-6, affine=True)
|
29 |
+
else:
|
30 |
+
raise NotImplementedError
|
31 |
+
elif norm_type == 'batch':
|
32 |
+
self.norm = nn.SyncBatchNorm(in_channels, track_running_stats=False) # Runtime Error: grad inplace if set track_running_stats to True
|
33 |
+
elif norm_type == 'no':
|
34 |
+
self.norm = nn.Identity()
|
35 |
+
|
36 |
+
def forward(self, x):
|
37 |
+
if self.norm_axis == "spatial":
|
38 |
+
if x.ndim == 4:
|
39 |
+
x = self.norm(x)
|
40 |
+
else:
|
41 |
+
B, C, T, H, W = x.shape
|
42 |
+
x = rearrange(x, "B C T H W -> (B T) C H W")
|
43 |
+
x = self.norm(x)
|
44 |
+
x = rearrange(x, "(B T) C H W -> B C T H W", T=T)
|
45 |
+
elif self.norm_axis == "spatial-temporal":
|
46 |
+
x = self.norm(x)
|
47 |
+
else:
|
48 |
+
raise NotImplementedError
|
49 |
+
return x
|
50 |
+
|
51 |
+
def swish(x: Tensor) -> Tensor:
|
52 |
+
try:
|
53 |
+
return x * torch.sigmoid(x)
|
54 |
+
except:
|
55 |
+
device = x.device
|
56 |
+
x = x.cpu().pin_memory()
|
57 |
+
return (x*torch.sigmoid(x)).to(device=device)
|
58 |
+
|
59 |
+
|
60 |
+
class AttnBlock(nn.Module):
|
61 |
+
def __init__(self, in_channels, norm_type='group', cnn_param=None):
|
62 |
+
super().__init__()
|
63 |
+
self.in_channels = in_channels
|
64 |
+
|
65 |
+
self.norm = Normalize(in_channels, norm_type, norm_axis=cnn_param["cnn_norm_axis"])
|
66 |
+
|
67 |
+
self.q = Conv(in_channels, in_channels, kernel_size=1)
|
68 |
+
self.k = Conv(in_channels, in_channels, kernel_size=1)
|
69 |
+
self.v = Conv(in_channels, in_channels, kernel_size=1)
|
70 |
+
self.proj_out = Conv(in_channels, in_channels, kernel_size=1)
|
71 |
+
|
72 |
+
def attention(self, h_: Tensor) -> Tensor:
|
73 |
+
B, _, T, _, _ = h_.shape
|
74 |
+
h_ = self.norm(h_)
|
75 |
+
h_ = rearrange(h_, "B C T H W -> (B T) C H W") # spatial attention only
|
76 |
+
q = self.q(h_)
|
77 |
+
k = self.k(h_)
|
78 |
+
v = self.v(h_)
|
79 |
+
|
80 |
+
b, c, h, w = q.shape
|
81 |
+
q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
|
82 |
+
k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
|
83 |
+
v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
|
84 |
+
h_ = nn.functional.scaled_dot_product_attention(q, k, v)
|
85 |
+
|
86 |
+
return rearrange(h_, "(b t) 1 (h w) c -> b c t h w", h=h, w=w, c=c, b=B, t=T)
|
87 |
+
|
88 |
+
def forward(self, x: Tensor) -> Tensor:
|
89 |
+
return x + self.proj_out(self.attention(x))
|
90 |
+
|
91 |
+
|
92 |
+
class ResnetBlock(nn.Module):
|
93 |
+
def __init__(self, in_channels: int, out_channels: int, norm_type='group', cnn_param=None):
|
94 |
+
super().__init__()
|
95 |
+
self.in_channels = in_channels
|
96 |
+
out_channels = in_channels if out_channels is None else out_channels
|
97 |
+
self.out_channels = out_channels
|
98 |
+
|
99 |
+
self.norm1 = Normalize(in_channels, norm_type, norm_axis=cnn_param["cnn_norm_axis"])
|
100 |
+
if cnn_param["res_conv_2d"] in ["half", "full"]:
|
101 |
+
self.conv1 = Conv(in_channels, out_channels, kernel_size=3, stride=1, padding=1, cnn_type="2d")
|
102 |
+
else:
|
103 |
+
self.conv1 = Conv(in_channels, out_channels, kernel_size=3, stride=1, padding=1, cnn_type=cnn_param["cnn_type"])
|
104 |
+
self.norm2 = Normalize(out_channels, norm_type, norm_axis=cnn_param["cnn_norm_axis"])
|
105 |
+
if cnn_param["res_conv_2d"] in ["full"]:
|
106 |
+
self.conv2 = Conv(out_channels, out_channels, kernel_size=3, stride=1, padding=1, cnn_type="2d")
|
107 |
+
else:
|
108 |
+
self.conv2 = Conv(out_channels, out_channels, kernel_size=3, stride=1, padding=1, cnn_type=cnn_param["cnn_type"])
|
109 |
+
if self.in_channels != self.out_channels:
|
110 |
+
self.nin_shortcut = Conv(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
111 |
+
|
112 |
+
def forward(self, x):
|
113 |
+
h = x
|
114 |
+
h = self.norm1(h)
|
115 |
+
h = swish(h)
|
116 |
+
h = self.conv1(h)
|
117 |
+
|
118 |
+
h = self.norm2(h)
|
119 |
+
h = swish(h)
|
120 |
+
h = self.conv2(h)
|
121 |
+
|
122 |
+
if self.in_channels != self.out_channels:
|
123 |
+
x = self.nin_shortcut(x)
|
124 |
+
|
125 |
+
return x + h
|
126 |
+
|
127 |
+
|
128 |
+
class Downsample(nn.Module):
|
129 |
+
def __init__(self, in_channels, cnn_type="2d", spatial_down=False, temporal_down=False):
|
130 |
+
super().__init__()
|
131 |
+
assert spatial_down == True
|
132 |
+
if cnn_type == "2d":
|
133 |
+
self.pad = (0,1,0,1)
|
134 |
+
if cnn_type == "3d":
|
135 |
+
self.pad = (0,1,0,1,0,0) # add padding to the right for h-axis and w-axis. No padding for t-axis
|
136 |
+
# no asymmetric padding in torch conv, must do it ourselves
|
137 |
+
self.conv = Conv(in_channels, in_channels, kernel_size=3, stride=2, padding=0, cnn_type=cnn_type, temporal_down=temporal_down)
|
138 |
+
|
139 |
+
def forward(self, x: Tensor):
|
140 |
+
x = nn.functional.pad(x, self.pad, mode="constant", value=0)
|
141 |
+
x = self.conv(x)
|
142 |
+
return x
|
143 |
+
|
144 |
+
|
145 |
+
class Upsample(nn.Module):
|
146 |
+
def __init__(self, in_channels, cnn_type="2d", spatial_up=False, temporal_up=False, use_pxsl=False):
|
147 |
+
super().__init__()
|
148 |
+
if cnn_type == "2d":
|
149 |
+
self.scale_factor = 2
|
150 |
+
self.causal_offset = 0
|
151 |
+
else:
|
152 |
+
assert spatial_up == True
|
153 |
+
if temporal_up:
|
154 |
+
self.scale_factor = (2,2,2)
|
155 |
+
self.causal_offset = -1
|
156 |
+
else:
|
157 |
+
self.scale_factor = (1,2,2)
|
158 |
+
self.causal_offset = 0
|
159 |
+
self.use_pxsl = use_pxsl
|
160 |
+
if self.use_pxsl:
|
161 |
+
self.conv = Conv(in_channels, in_channels*4, kernel_size=3, stride=1, padding=1, cnn_type=cnn_type, causal_offset=self.causal_offset)
|
162 |
+
self.pxsl = nn.PixelShuffle(2)
|
163 |
+
else:
|
164 |
+
self.conv = Conv(in_channels, in_channels, kernel_size=3, stride=1, padding=1, cnn_type=cnn_type, causal_offset=self.causal_offset)
|
165 |
+
|
166 |
+
def forward(self, x: Tensor):
|
167 |
+
if self.use_pxsl:
|
168 |
+
x = self.conv(x)
|
169 |
+
x = self.pxsl(x)
|
170 |
+
else:
|
171 |
+
try:
|
172 |
+
x = F.interpolate(x, scale_factor=self.scale_factor, mode="nearest")
|
173 |
+
except:
|
174 |
+
# shard across channel
|
175 |
+
_xs = []
|
176 |
+
for i in range(x.shape[1]):
|
177 |
+
_x = F.interpolate(x[:,i:i+1,...], scale_factor=self.scale_factor, mode="nearest")
|
178 |
+
_xs.append(_x)
|
179 |
+
x = torch.cat(_xs, dim=1)
|
180 |
+
x = self.conv(x)
|
181 |
+
return x
|
182 |
+
|
183 |
+
|
184 |
+
class Encoder(nn.Module):
|
185 |
+
def __init__(
|
186 |
+
self,
|
187 |
+
ch: int,
|
188 |
+
ch_mult: list[int],
|
189 |
+
num_res_blocks: int,
|
190 |
+
z_channels: int,
|
191 |
+
in_channels = 3,
|
192 |
+
patch_size=8, temporal_patch_size=4,
|
193 |
+
norm_type='group', cnn_param=None,
|
194 |
+
use_checkpoint=False,
|
195 |
+
use_vae=True,
|
196 |
+
):
|
197 |
+
super().__init__()
|
198 |
+
self.max_down = np.log2(patch_size)
|
199 |
+
self.temporal_max_down = np.log2(temporal_patch_size)
|
200 |
+
self.temporal_down_offset = self.max_down - self.temporal_max_down
|
201 |
+
self.ch = ch
|
202 |
+
self.num_resolutions = len(ch_mult)
|
203 |
+
self.num_res_blocks = num_res_blocks
|
204 |
+
self.in_channels = in_channels
|
205 |
+
self.cnn_param = cnn_param
|
206 |
+
self.use_checkpoint = use_checkpoint
|
207 |
+
# downsampling
|
208 |
+
# self.conv_in = Conv(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
|
209 |
+
# cnn_param["cnn_type"] = "2d" for images, cnn_param["cnn_type"] = "3d" for videos
|
210 |
+
if cnn_param["conv_in_out_2d"] == "yes": # "yes" for video
|
211 |
+
self.conv_in = Conv(in_channels, ch, kernel_size=3, stride=1, padding=1, cnn_type="2d")
|
212 |
+
else:
|
213 |
+
self.conv_in = Conv(in_channels, ch, kernel_size=3, stride=1, padding=1, cnn_type=cnn_param["cnn_type"])
|
214 |
+
|
215 |
+
in_ch_mult = (1,) + tuple(ch_mult)
|
216 |
+
self.in_ch_mult = in_ch_mult
|
217 |
+
self.down = nn.ModuleList()
|
218 |
+
block_in = self.ch
|
219 |
+
for i_level in range(self.num_resolutions):
|
220 |
+
block = nn.ModuleList()
|
221 |
+
attn = nn.ModuleList()
|
222 |
+
block_in = ch * in_ch_mult[i_level]
|
223 |
+
block_out = ch * ch_mult[i_level]
|
224 |
+
for _ in range(self.num_res_blocks):
|
225 |
+
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, norm_type=norm_type, cnn_param=cnn_param))
|
226 |
+
block_in = block_out
|
227 |
+
down = nn.Module()
|
228 |
+
down.block = block
|
229 |
+
down.attn = attn
|
230 |
+
# downsample, stride=1, stride=2, stride=2 for 4x8x8 Video VAE
|
231 |
+
spatial_down = True if i_level < self.max_down else False
|
232 |
+
temporal_down = True if i_level < self.max_down and i_level >= self.temporal_down_offset else False
|
233 |
+
if spatial_down or temporal_down:
|
234 |
+
down.downsample = Downsample(block_in, cnn_type=cnn_param["cnn_type"], spatial_down=spatial_down, temporal_down=temporal_down)
|
235 |
+
self.down.append(down)
|
236 |
+
|
237 |
+
# middle
|
238 |
+
self.mid = nn.Module()
|
239 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, norm_type=norm_type, cnn_param=cnn_param)
|
240 |
+
if cnn_param["cnn_attention"] == "yes":
|
241 |
+
self.mid.attn_1 = AttnBlock(block_in, norm_type, cnn_param=cnn_param)
|
242 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, norm_type=norm_type, cnn_param=cnn_param)
|
243 |
+
|
244 |
+
# end
|
245 |
+
self.norm_out = Normalize(block_in, norm_type, norm_axis=cnn_param["cnn_norm_axis"])
|
246 |
+
if cnn_param["conv_inner_2d"] == "yes":
|
247 |
+
self.conv_out = Conv(block_in, (int(use_vae) + 1) * z_channels, kernel_size=3, stride=1, padding=1, cnn_type="2d")
|
248 |
+
else:
|
249 |
+
self.conv_out = Conv(block_in, (int(use_vae) + 1) * z_channels, kernel_size=3, stride=1, padding=1, cnn_type=cnn_param["cnn_type"])
|
250 |
+
|
251 |
+
def forward(self, x, return_hidden=False):
|
252 |
+
if not self.use_checkpoint:
|
253 |
+
return self._forward(x, return_hidden=return_hidden)
|
254 |
+
else:
|
255 |
+
return checkpoint.checkpoint(self._forward, x, return_hidden, use_reentrant=False)
|
256 |
+
|
257 |
+
def _forward(self, x: Tensor, return_hidden=False) -> Tensor:
|
258 |
+
# downsampling
|
259 |
+
h0 = self.conv_in(x)
|
260 |
+
hs = [h0]
|
261 |
+
for i_level in range(self.num_resolutions):
|
262 |
+
for i_block in range(self.num_res_blocks):
|
263 |
+
h = self.down[i_level].block[i_block](hs[-1])
|
264 |
+
if len(self.down[i_level].attn) > 0:
|
265 |
+
h = self.down[i_level].attn[i_block](h)
|
266 |
+
hs.append(h)
|
267 |
+
if hasattr(self.down[i_level], "downsample"):
|
268 |
+
hs.append(self.down[i_level].downsample(hs[-1]))
|
269 |
+
|
270 |
+
# middle
|
271 |
+
h = hs[-1]
|
272 |
+
hs_mid = [h]
|
273 |
+
h = self.mid.block_1(h)
|
274 |
+
if self.cnn_param["cnn_attention"] == "yes":
|
275 |
+
h = self.mid.attn_1(h)
|
276 |
+
h = self.mid.block_2(h)
|
277 |
+
hs_mid.append(h)
|
278 |
+
# end
|
279 |
+
h = self.norm_out(h)
|
280 |
+
h = swish(h)
|
281 |
+
h = self.conv_out(h)
|
282 |
+
if return_hidden:
|
283 |
+
return h, hs, hs_mid
|
284 |
+
else:
|
285 |
+
return h
|
286 |
+
|
287 |
+
|
288 |
+
class Decoder(nn.Module):
|
289 |
+
def __init__(
|
290 |
+
self,
|
291 |
+
ch: int,
|
292 |
+
ch_mult: list[int],
|
293 |
+
num_res_blocks: int,
|
294 |
+
z_channels: int,
|
295 |
+
out_ch = 3,
|
296 |
+
patch_size=8, temporal_patch_size=4,
|
297 |
+
norm_type="group", cnn_param=None,
|
298 |
+
use_checkpoint=False,
|
299 |
+
use_freq_dec=False, # use frequency features for decoder
|
300 |
+
use_pxsf=False
|
301 |
+
):
|
302 |
+
super().__init__()
|
303 |
+
self.max_up = np.log2(patch_size)
|
304 |
+
self.temporal_max_up = np.log2(temporal_patch_size)
|
305 |
+
self.temporal_up_offset = self.max_up - self.temporal_max_up
|
306 |
+
self.ch = ch
|
307 |
+
self.num_resolutions = len(ch_mult)
|
308 |
+
self.num_res_blocks = num_res_blocks
|
309 |
+
self.ffactor = 2 ** (self.num_resolutions - 1)
|
310 |
+
self.cnn_param = cnn_param
|
311 |
+
self.use_checkpoint = use_checkpoint
|
312 |
+
self.use_freq_dec = use_freq_dec
|
313 |
+
self.use_pxsf = use_pxsf
|
314 |
+
|
315 |
+
# compute in_ch_mult, block_in and curr_res at lowest res
|
316 |
+
block_in = ch * ch_mult[self.num_resolutions - 1]
|
317 |
+
|
318 |
+
# z to block_in
|
319 |
+
if cnn_param["conv_inner_2d"] == "yes":
|
320 |
+
self.conv_in = Conv(z_channels, block_in, kernel_size=3, stride=1, padding=1, cnn_type="2d")
|
321 |
+
else:
|
322 |
+
self.conv_in = Conv(z_channels, block_in, kernel_size=3, stride=1, padding=1, cnn_type=cnn_param["cnn_type"])
|
323 |
+
|
324 |
+
# middle
|
325 |
+
self.mid = nn.Module()
|
326 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, norm_type=norm_type, cnn_param=cnn_param)
|
327 |
+
if cnn_param["cnn_attention"] == "yes":
|
328 |
+
self.mid.attn_1 = AttnBlock(block_in, norm_type=norm_type, cnn_param=cnn_param)
|
329 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, norm_type=norm_type, cnn_param=cnn_param)
|
330 |
+
|
331 |
+
# upsampling
|
332 |
+
self.up = nn.ModuleList()
|
333 |
+
for i_level in reversed(range(self.num_resolutions)):
|
334 |
+
block = nn.ModuleList()
|
335 |
+
attn = nn.ModuleList()
|
336 |
+
block_out = ch * ch_mult[i_level]
|
337 |
+
for _ in range(self.num_res_blocks + 1):
|
338 |
+
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, norm_type=norm_type, cnn_param=cnn_param))
|
339 |
+
block_in = block_out
|
340 |
+
up = nn.Module()
|
341 |
+
up.block = block
|
342 |
+
up.attn = attn
|
343 |
+
# upsample, stride=1, stride=2, stride=2 for 4x8x8 Video VAE, offset 1 compared with encoder
|
344 |
+
# https://github.com/black-forest-labs/flux/blob/b4f689aaccd40de93429865793e84a734f4a6254/src/flux/modules/autoencoder.py#L228
|
345 |
+
spatial_up = True if 1 <= i_level <= self.max_up else False
|
346 |
+
temporal_up = True if 1 <= i_level <= self.max_up and i_level >= self.temporal_up_offset+1 else False
|
347 |
+
if spatial_up or temporal_up:
|
348 |
+
up.upsample = Upsample(block_in, cnn_type=cnn_param["cnn_type"], spatial_up=spatial_up, temporal_up=temporal_up, use_pxsl=self.use_pxsf)
|
349 |
+
self.up.insert(0, up) # prepend to get consistent order
|
350 |
+
|
351 |
+
# end
|
352 |
+
self.norm_out = Normalize(block_in, norm_type, norm_axis=cnn_param["cnn_norm_axis"])
|
353 |
+
if cnn_param["conv_in_out_2d"] == "yes":
|
354 |
+
self.conv_out = Conv(block_in, out_ch, kernel_size=3, stride=1, padding=1, cnn_type="2d")
|
355 |
+
else:
|
356 |
+
self.conv_out = Conv(block_in, out_ch, kernel_size=3, stride=1, padding=1, cnn_type=cnn_param["cnn_type"])
|
357 |
+
|
358 |
+
def forward(self, z):
|
359 |
+
if not self.use_checkpoint:
|
360 |
+
return self._forward(z)
|
361 |
+
else:
|
362 |
+
return checkpoint.checkpoint(self._forward, z, use_reentrant=False)
|
363 |
+
|
364 |
+
def _forward(self, z: Tensor) -> Tensor:
|
365 |
+
# z to block_in
|
366 |
+
h = self.conv_in(z)
|
367 |
+
|
368 |
+
# middle
|
369 |
+
h = self.mid.block_1(h)
|
370 |
+
if self.cnn_param["cnn_attention"] == "yes":
|
371 |
+
h = self.mid.attn_1(h)
|
372 |
+
h = self.mid.block_2(h)
|
373 |
+
|
374 |
+
# upsampling
|
375 |
+
for i_level in reversed(range(self.num_resolutions)):
|
376 |
+
for i_block in range(self.num_res_blocks + 1):
|
377 |
+
h = self.up[i_level].block[i_block](h)
|
378 |
+
if len(self.up[i_level].attn) > 0:
|
379 |
+
h = self.up[i_level].attn[i_block](h)
|
380 |
+
if hasattr(self.up[i_level], "upsample"):
|
381 |
+
h = self.up[i_level].upsample(h)
|
382 |
+
|
383 |
+
# end
|
384 |
+
h = self.norm_out(h)
|
385 |
+
h = swish(h)
|
386 |
+
h = self.conv_out(h)
|
387 |
+
return h
|
388 |
+
|
389 |
+
|
390 |
+
class AutoEncoder(nn.Module):
|
391 |
+
def __init__(self, args):
|
392 |
+
super().__init__()
|
393 |
+
self.args = args
|
394 |
+
cnn_param = dict(
|
395 |
+
cnn_type=args.cnn_type,
|
396 |
+
conv_in_out_2d=args.conv_in_out_2d,
|
397 |
+
res_conv_2d=args.res_conv_2d,
|
398 |
+
cnn_attention=args.cnn_attention,
|
399 |
+
cnn_norm_axis=args.cnn_norm_axis,
|
400 |
+
conv_inner_2d=args.conv_inner_2d,
|
401 |
+
)
|
402 |
+
self.encoder = Encoder(
|
403 |
+
ch=args.base_ch,
|
404 |
+
ch_mult=args.encoder_ch_mult,
|
405 |
+
num_res_blocks=args.num_res_blocks,
|
406 |
+
z_channels=args.codebook_dim,
|
407 |
+
patch_size=args.patch_size,
|
408 |
+
temporal_patch_size=args.temporal_patch_size,
|
409 |
+
cnn_param=cnn_param,
|
410 |
+
use_checkpoint=args.use_checkpoint,
|
411 |
+
use_vae=args.use_vae,
|
412 |
+
)
|
413 |
+
self.decoder = Decoder(
|
414 |
+
ch=args.base_ch,
|
415 |
+
ch_mult=args.decoder_ch_mult,
|
416 |
+
num_res_blocks=args.num_res_blocks,
|
417 |
+
z_channels=args.codebook_dim,
|
418 |
+
patch_size=args.patch_size,
|
419 |
+
temporal_patch_size=args.temporal_patch_size,
|
420 |
+
cnn_param=cnn_param,
|
421 |
+
use_checkpoint=args.use_checkpoint,
|
422 |
+
use_freq_dec=args.use_freq_dec,
|
423 |
+
use_pxsf=args.use_pxsf # pixelshuffle for upsampling
|
424 |
+
)
|
425 |
+
self.z_drop = nn.Dropout(args.z_drop)
|
426 |
+
self.scale_factor = 0.3611
|
427 |
+
self.shift_factor = 0.1159
|
428 |
+
self.codebook_dim = self.embed_dim = args.codebook_dim
|
429 |
+
|
430 |
+
self.gan_feat_weight = args.gan_feat_weight
|
431 |
+
self.video_perceptual_weight = args.video_perceptual_weight
|
432 |
+
self.recon_loss_type = args.recon_loss_type
|
433 |
+
self.l1_weight = args.l1_weight
|
434 |
+
self.use_vae = args.use_vae
|
435 |
+
self.kl_weight = args.kl_weight
|
436 |
+
self.lfq_weight = args.lfq_weight
|
437 |
+
self.image_gan_weight = args.image_gan_weight # image GAN loss weight
|
438 |
+
self.video_gan_weight = args.video_gan_weight # video GAN loss weight
|
439 |
+
self.perceptual_weight = args.perceptual_weight
|
440 |
+
self.flux_weight = args.flux_weight
|
441 |
+
self.cycle_weight = args.cycle_weight
|
442 |
+
self.cycle_feat_weight = args.cycle_feat_weight
|
443 |
+
self.cycle_gan_weight = args.cycle_gan_weight
|
444 |
+
|
445 |
+
self.flux_image_encoder = None
|
446 |
+
|
447 |
+
if not args.use_vae:
|
448 |
+
if args.quantizer_type == 'MultiScaleBSQ':
|
449 |
+
self.quantizer = MultiScaleBSQ(
|
450 |
+
dim = args.codebook_dim, # this is the input feature dimension, defaults to log2(codebook_size) if not defined
|
451 |
+
codebook_size = args.codebook_size, # codebook size, must be a power of 2
|
452 |
+
entropy_loss_weight = args.entropy_loss_weight, # how much weight to place on entropy loss
|
453 |
+
diversity_gamma = args.diversity_gamma, # within entropy loss, how much weight to give to diversity of codes, taken from https://arxiv.org/abs/1911.05894
|
454 |
+
preserve_norm=args.preserve_norm, # preserve norm of the input for BSQ
|
455 |
+
ln_before_quant=args.ln_before_quant, # use layer norm before quantization
|
456 |
+
ln_init_by_sqrt=args.ln_init_by_sqrt, # layer norm init value 1/sqrt(d)
|
457 |
+
commitment_loss_weight=args.commitment_loss_weight, # loss weight of commitment loss
|
458 |
+
new_quant=args.new_quant,
|
459 |
+
use_decay_factor=args.use_decay_factor,
|
460 |
+
mask_out=args.mask_out,
|
461 |
+
use_stochastic_depth=args.use_stochastic_depth,
|
462 |
+
drop_rate=args.drop_rate,
|
463 |
+
schedule_mode=args.schedule_mode,
|
464 |
+
keep_first_quant=args.keep_first_quant,
|
465 |
+
keep_last_quant=args.keep_last_quant,
|
466 |
+
remove_residual_detach=args.remove_residual_detach,
|
467 |
+
use_out_phi=args.use_out_phi,
|
468 |
+
use_out_phi_res=args.use_out_phi_res,
|
469 |
+
random_flip = args.random_flip,
|
470 |
+
flip_prob = args.flip_prob,
|
471 |
+
flip_mode = args.flip_mode,
|
472 |
+
max_flip_lvl = args.max_flip_lvl,
|
473 |
+
random_flip_1lvl = args.random_flip_1lvl,
|
474 |
+
flip_lvl_idx = args.flip_lvl_idx,
|
475 |
+
drop_when_test = args.drop_when_test,
|
476 |
+
drop_lvl_idx = args.drop_lvl_idx,
|
477 |
+
drop_lvl_num = args.drop_lvl_num,
|
478 |
+
)
|
479 |
+
self.quantize = self.quantizer
|
480 |
+
self.vocab_size = args.codebook_size
|
481 |
+
else:
|
482 |
+
raise NotImplementedError(f"{args.quantizer_type} not supported")
|
483 |
+
|
484 |
+
|
485 |
+
def forward(self, x):
|
486 |
+
is_image = x.ndim == 4
|
487 |
+
if not is_image:
|
488 |
+
B, C, T, H, W = x.shape
|
489 |
+
else:
|
490 |
+
B, C, H, W = x.shape
|
491 |
+
T = 1
|
492 |
+
enc_dtype = ptdtype[self.args.encoder_dtype]
|
493 |
+
|
494 |
+
with torch.amp.autocast("cuda", dtype=enc_dtype):
|
495 |
+
h, hs, hs_mid = self.encoder(x, return_hidden=True) # B C H W or B C T H W
|
496 |
+
hs = [_h.detach() for _h in hs]
|
497 |
+
hs_mid = [_h.detach() for _h in hs_mid]
|
498 |
+
h = h.to(dtype=torch.float32)
|
499 |
+
# print(z.shape)
|
500 |
+
# Multiscale LFQ
|
501 |
+
z, all_indices, all_loss = self.quantizer(h)
|
502 |
+
x_recon = self.decoder(z)
|
503 |
+
vq_output = {
|
504 |
+
"commitment_loss": torch.mean(all_loss) * self.lfq_weight, # here commitment loss is sum of commitment loss and entropy penalty
|
505 |
+
"encodings": all_indices,
|
506 |
+
}
|
507 |
+
return x_recon, vq_output
|
508 |
+
|
509 |
+
def encode_for_raw_features(self, x, scale_schedule, return_residual_norm_per_scale=False):
|
510 |
+
is_image = x.ndim == 4
|
511 |
+
if not is_image:
|
512 |
+
B, C, T, H, W = x.shape
|
513 |
+
else:
|
514 |
+
B, C, H, W = x.shape
|
515 |
+
T = 1
|
516 |
+
|
517 |
+
enc_dtype = ptdtype[self.args.encoder_dtype]
|
518 |
+
with torch.amp.autocast("cuda", dtype=enc_dtype):
|
519 |
+
h, hs, hs_mid = self.encoder(x, return_hidden=True) # B C H W or B C T H W
|
520 |
+
|
521 |
+
hs = [_h.detach() for _h in hs]
|
522 |
+
hs_mid = [_h.detach() for _h in hs_mid]
|
523 |
+
h = h.to(dtype=torch.float32)
|
524 |
+
return h, hs, hs_mid
|
525 |
+
|
526 |
+
def encode(self, x, scale_schedule, return_residual_norm_per_scale=False):
|
527 |
+
h, hs, hs_mid = self.encode_for_raw_features(x, scale_schedule, return_residual_norm_per_scale)
|
528 |
+
# Multiscale LFQ
|
529 |
+
z, all_indices, all_bit_indices, residual_norm_per_scale, all_loss, var_input = self.quantizer(h, scale_schedule=scale_schedule, return_residual_norm_per_scale=return_residual_norm_per_scale)
|
530 |
+
return h, z, all_indices, all_bit_indices, residual_norm_per_scale, var_input
|
531 |
+
|
532 |
+
def decode(self, z):
|
533 |
+
x_recon = self.decoder(z)
|
534 |
+
x_recon = torch.clamp(x_recon, min=-1, max=1)
|
535 |
+
return x_recon
|
536 |
+
|
537 |
+
def decode_from_indices(self, all_indices, scale_schedule, label_type):
|
538 |
+
summed_codes = 0
|
539 |
+
for idx_Bl in all_indices:
|
540 |
+
codes = self.quantizer.lfq.indices_to_codes(idx_Bl, label_type)
|
541 |
+
summed_codes += F.interpolate(codes, size=scale_schedule[-1], mode=self.quantizer.z_interplote_up)
|
542 |
+
assert summed_codes.shape[-3] == 1
|
543 |
+
x_recon = self.decoder(summed_codes.squeeze(-3))
|
544 |
+
x_recon = torch.clamp(x_recon, min=-1, max=1)
|
545 |
+
return summed_codes, x_recon
|
546 |
+
|
547 |
+
@staticmethod
|
548 |
+
def add_model_specific_args(parent_parser):
|
549 |
+
parser = argparse.ArgumentParser(parents=[parent_parser], add_help=False)
|
550 |
+
parser.add_argument("--flux_weight", type=float, default=0)
|
551 |
+
parser.add_argument("--cycle_weight", type=float, default=0)
|
552 |
+
parser.add_argument("--cycle_feat_weight", type=float, default=0)
|
553 |
+
parser.add_argument("--cycle_gan_weight", type=float, default=0)
|
554 |
+
parser.add_argument("--cycle_loop", type=int, default=0)
|
555 |
+
parser.add_argument("--z_drop", type=float, default=0.)
|
556 |
+
return parser
|
557 |
+
|
models/bsq_vae/multiscale_bsq.py
ADDED
@@ -0,0 +1,718 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Binary Spherical Quantization
|
3 |
+
Proposed in https://arxiv.org/abs/2406.07548
|
4 |
+
|
5 |
+
In the simplest setup, each dimension is quantized into {-1, 1}.
|
6 |
+
An entropy penalty is used to encourage utilization.
|
7 |
+
"""
|
8 |
+
|
9 |
+
import random
|
10 |
+
from math import log2, ceil
|
11 |
+
from functools import partial, cache
|
12 |
+
from collections import namedtuple
|
13 |
+
from contextlib import nullcontext
|
14 |
+
|
15 |
+
import torch.distributed as dist
|
16 |
+
from torch.distributed import nn as dist_nn
|
17 |
+
|
18 |
+
import torch
|
19 |
+
from torch import nn, einsum
|
20 |
+
import torch.nn.functional as F
|
21 |
+
from torch.nn import Module
|
22 |
+
from torch.amp import autocast
|
23 |
+
import numpy as np
|
24 |
+
|
25 |
+
from einops import rearrange, reduce, pack, unpack
|
26 |
+
|
27 |
+
# from einx import get_at
|
28 |
+
|
29 |
+
from .dynamic_resolution import predefined_HW_Scales_dynamic
|
30 |
+
|
31 |
+
# constants
|
32 |
+
|
33 |
+
Return = namedtuple('Return', ['quantized', 'indices', 'bit_indices', 'entropy_aux_loss'])
|
34 |
+
|
35 |
+
LossBreakdown = namedtuple('LossBreakdown', ['per_sample_entropy', 'batch_entropy', 'commitment'])
|
36 |
+
|
37 |
+
# distributed helpers
|
38 |
+
|
39 |
+
@cache
|
40 |
+
def is_distributed():
|
41 |
+
return dist.is_initialized() and dist.get_world_size() > 1
|
42 |
+
|
43 |
+
def maybe_distributed_mean(t):
|
44 |
+
if not is_distributed():
|
45 |
+
return t
|
46 |
+
|
47 |
+
dist_nn.all_reduce(t)
|
48 |
+
t = t / dist.get_world_size()
|
49 |
+
return t
|
50 |
+
|
51 |
+
# helper functions
|
52 |
+
|
53 |
+
def exists(v):
|
54 |
+
return v is not None
|
55 |
+
|
56 |
+
def identity(t):
|
57 |
+
return t
|
58 |
+
|
59 |
+
def default(*args):
|
60 |
+
for arg in args:
|
61 |
+
if exists(arg):
|
62 |
+
return arg() if callable(arg) else arg
|
63 |
+
return None
|
64 |
+
|
65 |
+
def round_up_multiple(num, mult):
|
66 |
+
return ceil(num / mult) * mult
|
67 |
+
|
68 |
+
def pack_one(t, pattern):
|
69 |
+
return pack([t], pattern)
|
70 |
+
|
71 |
+
def unpack_one(t, ps, pattern):
|
72 |
+
return unpack(t, ps, pattern)[0]
|
73 |
+
|
74 |
+
def l2norm(t):
|
75 |
+
return F.normalize(t, dim = -1)
|
76 |
+
|
77 |
+
# entropy
|
78 |
+
|
79 |
+
def log(t, eps = 1e-5):
|
80 |
+
return t.clamp(min = eps).log()
|
81 |
+
|
82 |
+
def entropy(prob):
|
83 |
+
return (-prob * log(prob)).sum(dim=-1)
|
84 |
+
|
85 |
+
# cosine sim linear
|
86 |
+
|
87 |
+
class CosineSimLinear(Module):
|
88 |
+
def __init__(
|
89 |
+
self,
|
90 |
+
dim_in,
|
91 |
+
dim_out,
|
92 |
+
scale = 1.
|
93 |
+
):
|
94 |
+
super().__init__()
|
95 |
+
self.scale = scale
|
96 |
+
self.weight = nn.Parameter(torch.randn(dim_in, dim_out))
|
97 |
+
|
98 |
+
def forward(self, x):
|
99 |
+
x = F.normalize(x, dim = -1)
|
100 |
+
w = F.normalize(self.weight, dim = 0)
|
101 |
+
return (x @ w) * self.scale
|
102 |
+
|
103 |
+
|
104 |
+
def get_latent2scale_schedule(T: int, H: int, W: int, mode="original"):
|
105 |
+
assert mode in ["original", "dynamic", "dense", "same1", "same2", "same3"]
|
106 |
+
predefined_HW_Scales = {
|
107 |
+
# 256 * 256
|
108 |
+
(32, 32): [(1, 1), (2, 2), (3, 3), (4, 4), (6, 6), (9, 9), (13, 13), (18, 18), (24, 24), (32, 32)],
|
109 |
+
(16, 16): [(1, 1), (2, 2), (3, 3), (4, 4), (5, 5), (6, 6), (8, 8), (10, 10), (13, 13), (16, 16)],
|
110 |
+
# 1024x1024
|
111 |
+
(64, 64): [(1, 1), (2, 2), (3, 3), (4, 4), (5, 5), (7, 7), (9, 9), (12, 12), (16, 16), (21, 21), (27, 27), (36, 36), (48, 48), (64, 64)],
|
112 |
+
|
113 |
+
(36, 64): [(1, 1), (2, 2), (3, 3), (4, 4), (6, 6), (9, 12), (13, 16), (18, 24), (24, 32), (32, 48), (36, 64)],
|
114 |
+
}
|
115 |
+
if mode == "dynamic":
|
116 |
+
predefined_HW_Scales.update(predefined_HW_Scales_dynamic)
|
117 |
+
elif mode == "dense":
|
118 |
+
predefined_HW_Scales[(16, 16)] = [(x, x) for x in range(1, 16+1)]
|
119 |
+
predefined_HW_Scales[(32, 32)] = predefined_HW_Scales[(16, 16)] + [(20, 20), (24, 24), (28, 28), (32, 32)]
|
120 |
+
predefined_HW_Scales[(64, 64)] = predefined_HW_Scales[(32, 32)] + [(40, 40), (48, 48), (56, 56), (64, 64)]
|
121 |
+
elif mode.startswith("same"):
|
122 |
+
num_quant = int(mode[len("same"):])
|
123 |
+
predefined_HW_Scales[(16, 16)] = [(16, 16) for _ in range(num_quant)]
|
124 |
+
predefined_HW_Scales[(32, 32)] = [(32, 32) for _ in range(num_quant)]
|
125 |
+
predefined_HW_Scales[(64, 64)] = [(64, 64) for _ in range(num_quant)]
|
126 |
+
|
127 |
+
predefined_T_Scales = [1, 2, 3, 4, 5, 6, 7, 9, 11, 13, 15, 17, 17, 17, 17, 17]
|
128 |
+
patch_THW_shape_per_scale = predefined_HW_Scales[(H, W)]
|
129 |
+
if len(predefined_T_Scales) < len(patch_THW_shape_per_scale):
|
130 |
+
# print("warning: the length of predefined_T_Scales is less than the length of patch_THW_shape_per_scale!")
|
131 |
+
predefined_T_Scales += [predefined_T_Scales[-1]] * (len(patch_THW_shape_per_scale) - len(predefined_T_Scales))
|
132 |
+
patch_THW_shape_per_scale = [(min(T, t), h, w ) for (h, w), t in zip(patch_THW_shape_per_scale, predefined_T_Scales[:len(patch_THW_shape_per_scale)])]
|
133 |
+
return patch_THW_shape_per_scale
|
134 |
+
|
135 |
+
class LayerNorm(nn.Module):
|
136 |
+
r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
|
137 |
+
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
|
138 |
+
shape (batch_size, height, width, channels) while channels_first corresponds to inputs
|
139 |
+
with shape (batch_size, channels, height, width).
|
140 |
+
normalized_shape: int
|
141 |
+
"""
|
142 |
+
def __init__(self, normalized_shape, norm_weight=False, eps=1e-6, data_format="channels_first"):
|
143 |
+
super().__init__()
|
144 |
+
if norm_weight:
|
145 |
+
self.weight = nn.Parameter(torch.ones(normalized_shape)/(normalized_shape**0.5))
|
146 |
+
else:
|
147 |
+
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
148 |
+
self.bias = nn.Parameter(torch.zeros(normalized_shape))
|
149 |
+
self.eps = eps
|
150 |
+
self.data_format = data_format
|
151 |
+
if self.data_format not in ["channels_last", "channels_first"]:
|
152 |
+
raise NotImplementedError
|
153 |
+
self.normalized_shape = (normalized_shape, )
|
154 |
+
|
155 |
+
def forward(self, x):
|
156 |
+
if self.data_format == "channels_last":
|
157 |
+
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
158 |
+
elif self.data_format == "channels_first":
|
159 |
+
u = x.mean(1, keepdim=True)
|
160 |
+
s = (x - u).pow(2).mean(1, keepdim=True)
|
161 |
+
x = (x - u) / torch.sqrt(s + self.eps)
|
162 |
+
if x.ndim == 4: # (b, c, h, w)
|
163 |
+
x = self.weight[:, None, None] * x + self.bias[:, None, None]
|
164 |
+
elif x.ndim == 5: # (b, c, t, h, w)
|
165 |
+
x = self.weight[:, None, None, None] * x + self.bias[:, None, None, None]
|
166 |
+
else:
|
167 |
+
raise ValueError("the number of dimensions of the input should be 4 or 5")
|
168 |
+
return x
|
169 |
+
|
170 |
+
class MultiScaleBSQ(Module):
|
171 |
+
""" Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf """
|
172 |
+
|
173 |
+
def __init__(
|
174 |
+
self,
|
175 |
+
*,
|
176 |
+
dim,
|
177 |
+
codebook_size,
|
178 |
+
soft_clamp_input_value = None,
|
179 |
+
aux_loss = False, # intermediate auxiliary loss
|
180 |
+
ln_before_quant=False, # add a LN before multi-scale RQ
|
181 |
+
ln_init_by_sqrt=False, # weight init by 1/sqrt(d)
|
182 |
+
use_decay_factor=False,
|
183 |
+
use_stochastic_depth=False,
|
184 |
+
drop_rate=0.,
|
185 |
+
schedule_mode="original", # ["original", "dynamic", "dense"]
|
186 |
+
keep_first_quant=False,
|
187 |
+
keep_last_quant=False,
|
188 |
+
remove_residual_detach=False,
|
189 |
+
random_flip = False,
|
190 |
+
flip_prob = 0.5,
|
191 |
+
flip_mode = "stochastic", # "stochastic", "deterministic"
|
192 |
+
max_flip_lvl = 1,
|
193 |
+
random_flip_1lvl = False, # random flip one level each time
|
194 |
+
flip_lvl_idx = None,
|
195 |
+
drop_when_test=False,
|
196 |
+
drop_lvl_idx=None,
|
197 |
+
drop_lvl_num=0,
|
198 |
+
**kwargs
|
199 |
+
):
|
200 |
+
super().__init__()
|
201 |
+
codebook_dim = int(log2(codebook_size))
|
202 |
+
|
203 |
+
requires_projection = codebook_dim != dim
|
204 |
+
self.project_in = nn.Linear(dim, codebook_dim) if requires_projection else nn.Identity()
|
205 |
+
self.project_out = nn.Linear(codebook_dim, dim) if requires_projection else nn.Identity()
|
206 |
+
self.has_projections = requires_projection
|
207 |
+
self.layernorm = LayerNorm(codebook_dim, norm_weight=ln_init_by_sqrt) if ln_before_quant else nn.Identity()
|
208 |
+
self.use_stochastic_depth = use_stochastic_depth
|
209 |
+
self.drop_rate = drop_rate
|
210 |
+
self.remove_residual_detach = remove_residual_detach
|
211 |
+
self.random_flip = random_flip
|
212 |
+
self.flip_prob = flip_prob
|
213 |
+
self.flip_mode = flip_mode
|
214 |
+
self.max_flip_lvl = max_flip_lvl
|
215 |
+
self.random_flip_1lvl = random_flip_1lvl
|
216 |
+
self.flip_lvl_idx = flip_lvl_idx
|
217 |
+
assert (random_flip and random_flip_1lvl) == False
|
218 |
+
self.drop_when_test = drop_when_test
|
219 |
+
self.drop_lvl_idx = drop_lvl_idx
|
220 |
+
self.drop_lvl_num = drop_lvl_num
|
221 |
+
if self.drop_when_test:
|
222 |
+
assert drop_lvl_idx is not None
|
223 |
+
assert drop_lvl_num > 0
|
224 |
+
|
225 |
+
self.lfq = BSQ(
|
226 |
+
dim = codebook_dim,
|
227 |
+
codebook_scale = 1/np.sqrt(codebook_dim),
|
228 |
+
soft_clamp_input_value = soft_clamp_input_value,
|
229 |
+
# experimental_softplus_entropy_loss=True,
|
230 |
+
# entropy_loss_offset=2,
|
231 |
+
**kwargs
|
232 |
+
)
|
233 |
+
|
234 |
+
self.z_interplote_up = 'trilinear'
|
235 |
+
self.z_interplote_down = 'area'
|
236 |
+
|
237 |
+
self.use_decay_factor = use_decay_factor
|
238 |
+
self.schedule_mode = schedule_mode
|
239 |
+
self.keep_first_quant = keep_first_quant
|
240 |
+
self.keep_last_quant = keep_last_quant
|
241 |
+
if self.use_stochastic_depth and self.drop_rate > 0:
|
242 |
+
assert self.keep_first_quant or self.keep_last_quant
|
243 |
+
|
244 |
+
@property
|
245 |
+
def codebooks(self):
|
246 |
+
return self.lfq.codebook
|
247 |
+
|
248 |
+
def get_codes_from_indices(self, indices_list):
|
249 |
+
all_codes = []
|
250 |
+
for indices in indices_list:
|
251 |
+
codes = self.lfq.indices_to_codes(indices)
|
252 |
+
all_codes.append(codes)
|
253 |
+
_, _, T, H, W = all_codes[-1].size()
|
254 |
+
summed_codes = 0
|
255 |
+
for code in all_codes:
|
256 |
+
summed_codes += F.interpolate(code, size=(T, H, W), mode=self.z_interplote_up)
|
257 |
+
return summed_codes
|
258 |
+
|
259 |
+
def get_output_from_indices(self, indices):
|
260 |
+
codes = self.get_codes_from_indices(indices)
|
261 |
+
codes_summed = reduce(codes, 'q ... -> ...', 'sum')
|
262 |
+
return self.project_out(codes_summed)
|
263 |
+
|
264 |
+
def flip_quant(self, x):
|
265 |
+
assert self.flip_mode == 'stochastic'
|
266 |
+
flip_mask = torch.rand_like(x) < self.flip_prob
|
267 |
+
x = x.clone()
|
268 |
+
x[flip_mask] = -x[flip_mask]
|
269 |
+
return x
|
270 |
+
|
271 |
+
def forward(
|
272 |
+
self,
|
273 |
+
x,
|
274 |
+
scale_schedule=None,
|
275 |
+
mask = None,
|
276 |
+
return_all_codes = False,
|
277 |
+
return_residual_norm_per_scale = False
|
278 |
+
):
|
279 |
+
if x.ndim == 4:
|
280 |
+
x = x.unsqueeze(2)
|
281 |
+
B, C, T, H, W = x.size()
|
282 |
+
|
283 |
+
if scale_schedule is None:
|
284 |
+
if self.schedule_mode.startswith("same"):
|
285 |
+
scale_num = int(self.schedule_mode[len("same"):])
|
286 |
+
assert T == 1
|
287 |
+
scale_schedule = [(1, H, W)] * scale_num
|
288 |
+
else:
|
289 |
+
scale_schedule = get_latent2scale_schedule(T, H, W, mode=self.schedule_mode)
|
290 |
+
scale_num = len(scale_schedule)
|
291 |
+
|
292 |
+
# x = self.project_in(x)
|
293 |
+
x = x.permute(0, 2, 3, 4, 1).contiguous() # (b, c, t, h, w) => (b, t, h, w, c)
|
294 |
+
x = self.project_in(x)
|
295 |
+
x = x.permute(0, 4, 1, 2, 3).contiguous() # (b, t, h, w, c) => (b, c, t, h, w)
|
296 |
+
x = self.layernorm(x)
|
297 |
+
|
298 |
+
quantized_out = 0.
|
299 |
+
residual = x
|
300 |
+
|
301 |
+
all_losses = []
|
302 |
+
all_indices = []
|
303 |
+
all_bit_indices = []
|
304 |
+
var_inputs = []
|
305 |
+
residual_norm_per_scale = []
|
306 |
+
|
307 |
+
# go through the layers
|
308 |
+
out_fact = init_out_fact = 1.0
|
309 |
+
# residual_list = []
|
310 |
+
# interpolate_residual_list = []
|
311 |
+
# quantized_list = []
|
312 |
+
if self.drop_when_test:
|
313 |
+
drop_lvl_start = self.drop_lvl_idx
|
314 |
+
drop_lvl_end = self.drop_lvl_idx + self.drop_lvl_num
|
315 |
+
scale_num = len(scale_schedule)
|
316 |
+
with autocast('cuda', enabled = False):
|
317 |
+
for si, (pt, ph, pw) in enumerate(scale_schedule):
|
318 |
+
out_fact = max(0.1, out_fact) if self.use_decay_factor else init_out_fact
|
319 |
+
if (pt, ph, pw) != (T, H, W):
|
320 |
+
interpolate_residual = F.interpolate(residual, size=(pt, ph, pw), mode=self.z_interplote_down)
|
321 |
+
else:
|
322 |
+
interpolate_residual = residual
|
323 |
+
if return_residual_norm_per_scale:
|
324 |
+
residual_norm_per_scale.append((torch.abs(interpolate_residual) < 0.05 * self.lfq.codebook_scale).sum() / interpolate_residual.numel())
|
325 |
+
# residual_list.append(torch.norm(residual.detach(), dim=1).mean())
|
326 |
+
# interpolate_residual_list.append(torch.norm(interpolate_residual.detach(), dim=1).mean())
|
327 |
+
if self.training and self.use_stochastic_depth and random.random() < self.drop_rate:
|
328 |
+
if (si == 0 and self.keep_first_quant) or (si == scale_num - 1 and self.keep_last_quant):
|
329 |
+
quantized, indices, _, loss = self.lfq(interpolate_residual)
|
330 |
+
quantized = quantized * out_fact
|
331 |
+
all_indices.append(indices)
|
332 |
+
all_losses.append(loss)
|
333 |
+
else:
|
334 |
+
quantized = torch.zeros_like(interpolate_residual)
|
335 |
+
elif self.drop_when_test and drop_lvl_start <= si < drop_lvl_end:
|
336 |
+
continue
|
337 |
+
else:
|
338 |
+
# residual_norm = torch.norm(interpolate_residual.detach(), dim=1) # (b, t, h, w)
|
339 |
+
# print(si, residual_norm.min(), residual_norm.max(), residual_norm.mean())
|
340 |
+
quantized, indices, bit_indices, loss = self.lfq(interpolate_residual)
|
341 |
+
if self.random_flip and si < self.max_flip_lvl:
|
342 |
+
quantized = self.flip_quant(quantized)
|
343 |
+
if self.random_flip_1lvl and si == self.flip_lvl_idx:
|
344 |
+
quantized = self.flip_quant(quantized)
|
345 |
+
quantized = quantized * out_fact
|
346 |
+
all_indices.append(indices)
|
347 |
+
# quantized_list.append(torch.norm(quantized.detach(), dim=1).mean())
|
348 |
+
if (pt, ph, pw) != (T, H, W):
|
349 |
+
quantized = F.interpolate(quantized, size=(T, H, W), mode=self.z_interplote_up).contiguous()
|
350 |
+
|
351 |
+
if self.remove_residual_detach:
|
352 |
+
residual = residual - quantized
|
353 |
+
else:
|
354 |
+
residual = residual - quantized.detach()
|
355 |
+
quantized_out = quantized_out + quantized
|
356 |
+
|
357 |
+
all_bit_indices.append(bit_indices)
|
358 |
+
all_losses.append(loss)
|
359 |
+
if si != scale_num - 1:
|
360 |
+
var_inputs.append(F.interpolate(quantized_out, size=scale_schedule[si+1], mode=self.z_interplote_down).contiguous())
|
361 |
+
|
362 |
+
if self.use_decay_factor:
|
363 |
+
out_fact -= 0.1
|
364 |
+
# print("residual_list:", residual_list)
|
365 |
+
# print("interpolate_residual_list:", interpolate_residual_list)
|
366 |
+
# print("quantized_list:", quantized_list)
|
367 |
+
# import ipdb; ipdb.set_trace()
|
368 |
+
# project out, if needed
|
369 |
+
quantized_out = quantized_out.permute(0, 2, 3, 4, 1).contiguous() # (b, c, t, h, w) => (b, t, h, w, c)
|
370 |
+
quantized_out = self.project_out(quantized_out)
|
371 |
+
quantized_out = quantized_out.permute(0, 4, 1, 2, 3).contiguous() # (b, t, h, w, c) => (b, c, t, h, w)
|
372 |
+
|
373 |
+
# image
|
374 |
+
if quantized_out.size(2) == 1:
|
375 |
+
quantized_out = quantized_out.squeeze(2)
|
376 |
+
|
377 |
+
# stack all losses and indices
|
378 |
+
|
379 |
+
all_losses = torch.stack(all_losses, dim = -1)
|
380 |
+
|
381 |
+
ret = (quantized_out, all_indices, all_bit_indices, residual_norm_per_scale, all_losses, var_inputs)
|
382 |
+
|
383 |
+
if not return_all_codes:
|
384 |
+
return ret
|
385 |
+
|
386 |
+
# whether to return all codes from all codebooks across layers
|
387 |
+
all_codes = self.get_codes_from_indices(all_indices)
|
388 |
+
|
389 |
+
# will return all codes in shape (quantizer, batch, sequence length, codebook dimension)
|
390 |
+
|
391 |
+
return (*ret, all_codes)
|
392 |
+
|
393 |
+
|
394 |
+
class BSQ(Module):
|
395 |
+
def __init__(
|
396 |
+
self,
|
397 |
+
*,
|
398 |
+
dim = None,
|
399 |
+
codebook_size = None,
|
400 |
+
entropy_loss_weight = 0.1,
|
401 |
+
commitment_loss_weight = 0.25,
|
402 |
+
diversity_gamma = 1.,
|
403 |
+
straight_through_activation = nn.Identity(),
|
404 |
+
num_codebooks = 1,
|
405 |
+
keep_num_codebooks_dim = None,
|
406 |
+
codebook_scale = 1., # for residual LFQ, codebook scaled down by 2x at each layer
|
407 |
+
frac_per_sample_entropy = 1., # make less than 1. to only use a random fraction of the probs for per sample entropy
|
408 |
+
has_projections = None,
|
409 |
+
projection_has_bias = True,
|
410 |
+
soft_clamp_input_value = None,
|
411 |
+
cosine_sim_project_in = False,
|
412 |
+
cosine_sim_project_in_scale = None,
|
413 |
+
channel_first = None,
|
414 |
+
experimental_softplus_entropy_loss = False,
|
415 |
+
entropy_loss_offset = 5., # how much to shift the loss before softplus
|
416 |
+
spherical = True, # from https://arxiv.org/abs/2406.07548
|
417 |
+
force_quantization_f32 = True, # will force the quantization step to be full precision
|
418 |
+
inv_temperature = 100.0,
|
419 |
+
gamma0=1.0, gamma=1.0, zeta=1.0,
|
420 |
+
preserve_norm = False, # whether to preserve the original norm info
|
421 |
+
new_quant = False, # new quant function,
|
422 |
+
mask_out = False, # mask the output as 0 in some conditions
|
423 |
+
use_out_phi = False, # use output phi network
|
424 |
+
use_out_phi_res = False, # residual out phi
|
425 |
+
):
|
426 |
+
super().__init__()
|
427 |
+
|
428 |
+
# some assert validations
|
429 |
+
|
430 |
+
assert exists(dim) or exists(codebook_size), 'either dim or codebook_size must be specified for LFQ'
|
431 |
+
assert not exists(codebook_size) or log2(codebook_size).is_integer(), f'your codebook size must be a power of 2 for lookup free quantization (suggested {2 ** ceil(log2(codebook_size))})'
|
432 |
+
|
433 |
+
codebook_size = default(codebook_size, lambda: 2 ** dim)
|
434 |
+
self.codebook_size = codebook_size
|
435 |
+
|
436 |
+
codebook_dim = int(log2(codebook_size))
|
437 |
+
codebook_dims = codebook_dim * num_codebooks
|
438 |
+
dim = default(dim, codebook_dims)
|
439 |
+
self.codebook_dims = codebook_dims
|
440 |
+
|
441 |
+
has_projections = default(has_projections, dim != codebook_dims)
|
442 |
+
|
443 |
+
if cosine_sim_project_in:
|
444 |
+
cosine_sim_project_in = default(cosine_sim_project_in_scale, codebook_scale)
|
445 |
+
project_in_klass = partial(CosineSimLinear, scale = cosine_sim_project_in)
|
446 |
+
else:
|
447 |
+
project_in_klass = partial(nn.Linear, bias = projection_has_bias)
|
448 |
+
|
449 |
+
self.project_in = project_in_klass(dim, codebook_dims) if has_projections else nn.Identity() # nn.Identity()
|
450 |
+
self.project_out = nn.Linear(codebook_dims, dim, bias = projection_has_bias) if has_projections else nn.Identity() # nn.Identity()
|
451 |
+
self.has_projections = has_projections
|
452 |
+
|
453 |
+
self.out_phi = nn.Linear(codebook_dims, codebook_dims) if use_out_phi else nn.Identity()
|
454 |
+
self.use_out_phi_res = use_out_phi_res
|
455 |
+
if self.use_out_phi_res:
|
456 |
+
self.out_phi_scale = nn.Parameter(torch.zeros(codebook_dims), requires_grad=True) # init as zero
|
457 |
+
|
458 |
+
self.dim = dim
|
459 |
+
self.codebook_dim = codebook_dim
|
460 |
+
self.num_codebooks = num_codebooks
|
461 |
+
|
462 |
+
keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1)
|
463 |
+
assert not (num_codebooks > 1 and not keep_num_codebooks_dim)
|
464 |
+
self.keep_num_codebooks_dim = keep_num_codebooks_dim
|
465 |
+
|
466 |
+
# channel first
|
467 |
+
|
468 |
+
self.channel_first = channel_first
|
469 |
+
|
470 |
+
# straight through activation
|
471 |
+
|
472 |
+
self.activation = straight_through_activation
|
473 |
+
|
474 |
+
# For BSQ (binary spherical quantization)
|
475 |
+
if not spherical:
|
476 |
+
raise ValueError("For BSQ, spherical must be True.")
|
477 |
+
self.persample_entropy_compute = 'analytical'
|
478 |
+
self.inv_temperature = inv_temperature
|
479 |
+
self.gamma0 = gamma0 # loss weight for entropy penalty
|
480 |
+
self.gamma = gamma # loss weight for entropy penalty
|
481 |
+
self.zeta = zeta # loss weight for entire entropy penalty
|
482 |
+
self.preserve_norm = preserve_norm
|
483 |
+
self.new_quant = new_quant
|
484 |
+
self.mask_out = mask_out
|
485 |
+
|
486 |
+
# entropy aux loss related weights
|
487 |
+
|
488 |
+
assert 0 < frac_per_sample_entropy <= 1.
|
489 |
+
self.frac_per_sample_entropy = frac_per_sample_entropy
|
490 |
+
|
491 |
+
self.diversity_gamma = diversity_gamma
|
492 |
+
self.entropy_loss_weight = entropy_loss_weight
|
493 |
+
|
494 |
+
# codebook scale
|
495 |
+
|
496 |
+
self.codebook_scale = codebook_scale
|
497 |
+
|
498 |
+
# commitment loss
|
499 |
+
|
500 |
+
self.commitment_loss_weight = commitment_loss_weight
|
501 |
+
|
502 |
+
# whether to soft clamp the input value from -value to value
|
503 |
+
|
504 |
+
self.soft_clamp_input_value = soft_clamp_input_value
|
505 |
+
assert not exists(soft_clamp_input_value) or soft_clamp_input_value >= codebook_scale
|
506 |
+
|
507 |
+
# whether to make the entropy loss positive through a softplus (experimental, please report if this worked or not in discussions)
|
508 |
+
|
509 |
+
self.entropy_loss_offset = entropy_loss_offset
|
510 |
+
self.experimental_softplus_entropy_loss = experimental_softplus_entropy_loss
|
511 |
+
|
512 |
+
# for no auxiliary loss, during inference
|
513 |
+
|
514 |
+
self.register_buffer('mask', 2 ** torch.arange(codebook_dim - 1, -1, -1))
|
515 |
+
self.register_buffer('zero', torch.tensor(0.), persistent = False)
|
516 |
+
|
517 |
+
# whether to force quantization step to be f32
|
518 |
+
|
519 |
+
self.force_quantization_f32 = force_quantization_f32
|
520 |
+
|
521 |
+
# codes
|
522 |
+
|
523 |
+
# all_codes = torch.arange(codebook_size)
|
524 |
+
# bits = ((all_codes[..., None].int() & self.mask) != 0).float()
|
525 |
+
# codebook = self.bits_to_codes(bits)
|
526 |
+
|
527 |
+
# self.register_buffer('codebook', codebook.float(), persistent = False)
|
528 |
+
|
529 |
+
def bits_to_codes(self, bits):
|
530 |
+
return bits * self.codebook_scale * 2 - self.codebook_scale
|
531 |
+
|
532 |
+
# @property
|
533 |
+
# def dtype(self):
|
534 |
+
# return self.codebook.dtype
|
535 |
+
|
536 |
+
def indices_to_codes(
|
537 |
+
self,
|
538 |
+
indices,
|
539 |
+
label_type = 'int_label',
|
540 |
+
project_out = True
|
541 |
+
):
|
542 |
+
assert label_type in ['int_label', 'bit_label']
|
543 |
+
is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim))
|
544 |
+
should_transpose = default(self.channel_first, is_img_or_video)
|
545 |
+
|
546 |
+
if not self.keep_num_codebooks_dim:
|
547 |
+
if label_type == 'int_label':
|
548 |
+
indices = rearrange(indices, '... -> ... 1')
|
549 |
+
else:
|
550 |
+
indices = indices.unsqueeze(-2)
|
551 |
+
|
552 |
+
# indices to codes, which are bits of either -1 or 1
|
553 |
+
|
554 |
+
if label_type == 'int_label':
|
555 |
+
assert indices[..., None].int().min() > 0
|
556 |
+
bits = ((indices[..., None].int() & self.mask) != 0).float() # .to(self.dtype)
|
557 |
+
else:
|
558 |
+
bits = indices
|
559 |
+
|
560 |
+
codes = self.bits_to_codes(bits)
|
561 |
+
|
562 |
+
codes = l2norm(codes) # must normalize when using BSQ
|
563 |
+
|
564 |
+
codes = rearrange(codes, '... c d -> ... (c d)')
|
565 |
+
|
566 |
+
# whether to project codes out to original dimensions
|
567 |
+
# if the input feature dimensions were not log2(codebook size)
|
568 |
+
|
569 |
+
if project_out:
|
570 |
+
codes = self.project_out(codes)
|
571 |
+
|
572 |
+
# rearrange codes back to original shape
|
573 |
+
|
574 |
+
if should_transpose:
|
575 |
+
codes = rearrange(codes, 'b ... d -> b d ...')
|
576 |
+
|
577 |
+
return codes
|
578 |
+
|
579 |
+
def quantize(self, z):
|
580 |
+
assert z.shape[-1] == self.codebook_dims, f"Expected {self.codebook_dims} dimensions, got {z.shape[-1]}"
|
581 |
+
|
582 |
+
zhat = torch.where(z > 0,
|
583 |
+
torch.tensor(1, dtype=z.dtype, device=z.device),
|
584 |
+
torch.tensor(-1, dtype=z.dtype, device=z.device))
|
585 |
+
return z + (zhat - z).detach()
|
586 |
+
|
587 |
+
def quantize_new(self, z):
|
588 |
+
assert z.shape[-1] == self.codebook_dims, f"Expected {self.codebook_dims} dimensions, got {z.shape[-1]}"
|
589 |
+
|
590 |
+
zhat = torch.where(z > 0,
|
591 |
+
torch.tensor(1, dtype=z.dtype, device=z.device),
|
592 |
+
torch.tensor(-1, dtype=z.dtype, device=z.device))
|
593 |
+
|
594 |
+
q_scale = 1. / (self.codebook_dims ** 0.5)
|
595 |
+
zhat = q_scale * zhat # on unit sphere
|
596 |
+
|
597 |
+
return z + (zhat - z).detach()
|
598 |
+
|
599 |
+
def soft_entropy_loss(self, z):
|
600 |
+
if self.persample_entropy_compute == 'analytical':
|
601 |
+
# if self.l2_norm:
|
602 |
+
p = torch.sigmoid(-4 * z / (self.codebook_dims ** 0.5) * self.inv_temperature)
|
603 |
+
# else:
|
604 |
+
# p = torch.sigmoid(-4 * z * self.inv_temperature)
|
605 |
+
prob = torch.stack([p, 1-p], dim=-1) # (b, h, w, 18, 2)
|
606 |
+
per_sample_entropy = self.get_entropy(prob, dim=-1, normalize=False).sum(dim=-1).mean() # (b,h,w,18)->(b,h,w)->scalar
|
607 |
+
else:
|
608 |
+
per_sample_entropy = self.get_entropy(prob, dim=-1, normalize=False).sum(dim=-1).mean()
|
609 |
+
|
610 |
+
# macro average of the probability of each subgroup
|
611 |
+
avg_prob = reduce(prob, '... g d ->g d', 'mean') # (18, 2)
|
612 |
+
codebook_entropy = self.get_entropy(avg_prob, dim=-1, normalize=False)
|
613 |
+
|
614 |
+
# the approximation of the entropy is the sum of the entropy of each subgroup
|
615 |
+
return per_sample_entropy, codebook_entropy.sum(), avg_prob
|
616 |
+
|
617 |
+
def get_entropy(self, count, dim=-1, eps=1e-4, normalize=True):
|
618 |
+
if normalize: # False
|
619 |
+
probs = (count + eps) / (count + eps).sum(dim=dim, keepdim =True)
|
620 |
+
else: # True
|
621 |
+
probs = count
|
622 |
+
H = -(probs * torch.log(probs + 1e-8)).sum(dim=dim)
|
623 |
+
return H
|
624 |
+
|
625 |
+
def forward(
|
626 |
+
self,
|
627 |
+
x,
|
628 |
+
return_loss_breakdown = False,
|
629 |
+
mask = None,
|
630 |
+
entropy_weight=0.1
|
631 |
+
):
|
632 |
+
"""
|
633 |
+
einstein notation
|
634 |
+
b - batch
|
635 |
+
n - sequence (or flattened spatial dimensions)
|
636 |
+
d - feature dimension, which is also log2(codebook size)
|
637 |
+
c - number of codebook dim
|
638 |
+
"""
|
639 |
+
|
640 |
+
is_img_or_video = x.ndim >= 4
|
641 |
+
should_transpose = default(self.channel_first, is_img_or_video)
|
642 |
+
|
643 |
+
# standardize image or video into (batch, seq, dimension)
|
644 |
+
|
645 |
+
if should_transpose:
|
646 |
+
x = rearrange(x, 'b d ... -> b ... d')
|
647 |
+
x, ps = pack_one(x, 'b * d') # x.shape [b, hwt, c]
|
648 |
+
|
649 |
+
assert x.shape[-1] == self.dim, f'expected dimension of {self.dim} but received {x.shape[-1]}'
|
650 |
+
|
651 |
+
x = self.project_in(x)
|
652 |
+
|
653 |
+
# split out number of codebooks
|
654 |
+
|
655 |
+
x = rearrange(x, 'b n (c d) -> b n c d', c = self.num_codebooks)
|
656 |
+
|
657 |
+
x = l2norm(x)
|
658 |
+
|
659 |
+
# whether to force quantization step to be full precision or not
|
660 |
+
|
661 |
+
force_f32 = self.force_quantization_f32
|
662 |
+
|
663 |
+
quantization_context = partial(autocast, 'cuda', enabled = False) if force_f32 else nullcontext
|
664 |
+
|
665 |
+
indices = None
|
666 |
+
with quantization_context():
|
667 |
+
|
668 |
+
if force_f32:
|
669 |
+
orig_dtype = x.dtype
|
670 |
+
x = x.float()
|
671 |
+
|
672 |
+
# use straight-through gradients (optionally with custom activation fn) if training
|
673 |
+
if self.new_quant:
|
674 |
+
quantized = self.quantize_new(x)
|
675 |
+
|
676 |
+
# calculate indices
|
677 |
+
bit_indices = (quantized > 0).int()
|
678 |
+
entropy_penalty = persample_entropy = cb_entropy = self.zero
|
679 |
+
commit_loss = self.zero
|
680 |
+
|
681 |
+
# input back to original dtype if needed
|
682 |
+
|
683 |
+
if force_f32:
|
684 |
+
x = x.type(orig_dtype)
|
685 |
+
|
686 |
+
# merge back codebook dim
|
687 |
+
x = quantized # rename quantized to x for output
|
688 |
+
x = rearrange(x, 'b n c d -> b n (c d)')
|
689 |
+
|
690 |
+
# project out to feature dimension if needed
|
691 |
+
|
692 |
+
x = self.project_out(x)
|
693 |
+
|
694 |
+
# reconstitute image or video dimensions
|
695 |
+
|
696 |
+
if should_transpose:
|
697 |
+
x = unpack_one(x, ps, 'b * d')
|
698 |
+
x = rearrange(x, 'b ... d -> b d ...')
|
699 |
+
|
700 |
+
bit_indices = unpack_one(bit_indices, ps, 'b * c d')
|
701 |
+
|
702 |
+
# whether to remove single codebook dim
|
703 |
+
|
704 |
+
if not self.keep_num_codebooks_dim:
|
705 |
+
bit_indices = rearrange(bit_indices, '... 1 d -> ... d')
|
706 |
+
|
707 |
+
# complete aux loss
|
708 |
+
|
709 |
+
aux_loss = commit_loss * self.commitment_loss_weight + (self.zeta * entropy_penalty / self.inv_temperature)*entropy_weight
|
710 |
+
# returns
|
711 |
+
|
712 |
+
ret = Return(x, indices, bit_indices, aux_loss)
|
713 |
+
|
714 |
+
if not return_loss_breakdown:
|
715 |
+
return ret
|
716 |
+
|
717 |
+
return ret, LossBreakdown(persample_entropy, cb_entropy, commit_loss)
|
718 |
+
|
models/bsq_vae/vae.py
ADDED
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import torch
|
3 |
+
|
4 |
+
from infinity.models.bsq_vae.flux_vqgan import AutoEncoder
|
5 |
+
|
6 |
+
def load_cnn(model, state_dict, prefix, expand=False, use_linear=False):
|
7 |
+
delete_keys = []
|
8 |
+
loaded_keys = []
|
9 |
+
for key in state_dict:
|
10 |
+
if key.startswith(prefix):
|
11 |
+
_key = key[len(prefix):]
|
12 |
+
if _key in model.state_dict():
|
13 |
+
# load nn.Conv2d or nn.Linear to nn.Linear
|
14 |
+
if use_linear and (".q.weight" in key or ".k.weight" in key or ".v.weight" in key or ".proj_out.weight" in key):
|
15 |
+
load_weights = state_dict[key].squeeze()
|
16 |
+
elif _key.endswith(".conv.weight") and expand:
|
17 |
+
if model.state_dict()[_key].shape == state_dict[key].shape:
|
18 |
+
# 2D cnn to 2D cnn
|
19 |
+
load_weights = state_dict[key]
|
20 |
+
else:
|
21 |
+
# 2D cnn to 3D cnn
|
22 |
+
_expand_dim = model.state_dict()[_key].shape[2]
|
23 |
+
load_weights = state_dict[key].unsqueeze(2).repeat(1, 1, _expand_dim, 1, 1)
|
24 |
+
else:
|
25 |
+
load_weights = state_dict[key]
|
26 |
+
model.state_dict()[_key].copy_(load_weights)
|
27 |
+
delete_keys.append(key)
|
28 |
+
loaded_keys.append(prefix+_key)
|
29 |
+
# load nn.Conv2d to Conv class
|
30 |
+
conv_list = ["conv"] if use_linear else ["conv", ".q.", ".k.", ".v.", ".proj_out.", ".nin_shortcut."]
|
31 |
+
if any(k in _key for k in conv_list):
|
32 |
+
if _key.endswith(".weight"):
|
33 |
+
conv_key = _key.replace(".weight", ".conv.weight")
|
34 |
+
if conv_key and conv_key in model.state_dict():
|
35 |
+
if model.state_dict()[conv_key].shape == state_dict[key].shape:
|
36 |
+
# 2D cnn to 2D cnn
|
37 |
+
load_weights = state_dict[key]
|
38 |
+
else:
|
39 |
+
# 2D cnn to 3D cnn
|
40 |
+
_expand_dim = model.state_dict()[conv_key].shape[2]
|
41 |
+
load_weights = state_dict[key].unsqueeze(2).repeat(1, 1, _expand_dim, 1, 1)
|
42 |
+
model.state_dict()[conv_key].copy_(load_weights)
|
43 |
+
delete_keys.append(key)
|
44 |
+
loaded_keys.append(prefix+conv_key)
|
45 |
+
if _key.endswith(".bias"):
|
46 |
+
conv_key = _key.replace(".bias", ".conv.bias")
|
47 |
+
if conv_key and conv_key in model.state_dict():
|
48 |
+
model.state_dict()[conv_key].copy_(state_dict[key])
|
49 |
+
delete_keys.append(key)
|
50 |
+
loaded_keys.append(prefix+conv_key)
|
51 |
+
# load nn.GroupNorm to Normalize class
|
52 |
+
if "norm" in _key:
|
53 |
+
if _key.endswith(".weight"):
|
54 |
+
norm_key = _key.replace(".weight", ".norm.weight")
|
55 |
+
if norm_key and norm_key in model.state_dict():
|
56 |
+
model.state_dict()[norm_key].copy_(state_dict[key])
|
57 |
+
delete_keys.append(key)
|
58 |
+
loaded_keys.append(prefix+norm_key)
|
59 |
+
if _key.endswith(".bias"):
|
60 |
+
norm_key = _key.replace(".bias", ".norm.bias")
|
61 |
+
if norm_key and norm_key in model.state_dict():
|
62 |
+
model.state_dict()[norm_key].copy_(state_dict[key])
|
63 |
+
delete_keys.append(key)
|
64 |
+
loaded_keys.append(prefix+norm_key)
|
65 |
+
|
66 |
+
for key in delete_keys:
|
67 |
+
del state_dict[key]
|
68 |
+
|
69 |
+
return model, state_dict, loaded_keys
|
70 |
+
|
71 |
+
|
72 |
+
def vae_model(vqgan_ckpt, schedule_mode, codebook_dim, codebook_size, test_mode=True, patch_size=16, encoder_ch_mult=[1, 2, 4, 4, 4], decoder_ch_mult=[1, 2, 4, 4, 4],):
|
73 |
+
args=argparse.Namespace(
|
74 |
+
vqgan_ckpt=vqgan_ckpt,
|
75 |
+
sd_ckpt=None,
|
76 |
+
inference_type='image',
|
77 |
+
save='./imagenet_val_bsq',
|
78 |
+
save_prediction=True,
|
79 |
+
image_recon4video=False,
|
80 |
+
junke_old=False,
|
81 |
+
device='cuda',
|
82 |
+
max_steps=1000000.0,
|
83 |
+
log_every=1,
|
84 |
+
visu_every=1000,
|
85 |
+
ckpt_every=1000,
|
86 |
+
default_root_dir='',
|
87 |
+
compile='no',
|
88 |
+
ema='no',
|
89 |
+
lr=0.0001,
|
90 |
+
beta1=0.9,
|
91 |
+
beta2=0.95,
|
92 |
+
warmup_steps=0,
|
93 |
+
optim_type='Adam',
|
94 |
+
disc_optim_type=None,
|
95 |
+
lr_min=0.0,
|
96 |
+
warmup_lr_init=0.0,
|
97 |
+
max_grad_norm=1.0,
|
98 |
+
max_grad_norm_disc=1.0,
|
99 |
+
disable_sch=False,
|
100 |
+
patch_size=patch_size,
|
101 |
+
temporal_patch_size=4,
|
102 |
+
embedding_dim=256,
|
103 |
+
codebook_dim=codebook_dim,
|
104 |
+
num_quantizers=8,
|
105 |
+
quantizer_type='MultiScaleBSQ',
|
106 |
+
use_vae=False,
|
107 |
+
use_freq_enc=False,
|
108 |
+
use_freq_dec=False,
|
109 |
+
preserve_norm=False,
|
110 |
+
ln_before_quant=False,
|
111 |
+
ln_init_by_sqrt=False,
|
112 |
+
use_pxsf=False,
|
113 |
+
new_quant=True,
|
114 |
+
use_decay_factor=False,
|
115 |
+
mask_out=False,
|
116 |
+
use_stochastic_depth=False,
|
117 |
+
drop_rate=0.0,
|
118 |
+
schedule_mode=schedule_mode,
|
119 |
+
lr_drop=None,
|
120 |
+
lr_drop_rate=0.1,
|
121 |
+
keep_first_quant=False,
|
122 |
+
keep_last_quant=False,
|
123 |
+
remove_residual_detach=False,
|
124 |
+
use_out_phi=False,
|
125 |
+
use_out_phi_res=False,
|
126 |
+
use_lecam_reg=False,
|
127 |
+
lecam_weight=0.05,
|
128 |
+
perceptual_model='vgg16',
|
129 |
+
base_ch_disc=64,
|
130 |
+
random_flip=False,
|
131 |
+
flip_prob=0.5,
|
132 |
+
flip_mode='stochastic',
|
133 |
+
max_flip_lvl=1,
|
134 |
+
not_load_optimizer=False,
|
135 |
+
use_lecam_reg_zero=False,
|
136 |
+
freeze_encoder=False,
|
137 |
+
rm_downsample=False,
|
138 |
+
random_flip_1lvl=False,
|
139 |
+
flip_lvl_idx=0,
|
140 |
+
drop_when_test=False,
|
141 |
+
drop_lvl_idx=0,
|
142 |
+
drop_lvl_num=1,
|
143 |
+
disc_version='v1',
|
144 |
+
magvit_disc=False,
|
145 |
+
sigmoid_in_disc=False,
|
146 |
+
activation_in_disc='leaky_relu',
|
147 |
+
apply_blur=False,
|
148 |
+
apply_noise=False,
|
149 |
+
dis_warmup_steps=0,
|
150 |
+
dis_lr_multiplier=1.0,
|
151 |
+
dis_minlr_multiplier=False,
|
152 |
+
disc_channels=64,
|
153 |
+
disc_layers=3,
|
154 |
+
discriminator_iter_start=0,
|
155 |
+
disc_pretrain_iter=0,
|
156 |
+
disc_optim_steps=1,
|
157 |
+
disc_warmup=0,
|
158 |
+
disc_pool='no',
|
159 |
+
disc_pool_size=1000,
|
160 |
+
advanced_disc=False,
|
161 |
+
recon_loss_type='l1',
|
162 |
+
video_perceptual_weight=0.0,
|
163 |
+
image_gan_weight=1.0,
|
164 |
+
video_gan_weight=1.0,
|
165 |
+
image_disc_weight=0.0,
|
166 |
+
video_disc_weight=0.0,
|
167 |
+
l1_weight=4.0,
|
168 |
+
gan_feat_weight=0.0,
|
169 |
+
perceptual_weight=0.0,
|
170 |
+
kl_weight=0.0,
|
171 |
+
lfq_weight=0.0,
|
172 |
+
entropy_loss_weight=0.1,
|
173 |
+
commitment_loss_weight=0.25,
|
174 |
+
diversity_gamma=1,
|
175 |
+
norm_type='group',
|
176 |
+
disc_loss_type='hinge',
|
177 |
+
use_checkpoint=False,
|
178 |
+
precision='fp32',
|
179 |
+
encoder_dtype='fp32',
|
180 |
+
upcast_attention='',
|
181 |
+
upcast_tf32=False,
|
182 |
+
tokenizer='flux',
|
183 |
+
pretrained=None,
|
184 |
+
pretrained_mode='full',
|
185 |
+
inflation_pe=False,
|
186 |
+
init_vgen='no',
|
187 |
+
no_init_idis=False,
|
188 |
+
init_idis='keep',
|
189 |
+
init_vdis='no',
|
190 |
+
enable_nan_detector=False,
|
191 |
+
turn_on_profiler=False,
|
192 |
+
profiler_scheduler_wait_steps=10,
|
193 |
+
debug=True,
|
194 |
+
video_logger=False,
|
195 |
+
bytenas='',
|
196 |
+
username='',
|
197 |
+
seed=1234,
|
198 |
+
vq_to_vae=False,
|
199 |
+
load_not_strict=False,
|
200 |
+
zero=0,
|
201 |
+
bucket_cap_mb=40,
|
202 |
+
manual_gc_interval=1000,
|
203 |
+
data_path=[''],
|
204 |
+
data_type=[''],
|
205 |
+
dataset_list=['imagenet'],
|
206 |
+
fps=-1,
|
207 |
+
dataaug='resizecrop',
|
208 |
+
multi_resolution=False,
|
209 |
+
random_bucket_ratio=0.0,
|
210 |
+
sequence_length=16,
|
211 |
+
resolution=[256, 256],
|
212 |
+
batch_size=[1],
|
213 |
+
num_workers=0,
|
214 |
+
image_channels=3,
|
215 |
+
codebook_size=codebook_size,
|
216 |
+
codebook_l2_norm=True,
|
217 |
+
codebook_show_usage=True,
|
218 |
+
commit_loss_beta=0.25,
|
219 |
+
entropy_loss_ratio=0.0,
|
220 |
+
base_ch=128,
|
221 |
+
num_res_blocks=2,
|
222 |
+
encoder_ch_mult=encoder_ch_mult,
|
223 |
+
decoder_ch_mult=decoder_ch_mult,
|
224 |
+
dropout_p=0.0,
|
225 |
+
cnn_type='2d',
|
226 |
+
cnn_version='v1',
|
227 |
+
conv_in_out_2d='no',
|
228 |
+
conv_inner_2d='no',
|
229 |
+
res_conv_2d='no',
|
230 |
+
cnn_attention='no',
|
231 |
+
cnn_norm_axis='spatial',
|
232 |
+
flux_weight=0,
|
233 |
+
cycle_weight=0,
|
234 |
+
cycle_feat_weight=0,
|
235 |
+
cycle_gan_weight=0,
|
236 |
+
cycle_loop=0,
|
237 |
+
z_drop=0.0)
|
238 |
+
|
239 |
+
vae = AutoEncoder(args)
|
240 |
+
use_vae = vae.use_vae
|
241 |
+
if not use_vae:
|
242 |
+
num_codes = args.codebook_size
|
243 |
+
if isinstance(vqgan_ckpt, str):
|
244 |
+
state_dict = torch.load(args.vqgan_ckpt, map_location=torch.device("cpu"), weights_only=True)
|
245 |
+
else:
|
246 |
+
state_dict = args.vqgan_ckpt
|
247 |
+
if state_dict:
|
248 |
+
if args.ema == "yes":
|
249 |
+
vae, new_state_dict, loaded_keys = load_cnn(vae, state_dict["ema"], prefix="", expand=False)
|
250 |
+
else:
|
251 |
+
vae, new_state_dict, loaded_keys = load_cnn(vae, state_dict["vae"], prefix="", expand=False)
|
252 |
+
if test_mode:
|
253 |
+
vae.eval()
|
254 |
+
[p.requires_grad_(False) for p in vae.parameters()]
|
255 |
+
return vae
|
models/ema.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import torch
|
3 |
+
from collections import OrderedDict
|
4 |
+
|
5 |
+
|
6 |
+
def get_ema_model(model):
|
7 |
+
ema_model = copy.deepcopy(model)
|
8 |
+
ema_model.eval()
|
9 |
+
for param in ema_model.parameters():
|
10 |
+
param.requires_grad = False
|
11 |
+
return ema_model
|
12 |
+
|
13 |
+
@torch.no_grad()
|
14 |
+
def update_ema(ema_model, model, decay=0.9999):
|
15 |
+
"""
|
16 |
+
Step the EMA model towards the current model.
|
17 |
+
"""
|
18 |
+
ema_params = OrderedDict(ema_model.named_parameters())
|
19 |
+
model_params = OrderedDict(model.named_parameters())
|
20 |
+
|
21 |
+
for name, param in model_params.items():
|
22 |
+
# TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed
|
23 |
+
ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay)
|
models/flex_attn.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Wrap torch's flex attention and handle mess info or potentially refactor
|
3 |
+
"""
|
4 |
+
from functools import partial
|
5 |
+
import torch
|
6 |
+
import numpy as np
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
try:
|
10 |
+
from torch.nn.attention.flex_attention import flex_attention, create_block_mask
|
11 |
+
flex_attention_available = True
|
12 |
+
except ImportError:
|
13 |
+
print(f"[Warning] flex attention need pytorch 2.5.0+ but your version is {torch.__version__}")
|
14 |
+
flex_attention_available = False
|
15 |
+
|
16 |
+
def _causal_mask(b, h, q_idx, kv_idx):
|
17 |
+
return q_idx >= kv_idx
|
18 |
+
|
19 |
+
def _length_to_offsets(lengths, device):
|
20 |
+
"""Converts a list of lengths to a list of offsets.
|
21 |
+
|
22 |
+
Args:
|
23 |
+
lengths: A list of lengths.
|
24 |
+
|
25 |
+
"""
|
26 |
+
offsets = [0]
|
27 |
+
offsets.extend(lengths)
|
28 |
+
offsets = torch.tensor(offsets, device=device, dtype=torch.int32)
|
29 |
+
offsets = torch.cumsum(offsets, dim=-1)
|
30 |
+
return offsets
|
31 |
+
|
32 |
+
def _generate_var_mask_mod(offsets):
|
33 |
+
"""Generates mask mods that apply to inputs to flex attention in the sequence stacked
|
34 |
+
format.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
offsets: This tensor should be of shape(num_documents + 1)
|
38 |
+
this should contain the cumulative counts of document tokens.
|
39 |
+
e.g. if you have 3 documents of length 2, 4, 3 then
|
40 |
+
offsets = [0, 2, 6, 9]
|
41 |
+
|
42 |
+
Note:
|
43 |
+
What is the sequence stacked format? When assembling batches of inputs, we
|
44 |
+
take multiple sequences and stack them together to form 1 large sequence. We then
|
45 |
+
use masking to ensure that the attention scores are only applied to tokens within
|
46 |
+
the same document.
|
47 |
+
"""
|
48 |
+
|
49 |
+
def _offsets_to_doc_ids_tensor(offsets):
|
50 |
+
device = offsets.device
|
51 |
+
counts = offsets[1:] - offsets[:-1]
|
52 |
+
return torch.repeat_interleave(
|
53 |
+
torch.arange(len(counts), device=device, dtype=torch.int32), counts
|
54 |
+
)
|
55 |
+
|
56 |
+
document_id = _offsets_to_doc_ids_tensor(offsets)
|
57 |
+
|
58 |
+
def var_mask_mod(b, h, q_idx, kv_idx):
|
59 |
+
same_doc = document_id[q_idx] == document_id[kv_idx]
|
60 |
+
causal_mask = _causal_mask(b, h, q_idx, kv_idx)
|
61 |
+
return same_doc | causal_mask
|
62 |
+
|
63 |
+
return var_mask_mod
|
64 |
+
|
65 |
+
def _generate_var_infer_mask_with_kv_cache(lengths):
|
66 |
+
kv_len = sum(lengths)
|
67 |
+
def var_mask_mod(b, h, q_idx, kv_idx):
|
68 |
+
return kv_idx < kv_len
|
69 |
+
|
70 |
+
return var_mask_mod
|
71 |
+
|
72 |
+
class FlexAttn(nn.Module):
|
73 |
+
def __init__(
|
74 |
+
self, block_scales:list, mask_type:str, B, H, L:int, auto_padding=False
|
75 |
+
):
|
76 |
+
"""
|
77 |
+
:param block_scales: accept VAR's block sizes like [(1,1), (2,2), (3,3)]
|
78 |
+
:param mask_type: var/causal
|
79 |
+
:param B: batch size
|
80 |
+
:param H: heads num
|
81 |
+
:param L: sequence length
|
82 |
+
"""
|
83 |
+
super().__init__()
|
84 |
+
if not flex_attention_available:
|
85 |
+
raise NotImplementedError((f"[Error] flex attention need pytorch 2.5.0+ but your version is {torch.__version__}"))
|
86 |
+
|
87 |
+
self.support_mask_type = ["var", "causal", "var_infer_mask_with_kv_cache"]
|
88 |
+
self.auto_padding = auto_padding
|
89 |
+
|
90 |
+
self.flex_attention = torch.compile(flex_attention)
|
91 |
+
|
92 |
+
self.block_scales = block_scales
|
93 |
+
self.lengths = [ x * y * z for x,y,z in block_scales]
|
94 |
+
|
95 |
+
self.offsets = _length_to_offsets(self.lengths, device='cuda')
|
96 |
+
|
97 |
+
# if L paded to align 128, block need to cover padding area
|
98 |
+
if self.offsets[-1] < L:
|
99 |
+
self.offsets = torch.cat((self.offsets, torch.tensor([L], device='cuda')), dim=0)
|
100 |
+
|
101 |
+
if mask_type == "var":
|
102 |
+
self.mask_mod = _generate_var_mask_mod(self.offsets)
|
103 |
+
self.block_mask = create_block_mask(self.mask_mod, B = B, H = H, Q_LEN = L, KV_LEN = L, device = 'cuda', _compile = True)
|
104 |
+
elif mask_type == "causal":
|
105 |
+
self.mask_mod = _causal_mask
|
106 |
+
self.block_mask = create_block_mask(self.mask_mod, B = B, H = H, Q_LEN = L, KV_LEN = L, device = 'cuda', _compile = True)
|
107 |
+
elif mask_type == 'var_infer_mask_with_kv_cache':
|
108 |
+
self.mask_mod = _generate_var_infer_mask_with_kv_cache(self.lengths)
|
109 |
+
self.block_mask = create_block_mask(self.mask_mod, B = B, H = H, Q_LEN = L, KV_LEN = L, device = 'cuda', _compile = True)
|
110 |
+
else:
|
111 |
+
raise NotImplementedError(f"{mask_type} not supportted in FlexAttn, support type:{self.support_mask_type}")
|
112 |
+
|
113 |
+
|
114 |
+
def forward(self, q, k, v, scale = None):
|
115 |
+
if self.auto_padding:
|
116 |
+
q_pad_len = (128 - q.shape[-2] % 128) % 128
|
117 |
+
kv_pad_len = (128 - k.shape[-2] % 128) % 128
|
118 |
+
q_pad = F.pad(q, (0, 0, 0, q_pad_len))
|
119 |
+
k_pad = F.pad(k, (0, 0, 0, kv_pad_len))
|
120 |
+
v_pad = F.pad(v, (0, 0, 0, kv_pad_len))
|
121 |
+
oup = self.flex_attention(q_pad.to(v_pad.dtype), k_pad.to(v.dtype), v_pad, block_mask = self.block_mask, scale = scale)
|
122 |
+
if q_pad_len > 0:
|
123 |
+
oup = oup[:,:,:-q_pad_len]
|
124 |
+
else:
|
125 |
+
oup = self.flex_attention(q.to(v.dtype), k.to(v.dtype), v, block_mask = self.block_mask, scale = scale)
|
126 |
+
return oup
|
127 |
+
|
128 |
+
def extra_repr(self) -> str:
|
129 |
+
tail = ''
|
130 |
+
return f'block size:{self.block_scales} {tail}'
|
models/fused_op.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gc
|
2 |
+
from copy import deepcopy
|
3 |
+
from typing import Union
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from torch import nn as nn
|
7 |
+
from torch.nn import functional as F
|
8 |
+
|
9 |
+
|
10 |
+
@torch.compile(fullgraph=True)
|
11 |
+
def fused_rms_norm(x: torch.Tensor, weight: nn.Parameter, eps: float):
|
12 |
+
x = x.float()
|
13 |
+
return (x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True).add_(eps))) * weight
|
14 |
+
|
15 |
+
|
16 |
+
@torch.compile(fullgraph=True)
|
17 |
+
def fused_ada_layer_norm(C: int, eps: float, x: torch.Tensor, scale: torch.Tensor, shift: torch.Tensor):
|
18 |
+
x = x.float()
|
19 |
+
x = F.layer_norm(input=x, normalized_shape=(C,), weight=None, bias=None, eps=eps)
|
20 |
+
return x.mul(scale.add(1)).add_(shift)
|
21 |
+
|
22 |
+
|
23 |
+
@torch.compile(fullgraph=True)
|
24 |
+
def fused_ada_rms_norm(C: int, eps: float, x: torch.Tensor, scale: torch.Tensor, shift: torch.Tensor):
|
25 |
+
x = x.float()
|
26 |
+
x = (x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True).add_(eps)))
|
27 |
+
return x.mul(scale.add(1)).add_(shift)
|
models/infinity.py
ADDED
@@ -0,0 +1,795 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Definition of Infinity transformer model.
|
3 |
+
"""
|
4 |
+
|
5 |
+
import math
|
6 |
+
import random
|
7 |
+
import time
|
8 |
+
from contextlib import nullcontext
|
9 |
+
from functools import partial
|
10 |
+
from typing import List, Optional, Tuple, Union, Dict, Any
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
import torch.nn.functional as F
|
15 |
+
from timm.models import register_model
|
16 |
+
from torch.utils.checkpoint import checkpoint
|
17 |
+
from PIL import Image
|
18 |
+
import numpy as np
|
19 |
+
from torch.nn.attention.flex_attention import flex_attention
|
20 |
+
|
21 |
+
import infinity.utils.dist as dist
|
22 |
+
from infinity.utils.dist import for_visualize
|
23 |
+
from infinity.models.basic import flash_attn_func, flash_fused_op_installed, AdaLNBeforeHead, CrossAttnBlock, SelfAttnBlock, CrossAttention, FastRMSNorm, precompute_rope2d_freqs_grid
|
24 |
+
from infinity.utils import misc
|
25 |
+
from infinity.models.flex_attn import FlexAttn
|
26 |
+
from infinity.utils.dynamic_resolution import dynamic_resolution_h_w, h_div_w_templates
|
27 |
+
|
28 |
+
try:
|
29 |
+
from infinity.models.fused_op import fused_ada_layer_norm, fused_ada_rms_norm
|
30 |
+
except:
|
31 |
+
fused_ada_layer_norm, fused_ada_rms_norm = None, None
|
32 |
+
|
33 |
+
|
34 |
+
class MultiInpIdentity(nn.Module):
|
35 |
+
def forward(self, x, *args, **kwargs):
|
36 |
+
return x
|
37 |
+
|
38 |
+
|
39 |
+
class TextAttentivePool(nn.Module):
|
40 |
+
def __init__(self, Ct5: int, D: int):
|
41 |
+
super().__init__()
|
42 |
+
self.Ct5, self.D = Ct5, D
|
43 |
+
if D > 4096:
|
44 |
+
self.head_dim = 64
|
45 |
+
else:
|
46 |
+
self.head_dim = 128
|
47 |
+
|
48 |
+
self.num_heads = Ct5 // self.head_dim
|
49 |
+
self.ca = CrossAttention(for_attn_pool=True, embed_dim=self.D, kv_dim=Ct5, num_heads=self.num_heads)
|
50 |
+
def forward(self, ca_kv):
|
51 |
+
return self.ca(None, ca_kv).squeeze(1)
|
52 |
+
|
53 |
+
class SharedAdaLin(nn.Linear):
|
54 |
+
def forward(self, cond_BD):
|
55 |
+
C = self.weight.shape[0] // 6
|
56 |
+
return super().forward(cond_BD).reshape(-1, 1, 6, C) # B16C
|
57 |
+
|
58 |
+
|
59 |
+
class MultipleLayers(nn.Module):
|
60 |
+
def __init__(self, ls, num_blocks_in_a_chunk, index):
|
61 |
+
super().__init__()
|
62 |
+
self.module = nn.ModuleList()
|
63 |
+
for i in range(index, index+num_blocks_in_a_chunk):
|
64 |
+
self.module.append(ls[i])
|
65 |
+
|
66 |
+
def forward(self, x, cond_BD, ca_kv, attn_bias_or_two_vector, attn_fn=None, scale_schedule=None, checkpointing_full_block=False, rope2d_freqs_grid=None):
|
67 |
+
h = x
|
68 |
+
for m in self.module:
|
69 |
+
if checkpointing_full_block:
|
70 |
+
h = torch.utils.checkpoint.checkpoint(m, h, cond_BD, ca_kv, attn_bias_or_two_vector, attn_fn, scale_schedule, rope2d_freqs_grid, use_reentrant=False)
|
71 |
+
else:
|
72 |
+
h = m(h, cond_BD, ca_kv, attn_bias_or_two_vector, attn_fn, scale_schedule, rope2d_freqs_grid)
|
73 |
+
return h
|
74 |
+
|
75 |
+
class Infinity(nn.Module):
|
76 |
+
def __init__(
|
77 |
+
self, vae_local,
|
78 |
+
text_channels=0, text_maxlen=0, # text-cond generation
|
79 |
+
selecting_idx=None, # class-cond generation
|
80 |
+
embed_dim=1024, depth=16, num_heads=16, mlp_ratio=4., # model's architecture
|
81 |
+
drop_rate=0., drop_path_rate=0., # drop out and drop path
|
82 |
+
norm_eps=1e-6, rms_norm=False, # norm layer
|
83 |
+
shared_aln=False, head_aln=True, # adaptive norm
|
84 |
+
cond_drop_rate=0.1, # for classifier-free guidance
|
85 |
+
rand_uncond=False,
|
86 |
+
cross_attn_layer_scale=-1., nm0=False, tau=1, cos_attn=True, swiglu=False,
|
87 |
+
raw_scale_schedule=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16),
|
88 |
+
head_depth=1,
|
89 |
+
top_p=0.0, top_k=0.0,
|
90 |
+
customized_flash_attn=False, fused_mlp=False, fused_norm=False,
|
91 |
+
block_chunks=1,
|
92 |
+
checkpointing=None,
|
93 |
+
pad_to_multiplier=0,
|
94 |
+
use_flex_attn=False,
|
95 |
+
batch_size=2,
|
96 |
+
add_lvl_embeding_only_first_block=1,
|
97 |
+
use_bit_label=1,
|
98 |
+
rope2d_each_sa_layer=0,
|
99 |
+
rope2d_normalized_by_hw=0,
|
100 |
+
pn=None,
|
101 |
+
train_h_div_w_list=None,
|
102 |
+
video_frames=1,
|
103 |
+
always_training_scales=20,
|
104 |
+
apply_spatial_patchify = 0,
|
105 |
+
inference_mode=False,
|
106 |
+
):
|
107 |
+
# set hyperparameters
|
108 |
+
self.C = embed_dim
|
109 |
+
self.inference_mode = inference_mode
|
110 |
+
self.apply_spatial_patchify = apply_spatial_patchify
|
111 |
+
if self.apply_spatial_patchify:
|
112 |
+
self.d_vae = vae_local.embed_dim * 4
|
113 |
+
else:
|
114 |
+
self.d_vae = vae_local.embed_dim
|
115 |
+
self.use_bit_label = use_bit_label
|
116 |
+
self.codebook_dim = self.d_vae
|
117 |
+
self.V = (self.codebook_dim * 2) if self.use_bit_label else vae_local.vocab_size
|
118 |
+
self.bit_mask = vae_local.quantizer.lfq.mask if self.use_bit_label else None
|
119 |
+
self.Ct5 = text_channels
|
120 |
+
self.depth = depth
|
121 |
+
self.num_heads = num_heads
|
122 |
+
self.batch_size = batch_size
|
123 |
+
self.mlp_ratio = mlp_ratio
|
124 |
+
self.cond_drop_rate = cond_drop_rate
|
125 |
+
self.norm_eps = norm_eps
|
126 |
+
self.prog_si = -1
|
127 |
+
self.pn = pn
|
128 |
+
self.train_h_div_w_list = train_h_div_w_list if train_h_div_w_list else h_div_w_templates
|
129 |
+
self.video_frames = video_frames
|
130 |
+
self.always_training_scales = always_training_scales
|
131 |
+
|
132 |
+
assert add_lvl_embeding_only_first_block in [0,1]
|
133 |
+
self.add_lvl_embeding_only_first_block = add_lvl_embeding_only_first_block
|
134 |
+
assert rope2d_each_sa_layer in [0,1]
|
135 |
+
self.rope2d_each_sa_layer = rope2d_each_sa_layer
|
136 |
+
self.rope2d_normalized_by_hw = rope2d_normalized_by_hw
|
137 |
+
print(f'self.codebook_dim: {self.codebook_dim}, self.add_lvl_embeding_only_first_block: {self.add_lvl_embeding_only_first_block}, \
|
138 |
+
self.use_bit_label: {self.use_bit_label}, self.rope2d_each_sa_layer: {rope2d_each_sa_layer}, self.rope2d_normalized_by_hw: {self.rope2d_normalized_by_hw}')
|
139 |
+
head_up_method = ''
|
140 |
+
word_patch_size = 1 if head_up_method in {'', 'no'} else 2
|
141 |
+
if word_patch_size > 1:
|
142 |
+
assert all(raw_pn % word_patch_size == 0 for raw_pn in raw_scale_schedule), f'raw_scale_schedule={raw_scale_schedule}, not compatible with word_patch_size={word_patch_size}'
|
143 |
+
|
144 |
+
self.checkpointing = checkpointing
|
145 |
+
self.pad_to_multiplier = max(1, pad_to_multiplier)
|
146 |
+
|
147 |
+
customized_kernel_installed = any('Infinity' in arg_name for arg_name in flash_attn_func.__code__.co_varnames)
|
148 |
+
self.customized_flash_attn = customized_flash_attn and customized_kernel_installed
|
149 |
+
if customized_flash_attn and not customized_kernel_installed:
|
150 |
+
import inspect, warnings
|
151 |
+
file_path = inspect.getsourcefile(flash_attn_func)
|
152 |
+
line_number = inspect.getsourcelines(flash_attn_func)[1]
|
153 |
+
info = (
|
154 |
+
f'>>>>>> Customized FlashAttention2 is not installed or compiled, but specified in args by --flash=1. Set customized_flash_attn = False. <<<<<<\n'
|
155 |
+
f'>>>>>> `flash_attn_func` is in [line {line_number}] [file {file_path}] <<<<<<\n'
|
156 |
+
f'>>>>>> {flash_attn_func.__code__.co_varnames=} <<<<<<\n'
|
157 |
+
)
|
158 |
+
warnings.warn(info, ImportWarning)
|
159 |
+
print(info, flush=True)
|
160 |
+
|
161 |
+
self.raw_scale_schedule = raw_scale_schedule # 'raw' means before any patchifying
|
162 |
+
self.first_l = 1
|
163 |
+
# solve top-p top-k sampling hyperparameters
|
164 |
+
self.top_p, self.top_k = max(min(top_p, 1), 0), (round(top_k * self.V) if 0 < top_k < 1 else round(top_k))
|
165 |
+
if self.top_p < 1e-5: self.top_p = 0
|
166 |
+
if self.top_k >= self.V or self.top_k <= 0: self.top_k = 0
|
167 |
+
|
168 |
+
t = torch.zeros(dist.get_world_size(), device=dist.get_device())
|
169 |
+
t[dist.get_rank()] = float(flash_fused_op_installed)
|
170 |
+
dist.barrier()
|
171 |
+
dist.allreduce(t)
|
172 |
+
assert round(t.sum().item()) in {0, dist.get_world_size()}, f'flash_fused_op_installed: {t}'
|
173 |
+
|
174 |
+
super().__init__()
|
175 |
+
self.rng = torch.Generator(device=dist.get_device())
|
176 |
+
self.maybe_record_function = nullcontext
|
177 |
+
self.text_maxlen = text_maxlen
|
178 |
+
self.t2i = text_channels != 0
|
179 |
+
|
180 |
+
# [inp & position embedding]
|
181 |
+
init_std = math.sqrt(1 / self.C / 3)
|
182 |
+
self.norm0_cond = nn.Identity()
|
183 |
+
if self.t2i:
|
184 |
+
self.selecting_idx = None
|
185 |
+
self.num_classes = 0
|
186 |
+
self.D = self.C
|
187 |
+
|
188 |
+
cfg_uncond = torch.empty(self.text_maxlen, self.Ct5)
|
189 |
+
rng = torch.Generator(device='cpu')
|
190 |
+
rng.manual_seed(0)
|
191 |
+
torch.nn.init.trunc_normal_(cfg_uncond, std=1.2, generator=rng)
|
192 |
+
cfg_uncond /= self.Ct5 ** 0.5
|
193 |
+
if rand_uncond:
|
194 |
+
self.register_buffer('cfg_uncond', cfg_uncond)
|
195 |
+
else:
|
196 |
+
self.cfg_uncond = nn.Parameter(cfg_uncond)
|
197 |
+
|
198 |
+
self.text_norm = FastRMSNorm(self.Ct5, elementwise_affine=True, eps=norm_eps)
|
199 |
+
self.text_proj_for_sos = TextAttentivePool(self.Ct5, self.D)
|
200 |
+
self.text_proj_for_ca = nn.Sequential(
|
201 |
+
nn.Linear(self.Ct5, self.D),
|
202 |
+
nn.GELU(approximate='tanh'),
|
203 |
+
nn.Linear(self.D, self.D),
|
204 |
+
)
|
205 |
+
else: # class-label cond
|
206 |
+
if selecting_idx is None:
|
207 |
+
num_classes = 1000
|
208 |
+
print(f'======= WARNING: selecting_idx not specified, set to 1/{num_classes} @ {dist.get_device()} =======')
|
209 |
+
selecting_idx = torch.full((1, num_classes), fill_value=1/num_classes, dtype=torch.float32, device=dist.get_device())
|
210 |
+
self.selecting_idx = selecting_idx
|
211 |
+
self.num_classes = selecting_idx.shape[-1]
|
212 |
+
self.D = self.C
|
213 |
+
self.class_emb = nn.Embedding(self.num_classes + 1, self.C)
|
214 |
+
nn.init.trunc_normal_(self.class_emb.weight.data, mean=0, std=init_std)
|
215 |
+
|
216 |
+
self.pos_start = nn.Parameter(torch.empty(1, self.first_l, self.C))
|
217 |
+
nn.init.trunc_normal_(self.pos_start.data, mean=0, std=init_std)
|
218 |
+
if self.rope2d_each_sa_layer:
|
219 |
+
rope2d_freqs_grid = precompute_rope2d_freqs_grid(dim=self.C//self.num_heads, dynamic_resolution_h_w=dynamic_resolution_h_w, pad_to_multiplier=self.pad_to_multiplier, rope2d_normalized_by_hw=self.rope2d_normalized_by_hw)
|
220 |
+
self.rope2d_freqs_grid = rope2d_freqs_grid
|
221 |
+
else:
|
222 |
+
raise ValueError(f'self.rope2d_each_sa_layer={self.rope2d_each_sa_layer} not implemented')
|
223 |
+
self.lvl_embed = nn.Embedding(15, self.C)
|
224 |
+
nn.init.trunc_normal_(self.lvl_embed.weight.data, mean=0, std=init_std)
|
225 |
+
|
226 |
+
# [input layers] input norm && input embedding
|
227 |
+
norm_layer = partial(FastRMSNorm if rms_norm else nn.LayerNorm, eps=norm_eps)
|
228 |
+
self.norm0_ve = norm_layer(self.d_vae) if nm0 else nn.Identity()
|
229 |
+
self.word_embed = nn.Linear(self.d_vae, self.C)
|
230 |
+
|
231 |
+
# [shared adaptive layernorm mapping network]
|
232 |
+
self.shared_ada_lin = nn.Sequential(nn.SiLU(inplace=False), SharedAdaLin(self.D, 6*self.C)) if shared_aln else nn.Identity()
|
233 |
+
|
234 |
+
# fused norm
|
235 |
+
if fused_norm:
|
236 |
+
fused_norm_func = fused_ada_rms_norm if rms_norm else fused_ada_layer_norm
|
237 |
+
if fused_norm_func is not None: # pre-compile
|
238 |
+
B = 2
|
239 |
+
x = torch.randn(B, 1, self.C).requires_grad_(True)
|
240 |
+
scale = torch.randn(B, 1, self.C).mul_(0.01).requires_grad_(True)
|
241 |
+
shift = torch.randn(B, 1, self.C).mul_(0.01).requires_grad_(True)
|
242 |
+
# fused_norm_func(C=self.C, eps=self.norm_eps, x=x, scale=scale, shift=shift).mean().backward()
|
243 |
+
del B, x, scale, shift
|
244 |
+
else:
|
245 |
+
fused_norm_func = None
|
246 |
+
|
247 |
+
# [backbone and head]
|
248 |
+
self.use_flex_attn = use_flex_attn
|
249 |
+
self.attn_fn_compile_dict = {}
|
250 |
+
self.batch_size = batch_size
|
251 |
+
if self.use_flex_attn:
|
252 |
+
self.attn_fn_compile_dict = self.compile_flex_attn()
|
253 |
+
|
254 |
+
self.drop_path_rate = drop_path_rate
|
255 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # dpr means drop path rate (linearly increasing)
|
256 |
+
self.unregistered_blocks = []
|
257 |
+
for block_idx in range(depth):
|
258 |
+
block = (CrossAttnBlock if self.t2i else SelfAttnBlock)(
|
259 |
+
embed_dim=self.C, kv_dim=self.D, cross_attn_layer_scale=cross_attn_layer_scale, cond_dim=self.D, act=True, shared_aln=shared_aln, norm_layer=norm_layer,
|
260 |
+
num_heads=num_heads, mlp_ratio=mlp_ratio, drop=drop_rate, drop_path=dpr[block_idx], tau=tau, cos_attn=cos_attn,
|
261 |
+
swiglu=swiglu, customized_flash_attn=self.customized_flash_attn, fused_mlp=fused_mlp, fused_norm_func=fused_norm_func,
|
262 |
+
checkpointing_sa_only=self.checkpointing == 'self-attn',
|
263 |
+
use_flex_attn=use_flex_attn, batch_size=batch_size, pad_to_multiplier=pad_to_multiplier, rope2d_normalized_by_hw=rope2d_normalized_by_hw,
|
264 |
+
)
|
265 |
+
self.unregistered_blocks.append(block)
|
266 |
+
|
267 |
+
# [head]
|
268 |
+
V = self.V
|
269 |
+
if head_aln:
|
270 |
+
self.head_nm = AdaLNBeforeHead(self.C, self.D, act=True, norm_layer=norm_layer, fused_norm_func=fused_norm_func)
|
271 |
+
self.head = nn.Linear(self.C, V) if head_depth == 1 else nn.Sequential(nn.Linear(self.C, self.C, bias=True), nn.GELU(approximate='tanh'), nn.Linear(self.C, V))
|
272 |
+
else:
|
273 |
+
self.head_nm = MultiInpIdentity()
|
274 |
+
self.head = nn.Sequential(norm_layer(self.C), nn.Linear(self.C, V)) if head_depth == 1 else nn.Sequential(norm_layer(self.C), nn.Linear(self.C, self.C, bias=True), nn.GELU(approximate='tanh'), nn.Linear(self.C, V))
|
275 |
+
|
276 |
+
self.num_block_chunks = block_chunks or 1
|
277 |
+
self.num_blocks_in_a_chunk = depth // block_chunks
|
278 |
+
print(f"{self.num_blocks_in_a_chunk=}, {depth=}, {block_chunks=}")
|
279 |
+
assert self.num_blocks_in_a_chunk * block_chunks == depth
|
280 |
+
if self.num_block_chunks == 1:
|
281 |
+
self.blocks = nn.ModuleList(self.unregistered_blocks)
|
282 |
+
else:
|
283 |
+
self.block_chunks = nn.ModuleList()
|
284 |
+
for i in range(self.num_block_chunks):
|
285 |
+
self.block_chunks.append(MultipleLayers(self.unregistered_blocks, self.num_blocks_in_a_chunk, i*self.num_blocks_in_a_chunk))
|
286 |
+
print(
|
287 |
+
f'\n[constructor] ==== customized_flash_attn={self.customized_flash_attn} (using_flash={sum((b.sa.using_flash if self.t2i else b.attn.using_flash) for b in self.unregistered_blocks)}/{self.depth}), fused_mlp={fused_mlp} (fused_mlp={sum(b.ffn.fused_mlp_func is not None for b in self.unregistered_blocks)}/{self.depth}) ==== \n'
|
288 |
+
f' [Infinity config ] embed_dim={embed_dim}, num_heads={num_heads}, depth={depth}, mlp_ratio={mlp_ratio}, swiglu={swiglu} num_blocks_in_a_chunk={self.num_blocks_in_a_chunk}\n'
|
289 |
+
f' [drop ratios] drop_rate={drop_rate}, drop_path_rate={drop_path_rate:g} ({torch.linspace(0, drop_path_rate, depth)})',
|
290 |
+
end='\n\n', flush=True
|
291 |
+
)
|
292 |
+
|
293 |
+
|
294 |
+
def compile_flex_attn(self):
|
295 |
+
attn_fn_compile_dict = {}
|
296 |
+
for h_div_w in self.train_h_div_w_list:
|
297 |
+
h_div_w_template = h_div_w_templates[np.argmin(np.abs(float(h_div_w) - h_div_w_templates))]
|
298 |
+
full_scale_schedule = dynamic_resolution_h_w[h_div_w_template][self.pn]['scales']
|
299 |
+
if self.inference_mode:
|
300 |
+
apply_flex_attn_scales = list(range(1, 1+len(full_scale_schedule)))
|
301 |
+
mask_type = "infinity_infer_mask_with_kv_cache"
|
302 |
+
auto_padding = True
|
303 |
+
else:
|
304 |
+
mask_type = 'var'
|
305 |
+
auto_padding = False
|
306 |
+
apply_flex_attn_scales = [min(self.always_training_scales, len(full_scale_schedule))]
|
307 |
+
for scales_num in apply_flex_attn_scales:
|
308 |
+
print(f'====== apply flex attn hdivw: {h_div_w} scales: {scales_num} ======')
|
309 |
+
scale_schedule = full_scale_schedule[:scales_num]
|
310 |
+
scale_schedule = [ (min(t, self.video_frames//4+1), h, w) for (t,h, w) in scale_schedule]
|
311 |
+
patchs_nums_tuple = tuple(scale_schedule)
|
312 |
+
SEQ_L = sum( pt * ph * pw for pt, ph, pw in patchs_nums_tuple)
|
313 |
+
aligned_L = SEQ_L+ (self.pad_to_multiplier - SEQ_L % self.pad_to_multiplier) if SEQ_L % self.pad_to_multiplier != 0 else SEQ_L
|
314 |
+
attn_fn = FlexAttn(block_scales = patchs_nums_tuple,
|
315 |
+
mask_type = mask_type,
|
316 |
+
B = self.batch_size,
|
317 |
+
H = self.num_heads,
|
318 |
+
L = aligned_L,
|
319 |
+
auto_padding=auto_padding)
|
320 |
+
attn_fn_compile_dict[patchs_nums_tuple] = attn_fn
|
321 |
+
|
322 |
+
if self.video_frames > 1: # append image attn_fn when self.video_frames > 1 (namely videos)
|
323 |
+
scale_schedule = [ (1, h, w) for (t,h, w) in scale_schedule]
|
324 |
+
patchs_nums_tuple = tuple(scale_schedule)
|
325 |
+
SEQ_L = sum( pt * ph * pw for pt, ph, pw in patchs_nums_tuple)
|
326 |
+
aligned_L = SEQ_L+ (self.pad_to_multiplier - SEQ_L % self.pad_to_multiplier) if SEQ_L % self.pad_to_multiplier != 0 else SEQ_L
|
327 |
+
attn_fn = FlexAttn(block_scales = patchs_nums_tuple,
|
328 |
+
mask_type = mask_type,
|
329 |
+
B = self.batch_size,
|
330 |
+
H = self.num_heads,
|
331 |
+
L = aligned_L)
|
332 |
+
attn_fn_compile_dict[patchs_nums_tuple] = attn_fn
|
333 |
+
return attn_fn_compile_dict
|
334 |
+
|
335 |
+
def get_logits(self, h: torch.Tensor, cond_BD: Optional[torch.Tensor]):
|
336 |
+
"""
|
337 |
+
:param h: hidden_state, shaped (B or batch_size, L or seq_len, C or hidden_dim)
|
338 |
+
:param cond_BD: shaped (B or batch_size, D or cond_dim)
|
339 |
+
:param tau: temperature
|
340 |
+
:return: logits, shaped (B or batch_size, V or vocabulary_size)
|
341 |
+
"""
|
342 |
+
with torch.amp.autocast('cuda', enabled=False):
|
343 |
+
return self.head(self.head_nm(h.float(), cond_BD.float()))
|
344 |
+
|
345 |
+
def add_lvl_embeding(self, feature, scale_ind, scale_schedule, need_to_pad=0):
|
346 |
+
bs, seq_len, c = feature.shape
|
347 |
+
patch_t, patch_h, patch_w = scale_schedule[scale_ind]
|
348 |
+
t_mul_h_mul_w = patch_t * patch_h * patch_w
|
349 |
+
assert t_mul_h_mul_w + need_to_pad == seq_len
|
350 |
+
feature[:, :t_mul_h_mul_w] += self.lvl_embed(scale_ind*torch.ones((bs, t_mul_h_mul_w),dtype=torch.int).to(feature.device))
|
351 |
+
return feature
|
352 |
+
|
353 |
+
def add_lvl_embeding_for_x_BLC(self, x_BLC, scale_schedule, need_to_pad=0):
|
354 |
+
ptr = 0
|
355 |
+
x_BLC_list = []
|
356 |
+
for scale_ind, patch_t_h_w in enumerate(scale_schedule):
|
357 |
+
scale_seq_len = np.array(patch_t_h_w).prod()
|
358 |
+
x_BLC_this_scale = x_BLC[:,ptr:ptr+scale_seq_len] # shape: [bs, patch_h*patch_w, c]
|
359 |
+
ptr += scale_seq_len
|
360 |
+
x_BLC_this_scale = self.add_lvl_embeding(x_BLC_this_scale, scale_ind, scale_schedule)
|
361 |
+
x_BLC_list.append(x_BLC_this_scale)
|
362 |
+
assert x_BLC.shape[1] == (ptr + need_to_pad), f'{x_BLC.shape[1]} != {ptr} + {need_to_pad}'
|
363 |
+
x_BLC_list.append(x_BLC[:,ptr:])
|
364 |
+
x_BLC = torch.cat(x_BLC_list, dim=1)
|
365 |
+
return x_BLC
|
366 |
+
|
367 |
+
def forward(self, label_B_or_BLT: Union[torch.LongTensor, Tuple[torch.FloatTensor, torch.IntTensor, int]], x_BLC_wo_prefix: torch.Tensor, scale_schedule: List[Tuple[int]],
|
368 |
+
cfg_infer=False,
|
369 |
+
**kwargs,
|
370 |
+
) -> Union[torch.Tensor, List[torch.Tensor]]: # returns logits_BLV
|
371 |
+
"""
|
372 |
+
label_B_or_BLT: label_B or (kv_compact, cu_seqlens_k, max_seqlen_k)
|
373 |
+
:return: logits BLV, V is vocab_size
|
374 |
+
"""
|
375 |
+
if cfg_infer:
|
376 |
+
return self.autoregressive_infer_cfg(label_B_or_BLT=label_B_or_BLT, scale_schedule=scale_schedule, **kwargs)
|
377 |
+
|
378 |
+
x_BLC_wo_prefix = x_BLC_wo_prefix.float() # input should be float32
|
379 |
+
B = x_BLC_wo_prefix.shape[0]
|
380 |
+
|
381 |
+
# [1. get input sequence x_BLC]
|
382 |
+
with torch.amp.autocast('cuda', enabled=False):
|
383 |
+
kv_compact, lens, cu_seqlens_k, max_seqlen_k = label_B_or_BLT
|
384 |
+
# drop cond
|
385 |
+
total = 0
|
386 |
+
for le in lens:
|
387 |
+
if random.random() < self.cond_drop_rate:
|
388 |
+
kv_compact[total:total+le] = self.cfg_uncond[:le]
|
389 |
+
total += le
|
390 |
+
must_on_graph = self.cfg_uncond[0, 0] * 0
|
391 |
+
kv_compact = self.text_norm(kv_compact).contiguous()
|
392 |
+
sos = cond_BD = self.text_proj_for_sos((kv_compact, cu_seqlens_k, max_seqlen_k)).float().contiguous() # cond_BD should be float32
|
393 |
+
kv_compact = self.text_proj_for_ca(kv_compact).contiguous()
|
394 |
+
kv_compact[0, 0] += must_on_graph
|
395 |
+
ca_kv = kv_compact, cu_seqlens_k, max_seqlen_k
|
396 |
+
|
397 |
+
cond_BD_or_gss = self.shared_ada_lin(cond_BD).contiguous() # gss: gamma, scale, shift; cond_BD_or_gss should be float32
|
398 |
+
|
399 |
+
sos = sos.unsqueeze(1).expand(B, 1, -1) + self.pos_start.expand(B, 1, -1)
|
400 |
+
x_BLC = torch.cat((sos, self.word_embed(self.norm0_ve(x_BLC_wo_prefix))), dim=1)
|
401 |
+
|
402 |
+
# [1.1. pad the seqlen dim]
|
403 |
+
l_end = x_BLC.shape[1]
|
404 |
+
need_to_pad = (l_end + self.pad_to_multiplier - 1) // self.pad_to_multiplier * self.pad_to_multiplier - l_end # 0
|
405 |
+
|
406 |
+
if self.customized_flash_attn:
|
407 |
+
Infinity_visible_kvlen = self.Infinity_visible_kvlen[:l_end]
|
408 |
+
Infinity_invisible_qlen = self.Infinity_invisible_qlen[:l_end]
|
409 |
+
attn_bias_or_two_vector = (Infinity_visible_kvlen, Infinity_invisible_qlen)
|
410 |
+
# todo: solve need_to_pad here
|
411 |
+
elif self.use_flex_attn:
|
412 |
+
if need_to_pad:
|
413 |
+
x_BLC = F.pad(x_BLC, (0, 0, 0, need_to_pad))
|
414 |
+
assert x_BLC.shape[-1] % 128 == 0, 'x_BLC.shape[-1] % 128 != 0'
|
415 |
+
attn_bias_or_two_vector = None
|
416 |
+
else:
|
417 |
+
d: torch.Tensor = torch.cat([torch.full((pn[0]*pn[1]*pn[2],), i) for i, pn in enumerate(scale_schedule)]).view(1, l_end, 1)
|
418 |
+
dT = d.transpose(1, 2) # dT: 11L
|
419 |
+
attn_bias_for_masking = torch.where(d >= dT, 0., -torch.inf).reshape(1, 1, l_end, l_end)
|
420 |
+
attn_bias = attn_bias_for_masking[:, :, :l_end, :l_end].contiguous() # attn_bias: 11LL
|
421 |
+
if need_to_pad:
|
422 |
+
attn_bias = F.pad(attn_bias, (0, need_to_pad, 0, need_to_pad), value=-torch.inf)
|
423 |
+
attn_bias[0, 0, l_end:, 0] = 0
|
424 |
+
x_BLC = F.pad(x_BLC, (0, 0, 0, need_to_pad))
|
425 |
+
attn_bias_or_two_vector = attn_bias.type_as(x_BLC).to(x_BLC.device)
|
426 |
+
|
427 |
+
if self.use_flex_attn:
|
428 |
+
attn_fn = self.attn_fn_compile_dict[tuple(scale_schedule)]
|
429 |
+
else:
|
430 |
+
attn_fn = None
|
431 |
+
|
432 |
+
# [2. block loop]
|
433 |
+
SelfAttnBlock.forward, CrossAttnBlock.forward
|
434 |
+
checkpointing_full_block = self.checkpointing == 'full-block' and self.training
|
435 |
+
if self.num_block_chunks == 1:
|
436 |
+
for i, b in enumerate(self.blocks):
|
437 |
+
if self.add_lvl_embeding_only_first_block and i == 0:
|
438 |
+
x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad)
|
439 |
+
if not self.add_lvl_embeding_only_first_block:
|
440 |
+
x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad)
|
441 |
+
if checkpointing_full_block:
|
442 |
+
x_BLC = torch.utils.checkpoint.checkpoint(b, x_BLC, cond_BD_or_gss, ca_kv, attn_bias_or_two_vector, attn_fn, scale_schedule, self.rope2d_freqs_grid, use_reentrant=False)
|
443 |
+
else:
|
444 |
+
x_BLC = b(x=x_BLC, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=attn_bias_or_two_vector, attn_fn=attn_fn, scale_schedule=scale_schedule, rope2d_freqs_grid=self.rope2d_freqs_grid)
|
445 |
+
else:
|
446 |
+
for i, chunk in enumerate(self.block_chunks): # this path
|
447 |
+
if self.add_lvl_embeding_only_first_block and i == 0:
|
448 |
+
x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad)
|
449 |
+
if not self.add_lvl_embeding_only_first_block:
|
450 |
+
x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad)
|
451 |
+
x_BLC = chunk(x=x_BLC, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=attn_bias_or_two_vector, attn_fn=attn_fn, scale_schedule=scale_schedule, checkpointing_full_block=checkpointing_full_block, rope2d_freqs_grid=self.rope2d_freqs_grid)
|
452 |
+
|
453 |
+
# [3. unpad the seqlen dim, and then get logits]
|
454 |
+
return self.get_logits(x_BLC[:, :l_end], cond_BD) # return logits BLV, V is vocab_size
|
455 |
+
|
456 |
+
@torch.no_grad()
|
457 |
+
def autoregressive_infer_cfg(
|
458 |
+
self,
|
459 |
+
vae=None,
|
460 |
+
scale_schedule=None,
|
461 |
+
label_B_or_BLT=None,
|
462 |
+
B=1, negative_label_B_or_BLT=None, force_gt_Bhw=None,
|
463 |
+
g_seed=None, cfg_list=[], tau_list=[], cfg_sc=3, top_k=0, top_p=0.0,
|
464 |
+
returns_vemb=0, ratio_Bl1=None, gumbel=0, norm_cfg=False,
|
465 |
+
cfg_exp_k: float=0.0, cfg_insertion_layer=[-5],
|
466 |
+
vae_type=0, softmax_merge_topk=-1, ret_img=False,
|
467 |
+
trunk_scale=1000,
|
468 |
+
gt_leak=0, gt_ls_Bl=None,
|
469 |
+
inference_mode=False,
|
470 |
+
save_img_path=None,
|
471 |
+
sampling_per_bits=1,
|
472 |
+
): # returns List[idx_Bl]
|
473 |
+
if g_seed is None: rng = None
|
474 |
+
else: self.rng.manual_seed(g_seed); rng = self.rng
|
475 |
+
assert len(cfg_list) >= len(scale_schedule)
|
476 |
+
assert len(tau_list) >= len(scale_schedule)
|
477 |
+
|
478 |
+
# scale_schedule is used by infinity, vae_scale_schedule is used by vae if there exists a spatial patchify,
|
479 |
+
# we need to convert scale_schedule to vae_scale_schedule by multiply 2 to h and w
|
480 |
+
if self.apply_spatial_patchify:
|
481 |
+
vae_scale_schedule = [(pt, 2*ph, 2*pw) for pt, ph, pw in scale_schedule]
|
482 |
+
else:
|
483 |
+
vae_scale_schedule = scale_schedule
|
484 |
+
|
485 |
+
kv_compact, lens, cu_seqlens_k, max_seqlen_k = label_B_or_BLT
|
486 |
+
if any(np.array(cfg_list) != 1):
|
487 |
+
bs = 2*B
|
488 |
+
if not negative_label_B_or_BLT:
|
489 |
+
kv_compact_un = kv_compact.clone()
|
490 |
+
total = 0
|
491 |
+
for le in lens:
|
492 |
+
kv_compact_un[total:total+le] = (self.cfg_uncond)[:le]
|
493 |
+
total += le
|
494 |
+
kv_compact = torch.cat((kv_compact, kv_compact_un), dim=0)
|
495 |
+
cu_seqlens_k = torch.cat((cu_seqlens_k, cu_seqlens_k[1:]+cu_seqlens_k[-1]), dim=0)
|
496 |
+
else:
|
497 |
+
kv_compact_un, lens_un, cu_seqlens_k_un, max_seqlen_k_un = negative_label_B_or_BLT
|
498 |
+
kv_compact = torch.cat((kv_compact, kv_compact_un), dim=0)
|
499 |
+
cu_seqlens_k = torch.cat((cu_seqlens_k, cu_seqlens_k_un[1:]+cu_seqlens_k[-1]), dim=0)
|
500 |
+
max_seqlen_k = max(max_seqlen_k, max_seqlen_k_un)
|
501 |
+
else:
|
502 |
+
bs = B
|
503 |
+
|
504 |
+
kv_compact = self.text_norm(kv_compact)
|
505 |
+
sos = cond_BD = self.text_proj_for_sos((kv_compact, cu_seqlens_k, max_seqlen_k)) # sos shape: [2, 4096]
|
506 |
+
kv_compact = self.text_proj_for_ca(kv_compact) # kv_compact shape: [304, 4096]
|
507 |
+
ca_kv = kv_compact, cu_seqlens_k, max_seqlen_k
|
508 |
+
last_stage = sos.unsqueeze(1).expand(bs, 1, -1) + self.pos_start.expand(bs, 1, -1)
|
509 |
+
|
510 |
+
with torch.amp.autocast('cuda', enabled=False):
|
511 |
+
cond_BD_or_gss = self.shared_ada_lin(cond_BD.float()).float().contiguous()
|
512 |
+
accu_BChw, cur_L, ret = None, 0, [] # current length, list of reconstructed images
|
513 |
+
idx_Bl_list, idx_Bld_list = [], []
|
514 |
+
|
515 |
+
if inference_mode:
|
516 |
+
for b in self.unregistered_blocks: (b.sa if isinstance(b, CrossAttnBlock) else b.attn).kv_caching(True)
|
517 |
+
else:
|
518 |
+
assert self.num_block_chunks > 1
|
519 |
+
for block_chunk_ in self.block_chunks:
|
520 |
+
for module in block_chunk_.module.module:
|
521 |
+
(module.sa if isinstance(module, CrossAttnBlock) else module.attn).kv_caching(True)
|
522 |
+
|
523 |
+
abs_cfg_insertion_layers = []
|
524 |
+
add_cfg_on_logits, add_cfg_on_probs = False, False
|
525 |
+
leng = len(self.unregistered_blocks)
|
526 |
+
for item in cfg_insertion_layer:
|
527 |
+
if item == 0: # add cfg on logits
|
528 |
+
add_cfg_on_logits = True
|
529 |
+
elif item == 1: # add cfg on probs
|
530 |
+
add_cfg_on_probs = True # todo in the future, we may want to add cfg on logits and probs
|
531 |
+
elif item < 0: # determine to add cfg at item-th layer's output
|
532 |
+
assert leng+item > 0, f'cfg_insertion_layer: {item} is not valid since len(unregistered_blocks)={self.num_block_chunks}'
|
533 |
+
abs_cfg_insertion_layers.append(leng+item)
|
534 |
+
else:
|
535 |
+
raise ValueError(f'cfg_insertion_layer: {item} is not valid')
|
536 |
+
|
537 |
+
num_stages_minus_1 = len(scale_schedule)-1
|
538 |
+
summed_codes = 0
|
539 |
+
for si, pn in enumerate(scale_schedule): # si: i-th segment
|
540 |
+
cfg = cfg_list[si]
|
541 |
+
if si >= trunk_scale:
|
542 |
+
break
|
543 |
+
cur_L += np.array(pn).prod()
|
544 |
+
|
545 |
+
need_to_pad = 0
|
546 |
+
attn_fn = None
|
547 |
+
if self.use_flex_attn:
|
548 |
+
# need_to_pad = (self.pad_to_multiplier - cur_L % self.pad_to_multiplier) % self.pad_to_multiplier
|
549 |
+
# if need_to_pad:
|
550 |
+
# last_stage = F.pad(last_stage, (0, 0, 0, need_to_pad))
|
551 |
+
attn_fn = self.attn_fn_compile_dict.get(tuple(scale_schedule[:(si+1)]), None)
|
552 |
+
|
553 |
+
# assert self.attn_bias_for_masking[:, :, last_L:cur_L, :cur_L].sum() == 0, f'AR with {(self.attn_bias_for_masking[:, :, last_L:cur_L, :cur_L] != 0).sum()} / {self.attn_bias_for_masking[:, :, last_L:cur_L, :cur_L].numel()} mask item'
|
554 |
+
layer_idx = 0
|
555 |
+
for block_idx, b in enumerate(self.block_chunks):
|
556 |
+
# last_stage shape: [4, 1, 2048], cond_BD_or_gss.shape: [4, 1, 6, 2048], ca_kv[0].shape: [64, 2048], ca_kv[1].shape [5], ca_kv[2]: int
|
557 |
+
if self.add_lvl_embeding_only_first_block and block_idx == 0:
|
558 |
+
last_stage = self.add_lvl_embeding(last_stage, si, scale_schedule, need_to_pad=need_to_pad)
|
559 |
+
if not self.add_lvl_embeding_only_first_block:
|
560 |
+
last_stage = self.add_lvl_embeding(last_stage, si, scale_schedule, need_to_pad=need_to_pad)
|
561 |
+
|
562 |
+
for m in b.module:
|
563 |
+
last_stage = m(x=last_stage, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=None, attn_fn=attn_fn, scale_schedule=scale_schedule, rope2d_freqs_grid=self.rope2d_freqs_grid, scale_ind=si)
|
564 |
+
if (cfg != 1) and (layer_idx in abs_cfg_insertion_layers):
|
565 |
+
# print(f'add cfg={cfg} on {layer_idx}-th layer output')
|
566 |
+
last_stage = cfg * last_stage[:B] + (1-cfg) * last_stage[B:]
|
567 |
+
last_stage = torch.cat((last_stage, last_stage), 0)
|
568 |
+
layer_idx += 1
|
569 |
+
|
570 |
+
if (cfg != 1) and add_cfg_on_logits:
|
571 |
+
# print(f'add cfg on add_cfg_on_logits')
|
572 |
+
logits_BlV = self.get_logits(last_stage, cond_BD).mul(1/tau_list[si])
|
573 |
+
logits_BlV = cfg * logits_BlV[:B] + (1-cfg) * logits_BlV[B:]
|
574 |
+
else:
|
575 |
+
logits_BlV = self.get_logits(last_stage[:B], cond_BD[:B]).mul(1/tau_list[si])
|
576 |
+
|
577 |
+
if self.use_bit_label:
|
578 |
+
tmp_bs, tmp_seq_len = logits_BlV.shape[:2]
|
579 |
+
logits_BlV = logits_BlV.reshape(tmp_bs, -1, 2)
|
580 |
+
idx_Bld = sample_with_top_k_top_p_also_inplace_modifying_logits_(logits_BlV, rng=rng, top_k=top_k or self.top_k, top_p=top_p or self.top_p, num_samples=1)[:, :, 0]
|
581 |
+
idx_Bld = idx_Bld.reshape(tmp_bs, tmp_seq_len, -1)
|
582 |
+
else:
|
583 |
+
idx_Bl = sample_with_top_k_top_p_also_inplace_modifying_logits_(logits_BlV, rng=rng, top_k=top_k or self.top_k, top_p=top_p or self.top_p, num_samples=1)[:, :, 0]
|
584 |
+
if vae_type != 0:
|
585 |
+
assert returns_vemb
|
586 |
+
if si < gt_leak:
|
587 |
+
idx_Bld = gt_ls_Bl[si]
|
588 |
+
else:
|
589 |
+
assert pn[0] == 1
|
590 |
+
idx_Bld = idx_Bld.reshape(B, pn[1], pn[2], -1) # shape: [B, h, w, d] or [B, h, w, 4d]
|
591 |
+
if self.apply_spatial_patchify: # unpatchify operation
|
592 |
+
idx_Bld = idx_Bld.permute(0,3,1,2) # [B, 4d, h, w]
|
593 |
+
idx_Bld = torch.nn.functional.pixel_shuffle(idx_Bld, 2) # [B, d, 2h, 2w]
|
594 |
+
idx_Bld = idx_Bld.permute(0,2,3,1) # [B, 2h, 2w, d]
|
595 |
+
idx_Bld = idx_Bld.unsqueeze(1) # [B, 1, h, w, d] or [B, 1, 2h, 2w, d]
|
596 |
+
|
597 |
+
idx_Bld_list.append(idx_Bld)
|
598 |
+
codes = vae.quantizer.lfq.indices_to_codes(idx_Bld, label_type='bit_label') # [B, d, 1, h, w] or [B, d, 1, 2h, 2w]
|
599 |
+
if si != num_stages_minus_1:
|
600 |
+
summed_codes += F.interpolate(codes, size=vae_scale_schedule[-1], mode=vae.quantizer.z_interplote_up)
|
601 |
+
last_stage = F.interpolate(summed_codes, size=vae_scale_schedule[si+1], mode=vae.quantizer.z_interplote_down) # [B, d, 1, h, w] or [B, d, 1, 2h, 2w]
|
602 |
+
last_stage = last_stage.squeeze(-3) # [B, d, h, w] or [B, d, 2h, 2w]
|
603 |
+
if self.apply_spatial_patchify: # patchify operation
|
604 |
+
last_stage = torch.nn.functional.pixel_unshuffle(last_stage, 2) # [B, 4d, h, w]
|
605 |
+
last_stage = last_stage.reshape(*last_stage.shape[:2], -1) # [B, d, h*w] or [B, 4d, h*w]
|
606 |
+
last_stage = torch.permute(last_stage, [0,2,1]) # [B, h*w, d] or [B, h*w, 4d]
|
607 |
+
else:
|
608 |
+
summed_codes += codes
|
609 |
+
else:
|
610 |
+
if si < gt_leak:
|
611 |
+
idx_Bl = gt_ls_Bl[si]
|
612 |
+
h_BChw = self.quant_only_used_in_inference[0].embedding(idx_Bl).float() # BlC
|
613 |
+
|
614 |
+
# h_BChw = h_BChw.float().transpose_(1, 2).reshape(B, self.d_vae, scale_schedule[si][0], scale_schedule[si][1])
|
615 |
+
h_BChw = h_BChw.transpose_(1, 2).reshape(B, self.d_vae, scale_schedule[si][0], scale_schedule[si][1], scale_schedule[si][2])
|
616 |
+
ret.append(h_BChw if returns_vemb != 0 else idx_Bl)
|
617 |
+
idx_Bl_list.append(idx_Bl)
|
618 |
+
if si != num_stages_minus_1:
|
619 |
+
accu_BChw, last_stage = self.quant_only_used_in_inference[0].one_step_fuse(si, num_stages_minus_1+1, accu_BChw, h_BChw, scale_schedule)
|
620 |
+
|
621 |
+
if si != num_stages_minus_1:
|
622 |
+
last_stage = self.word_embed(self.norm0_ve(last_stage))
|
623 |
+
last_stage = last_stage.repeat(bs//B, 1, 1)
|
624 |
+
|
625 |
+
if inference_mode:
|
626 |
+
for b in self.unregistered_blocks: (b.sa if isinstance(b, CrossAttnBlock) else b.attn).kv_caching(False)
|
627 |
+
else:
|
628 |
+
assert self.num_block_chunks > 1
|
629 |
+
for block_chunk_ in self.block_chunks:
|
630 |
+
for module in block_chunk_.module.module:
|
631 |
+
(module.sa if isinstance(module, CrossAttnBlock) else module.attn).kv_caching(False)
|
632 |
+
|
633 |
+
if not ret_img:
|
634 |
+
return ret, idx_Bl_list, []
|
635 |
+
|
636 |
+
if vae_type != 0:
|
637 |
+
img = vae.decode(summed_codes.squeeze(-3))
|
638 |
+
else:
|
639 |
+
img = vae.viz_from_ms_h_BChw(ret, scale_schedule=scale_schedule, same_shape=True, last_one=True)
|
640 |
+
|
641 |
+
img = (img + 1) / 2
|
642 |
+
img = img.permute(0, 2, 3, 1).mul_(255).to(torch.uint8).flip(dims=(3,))
|
643 |
+
return ret, idx_Bl_list, img
|
644 |
+
|
645 |
+
@for_visualize
|
646 |
+
def vis_key_params(self, ep):
|
647 |
+
return
|
648 |
+
|
649 |
+
def load_state_dict(self, state_dict: Dict[str, Any], strict=False, assign=False):
|
650 |
+
for k in state_dict:
|
651 |
+
if 'cfg_uncond' in k:
|
652 |
+
old, new = state_dict[k], self.cfg_uncond.data
|
653 |
+
min_tlen = min(old.shape[0], new.shape[0])
|
654 |
+
if min_tlen == old.shape[0]:
|
655 |
+
state_dict[k] = torch.cat((old.to(device=new.device, dtype=new.dtype), new[min_tlen:]))
|
656 |
+
else:
|
657 |
+
state_dict[k] = old[:min_tlen]
|
658 |
+
|
659 |
+
for buf_name in ('lvl_1L', 'attn_bias_for_masking', 'Infinity_visible_kvlen', 'Infinity_invisible_qlen'):
|
660 |
+
state_dict.pop(buf_name, None)
|
661 |
+
if hasattr(self, buf_name):
|
662 |
+
state_dict[buf_name] = getattr(self, buf_name)
|
663 |
+
|
664 |
+
return super().load_state_dict(state_dict=state_dict, strict=strict, assign=assign)
|
665 |
+
|
666 |
+
def special_init(
|
667 |
+
self,
|
668 |
+
aln_init: float,
|
669 |
+
aln_gamma_init: float,
|
670 |
+
scale_head: float,
|
671 |
+
scale_proj: int,
|
672 |
+
):
|
673 |
+
# init head's norm
|
674 |
+
if isinstance(self.head_nm, AdaLNBeforeHead):
|
675 |
+
self.head_nm.ada_lin[-1].weight.data.mul_(aln_init) # there's no gamma for head
|
676 |
+
if hasattr(self.head_nm.ada_lin[-1], 'bias') and self.head_nm.ada_lin[-1].bias is not None:
|
677 |
+
self.head_nm.ada_lin[-1].bias.data.zero_()
|
678 |
+
|
679 |
+
# init head's proj
|
680 |
+
if scale_head >= 0:
|
681 |
+
if isinstance(self.head, nn.Linear):
|
682 |
+
self.head.weight.data.mul_(scale_head)
|
683 |
+
self.head.bias.data.zero_()
|
684 |
+
elif isinstance(self.head, nn.Sequential):
|
685 |
+
self.head[-1].weight.data.mul_(scale_head)
|
686 |
+
self.head[-1].bias.data.zero_()
|
687 |
+
|
688 |
+
depth = len(self.unregistered_blocks)
|
689 |
+
for block_idx, sab in enumerate(self.unregistered_blocks):
|
690 |
+
sab: Union[SelfAttnBlock, CrossAttnBlock]
|
691 |
+
# init proj
|
692 |
+
scale = 1 / math.sqrt(2*depth if scale_proj == 1 else 2*(1 + block_idx))
|
693 |
+
if scale_proj == 1:
|
694 |
+
if self.t2i:
|
695 |
+
sab.sa.proj.weight.data.mul_(scale)
|
696 |
+
sab.ca.proj.weight.data.mul_(scale)
|
697 |
+
else:
|
698 |
+
sab.attn.proj.weight.data.mul_(scale)
|
699 |
+
sab.ffn.fc2.weight.data.mul_(scale)
|
700 |
+
# if sab.using_swiglu:
|
701 |
+
# nn.init.ones_(sab.ffn.fcg.bias)
|
702 |
+
# nn.init.trunc_normal_(sab.ffn.fcg.weight, std=1e-5)
|
703 |
+
|
704 |
+
# init ada_lin
|
705 |
+
if hasattr(sab, 'ada_lin'):
|
706 |
+
lin = sab.ada_lin[-1]
|
707 |
+
lin.weight.data[:2*self.C].mul_(aln_gamma_init) # init gamma
|
708 |
+
lin.weight.data[2*self.C:].mul_(aln_init) # init scale and shift
|
709 |
+
if hasattr(lin, 'bias') and lin.bias is not None:
|
710 |
+
lin.bias.data.zero_()
|
711 |
+
elif hasattr(sab, 'ada_gss'):
|
712 |
+
sab.ada_gss.data[:, :, :2, :].mul_(aln_gamma_init) # init gamma
|
713 |
+
sab.ada_gss.data[:, :, 2:, :].mul_(aln_init) # init scale and shift
|
714 |
+
|
715 |
+
def extra_repr(self):
|
716 |
+
return f'drop_path_rate={self.drop_path_rate}'
|
717 |
+
|
718 |
+
def get_layer_id_and_scale_exp(self, para_name: str):
|
719 |
+
raise NotImplementedError
|
720 |
+
|
721 |
+
|
722 |
+
def sample_with_top_k_top_p_also_inplace_modifying_logits_(logits_BlV: torch.Tensor, top_k: int = 0, top_p: float = 0.0, rng=None, num_samples=1) -> torch.Tensor: # return idx, shaped (B, l)
|
723 |
+
B, l, V = logits_BlV.shape
|
724 |
+
if top_k > 0:
|
725 |
+
top_k = min(top_k, V)
|
726 |
+
idx_to_remove = logits_BlV < logits_BlV.topk(top_k, largest=True, sorted=False, dim=-1)[0].amin(dim=-1, keepdim=True)
|
727 |
+
logits_BlV.masked_fill_(idx_to_remove, -torch.inf)
|
728 |
+
if top_p > 0:
|
729 |
+
sorted_logits, sorted_idx = logits_BlV.sort(dim=-1, descending=False)
|
730 |
+
sorted_idx_to_remove = sorted_logits.softmax(dim=-1).cumsum_(dim=-1) <= (1 - top_p)
|
731 |
+
sorted_idx_to_remove[..., -1:] = False
|
732 |
+
logits_BlV.masked_fill_(sorted_idx_to_remove.scatter(sorted_idx.ndim - 1, sorted_idx, sorted_idx_to_remove), -torch.inf)
|
733 |
+
# sample (have to squeeze cuz multinomial can only be used on 2D tensor)
|
734 |
+
replacement = num_samples >= 0
|
735 |
+
num_samples = abs(num_samples)
|
736 |
+
return torch.multinomial(logits_BlV.softmax(dim=-1).view(-1, V), num_samples=num_samples, replacement=replacement, generator=rng).view(B, l, num_samples)
|
737 |
+
|
738 |
+
def sampling_with_top_k_top_p_also_inplace_modifying_probs_(probs_BlV: torch.Tensor, top_k: int = 0, top_p: float = 0.0, rng=None, num_samples=1) -> torch.Tensor: # return idx, shaped (B, l)
|
739 |
+
B, l, V = probs_BlV.shape
|
740 |
+
if top_k > 0:
|
741 |
+
top_k = min(top_k, V)
|
742 |
+
idx_to_remove = probs_BlV < probs_BlV.topk(top_k, largest=True, sorted=False, dim=-1)[0].amin(dim=-1, keepdim=True)
|
743 |
+
probs_BlV.masked_fill_(idx_to_remove, 0)
|
744 |
+
if top_p > 0:
|
745 |
+
sorted_probs, sorted_idx = probs_BlV.sort(dim=-1, descending=False)
|
746 |
+
sorted_idx_to_remove = sorted_probs.softmax(dim=-1).cumsum_(dim=-1) <= (1 - top_p)
|
747 |
+
sorted_idx_to_remove[..., -1:] = False
|
748 |
+
probs_BlV.masked_fill_(sorted_idx_to_remove.scatter(sorted_idx.ndim - 1, sorted_idx, sorted_idx_to_remove), 0)
|
749 |
+
# sample (have to squeeze cuz multinomial can only be used on 2D tensor)
|
750 |
+
probs_BlV = probs_BlV / probs_BlV.sum(-1, keepdims=True)
|
751 |
+
replacement = num_samples >= 0
|
752 |
+
num_samples = abs(num_samples)
|
753 |
+
return torch.multinomial(probs_BlV.view(-1, V), num_samples=num_samples, replacement=replacement, generator=rng).view(B, l, num_samples)
|
754 |
+
|
755 |
+
|
756 |
+
def get_params_num(d, w, mlp):
|
757 |
+
m = round(mlp * w / 256) * 256
|
758 |
+
s = d * (w**2 * 8 + w*m * 2) # sa+ca, mlp
|
759 |
+
s += w**2 * 6 # saln
|
760 |
+
s += 4096 * w # pred
|
761 |
+
s += 32 * w # we
|
762 |
+
|
763 |
+
Ct5 = 4096
|
764 |
+
s += Ct5*w * 4 # T5 attn pool
|
765 |
+
s += Ct5*w + w*w # T5 mlp
|
766 |
+
return f'{s/1e9:.2f}B'
|
767 |
+
|
768 |
+
|
769 |
+
TIMM_KEYS = {'img_size', 'pretrained', 'pretrained_cfg', 'pretrained_cfg_overlay', 'global_pool'}
|
770 |
+
|
771 |
+
@register_model
|
772 |
+
def infinity_2b(depth=32, embed_dim=2048, num_heads=2048//128, drop_path_rate=0.1, **kwargs): return Infinity(depth=depth, embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=4, drop_path_rate=drop_path_rate, **{k: v for k, v in kwargs.items() if k not in TIMM_KEYS})
|
773 |
+
|
774 |
+
@register_model
|
775 |
+
def infinity_20b(depth=58, embed_dim=4608, num_heads=4608//128, drop_path_rate=0.25, **kwargs): return Infinity(depth=depth, embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=4, drop_path_rate=drop_path_rate, **{k: v for k, v in kwargs.items() if k not in TIMM_KEYS})
|
776 |
+
|
777 |
+
# model configuration for scaling Infinity transformer
|
778 |
+
@register_model
|
779 |
+
def infinity_layer12(depth=12, embed_dim=768, num_heads=8, drop_path_rate=0.1, **kwargs):
|
780 |
+
return Infinity(depth=depth, embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=4, drop_path_rate=drop_path_rate, **{k: v for k, v in kwargs.items() if k not in TIMM_KEYS})
|
781 |
+
@register_model
|
782 |
+
def infinity_layer16(depth=16, embed_dim=1152, num_heads=12, drop_path_rate=0.1, **kwargs):
|
783 |
+
return Infinity(depth=depth, embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=4, drop_path_rate=drop_path_rate, **{k: v for k, v in kwargs.items() if k not in TIMM_KEYS})
|
784 |
+
@register_model
|
785 |
+
def infinity_layer24(depth=24, embed_dim=1536, num_heads=16, drop_path_rate=0.1, **kwargs):
|
786 |
+
return Infinity(depth=depth, embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=4, drop_path_rate=drop_path_rate, **{k: v for k, v in kwargs.items() if k not in TIMM_KEYS})
|
787 |
+
@register_model
|
788 |
+
def infinity_layer32(depth=32, embed_dim=2080, num_heads=20, drop_path_rate=0.1, **kwargs):
|
789 |
+
return Infinity(depth=depth, embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=4, drop_path_rate=drop_path_rate, **{k: v for k, v in kwargs.items() if k not in TIMM_KEYS})
|
790 |
+
@register_model
|
791 |
+
def infinity_layer40(depth=40, embed_dim=2688, num_heads=24, drop_path_rate=0.1, **kwargs):
|
792 |
+
return Infinity(depth=depth, embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=4, drop_path_rate=drop_path_rate, **{k: v for k, v in kwargs.items() if k not in TIMM_KEYS})
|
793 |
+
@register_model
|
794 |
+
def infinity_layer48(depth=48, embed_dim=3360, num_heads=28, drop_path_rate=0.1, **kwargs):
|
795 |
+
return Infinity(depth=depth, embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=4, drop_path_rate=drop_path_rate, **{k: v for k, v in kwargs.items() if k not in TIMM_KEYS})
|
models/init_param.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
|
3 |
+
|
4 |
+
def init_weights(model: nn.Module, conv_std_or_gain: float = 0.02, other_std: float = 0.02):
|
5 |
+
"""
|
6 |
+
:param model: the model to be inited
|
7 |
+
:param conv_std_or_gain: how to init every conv layer `m`
|
8 |
+
> 0: nn.init.trunc_normal_(m.weight.data, std=conv_std_or_gain)
|
9 |
+
< 0: nn.init.xavier_normal_(m.weight.data, gain=-conv_std_or_gain)
|
10 |
+
:param other_std: how to init every linear layer or embedding layer
|
11 |
+
use nn.init.trunc_normal_(m.weight.data, std=other_std)
|
12 |
+
"""
|
13 |
+
skip = abs(conv_std_or_gain) > 10
|
14 |
+
if skip: return
|
15 |
+
print(f'[init_weights] {type(model).__name__} with {"std" if conv_std_or_gain > 0 else "gain"}={abs(conv_std_or_gain):g}')
|
16 |
+
for m in model.modules():
|
17 |
+
if isinstance(m, nn.Linear):
|
18 |
+
nn.init.trunc_normal_(m.weight.data, std=other_std)
|
19 |
+
if m.bias is not None:
|
20 |
+
nn.init.constant_(m.bias.data, 0.)
|
21 |
+
elif isinstance(m, nn.Embedding):
|
22 |
+
nn.init.trunc_normal_(m.weight.data, std=other_std)
|
23 |
+
if m.padding_idx is not None:
|
24 |
+
m.weight.data[m.padding_idx].zero_()
|
25 |
+
elif isinstance(m, (nn.Conv1d, nn.Conv2d, nn.ConvTranspose1d, nn.ConvTranspose2d)):
|
26 |
+
nn.init.trunc_normal_(m.weight.data, std=conv_std_or_gain) if conv_std_or_gain > 0 else nn.init.xavier_normal_(m.weight.data, gain=-conv_std_or_gain) # todo: StyleSwin: (..., gain=.02)
|
27 |
+
if hasattr(m, 'bias') and m.bias is not None:
|
28 |
+
nn.init.constant_(m.bias.data, 0.)
|
29 |
+
elif isinstance(m, (nn.LayerNorm, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm, nn.GroupNorm, nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d)):
|
30 |
+
if m.bias is not None:
|
31 |
+
nn.init.constant_(m.bias.data, 0.)
|
32 |
+
if m.weight is not None:
|
33 |
+
nn.init.constant_(m.weight.data, 1.)
|
models/t5.py
ADDED
@@ -0,0 +1,369 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import torch
|
3 |
+
import os
|
4 |
+
import traceback
|
5 |
+
import numpy as np
|
6 |
+
from huggingface_hub import hf_hub_download
|
7 |
+
from transformers import AutoTokenizer, T5EncoderModel
|
8 |
+
|
9 |
+
import ftfy
|
10 |
+
import html
|
11 |
+
from bs4 import BeautifulSoup
|
12 |
+
import urllib.parse as ul
|
13 |
+
|
14 |
+
|
15 |
+
class T5Embedder:
|
16 |
+
|
17 |
+
available_models = ['t5-v1_1-xxl']
|
18 |
+
bad_punct_regex = re.compile(r'['+'#®•©™&@·º½¾¿¡§~'+'\)'+'\('+'\]'+'\['+'\}'+'\{'+'\|'+'\\'+'\/'+'\*' + r']{1,}') # noqa
|
19 |
+
|
20 |
+
def __init__(self, device, dir_or_name='t5-v1_1-xxl', *, local_cache=False, cache_dir=None, hf_token=None, use_text_preprocessing=True,
|
21 |
+
t5_model_kwargs=None, torch_dtype=torch.bfloat16, use_offload_folder=None, model_max_length=512, padding="max_length", clean_caption_func_name="clean_caption"):
|
22 |
+
self.device = torch.device(device)
|
23 |
+
self.torch_dtype = torch_dtype
|
24 |
+
if t5_model_kwargs is None:
|
25 |
+
t5_model_kwargs = {'low_cpu_mem_usage': True, 'torch_dtype': self.torch_dtype}
|
26 |
+
if use_offload_folder is not None:
|
27 |
+
t5_model_kwargs['offload_folder'] = use_offload_folder
|
28 |
+
t5_model_kwargs['device_map'] = {
|
29 |
+
'shared': self.device,
|
30 |
+
'encoder.embed_tokens': self.device,
|
31 |
+
'encoder.block.0': self.device,
|
32 |
+
'encoder.block.1': self.device,
|
33 |
+
'encoder.block.2': self.device,
|
34 |
+
'encoder.block.3': self.device,
|
35 |
+
'encoder.block.4': self.device,
|
36 |
+
'encoder.block.5': self.device,
|
37 |
+
'encoder.block.6': self.device,
|
38 |
+
'encoder.block.7': self.device,
|
39 |
+
'encoder.block.8': self.device,
|
40 |
+
'encoder.block.9': self.device,
|
41 |
+
'encoder.block.10': self.device,
|
42 |
+
'encoder.block.11': self.device,
|
43 |
+
'encoder.block.12': 'disk',
|
44 |
+
'encoder.block.13': 'disk',
|
45 |
+
'encoder.block.14': 'disk',
|
46 |
+
'encoder.block.15': 'disk',
|
47 |
+
'encoder.block.16': 'disk',
|
48 |
+
'encoder.block.17': 'disk',
|
49 |
+
'encoder.block.18': 'disk',
|
50 |
+
'encoder.block.19': 'disk',
|
51 |
+
'encoder.block.20': 'disk',
|
52 |
+
'encoder.block.21': 'disk',
|
53 |
+
'encoder.block.22': 'disk',
|
54 |
+
'encoder.block.23': 'disk',
|
55 |
+
'encoder.final_layer_norm': 'disk',
|
56 |
+
'encoder.dropout': 'disk',
|
57 |
+
}
|
58 |
+
else:
|
59 |
+
t5_model_kwargs['device_map'] = {'shared': self.device, 'encoder': self.device}
|
60 |
+
|
61 |
+
self.use_text_preprocessing = use_text_preprocessing
|
62 |
+
self.hf_token = hf_token
|
63 |
+
self.cache_dir = cache_dir or os.path.expanduser('~/.cache/IF_')
|
64 |
+
self.dir_or_name = dir_or_name
|
65 |
+
tokenizer_path, path = dir_or_name, dir_or_name
|
66 |
+
if local_cache:
|
67 |
+
cache_dir = os.path.join(self.cache_dir, dir_or_name)
|
68 |
+
tokenizer_path, path = cache_dir, cache_dir
|
69 |
+
elif dir_or_name in self.available_models:
|
70 |
+
cache_dir = os.path.join(self.cache_dir, dir_or_name)
|
71 |
+
for filename in [
|
72 |
+
'config.json', 'special_tokens_map.json', 'spiece.model', 'tokenizer_config.json',
|
73 |
+
'pytorch_model.bin.index.json', 'pytorch_model-00001-of-00002.bin', 'pytorch_model-00002-of-00002.bin'
|
74 |
+
]:
|
75 |
+
hf_hub_download(repo_id=f'DeepFloyd/{dir_or_name}', filename=filename, cache_dir=cache_dir,
|
76 |
+
force_filename=filename, token=self.hf_token)
|
77 |
+
tokenizer_path, path = cache_dir, cache_dir
|
78 |
+
else:
|
79 |
+
cache_dir = os.path.join(self.cache_dir, 't5-v1_1-xxl')
|
80 |
+
for filename in [
|
81 |
+
'config.json', 'special_tokens_map.json', 'spiece.model', 'tokenizer_config.json',
|
82 |
+
]:
|
83 |
+
hf_hub_download(repo_id='DeepFloyd/t5-v1_1-xxl', filename=filename, cache_dir=cache_dir,
|
84 |
+
force_filename=filename, token=self.hf_token)
|
85 |
+
tokenizer_path = cache_dir
|
86 |
+
|
87 |
+
print(f"Loading T5 from {tokenizer_path}")
|
88 |
+
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
89 |
+
self.model = T5EncoderModel.from_pretrained(path, **t5_model_kwargs).eval()
|
90 |
+
self.model_max_length = model_max_length
|
91 |
+
self.padding = padding
|
92 |
+
self.clean_caption_func = self.__getattribute__(clean_caption_func_name)
|
93 |
+
|
94 |
+
@torch.no_grad()
|
95 |
+
def get_text_embeddings(self, texts):
|
96 |
+
import time
|
97 |
+
start_time = time.time()
|
98 |
+
|
99 |
+
texts = [self.text_preprocessing(text) for text in texts]
|
100 |
+
# print("text_preprocessing: ", time.time() - start_time)
|
101 |
+
|
102 |
+
text_tokens_and_mask = self.tokenizer(
|
103 |
+
texts,
|
104 |
+
max_length=self.model_max_length,
|
105 |
+
padding=self.padding,
|
106 |
+
truncation=True,
|
107 |
+
return_attention_mask=True,
|
108 |
+
add_special_tokens=True,
|
109 |
+
return_tensors='pt'
|
110 |
+
)
|
111 |
+
|
112 |
+
# print("tokenizer: ", time.time() - start_time)
|
113 |
+
|
114 |
+
text_tokens_and_mask['input_ids'] = text_tokens_and_mask['input_ids'].to(self.device)
|
115 |
+
text_tokens_and_mask['attention_mask'] = text_tokens_and_mask['attention_mask'].to(self.device)
|
116 |
+
|
117 |
+
with torch.no_grad():
|
118 |
+
text_encoder_embs = self.model(
|
119 |
+
input_ids=text_tokens_and_mask['input_ids'],
|
120 |
+
attention_mask=text_tokens_and_mask['attention_mask'],
|
121 |
+
)['last_hidden_state'].detach()
|
122 |
+
|
123 |
+
# print("model: ", time.time() - start_time)
|
124 |
+
return text_encoder_embs, text_tokens_and_mask['attention_mask'], text_tokens_and_mask['input_ids'], texts
|
125 |
+
|
126 |
+
def text_preprocessing(self, text):
|
127 |
+
if self.use_text_preprocessing:
|
128 |
+
try:
|
129 |
+
# The exact text cleaning as was in the training stage:
|
130 |
+
text = self.clean_caption_func(text)
|
131 |
+
text = self.clean_caption_func(text)
|
132 |
+
return text
|
133 |
+
except Exception as e:
|
134 |
+
print(f"Error in text preprocessing: {e} with text: {text}")
|
135 |
+
print(traceback.format_exc())
|
136 |
+
return text
|
137 |
+
else:
|
138 |
+
return text.lower().strip()
|
139 |
+
|
140 |
+
@staticmethod
|
141 |
+
def basic_clean(text):
|
142 |
+
text = ftfy.fix_text(text)
|
143 |
+
text = html.unescape(html.unescape(text))
|
144 |
+
return text.strip()
|
145 |
+
|
146 |
+
def clean_caption(self, caption):
|
147 |
+
caption = str(caption)
|
148 |
+
caption = ul.unquote_plus(caption)
|
149 |
+
caption = caption.strip().lower()
|
150 |
+
caption = re.sub('<person>', 'person', caption)
|
151 |
+
# urls:
|
152 |
+
caption = re.sub(
|
153 |
+
r'\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))', # noqa
|
154 |
+
'', caption) # regex for urls
|
155 |
+
caption = re.sub(
|
156 |
+
r'\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))', # noqa
|
157 |
+
'', caption) # regex for urls
|
158 |
+
# html:
|
159 |
+
try:
|
160 |
+
caption = BeautifulSoup(caption, features='html.parser').text
|
161 |
+
except Exception as e:
|
162 |
+
print(f"Error parsing caption:{caption} with html.parser: {e}")
|
163 |
+
|
164 |
+
# @<nickname>
|
165 |
+
caption = re.sub(r'@[\w\d]+\b', '', caption)
|
166 |
+
|
167 |
+
# 31C0—31EF CJK Strokes
|
168 |
+
# 31F0—31FF Katakana Phonetic Extensions
|
169 |
+
# 3200—32FF Enclosed CJK Letters and Months
|
170 |
+
# 3300—33FF CJK Compatibility
|
171 |
+
# 3400—4DBF CJK Unified Ideographs Extension A
|
172 |
+
# 4DC0—4DFF Yijing Hexagram Symbols
|
173 |
+
# 4E00—9FFF CJK Unified Ideographs
|
174 |
+
caption = re.sub(r'[\u31c0-\u31ef]+', '', caption)
|
175 |
+
caption = re.sub(r'[\u31f0-\u31ff]+', '', caption)
|
176 |
+
caption = re.sub(r'[\u3200-\u32ff]+', '', caption)
|
177 |
+
caption = re.sub(r'[\u3300-\u33ff]+', '', caption)
|
178 |
+
caption = re.sub(r'[\u3400-\u4dbf]+', '', caption)
|
179 |
+
caption = re.sub(r'[\u4dc0-\u4dff]+', '', caption)
|
180 |
+
caption = re.sub(r'[\u4e00-\u9fff]+', '', caption)
|
181 |
+
#######################################################
|
182 |
+
|
183 |
+
# все виды тире / all types of dash --> "-"
|
184 |
+
caption = re.sub(
|
185 |
+
r'[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+', # noqa
|
186 |
+
'-', caption)
|
187 |
+
|
188 |
+
# кавычки к одному стандарту
|
189 |
+
caption = re.sub(r'[`´«»“”¨]', '"', caption)
|
190 |
+
caption = re.sub(r'[‘’]', "'", caption)
|
191 |
+
|
192 |
+
# "
|
193 |
+
caption = re.sub(r'"?', '', caption)
|
194 |
+
# &
|
195 |
+
caption = re.sub(r'&', '', caption)
|
196 |
+
|
197 |
+
# ip adresses:
|
198 |
+
caption = re.sub(r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}', ' ', caption)
|
199 |
+
|
200 |
+
# article ids:
|
201 |
+
caption = re.sub(r'\d:\d\d\s+$', '', caption)
|
202 |
+
|
203 |
+
# \n
|
204 |
+
caption = re.sub(r'\\n', ' ', caption)
|
205 |
+
|
206 |
+
# "#123"
|
207 |
+
caption = re.sub(r'#\d{1,3}\b', '', caption)
|
208 |
+
# "#12345.."
|
209 |
+
caption = re.sub(r'#\d{5,}\b', '', caption)
|
210 |
+
# "123456.."
|
211 |
+
caption = re.sub(r'\b\d{6,}\b', '', caption)
|
212 |
+
# filenames:
|
213 |
+
caption = re.sub(r'[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)', '', caption)
|
214 |
+
|
215 |
+
#
|
216 |
+
caption = re.sub(r'[\"\']{2,}', r'"', caption) # """AUSVERKAUFT"""
|
217 |
+
caption = re.sub(r'[\.]{2,}', r' ', caption) # """AUSVERKAUFT"""
|
218 |
+
|
219 |
+
caption = re.sub(self.bad_punct_regex, r' ', caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
|
220 |
+
caption = re.sub(r'\s+\.\s+', r' ', caption) # " . "
|
221 |
+
|
222 |
+
# this-is-my-cute-cat / this_is_my_cute_cat
|
223 |
+
regex2 = re.compile(r'(?:\-|\_)')
|
224 |
+
if len(re.findall(regex2, caption)) > 3:
|
225 |
+
caption = re.sub(regex2, ' ', caption)
|
226 |
+
|
227 |
+
caption = self.basic_clean(caption)
|
228 |
+
|
229 |
+
caption = re.sub(r'\b[a-zA-Z]{1,3}\d{3,15}\b', '', caption) # jc6640
|
230 |
+
caption = re.sub(r'\b[a-zA-Z]+\d+[a-zA-Z]+\b', '', caption) # jc6640vc
|
231 |
+
caption = re.sub(r'\b\d+[a-zA-Z]+\d+\b', '', caption) # 6640vc231
|
232 |
+
|
233 |
+
caption = re.sub(r'(worldwide\s+)?(free\s+)?shipping', '', caption)
|
234 |
+
caption = re.sub(r'(free\s)?download(\sfree)?', '', caption)
|
235 |
+
caption = re.sub(r'\bclick\b\s(?:for|on)\s\w+', '', caption)
|
236 |
+
caption = re.sub(r'\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?', '', caption)
|
237 |
+
caption = re.sub(r'\bpage\s+\d+\b', '', caption)
|
238 |
+
|
239 |
+
caption = re.sub(r'\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b', r' ', caption) # j2d1a2a...
|
240 |
+
|
241 |
+
caption = re.sub(r'\b\d+\.?\d*[xх×]\d+\.?\d*\b', '', caption)
|
242 |
+
|
243 |
+
caption = re.sub(r'\b\s+\:\s+', r': ', caption)
|
244 |
+
caption = re.sub(r'(\D[,\./])\b', r'\1 ', caption)
|
245 |
+
caption = re.sub(r'\s+', ' ', caption)
|
246 |
+
|
247 |
+
caption.strip()
|
248 |
+
|
249 |
+
caption = re.sub(r'^[\"\']([\w\W]+)[\"\']$', r'\1', caption)
|
250 |
+
caption = re.sub(r'^[\'\_,\-\:;]', r'', caption)
|
251 |
+
caption = re.sub(r'[\'\_,\-\:\-\+]$', r'', caption)
|
252 |
+
caption = re.sub(r'^\.\S+$', '', caption)
|
253 |
+
|
254 |
+
return caption.strip()
|
255 |
+
|
256 |
+
|
257 |
+
def clean_caption_simplify(self, caption):
|
258 |
+
# 将 caption 转换为字符串
|
259 |
+
caption = str(caption)
|
260 |
+
|
261 |
+
# 解码 URL 编码的字符串
|
262 |
+
caption = ul.unquote_plus(caption)
|
263 |
+
|
264 |
+
# 去除首尾空格并转换为小写
|
265 |
+
caption = caption.strip().lower()
|
266 |
+
|
267 |
+
# 将 '<person>' 替换为 'person'
|
268 |
+
caption = re.sub('<person>', 'person', caption)
|
269 |
+
|
270 |
+
# 移除 URL
|
271 |
+
caption = re.sub(
|
272 |
+
r'\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))',
|
273 |
+
'', caption) # 匹配以 http:// 或 https:// 开头的 URL
|
274 |
+
caption = re.sub(
|
275 |
+
r'\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))',
|
276 |
+
'', caption) # 匹配以 www. 开头的 URL
|
277 |
+
|
278 |
+
# 解析 HTML 并删除 HTML 标签
|
279 |
+
caption = BeautifulSoup(caption, features='html.parser').text
|
280 |
+
|
281 |
+
# 移除 @nickname 标签
|
282 |
+
caption = re.sub(r'@[\w\d]+\b', '', caption)
|
283 |
+
|
284 |
+
# 移除特定 Unicode 范围的字符:CJK 相关字符
|
285 |
+
caption = re.sub(r'[\u31c0-\u31ef]+', '', caption) # CJK 笔划
|
286 |
+
caption = re.sub(r'[\u31f0-\u31ff]+', '', caption) # 片假名语音扩展
|
287 |
+
caption = re.sub(r'[\u3200-\u32ff]+', '', caption) # 圆括号中的 CJK 字母和月份
|
288 |
+
caption = re.sub(r'[\u3300-\u33ff]+', '', caption) # CJK 兼容性
|
289 |
+
caption = re.sub(r'[\u3400-\u4dbf]+', '', caption) # CJK 统一表意符号扩展 A
|
290 |
+
caption = re.sub(r'[\u4dc0-\u4dff]+', '', caption) # 易经卦象符号
|
291 |
+
caption = re.sub(r'[\u4e00-\u9fff]+', '', caption) # CJK 统一表意符号
|
292 |
+
|
293 |
+
# 所有类型的破折号替换为 "-"
|
294 |
+
caption = re.sub(
|
295 |
+
r'[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+',
|
296 |
+
'-', caption) # 匹配各种 Unicode 破折号
|
297 |
+
|
298 |
+
# 统一不同类型的引号
|
299 |
+
caption = re.sub(r'[`´«»“”¨]', '"', caption) # 将各种引号替换为标准引号
|
300 |
+
caption = re.sub(r'[‘’]', "'", caption) # 将左单引号和右单引号替换为标准单引号
|
301 |
+
|
302 |
+
# 移除 " 和 &
|
303 |
+
caption = re.sub(r'"?', '', caption) # 移除 HTML 实体 "
|
304 |
+
caption = re.sub(r'&', '', caption) # 移除 HTML 实体 &
|
305 |
+
|
306 |
+
# 移除 IP 地址
|
307 |
+
caption = re.sub(r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}', ' ', caption) # 匹配 IPv4 地址
|
308 |
+
|
309 |
+
# 移除文章 ID 格式
|
310 |
+
caption = re.sub(r'\d:\d\d\s+$', '', caption) # 匹配类似 '1:23 ' 的格式
|
311 |
+
|
312 |
+
# 移除 \n 转义字符
|
313 |
+
caption = re.sub(r'\\n', ' ', caption)
|
314 |
+
|
315 |
+
# 移除特定格式的标签
|
316 |
+
# caption = re.sub(r'#\d{1,3}\b', '', caption) # #123 移除 # 加 1 到 3 位数字的标签
|
317 |
+
# caption = re.sub(r'#\d{5,}\b', '', caption) # #12345.. 移除 # 加 5 位或以上数字的标签
|
318 |
+
# caption = re.sub(r'\b\d{6,}\b', '', caption) # 123456.. 移除 6 位或以上的纯数字
|
319 |
+
|
320 |
+
# 移除文件名
|
321 |
+
caption = re.sub(r'[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)', '', caption) # 匹配图片和视频文件,匹配完整的文件名,包括文件名本身和扩展名。
|
322 |
+
|
323 |
+
# 简化多重引号和点
|
324 |
+
caption = re.sub(r'[\"\']{2,}', r'"', caption) # 连续的双引号替换为一个双引号
|
325 |
+
caption = re.sub(r'[\.]{2,}', r' ', caption) # 连续���点替换为空格
|
326 |
+
|
327 |
+
# 使用通用标点正则表达式清理无效标点
|
328 |
+
caption = re.sub(self.bad_punct_regex, r' ', caption) # 自定义的无效标点正则表达式
|
329 |
+
caption = re.sub(r'\s+\.\s+', r' ', caption) # 移除空格和点
|
330 |
+
|
331 |
+
# 过滤带有太多破折号或下划线的文本
|
332 |
+
regex2 = re.compile(r'(?:\-|\_)')
|
333 |
+
if len(re.findall(regex2, caption)) > 3:
|
334 |
+
caption = re.sub(regex2, ' ', caption)
|
335 |
+
|
336 |
+
# 基本清理
|
337 |
+
caption = self.basic_clean(caption)
|
338 |
+
|
339 |
+
# 移除特定格式的短字符串
|
340 |
+
# caption = re.sub(r'\b[a-zA-Z]{1,3}\d{3,15}\b', '', caption) # 匹配三个字母以下加三个数字以上的字符串
|
341 |
+
# caption = re.sub(r'\b[a-zA-Z]+\d+[a-zA-Z]+\b', '', caption) # 匹配字母数字混合的字符串
|
342 |
+
# caption = re.sub(r'\b\d+[a-zA-Z]+\d+\b', '', caption) # 匹配数字字母混合的字符串
|
343 |
+
|
344 |
+
# 移除特定的广告或指令性短语
|
345 |
+
# caption = re.sub(r'(worldwide\s+)?(free\s+)?shipping', '', caption) # 匹配 'worldwide free shipping', 'free shipping'
|
346 |
+
# caption = re.sub(r'(free\s)?download(\sfree)?', '', caption) # 匹配 'free download', 'download free'
|
347 |
+
# caption = re.sub(r'\bclick\b\s(?:for|on)\s\w+', '', caption) # 匹配 'click for ...' 或 'click on ...'
|
348 |
+
# caption = re.sub(r'\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?', '', caption) # 匹配文件扩展名,匹配独立的扩展名或扩展名后可能跟随的特定词汇的场景
|
349 |
+
# caption = re.sub(r'\bpage\s+\d+\b', '', caption) # 匹配 'page 123'
|
350 |
+
|
351 |
+
# 移除复杂模式的字符串
|
352 |
+
# caption = re.sub(r'\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b', r' ', caption) # 123A456B789
|
353 |
+
|
354 |
+
# 移除特定的矩形标识符
|
355 |
+
caption = re.sub(r'\b\d+\.?\d*[xх×]\d+\.?\d*\b', '', caption)
|
356 |
+
|
357 |
+
# 修复多余的空白和标点
|
358 |
+
caption = re.sub(r'\b\s+\:\s+', r': ', caption)
|
359 |
+
caption = re.sub(r'(\D[,\./])\b', r'\1 ', caption)
|
360 |
+
caption = re.sub(r'\s+', ' ', caption)
|
361 |
+
|
362 |
+
# 去除首尾的多余字符
|
363 |
+
caption.strip()
|
364 |
+
caption = re.sub(r'^[\"\']([\w\W]+)[\"\']$', r'\1', caption)
|
365 |
+
caption = re.sub(r'^[\'\_,\-\:;]', r'', caption)
|
366 |
+
caption = re.sub(r'[\'\_,\-\:\-\+]$', r'', caption)
|
367 |
+
caption = re.sub(r'^\.\S+$', '', caption)
|
368 |
+
|
369 |
+
return caption.strip()
|
requirements.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
random
|
2 |
+
torch
|
3 |
+
opencv-python
|
4 |
+
numpy
|
5 |
+
gradio
|
6 |
+
huggingface-hub
|
7 |
+
transformers
|
8 |
+
argparse
|
9 |
+
spaces
|
utils/amp_opt.py
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import os
|
3 |
+
import signal
|
4 |
+
import sys
|
5 |
+
import time
|
6 |
+
from typing import List, Optional, Tuple, Union
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
10 |
+
# from memory_profiler import profile
|
11 |
+
|
12 |
+
import infinity.utils.dist as dist
|
13 |
+
from infinity.utils import misc
|
14 |
+
|
15 |
+
class NullCtx:
|
16 |
+
def __enter__(self):
|
17 |
+
pass
|
18 |
+
|
19 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
20 |
+
pass
|
21 |
+
|
22 |
+
|
23 |
+
def handle_timeout(signum, frame):
|
24 |
+
raise TimeoutError('took too long')
|
25 |
+
|
26 |
+
|
27 |
+
def per_param_clip_grad_norm_(parameters, thresh: float, stable=False, fp=None) -> (float, float):
|
28 |
+
skipped, max_grad = [], 0
|
29 |
+
for pi, p in enumerate(parameters):
|
30 |
+
if p.grad is not None:
|
31 |
+
g = p.grad.data.norm(2).item() + 1e-7
|
32 |
+
max_grad = max(max_grad, g)
|
33 |
+
clip_coef = thresh / g
|
34 |
+
if clip_coef < 1:
|
35 |
+
if stable and clip_coef < 0.2:
|
36 |
+
skipped.append(clip_coef)
|
37 |
+
p.grad.data.mul_(0) # todo NOTE: inf.mul_(0)==nan will shrink the scale ratio, but inf.zero_()==0 won't
|
38 |
+
else:
|
39 |
+
p.grad.data.mul_(clip_coef)
|
40 |
+
|
41 |
+
# if fp is not None: fp.write(f'[per_param_clip_grad_norm_:47] finished.\n'); fp.flush()
|
42 |
+
return 0 if len(skipped) == 0 else math.log10(max(min(skipped), 1e-7)), max_grad
|
43 |
+
|
44 |
+
|
45 |
+
class AmpOptimizer:
|
46 |
+
def __init__(
|
47 |
+
self,
|
48 |
+
model_name_3letters: str, mixed_precision: int,
|
49 |
+
optimizer: torch.optim.Optimizer, model_maybe_fsdp: Union[torch.nn.Module, FSDP],
|
50 |
+
r_accu: float, grad_clip: float, zero: int,
|
51 |
+
):
|
52 |
+
self.enable_amp = mixed_precision > 0
|
53 |
+
self.zero = zero
|
54 |
+
if self.enable_amp:
|
55 |
+
self.using_fp16_rather_bf16 = mixed_precision != 2
|
56 |
+
self.max_sc = float(mixed_precision if mixed_precision > 128 else 32768)
|
57 |
+
|
58 |
+
# todo: on both V100 and A100, torch.get_autocast_gpu_dtype() returns fp16, not bf16.
|
59 |
+
self.amp_ctx = torch.autocast('cuda', enabled=True, dtype=torch.float16 if self.using_fp16_rather_bf16 else torch.bfloat16, cache_enabled=self.zero == 0) # todo: cache_enabled=False
|
60 |
+
if self.using_fp16_rather_bf16:
|
61 |
+
self.scaler = torch.cuda.amp.GradScaler(init_scale=2. ** 11, growth_interval=1000)
|
62 |
+
else:
|
63 |
+
self.scaler = None
|
64 |
+
else:
|
65 |
+
self.using_fp16_rather_bf16 = True
|
66 |
+
self.amp_ctx = NullCtx()
|
67 |
+
self.scaler = None
|
68 |
+
|
69 |
+
t = torch.zeros(dist.get_world_size())
|
70 |
+
t[dist.get_rank()] = float(self.enable_amp)
|
71 |
+
dist.allreduce(t)
|
72 |
+
assert round(t.sum().item()) in {0, dist.get_world_size()}, f'enable_amp: {t}'
|
73 |
+
|
74 |
+
t = torch.zeros(dist.get_world_size())
|
75 |
+
t[dist.get_rank()] = float(self.using_fp16_rather_bf16)
|
76 |
+
dist.allreduce(t)
|
77 |
+
assert round(t.sum().item()) in {0, dist.get_world_size()}, f'using_fp16_rather_bf16: {t}'
|
78 |
+
|
79 |
+
self.model_name_3letters = model_name_3letters
|
80 |
+
self.optimizer, self.model_maybe_fsdp = optimizer, model_maybe_fsdp
|
81 |
+
self.r_accu = r_accu
|
82 |
+
|
83 |
+
self.paras = self.names = ... # todo: solve EMA-related codes
|
84 |
+
|
85 |
+
self.grad_clip, self.grad_clip_we = grad_clip, 0 # todo: disable wclip
|
86 |
+
if self.grad_clip > 100:
|
87 |
+
self.grad_clip %= 100
|
88 |
+
self.per_param = True
|
89 |
+
else:
|
90 |
+
self.per_param = False
|
91 |
+
self.per_param = False # todo: disable wclip
|
92 |
+
|
93 |
+
self.early_clipping = grad_clip > 0 and not hasattr(optimizer, 'global_grad_norm')
|
94 |
+
self.late_clipping = grad_clip > 0 and hasattr(optimizer, 'global_grad_norm') # deepspeed's optimizer
|
95 |
+
|
96 |
+
self.fp = None
|
97 |
+
self.last_orig_norm: torch.Tensor = torch.tensor(0.1)
|
98 |
+
|
99 |
+
@torch.no_grad()
|
100 |
+
def log_param(self, ep: int):
|
101 |
+
if self.zero == 0:
|
102 |
+
for name, values in get_param_for_log(self.model_name_3letters, self.model_maybe_fsdp.named_parameters()).items():
|
103 |
+
values: List[float]
|
104 |
+
if len(values) == 1: # e.g., cls token will only have one value
|
105 |
+
values.append(values[0])
|
106 |
+
else:
|
107 |
+
...
|
108 |
+
# todo: log params
|
109 |
+
|
110 |
+
# @profile(precision=4, stream=open('amp_sc.log', 'w+'))
|
111 |
+
def backward_clip_step(
|
112 |
+
self, ep: int, it: int, g_it: int, stepping: bool, logging_params: bool, loss: torch.Tensor, clip_decay_ratio=1, stable=False,
|
113 |
+
) -> Tuple[torch.Tensor, Optional[float]]:
|
114 |
+
# backward
|
115 |
+
loss = loss.mul(self.r_accu) # r_accu == 1.0 / n_gradient_accumulation
|
116 |
+
orig_norm = scaler_sc = None
|
117 |
+
# if self.fp is not None:
|
118 |
+
# if g_it % 20 == 0: self.fp.seek(0); self.fp.truncate(0)
|
119 |
+
if self.scaler is not None:
|
120 |
+
self.scaler.scale(loss).backward(retain_graph=False, create_graph=False) # retain_graph=retain_graph, create_graph=create_graph
|
121 |
+
else:
|
122 |
+
loss.backward(retain_graph=False, create_graph=False)
|
123 |
+
# if self.fp is not None: self.fp.write(f'[backward_clip_step:131] [it{it}, g_it{g_it}] after backward\n'); self.fp.flush()
|
124 |
+
|
125 |
+
# clip gradients then step optimizer
|
126 |
+
if stepping:
|
127 |
+
if self.scaler is not None: self.scaler.unscale_(self.optimizer) # now the gradient can be correctly got
|
128 |
+
# if self.fp is not None: self.fp.write(f'[backward_clip_step:137] [it{it}, g_it{g_it}] after scaler.unscale_\n'); self.fp.flush()
|
129 |
+
|
130 |
+
skipped, orig_norm = 0, self.last_orig_norm
|
131 |
+
# try:
|
132 |
+
if self.fp is not None:
|
133 |
+
if g_it % 10 == 0: self.fp.seek(0); self.fp.truncate(0)
|
134 |
+
self.fp.write(f'<ep{ep} it{it} {g_it}>\n'); self.fp.flush()
|
135 |
+
if self.early_clipping:
|
136 |
+
c = self.grad_clip * clip_decay_ratio
|
137 |
+
if self.zero:
|
138 |
+
orig_norm: Optional[torch.Tensor] = self.model_maybe_fsdp.clip_grad_norm_(c)
|
139 |
+
else:
|
140 |
+
orig_norm: Optional[torch.Tensor] = torch.nn.utils.clip_grad_norm_(self.model_maybe_fsdp.parameters(), c)
|
141 |
+
|
142 |
+
# if self.fp is not None: self.fp.write(f'[backward_clip_step:175] [it{it}, g_it{g_it}] before opt step\n'); self.fp.flush()
|
143 |
+
if self.scaler is not None:
|
144 |
+
self.scaler: torch.cuda.amp.GradScaler
|
145 |
+
if self.zero:
|
146 |
+
# synchronize found_inf_per_device before calling step, so that even if only some ranks found inf on their sharded params, all other ranks will know
|
147 |
+
# otherwise, when saving FSDP optimizer state, it will cause AssertionError saying "Different ranks have different values for step."
|
148 |
+
for optimizer_state in self.scaler._per_optimizer_states.values():
|
149 |
+
for t in optimizer_state['found_inf_per_device'].values():
|
150 |
+
dist.allreduce(t) # ideally, each rank only has one single t; so no need to use async allreduce
|
151 |
+
|
152 |
+
self.scaler.step(self.optimizer)
|
153 |
+
scaler_sc: Optional[float] = self.scaler.get_scale()
|
154 |
+
if scaler_sc > self.max_sc: # fp16 will overflow when >65536, so multiply 32768 could be dangerous
|
155 |
+
# print(f'[fp16 scaling] too large loss scale {scaler_sc}! (clip to {self.max_sc:g})')
|
156 |
+
self.scaler.update(new_scale=self.max_sc)
|
157 |
+
else:
|
158 |
+
self.scaler.update()
|
159 |
+
try:
|
160 |
+
scaler_sc = float(math.log2(scaler_sc))
|
161 |
+
except Exception as e:
|
162 |
+
print(f'[scaler_sc = {scaler_sc}]\n' * 15, flush=True)
|
163 |
+
time.sleep(1)
|
164 |
+
print(f'[scaler_sc = {scaler_sc}]\n' * 15, flush=True)
|
165 |
+
raise e
|
166 |
+
else:
|
167 |
+
self.optimizer.step()
|
168 |
+
|
169 |
+
if self.late_clipping:
|
170 |
+
orig_norm: Optional[torch.Tensor] = self.optimizer.global_grad_norm
|
171 |
+
self.last_orig_norm = orig_norm
|
172 |
+
# no zero_grad calling here, gonna log those gradients!
|
173 |
+
return orig_norm, scaler_sc
|
174 |
+
|
175 |
+
def state_dict(self):
|
176 |
+
return {
|
177 |
+
'optimizer': self.optimizer.state_dict()
|
178 |
+
} if self.scaler is None else {
|
179 |
+
'scaler': self.scaler.state_dict(),
|
180 |
+
'optimizer': self.optimizer.state_dict()
|
181 |
+
}
|
182 |
+
|
183 |
+
def load_state_dict(self, state, strict=True):
|
184 |
+
if self.scaler is not None:
|
185 |
+
try: self.scaler.load_state_dict(state['scaler'])
|
186 |
+
except Exception as e: print(f'[fp16 load_state_dict err] {e}')
|
187 |
+
self.optimizer.load_state_dict(state['optimizer'])
|
utils/arg_util.py
ADDED
@@ -0,0 +1,482 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import math
|
3 |
+
import os
|
4 |
+
import random
|
5 |
+
import subprocess
|
6 |
+
import sys
|
7 |
+
import time
|
8 |
+
from collections import OrderedDict, deque
|
9 |
+
from typing import Optional, Union
|
10 |
+
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
from tap import Tap
|
14 |
+
|
15 |
+
import infinity.utils.dist as dist
|
16 |
+
|
17 |
+
|
18 |
+
class Args(Tap):
|
19 |
+
local_out_path: str = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'local_output') # directory for save checkpoints
|
20 |
+
data_path: str = '' # dataset
|
21 |
+
bed: str = '' # bed directory for copy checkpoints apart from local_out_path
|
22 |
+
vae_ckpt: str = '' # VAE ckpt
|
23 |
+
exp_name: str = '' # experiment name
|
24 |
+
ds: str = 'oi' # only used in GPT training::load_viz_data & FID benchmark
|
25 |
+
model: str = '' # for VAE training, 'b' or any other for GPT training
|
26 |
+
short_cap_prob: float = 0.2 # prob for training with short captions
|
27 |
+
project_name: str = 'Infinity' # name of wandb project
|
28 |
+
tf32: bool = True # whether to use TensorFloat32
|
29 |
+
auto_resume: bool = True # whether to automatically resume from the last checkpoint found in args.bed
|
30 |
+
rush_resume: str = '' # pretrained infinity checkpoint
|
31 |
+
nowd: int = 1 # whether to disable weight decay on sparse params (like class token)
|
32 |
+
enable_hybrid_shard: bool = False # whether to use hybrid FSDP
|
33 |
+
inner_shard_degree: int = 1 # inner degree for FSDP
|
34 |
+
zero: int = 0 # ds zero
|
35 |
+
buck: str = 'chunk' # =0 for using module-wise
|
36 |
+
fsdp_orig: bool = True
|
37 |
+
enable_checkpointing: str = None # checkpointing strategy: full-block, self-attn
|
38 |
+
pad_to_multiplier: int = 1 # >1 for padding the seq len to a multiplier of this
|
39 |
+
log_every_iter: bool = False
|
40 |
+
checkpoint_type: str = 'torch' # checkpoint_type: torch, onmistore
|
41 |
+
seed: int = None # 3407
|
42 |
+
rand: bool = True # actual seed = seed + (dist.get_rank()*512 if rand else 0)
|
43 |
+
device: str = 'cpu'
|
44 |
+
task_id: str = '2493513'
|
45 |
+
trial_id: str = '7260554'
|
46 |
+
robust_run_id: str = '00'
|
47 |
+
ckpt_trials = []
|
48 |
+
real_trial_id: str = '7260552'
|
49 |
+
chunk_nodes: int = None
|
50 |
+
is_master_node: bool = None
|
51 |
+
# dir
|
52 |
+
log_txt_path: str = ''
|
53 |
+
t5_path: str = '' # if not specified: automatically find from all bytenas
|
54 |
+
online_t5: bool = True # whether to use online t5 or load local features
|
55 |
+
# GPT
|
56 |
+
sdpa_mem: bool = True # whether to use with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=False, enable_mem_efficient=True)
|
57 |
+
tfast: int = 0 # compile GPT
|
58 |
+
model_alias: str = 'b' # [automatically set; don't specify this]
|
59 |
+
rms: bool = False
|
60 |
+
aln: float = 1e-3 # multiplier of ada_lin.w's initialization
|
61 |
+
alng: float = -1 # multiplier of ada_lin.w[gamma channels]'s initialization, -1: the same as aln
|
62 |
+
saln: bool = False # whether to use a shared adaln layer
|
63 |
+
haln: bool = True # whether to use a specific adaln layer in head layer
|
64 |
+
nm0: bool = False # norm before word proj linear
|
65 |
+
tau: float = 1 # tau of self attention in GPT
|
66 |
+
cos: bool = True # cosine attn as in swin v2
|
67 |
+
swi: bool = False # whether to use FFNSwiGLU, instead of vanilla FFN
|
68 |
+
dp: float = -1
|
69 |
+
drop: float = 0.0 # GPT's dropout (VAE's is --vd)
|
70 |
+
hd: int = 0
|
71 |
+
ca_gamma: float = -1 # >=0 for using layer-scale for cross attention
|
72 |
+
diva: int = 1 # rescale_attn_fc_weights
|
73 |
+
hd0: float = 0.02 # head.w *= hd0
|
74 |
+
dec: int = 1 # dec depth
|
75 |
+
cum: int = 3 # cumulating fea map as GPT TF input, 0: not cum; 1: cum @ next hw, 2: cum @ final hw
|
76 |
+
rwe: bool = False # random word emb
|
77 |
+
tp: float = 0.0 # top-p
|
78 |
+
tk: float = 0.0 # top-k
|
79 |
+
tini: float = 0.02 # init parameters
|
80 |
+
cfg: float = 0.1 # >0: classifier-free guidance, drop cond with prob cfg
|
81 |
+
rand_uncond = False # whether to use random, unlearnable uncond embeding
|
82 |
+
ema: float = 0.9999 # VAE's ema ratio, not VAR's. 0.9977844 == 0.5 ** (32 / (10 * 1000)) from gans, 0.9999 from SD
|
83 |
+
tema: float = 0 # 0.9999 in DiffiT, DiT
|
84 |
+
fp16: int = 0 # 1: fp16, 2: bf16, >2: fp16's max scaling multiplier todo: 记得让quantize相关的feature都强制fp32!另外residueal最好也是fp32(根据flash-attention)nn.Conv2d有一个参数是use_float16?
|
85 |
+
fuse: bool = False # whether to use fused mlp
|
86 |
+
fused_norm: bool = False # whether to use fused norm
|
87 |
+
flash: bool = False # whether to use customized flash-attn kernel
|
88 |
+
xen: bool = False # whether to use xentropy
|
89 |
+
use_flex_attn: bool = False # whether to use flex_attn to speedup training
|
90 |
+
stable: bool = False
|
91 |
+
gblr: float = 1e-4
|
92 |
+
dblr: float = None # =gblr if is None
|
93 |
+
tblr: float = 6e-4
|
94 |
+
glr: float = None
|
95 |
+
dlr: float = None
|
96 |
+
tlr: float = None # vqgan: 4e-5
|
97 |
+
gwd: float = 0.005
|
98 |
+
dwd: float = 0.0005
|
99 |
+
twd: float = 0.005 # vqgan: 0.01
|
100 |
+
gwde: float = 0
|
101 |
+
dwde: float = 0
|
102 |
+
twde: float = 0
|
103 |
+
ls: float = 0.0 # label smooth
|
104 |
+
lz: float = 0.0 # z loss from PaLM = 1e-4 todo
|
105 |
+
eq: int = 0 # equalized loss
|
106 |
+
ep: int = 100
|
107 |
+
wp: float = 0
|
108 |
+
wp0: float = 0.005
|
109 |
+
wpe: float = 0.3 # 0.001, final cosine lr = wpe * peak lr
|
110 |
+
sche: str = '' # cos, exp, lin
|
111 |
+
log_freq: int = 50 # log frequency in the stdout
|
112 |
+
gclip: float = 6. # <=0 for not grad clip VAE
|
113 |
+
dclip: float = 6. # <=0 for not grad clip discriminator
|
114 |
+
tclip: float = 2. # <=0 for not grad clip GPT; >100 for per-param clip (%= 100 automatically)
|
115 |
+
cdec: bool = False # decay the grad clip thresholds of GPT and GPT's word embed
|
116 |
+
opt: str = 'adamw' # lion: https://cloud.tencent.com/developer/article/2336657?areaId=106001 lr=5e-5(比Adam学习率低四倍)和wd=0.8(比Adam高八倍);比如在小的 batch_size 时,Lion 的表现不如 AdamW
|
117 |
+
ada: str = '' # adam's beta0 and beta1 for VAE or GPT, '0_0.99' from style-swin and magvit, '0.5_0.9' from VQGAN
|
118 |
+
dada: str = '' # adam's beta0 and beta1 for discriminator
|
119 |
+
oeps: float = 0 # adam's eps, pixart uses 1e-10
|
120 |
+
afuse: bool = True # fused adam
|
121 |
+
# data
|
122 |
+
pn: str = '' # pixel nums, choose from 0.06M, 0.25M, 1M
|
123 |
+
scale_schedule: tuple = None # [automatically set; don't specify this] = tuple(map(int, args.pn.replace('-', '_').split('_')))
|
124 |
+
patch_size: int = None # [automatically set; don't specify this] = 2 ** (len(args.scale_schedule) - 1)
|
125 |
+
resos: tuple = None # [automatically set; don't specify this]
|
126 |
+
data_load_reso: int = None # [automatically set; don't specify this]
|
127 |
+
workers: int = 0 # num workers; 0: auto, -1: don't use multiprocessing in DataLoader
|
128 |
+
lbs: int = 0 # local batch size; if lbs != 0, bs will be ignored, and will be reset as round(args.lbs / args.ac) * dist.get_world_size()
|
129 |
+
bs: int = 0 # global batch size; if lbs != 0, bs will be ignored
|
130 |
+
batch_size: int = 0 # [automatically set; don't specify this] batch size per GPU = round(args.bs / args.ac / dist.get_world_size())
|
131 |
+
glb_batch_size: int = 0 # [automatically set; don't specify this] global batch size = args.batch_size * dist.get_world_size()
|
132 |
+
ac: int = 1 # gradient accumulation
|
133 |
+
r_accu: float = 1.0 # [automatically set; don't specify this] = 1 / args.ac
|
134 |
+
norm_eps: float = 1e-6 # norm eps for infinity
|
135 |
+
tlen: int = 512 # truncate text embedding to this length
|
136 |
+
Ct5: int = 2048 # feature dimension of text encoder
|
137 |
+
use_bit_label: int = 1 # pred bitwise labels or index-wise labels
|
138 |
+
bitloss_type: str = 'mean' # mean or sum
|
139 |
+
dynamic_resolution_across_gpus: int = 1 # allow dynamic resolution across gpus
|
140 |
+
enable_dynamic_length_prompt: int = 0 # enable dynamic length prompt during training
|
141 |
+
use_streaming_dataset: int = 0 # use streaming dataset
|
142 |
+
iterable_data_buffersize: int = 90000 # streaming dataset buffer size
|
143 |
+
save_model_iters_freq: int = 1000 # save model iter freq
|
144 |
+
noise_apply_layers: int = -1 # Bitwise Self-Correction: apply noise to layers, -1 means not apply noise
|
145 |
+
noise_apply_strength: float = -1 # Bitwise Self-Correction: apply noise strength, -1 means not apply noise
|
146 |
+
noise_apply_requant: int = 1 # Bitwise Self-Correction: requant after apply noise
|
147 |
+
rope2d_each_sa_layer: int = 0 # apply rope2d to each self-attention layer
|
148 |
+
rope2d_normalized_by_hw: int = 1 # apply normalized rope2d
|
149 |
+
use_fsdp_model_ema: int = 0 # use fsdp model ema
|
150 |
+
add_lvl_embeding_only_first_block: int = 1 # apply lvl pe embedding only first block or each block
|
151 |
+
reweight_loss_by_scale: int = 0 # reweight loss by scale
|
152 |
+
always_training_scales: int = 100 # trunc training scales
|
153 |
+
vae_type: int = 1 # here 16/32/64 is bsq vae of different quant bits
|
154 |
+
fake_vae_input: bool = False # fake vae input for debug
|
155 |
+
model_init_device: str = 'cuda' # model_init_device
|
156 |
+
prefetch_factor: int = 2 # prefetch_factor for dataset
|
157 |
+
apply_spatial_patchify: int = 0 # apply apply_spatial_patchify or not
|
158 |
+
debug_bsc: int = 0 # save figs and set breakpoint for debug bsc and check input
|
159 |
+
task_type: str = 't2i' # take type to t2i or t2v
|
160 |
+
|
161 |
+
|
162 |
+
############################ Attention! The following arguments and configurations are set automatically, you can skip reading the following part ###############################
|
163 |
+
############################ Attention! The following arguments and configurations are set automatically, you can skip reading the following part ###############################
|
164 |
+
############################ Attention! The following arguments and configurations are set automatically, you can skip reading the following part ###############################
|
165 |
+
|
166 |
+
|
167 |
+
# would be automatically set in runtime
|
168 |
+
branch: str = subprocess.check_output(f'git symbolic-ref --short HEAD 2>/dev/null || git rev-parse HEAD', shell=True).decode('utf-8').strip() or '[unknown]' # [automatically set; don't specify this]
|
169 |
+
commit_id: str = '' # subprocess.check_output(f'git rev-parse HEAD', shell=True).decode('utf-8').strip() or '[unknown]' # [automatically set; don't specify this]
|
170 |
+
commit_msg: str = ''# (subprocess.check_output(f'git log -1', shell=True).decode('utf-8').strip().splitlines() or ['[unknown]'])[-1].strip() # [automatically set; don't specify this]
|
171 |
+
cmd: str = ' '.join(a.replace('--exp_name=', '').replace('--exp_name ', '') for a in sys.argv[7:]) # [automatically set; don't specify this]
|
172 |
+
tag: str = 'UK' # [automatically set; don't specify this]
|
173 |
+
acc_all: float = None # [automatically set; don't specify this]
|
174 |
+
acc_real: float = None # [automatically set; don't specify this]
|
175 |
+
acc_fake: float = None # [automatically set; don't specify this]
|
176 |
+
last_Lnll: float = None # [automatically set; don't specify this]
|
177 |
+
last_L1: float = None # [automatically set; don't specify this]
|
178 |
+
last_Ld: float = None # [automatically set; don't specify this]
|
179 |
+
last_wei_g: float = None # [automatically set; don't specify this]
|
180 |
+
grad_boom: str = None # [automatically set; don't specify this]
|
181 |
+
diff: float = None # [automatically set; don't specify this]
|
182 |
+
diffs: str = '' # [automatically set; don't specify this]
|
183 |
+
diffs_ema: str = None # [automatically set; don't specify this]
|
184 |
+
ca_performance: str = '' # [automatically set; don't specify this]
|
185 |
+
cur_phase: str = '' # [automatically set; don't specify this]
|
186 |
+
cur_it: str = '' # [automatically set; don't specify this]
|
187 |
+
cur_ep: str = '' # [automatically set; don't specify this]
|
188 |
+
remain_time: str = '' # [automatically set; don't specify this]
|
189 |
+
finish_time: str = '' # [automatically set; don't specify this]
|
190 |
+
iter_speed: float = None # [automatically set; don't specify this]
|
191 |
+
img_per_day: float = None # [automatically set; don't specify this]
|
192 |
+
max_nvidia_smi: float = 0 # [automatically set; don't specify this]
|
193 |
+
max_memory_allocated: float = None # [automatically set; don't specify this]
|
194 |
+
max_memory_reserved: float = None # [automatically set; don't specify this]
|
195 |
+
num_alloc_retries: int = None # [automatically set; don't specify this]
|
196 |
+
MFU: float = None # [automatically set; don't specify this]
|
197 |
+
HFU: float = None # [automatically set; don't specify this]
|
198 |
+
# ==================================================================================================================
|
199 |
+
# ======================== ignore these parts below since they are only for debug use ==============================
|
200 |
+
# ==================================================================================================================
|
201 |
+
dbg_modified: bool = False
|
202 |
+
dbg_ks: bool = False
|
203 |
+
dbg_ks_last = None
|
204 |
+
dbg_ks_fp = None
|
205 |
+
def dbg_ks_this_line(self, g_it: int):
|
206 |
+
if self.dbg_ks:
|
207 |
+
if self.dbg_ks_last is None:
|
208 |
+
self.dbg_ks_last = deque(maxlen=6)
|
209 |
+
|
210 |
+
from utils.misc import time_str
|
211 |
+
self.dbg_ks_fp.seek(0)
|
212 |
+
f_back = sys._getframe().f_back
|
213 |
+
file_desc = f'{f_back.f_code.co_filename:24s}'[-24:]
|
214 |
+
info = f'{time_str()} ({file_desc}, line{f_back.f_lineno:-4d})'
|
215 |
+
if g_it is not None:
|
216 |
+
info += f' [g_it: {g_it}]'
|
217 |
+
|
218 |
+
self.dbg_ks_last.append(info)
|
219 |
+
self.dbg_ks_fp.write('\n'.join(self.dbg_ks_last) + '\n')
|
220 |
+
self.dbg_ks_fp.flush()
|
221 |
+
|
222 |
+
dbg: bool = 'KEVIN_LOCAL' in os.environ # only used when debug about unused param in DDP
|
223 |
+
ks: bool = False
|
224 |
+
nodata: bool = False # if True, will set nova=True as well
|
225 |
+
nodata_tlen: int = 320
|
226 |
+
nova: bool = False # no val, no FID
|
227 |
+
prof: int = 0 # profile
|
228 |
+
prof_freq: int = 50 # profile
|
229 |
+
tos_profiler_file_prefix: str = 'vgpt_default/'
|
230 |
+
profall: int = 0
|
231 |
+
@property
|
232 |
+
def is_vae_visualization_only(self) -> bool:
|
233 |
+
return self.v_seed > 0
|
234 |
+
v_seed: int = 0 # v_seed != 0 means the visualization-only mode
|
235 |
+
@property
|
236 |
+
def is_gpt_visualization_only(self) -> bool:
|
237 |
+
return self.g_seed > 0
|
238 |
+
g_seed: int = 0 # g_seed != 0 means the visualization-only mode
|
239 |
+
# ==================================================================================================================
|
240 |
+
# ======================== ignore these parts above since they are only for debug use ==============================
|
241 |
+
# ==================================================================================================================
|
242 |
+
|
243 |
+
@property
|
244 |
+
def gpt_training(self):
|
245 |
+
return len(self.model) > 0
|
246 |
+
|
247 |
+
def set_initial_seed(self, benchmark: bool):
|
248 |
+
torch.backends.cudnn.enabled = True
|
249 |
+
torch.backends.cudnn.benchmark = benchmark
|
250 |
+
if self.seed is None:
|
251 |
+
torch.backends.cudnn.deterministic = False
|
252 |
+
else:
|
253 |
+
seed = self.seed + (dist.get_rank()*512 if self.rand else 0)
|
254 |
+
torch.backends.cudnn.deterministic = True
|
255 |
+
os.environ['PYTHONHASHSEED'] = str(seed)
|
256 |
+
random.seed(seed)
|
257 |
+
np.random.seed(seed)
|
258 |
+
torch.manual_seed(seed)
|
259 |
+
if torch.cuda.is_available():
|
260 |
+
torch.cuda.manual_seed(seed)
|
261 |
+
torch.cuda.manual_seed_all(seed)
|
262 |
+
|
263 |
+
def get_different_generator_for_each_rank(self) -> Optional[torch.Generator]: # for random augmentation
|
264 |
+
if self.seed is None:
|
265 |
+
return None
|
266 |
+
g = torch.Generator()
|
267 |
+
g.manual_seed(self.seed + dist.get_rank()*512)
|
268 |
+
return g
|
269 |
+
|
270 |
+
def compile_model(self, m, fast):
|
271 |
+
if fast == 0:
|
272 |
+
return m
|
273 |
+
return torch.compile(m, mode={
|
274 |
+
1: 'reduce-overhead',
|
275 |
+
2: 'max-autotune',
|
276 |
+
3: 'default',
|
277 |
+
}[fast]) if hasattr(torch, 'compile') else m
|
278 |
+
|
279 |
+
def dump_log(self):
|
280 |
+
if not dist.is_local_master():
|
281 |
+
return
|
282 |
+
nd = {'is_master': dist.is_visualizer()}
|
283 |
+
r_trial, trial = str(self.real_trial_id), str(self.trial_id)
|
284 |
+
for k, v in {
|
285 |
+
'name': self.exp_name, 'tag': self.tag, 'cmd': self.cmd, 'commit': self.commit_id, 'branch': self.branch,
|
286 |
+
'Lnll': self.last_Lnll, 'L1': self.last_L1,
|
287 |
+
'Ld': self.last_Ld,
|
288 |
+
'acc': self.acc_all, 'acc_r': self.acc_real, 'acc_f': self.acc_fake,
|
289 |
+
'weiG': self.last_wei_g if (self.last_wei_g is None or math.isfinite(self.last_wei_g)) else -23333,
|
290 |
+
'grad': self.grad_boom,
|
291 |
+
|
292 |
+
'cur': self.cur_phase, 'cur_ep': self.cur_ep, 'cur_it': self.cur_it,
|
293 |
+
'rema': self.remain_time, 'fini': self.finish_time, 'last_upd': time.strftime("%Y-%m-%d %H:%M", time.localtime()),
|
294 |
+
'bsep': f'{self.glb_batch_size}/{self.ep}',
|
295 |
+
'G_lrwd': f'{self.glr:.1e}'.replace('.0', '').replace('-0', '-').replace('+0', '+') + f'/{self.gwd:g}',
|
296 |
+
'D_lrwd': f'{self.dlr:.1e}'.replace('.0', '').replace('-0', '-').replace('+0', '+') + f'/{self.dwd:g}',
|
297 |
+
'T_lrwd': f'{self.tlr:.1e}'.replace('.0', '').replace('-0', '-').replace('+0', '+') + f'/{self.twd:g}',
|
298 |
+
'diff': self.diff, 'diffs': self.diffs, 'diffs_ema': self.diffs_ema if self.diffs_ema else None,
|
299 |
+
'opt': self.opt,
|
300 |
+
'is_master_node': self.is_master_node,
|
301 |
+
}.items():
|
302 |
+
if hasattr(v, 'item'):v = v.item()
|
303 |
+
if v is None or (isinstance(v, str) and len(v) == 0): continue
|
304 |
+
nd[k] = v
|
305 |
+
if r_trial == trial:
|
306 |
+
nd.pop('trial', None)
|
307 |
+
|
308 |
+
with open(self.log_txt_path, 'w') as fp:
|
309 |
+
json.dump(nd, fp, indent=2)
|
310 |
+
|
311 |
+
def touch_log(self): # listener will kill me if log_txt_path is not updated for 120s
|
312 |
+
os.utime(self.log_txt_path) # about 2e-6 sec
|
313 |
+
|
314 |
+
def state_dict(self, key_ordered=True) -> Union[OrderedDict, dict]:
|
315 |
+
d = (OrderedDict if key_ordered else dict)()
|
316 |
+
# self.as_dict() would contain methods, but we only need variables
|
317 |
+
for k in self.class_variables.keys():
|
318 |
+
if k not in {'device', 'dbg_ks_fp'}: # these are not serializable
|
319 |
+
d[k] = getattr(self, k)
|
320 |
+
return d
|
321 |
+
|
322 |
+
def load_state_dict(self, d: Union[OrderedDict, dict, str]):
|
323 |
+
if isinstance(d, str): # for compatibility with old version
|
324 |
+
d: dict = eval('\n'.join([l for l in d.splitlines() if '<bound' not in l and 'device(' not in l]))
|
325 |
+
for k in d.keys():
|
326 |
+
if k in {'is_large_model', 'gpt_training'}:
|
327 |
+
continue
|
328 |
+
try:
|
329 |
+
setattr(self, k, d[k])
|
330 |
+
except Exception as e:
|
331 |
+
print(f'k={k}, v={d[k]}')
|
332 |
+
raise e
|
333 |
+
|
334 |
+
@staticmethod
|
335 |
+
def set_tf32(tf32: bool):
|
336 |
+
if torch.cuda.is_available():
|
337 |
+
torch.backends.cudnn.allow_tf32 = bool(tf32)
|
338 |
+
torch.backends.cuda.matmul.allow_tf32 = bool(tf32)
|
339 |
+
if hasattr(torch, 'set_float32_matmul_precision'):
|
340 |
+
torch.set_float32_matmul_precision('high' if tf32 else 'highest')
|
341 |
+
print(f'[tf32] [precis] torch.get_float32_matmul_precision(): {torch.get_float32_matmul_precision()}')
|
342 |
+
print(f'[tf32] [ conv ] torch.backends.cudnn.allow_tf32: {torch.backends.cudnn.allow_tf32}')
|
343 |
+
print(f'[tf32] [matmul] torch.backends.cuda.matmul.allow_tf32: {torch.backends.cuda.matmul.allow_tf32}')
|
344 |
+
|
345 |
+
def __str__(self):
|
346 |
+
s = []
|
347 |
+
for k in self.class_variables.keys():
|
348 |
+
if k not in {'device', 'dbg_ks_fp'}: # these are not serializable
|
349 |
+
s.append(f' {k:20s}: {getattr(self, k)}')
|
350 |
+
s = '\n'.join(s)
|
351 |
+
return f'{{\n{s}\n}}\n'
|
352 |
+
|
353 |
+
|
354 |
+
def init_dist_and_get_args():
|
355 |
+
for i in range(len(sys.argv)):
|
356 |
+
if sys.argv[i].startswith('--local-rank=') or sys.argv[i].startswith('--local_rank='):
|
357 |
+
del sys.argv[i]
|
358 |
+
break
|
359 |
+
args = Args(explicit_bool=True).parse_args(known_only=True)
|
360 |
+
args.chunk_nodes = int(os.environ.get('CK', '') or '0')
|
361 |
+
|
362 |
+
if len(args.extra_args) > 0 and args.is_master_node == 0:
|
363 |
+
print(f'======================================================================================')
|
364 |
+
print(f'=========================== WARNING: UNEXPECTED EXTRA ARGS ===========================\n{args.extra_args}')
|
365 |
+
print(f'=========================== WARNING: UNEXPECTED EXTRA ARGS ===========================')
|
366 |
+
print(f'======================================================================================\n\n')
|
367 |
+
|
368 |
+
args.set_tf32(args.tf32)
|
369 |
+
if args.dbg:
|
370 |
+
torch.autograd.set_detect_anomaly(True)
|
371 |
+
|
372 |
+
try: os.makedirs(args.bed, exist_ok=True)
|
373 |
+
except: pass
|
374 |
+
try: os.makedirs(args.local_out_path, exist_ok=True)
|
375 |
+
except: pass
|
376 |
+
|
377 |
+
day3 = 60*24*3
|
378 |
+
dist.init_distributed_mode(local_out_path=args.local_out_path, fork=False, timeout_minutes=day3 if int(os.environ.get('LONG_DBG', '0') or '0') > 0 else 30)
|
379 |
+
|
380 |
+
args.tlen = max(args.tlen, args.nodata_tlen)
|
381 |
+
if args.zero and args.tema != 0:
|
382 |
+
args.tema = 0
|
383 |
+
print(f'======================================================================================')
|
384 |
+
print(f'======================== WARNING: args.tema:=0, due to zero={args.zero} ========================')
|
385 |
+
print(f'======================================================================================\n\n')
|
386 |
+
|
387 |
+
if args.nodata:
|
388 |
+
args.nova = True
|
389 |
+
|
390 |
+
if not args.tos_profiler_file_prefix.endswith('/'): args.tos_profiler_file_prefix += '/'
|
391 |
+
|
392 |
+
if args.alng < 0:
|
393 |
+
args.alng = args.aln
|
394 |
+
|
395 |
+
args.device = dist.get_device()
|
396 |
+
args.r_accu = 1 / args.ac # gradient accumulation
|
397 |
+
args.data_load_reso = None
|
398 |
+
args.rand |= args.seed is None
|
399 |
+
args.sche = args.sche or ('lin0' if args.gpt_training else 'cos')
|
400 |
+
if args.wp == 0:
|
401 |
+
args.wp = args.ep * 1/100
|
402 |
+
|
403 |
+
di = {
|
404 |
+
'b': 'bilinear', 'c': 'bicubic', 'n': 'nearest', 'a': 'area', 'aa': 'area+area',
|
405 |
+
'at': 'auto', 'auto': 'auto',
|
406 |
+
'v': 'vae',
|
407 |
+
'x': 'pix', 'xg': 'pix_glu', 'gx': 'pix_glu', 'g': 'pix_glu'
|
408 |
+
}
|
409 |
+
|
410 |
+
args.ada = args.ada or ('0.9_0.96' if args.gpt_training else '0.5_0.9')
|
411 |
+
args.dada = args.dada or args.ada
|
412 |
+
args.opt = args.opt.lower().strip()
|
413 |
+
|
414 |
+
if args.lbs:
|
415 |
+
bs_per_gpu = args.lbs / args.ac
|
416 |
+
else:
|
417 |
+
bs_per_gpu = args.bs / args.ac / dist.get_world_size()
|
418 |
+
bs_per_gpu = round(bs_per_gpu)
|
419 |
+
args.batch_size = bs_per_gpu
|
420 |
+
args.bs = args.glb_batch_size = args.batch_size * dist.get_world_size()
|
421 |
+
args.workers = min(args.workers, bs_per_gpu)
|
422 |
+
args.dblr = args.dblr or args.gblr
|
423 |
+
args.glr = args.ac * args.gblr * args.glb_batch_size / 256
|
424 |
+
args.dlr = args.ac * args.dblr * args.glb_batch_size / 256
|
425 |
+
args.tlr = args.ac * args.tblr * args.glb_batch_size / 256
|
426 |
+
args.gwde = args.gwde or args.gwd
|
427 |
+
args.dwde = args.dwde or args.dwd
|
428 |
+
args.twde = args.twde or args.twd
|
429 |
+
|
430 |
+
if args.dbg_modified:
|
431 |
+
torch.autograd.set_detect_anomaly(True)
|
432 |
+
args.dbg_ks &= dist.is_local_master()
|
433 |
+
if args.dbg_ks:
|
434 |
+
args.dbg_ks_fp = open(os.path.join(args.local_out_path, 'dbg_ks.txt'), 'w')
|
435 |
+
|
436 |
+
# gpt args
|
437 |
+
if args.gpt_training:
|
438 |
+
assert args.vae_ckpt, 'VAE ckpt must be specified when training GPT'
|
439 |
+
from infinity.models import alias_dict, alias_dict_inv
|
440 |
+
if args.model in alias_dict:
|
441 |
+
args.model = alias_dict[args.model]
|
442 |
+
args.model_alias = alias_dict_inv[args.model]
|
443 |
+
else:
|
444 |
+
args.model_alias = args.model
|
445 |
+
args.model = f'infinity_{args.model}'
|
446 |
+
|
447 |
+
args.task_id = '123'
|
448 |
+
args.trial_id = '123'
|
449 |
+
args.robust_run_id = '0'
|
450 |
+
args.log_txt_path = os.path.join(args.local_out_path, 'log.txt')
|
451 |
+
|
452 |
+
ls = '[]'
|
453 |
+
if 'AUTO_RESUME' in os.environ:
|
454 |
+
ls.append(int(os.environ['AUTO_RESUME']))
|
455 |
+
ls = sorted(ls, reverse=True)
|
456 |
+
ls = [str(i) for i in ls]
|
457 |
+
args.ckpt_trials = ls
|
458 |
+
args.real_trial_id = args.trial_id if len(ls) == 0 else str(ls[-1])
|
459 |
+
|
460 |
+
args.enable_checkpointing = None if args.enable_checkpointing in [False, 0, "0"] else args.enable_checkpointing
|
461 |
+
args.enable_checkpointing = "full-block" if args.enable_checkpointing in [True, 1, "1"] else args.enable_checkpointing
|
462 |
+
assert args.enable_checkpointing in [None, "full-block", "full-attn", "self-attn"], \
|
463 |
+
f"only support no-checkpointing or full-block/full-attn checkpointing, but got {args.enable_checkpointing}."
|
464 |
+
|
465 |
+
if len(args.exp_name) == 0:
|
466 |
+
args.exp_name = os.path.basename(args.bed) or 'test_exp'
|
467 |
+
|
468 |
+
if '-' in args.exp_name:
|
469 |
+
args.tag, args.exp_name = args.exp_name.split('-', maxsplit=1)
|
470 |
+
else:
|
471 |
+
args.tag = 'UK'
|
472 |
+
|
473 |
+
if dist.is_master():
|
474 |
+
os.system(f'rm -rf {os.path.join(args.bed, "ready-node*")} {os.path.join(args.local_out_path, "ready-node*")}')
|
475 |
+
|
476 |
+
if args.sdpa_mem:
|
477 |
+
from torch.backends.cuda import enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp
|
478 |
+
enable_flash_sdp(True)
|
479 |
+
enable_mem_efficient_sdp(True)
|
480 |
+
enable_math_sdp(False)
|
481 |
+
|
482 |
+
return args
|
utils/csv_util.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import os.path as osp
|
3 |
+
import csv
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
|
8 |
+
def write_dicts2csv_file(input_dict_list, csv_filename):
|
9 |
+
os.makedirs(osp.dirname(csv_filename), exist_ok=True)
|
10 |
+
with open(csv_filename, mode='w', newline='', encoding='utf-8') as file:
|
11 |
+
fieldnames = input_dict_list[0].keys()
|
12 |
+
writer = csv.DictWriter(file, fieldnames=fieldnames)
|
13 |
+
writer.writeheader()
|
14 |
+
writer.writerows(input_dict_list)
|
15 |
+
print(f'"{csv_filename}" has been written.')
|
16 |
+
|
17 |
+
def load_csv_as_dicts(csv_filename):
|
18 |
+
with open(csv_filename, mode='r', newline='', encoding='utf-8') as csvfile:
|
19 |
+
reader = csv.DictReader(csvfile)
|
20 |
+
return list(reader)
|
utils/dist.py
ADDED
@@ -0,0 +1,326 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datetime
|
2 |
+
import functools
|
3 |
+
import os
|
4 |
+
import sys
|
5 |
+
from typing import List
|
6 |
+
from typing import Union
|
7 |
+
|
8 |
+
import pytz
|
9 |
+
import torch
|
10 |
+
import torch.distributed as tdist
|
11 |
+
import torch.multiprocessing as mp
|
12 |
+
|
13 |
+
|
14 |
+
__rank, __local_rank, __world_size, __device = 0, 0, 1, 'cpu'
|
15 |
+
__rank_str_zfill = '0'
|
16 |
+
__initialized = False
|
17 |
+
|
18 |
+
|
19 |
+
def initialized():
|
20 |
+
return __initialized
|
21 |
+
|
22 |
+
|
23 |
+
def __initialize(fork=False, backend='nccl', gpu_id_if_not_distibuted=0, timeout_minutes=30):
|
24 |
+
global __device
|
25 |
+
if not torch.cuda.is_available():
|
26 |
+
print(f'[dist initialize] cuda is not available, use cpu instead', file=sys.stderr)
|
27 |
+
return
|
28 |
+
elif 'RANK' not in os.environ:
|
29 |
+
torch.cuda.set_device(gpu_id_if_not_distibuted)
|
30 |
+
__device = torch.empty(1).cuda().device
|
31 |
+
print(f'[dist initialize] env variable "RANK" is not set, use {__device} as the device', file=sys.stderr)
|
32 |
+
return
|
33 |
+
# then 'RANK' must exist
|
34 |
+
global_rank, num_gpus = int(os.environ['RANK']), torch.cuda.device_count()
|
35 |
+
local_rank = global_rank % num_gpus
|
36 |
+
torch.cuda.set_device(local_rank)
|
37 |
+
|
38 |
+
# ref: https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py#L29
|
39 |
+
"""
|
40 |
+
if mp.get_start_method(allow_none=True) is None:
|
41 |
+
method = 'fork' if fork else 'spawn'
|
42 |
+
print(f'[dist initialize] mp method={method}')
|
43 |
+
mp.set_start_method(method)
|
44 |
+
"""
|
45 |
+
tdist.init_process_group(backend=backend, timeout=datetime.timedelta(seconds=timeout_minutes * 60))
|
46 |
+
|
47 |
+
global __rank, __local_rank, __world_size, __initialized, __rank_str_zfill
|
48 |
+
__local_rank = local_rank
|
49 |
+
__rank, __world_size = tdist.get_rank(), tdist.get_world_size()
|
50 |
+
__rank_str_zfill = str(__rank).zfill(len(str(__world_size)))
|
51 |
+
__device = torch.device(local_rank)
|
52 |
+
__initialized = True
|
53 |
+
|
54 |
+
assert tdist.is_initialized(), 'torch.distributed is not initialized!'
|
55 |
+
print(f'[lrk={get_local_rank()}, rk={get_rank()}]')
|
56 |
+
|
57 |
+
|
58 |
+
def get_rank():
|
59 |
+
return __rank
|
60 |
+
|
61 |
+
|
62 |
+
def get_rank_given_group(group: tdist.ProcessGroup):
|
63 |
+
return tdist.get_rank(group=group)
|
64 |
+
|
65 |
+
|
66 |
+
def get_rank_str_zfill():
|
67 |
+
return __rank_str_zfill
|
68 |
+
|
69 |
+
|
70 |
+
def get_local_rank():
|
71 |
+
return __local_rank
|
72 |
+
|
73 |
+
|
74 |
+
def get_world_size():
|
75 |
+
return __world_size
|
76 |
+
|
77 |
+
|
78 |
+
def get_device():
|
79 |
+
return __device
|
80 |
+
|
81 |
+
|
82 |
+
def set_gpu_id(gpu_id: int):
|
83 |
+
if gpu_id is None: return
|
84 |
+
global __device
|
85 |
+
if isinstance(gpu_id, (str, int)):
|
86 |
+
torch.cuda.set_device(int(gpu_id))
|
87 |
+
__device = torch.empty(1).cuda().device
|
88 |
+
else:
|
89 |
+
raise NotImplementedError
|
90 |
+
|
91 |
+
|
92 |
+
def is_master():
|
93 |
+
return __rank == 0
|
94 |
+
|
95 |
+
|
96 |
+
def is_local_master():
|
97 |
+
return __local_rank == 0
|
98 |
+
|
99 |
+
|
100 |
+
def is_visualizer():
|
101 |
+
return __rank == 0
|
102 |
+
# return __rank == max(__world_size - 8, 0)
|
103 |
+
|
104 |
+
|
105 |
+
def parallelize(net, syncbn=False):
|
106 |
+
if syncbn:
|
107 |
+
net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(net)
|
108 |
+
net = net.cuda()
|
109 |
+
net = torch.nn.parallel.DistributedDataParallel(net, device_ids=[get_local_rank()], find_unused_parameters=False, broadcast_buffers=False)
|
110 |
+
return net
|
111 |
+
|
112 |
+
|
113 |
+
def new_group(ranks: List[int]):
|
114 |
+
if __initialized:
|
115 |
+
return tdist.new_group(ranks=ranks)
|
116 |
+
return None
|
117 |
+
|
118 |
+
|
119 |
+
def new_local_machine_group():
|
120 |
+
if __initialized:
|
121 |
+
cur_subgroup, subgroups = tdist.new_subgroups()
|
122 |
+
return cur_subgroup
|
123 |
+
return None
|
124 |
+
|
125 |
+
|
126 |
+
def barrier():
|
127 |
+
if __initialized:
|
128 |
+
tdist.barrier()
|
129 |
+
|
130 |
+
|
131 |
+
def allreduce(t: torch.Tensor, async_op=False):
|
132 |
+
if __initialized:
|
133 |
+
if not t.is_cuda:
|
134 |
+
cu = t.detach().cuda()
|
135 |
+
ret = tdist.all_reduce(cu, async_op=async_op)
|
136 |
+
t.copy_(cu.cpu())
|
137 |
+
else:
|
138 |
+
ret = tdist.all_reduce(t, async_op=async_op)
|
139 |
+
return ret
|
140 |
+
return None
|
141 |
+
|
142 |
+
|
143 |
+
def allgather(t: torch.Tensor, cat=True) -> Union[List[torch.Tensor], torch.Tensor]:
|
144 |
+
if __initialized:
|
145 |
+
if not t.is_cuda:
|
146 |
+
t = t.cuda()
|
147 |
+
ls = [torch.empty_like(t) for _ in range(__world_size)]
|
148 |
+
tdist.all_gather(ls, t)
|
149 |
+
else:
|
150 |
+
ls = [t]
|
151 |
+
if cat:
|
152 |
+
ls = torch.cat(ls, dim=0)
|
153 |
+
return ls
|
154 |
+
|
155 |
+
|
156 |
+
def allgather_diff_shape(t: torch.Tensor, cat=True) -> Union[List[torch.Tensor], torch.Tensor]:
|
157 |
+
if __initialized:
|
158 |
+
if not t.is_cuda:
|
159 |
+
t = t.cuda()
|
160 |
+
|
161 |
+
t_size = torch.tensor(t.size(), device=t.device)
|
162 |
+
ls_size = [torch.empty_like(t_size) for _ in range(__world_size)]
|
163 |
+
tdist.all_gather(ls_size, t_size)
|
164 |
+
|
165 |
+
max_B = max(size[0].item() for size in ls_size)
|
166 |
+
pad = max_B - t_size[0].item()
|
167 |
+
if pad:
|
168 |
+
pad_size = (pad, *t.size()[1:])
|
169 |
+
t = torch.cat((t, t.new_empty(pad_size)), dim=0)
|
170 |
+
|
171 |
+
ls_padded = [torch.empty_like(t) for _ in range(__world_size)]
|
172 |
+
tdist.all_gather(ls_padded, t)
|
173 |
+
ls = []
|
174 |
+
for t, size in zip(ls_padded, ls_size):
|
175 |
+
ls.append(t[:size[0].item()])
|
176 |
+
else:
|
177 |
+
ls = [t]
|
178 |
+
if cat:
|
179 |
+
ls = torch.cat(ls, dim=0)
|
180 |
+
return ls
|
181 |
+
|
182 |
+
|
183 |
+
def broadcast(t: torch.Tensor, src_rank) -> None:
|
184 |
+
if __initialized:
|
185 |
+
if not t.is_cuda:
|
186 |
+
cu = t.detach().cuda()
|
187 |
+
tdist.broadcast(cu, src=src_rank)
|
188 |
+
t.copy_(cu.cpu())
|
189 |
+
else:
|
190 |
+
tdist.broadcast(t, src=src_rank)
|
191 |
+
|
192 |
+
|
193 |
+
def dist_fmt_vals(val: float, fmt: Union[str, None] = '%.2f') -> Union[torch.Tensor, List]:
|
194 |
+
if not initialized():
|
195 |
+
return torch.tensor([val]) if fmt is None else [fmt % val]
|
196 |
+
|
197 |
+
ts = torch.zeros(__world_size)
|
198 |
+
ts[__rank] = val
|
199 |
+
allreduce(ts)
|
200 |
+
if fmt is None:
|
201 |
+
return ts
|
202 |
+
return [fmt % v for v in ts.cpu().numpy().tolist()]
|
203 |
+
|
204 |
+
|
205 |
+
def master_only(func):
|
206 |
+
@functools.wraps(func)
|
207 |
+
def wrapper(*args, **kwargs):
|
208 |
+
force = kwargs.pop('force', False)
|
209 |
+
if force or is_master():
|
210 |
+
ret = func(*args, **kwargs)
|
211 |
+
else:
|
212 |
+
ret = None
|
213 |
+
barrier()
|
214 |
+
return ret
|
215 |
+
return wrapper
|
216 |
+
|
217 |
+
|
218 |
+
def local_master_only(func):
|
219 |
+
@functools.wraps(func)
|
220 |
+
def wrapper(*args, **kwargs):
|
221 |
+
force = kwargs.pop('force', False)
|
222 |
+
if force or is_local_master():
|
223 |
+
ret = func(*args, **kwargs)
|
224 |
+
else:
|
225 |
+
ret = None
|
226 |
+
barrier()
|
227 |
+
return ret
|
228 |
+
return wrapper
|
229 |
+
|
230 |
+
|
231 |
+
def for_visualize(func):
|
232 |
+
@functools.wraps(func)
|
233 |
+
def wrapper(*args, **kwargs):
|
234 |
+
if is_visualizer():
|
235 |
+
# with torch.no_grad():
|
236 |
+
ret = func(*args, **kwargs)
|
237 |
+
else:
|
238 |
+
ret = None
|
239 |
+
return ret
|
240 |
+
return wrapper
|
241 |
+
|
242 |
+
|
243 |
+
def finalize():
|
244 |
+
if __initialized:
|
245 |
+
tdist.destroy_process_group()
|
246 |
+
|
247 |
+
|
248 |
+
def init_distributed_mode(local_out_path, fork=False, only_sync_master=False, timeout_minutes=30):
|
249 |
+
try:
|
250 |
+
__initialize(fork=fork, timeout_minutes=timeout_minutes)
|
251 |
+
barrier()
|
252 |
+
except RuntimeError as e:
|
253 |
+
print(f'{"!"*80} dist init error (NCCL Error?), stopping training! {"!"*80}', flush=True)
|
254 |
+
raise e
|
255 |
+
|
256 |
+
if local_out_path is not None: os.makedirs(local_out_path, exist_ok=True)
|
257 |
+
_change_builtin_print(is_local_master())
|
258 |
+
if (is_master() if only_sync_master else is_local_master()) and local_out_path is not None and len(local_out_path):
|
259 |
+
sys.stdout, sys.stderr = BackupStreamToFile(local_out_path, for_stdout=True), BackupStreamToFile(local_out_path, for_stdout=False)
|
260 |
+
|
261 |
+
|
262 |
+
def _change_builtin_print(is_master):
|
263 |
+
import builtins as __builtin__
|
264 |
+
|
265 |
+
builtin_print = __builtin__.print
|
266 |
+
if type(builtin_print) != type(open):
|
267 |
+
return
|
268 |
+
|
269 |
+
def prt(*args, **kwargs):
|
270 |
+
force = kwargs.pop('force', False)
|
271 |
+
clean = kwargs.pop('clean', False)
|
272 |
+
deeper = kwargs.pop('deeper', False)
|
273 |
+
if is_master or force:
|
274 |
+
if not clean:
|
275 |
+
f_back = sys._getframe().f_back
|
276 |
+
if deeper and f_back.f_back is not None:
|
277 |
+
f_back = f_back.f_back
|
278 |
+
file_desc = f'{f_back.f_code.co_filename:24s}'[-24:]
|
279 |
+
time_str = datetime.datetime.now(tz=pytz.timezone('Asia/Shanghai')).strftime('[%m-%d %H:%M:%S]')
|
280 |
+
builtin_print(f'{time_str} ({file_desc}, line{f_back.f_lineno:-4d})=>', *args, **kwargs)
|
281 |
+
else:
|
282 |
+
builtin_print(*args, **kwargs)
|
283 |
+
|
284 |
+
__builtin__.print = prt
|
285 |
+
|
286 |
+
|
287 |
+
class BackupStreamToFile(object):
|
288 |
+
def __init__(self, local_output_dir, for_stdout=True):
|
289 |
+
self.for_stdout = for_stdout
|
290 |
+
self.terminal_stream = sys.stdout if for_stdout else sys.stderr
|
291 |
+
fname = os.path.join(local_output_dir, 'b1_stdout.txt' if for_stdout else 'b2_stderr.txt')
|
292 |
+
existing = os.path.exists(fname)
|
293 |
+
self.file_stream = open(fname, 'a')
|
294 |
+
if existing:
|
295 |
+
time_str = datetime.datetime.now(tz=pytz.timezone('Asia/Shanghai')).strftime('[%m-%d %H:%M:%S]')
|
296 |
+
self.file_stream.write('\n'*7 + '='*55 + f' RESTART {time_str} ' + '='*55 + '\n')
|
297 |
+
self.file_stream.flush()
|
298 |
+
os.system(f'ln -s {fname} /opt/tiger/run_trial/ >/dev/null 2>&1')
|
299 |
+
self.enabled = True
|
300 |
+
|
301 |
+
def write(self, message):
|
302 |
+
self.terminal_stream.write(message)
|
303 |
+
self.file_stream.write(message)
|
304 |
+
|
305 |
+
def flush(self):
|
306 |
+
self.terminal_stream.flush()
|
307 |
+
self.file_stream.flush()
|
308 |
+
|
309 |
+
def isatty(self):
|
310 |
+
return True
|
311 |
+
|
312 |
+
def close(self):
|
313 |
+
if not self.enabled:
|
314 |
+
return
|
315 |
+
self.enabled = False
|
316 |
+
self.file_stream.flush()
|
317 |
+
self.file_stream.close()
|
318 |
+
if self.for_stdout:
|
319 |
+
sys.stdout = self.terminal_stream
|
320 |
+
sys.stdout.flush()
|
321 |
+
else:
|
322 |
+
sys.stderr = self.terminal_stream
|
323 |
+
sys.stderr.flush()
|
324 |
+
|
325 |
+
def __del__(self):
|
326 |
+
self.close()
|
utils/dynamic_resolution.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import numpy as np
|
3 |
+
import tqdm
|
4 |
+
|
5 |
+
vae_stride = 16
|
6 |
+
ratio2hws = {
|
7 |
+
1.000: [(1,1),(2,2),(4,4),(6,6),(8,8),(12,12),(16,16),(20,20),(24,24),(32,32),(40,40),(48,48),(64,64)],
|
8 |
+
1.250: [(1,1),(2,2),(3,3),(5,4),(10,8),(15,12),(20,16),(25,20),(30,24),(35,28),(45,36),(55,44),(70,56)],
|
9 |
+
1.333: [(1,1),(2,2),(4,3),(8,6),(12,9),(16,12),(20,15),(24,18),(28,21),(36,27),(48,36),(60,45),(72,54)],
|
10 |
+
1.500: [(1,1),(2,2),(3,2),(6,4),(9,6),(15,10),(21,14),(27,18),(33,22),(39,26),(48,32),(63,42),(78,52)],
|
11 |
+
1.750: [(1,1),(2,2),(3,3),(7,4),(11,6),(14,8),(21,12),(28,16),(35,20),(42,24),(56,32),(70,40),(84,48)],
|
12 |
+
2.000: [(1,1),(2,2),(4,2),(6,3),(10,5),(16,8),(22,11),(30,15),(38,19),(46,23),(60,30),(74,37),(90,45)],
|
13 |
+
2.500: [(1,1),(2,2),(5,2),(10,4),(15,6),(20,8),(25,10),(30,12),(40,16),(50,20),(65,26),(80,32),(100,40)],
|
14 |
+
3.000: [(1,1),(2,2),(6,2),(9,3),(15,5),(21,7),(27,9),(36,12),(45,15),(54,18),(72,24),(90,30),(111,37)],
|
15 |
+
}
|
16 |
+
predefined_t = [1, 2, 3, 4, 5, 6, 7, 9, 11, 13, 15, 17, 21]
|
17 |
+
|
18 |
+
full_ratio2hws = {}
|
19 |
+
for ratio, hws in ratio2hws.items():
|
20 |
+
full_ratio2hws[ratio] = hws
|
21 |
+
if ratio != 1.000:
|
22 |
+
full_ratio2hws[int(1/ratio*1000)/1000] = [(item[1], item[0]) for item in hws]
|
23 |
+
|
24 |
+
dynamic_resolution_h_w = {}
|
25 |
+
for ratio in full_ratio2hws:
|
26 |
+
dynamic_resolution_h_w[ratio] ={}
|
27 |
+
for ind, leng in enumerate([7, 10, 12, 13]):
|
28 |
+
h_div_w = full_ratio2hws[ratio][leng-1][0] / full_ratio2hws[ratio][leng-1][1]
|
29 |
+
assert np.abs(h_div_w-ratio) < 0.01, f'{full_ratio2hws[ratio][leng-1]}: {h_div_w} != {ratio}'
|
30 |
+
pixel = (full_ratio2hws[ratio][leng-1][0] * vae_stride, full_ratio2hws[ratio][leng-1][1] * vae_stride)
|
31 |
+
if ind == 0:
|
32 |
+
total_pixels = '0.06M'
|
33 |
+
elif ind == 1:
|
34 |
+
total_pixels = '0.25M'
|
35 |
+
elif ind == 2:
|
36 |
+
total_pixels = '0.60M'
|
37 |
+
else:
|
38 |
+
total_pixels = '1M'
|
39 |
+
|
40 |
+
scales = full_ratio2hws[ratio][:leng]
|
41 |
+
scales = [ (t, h, w) for t, (h, w) in zip(predefined_t, scales) ]
|
42 |
+
dynamic_resolution_h_w[ratio][total_pixels] = {
|
43 |
+
'pixel': pixel,
|
44 |
+
'scales': scales
|
45 |
+
}
|
46 |
+
|
47 |
+
h_div_w_templates = []
|
48 |
+
for h_div_w in dynamic_resolution_h_w.keys():
|
49 |
+
h_div_w_templates.append(h_div_w)
|
50 |
+
h_div_w_templates = np.array(h_div_w_templates)
|
51 |
+
|
52 |
+
def get_h_div_w_template2indices(h_div_w_list, h_div_w_templates):
|
53 |
+
indices = list(range(len(h_div_w_list)))
|
54 |
+
h_div_w_template2indices = {}
|
55 |
+
pbar = tqdm.tqdm(total=len(indices), desc='get_h_div_w_template2indices...')
|
56 |
+
for h_div_w, index in zip(h_div_w_list, indices):
|
57 |
+
pbar.update(1)
|
58 |
+
nearest_h_div_w_template_ = h_div_w_templates[np.argmin(np.abs(h_div_w-h_div_w_templates))]
|
59 |
+
if nearest_h_div_w_template_ not in h_div_w_template2indices:
|
60 |
+
h_div_w_template2indices[nearest_h_div_w_template_] = []
|
61 |
+
h_div_w_template2indices[nearest_h_div_w_template_].append(index)
|
62 |
+
for h_div_w_template_, sub_indices in h_div_w_template2indices.items():
|
63 |
+
h_div_w_template2indices[h_div_w_template_] = np.array(sub_indices)
|
64 |
+
return h_div_w_template2indices
|
65 |
+
|
66 |
+
if __name__ == '__main__':
|
67 |
+
for h_div_w_template in dynamic_resolution_h_w:
|
68 |
+
for total_pixels in dynamic_resolution_h_w[h_div_w_template]:
|
69 |
+
scales = np.array(dynamic_resolution_h_w[h_div_w_template][total_pixels]['scales'])
|
70 |
+
seq_len = np.sum(scales[:,0]*scales[:,1])
|
71 |
+
if total_pixels == '1M':
|
72 |
+
string = f'{h_div_w_template}, {total_pixels}, {dynamic_resolution_h_w[h_div_w_template][total_pixels]}, seq_len: {seq_len}'.replace(', ', ',')
|
73 |
+
print(string)
|
utils/large_file_util.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import os.path as osp
|
3 |
+
import time
|
4 |
+
import itertools
|
5 |
+
import shutil
|
6 |
+
import glob
|
7 |
+
import argparse
|
8 |
+
|
9 |
+
import tqdm
|
10 |
+
import numpy as np
|
11 |
+
import threading
|
12 |
+
|
13 |
+
def save_lines(lines, filename):
|
14 |
+
os.makedirs(osp.dirname(filename), exist_ok=True)
|
15 |
+
with open(filename, 'w') as f:
|
16 |
+
f.writelines(lines)
|
17 |
+
del lines
|
18 |
+
|
19 |
+
def get_part_jsonls(filepath, total_line_number, parts=512):
|
20 |
+
dirname, filename, ext = osp.dirname(filepath), osp.splitext(osp.basename(filepath))[0], osp.splitext(osp.basename(filepath))[1]
|
21 |
+
if parts == 1:
|
22 |
+
return False, {1: filepath}
|
23 |
+
save_dir = osp.join(dirname, f'{parts:04d}_parts')
|
24 |
+
chunk_id2save_files = {}
|
25 |
+
missing = False
|
26 |
+
chunk_size = int(total_line_number/parts)
|
27 |
+
for chunk_id in range(1, parts+1):
|
28 |
+
if chunk_id == parts:
|
29 |
+
num_of_lines = total_line_number - chunk_size * (parts-1)
|
30 |
+
else:
|
31 |
+
num_of_lines = chunk_size
|
32 |
+
chunk_id2save_files[chunk_id] = osp.join(save_dir, f'{filename}_{chunk_id:04d}_{parts:04d}_{num_of_lines:09d}{ext}')
|
33 |
+
if not osp.exists(chunk_id2save_files[chunk_id]):
|
34 |
+
missing = True
|
35 |
+
return missing, chunk_id2save_files
|
36 |
+
|
37 |
+
def split_large_txt_files(filepath, chunk_id2save_files):
|
38 |
+
thread_list = []
|
39 |
+
chunk_id = 1
|
40 |
+
with open(filepath, 'r') as f:
|
41 |
+
chunk = []
|
42 |
+
pbar = tqdm.tqdm(total=len(chunk_id2save_files))
|
43 |
+
for line in f:
|
44 |
+
chunk.append(line)
|
45 |
+
cur_chunk_size = int(osp.splitext(osp.basename(chunk_id2save_files[chunk_id]))[0].split('_')[-1])
|
46 |
+
if len(chunk) >= cur_chunk_size:
|
47 |
+
pbar.update(1)
|
48 |
+
thread_list.append(threading.Thread(target=save_lines, args=(chunk, chunk_id2save_files[chunk_id])))
|
49 |
+
thread_list[-1].start()
|
50 |
+
chunk = []
|
51 |
+
chunk_id += 1
|
52 |
+
if len(chunk):
|
53 |
+
import ipdb; ipdb.set_trace()
|
54 |
+
assert not len(chunk)
|
55 |
+
for thread in thread_list:
|
56 |
+
thread.join()
|
57 |
+
|
58 |
+
if __name__ == '__main__':
|
59 |
+
parser = argparse.ArgumentParser()
|
60 |
+
parser.add_argument('--jsonl_folder', type=str, default='')
|
61 |
+
parser.add_argument('--parts', type=int, default=600)
|
62 |
+
args = parser.parse_args()
|
63 |
+
for jsonl_filepath in sorted(glob.glob(osp.join(args.jsonl_folder, '*.jsonl'))):
|
64 |
+
print(jsonl_filepath)
|
65 |
+
t1 = time.time()
|
66 |
+
line_num = int(jsonl_filepath.split('_')[-1].split('.')[0])
|
67 |
+
missing, chunk_id2save_files = get_part_jsonls(jsonl_filepath, line_num, parts=args.parts)
|
68 |
+
split_large_txt_files(jsonl_filepath, chunk_id2save_files)
|
69 |
+
t2 = time.time()
|
70 |
+
print(f'split takes {t2-t1}s')
|
utils/load.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
import gc
|
3 |
+
import os
|
4 |
+
import os.path as osp
|
5 |
+
import random
|
6 |
+
import sys
|
7 |
+
from copy import deepcopy
|
8 |
+
from typing import Tuple, Union
|
9 |
+
|
10 |
+
import colorama
|
11 |
+
import torch
|
12 |
+
import yaml
|
13 |
+
|
14 |
+
import infinity.utils.dist as dist
|
15 |
+
|
16 |
+
from infinity.models import Infinity
|
17 |
+
from infinity.models.ema import get_ema_model
|
18 |
+
from infinity.utils import arg_util, misc
|
19 |
+
from infinity.utils.misc import os_system
|
20 |
+
|
21 |
+
|
22 |
+
def build_vae_gpt(args: arg_util.Args, vae_st: dict, skip_gpt: bool, force_flash=False, device='cuda'):
|
23 |
+
if args.vae_type in [8,16,18,20,24,32,64,128]:
|
24 |
+
from infinity.models.bsq_vae.vae import vae_model
|
25 |
+
schedule_mode = "dynamic"
|
26 |
+
codebook_dim = args.vae_type # 18
|
27 |
+
codebook_size = 2**codebook_dim
|
28 |
+
if args.apply_spatial_patchify:
|
29 |
+
patch_size = 8
|
30 |
+
encoder_ch_mult=[1, 2, 4, 4]
|
31 |
+
decoder_ch_mult=[1, 2, 4, 4]
|
32 |
+
else:
|
33 |
+
patch_size = 16
|
34 |
+
encoder_ch_mult=[1, 2, 4, 4, 4]
|
35 |
+
decoder_ch_mult=[1, 2, 4, 4, 4]
|
36 |
+
vae_local = vae_model(vae_st, schedule_mode, codebook_dim, codebook_size, patch_size=patch_size,
|
37 |
+
encoder_ch_mult=encoder_ch_mult, decoder_ch_mult=decoder_ch_mult, test_mode=True).to(args.device)
|
38 |
+
if args.fake_vae_input:
|
39 |
+
vae_local.encoder = None
|
40 |
+
vae_local.decoder = None
|
41 |
+
torch.cuda.empty_cache()
|
42 |
+
else:
|
43 |
+
raise ValueError(f"vae_type {args.vae_type} not supported")
|
44 |
+
if force_flash: args.flash = True
|
45 |
+
gpt_kw = dict(
|
46 |
+
pretrained=False, global_pool='',
|
47 |
+
text_channels=args.Ct5, text_maxlen=args.tlen,
|
48 |
+
norm_eps=args.norm_eps, rms_norm=args.rms,
|
49 |
+
shared_aln=args.saln, head_aln=args.haln,
|
50 |
+
cond_drop_rate=args.cfg, rand_uncond=args.rand_uncond, drop_rate=args.drop,
|
51 |
+
cross_attn_layer_scale=args.ca_gamma, nm0=args.nm0, tau=args.tau, cos_attn=args.cos, swiglu=args.swi,
|
52 |
+
raw_scale_schedule=args.scale_schedule,
|
53 |
+
head_depth=args.dec,
|
54 |
+
top_p=args.tp, top_k=args.tk,
|
55 |
+
customized_flash_attn=args.flash, fused_mlp=args.fuse, fused_norm=args.fused_norm,
|
56 |
+
checkpointing=args.enable_checkpointing,
|
57 |
+
pad_to_multiplier=args.pad_to_multiplier,
|
58 |
+
use_flex_attn=args.use_flex_attn,
|
59 |
+
batch_size=args.batch_size,
|
60 |
+
add_lvl_embeding_only_first_block=args.add_lvl_embeding_only_first_block,
|
61 |
+
use_bit_label=args.use_bit_label,
|
62 |
+
rope2d_each_sa_layer=args.rope2d_each_sa_layer,
|
63 |
+
rope2d_normalized_by_hw=args.rope2d_normalized_by_hw,
|
64 |
+
pn=args.pn,
|
65 |
+
train_h_div_w_list=args.train_h_div_w_list,
|
66 |
+
always_training_scales=args.always_training_scales,
|
67 |
+
apply_spatial_patchify=args.apply_spatial_patchify,
|
68 |
+
)
|
69 |
+
if args.dp >= 0: gpt_kw['drop_path_rate'] = args.dp
|
70 |
+
if args.hd > 0: gpt_kw['num_heads'] = args.hd
|
71 |
+
|
72 |
+
print(f'[create gpt_wo_ddp] constructor kw={gpt_kw}\n')
|
73 |
+
gpt_kw['vae_local'] = vae_local
|
74 |
+
|
75 |
+
model_str = args.model.replace('vgpt', 'infinity') # legacy
|
76 |
+
print(f"{model_str=}")
|
77 |
+
if model_str.rsplit('c', maxsplit=1)[-1].isdecimal():
|
78 |
+
model_str, block_chunks = model_str.rsplit('c', maxsplit=1)
|
79 |
+
block_chunks = int(block_chunks)
|
80 |
+
else:
|
81 |
+
block_chunks = 1
|
82 |
+
gpt_kw['block_chunks'] = block_chunks
|
83 |
+
|
84 |
+
from infinity.models import Infinity
|
85 |
+
from timm.models import create_model
|
86 |
+
gpt_wo_ddp: Infinity = create_model(model_str, **gpt_kw)
|
87 |
+
if args.use_fsdp_model_ema:
|
88 |
+
gpt_wo_ddp_ema = get_ema_model(gpt_wo_ddp)
|
89 |
+
else:
|
90 |
+
gpt_wo_ddp_ema = None
|
91 |
+
gpt_wo_ddp = gpt_wo_ddp.to(device)
|
92 |
+
|
93 |
+
assert all(not p.requires_grad for p in vae_local.parameters())
|
94 |
+
assert all(p.requires_grad for n, p in gpt_wo_ddp.named_parameters())
|
95 |
+
|
96 |
+
return vae_local, gpt_wo_ddp, gpt_wo_ddp_ema
|
97 |
+
|
98 |
+
|
99 |
+
if __name__ == '__main__':
|
100 |
+
ld(sys.argv[1])
|
utils/lr_control.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from pprint import pformat
|
3 |
+
from typing import Tuple, List, Dict, Union
|
4 |
+
|
5 |
+
import torch.nn
|
6 |
+
import infinity.utils.dist as dist
|
7 |
+
|
8 |
+
|
9 |
+
def lr_wd_annealing(sche_type: str, optimizer, peak_lr, wd, wd_end, cur_it, wp_it, max_it, wp0=0.005, wpe=0.001):
|
10 |
+
"""Decay the learning rate with half-cycle cosine after warmup"""
|
11 |
+
wp_it = round(wp_it)
|
12 |
+
|
13 |
+
if cur_it < wp_it:
|
14 |
+
cur_lr = wp0 + (1-wp0) * cur_it / wp_it
|
15 |
+
else:
|
16 |
+
pasd = (cur_it - wp_it) / (max_it-1 - wp_it) # [0, 1]
|
17 |
+
rest = 1 - pasd # [1, 0]
|
18 |
+
if sche_type == 'cos':
|
19 |
+
cur_lr = wpe + (1-wpe) * (0.5 + 0.5 * math.cos(math.pi * pasd))
|
20 |
+
elif sche_type == 'lin':
|
21 |
+
T = 0.15; max_rest = 1-T
|
22 |
+
if pasd < T: cur_lr = 1
|
23 |
+
else: cur_lr = wpe + (1-wpe) * rest / max_rest # 1 to wpe
|
24 |
+
elif sche_type == 'lin0':
|
25 |
+
T = 0.05; max_rest = 1-T
|
26 |
+
if pasd < T: cur_lr = 1
|
27 |
+
else: cur_lr = wpe + (1-wpe) * rest / max_rest
|
28 |
+
elif sche_type == 'lin00':
|
29 |
+
cur_lr = wpe + (1-wpe) * rest
|
30 |
+
elif sche_type.startswith('lin'):
|
31 |
+
T = float(sche_type[3:]); max_rest = 1-T
|
32 |
+
wpe_mid = wpe + (1-wpe) * max_rest
|
33 |
+
wpe_mid = (1 + wpe_mid) / 2
|
34 |
+
if pasd < T: cur_lr = 1 + (wpe_mid-1) * pasd / T
|
35 |
+
else: cur_lr = wpe + (wpe_mid-wpe) * rest / max_rest
|
36 |
+
elif sche_type == 'exp':
|
37 |
+
T = 0.15; max_rest = 1-T
|
38 |
+
if pasd < T: cur_lr = 1
|
39 |
+
else:
|
40 |
+
expo = (pasd-T) / max_rest * math.log(wpe)
|
41 |
+
cur_lr = math.exp(expo)
|
42 |
+
else:
|
43 |
+
raise NotImplementedError(f'unknown sche_type {sche_type}')
|
44 |
+
|
45 |
+
cur_lr *= peak_lr
|
46 |
+
pasd = cur_it / (max_it-1)
|
47 |
+
cur_wd = wd_end + (wd - wd_end) * (0.5 + 0.5 * math.cos(math.pi * pasd))
|
48 |
+
|
49 |
+
inf = 1e6
|
50 |
+
min_lr, max_lr = inf, -1
|
51 |
+
min_wd, max_wd = inf, -1
|
52 |
+
for param_group in optimizer.param_groups:
|
53 |
+
param_group['lr'] = cur_lr * param_group.get('lr_sc', 1) # 'lr_sc' could be assigned
|
54 |
+
max_lr = max(max_lr, param_group['lr'])
|
55 |
+
min_lr = min(min_lr, param_group['lr'])
|
56 |
+
|
57 |
+
param_group['weight_decay'] = cur_wd * param_group.get('wd_sc', 1)
|
58 |
+
max_wd = max(max_wd, param_group['weight_decay'])
|
59 |
+
if param_group['weight_decay'] > 0:
|
60 |
+
min_wd = min(min_wd, param_group['weight_decay'])
|
61 |
+
|
62 |
+
if min_lr == inf: min_lr = -1
|
63 |
+
if min_wd == inf: min_wd = -1
|
64 |
+
return min_lr, max_lr, min_wd, max_wd
|
65 |
+
|
66 |
+
|
67 |
+
def filter_params(model, ndim_dict, nowd_keys=(), lr_scale=0.0) -> Tuple[
|
68 |
+
List[str], List[torch.nn.Parameter], List[Dict[str, Union[torch.nn.Parameter, float]]]
|
69 |
+
]:
|
70 |
+
with_lr_scale = hasattr(model, 'get_layer_id_and_scale_exp') and 0 < lr_scale <= 1
|
71 |
+
print(f'[get_param_groups][lr decay] with_lr_scale={with_lr_scale}, lr_scale={lr_scale}')
|
72 |
+
para_groups, para_groups_dbg = {}, {}
|
73 |
+
names, paras = [], []
|
74 |
+
names_no_grad = []
|
75 |
+
count, numel = 0, 0
|
76 |
+
for name, para in model.named_parameters():
|
77 |
+
name = name.replace('_fsdp_wrapped_module.', '')
|
78 |
+
if not para.requires_grad:
|
79 |
+
names_no_grad.append(name)
|
80 |
+
continue # frozen weights
|
81 |
+
count += 1
|
82 |
+
numel += para.numel()
|
83 |
+
names.append(name)
|
84 |
+
paras.append(para)
|
85 |
+
|
86 |
+
if ndim_dict.get(name, 2) == 1 or name.endswith('bias') or any(k in name for k in nowd_keys):
|
87 |
+
cur_wd_sc, group_name = 0., 'ND'
|
88 |
+
# elif any(k in name for k in small_wd_keys):
|
89 |
+
# cur_wd_sc, group_name = small_wd, 'small_decay'
|
90 |
+
else:
|
91 |
+
cur_wd_sc, group_name = 1., 'D'
|
92 |
+
|
93 |
+
if with_lr_scale:
|
94 |
+
layer_id, scale_exp = model.get_layer_id_and_scale_exp(name)
|
95 |
+
group_name = f'layer{layer_id}_' + group_name
|
96 |
+
cur_lr_sc = lr_scale ** scale_exp
|
97 |
+
dbg = f'[layer {layer_id}][sc = {lr_scale} ** {scale_exp}]'
|
98 |
+
else:
|
99 |
+
cur_lr_sc = 1.
|
100 |
+
dbg = f'[no scale]'
|
101 |
+
|
102 |
+
if group_name not in para_groups:
|
103 |
+
para_groups[group_name] = {'params': [], 'wd_sc': cur_wd_sc, 'lr_sc': cur_lr_sc}
|
104 |
+
para_groups_dbg[group_name] = {'params': [], 'wd_sc': cur_wd_sc, 'lr_sc': dbg}
|
105 |
+
para_groups[group_name]['params'].append(para)
|
106 |
+
para_groups_dbg[group_name]['params'].append(name)
|
107 |
+
|
108 |
+
for g in para_groups_dbg.values():
|
109 |
+
g['params'] = pformat(', '.join(g['params']), width=200)
|
110 |
+
|
111 |
+
print(f'[get_param_groups] param_groups = \n{pformat(para_groups_dbg, indent=2, width=240)}\n')
|
112 |
+
|
113 |
+
for rk in range(dist.get_world_size()):
|
114 |
+
dist.barrier()
|
115 |
+
if dist.get_rank() == rk:
|
116 |
+
print(f'[get_param_groups][rank{dist.get_rank()}] {type(model).__name__=} {count=}, {numel=}', flush=True, force=True)
|
117 |
+
print('')
|
118 |
+
|
119 |
+
assert len(names_no_grad) == 0, f'[get_param_groups] names_no_grad = \n{pformat(names_no_grad, indent=2, width=240)}\n'
|
120 |
+
del ndim_dict
|
121 |
+
return names, paras, list(para_groups.values())
|
122 |
+
|
123 |
+
|
124 |
+
def plot():
|
125 |
+
import matplotlib.pyplot as plt
|
126 |
+
import torch.nn as nn
|
127 |
+
from torch.optim import SGD
|
128 |
+
# for sche in ('lin', 'lin0', 'lin00', 'lin0.5', 'lin0.75'):
|
129 |
+
for sche in ('lin0', ):
|
130 |
+
op = SGD(nn.Linear(3, 4).parameters(), lr=1e-3)
|
131 |
+
it, lr = [], []
|
132 |
+
iters = 500
|
133 |
+
wp_it, max_it = 1 * iters, 10 * iters
|
134 |
+
for cur_it in range(max_it):
|
135 |
+
it.append(cur_it)
|
136 |
+
lr.append(lr_wd_annealing(sche, op, 0.1, 1e-5, 1e-5, cur_it, wp_it, max_it, wpe=0.3)[0])
|
137 |
+
|
138 |
+
plt.figure()
|
139 |
+
plt.title(sche)
|
140 |
+
plt.plot(it, lr, 'b', label=sche)
|
141 |
+
plt.xlabel('it'), plt.ylabel('lr')
|
142 |
+
plt.legend()
|
143 |
+
|
144 |
+
plt.savefig('lr.jpg')
|
145 |
+
|
146 |
+
|
147 |
+
if __name__ == '__main__':
|
148 |
+
plot()
|
utils/misc.py
ADDED
@@ -0,0 +1,397 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datetime
|
2 |
+
import functools
|
3 |
+
import math
|
4 |
+
import os
|
5 |
+
import random
|
6 |
+
import subprocess
|
7 |
+
import sys
|
8 |
+
import threading
|
9 |
+
import time
|
10 |
+
from collections import defaultdict, deque
|
11 |
+
from typing import Iterator, List, Tuple
|
12 |
+
|
13 |
+
import numpy as np
|
14 |
+
import pytz
|
15 |
+
import torch
|
16 |
+
import torch.distributed as tdist
|
17 |
+
import torch.nn.functional as F
|
18 |
+
|
19 |
+
import infinity.utils.dist as dist
|
20 |
+
|
21 |
+
os_system = functools.partial(subprocess.call, shell=True)
|
22 |
+
def echo(info):
|
23 |
+
os_system(f'echo "[$(date "+%m-%d-%H:%M:%S")] ({os.path.basename(sys._getframe().f_back.f_code.co_filename)}, line{sys._getframe().f_back.f_lineno})=> {info}"')
|
24 |
+
def os_system_get_stdout(cmd):
|
25 |
+
return subprocess.run(cmd, shell=True, stdout=subprocess.PIPE).stdout.decode('utf-8')
|
26 |
+
def os_system_get_stdout_stderr(cmd):
|
27 |
+
cnt = 0
|
28 |
+
while True:
|
29 |
+
try:
|
30 |
+
sp = subprocess.run(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, timeout=30)
|
31 |
+
except subprocess.TimeoutExpired:
|
32 |
+
cnt += 1
|
33 |
+
print(f'[fetch free_port file] timeout cnt={cnt}')
|
34 |
+
else:
|
35 |
+
return sp.stdout.decode('utf-8'), sp.stderr.decode('utf-8')
|
36 |
+
|
37 |
+
|
38 |
+
def is_pow2n(x):
|
39 |
+
return x > 0 and (x & (x - 1) == 0)
|
40 |
+
|
41 |
+
|
42 |
+
def time_str(fmt='[%m-%d %H:%M:%S]'):
|
43 |
+
return datetime.datetime.now(tz=pytz.timezone('Asia/Shanghai')).strftime(fmt)
|
44 |
+
|
45 |
+
|
46 |
+
class DistLogger(object):
|
47 |
+
def __init__(self, lg):
|
48 |
+
self._lg = lg
|
49 |
+
|
50 |
+
@staticmethod
|
51 |
+
def do_nothing(*args, **kwargs):
|
52 |
+
pass
|
53 |
+
|
54 |
+
def __getattr__(self, attr: str):
|
55 |
+
return getattr(self._lg, attr) if self._lg is not None else DistLogger.do_nothing
|
56 |
+
|
57 |
+
class TensorboardLogger(object):
|
58 |
+
def __init__(self, log_dir, filename_suffix):
|
59 |
+
try: import tensorflow_io as tfio
|
60 |
+
except: pass
|
61 |
+
from torch.utils.tensorboard import SummaryWriter
|
62 |
+
self.writer = SummaryWriter(log_dir=log_dir, filename_suffix=filename_suffix)
|
63 |
+
self.step = 0
|
64 |
+
|
65 |
+
def set_step(self, step=None):
|
66 |
+
if step is not None:
|
67 |
+
self.step = step
|
68 |
+
else:
|
69 |
+
self.step += 1
|
70 |
+
|
71 |
+
def loggable(self):
|
72 |
+
return self.step == 0 or (self.step + 1) % 500 == 0
|
73 |
+
|
74 |
+
def update(self, head='scalar', step=None, **kwargs):
|
75 |
+
if step is None:
|
76 |
+
step = self.step
|
77 |
+
if not self.loggable(): return
|
78 |
+
for k, v in kwargs.items():
|
79 |
+
if v is None: continue
|
80 |
+
if hasattr(v, 'item'): v = v.item()
|
81 |
+
self.writer.add_scalar(f'{head}/{k}', v, step)
|
82 |
+
|
83 |
+
def log_tensor_as_distri(self, tag, tensor1d, step=None):
|
84 |
+
if step is None:
|
85 |
+
step = self.step
|
86 |
+
if not self.loggable(): return
|
87 |
+
try:
|
88 |
+
self.writer.add_histogram(tag=tag, values=tensor1d, global_step=step)
|
89 |
+
except Exception as e:
|
90 |
+
print(f'[log_tensor_as_distri writer.add_histogram failed]: {e}')
|
91 |
+
|
92 |
+
def log_image(self, tag, img_chw, step=None):
|
93 |
+
if step is None:
|
94 |
+
step = self.step
|
95 |
+
if not self.loggable(): return
|
96 |
+
self.writer.add_image(tag, img_chw, step, dataformats='CHW')
|
97 |
+
|
98 |
+
def flush(self):
|
99 |
+
self.writer.flush()
|
100 |
+
|
101 |
+
def close(self):
|
102 |
+
self.writer.close()
|
103 |
+
|
104 |
+
|
105 |
+
class Low_GPU_usage(object):
|
106 |
+
def __init__(self, files, sleep_secs, verbose):
|
107 |
+
pass
|
108 |
+
|
109 |
+
def early_stop(self):
|
110 |
+
pass
|
111 |
+
|
112 |
+
def __enter__(self):
|
113 |
+
return self
|
114 |
+
|
115 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
116 |
+
pass
|
117 |
+
|
118 |
+
class TouchingDaemonDontForgetToStartMe(threading.Thread):
|
119 |
+
def __init__(self, files: List[str], sleep_secs: int, verbose=False):
|
120 |
+
super().__init__(daemon=True)
|
121 |
+
self.files = tuple(files)
|
122 |
+
self.sleep_secs = sleep_secs
|
123 |
+
self.is_finished = False
|
124 |
+
self.verbose = verbose
|
125 |
+
|
126 |
+
f_back = sys._getframe().f_back
|
127 |
+
file_desc = f'{f_back.f_code.co_filename:24s}'[-24:]
|
128 |
+
self.print_prefix = f' ({file_desc}, line{f_back.f_lineno:-4d}) @daemon@ '
|
129 |
+
|
130 |
+
def finishing(self):
|
131 |
+
self.is_finished = True
|
132 |
+
|
133 |
+
def run(self) -> None:
|
134 |
+
kw = {}
|
135 |
+
if tdist.is_initialized(): kw['clean'] = True
|
136 |
+
|
137 |
+
stt = time.time()
|
138 |
+
if self.verbose: print(f'{time_str()}{self.print_prefix}[TouchingDaemon tid={threading.get_native_id()}] start touching {self.files} per {self.sleep_secs}s ...', **kw)
|
139 |
+
while not self.is_finished:
|
140 |
+
for f in self.files:
|
141 |
+
if os.path.exists(f):
|
142 |
+
try:
|
143 |
+
os.utime(f)
|
144 |
+
fp = open(f, 'a')
|
145 |
+
fp.close()
|
146 |
+
except: pass
|
147 |
+
time.sleep(self.sleep_secs)
|
148 |
+
|
149 |
+
if self.verbose: print(f'{time_str()}{self.print_prefix}[TouchingDaemon tid={threading.get_native_id()}] finish touching after {time.time()-stt:.1f} secs {self.files} per {self.sleep_secs}s. ', **kw)
|
150 |
+
|
151 |
+
|
152 |
+
class SmoothedValue(object):
|
153 |
+
"""Track a series of values and provide access to smoothed values over a
|
154 |
+
window or the global series average.
|
155 |
+
"""
|
156 |
+
|
157 |
+
def __init__(self, window_size=30, fmt=None):
|
158 |
+
if fmt is None:
|
159 |
+
fmt = "{median:.4f} ({global_avg:.4f})"
|
160 |
+
self.deque = deque(maxlen=window_size)
|
161 |
+
self.total = 0.0
|
162 |
+
self.count = 0
|
163 |
+
self.fmt = fmt
|
164 |
+
|
165 |
+
def update(self, value, n=1):
|
166 |
+
self.deque.append(value)
|
167 |
+
self.count += n
|
168 |
+
self.total += value * n
|
169 |
+
|
170 |
+
def synchronize_between_processes(self):
|
171 |
+
"""
|
172 |
+
Warning: does not synchronize the deque!
|
173 |
+
"""
|
174 |
+
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
|
175 |
+
tdist.barrier()
|
176 |
+
tdist.all_reduce(t)
|
177 |
+
t = t.tolist()
|
178 |
+
self.count = int(t[0])
|
179 |
+
self.total = t[1]
|
180 |
+
|
181 |
+
@property
|
182 |
+
def median(self):
|
183 |
+
return np.median(self.deque) if len(self.deque) else 0
|
184 |
+
|
185 |
+
@property
|
186 |
+
def avg(self):
|
187 |
+
return sum(self.deque) / (len(self.deque) or 1)
|
188 |
+
|
189 |
+
@property
|
190 |
+
def global_avg(self):
|
191 |
+
return self.total / (self.count or 1)
|
192 |
+
|
193 |
+
@property
|
194 |
+
def max(self):
|
195 |
+
return max(self.deque) if len(self.deque) else 0
|
196 |
+
|
197 |
+
@property
|
198 |
+
def value(self):
|
199 |
+
return self.deque[-1] if len(self.deque) else 0
|
200 |
+
|
201 |
+
def time_preds(self, counts) -> Tuple[float, str, str]:
|
202 |
+
remain_secs = counts * self.median
|
203 |
+
return remain_secs, str(datetime.timedelta(seconds=round(remain_secs))), time.strftime("%Y-%m-%d %H:%M", time.localtime(time.time() + remain_secs))
|
204 |
+
|
205 |
+
def __str__(self):
|
206 |
+
return self.fmt.format(median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value)
|
207 |
+
|
208 |
+
|
209 |
+
class MetricLogger(object):
|
210 |
+
def __init__(self):
|
211 |
+
self.meters = defaultdict(SmoothedValue)
|
212 |
+
self.iter_end_t = time.time()
|
213 |
+
self.log_iters = set()
|
214 |
+
self.log_every_iter = False
|
215 |
+
|
216 |
+
def update(self, **kwargs):
|
217 |
+
# if it != 0 and it not in self.log_iters: return
|
218 |
+
for k, v in kwargs.items():
|
219 |
+
if v is None: continue
|
220 |
+
if hasattr(v, 'item'): v = v.item()
|
221 |
+
# assert isinstance(v, (float, int)), type(v)
|
222 |
+
self.meters[k].update(v)
|
223 |
+
|
224 |
+
def __getattr__(self, attr):
|
225 |
+
if attr in self.meters:
|
226 |
+
return self.meters[attr]
|
227 |
+
if attr in self.__dict__:
|
228 |
+
return self.__dict__[attr]
|
229 |
+
raise AttributeError("'{}' object has no attribute '{}'".format(
|
230 |
+
type(self).__name__, attr))
|
231 |
+
|
232 |
+
def __str__(self):
|
233 |
+
loss_str = []
|
234 |
+
for name, meter in self.meters.items():
|
235 |
+
if len(meter.deque):
|
236 |
+
loss_str.append(
|
237 |
+
"{}: {}".format(name, str(meter))
|
238 |
+
)
|
239 |
+
return ' '.join(loss_str)
|
240 |
+
|
241 |
+
def synchronize_between_processes(self):
|
242 |
+
for meter in self.meters.values():
|
243 |
+
meter.synchronize_between_processes()
|
244 |
+
|
245 |
+
def add_meter(self, name, meter):
|
246 |
+
self.meters[name] = meter
|
247 |
+
|
248 |
+
def log_every(self, start_it, max_iters, itrt, log_freq, log_every_iter=False, header=''): # also solve logging & skipping iterations before start_it
|
249 |
+
start_it = start_it % max_iters
|
250 |
+
self.log_iters = set(range(start_it, max_iters, log_freq))
|
251 |
+
self.log_iters.add(start_it)
|
252 |
+
self.log_iters.add(max_iters-1)
|
253 |
+
self.log_iters.add(max_iters)
|
254 |
+
self.log_every_iter = log_every_iter
|
255 |
+
self.iter_end_t = time.time()
|
256 |
+
self.iter_time = SmoothedValue(fmt='{value:.4f}')
|
257 |
+
self.data_time = SmoothedValue(fmt='{value:.3f}')
|
258 |
+
header_fmt = header + ': [{0:' + str(len(str(max_iters))) + 'd}/{1}]'
|
259 |
+
|
260 |
+
start_time = time.time()
|
261 |
+
if isinstance(itrt, Iterator) and not hasattr(itrt, 'preload') and not hasattr(itrt, 'set_epoch'):
|
262 |
+
for it in range(start_it, max_iters):
|
263 |
+
obj = next(itrt)
|
264 |
+
if it < start_it: continue
|
265 |
+
self.data_time.update(time.time() - self.iter_end_t)
|
266 |
+
yield it, obj
|
267 |
+
self.iter_time.update(time.time() - self.iter_end_t)
|
268 |
+
if self.log_every_iter or it in self.log_iters:
|
269 |
+
eta_seconds = self.iter_time.avg * (max_iters - it)
|
270 |
+
print(f'{header_fmt.format(it, max_iters)} eta: {str(datetime.timedelta(seconds=int(eta_seconds)))} {str(self)} T: {self.iter_time.value:.3f}s dataT: {self.data_time.value*1e3:.1f}ms', flush=True)
|
271 |
+
self.iter_end_t = time.time()
|
272 |
+
else:
|
273 |
+
if isinstance(itrt, int): itrt = range(itrt)
|
274 |
+
for it, obj in enumerate(itrt):
|
275 |
+
if it < start_it:
|
276 |
+
self.iter_end_t = time.time()
|
277 |
+
continue
|
278 |
+
self.data_time.update(time.time() - self.iter_end_t)
|
279 |
+
yield it, obj
|
280 |
+
self.iter_time.update(time.time() - self.iter_end_t)
|
281 |
+
if self.log_every_iter or it in self.log_iters:
|
282 |
+
eta_seconds = self.iter_time.avg * (max_iters - it)
|
283 |
+
print(f'{header_fmt.format(it, max_iters)} eta: {str(datetime.timedelta(seconds=int(eta_seconds)))} {str(self)} T: {self.iter_time.value:.3f}s dataT: {self.data_time.value*1e3:.1f}ms', flush=True)
|
284 |
+
self.iter_end_t = time.time()
|
285 |
+
cost = time.time() - start_time
|
286 |
+
cost_str = str(datetime.timedelta(seconds=int(cost)))
|
287 |
+
print(f'{header} Cost of this ep: {cost_str} ({cost / (max_iters-start_it):.3f} s / it)', flush=True)
|
288 |
+
|
289 |
+
|
290 |
+
class NullDDP(torch.nn.Module):
|
291 |
+
def __init__(self, module, *args, **kwargs):
|
292 |
+
super(NullDDP, self).__init__()
|
293 |
+
self.module = module
|
294 |
+
self.require_backward_grad_sync = False
|
295 |
+
|
296 |
+
def forward(self, *args, **kwargs):
|
297 |
+
return self.module(*args, **kwargs)
|
298 |
+
|
299 |
+
|
300 |
+
def build_2d_sincos_position_embedding(h, w, embed_dim, temperature=10000., sc=0, verbose=True): # (1, hw**2, embed_dim)
|
301 |
+
# DiT: sc=0
|
302 |
+
# DETR: sc=2?
|
303 |
+
grid_w = torch.arange(w, dtype=torch.float32)
|
304 |
+
grid_h = torch.arange(h, dtype=torch.float32)
|
305 |
+
grid_w, grid_h = torch.meshgrid([grid_w, grid_h], indexing='ij')
|
306 |
+
if sc == 0:
|
307 |
+
scale = 1
|
308 |
+
elif sc == 1:
|
309 |
+
scale = math.pi * 2 / w
|
310 |
+
else:
|
311 |
+
scale = 1 / w
|
312 |
+
grid_w = scale * grid_w.reshape(h*w, 1) # scale * [0, 0, 0, 1, 1, 1, 2, 2, 2]
|
313 |
+
grid_h = scale * grid_h.reshape(h*w, 1) # scale * [0, 1, 2, 0, 1, 2, 0, 1, 2]
|
314 |
+
|
315 |
+
assert embed_dim % 4 == 0, f'Embed dimension ({embed_dim}) must be divisible by 4 for 2D sin-cos position embedding!'
|
316 |
+
pos_dim = embed_dim // 4
|
317 |
+
omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
|
318 |
+
omega = (-math.log(temperature) * omega).exp()
|
319 |
+
# omega == (1/T) ** (arange(pos_dim) / pos_dim), a vector only dependent on C
|
320 |
+
out_w = grid_w * omega.view(1, pos_dim) # out_w: scale * [0*ome, 0*ome, 0*ome, 1*ome, 1*ome, 1*ome, 2*ome, 2*ome, 2*ome]
|
321 |
+
out_h = grid_h * omega.view(1, pos_dim) # out_h: scale * [0*ome, 1*ome, 2*ome, 0*ome, 1*ome, 2*ome, 0*ome, 1*ome, 2*ome]
|
322 |
+
pos_emb = torch.cat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h), torch.cos(out_h)], dim=1)[None, :, :]
|
323 |
+
if verbose: print(f'[build_2d_sincos_position_embedding @ {hw} x {hw}] scale_type={sc}, temperature={temperature:g}, shape={pos_emb.shape}')
|
324 |
+
return pos_emb # (1, hw**2, embed_dim)
|
325 |
+
|
326 |
+
|
327 |
+
if __name__ == '__main__':
|
328 |
+
import seaborn as sns
|
329 |
+
import matplotlib.pyplot as plt
|
330 |
+
cmap_div = sns.color_palette('icefire', as_cmap=True)
|
331 |
+
|
332 |
+
scs = [0, 1, 2]
|
333 |
+
temps = [20, 50, 100, 1000]
|
334 |
+
reso = 3.0
|
335 |
+
RR, CC = len(scs), len(temps)
|
336 |
+
plt.figure(figsize=(CC * reso, RR * reso)) # figsize=(16, 16)
|
337 |
+
for row, sc in enumerate(scs):
|
338 |
+
for col, temp in enumerate(temps):
|
339 |
+
name = f'sc={sc}, T={temp}'
|
340 |
+
hw, C = 16, 512
|
341 |
+
N = hw*hw
|
342 |
+
pe = build_2d_sincos_position_embedding(hw, C, temperature=temp, sc=sc, verbose=False)[0] # N, C = 64, 16
|
343 |
+
|
344 |
+
hw2 = 16
|
345 |
+
N2 = hw2*hw2
|
346 |
+
pe2 = build_2d_sincos_position_embedding(hw2, C, temperature=temp, sc=sc, verbose=False)[0] # N, C = 64, 16
|
347 |
+
# pe2 = pe2.flip(dims=(0,))
|
348 |
+
bchw, bchw2 = F.normalize(pe.view(hw, hw, C).permute(2, 0, 1).unsqueeze(0), dim=1), F.normalize(pe2.view(hw2, hw2, C).permute(2, 0, 1).unsqueeze(0), dim=1)
|
349 |
+
dis = [
|
350 |
+
f'{F.mse_loss(bchw, F.interpolate(bchw2, size=bchw.shape[-2], mode=inter)).item():.3f}'
|
351 |
+
for inter in ('bilinear', 'bicubic', 'nearest')
|
352 |
+
]
|
353 |
+
dis += [
|
354 |
+
f'{F.mse_loss(F.interpolate(bchw, size=bchw2.shape[-2], mode=inter), bchw2).item():.3f}'
|
355 |
+
for inter in ('area', 'nearest')
|
356 |
+
]
|
357 |
+
print(f'[{name:^20s}] dis: {dis}')
|
358 |
+
"""
|
359 |
+
[ sc=0, T=20 ] dis: ['0.010', '0.011', '0.011', '0.009', '0.010']
|
360 |
+
[ sc=0, T=100 ] dis: ['0.007', '0.007', '0.007', '0.006', '0.007']
|
361 |
+
[ sc=0, T=1000 ] dis: ['0.005', '0.005', '0.005', '0.004', '0.005']
|
362 |
+
[ sc=0, T=10000 ] dis: ['0.004', '0.004', '0.004', '0.003', '0.004']
|
363 |
+
[ sc=1, T=20 ] dis: ['0.007', '0.008', '0.008', '0.007', '0.008']
|
364 |
+
[ sc=1, T=100 ] dis: ['0.005', '0.005', '0.005', '0.005', '0.005']
|
365 |
+
[ sc=1, T=1000 ] dis: ['0.003', '0.003', '0.003', '0.003', '0.003']
|
366 |
+
[ sc=1, T=10000 ] dis: ['0.003', '0.003', '0.003', '0.003', '0.003']
|
367 |
+
[ sc=2, T=20 ] dis: ['0.000', '0.000', '0.000', '0.000', '0.000']
|
368 |
+
[ sc=2, T=100 ] dis: ['0.000', '0.000', '0.000', '0.000', '0.000']
|
369 |
+
[ sc=2, T=1000 ] dis: ['0.000', '0.000', '0.000', '0.000', '0.000']
|
370 |
+
[ sc=2, T=10000 ] dis: ['0.000', '0.000', '0.000', '0.000', '0.000']
|
371 |
+
Process finished with exit code 0
|
372 |
+
"""
|
373 |
+
|
374 |
+
pe = torch.from_numpy(cmap_div(pe.T.numpy())[:, :, :3]) # C, N, 3
|
375 |
+
tar_h, tar_w = 1024, 1024
|
376 |
+
pe = pe.repeat_interleave(tar_w//pe.shape[0], dim=0).repeat_interleave(tar_h//pe.shape[1], dim=1)
|
377 |
+
plt.subplot(RR, CC, 1+row*CC+col)
|
378 |
+
plt.title(name)
|
379 |
+
plt.xlabel('hxw'), plt.ylabel('C')
|
380 |
+
plt.xticks([]), plt.yticks([])
|
381 |
+
plt.imshow(pe.mul(255).round().clamp(0, 255).byte().numpy())
|
382 |
+
plt.tight_layout(h_pad=0.02)
|
383 |
+
plt.show()
|
384 |
+
|
385 |
+
|
386 |
+
def check_randomness(args):
|
387 |
+
U = 16384
|
388 |
+
t = torch.zeros(dist.get_world_size(), 4, dtype=torch.float32, device=args.device)
|
389 |
+
t0 = torch.zeros(1, dtype=torch.float32, device=args.device).random_(U)
|
390 |
+
t[dist.get_rank(), 0] = float(random.randrange(U))
|
391 |
+
t[dist.get_rank(), 1] = float(np.random.randint(U))
|
392 |
+
t[dist.get_rank(), 2] = float(torch.randint(0, U, (1,))[0])
|
393 |
+
t[dist.get_rank(), 3] = float(t0[0])
|
394 |
+
dist.allreduce(t)
|
395 |
+
for rk in range(1, dist.get_world_size()):
|
396 |
+
assert torch.allclose(t[rk - 1], t[rk]), f't={t}'
|
397 |
+
del t0, t, U
|
utils/save_and_load.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gc
|
2 |
+
import os
|
3 |
+
import subprocess
|
4 |
+
import time
|
5 |
+
import re
|
6 |
+
from typing import List, Optional, Tuple
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
10 |
+
|
11 |
+
import glob
|
12 |
+
import shutil
|
13 |
+
from infinity.utils import arg_util
|
14 |
+
import infinity.utils.dist as dist
|
15 |
+
|
16 |
+
|
17 |
+
def glob_with_epoch_iter(pattern, recursive=False):
|
18 |
+
def extract_ep_iter(filename):
|
19 |
+
match = re.search(r'ep(\d+)-iter(\d+)', filename)
|
20 |
+
if match:
|
21 |
+
ep = int(match.group(1))
|
22 |
+
iter_idx = int(match.group(2))
|
23 |
+
return ep, iter_idx
|
24 |
+
return 0, 0
|
25 |
+
return sorted(glob.glob(pattern, recursive=recursive), key=lambda x: extract_ep_iter(os.path.basename(x)), reverse=True)
|
26 |
+
|
27 |
+
|
28 |
+
def glob_with_global_step(pattern, recursive=False):
|
29 |
+
def extract_ep_iter(filename):
|
30 |
+
match = re.search(r'global_step_(\d+)', filename)
|
31 |
+
if match:
|
32 |
+
iter_idx = int(match.group(1))
|
33 |
+
return iter_idx
|
34 |
+
return 0
|
35 |
+
return sorted(glob.glob(pattern, recursive=recursive), key=lambda x: extract_ep_iter(os.path.basename(x)), reverse=True)
|
36 |
+
|
37 |
+
|
38 |
+
class CKPTSaver(object):
|
39 |
+
def __init__(self, is_master: bool, eval_milestone: List[Tuple[float, float]]):
|
40 |
+
self.is_master = is_master
|
41 |
+
self.time_stamp = torch.tensor([time.time() - 1e5, time.time()], device=dist.get_device())
|
42 |
+
self.sp_also: subprocess.Popen = None
|
43 |
+
self.sp_best: subprocess.Popen = None
|
44 |
+
self.sp_backup: subprocess.Popen = None
|
45 |
+
self.acc_str, self.eval_milestone = '[no acc str]', eval_milestone
|
46 |
+
|
47 |
+
def sav(
|
48 |
+
self, args: arg_util.Args, g_it: int, next_ep: int, next_it: int, trainer,
|
49 |
+
acc_str: Optional[str] = None, eval_milestone: Optional[List[Tuple[float, float]]] = None,
|
50 |
+
also_save_to: str = None, best_save_to: str = None,
|
51 |
+
):
|
52 |
+
self.time_stamp[1] = time.time()
|
53 |
+
dist.broadcast(self.time_stamp, src_rank=0)
|
54 |
+
last_save_time, cur_time = self.time_stamp.cpu().tolist()
|
55 |
+
|
56 |
+
auto_save = cur_time - last_save_time > 20 * 60
|
57 |
+
need_save = also_save_to is not None or best_save_to is not None or next_ep == args.ep or auto_save
|
58 |
+
if not need_save:
|
59 |
+
return
|
60 |
+
|
61 |
+
if acc_str is not None: self.acc_str = acc_str
|
62 |
+
if eval_milestone is not None: self.eval_milestone = eval_milestone
|
63 |
+
|
64 |
+
fname = f'ar-ckpt-giter{g_it//1000:03d}K-ep{next_ep}-iter{next_it}-last.pth' if args.gpt_training else f'ckpt-last.pth'
|
65 |
+
local_out_ckpt = os.path.join(args.local_out_path, fname)
|
66 |
+
|
67 |
+
# NOTE: all rank should call this state_dict(), not master only!
|
68 |
+
trainer_state = trainer.state_dict()
|
69 |
+
|
70 |
+
if self.is_master:
|
71 |
+
stt = time.time()
|
72 |
+
torch.save({
|
73 |
+
'args': args.state_dict(),
|
74 |
+
'gpt_training': args.gpt_training,
|
75 |
+
'arch': args.model if args.gpt_training else args.vv,
|
76 |
+
'epoch': next_ep,
|
77 |
+
'iter': next_it,
|
78 |
+
'trainer': trainer_state,
|
79 |
+
'acc_str': self.acc_str,
|
80 |
+
'milestones': self.eval_milestone,
|
81 |
+
}, local_out_ckpt)
|
82 |
+
|
83 |
+
print(f'[CKPTSaver][rank00] start: {also_save_to=} {best_save_to=} {(next_ep == args.ep)=} {auto_save=} | see {local_out_ckpt}', flush=True)
|
84 |
+
print(f'[CKPTSaver][rank00] dbg: {args.bed=}', flush=True)
|
85 |
+
if auto_save:
|
86 |
+
if self.sp_backup is not None:
|
87 |
+
self.sp_backup.wait(timeout=300); self.sp_backup.kill(); self.sp_backup.communicate()
|
88 |
+
self.time_stamp[0] = time.time()
|
89 |
+
|
90 |
+
def auto_sync(source_filename, target_filename):
|
91 |
+
cmd = f'cp -r {source_filename} {target_filename}'
|
92 |
+
self.sp_backup = subprocess.Popen(cmd, shell=True, bufsize=-1)
|
93 |
+
print(f'[CKPTSaver] auto_save cmd: {cmd}', flush=True)
|
94 |
+
|
95 |
+
local_files = glob.glob(f"{args.local_out_path}/*")
|
96 |
+
for filename in local_files:
|
97 |
+
basename = os.path.basename(filename)
|
98 |
+
target_filename = f'{args.bed}/{basename}'
|
99 |
+
if basename.endswith('.pth'):
|
100 |
+
if not os.path.isfile(target_filename):
|
101 |
+
auto_sync(filename, target_filename)
|
102 |
+
else:
|
103 |
+
auto_sync(filename, target_filename)
|
104 |
+
cost = time.time() - stt
|
105 |
+
print(f'[CKPTSaver][rank00] cost: {cost:.2f}s', flush=True)
|
106 |
+
|
107 |
+
del trainer_state
|
108 |
+
time.sleep(3), gc.collect(), torch.cuda.empty_cache(), time.sleep(3)
|
109 |
+
dist.barrier()
|
110 |
+
|
111 |
+
|
112 |
+
def auto_resume(args: arg_util.Args, pattern='ckpt*.pth') -> Tuple[List[str], int, int, str, List[Tuple[float, float]], dict, dict]:
|
113 |
+
info = []
|
114 |
+
resume = ''
|
115 |
+
if args.auto_resume:
|
116 |
+
for dd in (args.local_out_path, args.bed):
|
117 |
+
all_ckpt = glob_with_epoch_iter(os.path.join(dd, pattern))
|
118 |
+
if len(all_ckpt): break
|
119 |
+
if len(all_ckpt) == 0:
|
120 |
+
info.append(f'[auto_resume] no ckpt found @ {pattern}')
|
121 |
+
info.append(f'[auto_resume quit]')
|
122 |
+
else:
|
123 |
+
resume = all_ckpt[0]
|
124 |
+
info.append(f'[auto_resume] auto load from @ {resume} ...')
|
125 |
+
else:
|
126 |
+
info.append(f'[auto_resume] disabled')
|
127 |
+
info.append(f'[auto_resume quit]')
|
128 |
+
|
129 |
+
if len(resume) == 0:
|
130 |
+
return info, 0, 0, '[no acc str]', [], {}, {}
|
131 |
+
|
132 |
+
print(f'auto resume from {resume}')
|
133 |
+
|
134 |
+
try:
|
135 |
+
ckpt = torch.load(resume, map_location='cpu')
|
136 |
+
except Exception as e:
|
137 |
+
info.append(f'[auto_resume] failed, {e} @ {resume}')
|
138 |
+
if len(all_ckpt) < 2:
|
139 |
+
return info, 0, 0, '[no acc str]', [], {}, {}
|
140 |
+
try: # another chance to load from bytenas
|
141 |
+
ckpt = torch.load(all_ckpt[1], map_location='cpu')
|
142 |
+
except Exception as e:
|
143 |
+
info.append(f'[auto_resume] failed, {e} @ {all_ckpt[1]}')
|
144 |
+
return info, 0, 0, '[no acc str]', [], {}, {}
|
145 |
+
|
146 |
+
dist.barrier()
|
147 |
+
ep, it = ckpt['epoch'], ckpt['iter']
|
148 |
+
eval_milestone = ckpt.get('milestones', [])
|
149 |
+
info.append(f'[auto_resume success] resume from ep{ep}, it{it}, eval_milestone: {eval_milestone}')
|
150 |
+
return info, ep, it, ckpt.get('acc_str', '[no acc str]'), eval_milestone, ckpt['trainer'], ckpt['args']
|
utils/wandb_utils.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import wandb
|
2 |
+
import torch
|
3 |
+
from torchvision.utils import make_grid
|
4 |
+
import torch.distributed as dist
|
5 |
+
from PIL import Image
|
6 |
+
import os
|
7 |
+
import argparse
|
8 |
+
import hashlib
|
9 |
+
import math
|
10 |
+
|
11 |
+
|
12 |
+
def is_main_process():
|
13 |
+
return dist.get_rank() == 0
|
14 |
+
|
15 |
+
def namespace_to_dict(namespace):
|
16 |
+
return {
|
17 |
+
k: namespace_to_dict(v) if isinstance(v, argparse.Namespace) else v
|
18 |
+
for k, v in vars(namespace).items()
|
19 |
+
}
|
20 |
+
|
21 |
+
|
22 |
+
def generate_run_id(exp_name):
|
23 |
+
# https://stackoverflow.com/questions/16008670/how-to-hash-a-string-into-8-digits
|
24 |
+
return str(int(hashlib.sha256(exp_name.encode('utf-8')).hexdigest(), 16) % 10 ** 8)
|
25 |
+
|
26 |
+
|
27 |
+
def initialize(args, entity, exp_name, project_name):
|
28 |
+
config_dict = namespace_to_dict(args)
|
29 |
+
wandb.login(key=os.environ["WANDB_KEY"])
|
30 |
+
wandb.init(
|
31 |
+
entity=entity,
|
32 |
+
project=project_name,
|
33 |
+
name=exp_name,
|
34 |
+
config=config_dict,
|
35 |
+
id=generate_run_id(exp_name),
|
36 |
+
resume="allow",
|
37 |
+
)
|
38 |
+
|
39 |
+
|
40 |
+
def log(stats, step=None):
|
41 |
+
if is_main_process():
|
42 |
+
wandb.log({k: v for k, v in stats.items()}, step=step)
|
43 |
+
|
44 |
+
|
45 |
+
def log_image(name, sample, step=None):
|
46 |
+
if is_main_process():
|
47 |
+
sample = array2grid(sample)
|
48 |
+
wandb.log({f"{name}": wandb.Image(sample), "train_step": step})
|
49 |
+
|
50 |
+
|
51 |
+
def array2grid(x):
|
52 |
+
nrow = round(math.sqrt(x.size(0)))
|
53 |
+
x = make_grid(x, nrow=nrow, normalize=True, value_range=(-1,1))
|
54 |
+
x = x.mul(255).add_(0.5).clamp_(0,255).permute(1,2,0).to('cpu', torch.uint8).numpy()
|
55 |
+
return x
|