In [1]:
from model import *
from loss import *
from data import *
from torch import optim
from tqdm import tqdm

import pytorch_lightning as pl
from torchmetrics.detection import MeanAveragePrecision
from pytorch_lightning.loggers import TensorBoardLogger

In [2]:
_, _, test_dataset = get_datasets()

In [3]:
class LitModel(pl.LightningModule):
 def __init__(self):
 super().__init__()
 self.model = YOLOStamp()
 self.criterion = YOLOLoss()
 self.val_map = MeanAveragePrecision(box_format='xywh', iou_type='bbox')
 
 def forward(self, x):
 return self.model(x)

 def configure_optimizers(self):
 optimizer = optim.AdamW(self.parameters(), lr=1e-3)
 # return optimizer
 scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, 1000)
 return {"optimizer": optimizer, "lr_scheduler": scheduler}

 def training_step(self, batch, batch_idx):
 images, targets = batch
 tensor_images = torch.stack(images)
 tensor_targets = torch.stack(targets)
 output = self.model(tensor_images)
 loss = self.criterion(output, tensor_targets)
 self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
 return loss

 def validation_step(self, batch, batch_idx):
 images, targets = batch
 tensor_images = torch.stack(images)
 tensor_targets = torch.stack(targets)
 output = self.model(tensor_images)
 loss = self.criterion(output, tensor_targets)
 self.log("val_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)

 for i in range(len(images)):
 boxes = output_tensor_to_boxes(output[i].detach().cpu())
 boxes = nonmax_suppression(boxes)
 target = target_tensor_to_boxes(targets[i])[::BOX]
 if not boxes:
 boxes = torch.zeros((1, 5))
 preds = [
 dict(
 boxes=torch.tensor(boxes)[:, :4].clone().detach(),
 scores=torch.tensor(boxes)[:, 4].clone().detach(),
 labels=torch.zeros(len(boxes)),
 )
 ]
 target = [
 dict(
 boxes=torch.tensor(target),
 labels=torch.zeros(len(target)),
 )
 ]
 self.val_map.update(preds, target)
 
 def on_validation_epoch_end(self):
 mAPs = {"val_" + k: v for k, v in self.val_map.compute().items()}
 mAPs_per_class = mAPs.pop("val_map_per_class")
 mARs_per_class = mAPs.pop("val_mar_100_per_class")
 self.log_dict(mAPs)
 self.val_map.reset()

 image = test_dataset[randint(0, len(test_dataset) - 1)][0].to(self.device)
 output = self.model(image.unsqueeze(0))
 boxes = output_tensor_to_boxes(output[0].detach().cpu())
 boxes = nonmax_suppression(boxes)
 img = image.permute(1, 2, 0).cpu().numpy()
 img = visualize_bbox(img.copy(), boxes=boxes)
 img = (255. * (img * np.array(STD) + np.array(MEAN))).astype(np.uint8)
 
 self.logger.experiment.add_image("detected boxes", torch.tensor(img).permute(2, 0, 1), self.current_epoch)


In [4]:
litmodel = LitModel()

In [5]:
logger = TensorBoardLogger("detection_logs")

In [7]:
epochs = 100

In [8]:
train_loader, val_loader = get_loaders(batch_size=8)

In [None]:
trainer = pl.Trainer(accelerator="auto", max_epochs=epochs, logger=logger)
trainer.fit(model=litmodel, train_dataloaders=train_loader, val_dataloaders=val_loader)

In [None]:
%tensorboard