Infinity / utils /load.py
MohamedRashad's picture
Add initial project structure with requirements and utility functions
32287b3
#!/usr/bin/python3
import gc
import os
import os.path as osp
import random
import sys
from copy import deepcopy
from typing import Tuple, Union
import colorama
import torch
import yaml
import infinity.utils.dist as dist
from infinity.models import Infinity
from infinity.models.ema import get_ema_model
from infinity.utils import arg_util, misc
from infinity.utils.misc import os_system
def build_vae_gpt(args: arg_util.Args, vae_st: dict, skip_gpt: bool, force_flash=False, device='cuda'):
if args.vae_type in [8,16,18,20,24,32,64,128]:
from infinity.models.bsq_vae.vae import vae_model
schedule_mode = "dynamic"
codebook_dim = args.vae_type # 18
codebook_size = 2**codebook_dim
if args.apply_spatial_patchify:
patch_size = 8
encoder_ch_mult=[1, 2, 4, 4]
decoder_ch_mult=[1, 2, 4, 4]
else:
patch_size = 16
encoder_ch_mult=[1, 2, 4, 4, 4]
decoder_ch_mult=[1, 2, 4, 4, 4]
vae_local = vae_model(vae_st, schedule_mode, codebook_dim, codebook_size, patch_size=patch_size,
encoder_ch_mult=encoder_ch_mult, decoder_ch_mult=decoder_ch_mult, test_mode=True).to(args.device)
if args.fake_vae_input:
vae_local.encoder = None
vae_local.decoder = None
torch.cuda.empty_cache()
else:
raise ValueError(f"vae_type {args.vae_type} not supported")
if force_flash: args.flash = True
gpt_kw = dict(
pretrained=False, global_pool='',
text_channels=args.Ct5, text_maxlen=args.tlen,
norm_eps=args.norm_eps, rms_norm=args.rms,
shared_aln=args.saln, head_aln=args.haln,
cond_drop_rate=args.cfg, rand_uncond=args.rand_uncond, drop_rate=args.drop,
cross_attn_layer_scale=args.ca_gamma, nm0=args.nm0, tau=args.tau, cos_attn=args.cos, swiglu=args.swi,
raw_scale_schedule=args.scale_schedule,
head_depth=args.dec,
top_p=args.tp, top_k=args.tk,
customized_flash_attn=args.flash, fused_mlp=args.fuse, fused_norm=args.fused_norm,
checkpointing=args.enable_checkpointing,
pad_to_multiplier=args.pad_to_multiplier,
use_flex_attn=args.use_flex_attn,
batch_size=args.batch_size,
add_lvl_embeding_only_first_block=args.add_lvl_embeding_only_first_block,
use_bit_label=args.use_bit_label,
rope2d_each_sa_layer=args.rope2d_each_sa_layer,
rope2d_normalized_by_hw=args.rope2d_normalized_by_hw,
pn=args.pn,
train_h_div_w_list=args.train_h_div_w_list,
always_training_scales=args.always_training_scales,
apply_spatial_patchify=args.apply_spatial_patchify,
)
if args.dp >= 0: gpt_kw['drop_path_rate'] = args.dp
if args.hd > 0: gpt_kw['num_heads'] = args.hd
print(f'[create gpt_wo_ddp] constructor kw={gpt_kw}\n')
gpt_kw['vae_local'] = vae_local
model_str = args.model.replace('vgpt', 'infinity') # legacy
print(f"{model_str=}")
if model_str.rsplit('c', maxsplit=1)[-1].isdecimal():
model_str, block_chunks = model_str.rsplit('c', maxsplit=1)
block_chunks = int(block_chunks)
else:
block_chunks = 1
gpt_kw['block_chunks'] = block_chunks
from infinity.models import Infinity
from timm.models import create_model
gpt_wo_ddp: Infinity = create_model(model_str, **gpt_kw)
if args.use_fsdp_model_ema:
gpt_wo_ddp_ema = get_ema_model(gpt_wo_ddp)
else:
gpt_wo_ddp_ema = None
gpt_wo_ddp = gpt_wo_ddp.to(device)
assert all(not p.requires_grad for p in vae_local.parameters())
assert all(p.requires_grad for n, p in gpt_wo_ddp.named_parameters())
return vae_local, gpt_wo_ddp, gpt_wo_ddp_ema
if __name__ == '__main__':
ld(sys.argv[1])