mbreuss commited on
Commit
e18392c
·
verified ·
1 Parent(s): 0fa3b88

Create config.yaml

Browse files
Files changed (1) hide show
  1. config.yaml +254 -0
config.yaml ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ callbacks:
2
+ rollout_lh:
3
+ _target_: mode.rollout.libero_rollout.RolloutLibero
4
+ _recursive_: false
5
+ env_cfg:
6
+ _target_: mode.wrappers.hulc_wrapper.HulcWrapper
7
+ skip_epochs: ${rollout_lh_skip_epochs}
8
+ benchmark_name: ${libero_benchmark}
9
+ rollout_freq: 10
10
+ num_videos: 0
11
+ num_sequences: 50
12
+ max_steps: 600
13
+ empty_cache: false
14
+ debug: false
15
+ n_eval: 20
16
+ num_procs: 10
17
+ use_mp: false
18
+ task_embedding_format: clip
19
+ device: ${device}
20
+ checkpoint:
21
+ _target_: pytorch_lightning.callbacks.ModelCheckpoint
22
+ save_top_k: 1
23
+ verbose: true
24
+ monitor: eval_lh/avg_seq_len
25
+ mode: max
26
+ dirpath: saved_models
27
+ filename: '{epoch:02d}_{eval_lh/avg_seq_len:.2f}'
28
+ every_n_epochs: ${callbacks.rollout_lh.rollout_freq}
29
+ ema:
30
+ _target_: mode.callbacks.ema.EMA
31
+ decay: 0.999
32
+ start_step: 0
33
+ save_ema_weights_in_callback_state: true
34
+ evaluate_ema_weights_instead: true
35
+ power: 0.6666666666666666
36
+ inv_gamma: 1.0
37
+ min_value: 0.0
38
+ max_value: 0.9999
39
+ datamodule:
40
+ datasets:
41
+ lang_dataset:
42
+ _target_: mode.datasets.libero_dataset.LiberoMultitaskDataset
43
+ key: lang
44
+ benchmark_name: ${libero_benchmark}
45
+ batch_size: ${batch_size}
46
+ proprio_state: ${datamodule.proprioception_dims}
47
+ obs_space: ${datamodule.observation_space}
48
+ num_workers: ${num_workers}
49
+ action_seq_len: ${act_seq_len}
50
+ obs_seq_len: ${obs_seq_len}
51
+ split_ratio: 0.0
52
+ transforms:
53
+ train:
54
+ rgb_static:
55
+ - _target_: torchvision.transforms.Resize
56
+ size: 224
57
+ antialias: true
58
+ - _target_: mode.utils.transforms.RandomShiftsAug
59
+ pad: 10
60
+ - _target_: mode.utils.transforms.ScaleImageTensor
61
+ - _target_: torchvision.transforms.Normalize
62
+ mean:
63
+ - 0.48145466
64
+ - 0.4578275
65
+ - 0.40821073
66
+ std:
67
+ - 0.26862954
68
+ - 0.26130258
69
+ - 0.27577711
70
+ rgb_gripper:
71
+ - _target_: torchvision.transforms.Resize
72
+ size: 112
73
+ antialias: true
74
+ - _target_: mode.utils.transforms.RandomShiftsAug
75
+ pad: 4
76
+ - _target_: mode.utils.transforms.ScaleImageTensor
77
+ - _target_: torchvision.transforms.Normalize
78
+ mean:
79
+ - 0.48145466
80
+ - 0.4578275
81
+ - 0.40821073
82
+ std:
83
+ - 0.26862954
84
+ - 0.26130258
85
+ - 0.27577711
86
+ val:
87
+ rgb_static:
88
+ - _target_: torchvision.transforms.Resize
89
+ size: 224
90
+ antialias: true
91
+ - _target_: mode.utils.transforms.ScaleImageTensor
92
+ - _target_: torchvision.transforms.Normalize
93
+ mean:
94
+ - 0.48145466
95
+ - 0.4578275
96
+ - 0.40821073
97
+ std:
98
+ - 0.26862954
99
+ - 0.26130258
100
+ - 0.27577711
101
+ rgb_gripper:
102
+ - _target_: torchvision.transforms.Resize
103
+ size: 112
104
+ antialias: true
105
+ - _target_: mode.utils.transforms.ScaleImageTensor
106
+ - _target_: torchvision.transforms.Normalize
107
+ mean:
108
+ - 0.48145466
109
+ - 0.4578275
110
+ - 0.40821073
111
+ std:
112
+ - 0.26862954
113
+ - 0.26130258
114
+ - 0.27577711
115
+ _target_: mode.datasets.libero_data_module.LiberoDataModule
116
+ _recursive_: false
117
+ root_data_dir: ${root_data_dir}
118
+ action_space: 7
119
+ shuffle_val: false
120
+ benchmark_name: ${libero_benchmark}
121
+ observation_space:
122
+ rgb_obs:
123
+ - agentview_rgb
124
+ - eye_in_hand_rgb
125
+ depth_obs: []
126
+ state_obs:
127
+ - gripper_states
128
+ - joint_states
129
+ actions:
130
+ - rel_actions
131
+ language:
132
+ - language
133
+ proprioception_dims: None
134
+ model:
135
+ language_goal:
136
+ _target_: mode.models.networks.clip_lang_encoder.LangClip
137
+ _recursive_: false
138
+ model_name: ${clip_lang_model_name}
139
+ model:
140
+ _target_: mode.models.edm_diffusion.score_wrappers.GCDenoiser
141
+ _recursive_: false
142
+ sigma_data: ${model.sigma_data}
143
+ inner_model:
144
+ _target_: mode.models.networks.modedit.MoDeDiT
145
+ action_dim: ${datamodule.action_space}
146
+ goal_dim: ${model.cond_dim}
147
+ obs_dim: ${obs_dim}
148
+ goal_conditioned: true
149
+ causal: true
150
+ use_custom_attn_mask: false
151
+ use_proprio: ${model.use_proprio}
152
+ state_dim: ${proprio_dims}
153
+ embed_dim: ${model.latent_dim}
154
+ n_layers: 12
155
+ goal_seq_len: 1
156
+ obs_seq_len: ${obs_seq_len}
157
+ action_seq_len: ${act_seq_len}
158
+ embed_pdrob: 0
159
+ goal_drop: 0.1
160
+ attn_pdrop: 0.3
161
+ mlp_pdrop: 0.1
162
+ n_heads: 8
163
+ device: ${device}
164
+ linear_output: true
165
+ cond_router: true
166
+ num_experts: 4
167
+ top_k: 2
168
+ router_normalize: true
169
+ use_goal_in_routing: false
170
+ use_argmax: false
171
+ use_shared_expert: false
172
+ use_noise_token_as_input: true
173
+ init_style: olmoe
174
+ _target_: mode.models.mode_agent.MoDEAgent
175
+ _recursive_: false
176
+ multistep: ${multistep}
177
+ use_lr_scheduler: true
178
+ entropy_gamma: 0.0
179
+ router_z_delta: 0.0
180
+ use_proprio: false
181
+ seed: ${seed}
182
+ sampler_type: ddim
183
+ num_sampling_steps: 5
184
+ sigma_data: 0.5
185
+ sigma_min: 0.001
186
+ sigma_max: 80
187
+ noise_scheduler: exponential
188
+ sigma_sample_density_type: loglogistic
189
+ ckpt_path: /home/reuss/code/MeDiT_Policy/convert_weights/mode_first_run
190
+ start_from_pretrained: true
191
+ act_window_size: ${act_seq_len}
192
+ latent_dim: 1024
193
+ obs_enc_dim: ${obs_dim}
194
+ cond_dim: 512
195
+ resnet_type: '50'
196
+ optimizer:
197
+ _target_: torch.optim.AdamW
198
+ transformer_weight_decay: 0.05
199
+ obs_encoder_weight_decay: 0.05
200
+ learning_rate: 0.0001
201
+ betas:
202
+ - 0.9
203
+ - 0.95
204
+ lr_scheduler:
205
+ lr_scheduler:
206
+ init_lr: 0.0001
207
+ init_lr_scale: 0.1
208
+ final_lr_scale: 1.0e-06
209
+ total_steps: 40000
210
+ phase_ratio: (0.02, 0.08, 0.9)
211
+ lr: 0.0001
212
+ root_data_dir: /home/yagmurlu/code/MoDE_Calvin/dataset/task_ABC_D
213
+ lang_folder: lang_clip_resnet50
214
+ vis_clip_model_name: ViT-B/16
215
+ clip_lang_model_name: ViT-B/32
216
+ log_dir: ./logs
217
+ slurm: false
218
+ future_range: 29
219
+ seed: 242
220
+ device: cuda
221
+ batch_size: 128
222
+ devices: 2
223
+ goal_window_size: 1
224
+ act_dim: 7
225
+ proprio_dims: 9
226
+ obs_dim: 512
227
+ goal_dim: 512
228
+ obs_seq_len: 1
229
+ act_seq_len: 10
230
+ multistep: ${act_seq_len}
231
+ p_last_state: 0
232
+ gen_img_res: 112
233
+ max_epochs: 10
234
+ rollout_lh_skip_epochs: 9
235
+ num_workers: 1
236
+ benchmark_name: ${libero_benchmark}
237
+ libero_benchmark: libero_10
238
+ trainer:
239
+ gpus: ${devices}
240
+ precision: bf16
241
+ max_epochs: ${max_epochs}
242
+ sync_batchnorm: false
243
+ accelerator: auto
244
+ limit_train_batches: 1000
245
+ limit_val_batches: 4
246
+ logger:
247
+ _target_: pytorch_lightning.loggers.WandbLogger
248
+ save_dir: .
249
+ name: logger
250
+ group: mode
251
+ log_model: false
252
+ project: ${libero_benchmark}
253
+ entity: bennoq
254
+ id: ???