File size: 4,523 Bytes
6ce7d82 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 |
"""
Support log functions
TODO: log model using mlflow.pytorch in parallel / addition to checkpointing
"""
import numpy as np
import h5py
import os
import argparse
import torch
import torchvision.utils as vutils
import pytorch_lightning as pl
class ImgCB(pl.Callback):
def __init__(self, **kwargs):
parser = ImgCB.add_argparse_args()
for action in parser._actions:
if action.dest in kwargs:
action.default = kwargs[action.dest]
args = parser.parse_args([])
self.__dict__.update(vars(args))
@staticmethod
def add_argparse_args(parent_parser=None):
parser = argparse.ArgumentParser(
prog='ImgCB',
usage=ImgCB.__doc__,
parents=[parent_parser] if parent_parser is not None else [],
add_help=False)
parser.add_argument('--img_ranges', default=[1300, 1800], nargs='*', help='Scaling range on output image, either pair, or set of pairs')
parser.add_argument('--err_ranges', default=[0, 50], nargs='*', help='Scaling range on error images, either pair, or set of pairs')
return parser
def log_images(self, mfl_logger, y, z, prefix):
img_ranges = tuple(self.img_ranges)
err_ranges = tuple(self.err_ranges)
#
for i in range(y.shape[1]):
if y.shape[1] > 1:
tag = f'_{i}_'
if len(self.img_ranges) > 2:
img_ranges = tuple(self.img_ranges[2*i, 2*i + 1])
if len(self.err_ranges) > 2:
err_ranges = tuple(self.err_ranges[2*i, 2*i + 1])
else:
tag = ''
mfl_logger.experiment.log_image(
mfl_logger.run_id,
(np.array(vutils.make_grid(
y[:, [i], ...].detach(),
normalize=True, value_range=img_ranges, nrow=6).cpu())[0, ...] * 255.).astype(np.int),
prefix + tag + '_labels.png')
mfl_logger.experiment.log_image(
mfl_logger.run_id,
(np.array(vutils.make_grid(
z[:, [i], ...].detach(),
normalize=True, value_range=img_ranges, nrow=6).cpu())[0, ...] * 255.).astype(np.int),
prefix + tag + '_outputs.png')
mfl_logger.experiment.log_image(
mfl_logger.run_id,
(np.array(vutils.make_grid(
torch.abs(y[:, [i], ...].detach() - z[:, [i], ...].detach()),
normalize=True, value_range=err_ranges, nrow=6).cpu())[0, ...] * 255.).astype(np.int),
prefix + tag + '_errors.png')
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
if batch_idx == 0:
with torch.no_grad():
x, y = batch
if pl_module.hparams.rand_output_crop:
x = x[..., :-pl_module.hparams.rand_output_crop, :]
y = y[..., :-pl_module.hparams.rand_output_crop * 2, :]
z = pl_module(x.to(pl_module.device))
if isinstance(z, tuple) or isinstance(z, list):
z = z[0]
self.log_images(pl_module.logger, y.to(pl_module.device), z, 'train_')
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
if batch_idx == 0:
with torch.no_grad():
x, y = batch
if pl_module.hparams.rand_output_crop:
x = x[..., :-pl_module.hparams.rand_output_crop, :]
y = y[..., :-pl_module.hparams.rand_output_crop * 2, :]
z = pl_module(x.to(pl_module.device))
if isinstance(z, tuple) or isinstance(z, list):
z = z[0]
self.log_images(pl_module.logger, y.to(pl_module.device), z, 'validate_')
class TestLogger(pl.Callback):
"""
pytorch_lightning Data saving logger for testing output
Warning !!! : this function is not multi GPU / multi device safe -- only run on a single gpu / device
"""
def __init__(self, fname: str = 'output.h5'):
self.fname = fname
def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
with h5py.File(self.fname, 'a') as f:
f[f'batch_{batch_idx:05}'] = outputs.to('cpu').numpy()
if len(batch) > 1:
f[f'labels_{batch_idx:05}'] = batch[1].to('cpu').numpy() |