yunusserhat's picture
Create APP
894bc0c verified
import os
from typing import Any
import pytorch_lightning as L
import torch
import torch.nn as nn
from hydra.utils import instantiate
import copy
import pandas as pd
import numpy as np
class Geolocalizer(L.LightningModule):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
self.model = instantiate(cfg.network.instance)
if cfg.text_tuning:
self.text_model = instantiate(cfg.text_network.instance)
self.loss = instantiate(cfg.loss)
self.val_metrics = instantiate(cfg.val_metrics)
self.test_metrics = instantiate(cfg.test_metrics)
self.text_tuning = cfg.text_tuning
def training_step(self, batch, batch_idx):
pred = self.model(batch)
if self.text_tuning:
pred["text_features"] = self.text_model(batch)
loss = self.loss(pred, batch, average=True)
for metric_name, metric_value in loss.items():
self.log(
f"train/{metric_name}",
metric_value,
sync_dist=True,
on_step=True,
on_epoch=True,
)
return loss
@torch.no_grad()
def validation_step(self, batch, batch_idx):
pred = self.model(batch)
if self.text_tuning:
pred["text_features"] = self.text_model(batch)
loss = self.loss(pred, batch, average=True)["loss"]
self.val_metrics.update(pred, batch)
self.log("val/loss", loss, sync_dist=True, on_step=False, on_epoch=True)
def on_validation_epoch_end(self):
metrics = self.val_metrics.compute()
for metric_name, metric_value in metrics.items():
self.log(
f"val/{metric_name}",
metric_value,
sync_dist=True,
on_step=False,
on_epoch=True,
)
@torch.no_grad()
def test_step(self, batch, batch_idx):
pred = self.model(batch)
self.test_metrics.update(pred, batch)
def on_test_epoch_end(self):
metrics = self.test_metrics.compute()
for metric_name, metric_value in metrics.items():
self.log(
f"test/{metric_name}",
metric_value,
sync_dist=True,
on_step=False,
on_epoch=True,
)
def configure_optimizers(self):
lora_params = []
backbone_params = []
other_params = []
last_block_params = []
for name, param in self.model.named_parameters():
if "lora" in name:
lora_params.append(param)
elif "backbone" in name:
if self.cfg.optimizer.diff_backbone_last and ".11." in name:
last_block_params.append(param)
else:
backbone_params.append(param)
else:
other_params.append(param)
params_to_optimize = [{"params": other_params}]
if self.cfg.optimizer.unfreeze_lr:
params_to_optimize += [
{"params": backbone_params, "lr": self.cfg.optimizer.backbone_lr}
]
if self.cfg.optimizer.diff_backbone_last:
params_to_optimize += [
{
"params": last_block_params,
"lr": self.cfg.optimizer.last_block_lr,
}
]
if len(lora_params) > 0:
# LoRA params sometimes train better with a different lr (~1e-4 for CLIP)
params_to_optimize += [
{"params": lora_params, "lr": self.cfg.optimizer.lora_lr}
]
if self.cfg.optimizer.exclude_ln_and_biases_from_weight_decay:
parameters_names_wd = get_parameter_names(self.model, [nn.LayerNorm])
parameters_names_wd = [
name for name in parameters_names_wd if "bias" not in name
]
optimizer_grouped_parameters = [
{
"params": [
p
for n, p in self.model.named_parameters()
if n in parameters_names_wd
],
"weight_decay": self.cfg.optimizer.optim.weight_decay,
},
{
"params": [
p
for n, p in self.model.named_parameters()
if n not in parameters_names_wd
],
"weight_decay": 0.0,
},
]
optimizer = instantiate(
self.cfg.optimizer.optim, optimizer_grouped_parameters
)
else:
optimizer = instantiate(self.cfg.optimizer.optim, params_to_optimize)
scheduler = instantiate(self.cfg.lr_scheduler)(optimizer)
return [optimizer], [{"scheduler": scheduler, "interval": "step"}]
def lr_scheduler_step(self, scheduler, metric):
scheduler.step(self.global_step)
def get_parameter_names(model, forbidden_layer_types):
"""
Returns the names of the model parameters that are not inside a forbidden layer.
Taken from HuggingFace transformers.
"""
result = []
for name, child in model.named_children():
result += [
f"{name}.{n}"
for n in get_parameter_names(child, forbidden_layer_types)
if not isinstance(child, tuple(forbidden_layer_types))
]
# Add model specific parameters (defined with nn.Parameter) since they are not in any child.
result += list(model._parameters.keys())
return result