|
#!/bin/bash |
|
|
|
|
|
|
|
|
|
|
|
|
|
GPUS=8 |
|
VAE_MODEL_PATH=PATH/vae_ckpt |
|
LPIPS_CKPT=vgg_lpips.pth |
|
OUTPUT_DIR=/PATH/output_dir |
|
IMAGE_ANNO=annotation/image_text.jsonl |
|
VIDEO_ANNO=annotation/video_text.jsonl |
|
RESOLUTION=256 |
|
NUM_FRAMES=17 |
|
BATCH_SIZE=2 |
|
|
|
|
|
|
|
torchrun --nproc_per_node $GPUS \ |
|
train/train_video_vae.py \ |
|
--num_workers 6 \ |
|
--model_path $VAE_MODEL_PATH \ |
|
--model_dtype bf16 \ |
|
--lpips_ckpt $LPIPS_CKPT \ |
|
--output_dir $OUTPUT_DIR \ |
|
--image_anno $IMAGE_ANNO \ |
|
--video_anno $VIDEO_ANNO \ |
|
--use_image_video_mixed_training \ |
|
--image_mix_ratio 0.1 \ |
|
--resolution $RESOLUTION \ |
|
--max_frames $NUM_FRAMES \ |
|
--disc_start 250000 \ |
|
--kl_weight 1e-12 \ |
|
--pixelloss_weight 10.0 \ |
|
--perceptual_weight 1.0 \ |
|
--disc_weight 0.5 \ |
|
--batch_size $BATCH_SIZE \ |
|
--opt adamw \ |
|
--opt_betas 0.9 0.95 \ |
|
--seed 42 \ |
|
--weight_decay 1e-3 \ |
|
--clip_grad 1.0 \ |
|
--lr 1e-4 \ |
|
--lr_disc 1e-4 \ |
|
--warmup_epochs 1 \ |
|
--epochs 100 \ |
|
--iters_per_epoch 2000 \ |
|
--print_freq 40 \ |
|
--save_ckpt_freq 1 |
|
|
|
|
|
|
|
CONTEXT_SIZE=2 |
|
NUM_FRAMES=33 |
|
VAE_CKPT_PATH=stage1_path |
|
|
|
torchrun --nproc_per_node $GPUS \ |
|
train/train_video_vae.py \ |
|
--num_workers 6 \ |
|
--model_path $VAE_MODEL_PATH \ |
|
--model_dtype bf16 \ |
|
--pretrained_vae_weight $VAE_CKPT_PATH \ |
|
--use_context_parallel \ |
|
--context_size $CONTEXT_SIZE \ |
|
--lpips_ckpt $LPIPS_CKPT \ |
|
--output_dir $OUTPUT_DIR \ |
|
--video_anno $VIDEO_ANNO \ |
|
--image_mix_ratio 0.0 \ |
|
--resolution $RESOLUTION \ |
|
--max_frames $NUM_FRAMES \ |
|
--disc_start 250000 \ |
|
--kl_weight 1e-12 \ |
|
--pixelloss_weight 10.0 \ |
|
--perceptual_weight 1.0 \ |
|
--disc_weight 0.5 \ |
|
--batch_size $BATCH_SIZE \ |
|
--opt adamw \ |
|
--opt_betas 0.9 0.95 \ |
|
--seed 42 \ |
|
--weight_decay 1e-3 \ |
|
--clip_grad 1.0 \ |
|
--lr 1e-4 \ |
|
--lr_disc 1e-4 \ |
|
--warmup_epochs 1 \ |
|
--epochs 100 \ |
|
--iters_per_epoch 2000 \ |
|
--print_freq 40 \ |
|
--save_ckpt_freq 1 |