add config file
Browse files- config.yaml +142 -0
config.yaml
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# lightning.pytorch==2.1.1
|
2 |
+
seed_everything: 42
|
3 |
+
|
4 |
+
### Trainer configuration
|
5 |
+
trainer:
|
6 |
+
accelerator: auto
|
7 |
+
strategy: auto
|
8 |
+
devices: auto
|
9 |
+
num_nodes: 1
|
10 |
+
# precision: 16-mixed
|
11 |
+
logger:
|
12 |
+
class_path: TensorBoardLogger
|
13 |
+
init_args:
|
14 |
+
save_dir: ./experiments
|
15 |
+
name: finetune_region
|
16 |
+
callbacks:
|
17 |
+
- class_path: RichProgressBar
|
18 |
+
- class_path: LearningRateMonitor
|
19 |
+
init_args:
|
20 |
+
logging_interval: epoch
|
21 |
+
- class_path: EarlyStopping
|
22 |
+
init_args:
|
23 |
+
monitor: val/loss
|
24 |
+
patience: 100
|
25 |
+
max_epochs: 300
|
26 |
+
check_val_every_n_epoch: 1
|
27 |
+
log_every_n_steps: 20
|
28 |
+
enable_checkpointing: true
|
29 |
+
default_root_dir: ./experiments
|
30 |
+
|
31 |
+
### Data configuration
|
32 |
+
data:
|
33 |
+
class_path: GenericNonGeoPixelwiseRegressionDataModule
|
34 |
+
init_args:
|
35 |
+
batch_size: 64
|
36 |
+
num_workers: 8
|
37 |
+
train_transform:
|
38 |
+
- class_path: albumentations.HorizontalFlip
|
39 |
+
init_args:
|
40 |
+
p: 0.5
|
41 |
+
- class_path: albumentations.Rotate
|
42 |
+
init_args:
|
43 |
+
limit: 30
|
44 |
+
border_mode: 0 # cv2.BORDER_CONSTANT
|
45 |
+
value: 0
|
46 |
+
# mask_value: 1
|
47 |
+
p: 0.5
|
48 |
+
- class_path: ToTensorV2
|
49 |
+
# Specify all bands which are in the input data.
|
50 |
+
# -1 are placeholders for bands that are in the data but that we will discard
|
51 |
+
dataset_bands:
|
52 |
+
- -1
|
53 |
+
- BLUE
|
54 |
+
- GREEN
|
55 |
+
- RED
|
56 |
+
- NIR_NARROW
|
57 |
+
- SWIR_1
|
58 |
+
- SWIR_2
|
59 |
+
- -1
|
60 |
+
- -1
|
61 |
+
- -1
|
62 |
+
- -1
|
63 |
+
output_bands: #Specify the bands which are used from the input data.
|
64 |
+
- BLUE
|
65 |
+
- GREEN
|
66 |
+
- RED
|
67 |
+
- NIR_NARROW
|
68 |
+
- SWIR_1
|
69 |
+
- SWIR_2
|
70 |
+
rgb_indices:
|
71 |
+
- 2
|
72 |
+
- 1
|
73 |
+
- 0
|
74 |
+
# Directory roots to training, validation and test datasplits:
|
75 |
+
train_data_root: train_images
|
76 |
+
train_label_data_root: train_labels
|
77 |
+
val_data_root: val_images
|
78 |
+
val_label_data_root: val_labels
|
79 |
+
test_data_root: test_images
|
80 |
+
test_label_data_root: test_labels
|
81 |
+
means: # Mean value of the training dataset per band
|
82 |
+
- 547.36707
|
83 |
+
- 898.5121
|
84 |
+
- 1020.9082
|
85 |
+
- 2665.5352
|
86 |
+
- 2340.584
|
87 |
+
- 1610.1407
|
88 |
+
stds: # Standard deviation of the training dataset per band
|
89 |
+
- 411.4701
|
90 |
+
- 558.54065
|
91 |
+
- 815.94025
|
92 |
+
- 812.4403
|
93 |
+
- 1113.7145
|
94 |
+
- 1067.641
|
95 |
+
# Nodata value in label data
|
96 |
+
no_label_replace: -1
|
97 |
+
# Nodata value in the input data
|
98 |
+
no_data_replace: 0
|
99 |
+
|
100 |
+
### Model configuration
|
101 |
+
model:
|
102 |
+
class_path: terratorch.tasks.PixelwiseRegressionTask
|
103 |
+
init_args:
|
104 |
+
model_args:
|
105 |
+
decoder: UperNetDecoder
|
106 |
+
pretrained: false
|
107 |
+
backbone: prithvi_swin_B
|
108 |
+
backbone_drop_path_rate: 0.3
|
109 |
+
decoder_channels: 32
|
110 |
+
in_channels: 6
|
111 |
+
bands:
|
112 |
+
- BLUE
|
113 |
+
- GREEN
|
114 |
+
- RED
|
115 |
+
- NIR_NARROW
|
116 |
+
- SWIR_1
|
117 |
+
- SWIR_2
|
118 |
+
num_frames: 1
|
119 |
+
head_dropout: 0.16194593880230534
|
120 |
+
head_final_act: torch.nn.ReLU
|
121 |
+
head_learned_upscale_layers: 2
|
122 |
+
loss: rmse
|
123 |
+
ignore_index: -1
|
124 |
+
freeze_backbone: false
|
125 |
+
freeze_decoder: false
|
126 |
+
model_factory: PrithviModelFactory
|
127 |
+
# uncomment this block for tiled inference
|
128 |
+
# tiled_inference_parameters:
|
129 |
+
# h_crop: 224
|
130 |
+
# h_stride: 192
|
131 |
+
# w_crop: 224
|
132 |
+
# w_stride: 192
|
133 |
+
# average_patches: true
|
134 |
+
optimizer:
|
135 |
+
class_path: torch.optim.AdamW
|
136 |
+
init_args:
|
137 |
+
lr: 0.00031406904191973693
|
138 |
+
weight_decay: 0.03283253068408954
|
139 |
+
lr_scheduler:
|
140 |
+
class_path: ReduceLROnPlateau
|
141 |
+
init_args:
|
142 |
+
monitor: val/loss
|