|
import argparse |
|
import os |
|
from datetime import datetime |
|
|
|
import lightning.pytorch |
|
import torch |
|
from datamodules.s2geo_dataset import S2GeoDataModule |
|
from lightning.pytorch.callbacks import ModelCheckpoint |
|
from lightning.pytorch.cli import LightningCLI |
|
from loss import SatCLIPLoss |
|
from model import SatCLIP |
|
|
|
torch.set_float32_matmul_precision('high') |
|
|
|
class SatCLIPLightningModule(lightning.pytorch.LightningModule): |
|
def __init__( |
|
self, |
|
embed_dim=512, |
|
image_resolution=256, |
|
vision_layers=12, |
|
vision_width=768, |
|
vision_patch_size=32, |
|
in_channels=4, |
|
le_type="grid", |
|
pe_type="siren", |
|
frequency_num=16, |
|
max_radius=260, |
|
min_radius=1, |
|
legendre_polys=16, |
|
harmonics_calculation="analytic", |
|
sh_embedding_dims=32, |
|
learning_rate=1e-4, |
|
weight_decay=0.01, |
|
num_hidden_layers=2, |
|
capacity=256, |
|
) -> None: |
|
super().__init__() |
|
|
|
self.model = SatCLIP( |
|
embed_dim=embed_dim, |
|
image_resolution=image_resolution, |
|
vision_layers=vision_layers, |
|
vision_width=vision_width, |
|
vision_patch_size=vision_patch_size, |
|
in_channels=in_channels, |
|
le_type=le_type, |
|
pe_type=pe_type, |
|
frequency_num=frequency_num, |
|
max_radius=max_radius, |
|
min_radius=min_radius, |
|
legendre_polys=legendre_polys, |
|
harmonics_calculation=harmonics_calculation, |
|
sh_embedding_dims=sh_embedding_dims, |
|
num_hidden_layers=num_hidden_layers, |
|
capacity=capacity, |
|
) |
|
|
|
self.loss_fun = SatCLIPLoss() |
|
self.learning_rate = learning_rate |
|
self.weight_decay = weight_decay |
|
self.save_hyperparameters() |
|
|
|
def common_step(self, batch, batch_idx): |
|
images = batch["image"] |
|
t_points = batch["point"].float() |
|
logits_per_image, logits_per_coord = self.model(images, t_points) |
|
return self.loss_fun(logits_per_image, logits_per_coord) |
|
|
|
def training_step(self, batch, batch_idx): |
|
loss = self.common_step(batch, batch_idx) |
|
self.log("train_loss", loss) |
|
return loss |
|
|
|
def validation_step(self, batch, batch_idx): |
|
loss = self.common_step(batch, batch_idx) |
|
self.log("val_loss", loss) |
|
return loss |
|
|
|
def configure_optimizers(self): |
|
exclude = ( |
|
lambda n, p: p.ndim < 2 |
|
or "bn" in n |
|
or "ln" in n |
|
or "bias" in n |
|
or "logit_scale" in n |
|
) |
|
include = lambda n, p: not exclude(n, p) |
|
|
|
named_parameters = list(self.model.named_parameters()) |
|
gain_or_bias_params = [ |
|
p for n, p in named_parameters if exclude(n, p) and p.requires_grad |
|
] |
|
rest_params = [ |
|
p for n, p in named_parameters if include(n, p) and p.requires_grad |
|
] |
|
|
|
optimizer = torch.optim.AdamW( |
|
[ |
|
{"params": gain_or_bias_params, "weight_decay": 0.0}, |
|
{ |
|
"params": rest_params, |
|
"weight_decay": self.weight_decay, |
|
}, |
|
], |
|
lr=self.learning_rate, |
|
) |
|
|
|
return optimizer |
|
|
|
|
|
class MyLightningCLI(LightningCLI): |
|
def add_arguments_to_parser(self, parser): |
|
parser.add_argument("--watchmodel", action="store_true") |
|
|
|
|
|
def cli_main(default_config_filename="/configs/default.yaml"): |
|
|
|
|
|
save_config_fn = default_config_filename.replace(".yaml", "-latest.yaml") |
|
|
|
cli = MyLightningCLI( |
|
model_class=SatCLIPLightningModule, |
|
datamodule_class=S2GeoDataModule, |
|
save_config_kwargs=dict( |
|
config_filename=save_config_fn, |
|
overwrite=True, |
|
), |
|
trainer_defaults={ |
|
"accumulate_grad_batches": 16, |
|
"log_every_n_steps": 10, |
|
}, |
|
parser_kwargs={"default_config_files": [default_config_filename]}, |
|
seed_everything_default=0, |
|
run=False, |
|
) |
|
|
|
ts = datetime.now().strftime("%Y-%m-%d_%H:%M:%S") |
|
run_name = f"SatCLIP_S2_{ts}" |
|
if cli.trainer.logger is not None: |
|
cli.trainer.logger.experiment.name = run_name |
|
|
|
cli.trainer.logger.log_hyperparams(cli.datamodule.hparams) |
|
|
|
cli.trainer.fit( |
|
model=cli.model, |
|
datamodule=cli.datamodule, |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
config_fn = "./configs/default.yaml" |
|
|
|
|
|
if torch.cuda.get_device_name(device=0)=='NVIDIA A100 80GB PCIe': |
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
print('Superfastmode! π') |
|
else: |
|
torch.backends.cuda.matmul.allow_tf32 = False |
|
cli_main(config_fn) |