blumenstiel commited on
Commit
92ae9e8
·
1 Parent(s): 7e467ed

Switched from yaml to config.json

Browse files
Files changed (5) hide show
  1. README.md +1 -1
  2. config.json +1 -0
  3. config.yaml +0 -21
  4. inference.py +18 -26
  5. prithvi_mae.py +2 -2
README.md CHANGED
@@ -36,7 +36,7 @@ We provide a **demo** running Prithvi-EO-2.0-300M-TL [here](https://huggingface.
36
  There is also an inference script (`inference.py`) that allows to run the image reconstruction on a set of HLS images assumed to be from the same location at different timestamps (see example below). These should be provided in chronological order in geotiff format, including the channels described above (Blue, Green, Red, Narrow NIR, SWIR 1, SWIR 2) in reflectance units.
37
 
38
  ```
39
- python inference.py --data_files t1.tif t2.tif t3.tif t4.tif --output_dir output/ --input_indices <space separated 0-based indices of channels to select from input>
40
  ```
41
 
42
  ## Finetuning
 
36
  There is also an inference script (`inference.py`) that allows to run the image reconstruction on a set of HLS images assumed to be from the same location at different timestamps (see example below). These should be provided in chronological order in geotiff format, including the channels described above (Blue, Green, Red, Narrow NIR, SWIR 1, SWIR 2) in reflectance units.
37
 
38
  ```
39
+ python inference.py --data_files t1.tif t2.tif t3.tif t4.tif --input_indices <optional, space separated 0-based indices of the six Prithvi channels in your input>
40
  ```
41
 
42
  ## Finetuning
config.json CHANGED
@@ -17,6 +17,7 @@
17
  "coords_scale_learn": true,
18
  "mask_ratio": 0.75,
19
  "norm_pix_loss": false,
 
20
  "mean": [1087.0, 1342.0, 1433.0, 2734.0, 1958.0, 1363.0],
21
  "std": [2248.0, 2179.0, 2178.0, 1850.0, 1242.0, 1049.0],
22
  "origin_url": "https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL",
 
17
  "coords_scale_learn": true,
18
  "mask_ratio": 0.75,
19
  "norm_pix_loss": false,
20
+ "bands": ["B02", "B03", "B04", "B05", "B06", "B07"],
21
  "mean": [1087.0, 1342.0, 1433.0, 2734.0, 1958.0, 1363.0],
22
  "std": [2248.0, 2179.0, 2178.0, 1850.0, 1242.0, 1049.0],
23
  "origin_url": "https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL",
config.yaml DELETED
@@ -1,21 +0,0 @@
1
- DATA:
2
- BANDS: [B02, B03, B04, B05, B06, B07]
3
- INPUT_SIZE: [4, 224, 224]
4
- MASK_RATIO: 0.75
5
- MEAN: [1087.0, 1342.0, 1433.0, 2734.0, 1958.0, 1363.0]
6
- STD: [2248.0, 2179.0, 2178.0, 1850.0, 1242.0, 1049.0]
7
- MODEL:
8
- COORDS_DROP_RATE: 0.1
9
- COORDS_ENCODING: [time, location]
10
- COORDS_SCALE_LEARN: true
11
- DECODER_DEPTH: 8
12
- DECODER_EMBED_DIM: 512
13
- DECODER_NUM_HEADS: 16
14
- DEPTH: 24
15
- DROP_CHANNELS_RATE: 0.0
16
- EMBED_DIM: 1024
17
- MLP_RATIO: 4.0
18
- NAME: vit_l
19
- NORM_PIX_LOSS: false
20
- NUM_HEADS: 16
21
- PATCH_SIZE: [1, 16, 16]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
inference.py CHANGED
@@ -304,18 +304,18 @@ def main(
304
 
305
  # Get parameters --------
306
 
 
307
  with open(config_path, "r") as f:
308
- config = yaml.safe_load(f)
309
 
310
  batch_size = 1
311
- bands = config['DATA']['BANDS']
312
  num_frames = len(data_files)
313
- mean = config['DATA']['MEAN']
314
- std = config['DATA']['STD']
315
- coords_encoding = config['MODEL']['COORDS_ENCODING']
316
- img_size = config['DATA']['INPUT_SIZE'][-1]
317
-
318
- mask_ratio = mask_ratio or config['DATA']['MASK_RATIO']
319
 
320
  print(
321
  f"\nTreating {len(data_files)} files as {len(data_files)} time steps from the same location\n"
@@ -345,21 +345,13 @@ def main(
345
 
346
  # Create model and load checkpoint -------------------------------------------------------------
347
 
348
- model = PrithviMAE(img_size=config['DATA']['INPUT_SIZE'][-2:],
349
- patch_size=config['MODEL']['PATCH_SIZE'],
350
- num_frames=num_frames,
351
- in_chans=len(bands),
352
- embed_dim=config['MODEL']['EMBED_DIM'],
353
- depth=config['MODEL']['DEPTH'],
354
- num_heads=config['MODEL']['NUM_HEADS'],
355
- decoder_embed_dim=config['MODEL']['DECODER_EMBED_DIM'],
356
- decoder_depth=config['MODEL']['DECODER_DEPTH'],
357
- decoder_num_heads=config['MODEL']['DECODER_NUM_HEADS'],
358
- mlp_ratio=config['MODEL']['MLP_RATIO'],
359
- norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
360
- norm_pix_loss=config['MODEL']['NORM_PIX_LOSS'],
361
- coords_encoding=coords_encoding,
362
- coords_scale_learn=config['MODEL']['COORDS_SCALE_LEARN'])
363
 
364
  total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
365
  print(f"\n--> Model has {total_params:,} parameters.\n")
@@ -487,11 +479,11 @@ if __name__ == "__main__":
487
  help="Path to the data files. Assumes multi-band files.",
488
  )
489
  parser.add_argument(
490
- "--config_path",
491
  "-c",
492
  type=str,
493
- default="config.yaml",
494
- help="Path to yaml file containing model training parameters.",
495
  )
496
  parser.add_argument(
497
  "--checkpoint",
 
304
 
305
  # Get parameters --------
306
 
307
+ import json
308
  with open(config_path, "r") as f:
309
+ config = yaml.safe_load(f)['pretrained_cfg']
310
 
311
  batch_size = 1
312
+ bands = config['bands']
313
  num_frames = len(data_files)
314
+ mean = config['mean']
315
+ std = config['std']
316
+ coords_encoding = config['coords_encoding']
317
+ img_size = config['img_size']
318
+ mask_ratio = mask_ratio or config['mask_ratio']
 
319
 
320
  print(
321
  f"\nTreating {len(data_files)} files as {len(data_files)} time steps from the same location\n"
 
345
 
346
  # Create model and load checkpoint -------------------------------------------------------------
347
 
348
+ config.update(
349
+ coords_encoding=coords_encoding,
350
+ num_frames=num_frames,
351
+ in_chans=len(bands),
352
+ )
353
+
354
+ model = PrithviMAE(**config)
 
 
 
 
 
 
 
 
355
 
356
  total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
357
  print(f"\n--> Model has {total_params:,} parameters.\n")
 
479
  help="Path to the data files. Assumes multi-band files.",
480
  )
481
  parser.add_argument(
482
+ "--config",
483
  "-c",
484
  type=str,
485
+ default="config.json",
486
+ help="Path to json file containing model training parameters.",
487
  )
488
  parser.add_argument(
489
  "--checkpoint",
prithvi_mae.py CHANGED
@@ -240,7 +240,7 @@ class PrithviViT(nn.Module):
240
  depth: int = 24,
241
  num_heads: int = 16,
242
  mlp_ratio: float = 4.,
243
- norm_layer: nn.Module = nn.LayerNorm,
244
  coords_encoding: List[str] | None = None,
245
  coords_scale_learn: bool = False,
246
  encoder_only: bool = True, # needed for timm
@@ -598,7 +598,7 @@ class PrithviMAE(nn.Module):
598
  decoder_depth: int = 8,
599
  decoder_num_heads: int = 16,
600
  mlp_ratio: float = 4.,
601
- norm_layer: nn.Module = nn.LayerNorm,
602
  norm_pix_loss: bool = False,
603
  coords_encoding: List[str] | None = None,
604
  coords_scale_learn: bool = False,
 
240
  depth: int = 24,
241
  num_heads: int = 16,
242
  mlp_ratio: float = 4.,
243
+ norm_layer: nn.Module = partial(torch.nn.LayerNorm, eps=1e-6),
244
  coords_encoding: List[str] | None = None,
245
  coords_scale_learn: bool = False,
246
  encoder_only: bool = True, # needed for timm
 
598
  decoder_depth: int = 8,
599
  decoder_num_heads: int = 16,
600
  mlp_ratio: float = 4.,
601
+ norm_layer: nn.Module = partial(torch.nn.LayerNorm, eps=1e-6),
602
  norm_pix_loss: bool = False,
603
  coords_encoding: List[str] | None = None,
604
  coords_scale_learn: bool = False,