{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from model import *\n", "from loss import *\n", "from data import *\n", "from torch import optim\n", "from tqdm import tqdm\n", "\n", "import pytorch_lightning as pl\n", "from torchmetrics.detection import MeanAveragePrecision\n", "from pytorch_lightning.loggers import TensorBoardLogger" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "_, _, test_dataset = get_datasets()" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "class LitModel(pl.LightningModule):\n", " def __init__(self):\n", " super().__init__()\n", " self.model = YOLOStamp()\n", " self.criterion = YOLOLoss()\n", " self.val_map = MeanAveragePrecision(box_format='xywh', iou_type='bbox')\n", " \n", " def forward(self, x):\n", " return self.model(x)\n", "\n", " def configure_optimizers(self):\n", " optimizer = optim.AdamW(self.parameters(), lr=1e-3)\n", " # return optimizer\n", " scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, 1000)\n", " return {\"optimizer\": optimizer, \"lr_scheduler\": scheduler}\n", "\n", " def training_step(self, batch, batch_idx):\n", " images, targets = batch\n", " tensor_images = torch.stack(images)\n", " tensor_targets = torch.stack(targets)\n", " output = self.model(tensor_images)\n", " loss = self.criterion(output, tensor_targets)\n", " self.log(\"train_loss\", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)\n", " return loss\n", "\n", " def validation_step(self, batch, batch_idx):\n", " images, targets = batch\n", " tensor_images = torch.stack(images)\n", " tensor_targets = torch.stack(targets)\n", " output = self.model(tensor_images)\n", " loss = self.criterion(output, tensor_targets)\n", " self.log(\"val_loss\", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)\n", "\n", " for i in range(len(images)):\n", " boxes = output_tensor_to_boxes(output[i].detach().cpu())\n", " boxes = nonmax_suppression(boxes)\n", " target = target_tensor_to_boxes(targets[i])[::BOX]\n", " if not boxes:\n", " boxes = torch.zeros((1, 5))\n", " preds = [\n", " dict(\n", " boxes=torch.tensor(boxes)[:, :4].clone().detach(),\n", " scores=torch.tensor(boxes)[:, 4].clone().detach(),\n", " labels=torch.zeros(len(boxes)),\n", " )\n", " ]\n", " target = [\n", " dict(\n", " boxes=torch.tensor(target),\n", " labels=torch.zeros(len(target)),\n", " )\n", " ]\n", " self.val_map.update(preds, target)\n", " \n", " def on_validation_epoch_end(self):\n", " mAPs = {\"val_\" + k: v for k, v in self.val_map.compute().items()}\n", " mAPs_per_class = mAPs.pop(\"val_map_per_class\")\n", " mARs_per_class = mAPs.pop(\"val_mar_100_per_class\")\n", " self.log_dict(mAPs)\n", " self.val_map.reset()\n", "\n", " image = test_dataset[randint(0, len(test_dataset) - 1)][0].to(self.device)\n", " output = self.model(image.unsqueeze(0))\n", " boxes = output_tensor_to_boxes(output[0].detach().cpu())\n", " boxes = nonmax_suppression(boxes)\n", " img = image.permute(1, 2, 0).cpu().numpy()\n", " img = visualize_bbox(img.copy(), boxes=boxes)\n", " img = (255. * (img * np.array(STD) + np.array(MEAN))).astype(np.uint8)\n", " \n", " self.logger.experiment.add_image(\"detected boxes\", torch.tensor(img).permute(2, 0, 1), self.current_epoch)\n" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "litmodel = LitModel()" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "logger = TensorBoardLogger(\"detection_logs\")" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "epochs = 100" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "train_loader, val_loader = get_loaders(batch_size=8)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "trainer = pl.Trainer(accelerator=\"auto\", max_epochs=epochs, logger=logger)\n", "trainer.fit(model=litmodel, train_dataloaders=train_loader, val_dataloaders=val_loader)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%tensorboard" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.0" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }