import json import math import os import random import subprocess import sys import time from collections import OrderedDict, deque from typing import Optional, Union import numpy as np import torch from tap import Tap import infinity.utils.dist as dist class Args(Tap): local_out_path: str = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'local_output') # directory for save checkpoints data_path: str = '' # dataset bed: str = '' # bed directory for copy checkpoints apart from local_out_path vae_ckpt: str = '' # VAE ckpt exp_name: str = '' # experiment name ds: str = 'oi' # only used in GPT training::load_viz_data & FID benchmark model: str = '' # for VAE training, 'b' or any other for GPT training short_cap_prob: float = 0.2 # prob for training with short captions project_name: str = 'Infinity' # name of wandb project tf32: bool = True # whether to use TensorFloat32 auto_resume: bool = True # whether to automatically resume from the last checkpoint found in args.bed rush_resume: str = '' # pretrained infinity checkpoint nowd: int = 1 # whether to disable weight decay on sparse params (like class token) enable_hybrid_shard: bool = False # whether to use hybrid FSDP inner_shard_degree: int = 1 # inner degree for FSDP zero: int = 0 # ds zero buck: str = 'chunk' # =0 for using module-wise fsdp_orig: bool = True enable_checkpointing: str = None # checkpointing strategy: full-block, self-attn pad_to_multiplier: int = 1 # >1 for padding the seq len to a multiplier of this log_every_iter: bool = False checkpoint_type: str = 'torch' # checkpoint_type: torch, onmistore seed: int = None # 3407 rand: bool = True # actual seed = seed + (dist.get_rank()*512 if rand else 0) device: str = 'cpu' task_id: str = '2493513' trial_id: str = '7260554' robust_run_id: str = '00' ckpt_trials = [] real_trial_id: str = '7260552' chunk_nodes: int = None is_master_node: bool = None # dir log_txt_path: str = '' t5_path: str = '' # if not specified: automatically find from all bytenas online_t5: bool = True # whether to use online t5 or load local features # GPT sdpa_mem: bool = True # whether to use with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=False, enable_mem_efficient=True) tfast: int = 0 # compile GPT model_alias: str = 'b' # [automatically set; don't specify this] rms: bool = False aln: float = 1e-3 # multiplier of ada_lin.w's initialization alng: float = -1 # multiplier of ada_lin.w[gamma channels]'s initialization, -1: the same as aln saln: bool = False # whether to use a shared adaln layer haln: bool = True # whether to use a specific adaln layer in head layer nm0: bool = False # norm before word proj linear tau: float = 1 # tau of self attention in GPT cos: bool = True # cosine attn as in swin v2 swi: bool = False # whether to use FFNSwiGLU, instead of vanilla FFN dp: float = -1 drop: float = 0.0 # GPT's dropout (VAE's is --vd) hd: int = 0 ca_gamma: float = -1 # >=0 for using layer-scale for cross attention diva: int = 1 # rescale_attn_fc_weights hd0: float = 0.02 # head.w *= hd0 dec: int = 1 # dec depth cum: int = 3 # cumulating fea map as GPT TF input, 0: not cum; 1: cum @ next hw, 2: cum @ final hw rwe: bool = False # random word emb tp: float = 0.0 # top-p tk: float = 0.0 # top-k tini: float = 0.02 # init parameters cfg: float = 0.1 # >0: classifier-free guidance, drop cond with prob cfg rand_uncond = False # whether to use random, unlearnable uncond embeding ema: float = 0.9999 # VAE's ema ratio, not VAR's. 0.9977844 == 0.5 ** (32 / (10 * 1000)) from gans, 0.9999 from SD tema: float = 0 # 0.9999 in DiffiT, DiT fp16: int = 0 # 1: fp16, 2: bf16, >2: fp16's max scaling multiplier todo: 记得让quantize相关的feature都强制fp32!另外residueal最好也是fp32(根据flash-attention)nn.Conv2d有一个参数是use_float16? fuse: bool = False # whether to use fused mlp fused_norm: bool = False # whether to use fused norm flash: bool = False # whether to use customized flash-attn kernel xen: bool = False # whether to use xentropy use_flex_attn: bool = False # whether to use flex_attn to speedup training stable: bool = False gblr: float = 1e-4 dblr: float = None # =gblr if is None tblr: float = 6e-4 glr: float = None dlr: float = None tlr: float = None # vqgan: 4e-5 gwd: float = 0.005 dwd: float = 0.0005 twd: float = 0.005 # vqgan: 0.01 gwde: float = 0 dwde: float = 0 twde: float = 0 ls: float = 0.0 # label smooth lz: float = 0.0 # z loss from PaLM = 1e-4 todo eq: int = 0 # equalized loss ep: int = 100 wp: float = 0 wp0: float = 0.005 wpe: float = 0.3 # 0.001, final cosine lr = wpe * peak lr sche: str = '' # cos, exp, lin log_freq: int = 50 # log frequency in the stdout gclip: float = 6. # <=0 for not grad clip VAE dclip: float = 6. # <=0 for not grad clip discriminator tclip: float = 2. # <=0 for not grad clip GPT; >100 for per-param clip (%= 100 automatically) cdec: bool = False # decay the grad clip thresholds of GPT and GPT's word embed opt: str = 'adamw' # lion: https://cloud.tencent.com/developer/article/2336657?areaId=106001 lr=5e-5(比Adam学习率低四倍)和wd=0.8(比Adam高八倍);比如在小的 batch_size 时,Lion 的表现不如 AdamW ada: str = '' # adam's beta0 and beta1 for VAE or GPT, '0_0.99' from style-swin and magvit, '0.5_0.9' from VQGAN dada: str = '' # adam's beta0 and beta1 for discriminator oeps: float = 0 # adam's eps, pixart uses 1e-10 afuse: bool = True # fused adam # data pn: str = '' # pixel nums, choose from 0.06M, 0.25M, 1M scale_schedule: tuple = None # [automatically set; don't specify this] = tuple(map(int, args.pn.replace('-', '_').split('_'))) patch_size: int = None # [automatically set; don't specify this] = 2 ** (len(args.scale_schedule) - 1) resos: tuple = None # [automatically set; don't specify this] data_load_reso: int = None # [automatically set; don't specify this] workers: int = 0 # num workers; 0: auto, -1: don't use multiprocessing in DataLoader lbs: int = 0 # local batch size; if lbs != 0, bs will be ignored, and will be reset as round(args.lbs / args.ac) * dist.get_world_size() bs: int = 0 # global batch size; if lbs != 0, bs will be ignored batch_size: int = 0 # [automatically set; don't specify this] batch size per GPU = round(args.bs / args.ac / dist.get_world_size()) glb_batch_size: int = 0 # [automatically set; don't specify this] global batch size = args.batch_size * dist.get_world_size() ac: int = 1 # gradient accumulation r_accu: float = 1.0 # [automatically set; don't specify this] = 1 / args.ac norm_eps: float = 1e-6 # norm eps for infinity tlen: int = 512 # truncate text embedding to this length Ct5: int = 2048 # feature dimension of text encoder use_bit_label: int = 1 # pred bitwise labels or index-wise labels bitloss_type: str = 'mean' # mean or sum dynamic_resolution_across_gpus: int = 1 # allow dynamic resolution across gpus enable_dynamic_length_prompt: int = 0 # enable dynamic length prompt during training use_streaming_dataset: int = 0 # use streaming dataset iterable_data_buffersize: int = 90000 # streaming dataset buffer size save_model_iters_freq: int = 1000 # save model iter freq noise_apply_layers: int = -1 # Bitwise Self-Correction: apply noise to layers, -1 means not apply noise noise_apply_strength: float = -1 # Bitwise Self-Correction: apply noise strength, -1 means not apply noise noise_apply_requant: int = 1 # Bitwise Self-Correction: requant after apply noise rope2d_each_sa_layer: int = 0 # apply rope2d to each self-attention layer rope2d_normalized_by_hw: int = 1 # apply normalized rope2d use_fsdp_model_ema: int = 0 # use fsdp model ema add_lvl_embeding_only_first_block: int = 1 # apply lvl pe embedding only first block or each block reweight_loss_by_scale: int = 0 # reweight loss by scale always_training_scales: int = 100 # trunc training scales vae_type: int = 1 # here 16/32/64 is bsq vae of different quant bits fake_vae_input: bool = False # fake vae input for debug model_init_device: str = 'cuda' # model_init_device prefetch_factor: int = 2 # prefetch_factor for dataset apply_spatial_patchify: int = 0 # apply apply_spatial_patchify or not debug_bsc: int = 0 # save figs and set breakpoint for debug bsc and check input task_type: str = 't2i' # take type to t2i or t2v ############################ Attention! The following arguments and configurations are set automatically, you can skip reading the following part ############################### ############################ Attention! The following arguments and configurations are set automatically, you can skip reading the following part ############################### ############################ Attention! The following arguments and configurations are set automatically, you can skip reading the following part ############################### # would be automatically set in runtime branch: str = subprocess.check_output(f'git symbolic-ref --short HEAD 2>/dev/null || git rev-parse HEAD', shell=True).decode('utf-8').strip() or '[unknown]' # [automatically set; don't specify this] commit_id: str = '' # subprocess.check_output(f'git rev-parse HEAD', shell=True).decode('utf-8').strip() or '[unknown]' # [automatically set; don't specify this] commit_msg: str = ''# (subprocess.check_output(f'git log -1', shell=True).decode('utf-8').strip().splitlines() or ['[unknown]'])[-1].strip() # [automatically set; don't specify this] cmd: str = ' '.join(a.replace('--exp_name=', '').replace('--exp_name ', '') for a in sys.argv[7:]) # [automatically set; don't specify this] tag: str = 'UK' # [automatically set; don't specify this] acc_all: float = None # [automatically set; don't specify this] acc_real: float = None # [automatically set; don't specify this] acc_fake: float = None # [automatically set; don't specify this] last_Lnll: float = None # [automatically set; don't specify this] last_L1: float = None # [automatically set; don't specify this] last_Ld: float = None # [automatically set; don't specify this] last_wei_g: float = None # [automatically set; don't specify this] grad_boom: str = None # [automatically set; don't specify this] diff: float = None # [automatically set; don't specify this] diffs: str = '' # [automatically set; don't specify this] diffs_ema: str = None # [automatically set; don't specify this] ca_performance: str = '' # [automatically set; don't specify this] cur_phase: str = '' # [automatically set; don't specify this] cur_it: str = '' # [automatically set; don't specify this] cur_ep: str = '' # [automatically set; don't specify this] remain_time: str = '' # [automatically set; don't specify this] finish_time: str = '' # [automatically set; don't specify this] iter_speed: float = None # [automatically set; don't specify this] img_per_day: float = None # [automatically set; don't specify this] max_nvidia_smi: float = 0 # [automatically set; don't specify this] max_memory_allocated: float = None # [automatically set; don't specify this] max_memory_reserved: float = None # [automatically set; don't specify this] num_alloc_retries: int = None # [automatically set; don't specify this] MFU: float = None # [automatically set; don't specify this] HFU: float = None # [automatically set; don't specify this] # ================================================================================================================== # ======================== ignore these parts below since they are only for debug use ============================== # ================================================================================================================== dbg_modified: bool = False dbg_ks: bool = False dbg_ks_last = None dbg_ks_fp = None def dbg_ks_this_line(self, g_it: int): if self.dbg_ks: if self.dbg_ks_last is None: self.dbg_ks_last = deque(maxlen=6) from utils.misc import time_str self.dbg_ks_fp.seek(0) f_back = sys._getframe().f_back file_desc = f'{f_back.f_code.co_filename:24s}'[-24:] info = f'{time_str()} ({file_desc}, line{f_back.f_lineno:-4d})' if g_it is not None: info += f' [g_it: {g_it}]' self.dbg_ks_last.append(info) self.dbg_ks_fp.write('\n'.join(self.dbg_ks_last) + '\n') self.dbg_ks_fp.flush() dbg: bool = 'KEVIN_LOCAL' in os.environ # only used when debug about unused param in DDP ks: bool = False nodata: bool = False # if True, will set nova=True as well nodata_tlen: int = 320 nova: bool = False # no val, no FID prof: int = 0 # profile prof_freq: int = 50 # profile tos_profiler_file_prefix: str = 'vgpt_default/' profall: int = 0 @property def is_vae_visualization_only(self) -> bool: return self.v_seed > 0 v_seed: int = 0 # v_seed != 0 means the visualization-only mode @property def is_gpt_visualization_only(self) -> bool: return self.g_seed > 0 g_seed: int = 0 # g_seed != 0 means the visualization-only mode # ================================================================================================================== # ======================== ignore these parts above since they are only for debug use ============================== # ================================================================================================================== @property def gpt_training(self): return len(self.model) > 0 def set_initial_seed(self, benchmark: bool): torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = benchmark if self.seed is None: torch.backends.cudnn.deterministic = False else: seed = self.seed + (dist.get_rank()*512 if self.rand else 0) torch.backends.cudnn.deterministic = True os.environ['PYTHONHASHSEED'] = str(seed) random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) def get_different_generator_for_each_rank(self) -> Optional[torch.Generator]: # for random augmentation if self.seed is None: return None g = torch.Generator() g.manual_seed(self.seed + dist.get_rank()*512) return g def compile_model(self, m, fast): if fast == 0: return m return torch.compile(m, mode={ 1: 'reduce-overhead', 2: 'max-autotune', 3: 'default', }[fast]) if hasattr(torch, 'compile') else m def dump_log(self): if not dist.is_local_master(): return nd = {'is_master': dist.is_visualizer()} r_trial, trial = str(self.real_trial_id), str(self.trial_id) for k, v in { 'name': self.exp_name, 'tag': self.tag, 'cmd': self.cmd, 'commit': self.commit_id, 'branch': self.branch, 'Lnll': self.last_Lnll, 'L1': self.last_L1, 'Ld': self.last_Ld, 'acc': self.acc_all, 'acc_r': self.acc_real, 'acc_f': self.acc_fake, 'weiG': self.last_wei_g if (self.last_wei_g is None or math.isfinite(self.last_wei_g)) else -23333, 'grad': self.grad_boom, 'cur': self.cur_phase, 'cur_ep': self.cur_ep, 'cur_it': self.cur_it, 'rema': self.remain_time, 'fini': self.finish_time, 'last_upd': time.strftime("%Y-%m-%d %H:%M", time.localtime()), 'bsep': f'{self.glb_batch_size}/{self.ep}', 'G_lrwd': f'{self.glr:.1e}'.replace('.0', '').replace('-0', '-').replace('+0', '+') + f'/{self.gwd:g}', 'D_lrwd': f'{self.dlr:.1e}'.replace('.0', '').replace('-0', '-').replace('+0', '+') + f'/{self.dwd:g}', 'T_lrwd': f'{self.tlr:.1e}'.replace('.0', '').replace('-0', '-').replace('+0', '+') + f'/{self.twd:g}', 'diff': self.diff, 'diffs': self.diffs, 'diffs_ema': self.diffs_ema if self.diffs_ema else None, 'opt': self.opt, 'is_master_node': self.is_master_node, }.items(): if hasattr(v, 'item'):v = v.item() if v is None or (isinstance(v, str) and len(v) == 0): continue nd[k] = v if r_trial == trial: nd.pop('trial', None) with open(self.log_txt_path, 'w') as fp: json.dump(nd, fp, indent=2) def touch_log(self): # listener will kill me if log_txt_path is not updated for 120s os.utime(self.log_txt_path) # about 2e-6 sec def state_dict(self, key_ordered=True) -> Union[OrderedDict, dict]: d = (OrderedDict if key_ordered else dict)() # self.as_dict() would contain methods, but we only need variables for k in self.class_variables.keys(): if k not in {'device', 'dbg_ks_fp'}: # these are not serializable d[k] = getattr(self, k) return d def load_state_dict(self, d: Union[OrderedDict, dict, str]): if isinstance(d, str): # for compatibility with old version d: dict = eval('\n'.join([l for l in d.splitlines() if ' 0 and args.is_master_node == 0: print(f'======================================================================================') print(f'=========================== WARNING: UNEXPECTED EXTRA ARGS ===========================\n{args.extra_args}') print(f'=========================== WARNING: UNEXPECTED EXTRA ARGS ===========================') print(f'======================================================================================\n\n') args.set_tf32(args.tf32) if args.dbg: torch.autograd.set_detect_anomaly(True) try: os.makedirs(args.bed, exist_ok=True) except: pass try: os.makedirs(args.local_out_path, exist_ok=True) except: pass day3 = 60*24*3 dist.init_distributed_mode(local_out_path=args.local_out_path, fork=False, timeout_minutes=day3 if int(os.environ.get('LONG_DBG', '0') or '0') > 0 else 30) args.tlen = max(args.tlen, args.nodata_tlen) if args.zero and args.tema != 0: args.tema = 0 print(f'======================================================================================') print(f'======================== WARNING: args.tema:=0, due to zero={args.zero} ========================') print(f'======================================================================================\n\n') if args.nodata: args.nova = True if not args.tos_profiler_file_prefix.endswith('/'): args.tos_profiler_file_prefix += '/' if args.alng < 0: args.alng = args.aln args.device = dist.get_device() args.r_accu = 1 / args.ac # gradient accumulation args.data_load_reso = None args.rand |= args.seed is None args.sche = args.sche or ('lin0' if args.gpt_training else 'cos') if args.wp == 0: args.wp = args.ep * 1/100 di = { 'b': 'bilinear', 'c': 'bicubic', 'n': 'nearest', 'a': 'area', 'aa': 'area+area', 'at': 'auto', 'auto': 'auto', 'v': 'vae', 'x': 'pix', 'xg': 'pix_glu', 'gx': 'pix_glu', 'g': 'pix_glu' } args.ada = args.ada or ('0.9_0.96' if args.gpt_training else '0.5_0.9') args.dada = args.dada or args.ada args.opt = args.opt.lower().strip() if args.lbs: bs_per_gpu = args.lbs / args.ac else: bs_per_gpu = args.bs / args.ac / dist.get_world_size() bs_per_gpu = round(bs_per_gpu) args.batch_size = bs_per_gpu args.bs = args.glb_batch_size = args.batch_size * dist.get_world_size() args.workers = min(args.workers, bs_per_gpu) args.dblr = args.dblr or args.gblr args.glr = args.ac * args.gblr * args.glb_batch_size / 256 args.dlr = args.ac * args.dblr * args.glb_batch_size / 256 args.tlr = args.ac * args.tblr * args.glb_batch_size / 256 args.gwde = args.gwde or args.gwd args.dwde = args.dwde or args.dwd args.twde = args.twde or args.twd if args.dbg_modified: torch.autograd.set_detect_anomaly(True) args.dbg_ks &= dist.is_local_master() if args.dbg_ks: args.dbg_ks_fp = open(os.path.join(args.local_out_path, 'dbg_ks.txt'), 'w') # gpt args if args.gpt_training: assert args.vae_ckpt, 'VAE ckpt must be specified when training GPT' from infinity.models import alias_dict, alias_dict_inv if args.model in alias_dict: args.model = alias_dict[args.model] args.model_alias = alias_dict_inv[args.model] else: args.model_alias = args.model args.model = f'infinity_{args.model}' args.task_id = '123' args.trial_id = '123' args.robust_run_id = '0' args.log_txt_path = os.path.join(args.local_out_path, 'log.txt') ls = '[]' if 'AUTO_RESUME' in os.environ: ls.append(int(os.environ['AUTO_RESUME'])) ls = sorted(ls, reverse=True) ls = [str(i) for i in ls] args.ckpt_trials = ls args.real_trial_id = args.trial_id if len(ls) == 0 else str(ls[-1]) args.enable_checkpointing = None if args.enable_checkpointing in [False, 0, "0"] else args.enable_checkpointing args.enable_checkpointing = "full-block" if args.enable_checkpointing in [True, 1, "1"] else args.enable_checkpointing assert args.enable_checkpointing in [None, "full-block", "full-attn", "self-attn"], \ f"only support no-checkpointing or full-block/full-attn checkpointing, but got {args.enable_checkpointing}." if len(args.exp_name) == 0: args.exp_name = os.path.basename(args.bed) or 'test_exp' if '-' in args.exp_name: args.tag, args.exp_name = args.exp_name.split('-', maxsplit=1) else: args.tag = 'UK' if dist.is_master(): os.system(f'rm -rf {os.path.join(args.bed, "ready-node*")} {os.path.join(args.local_out_path, "ready-node*")}') if args.sdpa_mem: from torch.backends.cuda import enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp enable_flash_sdp(True) enable_mem_efficient_sdp(True) enable_math_sdp(False) return args