medical imaging
ultrasound
dl_us_sos_inversion / run_logger.py
laughingrice's picture
Upload 11 files
6ce7d82
raw
history blame
4.52 kB
"""
Support log functions
TODO: log model using mlflow.pytorch in parallel / addition to checkpointing
"""
import numpy as np
import h5py
import os
import argparse
import torch
import torchvision.utils as vutils
import pytorch_lightning as pl
class ImgCB(pl.Callback):
def __init__(self, **kwargs):
parser = ImgCB.add_argparse_args()
for action in parser._actions:
if action.dest in kwargs:
action.default = kwargs[action.dest]
args = parser.parse_args([])
self.__dict__.update(vars(args))
@staticmethod
def add_argparse_args(parent_parser=None):
parser = argparse.ArgumentParser(
prog='ImgCB',
usage=ImgCB.__doc__,
parents=[parent_parser] if parent_parser is not None else [],
add_help=False)
parser.add_argument('--img_ranges', default=[1300, 1800], nargs='*', help='Scaling range on output image, either pair, or set of pairs')
parser.add_argument('--err_ranges', default=[0, 50], nargs='*', help='Scaling range on error images, either pair, or set of pairs')
return parser
def log_images(self, mfl_logger, y, z, prefix):
img_ranges = tuple(self.img_ranges)
err_ranges = tuple(self.err_ranges)
#
for i in range(y.shape[1]):
if y.shape[1] > 1:
tag = f'_{i}_'
if len(self.img_ranges) > 2:
img_ranges = tuple(self.img_ranges[2*i, 2*i + 1])
if len(self.err_ranges) > 2:
err_ranges = tuple(self.err_ranges[2*i, 2*i + 1])
else:
tag = ''
mfl_logger.experiment.log_image(
mfl_logger.run_id,
(np.array(vutils.make_grid(
y[:, [i], ...].detach(),
normalize=True, value_range=img_ranges, nrow=6).cpu())[0, ...] * 255.).astype(np.int),
prefix + tag + '_labels.png')
mfl_logger.experiment.log_image(
mfl_logger.run_id,
(np.array(vutils.make_grid(
z[:, [i], ...].detach(),
normalize=True, value_range=img_ranges, nrow=6).cpu())[0, ...] * 255.).astype(np.int),
prefix + tag + '_outputs.png')
mfl_logger.experiment.log_image(
mfl_logger.run_id,
(np.array(vutils.make_grid(
torch.abs(y[:, [i], ...].detach() - z[:, [i], ...].detach()),
normalize=True, value_range=err_ranges, nrow=6).cpu())[0, ...] * 255.).astype(np.int),
prefix + tag + '_errors.png')
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
if batch_idx == 0:
with torch.no_grad():
x, y = batch
if pl_module.hparams.rand_output_crop:
x = x[..., :-pl_module.hparams.rand_output_crop, :]
y = y[..., :-pl_module.hparams.rand_output_crop * 2, :]
z = pl_module(x.to(pl_module.device))
if isinstance(z, tuple) or isinstance(z, list):
z = z[0]
self.log_images(pl_module.logger, y.to(pl_module.device), z, 'train_')
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
if batch_idx == 0:
with torch.no_grad():
x, y = batch
if pl_module.hparams.rand_output_crop:
x = x[..., :-pl_module.hparams.rand_output_crop, :]
y = y[..., :-pl_module.hparams.rand_output_crop * 2, :]
z = pl_module(x.to(pl_module.device))
if isinstance(z, tuple) or isinstance(z, list):
z = z[0]
self.log_images(pl_module.logger, y.to(pl_module.device), z, 'validate_')
class TestLogger(pl.Callback):
"""
pytorch_lightning Data saving logger for testing output
Warning !!! : this function is not multi GPU / multi device safe -- only run on a single gpu / device
"""
def __init__(self, fname: str = 'output.h5'):
self.fname = fname
def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
with h5py.File(self.fname, 'a') as f:
f[f'batch_{batch_idx:05}'] = outputs.to('cpu').numpy()
if len(batch) > 1:
f[f'labels_{batch_idx:05}'] = batch[1].to('cpu').numpy()