counterfactual-world-models / scripts /cwm /3frame_no_clumping_maskvit_gpu.sh
rahulvenkk
app.py updated
6dfcb0f
raw
history blame
1.17 kB
OUTPUT_DIR='/ccn2/u/honglinc/cwm_checkpoints/ablation_3frame_no_clumping_maskvit/'
DATA_PATH="${HOME}/BBNet/bbnet/models/VideoMAE-main/video_file_lists/kinetics_400_train_list.txt"
OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=8 \
--master_addr=10.102.2.153 --master_port=32240 \
--nnodes=2 --node_rank=1 \
run_cwm_pretraining.py \
--data_path ${DATA_PATH} \
--mask_type rotated_table_maskvit \
--mask_ratio 0.99 \
--mask_kwargs '{"tube_length": 1}' \
--model vitbase_8x8patch_3frames_1tube \
--context_frames 2 \
--target_frames 1 \
--temporal_units 'ms' \
--sampling_rate 150 \
--context_target_gap 150 150 \
--batch_size 16 \
--accum_iter 16 \
--opt adamw \
--opt_betas 0.9 0.95 \
--warmup_epochs 40 \
--save_ckpt_freq 10 \
--epochs 800 \
--no_normlize_target \
--rescale_size 224 \
--augmentation_type 'multiscale' \
--augmentation_scales 1.0 0.875 0.75 0.66 \
--log_dir ${OUTPUT_DIR} \
--output_dir ${OUTPUT_DIR} \
--print_freq 20 \
--num_workers 32 \
--min_lr 1e-5