import argparse import os from datetime import datetime import lightning.pytorch import torch from datamodules.s2geo_dataset import S2GeoDataModule from lightning.pytorch.callbacks import ModelCheckpoint from lightning.pytorch.cli import LightningCLI from loss import SatCLIPLoss from model import SatCLIP torch.set_float32_matmul_precision('high') class SatCLIPLightningModule(lightning.pytorch.LightningModule): def __init__( self, embed_dim=512, image_resolution=256, vision_layers=12, vision_width=768, vision_patch_size=32, in_channels=4, le_type="grid", pe_type="siren", frequency_num=16, max_radius=260, min_radius=1, legendre_polys=16, harmonics_calculation="analytic", sh_embedding_dims=32, learning_rate=1e-4, weight_decay=0.01, num_hidden_layers=2, capacity=256, ) -> None: super().__init__() self.model = SatCLIP( embed_dim=embed_dim, image_resolution=image_resolution, vision_layers=vision_layers, vision_width=vision_width, vision_patch_size=vision_patch_size, in_channels=in_channels, le_type=le_type, pe_type=pe_type, frequency_num=frequency_num, max_radius=max_radius, min_radius=min_radius, legendre_polys=legendre_polys, harmonics_calculation=harmonics_calculation, sh_embedding_dims=sh_embedding_dims, num_hidden_layers=num_hidden_layers, capacity=capacity, ) self.loss_fun = SatCLIPLoss() self.learning_rate = learning_rate self.weight_decay = weight_decay self.save_hyperparameters() def common_step(self, batch, batch_idx): images = batch["image"] t_points = batch["point"].float() logits_per_image, logits_per_coord = self.model(images, t_points) return self.loss_fun(logits_per_image, logits_per_coord) def training_step(self, batch, batch_idx): loss = self.common_step(batch, batch_idx) self.log("train_loss", loss) return loss def validation_step(self, batch, batch_idx): loss = self.common_step(batch, batch_idx) self.log("val_loss", loss) return loss def configure_optimizers(self): exclude = ( lambda n, p: p.ndim < 2 or "bn" in n or "ln" in n or "bias" in n or "logit_scale" in n ) include = lambda n, p: not exclude(n, p) named_parameters = list(self.model.named_parameters()) gain_or_bias_params = [ p for n, p in named_parameters if exclude(n, p) and p.requires_grad ] rest_params = [ p for n, p in named_parameters if include(n, p) and p.requires_grad ] optimizer = torch.optim.AdamW( [ {"params": gain_or_bias_params, "weight_decay": 0.0}, { "params": rest_params, "weight_decay": self.weight_decay, }, # specify in configs/default.yaml ], lr=self.learning_rate, # specify in configs/default.yaml ) return optimizer class MyLightningCLI(LightningCLI): def add_arguments_to_parser(self, parser): parser.add_argument("--watchmodel", action="store_true") def cli_main(default_config_filename="/configs/default.yaml"): save_config_fn = default_config_filename.replace(".yaml", "-latest.yaml") # modify configs/default.yaml for learning rate etc. cli = MyLightningCLI( model_class=SatCLIPLightningModule, datamodule_class=S2GeoDataModule, save_config_kwargs=dict( config_filename=save_config_fn, overwrite=True, ), trainer_defaults={ "accumulate_grad_batches": 16, "log_every_n_steps": 10, }, parser_kwargs={"default_config_files": [default_config_filename]}, seed_everything_default=0, run=False, ) ts = datetime.now().strftime("%Y-%m-%d_%H:%M:%S") run_name = f"SatCLIP_S2_{ts}" if cli.trainer.logger is not None: cli.trainer.logger.experiment.name = run_name # this seems to be necessary to force logging of datamodule hyperparams cli.trainer.logger.log_hyperparams(cli.datamodule.hparams) cli.trainer.fit( model=cli.model, datamodule=cli.datamodule, ) if __name__ == "__main__": config_fn = "./configs/default.yaml" #A100 go vroom vroom 🚗💨 if torch.cuda.get_device_name(device=0)=='NVIDIA A100 80GB PCIe': torch.backends.cuda.matmul.allow_tf32 = True print('Superfastmode! 🚀') else: torch.backends.cuda.matmul.allow_tf32 = False cli_main(config_fn)