styledrop / configs /imagenet256_base_vq_jax.py
zideliu's picture
StyleDrop init
28c6826
raw
history blame
1.64 kB
import ml_collections
def d(**kwargs):
"""Helper of creating a config dict."""
return ml_collections.ConfigDict(initial_dictionary=kwargs)
def get_config():
config = ml_collections.ConfigDict()
config.seed = 1234
config.z_shape = (8, 16, 16)
config.autoencoder = d(
config_file='vq-f16-jax.yaml',
)
config.train = d(
n_steps=99999999,
batch_size=2048,
log_interval=10,
eval_interval=5000,
save_interval=5000,
fid_interval=50000,
)
config.eval = d(
n_samples=10000,
sample_steps=12,
)
config.optimizer = d(
name='adamw',
lr=0.0004,
weight_decay=0.03,
betas=(0.99, 0.99),
)
config.lr_scheduler = d(
name='customized',
warmup_steps=5000
)
config.nnet = d(
name='uvit_vq',
img_size=16,
codebook_size=1024,
in_chans=256,
patch_size=1,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4,
qkv_bias=False,
num_classes=1001,
use_checkpoint=False,
skip=True,
)
config.muse = d(
ignore_ind=-1,
smoothing=0.1,
gen_temp=4.5
)
config.dataset = d(
name='imagenet256_features',
path='assets/datasets/imagenet256_vq_features/vq-f16-jax',
cfg=True,
p_uncond=0.15,
)
config.sample = d(
sample_steps=12,
n_samples=50000,
mini_batch_size=50,
cfg=True,
linear_inc_scale=True,
scale=3.,
path=''
)
return config