blumenstiel
commited on
Commit
·
92ae9e8
1
Parent(s):
7e467ed
Switched from yaml to config.json
Browse files- README.md +1 -1
- config.json +1 -0
- config.yaml +0 -21
- inference.py +18 -26
- prithvi_mae.py +2 -2
README.md
CHANGED
@@ -36,7 +36,7 @@ We provide a **demo** running Prithvi-EO-2.0-300M-TL [here](https://huggingface.
|
|
36 |
There is also an inference script (`inference.py`) that allows to run the image reconstruction on a set of HLS images assumed to be from the same location at different timestamps (see example below). These should be provided in chronological order in geotiff format, including the channels described above (Blue, Green, Red, Narrow NIR, SWIR 1, SWIR 2) in reflectance units.
|
37 |
|
38 |
```
|
39 |
-
python inference.py --data_files t1.tif t2.tif t3.tif t4.tif --
|
40 |
```
|
41 |
|
42 |
## Finetuning
|
|
|
36 |
There is also an inference script (`inference.py`) that allows to run the image reconstruction on a set of HLS images assumed to be from the same location at different timestamps (see example below). These should be provided in chronological order in geotiff format, including the channels described above (Blue, Green, Red, Narrow NIR, SWIR 1, SWIR 2) in reflectance units.
|
37 |
|
38 |
```
|
39 |
+
python inference.py --data_files t1.tif t2.tif t3.tif t4.tif --input_indices <optional, space separated 0-based indices of the six Prithvi channels in your input>
|
40 |
```
|
41 |
|
42 |
## Finetuning
|
config.json
CHANGED
@@ -17,6 +17,7 @@
|
|
17 |
"coords_scale_learn": true,
|
18 |
"mask_ratio": 0.75,
|
19 |
"norm_pix_loss": false,
|
|
|
20 |
"mean": [1087.0, 1342.0, 1433.0, 2734.0, 1958.0, 1363.0],
|
21 |
"std": [2248.0, 2179.0, 2178.0, 1850.0, 1242.0, 1049.0],
|
22 |
"origin_url": "https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL",
|
|
|
17 |
"coords_scale_learn": true,
|
18 |
"mask_ratio": 0.75,
|
19 |
"norm_pix_loss": false,
|
20 |
+
"bands": ["B02", "B03", "B04", "B05", "B06", "B07"],
|
21 |
"mean": [1087.0, 1342.0, 1433.0, 2734.0, 1958.0, 1363.0],
|
22 |
"std": [2248.0, 2179.0, 2178.0, 1850.0, 1242.0, 1049.0],
|
23 |
"origin_url": "https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL",
|
config.yaml
DELETED
@@ -1,21 +0,0 @@
|
|
1 |
-
DATA:
|
2 |
-
BANDS: [B02, B03, B04, B05, B06, B07]
|
3 |
-
INPUT_SIZE: [4, 224, 224]
|
4 |
-
MASK_RATIO: 0.75
|
5 |
-
MEAN: [1087.0, 1342.0, 1433.0, 2734.0, 1958.0, 1363.0]
|
6 |
-
STD: [2248.0, 2179.0, 2178.0, 1850.0, 1242.0, 1049.0]
|
7 |
-
MODEL:
|
8 |
-
COORDS_DROP_RATE: 0.1
|
9 |
-
COORDS_ENCODING: [time, location]
|
10 |
-
COORDS_SCALE_LEARN: true
|
11 |
-
DECODER_DEPTH: 8
|
12 |
-
DECODER_EMBED_DIM: 512
|
13 |
-
DECODER_NUM_HEADS: 16
|
14 |
-
DEPTH: 24
|
15 |
-
DROP_CHANNELS_RATE: 0.0
|
16 |
-
EMBED_DIM: 1024
|
17 |
-
MLP_RATIO: 4.0
|
18 |
-
NAME: vit_l
|
19 |
-
NORM_PIX_LOSS: false
|
20 |
-
NUM_HEADS: 16
|
21 |
-
PATCH_SIZE: [1, 16, 16]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inference.py
CHANGED
@@ -304,18 +304,18 @@ def main(
|
|
304 |
|
305 |
# Get parameters --------
|
306 |
|
|
|
307 |
with open(config_path, "r") as f:
|
308 |
-
config = yaml.safe_load(f)
|
309 |
|
310 |
batch_size = 1
|
311 |
-
bands = config['
|
312 |
num_frames = len(data_files)
|
313 |
-
mean = config['
|
314 |
-
std = config['
|
315 |
-
coords_encoding = config['
|
316 |
-
img_size = config['
|
317 |
-
|
318 |
-
mask_ratio = mask_ratio or config['DATA']['MASK_RATIO']
|
319 |
|
320 |
print(
|
321 |
f"\nTreating {len(data_files)} files as {len(data_files)} time steps from the same location\n"
|
@@ -345,21 +345,13 @@ def main(
|
|
345 |
|
346 |
# Create model and load checkpoint -------------------------------------------------------------
|
347 |
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
decoder_embed_dim=config['MODEL']['DECODER_EMBED_DIM'],
|
356 |
-
decoder_depth=config['MODEL']['DECODER_DEPTH'],
|
357 |
-
decoder_num_heads=config['MODEL']['DECODER_NUM_HEADS'],
|
358 |
-
mlp_ratio=config['MODEL']['MLP_RATIO'],
|
359 |
-
norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
|
360 |
-
norm_pix_loss=config['MODEL']['NORM_PIX_LOSS'],
|
361 |
-
coords_encoding=coords_encoding,
|
362 |
-
coords_scale_learn=config['MODEL']['COORDS_SCALE_LEARN'])
|
363 |
|
364 |
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
365 |
print(f"\n--> Model has {total_params:,} parameters.\n")
|
@@ -487,11 +479,11 @@ if __name__ == "__main__":
|
|
487 |
help="Path to the data files. Assumes multi-band files.",
|
488 |
)
|
489 |
parser.add_argument(
|
490 |
-
"--
|
491 |
"-c",
|
492 |
type=str,
|
493 |
-
default="config.
|
494 |
-
help="Path to
|
495 |
)
|
496 |
parser.add_argument(
|
497 |
"--checkpoint",
|
|
|
304 |
|
305 |
# Get parameters --------
|
306 |
|
307 |
+
import json
|
308 |
with open(config_path, "r") as f:
|
309 |
+
config = yaml.safe_load(f)['pretrained_cfg']
|
310 |
|
311 |
batch_size = 1
|
312 |
+
bands = config['bands']
|
313 |
num_frames = len(data_files)
|
314 |
+
mean = config['mean']
|
315 |
+
std = config['std']
|
316 |
+
coords_encoding = config['coords_encoding']
|
317 |
+
img_size = config['img_size']
|
318 |
+
mask_ratio = mask_ratio or config['mask_ratio']
|
|
|
319 |
|
320 |
print(
|
321 |
f"\nTreating {len(data_files)} files as {len(data_files)} time steps from the same location\n"
|
|
|
345 |
|
346 |
# Create model and load checkpoint -------------------------------------------------------------
|
347 |
|
348 |
+
config.update(
|
349 |
+
coords_encoding=coords_encoding,
|
350 |
+
num_frames=num_frames,
|
351 |
+
in_chans=len(bands),
|
352 |
+
)
|
353 |
+
|
354 |
+
model = PrithviMAE(**config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
355 |
|
356 |
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
357 |
print(f"\n--> Model has {total_params:,} parameters.\n")
|
|
|
479 |
help="Path to the data files. Assumes multi-band files.",
|
480 |
)
|
481 |
parser.add_argument(
|
482 |
+
"--config",
|
483 |
"-c",
|
484 |
type=str,
|
485 |
+
default="config.json",
|
486 |
+
help="Path to json file containing model training parameters.",
|
487 |
)
|
488 |
parser.add_argument(
|
489 |
"--checkpoint",
|
prithvi_mae.py
CHANGED
@@ -240,7 +240,7 @@ class PrithviViT(nn.Module):
|
|
240 |
depth: int = 24,
|
241 |
num_heads: int = 16,
|
242 |
mlp_ratio: float = 4.,
|
243 |
-
norm_layer: nn.Module = nn.LayerNorm,
|
244 |
coords_encoding: List[str] | None = None,
|
245 |
coords_scale_learn: bool = False,
|
246 |
encoder_only: bool = True, # needed for timm
|
@@ -598,7 +598,7 @@ class PrithviMAE(nn.Module):
|
|
598 |
decoder_depth: int = 8,
|
599 |
decoder_num_heads: int = 16,
|
600 |
mlp_ratio: float = 4.,
|
601 |
-
norm_layer: nn.Module = nn.LayerNorm,
|
602 |
norm_pix_loss: bool = False,
|
603 |
coords_encoding: List[str] | None = None,
|
604 |
coords_scale_learn: bool = False,
|
|
|
240 |
depth: int = 24,
|
241 |
num_heads: int = 16,
|
242 |
mlp_ratio: float = 4.,
|
243 |
+
norm_layer: nn.Module = partial(torch.nn.LayerNorm, eps=1e-6),
|
244 |
coords_encoding: List[str] | None = None,
|
245 |
coords_scale_learn: bool = False,
|
246 |
encoder_only: bool = True, # needed for timm
|
|
|
598 |
decoder_depth: int = 8,
|
599 |
decoder_num_heads: int = 16,
|
600 |
mlp_ratio: float = 4.,
|
601 |
+
norm_layer: nn.Module = partial(torch.nn.LayerNorm, eps=1e-6),
|
602 |
norm_pix_loss: bool = False,
|
603 |
coords_encoding: List[str] | None = None,
|
604 |
coords_scale_learn: bool = False,
|