cgomes98's picture
add config file
a82c831
raw
history blame
3.4 kB
# lightning.pytorch==2.1.1
seed_everything: 42
### Trainer configuration
trainer:
accelerator: auto
strategy: auto
devices: auto
num_nodes: 1
# precision: 16-mixed
logger:
class_path: TensorBoardLogger
init_args:
save_dir: ./experiments
name: finetune_region
callbacks:
- class_path: RichProgressBar
- class_path: LearningRateMonitor
init_args:
logging_interval: epoch
- class_path: EarlyStopping
init_args:
monitor: val/loss
patience: 100
max_epochs: 300
check_val_every_n_epoch: 1
log_every_n_steps: 20
enable_checkpointing: true
default_root_dir: ./experiments
### Data configuration
data:
class_path: GenericNonGeoPixelwiseRegressionDataModule
init_args:
batch_size: 64
num_workers: 8
train_transform:
- class_path: albumentations.HorizontalFlip
init_args:
p: 0.5
- class_path: albumentations.Rotate
init_args:
limit: 30
border_mode: 0 # cv2.BORDER_CONSTANT
value: 0
# mask_value: 1
p: 0.5
- class_path: ToTensorV2
# Specify all bands which are in the input data.
# -1 are placeholders for bands that are in the data but that we will discard
dataset_bands:
- -1
- BLUE
- GREEN
- RED
- NIR_NARROW
- SWIR_1
- SWIR_2
- -1
- -1
- -1
- -1
output_bands: #Specify the bands which are used from the input data.
- BLUE
- GREEN
- RED
- NIR_NARROW
- SWIR_1
- SWIR_2
rgb_indices:
- 2
- 1
- 0
# Directory roots to training, validation and test datasplits:
train_data_root: train_images
train_label_data_root: train_labels
val_data_root: val_images
val_label_data_root: val_labels
test_data_root: test_images
test_label_data_root: test_labels
means: # Mean value of the training dataset per band
- 547.36707
- 898.5121
- 1020.9082
- 2665.5352
- 2340.584
- 1610.1407
stds: # Standard deviation of the training dataset per band
- 411.4701
- 558.54065
- 815.94025
- 812.4403
- 1113.7145
- 1067.641
# Nodata value in label data
no_label_replace: -1
# Nodata value in the input data
no_data_replace: 0
### Model configuration
model:
class_path: terratorch.tasks.PixelwiseRegressionTask
init_args:
model_args:
decoder: UperNetDecoder
pretrained: false
backbone: prithvi_swin_B
backbone_drop_path_rate: 0.3
decoder_channels: 32
in_channels: 6
bands:
- BLUE
- GREEN
- RED
- NIR_NARROW
- SWIR_1
- SWIR_2
num_frames: 1
head_dropout: 0.16194593880230534
head_final_act: torch.nn.ReLU
head_learned_upscale_layers: 2
loss: rmse
ignore_index: -1
freeze_backbone: false
freeze_decoder: false
model_factory: PrithviModelFactory
# uncomment this block for tiled inference
# tiled_inference_parameters:
# h_crop: 224
# h_stride: 192
# w_crop: 224
# w_stride: 192
# average_patches: true
optimizer:
class_path: torch.optim.AdamW
init_args:
lr: 0.00031406904191973693
weight_decay: 0.03283253068408954
lr_scheduler:
class_path: ReduceLROnPlateau
init_args:
monitor: val/loss