Spaces:
Sleeping
Sleeping
import ml_collections | |
def d(**kwargs): | |
"""Helper of creating a config dict.""" | |
return ml_collections.ConfigDict(initial_dictionary=kwargs) | |
def get_config(): | |
config = ml_collections.ConfigDict() | |
config.seed = 1234 | |
config.z_shape = (8, 16, 16) | |
config.autoencoder = d( | |
config_file='vq-f16-jax.yaml', | |
) | |
config.train = d( | |
n_steps=999999999, | |
batch_size=2048, | |
log_interval=10, | |
eval_interval=5000, | |
save_interval=5000, | |
fid_interval=50000, | |
num_workers=8, | |
resampled=False, | |
) | |
config.eval = d( | |
n_samples=10000, | |
sample_steps=18, | |
) | |
config.optimizer = d( | |
name='adamw', | |
lr=0.0002, | |
weight_decay=0.03, | |
betas=(0.99, 0.99), | |
) | |
config.lr_scheduler = d( | |
name='customized', | |
warmup_steps=5000 | |
) | |
config.nnet = d( | |
name='uvit_t2i_vq', | |
img_size=16, | |
codebook_size=1024, | |
in_chans=4, | |
embed_dim=1152, | |
depth=28, | |
num_heads=16, | |
mlp_ratio=4, | |
qkv_bias=False, | |
clip_dim=1280, | |
num_clip_token=77, | |
use_checkpoint=True, | |
skip=True, | |
) | |
config.muse = d( | |
ignore_ind=-1, | |
smoothing=0.1, | |
gen_temp=4.5 | |
) | |
config.dataset = d( | |
name='cc3m_web', | |
cfg=True, | |
p_uncond=0.15, | |
) | |
config.wds = d( | |
train_data='assets/datasets/cc3m/vq_f16_jax_clipG_cc3m_train_emb/{00000..03044}.tar', | |
val_data='assets/datasets/cc3m/vq_f16_jax_clipG_cc3m_val_emb/{00000..00012}.tar', | |
ctx_path='assets/contexts', | |
dist_eval=True, | |
) | |
config.sample = d( | |
sample_steps=18, | |
n_samples=30000, | |
mini_batch_size=2, | |
cfg=True, | |
linear_inc_scale=True, | |
scale=10., | |
path='', | |
) | |
return config | |