Spaces:
Running
on
Zero
Running
on
Zero
#!/usr/bin/python3 | |
import gc | |
import os | |
import os.path as osp | |
import random | |
import sys | |
from copy import deepcopy | |
from typing import Tuple, Union | |
import colorama | |
import torch | |
import yaml | |
import infinity.utils.dist as dist | |
from infinity.models import Infinity | |
from infinity.models.ema import get_ema_model | |
from infinity.utils import arg_util, misc | |
from infinity.utils.misc import os_system | |
def build_vae_gpt(args: arg_util.Args, vae_st: dict, skip_gpt: bool, force_flash=False, device='cuda'): | |
if args.vae_type in [8,16,18,20,24,32,64,128]: | |
from infinity.models.bsq_vae.vae import vae_model | |
schedule_mode = "dynamic" | |
codebook_dim = args.vae_type # 18 | |
codebook_size = 2**codebook_dim | |
if args.apply_spatial_patchify: | |
patch_size = 8 | |
encoder_ch_mult=[1, 2, 4, 4] | |
decoder_ch_mult=[1, 2, 4, 4] | |
else: | |
patch_size = 16 | |
encoder_ch_mult=[1, 2, 4, 4, 4] | |
decoder_ch_mult=[1, 2, 4, 4, 4] | |
vae_local = vae_model(vae_st, schedule_mode, codebook_dim, codebook_size, patch_size=patch_size, | |
encoder_ch_mult=encoder_ch_mult, decoder_ch_mult=decoder_ch_mult, test_mode=True).to(args.device) | |
if args.fake_vae_input: | |
vae_local.encoder = None | |
vae_local.decoder = None | |
torch.cuda.empty_cache() | |
else: | |
raise ValueError(f"vae_type {args.vae_type} not supported") | |
if force_flash: args.flash = True | |
gpt_kw = dict( | |
pretrained=False, global_pool='', | |
text_channels=args.Ct5, text_maxlen=args.tlen, | |
norm_eps=args.norm_eps, rms_norm=args.rms, | |
shared_aln=args.saln, head_aln=args.haln, | |
cond_drop_rate=args.cfg, rand_uncond=args.rand_uncond, drop_rate=args.drop, | |
cross_attn_layer_scale=args.ca_gamma, nm0=args.nm0, tau=args.tau, cos_attn=args.cos, swiglu=args.swi, | |
raw_scale_schedule=args.scale_schedule, | |
head_depth=args.dec, | |
top_p=args.tp, top_k=args.tk, | |
customized_flash_attn=args.flash, fused_mlp=args.fuse, fused_norm=args.fused_norm, | |
checkpointing=args.enable_checkpointing, | |
pad_to_multiplier=args.pad_to_multiplier, | |
use_flex_attn=args.use_flex_attn, | |
batch_size=args.batch_size, | |
add_lvl_embeding_only_first_block=args.add_lvl_embeding_only_first_block, | |
use_bit_label=args.use_bit_label, | |
rope2d_each_sa_layer=args.rope2d_each_sa_layer, | |
rope2d_normalized_by_hw=args.rope2d_normalized_by_hw, | |
pn=args.pn, | |
train_h_div_w_list=args.train_h_div_w_list, | |
always_training_scales=args.always_training_scales, | |
apply_spatial_patchify=args.apply_spatial_patchify, | |
) | |
if args.dp >= 0: gpt_kw['drop_path_rate'] = args.dp | |
if args.hd > 0: gpt_kw['num_heads'] = args.hd | |
print(f'[create gpt_wo_ddp] constructor kw={gpt_kw}\n') | |
gpt_kw['vae_local'] = vae_local | |
model_str = args.model.replace('vgpt', 'infinity') # legacy | |
print(f"{model_str=}") | |
if model_str.rsplit('c', maxsplit=1)[-1].isdecimal(): | |
model_str, block_chunks = model_str.rsplit('c', maxsplit=1) | |
block_chunks = int(block_chunks) | |
else: | |
block_chunks = 1 | |
gpt_kw['block_chunks'] = block_chunks | |
from infinity.models import Infinity | |
from timm.models import create_model | |
gpt_wo_ddp: Infinity = create_model(model_str, **gpt_kw) | |
if args.use_fsdp_model_ema: | |
gpt_wo_ddp_ema = get_ema_model(gpt_wo_ddp) | |
else: | |
gpt_wo_ddp_ema = None | |
gpt_wo_ddp = gpt_wo_ddp.to(device) | |
assert all(not p.requires_grad for p in vae_local.parameters()) | |
assert all(p.requires_grad for n, p in gpt_wo_ddp.named_parameters()) | |
return vae_local, gpt_wo_ddp, gpt_wo_ddp_ema | |
if __name__ == '__main__': | |
ld(sys.argv[1]) | |