blumenstiel commited on
Commit
e25d381
·
1 Parent(s): 36f6e0e

Add inference code

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.tif filter=lfs diff=lfs merge=lfs -text
Prithvi_EO_V2_300_TL_config.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ DATA:
2
+ BANDS: [B02, B03, B04, B05, B06, B07]
3
+ INPUT_SIZE: [4, 224, 224]
4
+ MASK_RATIO: 0.75
5
+ MEAN: [1087.0, 1342.0, 1433.0, 2734.0, 1958.0, 1363.0]
6
+ STD: [2248.0, 2179.0, 2178.0, 1850.0, 1242.0, 1049.0]
7
+ MODEL:
8
+ COORDS_DROP_RATE: 0.1
9
+ COORDS_ENCODING: [time, location]
10
+ COORDS_SCALE_LEARN: true
11
+ DECODER_DEPTH: 8
12
+ DECODER_EMBED_DIM: 512
13
+ DECODER_NUM_HEADS: 16
14
+ DEPTH: 24
15
+ DROP_CHANNELS_RATE: 0.0
16
+ EMBED_DIM: 1024
17
+ MLP_RATIO: 4.0
18
+ NAME: vit_l
19
+ NORM_PIX_LOSS: false
20
+ NUM_HEADS: 16
21
+ PATCH_SIZE: [1, 16, 16]
README.md CHANGED
@@ -1,3 +1,57 @@
1
  ---
2
  license: apache-2.0
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: apache-2.0
3
  ---
4
+
5
+ ### Model and Inputs
6
+ Prithvi is a first-of-its-kind temporal Vision transformer pre-trained by the IBM and NASA team on contiguous US Harmonised Landsat Sentinel 2 (HLS) data. The model adopts a self-supervised encoder developed with a ViT architecture and Masked AutoEncoder (MAE) learning strategy, with an MSE loss function. The model includes spatial attention across multiple patches and also temporal attention for each patch.
7
+
8
+ ![](GFM.png)
9
+
10
+ The model accepts remote sensing data in a video format (B, C, T, H, W). Note that the temporal dimension (T) is very important in this application and not present in most other works around remote sensing modeling. The ability to handle a time series of remote sensing images can benefit a variety of downstream tasks (e.g. Burn Scars segmentation, Flood Segmentation, Land Cover Classification). The model can also handle static imagery which can be fed into the model with T=1.
11
+
12
+ ### Pre-training
13
+ The model was pre-trained with NASA's HLS V2 L30 product (30m granularity) from the contiguous United States. The bands that were used are the following:
14
+
15
+ 1. Blue
16
+ 2. Green
17
+ 3. Red
18
+ 4. Narrow NIR
19
+ 5. SWIR 1
20
+ 6. SWIR 2
21
+
22
+ ### Code
23
+ The model follows the [original MAE repo](https://github.com/facebookresearch/mae) with some modifications including:
24
+
25
+ 1. replace 2D patch embed with 3D patch embed;
26
+ 2. replace 2D positional embed with 3D positional embed;
27
+ 3. replace 2D patchify and unpatchify with 3D.
28
+ 4. adding infrared bands besides RGB
29
+
30
+ ### Inference and demo
31
+ There is an inference script (`Prithvi_run_inference.py`) that allows to run the image reconstruction on a set of HLS images assumed to be from the same location at different time steps(see example below). These should be provided in chronological order in geotiff format, including the channels described above (Blue, Green, Red, Narrow NIR, SWIR 1, SWIR 2) in reflectance units. There is also a **demo** that leverages the same code [here](https://huggingface.co/spaces/ibm-nasa-geospatial/Prithvi-100M-demo).
32
+
33
+ ```
34
+ python Prithvi_run_inference.py --data_files t1.tif t2.tif t3.tif --yaml_file_path /path/to/yaml/Prithvi_100.yaml --checkpoint /path/to/checkpoint/Prithvi_100.pth --output_dir /path/to/out/dir/ --input_indices <space separated 0-based indices of channels to select from input> --mask_ratio 0.5 --img_size <length of one side of square input shape>
35
+ ```
36
+
37
+ This demo is a starting point that can be used as a starting point to generalize to different input shapes / types.
38
+
39
+ ### Finetuning examples
40
+ Examples of finetuning the model for image segmentation using the mmsegmentation library are available through Hugging Face (e.g. [burn scars segmentation](https://huggingface.co/ibm-nasa-geospatial/Prithvi-100M-burn-scar), [flood mapping](https://huggingface.co/ibm-nasa-geospatial/Prithvi-100M-sen1floods11), and [multi temporal crop classification](https://huggingface.co/ibm-nasa-geospatial/Prithvi-100M-multi-temporal-crop-classification)), with the code used for the experiments available on [github](https://github.com/NASA-IMPACT/hls-foundation-os/tree/main/fine-tuning-examples). This also contains instructions to finetune the model for flood detection on the popular open access [sen1floods11 dataset](https://github.com/cloudtostreet/Sen1Floods11).
41
+
42
+ ### Feedback
43
+
44
+ Your feedback is invaluable to us. If you have any feedback about the model, please feel free to share it with us. You can do this by submitting issues on our open-source repository, [hls-foundation-os](https://github.com/NASA-IMPACT/hls-foundation-os/issues), on GitHub.
45
+
46
+ ### Citation
47
+
48
+ If this model helped your research, please cite `Prithvi-V2` in your publications. Here are two BibTeX entries as examples:
49
+
50
+ ```
51
+ @article{Prithvi-2-preprint,
52
+ author = {},
53
+ title = {{Title}},
54
+ journal = {Preprint Available on arxiv},
55
+ year = {2024}
56
+ }
57
+ ```
inference.py ADDED
@@ -0,0 +1,525 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import functools
3
+ import os
4
+ from typing import List, Union
5
+ import re
6
+ import datetime
7
+ import numpy as np
8
+ import pandas as pd
9
+ import rasterio
10
+ import torch
11
+ import yaml
12
+ from einops import rearrange
13
+
14
+ from functools import partial
15
+ from prithvi_mae import PrithviMAE
16
+
17
+ NO_DATA = -9999
18
+ NO_DATA_FLOAT = 0.0001
19
+ OFFSET = 0
20
+ PERCENTILE = 99.9
21
+
22
+
23
+ def process_channel_group(orig_img, new_img, channels, mean, std):
24
+ """Process *orig_img* and *new_img* for RGB visualization. Each band is rescaled back to the
25
+ original range using *data_mean* and *data_std* and then lowest and highest percentiles are
26
+ removed to enhance contrast. Data is rescaled to (0, 1) range and stacked channels_first.
27
+
28
+ Args:
29
+ orig_img: torch.Tensor representing original image (reference) with shape = (bands, H, W).
30
+ new_img: torch.Tensor representing image with shape = (bands, H, W).
31
+ channels: list of indices representing RGB channels.
32
+ mean: list of mean values for each band.
33
+ std: list of std values for each band.
34
+
35
+ Returns:
36
+ torch.Tensor with shape (num_channels, height, width) for original image
37
+ torch.Tensor with shape (num_channels, height, width) for the other image
38
+ """
39
+
40
+ mean = torch.tensor(np.asarray(mean)[:, None, None]) # C H W
41
+ std = torch.tensor(np.asarray(std)[:, None, None])
42
+ orig_img = orig_img[channels, ...]
43
+ valid_mask = torch.ones_like(orig_img, dtype=torch.bool)
44
+ valid_mask[orig_img == NO_DATA_FLOAT] = False
45
+
46
+ # Back to original data range
47
+ orig_img = (orig_img * std[channels]) + mean[channels]
48
+ new_img = (new_img[channels, ...] * std[channels]) + mean[channels]
49
+
50
+ # Rescale (enhancing contrast)
51
+ max_value = max(3000, np.percentile(orig_img[valid_mask], PERCENTILE))
52
+ min_value = OFFSET
53
+
54
+ orig_img = torch.clamp((orig_img - min_value) / (max_value - min_value), 0, 1)
55
+ new_img = torch.clamp((new_img - min_value) / (max_value - min_value), 0, 1)
56
+
57
+ # No data as zeros
58
+ orig_img[~valid_mask] = 0
59
+ new_img[~valid_mask] = 0
60
+
61
+ return orig_img, new_img
62
+
63
+
64
+ def read_geotiff(file_path: str):
65
+ """Read all bands from *file_path* and return image + meta info.
66
+
67
+ Args:
68
+ file_path: path to image file.
69
+
70
+ Returns:
71
+ np.ndarray with shape (bands, height, width)
72
+ meta info dict
73
+ """
74
+
75
+ with rasterio.open(file_path) as src:
76
+ img = src.read()
77
+ meta = src.meta
78
+ coords = src.lnglat()
79
+
80
+ return img, meta, coords
81
+
82
+
83
+ def save_geotiff(image, output_path: str, meta: dict):
84
+ """Save multi-band image in Geotiff file.
85
+
86
+ Args:
87
+ image: np.ndarray with shape (bands, height, width)
88
+ output_path: path where to save the image
89
+ meta: dict with meta info.
90
+ """
91
+
92
+ with rasterio.open(output_path, "w", **meta) as dest:
93
+ for i in range(image.shape[0]):
94
+ dest.write(image[i, :, :], i + 1)
95
+
96
+ return
97
+
98
+
99
+ def _convert_np_uint8(float_image: torch.Tensor):
100
+ image = float_image.numpy() * 255.0
101
+ image = image.astype(dtype=np.uint8)
102
+
103
+ return image
104
+
105
+
106
+ def load_example(
107
+ file_paths: List[str],
108
+ mean: List[float],
109
+ std: List[float],
110
+ indices: Union[list[int], None] = None,
111
+ ):
112
+ """Build an input example by loading images in *file_paths*.
113
+
114
+ Args:
115
+ file_paths: list of file paths .
116
+ mean: list containing mean values for each band in the images in *file_paths*.
117
+ std: list containing std values for each band in the images in *file_paths*.
118
+
119
+ Returns:
120
+ np.array containing created example
121
+ list of meta info for each image in *file_paths*
122
+ """
123
+
124
+ imgs = []
125
+ metas = []
126
+ temporal_coords = []
127
+ location_coords = []
128
+
129
+ for file in file_paths:
130
+ img, meta, coords = read_geotiff(file)
131
+
132
+ # Rescaling (don't normalize on nodata)
133
+ img = np.moveaxis(img, 0, -1) # channels last for rescaling
134
+ if indices is not None:
135
+ img = img[..., indices]
136
+ img = np.where(img == NO_DATA, NO_DATA_FLOAT, (img - mean) / std)
137
+
138
+ imgs.append(img)
139
+ metas.append(meta)
140
+ location_coords.append(coords)
141
+
142
+ try:
143
+ match = re.search(r'(\d{7}T\d{6})', file)
144
+ if match:
145
+ year = int(match.group(1)[:4])
146
+ julian_day = match.group(1).split('T')[0][4:]
147
+ if len(julian_day) == 3:
148
+ julian_day = int(julian_day)
149
+ else:
150
+ julian_day = datetime.datetime.strptime(julian_day, '%m%d').timetuple().tm_yday
151
+ temporal_coords.append([year, julian_day])
152
+ except Exception as e:
153
+ print(f'Could not extract timestamp for {file} ({e})')
154
+
155
+ imgs = np.stack(imgs, axis=0) # num_frames, H, W, C
156
+ imgs = np.moveaxis(imgs, -1, 0).astype("float32") # C, num_frames, H, W
157
+ imgs = np.expand_dims(imgs, axis=0) # add batch di
158
+
159
+ return imgs, temporal_coords, location_coords, metas
160
+
161
+
162
+ def run_model(
163
+ model: torch.nn.Module,
164
+ input_data: torch.Tensor,
165
+ temporal_coords: None | torch.Tensor,
166
+ location_coords: None | torch.Tensor,
167
+ mask_ratio: float,
168
+ device: torch.device,
169
+ ):
170
+ """Run *model* with *input_data* and create images from output tokens (mask, reconstructed + visible).
171
+
172
+ Args:
173
+ model: MAE model to run.
174
+ input_data: torch.Tensor with shape (B, C, T, H, W).
175
+ mask_ratio: mask ratio to use.
176
+ device: device where model should run.
177
+
178
+ Returns:
179
+ 3 torch.Tensor with shape (B, C, T, H, W).
180
+ """
181
+
182
+ with torch.no_grad():
183
+ x = input_data.to(device)
184
+
185
+ _, pred, mask = model(x, temporal_coords, location_coords, mask_ratio)
186
+
187
+ # Create mask and prediction images (un-patchify)
188
+ mask_img = (
189
+ model.unpatchify(mask.unsqueeze(-1).repeat(1, 1, pred.shape[-1])).detach().cpu()
190
+ )
191
+ pred_img = model.unpatchify(pred).detach().cpu()
192
+
193
+ # Mix visible and predicted patches
194
+ rec_img = input_data.clone()
195
+ rec_img[mask_img == 1] = pred_img[
196
+ mask_img == 1
197
+ ] # binary mask: 0 is keep, 1 is remove
198
+
199
+ # Switch zeros/ones in mask images so masked patches appear darker in plots (better visualization)
200
+ mask_img = (~(mask_img.to(torch.bool))).to(torch.float)
201
+
202
+ return rec_img, mask_img
203
+
204
+
205
+ def save_rgb_imgs(
206
+ input_img, rec_img, mask_img, channels, mean, std, output_dir, meta_data
207
+ ):
208
+ """Wrapper function to save Geotiff images (original, reconstructed, masked) per timestamp.
209
+
210
+ Args:
211
+ input_img: input torch.Tensor with shape (C, T, H, W).
212
+ rec_img: reconstructed torch.Tensor with shape (C, T, H, W).
213
+ mask_img: mask torch.Tensor with shape (C, T, H, W).
214
+ channels: list of indices representing RGB channels.
215
+ mean: list of mean values for each band.
216
+ std: list of std values for each band.
217
+ output_dir: directory where to save outputs.
218
+ meta_data: list of dicts with geotiff meta info.
219
+ """
220
+
221
+ for t in range(input_img.shape[1]):
222
+ rgb_orig, rgb_pred = process_channel_group(
223
+ orig_img=input_img[:, t, :, :],
224
+ new_img=rec_img[:, t, :, :],
225
+ channels=channels,
226
+ mean=mean,
227
+ std=std,
228
+ )
229
+
230
+ rgb_mask = mask_img[channels, t, :, :] * rgb_orig
231
+
232
+ # Saving images
233
+
234
+ save_geotiff(
235
+ image=_convert_np_uint8(rgb_orig),
236
+ output_path=os.path.join(output_dir, f"original_rgb_t{t}.tiff"),
237
+ meta=meta_data[t],
238
+ )
239
+
240
+ save_geotiff(
241
+ image=_convert_np_uint8(rgb_pred),
242
+ output_path=os.path.join(output_dir, f"predicted_rgb_t{t}.tiff"),
243
+ meta=meta_data[t],
244
+ )
245
+
246
+ save_geotiff(
247
+ image=_convert_np_uint8(rgb_mask),
248
+ output_path=os.path.join(output_dir, f"masked_rgb_t{t}.tiff"),
249
+ meta=meta_data[t],
250
+ )
251
+
252
+
253
+ def save_imgs(rec_img, mask_img, mean, std, output_dir, meta_data):
254
+ """Wrapper function to save Geotiff images (reconstructed, mask) per timestamp.
255
+
256
+ Args:
257
+ rec_img: reconstructed torch.Tensor with shape (C, T, H, W).
258
+ mask_img: mask torch.Tensor with shape (C, T, H, W).
259
+ mean: list of mean values for each band.
260
+ std: list of std values for each band.
261
+ output_dir: directory where to save outputs.
262
+ meta_data: list of dicts with geotiff meta info.
263
+ """
264
+
265
+ mean = torch.tensor(np.asarray(mean)[:, None, None]) # C H W
266
+ std = torch.tensor(np.asarray(std)[:, None, None])
267
+
268
+ for t in range(rec_img.shape[1]):
269
+ # Back to original data range
270
+ rec_img_t = ((rec_img[:, t, :, :] * std) + mean).to(torch.int16)
271
+
272
+ mask_img_t = mask_img[:, t, :, :].to(torch.int16)
273
+
274
+ # Saving images
275
+
276
+ save_geotiff(
277
+ image=rec_img_t,
278
+ output_path=os.path.join(output_dir, f"predicted_t{t}.tiff"),
279
+ meta=meta_data[t],
280
+ )
281
+
282
+ save_geotiff(
283
+ image=mask_img_t,
284
+ output_path=os.path.join(output_dir, f"mask_t{t}.tiff"),
285
+ meta=meta_data[t],
286
+ )
287
+
288
+
289
+ def main(
290
+ data_files: List[str],
291
+ config_path: str,
292
+ checkpoint: str,
293
+ output_dir: str,
294
+ rgb_outputs: bool,
295
+ mask_ratio: float = None,
296
+ input_indices: list[int] = None,
297
+ ):
298
+ os.makedirs(output_dir, exist_ok=True)
299
+
300
+ # Get parameters --------
301
+
302
+ with open(config_path, "r") as f:
303
+ config = yaml.safe_load(f)
304
+
305
+ batch_size = 1
306
+ bands = config['DATA']['BANDS']
307
+ num_frames = len(data_files)
308
+ mean = config['DATA']['MEAN']
309
+ std = config['DATA']['STD']
310
+ coords_encoding = config['MODEL']['COORDS_ENCODING']
311
+ img_size = config['DATA']['INPUT_SIZE'][-1]
312
+
313
+ mask_ratio = mask_ratio or config['DATA']['MASK_RATIO']
314
+
315
+ print(
316
+ f"\nTreating {len(data_files)} files as {len(data_files)} time steps from the same location\n"
317
+ )
318
+ if len(data_files) != 3:
319
+ print(
320
+ "The original model was trained for 3 time steps (expecting 3 files). \nResults with different numbers of timesteps may vary"
321
+ )
322
+
323
+ if torch.cuda.is_available():
324
+ device = torch.device("cuda")
325
+ else:
326
+ device = torch.device("cpu")
327
+
328
+ print(f"Using {device} device.\n")
329
+
330
+ # Loading data ---------------------------------------------------------------------------------
331
+
332
+ input_data, temporal_coords, location_coords, meta_data = load_example(
333
+ file_paths=data_files, indices=input_indices, mean=mean, std=std
334
+ )
335
+
336
+ if not temporal_coords and 'time' in coords_encoding:
337
+ coords_encoding.pop('time')
338
+ if location_coords is None and 'location' in coords_encoding:
339
+ coords_encoding.pop('location')
340
+
341
+ # Create model and load checkpoint -------------------------------------------------------------
342
+
343
+ model = PrithviMAE(img_size=config['DATA']['INPUT_SIZE'][-2:],
344
+ patch_size=config['MODEL']['PATCH_SIZE'],
345
+ num_frames=num_frames,
346
+ in_chans=len(bands),
347
+ embed_dim=config['MODEL']['EMBED_DIM'],
348
+ depth=config['MODEL']['DEPTH'],
349
+ num_heads=config['MODEL']['NUM_HEADS'],
350
+ decoder_embed_dim=config['MODEL']['DECODER_EMBED_DIM'],
351
+ decoder_depth=config['MODEL']['DECODER_DEPTH'],
352
+ decoder_num_heads=config['MODEL']['DECODER_NUM_HEADS'],
353
+ mlp_ratio=config['MODEL']['MLP_RATIO'],
354
+ norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
355
+ norm_pix_loss=config['MODEL']['NORM_PIX_LOSS'],
356
+ coords_encoding=coords_encoding,
357
+ coords_scale_learn=config['MODEL']['COORDS_SCALE_LEARN'])
358
+
359
+ total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
360
+ print(f"\n--> Model has {total_params:,} parameters.\n")
361
+
362
+ model.to(device)
363
+
364
+ state_dict = torch.load(checkpoint, map_location=device)
365
+ # discard fixed pos_embedding weight
366
+ for k in list(state_dict.keys()):
367
+ if 'pos_embed' in k:
368
+ del state_dict[k]
369
+ model.load_state_dict(state_dict, strict=False)
370
+ print(f"Loaded checkpoint from {checkpoint}")
371
+
372
+ # Running model --------------------------------------------------------------------------------
373
+
374
+ model.eval()
375
+ channels = [bands.index(b) for b in ["B04", "B03", "B02"]] # BGR -> RGB
376
+
377
+ # Reflect pad if not divisible by img_size
378
+ original_h, original_w = input_data.shape[-2:]
379
+ pad_h = img_size - (original_h % img_size)
380
+ pad_w = img_size - (original_w % img_size)
381
+ input_data = np.pad(
382
+ input_data, ((0, 0), (0, 0), (0, 0), (0, pad_h), (0, pad_w)), mode="reflect"
383
+ )
384
+
385
+ # Build sliding window
386
+ batch = torch.tensor(input_data, device="cpu")
387
+ windows = batch.unfold(3, img_size, img_size).unfold(4, img_size, img_size)
388
+ h1, w1 = windows.shape[3:5]
389
+ windows = rearrange(
390
+ windows, "b c t h1 w1 h w -> (b h1 w1) c t h w", h=img_size, w=img_size
391
+ )
392
+
393
+ # Split into batches if number of windows > batch_size
394
+ num_batches = windows.shape[0] // batch_size if windows.shape[0] > batch_size else 1
395
+ windows = torch.tensor_split(windows, num_batches, dim=0)
396
+
397
+ temporal_coords = torch.Tensor(temporal_coords, device=device).unsqueeze(0)
398
+ location_coords = torch.Tensor(location_coords[0], device=device).unsqueeze(0)
399
+
400
+ # Run model
401
+ rec_imgs = []
402
+ mask_imgs = []
403
+ for x in windows:
404
+ rec_img, mask_img = run_model(model, x, temporal_coords, location_coords, mask_ratio, device)
405
+ rec_imgs.append(rec_img)
406
+ mask_imgs.append(mask_img)
407
+
408
+ rec_imgs = torch.concat(rec_imgs, dim=0)
409
+ mask_imgs = torch.concat(mask_imgs, dim=0)
410
+
411
+ # Build images from patches
412
+ rec_imgs = rearrange(
413
+ rec_imgs,
414
+ "(b h1 w1) c t h w -> b c t (h1 h) (w1 w)",
415
+ h=img_size,
416
+ w=img_size,
417
+ b=1,
418
+ c=len(bands),
419
+ t=num_frames,
420
+ h1=h1,
421
+ w1=w1,
422
+ )
423
+ mask_imgs = rearrange(
424
+ mask_imgs,
425
+ "(b h1 w1) c t h w -> b c t (h1 h) (w1 w)",
426
+ h=img_size,
427
+ w=img_size,
428
+ b=1,
429
+ c=len(bands),
430
+ t=num_frames,
431
+ h1=h1,
432
+ w1=w1,
433
+ )
434
+
435
+ # Cut padded images back to original size
436
+ rec_imgs_full = rec_imgs[..., :original_h, :original_w]
437
+ mask_imgs_full = mask_imgs[..., :original_h, :original_w]
438
+ batch_full = batch[..., :original_h, :original_w]
439
+
440
+ # Build output images
441
+ if rgb_outputs:
442
+ for d in meta_data:
443
+ d.update(count=3, dtype="uint8", compress="lzw", nodata=0)
444
+
445
+ save_rgb_imgs(
446
+ batch_full[0, ...],
447
+ rec_imgs_full[0, ...],
448
+ mask_imgs_full[0, ...],
449
+ channels,
450
+ mean,
451
+ std,
452
+ output_dir,
453
+ meta_data,
454
+ )
455
+ else:
456
+ for d in meta_data:
457
+ d.update(compress="lzw", nodata=0)
458
+
459
+ save_imgs(
460
+ rec_imgs_full[0, ...],
461
+ mask_imgs_full[0, ...],
462
+ mean,
463
+ std,
464
+ output_dir,
465
+ meta_data,
466
+ )
467
+
468
+ print("Done!")
469
+
470
+
471
+ if __name__ == "__main__":
472
+ parser = argparse.ArgumentParser("MAE run inference", add_help=False)
473
+
474
+ parser.add_argument(
475
+ "--data_files",
476
+ type=str,
477
+ nargs="+",
478
+ default=["examples/HLS.L30.T13REN.2018013T172747.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif",
479
+ "examples/HLS.L30.T13REN.2018029T172738.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif",
480
+ "examples/HLS.L30.T13REN.2018061T172724.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif"
481
+ ],
482
+ help="Path to the data files. Assumes multi-band files.",
483
+ )
484
+ parser.add_argument(
485
+ "--config_path",
486
+ "-c",
487
+ type=str,
488
+ default="Prithvi_EO_V2_300_TL_config.yaml",
489
+ help="Path to yaml file containing model training parameters.",
490
+ )
491
+ parser.add_argument(
492
+ "--checkpoint",
493
+ type=str,
494
+ default="Prithvi_EO_V2_300_TL.pt",
495
+ help="Path to a checkpoint file to load from.",
496
+ )
497
+ parser.add_argument(
498
+ "--output_dir",
499
+ type=str,
500
+ default="output",
501
+ help="Path to the directory where to save outputs.",
502
+ )
503
+ parser.add_argument(
504
+ "--mask_ratio",
505
+ default=0.75,
506
+ type=float,
507
+ help="Masking ratio (percentage of removed patches). "
508
+ "If None (default) use same value used for pretraining.",
509
+ )
510
+ parser.add_argument(
511
+ "--input_indices",
512
+ default=None,
513
+ type=int,
514
+ nargs="+",
515
+ help="0-based indices of channels to be selected from the input. By default takes all.",
516
+ )
517
+ parser.add_argument(
518
+ "--rgb_outputs",
519
+ action="store_true",
520
+ help="If present, output files will only contain RGB channels. "
521
+ "Otherwise, all bands will be saved.",
522
+ )
523
+ args = parser.parse_args()
524
+
525
+ main(**vars(args))
prithvi_mae.py ADDED
@@ -0,0 +1,736 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) IBM Corp. 2024. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # --------------------------------------------------------
15
+ # References:
16
+ # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
17
+ # transformers: https://github.com/huggingface/transformers
18
+ # --------------------------------------------------------
19
+
20
+ from functools import partial
21
+ from typing import List, Tuple
22
+
23
+ import logging
24
+ import numpy as np
25
+ import torch
26
+ import torch.nn as nn
27
+ from einops import rearrange
28
+ from timm.layers import to_2tuple
29
+ from timm.models.vision_transformer import Block
30
+
31
+
32
+ def get_3d_sincos_pos_embed(embed_dim, grid_size, add_cls_token=False):
33
+ """
34
+ Create 3D sin/cos positional embeddings.
35
+
36
+ Args:
37
+ embed_dim (int):
38
+ Embedding dimension.
39
+ grid_size (tuple[int, int, int] | list[int]):
40
+ The grid depth, height and width.
41
+ add_cls_token (bool, *optional*, defaults to False):
42
+ Whether or not to add a classification (CLS) token.
43
+
44
+ Returns:
45
+ (`torch.FloatTensor` of shape (grid_size[0]*grid_size[1]*grid_size[2], embed_dim) or
46
+ (1+grid_size[0]*grid_size[1]*grid_size[2], embed_dim): the position embeddings (with or without cls token)
47
+ """
48
+
49
+ assert embed_dim % 16 == 0
50
+
51
+ t_size, h_size, w_size = grid_size
52
+
53
+ w_embed_dim = embed_dim // 16 * 6
54
+ h_embed_dim = embed_dim // 16 * 6
55
+ t_embed_dim = embed_dim // 16 * 4
56
+
57
+ w_pos_embed = get_1d_sincos_pos_embed_from_grid(w_embed_dim, np.arange(w_size))
58
+ h_pos_embed = get_1d_sincos_pos_embed_from_grid(h_embed_dim, np.arange(h_size))
59
+ t_pos_embed = get_1d_sincos_pos_embed_from_grid(t_embed_dim, np.arange(t_size))
60
+
61
+ w_pos_embed = np.tile(w_pos_embed, (t_size * h_size, 1))
62
+ h_pos_embed = np.tile(np.repeat(h_pos_embed, w_size, axis=0), (t_size, 1))
63
+ t_pos_embed = np.repeat(t_pos_embed, h_size * w_size, axis=0)
64
+
65
+ pos_embed = np.concatenate((w_pos_embed, h_pos_embed, t_pos_embed), axis=1)
66
+
67
+ if add_cls_token:
68
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
69
+ return pos_embed
70
+
71
+
72
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
73
+ """
74
+ embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
75
+ """
76
+ if embed_dim % 2 != 0:
77
+ raise ValueError("embed_dim must be even")
78
+
79
+ omega = np.arange(embed_dim // 2, dtype=float)
80
+ omega /= embed_dim / 2.0
81
+ omega = 1.0 / 10000**omega # (D/2,)
82
+
83
+ pos = pos.reshape(-1) # (M,)
84
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
85
+
86
+ emb_sin = np.sin(out) # (M, D/2)
87
+ emb_cos = np.cos(out) # (M, D/2)
88
+
89
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
90
+ return emb
91
+
92
+
93
+ def _get_1d_sincos_embed_from_grid_torch(embed_dim: int, pos: torch.Tensor):
94
+ """ This is the torch version of *get_1d_sincos_pos_embed_from_grid()*. However,
95
+ it was modified to cast omega values to pos.dtype which must be float (and not int as in
96
+ regular positional embeddings). This was required in order to allow for native FSDP mixed
97
+ precision support: modify omega to appropriate dtype (pos carries the correct float dtype),
98
+ instead of manually forcing float32.
99
+
100
+ embed_dim: output dimension for each position
101
+ pos: a list of positions to be encoded: size (M,) - must be float dtype!
102
+ out: (M, D)
103
+ """
104
+ assert embed_dim % 2 == 0
105
+ assert pos.dtype in [torch.float32, torch.float16, torch.bfloat16]
106
+
107
+ omega = torch.arange(embed_dim // 2, dtype=pos.dtype).to(pos.device)
108
+ omega /= embed_dim / 2.0
109
+ omega = 1.0 / 10000**omega # (D/2,)
110
+
111
+ pos = pos.reshape(-1) # (M,)
112
+ out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
113
+
114
+ emb_sin = torch.sin(out) # (M, D/2)
115
+ emb_cos = torch.cos(out) # (M, D/2)
116
+
117
+ emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
118
+
119
+ return emb
120
+
121
+
122
+ def _init_weights(module):
123
+ """Initialize the weights"""
124
+ if isinstance(module, nn.Linear):
125
+ nn.init.xavier_uniform_(module.weight)
126
+ if module.bias is not None:
127
+ module.bias.data.zero_()
128
+ elif isinstance(module, nn.LayerNorm):
129
+ module.bias.data.zero_()
130
+ module.weight.data.fill_(1.0)
131
+
132
+
133
+ class PatchEmbed(nn.Module):
134
+ """3D version of timm.models.vision_transformer.PatchEmbed"""
135
+ def __init__(
136
+ self,
137
+ input_size: Tuple[int, int, int] = (1, 224, 224),
138
+ patch_size: Tuple[int, int, int] = (1, 16, 16),
139
+ in_chans: int = 3,
140
+ embed_dim: int = 768,
141
+ norm_layer: nn.Module | None = None,
142
+ flatten: bool = True,
143
+ bias: bool = True,
144
+ ):
145
+ super().__init__()
146
+ self.input_size = input_size
147
+ self.patch_size = patch_size
148
+ self.grid_size = [s // p for s, p in zip(self.input_size, self.patch_size)]
149
+ self.num_patches = self.grid_size[0] * self.grid_size[1] * self.grid_size[2]
150
+ self.flatten = flatten
151
+
152
+ self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
153
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
154
+
155
+ def forward(self, x):
156
+ B, C, T, H, W = x.shape
157
+
158
+ if T / self.patch_size[0] % 1 or H / self.patch_size[1] % 1 or W / self.patch_size[2] % 1:
159
+ logging.warning(f"Input {x.shape[-3:]} is not divisible by patch size {self.patch_size}."
160
+ f"The border will be ignored, add backbone_padding for pixel-wise tasks.")
161
+
162
+ x = self.proj(x)
163
+ if self.flatten:
164
+ x = x.flatten(2).transpose(1, 2) # B,C,T,H,W -> B,C,L -> B,L,C
165
+ x = self.norm(x)
166
+ return x
167
+
168
+
169
+ class TemporalEncoder(nn.Module):
170
+ def __init__(self, embed_dim: int, trainable_scale: bool = False):
171
+ super().__init__()
172
+ self.embed_dim = embed_dim
173
+ self.year_embed_dim = embed_dim // 2
174
+ self.julian_day_embed_dim = embed_dim - self.year_embed_dim
175
+
176
+ # If trainable, initialize scale with small number
177
+ if trainable_scale:
178
+ self.scale = nn.Parameter(torch.full((1,), 0.1))
179
+ else:
180
+ self.register_buffer('scale', torch.ones(1))
181
+
182
+ def forward(self, temporal_coords: torch.Tensor, tokens_per_frame: int | None = None):
183
+ """
184
+ temporal_coords: year and day-of-year info with shape (B, T, 2).
185
+ tokens_per_frame: number of tokens for each frame in the sample. If provided, embeddings will be
186
+ repeated over T dimension, and final shape is (B, T*tokens_per_frame, embed_dim).
187
+ """
188
+ shape = temporal_coords.shape[:2] + (-1,) # B, T, -1
189
+
190
+ year = _get_1d_sincos_embed_from_grid_torch(
191
+ self.year_embed_dim, temporal_coords[:, :, 0].flatten()).reshape(shape)
192
+ julian_day = _get_1d_sincos_embed_from_grid_torch(
193
+ self.julian_day_embed_dim, temporal_coords[:, :, 1].flatten()).reshape(shape)
194
+
195
+ embedding = self.scale * torch.cat([year, julian_day], dim=-1)
196
+
197
+ if tokens_per_frame is not None:
198
+ embedding = torch.repeat_interleave(embedding, tokens_per_frame, dim=1)
199
+
200
+ return embedding # B, T*tokens_per_frame, embed_dim
201
+
202
+
203
+ class LocationEncoder(nn.Module):
204
+ def __init__(self, embed_dim: int, trainable_scale: bool = False):
205
+ super().__init__()
206
+ self.embed_dim = embed_dim
207
+ self.lat_embed_dim = embed_dim // 2
208
+ self.lon_embed_dim = embed_dim - self.lat_embed_dim
209
+
210
+ # If trainable, initialize scale with small number
211
+ if trainable_scale:
212
+ self.scale = nn.Parameter(torch.full((1,), 0.1))
213
+ else:
214
+ self.register_buffer('scale', torch.ones(1))
215
+
216
+ def forward(self, location_coords: torch.Tensor):
217
+ """
218
+ location_coords: lat and lon info with shape (B, 2).
219
+ """
220
+ shape = location_coords.shape[:1] + (1, -1) # B, 1, -1
221
+
222
+ lat = _get_1d_sincos_embed_from_grid_torch(
223
+ self.lat_embed_dim, location_coords[:, 0].flatten()).reshape(shape)
224
+ lon = _get_1d_sincos_embed_from_grid_torch(
225
+ self.lon_embed_dim, location_coords[:, 1].flatten()).reshape(shape)
226
+
227
+ embedding = self.scale * torch.cat([lat, lon], dim=-1)
228
+
229
+ return embedding # B, 1, embed_dim
230
+
231
+
232
+ class PrithviViT(nn.Module):
233
+ """ Prithvi ViT Encoder"""
234
+ def __init__(self,
235
+ img_size: int | Tuple[int, int] = 224,
236
+ patch_size: int | Tuple[int, int, int] = (1, 16, 16),
237
+ num_frames: int = 1,
238
+ in_chans: int = 3,
239
+ embed_dim: int = 1024,
240
+ depth: int = 24,
241
+ num_heads: int = 16,
242
+ mlp_ratio: float = 4.,
243
+ norm_layer: nn.Module = nn.LayerNorm,
244
+ coords_encoding: List[str] | None = None,
245
+ coords_scale_learn: bool = False,
246
+ encoder_only: bool = True, # needed for timm
247
+ ** kwargs,
248
+ ):
249
+ super().__init__()
250
+
251
+ self.feature_info = []
252
+ self.encoder_only = encoder_only
253
+ self.in_chans = in_chans
254
+ self.num_frames = num_frames
255
+ self.embed_dim = embed_dim
256
+ self.img_size = to_2tuple(img_size)
257
+ if isinstance(patch_size, int):
258
+ patch_size = (1, patch_size, patch_size)
259
+
260
+ # 3D patch embedding
261
+ self.patch_embed = PatchEmbed(
262
+ input_size=(num_frames,) + self.img_size,
263
+ patch_size=patch_size,
264
+ in_chans=in_chans,
265
+ embed_dim=embed_dim,
266
+ )
267
+
268
+ # Optional temporal and location embedding
269
+ coords_encoding = coords_encoding or []
270
+ self.temporal_encoding = 'time' in coords_encoding
271
+ self.location_encoding = 'location' in coords_encoding
272
+ if self.temporal_encoding:
273
+ assert patch_size[0] == 1, f"With temporal encoding, patch_size[0] must be 1, received {patch_size[0]}"
274
+ self.temporal_embed_enc = TemporalEncoder(embed_dim, coords_scale_learn)
275
+ if self.location_encoding:
276
+ self.location_embed_enc = LocationEncoder(embed_dim, coords_scale_learn)
277
+
278
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
279
+ self.register_buffer("pos_embed", torch.zeros(1, self.patch_embed.num_patches + 1, embed_dim))
280
+
281
+ # Transformer layers
282
+ self.blocks = []
283
+ for i in range(depth):
284
+ self.blocks.append(Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer))
285
+ self.feature_info.append(
286
+ {"num_chs": embed_dim * self.patch_embed.patch_size[0], "reduction": 1, "module": f"blocks.{i}"}
287
+ )
288
+ self.blocks = nn.ModuleList(self.blocks)
289
+
290
+ self.norm = norm_layer(embed_dim)
291
+
292
+ self.initialize_weights()
293
+
294
+ def initialize_weights(self):
295
+ # initialize (and freeze) position embeddings by sin-cos embedding
296
+ pos_embed = get_3d_sincos_pos_embed(
297
+ self.pos_embed.shape[-1], self.patch_embed.grid_size, add_cls_token=True
298
+ )
299
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
300
+
301
+ # initialize patch_embeddings like nn.Linear (instead of nn.Conv2d)
302
+ w = self.patch_embed.proj.weight.data
303
+ torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
304
+
305
+ # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
306
+ torch.nn.init.normal_(self.cls_token, std=0.02)
307
+ self.apply(_init_weights)
308
+
309
+ def random_masking(self, sequence, mask_ratio, noise=None):
310
+ """
311
+ Perform per-sample random masking by per-sample shuffling. Per-sample shuffling is done by argsort random
312
+ noise.
313
+
314
+ Args:
315
+ sequence (`torch.FloatTensor` of shape `(batch_size, sequence_length, dim)`)
316
+ mask_ratio (float): mask ratio to use.
317
+ noise (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*) which is
318
+ mainly used for testing purposes to control randomness and maintain the reproducibility
319
+ """
320
+ batch_size, seq_length, dim = sequence.shape
321
+ len_keep = int(seq_length * (1 - mask_ratio))
322
+
323
+ if noise is None:
324
+ noise = torch.rand(batch_size, seq_length, device=sequence.device) # noise in [0, 1]
325
+
326
+ # sort noise for each sample
327
+ ids_shuffle = torch.argsort(noise, dim=1).to(sequence.device) # ascend: small is keep, large is remove
328
+ ids_restore = torch.argsort(ids_shuffle, dim=1).to(sequence.device)
329
+
330
+ # keep the first subset
331
+ ids_keep = ids_shuffle[:, :len_keep]
332
+ sequence_unmasked = torch.gather(sequence, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, dim))
333
+
334
+ # generate the binary mask: 0 is keep, 1 is remove
335
+ mask = torch.ones([batch_size, seq_length], device=sequence.device)
336
+ mask[:, :len_keep] = 0
337
+ # unshuffle to get the binary mask
338
+ mask = torch.gather(mask, dim=1, index=ids_restore)
339
+
340
+ return sequence_unmasked, mask, ids_restore
341
+
342
+ def _get_pos_embed(self, x):
343
+ t, h, w = x.shape[-3:]
344
+
345
+ pos_embed = torch.from_numpy(get_3d_sincos_pos_embed(
346
+ self.embed_dim,
347
+ (
348
+ t // self.patch_embed.patch_size[0],
349
+ h // self.patch_embed.patch_size[1],
350
+ w // self.patch_embed.patch_size[2],
351
+ ),
352
+ add_cls_token=True,
353
+ )).float().unsqueeze(0).to(x)
354
+
355
+ return pos_embed
356
+
357
+
358
+ def forward(
359
+ self, x: torch.Tensor,
360
+ temporal_coords: None | torch.Tensor = None,
361
+ location_coords: None | torch.Tensor = None,
362
+ mask_ratio=0.75
363
+ ):
364
+ if x.shape[-3:] != self.patch_embed.input_size:
365
+ # changed input size
366
+ pos_embed = self._get_pos_embed(x)
367
+ else:
368
+ pos_embed = self.pos_embed
369
+
370
+ # embed patches
371
+ x = self.patch_embed(x)
372
+
373
+ # add pos embed w/o cls token
374
+ x = x + pos_embed[:, 1:, :]
375
+
376
+ if self.temporal_encoding:
377
+ num_tokens_per_frame = x.shape[1] // self.num_frames
378
+ temporal_encoding = self.temporal_embed_enc(temporal_coords, num_tokens_per_frame)
379
+ x = x + temporal_encoding
380
+ if self.location_encoding:
381
+ location_encoding = self.location_embed_enc(location_coords)
382
+ x = x + location_encoding
383
+
384
+ # masking: length -> length * mask_ratio
385
+ x, mask, ids_restore = self.random_masking(x, mask_ratio)
386
+
387
+ # append cls token
388
+ cls_token = self.cls_token + pos_embed[:, :1, :]
389
+ cls_tokens = cls_token.expand(x.shape[0], -1, -1)
390
+ x = torch.cat((cls_tokens, x), dim=1)
391
+
392
+ # apply Transformer blocks
393
+ for block in self.blocks:
394
+ x = block(x)
395
+ x = self.norm(x)
396
+
397
+ return x, mask, ids_restore
398
+
399
+ def forward_features(
400
+ self,
401
+ x: torch.Tensor,
402
+ temporal_coords: None | torch.Tensor = None,
403
+ location_coords: None | torch.Tensor = None,
404
+ ) -> list[torch.Tensor]:
405
+ if len(x.shape) == 4 and self.patch_embed.input_size[0] == 1:
406
+ # add time dim
407
+ x = x.unsqueeze(2)
408
+
409
+ if x.shape[-3:] != self.patch_embed.input_size:
410
+ pos_embed = self._get_pos_embed(x)
411
+ else:
412
+ pos_embed = self.pos_embed
413
+
414
+ # embed patches
415
+ x = self.patch_embed(x)
416
+
417
+ # add pos embed w/o cls token
418
+ x = x + pos_embed[:, 1:, :]
419
+
420
+ if self.temporal_encoding:
421
+ num_tokens_per_frame = x.shape[1] // self.patch_embed.num_frames
422
+ temporal_encoding = self.temporal_embed_enc(temporal_coords, num_tokens_per_frame)
423
+ x = x + temporal_encoding
424
+ if self.location_encoding:
425
+ location_encoding = self.location_embed_enc(location_coords)
426
+ x = x + location_encoding
427
+
428
+ # append cls token
429
+ cls_token = self.cls_token + pos_embed[:, :1, :]
430
+ cls_tokens = cls_token.expand(x.shape[0], -1, -1)
431
+ x = torch.cat((cls_tokens, x), dim=1)
432
+
433
+ # apply Transformer blocks
434
+ out = []
435
+ for block in self.blocks:
436
+ x = block(x)
437
+ out.append(x.clone())
438
+
439
+ x = self.norm(x)
440
+ out[-1] = x
441
+ return out
442
+
443
+ def prepare_features_for_image_model(self, features: list[torch.Tensor]) -> list[torch.Tensor]:
444
+ out = []
445
+ effective_time_dim = self.patch_embed.input_size[0] // self.patch_embed.patch_size[0]
446
+ for x in features:
447
+ x_no_token = x[:, 1:, :]
448
+ number_of_tokens = x_no_token.shape[1]
449
+ tokens_per_timestep = number_of_tokens // effective_time_dim
450
+ h = int(np.sqrt(tokens_per_timestep))
451
+ encoded = rearrange(
452
+ x_no_token,
453
+ "batch (t h w) e -> batch (t e) h w",
454
+ e=self.embed_dim,
455
+ t=effective_time_dim,
456
+ h=h,
457
+ )
458
+ out.append(encoded)
459
+ return out
460
+
461
+
462
+ class MAEDecoder(nn.Module):
463
+ """ Transformer Decoder used in the Prithvi MAE"""
464
+ def __init__(self,
465
+ patch_size: int | Tuple[int, int, int] = (1, 16, 16),
466
+ grid_size: List[int] | Tuple[int, int, int] = (3, 14, 14),
467
+ in_chans: int = 3,
468
+ encoder_embed_dim: int = 1024,
469
+ decoder_embed_dim: int = 512,
470
+ depth: int = 8,
471
+ num_heads: int = 16,
472
+ mlp_ratio: float = 4.,
473
+ norm_layer: nn.Module = nn.LayerNorm,
474
+ coords_encoding: List[str] | None = None,
475
+ coords_scale_learn: bool = False,
476
+ ):
477
+ super().__init__()
478
+
479
+ self.decoder_embed = nn.Linear(encoder_embed_dim, decoder_embed_dim, bias=True)
480
+ self.decoder_embed_dim = decoder_embed_dim
481
+ self.grid_size = grid_size
482
+ if isinstance(patch_size, int):
483
+ patch_size = (1, patch_size, patch_size)
484
+ self.patch_size = patch_size
485
+ self.num_frames = self.grid_size[0] * patch_size[0]
486
+ num_patches = self.grid_size[0] * self.grid_size[1] * self.grid_size[2]
487
+
488
+ # Optional temporal and location embedding
489
+ coords_encoding = coords_encoding or []
490
+ self.temporal_encoding = 'time' in coords_encoding
491
+ self.location_encoding = 'location' in coords_encoding
492
+ if self.temporal_encoding:
493
+ self.temporal_embed_dec = TemporalEncoder(decoder_embed_dim, coords_scale_learn)
494
+ if self.location_encoding:
495
+ self.location_embed_dec = LocationEncoder(decoder_embed_dim, coords_scale_learn)
496
+
497
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
498
+
499
+ self.register_buffer("decoder_pos_embed", torch.zeros(1, num_patches + 1, decoder_embed_dim))
500
+
501
+ self.decoder_blocks = nn.ModuleList(
502
+ [Block(decoder_embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) for _ in range(depth)]
503
+ )
504
+
505
+ self.decoder_norm = norm_layer(decoder_embed_dim)
506
+ self.decoder_pred = nn.Linear(decoder_embed_dim,
507
+ patch_size[0] * patch_size[1] * patch_size[2] * in_chans,
508
+ bias=True)
509
+
510
+ self.initialize_weights()
511
+
512
+ def initialize_weights(self):
513
+ # initialize (and freeze) position embeddings by sin-cos embedding
514
+ decoder_pos_embed = get_3d_sincos_pos_embed(
515
+ self.decoder_pos_embed.shape[-1], self.grid_size, add_cls_token=True
516
+ )
517
+ self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))
518
+
519
+ # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
520
+ torch.nn.init.normal_(self.mask_token, std=0.02)
521
+ self.apply(_init_weights)
522
+
523
+ def forward(
524
+ self,
525
+ hidden_states: torch.Tensor,
526
+ ids_restore: torch.Tensor,
527
+ temporal_coords: None | torch.Tensor = None,
528
+ location_coords: None | torch.Tensor = None,
529
+ input_size: list[int] = None,
530
+ ):
531
+ # embed tokens
532
+ x = self.decoder_embed(hidden_states)
533
+
534
+ t, h, w = input_size[-3:]
535
+ decoder_pos_embed = torch.from_numpy(
536
+ get_3d_sincos_pos_embed(
537
+ self.decoder_embed_dim,
538
+ (
539
+ t // self.patch_size[0],
540
+ h // self.patch_size[1],
541
+ w // self.patch_size[2],
542
+ ),
543
+ add_cls_token=True,
544
+ )
545
+ ).to(x)
546
+
547
+ # append mask tokens to sequence
548
+ mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
549
+ x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token
550
+ # unshuffle
551
+ x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]).to(x_.device))
552
+ x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token
553
+ # add pos embed
554
+ x = x + decoder_pos_embed
555
+
556
+ # remove cls token
557
+ x_ = x[:, 1:, :]
558
+
559
+ if self.temporal_encoding:
560
+ num_tokens_per_frame = x_.shape[1] // self.num_frames
561
+ temporal_encoding = self.temporal_embed_dec(temporal_coords, num_tokens_per_frame)
562
+ # Add temporal encoding w/o cls token
563
+ x_ = x_ + temporal_encoding
564
+ if self.location_encoding:
565
+ location_encoding = self.location_embed_dec(location_coords)
566
+ # Add location encoding w/o cls token
567
+ x_ = x_ + location_encoding
568
+
569
+ # append cls token
570
+ x = torch.cat([x[:, :1, :], x_], dim=1)
571
+
572
+ # apply Transformer layers (blocks)
573
+ for block in self.decoder_blocks:
574
+ x = block(x)
575
+ x = self.decoder_norm(x)
576
+
577
+ # predictor projection
578
+ pred = self.decoder_pred(x)
579
+
580
+ # remove cls token
581
+ pred = pred[:, 1:, :]
582
+
583
+ return pred
584
+
585
+
586
+ class PrithviMAE(nn.Module):
587
+ """ Prithvi Masked Autoencoder"""
588
+
589
+ def __init__(self,
590
+ img_size: int | Tuple[int, int] = 224,
591
+ patch_size: int | Tuple[int, int, int] = (1, 16, 16),
592
+ num_frames: int = 3,
593
+ in_chans: int = 3,
594
+ embed_dim: int = 1024,
595
+ depth: int = 24,
596
+ num_heads: int = 16,
597
+ decoder_embed_dim: int = 512,
598
+ decoder_depth: int = 8,
599
+ decoder_num_heads: int = 16,
600
+ mlp_ratio: float = 4.,
601
+ norm_layer: nn.Module = nn.LayerNorm,
602
+ norm_pix_loss: bool = False,
603
+ coords_encoding: List[str] | None = None,
604
+ coords_scale_learn: bool = False,
605
+ encoder_only: bool = False,
606
+ **kwargs,
607
+ ):
608
+ super().__init__()
609
+
610
+ self.encoder = PrithviViT(
611
+ img_size=img_size,
612
+ num_frames=num_frames,
613
+ patch_size=patch_size,
614
+ in_chans=in_chans,
615
+ embed_dim=embed_dim,
616
+ depth=depth,
617
+ num_heads=num_heads,
618
+ mlp_ratio=mlp_ratio,
619
+ norm_layer=norm_layer,
620
+ coords_encoding=coords_encoding,
621
+ coords_scale_learn=coords_scale_learn,
622
+ )
623
+
624
+ self.encoder_only = encoder_only
625
+
626
+ if not encoder_only:
627
+ self.decoder = MAEDecoder(
628
+ patch_size=patch_size,
629
+ grid_size=self.encoder.patch_embed.grid_size,
630
+ in_chans=in_chans,
631
+ encoder_embed_dim=embed_dim,
632
+ decoder_embed_dim=decoder_embed_dim,
633
+ depth=decoder_depth,
634
+ num_heads=decoder_num_heads,
635
+ mlp_ratio=mlp_ratio,
636
+ norm_layer=norm_layer,
637
+ coords_encoding=coords_encoding,
638
+ coords_scale_learn=coords_scale_learn,
639
+ )
640
+ else:
641
+ self.decoder = nn.Identity()
642
+
643
+ self.norm_pix_loss = norm_pix_loss
644
+
645
+ def patchify(self, pixel_values):
646
+ """
647
+ Args:
648
+ pixel_values (torch.FloatTensor of shape `(batch_size, num_channels, time, height, width)`):
649
+ Pixel values.
650
+
651
+ Returns:
652
+ torch.FloatTensor of shape `(batch_size, num_patches, patch_size[0]*patch_size[1]*patch_size[2] * num_channels)`:
653
+ Patchified pixel values.
654
+ """
655
+ patch_size_t, patch_size_h, patch_size_w = self.encoder.patch_embed.patch_size
656
+ num_channels = self.encoder.in_chans
657
+
658
+ # patchify
659
+ patchified_pixel_values = rearrange(pixel_values, 'b c (t s) (h p) (w q) -> b (t h w) (s p q c)',
660
+ c=num_channels, s=patch_size_t, p=patch_size_h, q=patch_size_w)
661
+
662
+
663
+ return patchified_pixel_values
664
+
665
+ def unpatchify(self, patchified_pixel_values, image_size: Tuple[int, int] | None = None):
666
+ """
667
+ Args:
668
+ patchified_pixel_values (`torch.FloatTensor` of shape
669
+ `(batch_size, num_patches, patch_size[0]*patch_size[1]*patch_size[2] * num_channels)`:
670
+ Patchified pixel values.
671
+ image_size (`Tuple[int, int]`, *optional*):
672
+ Original image size.
673
+
674
+ Returns:
675
+ `torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`:
676
+ Pixel values.
677
+ """
678
+ patch_size_t, patch_size_h, patch_size_w = self.encoder.patch_embed.patch_size
679
+ image_size = to_2tuple(image_size) if image_size is not None else self.encoder.img_size
680
+ original_height, original_width = image_size
681
+ num_patches_h = original_height // patch_size_h
682
+ num_patches_w = original_width // patch_size_w
683
+ num_channels = self.encoder.in_chans
684
+
685
+ pixel_values = rearrange(patchified_pixel_values, 'b (t h w) (s p q c) -> b c (t s) (h p) (w q)',
686
+ c=num_channels, h=num_patches_h, w=num_patches_w,
687
+ s=patch_size_t, p=patch_size_h, q=patch_size_w)
688
+ return pixel_values
689
+
690
+ def forward_loss(self, pixel_values, pred, mask):
691
+ """
692
+ Args:
693
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, time, height, width)`):
694
+ Pixel values.
695
+ pred (`torch.FloatTensor` of shape `(batch_size, num_patches, patch_size[0]*patch_size[1]*patch_size[2] * num_channels)`:
696
+ Predicted pixel values.
697
+ mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
698
+ Tensor indicating which patches are masked (1) and which are not (0).
699
+
700
+ Returns:
701
+ `torch.FloatTensor`: Pixel reconstruction loss.
702
+ """
703
+ target = self.patchify(pixel_values)
704
+ if self.norm_pix_loss:
705
+ mean = target.mean(dim=-1, keepdim=True)
706
+ var = target.var(dim=-1, keepdim=True)
707
+ target = (target - mean) / (var + 1.0e-6) ** 0.5
708
+
709
+ loss = (pred - target) ** 2
710
+ loss = loss.mean(dim=-1) # [N, L], mean loss per patch
711
+ loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches
712
+ return loss
713
+
714
+ def forward(
715
+ self,
716
+ pixel_values: torch.Tensor,
717
+ temporal_coords: None | torch.Tensor = None,
718
+ location_coords: None | torch.Tensor = None,
719
+ mask_ratio: float = 0.75
720
+ ):
721
+ if len(pixel_values.shape) == 4 and self.encoder.patch_embed.input_size[0] == 1:
722
+ # add time dim
723
+ pixel_values = pixel_values.unsqueeze(2)
724
+
725
+ latent, mask, ids_restore = self.encoder(pixel_values, temporal_coords, location_coords, mask_ratio)
726
+ pred = self.decoder(latent, ids_restore, temporal_coords, location_coords, input_size=pixel_values.shape)
727
+ loss = self.forward_loss(pixel_values, pred, mask)
728
+ return loss, pred, mask
729
+
730
+ def forward_features(
731
+ self,
732
+ x: torch.Tensor,
733
+ temporal_coords: None | torch.Tensor = None,
734
+ location_coords: None | torch.Tensor = None,
735
+ ) -> List[torch.Tensor]:
736
+ return self.encoder.forward_features(x, temporal_coords, location_coords)
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ timm
4
+ einops
5
+ rasterio