medical imaging
ultrasound
File size: 4,523 Bytes
6ce7d82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
"""
Support log functions

TODO: log model using mlflow.pytorch in parallel / addition to checkpointing
"""

import numpy as np
import h5py
import os
import argparse

import torch
import torchvision.utils as vutils
import pytorch_lightning as pl


class ImgCB(pl.Callback):
    def __init__(self, **kwargs):
        parser = ImgCB.add_argparse_args()
        for action in parser._actions:
            if action.dest in kwargs:
                action.default = kwargs[action.dest]
        args = parser.parse_args([])
        self.__dict__.update(vars(args))

    @staticmethod
    def add_argparse_args(parent_parser=None):
        parser = argparse.ArgumentParser(
            prog='ImgCB',
            usage=ImgCB.__doc__,
            parents=[parent_parser] if parent_parser is not None else [],
            add_help=False)

        parser.add_argument('--img_ranges', default=[1300, 1800], nargs='*', help='Scaling range on output image, either pair, or set of pairs')
        parser.add_argument('--err_ranges', default=[0, 50], nargs='*', help='Scaling range on error images, either pair, or set of pairs')

        return parser

    def log_images(self, mfl_logger, y, z, prefix):
        img_ranges = tuple(self.img_ranges)
        err_ranges = tuple(self.err_ranges)
        # 
        for i in range(y.shape[1]):
            if y.shape[1] > 1:
                tag = f'_{i}_'

                if len(self.img_ranges) > 2:
                    img_ranges = tuple(self.img_ranges[2*i, 2*i + 1])
                if len(self.err_ranges) > 2:
                    err_ranges = tuple(self.err_ranges[2*i, 2*i + 1])
            else:
                tag = ''

            mfl_logger.experiment.log_image(
                mfl_logger.run_id,
                (np.array(vutils.make_grid(
                    y[:, [i], ...].detach(),
                    normalize=True, value_range=img_ranges, nrow=6).cpu())[0, ...] * 255.).astype(np.int),
                prefix + tag + '_labels.png')

            mfl_logger.experiment.log_image(
                mfl_logger.run_id,
                (np.array(vutils.make_grid(
                    z[:, [i], ...].detach(),
                    normalize=True, value_range=img_ranges, nrow=6).cpu())[0, ...] * 255.).astype(np.int),
                prefix + tag + '_outputs.png')

            mfl_logger.experiment.log_image(
                mfl_logger.run_id,
                (np.array(vutils.make_grid(
                    torch.abs(y[:, [i], ...].detach() - z[:, [i], ...].detach()),
                    normalize=True, value_range=err_ranges, nrow=6).cpu())[0, ...] * 255.).astype(np.int),
                prefix + tag + '_errors.png')

    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
        if batch_idx == 0:
            with torch.no_grad():
                x, y = batch

                if pl_module.hparams.rand_output_crop:
                    x = x[..., :-pl_module.hparams.rand_output_crop, :]
                    y = y[..., :-pl_module.hparams.rand_output_crop * 2, :]

                z = pl_module(x.to(pl_module.device))

                if isinstance(z, tuple) or isinstance(z, list):
                    z = z[0]

                self.log_images(pl_module.logger, y.to(pl_module.device), z, 'train_')

    def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
        if batch_idx == 0:
            with torch.no_grad():
                x, y = batch

                if pl_module.hparams.rand_output_crop:
                    x = x[..., :-pl_module.hparams.rand_output_crop, :]
                    y = y[..., :-pl_module.hparams.rand_output_crop * 2, :]

                z = pl_module(x.to(pl_module.device))

                if isinstance(z, tuple) or isinstance(z, list):
                    z = z[0]

                self.log_images(pl_module.logger, y.to(pl_module.device), z, 'validate_')


class TestLogger(pl.Callback):
    """
    pytorch_lightning Data saving logger for testing output
    Warning !!! : this function is not multi GPU / multi device safe -- only run on a single gpu / device
    """
    def __init__(self, fname: str = 'output.h5'):
        self.fname = fname

    def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
        with h5py.File(self.fname, 'a') as f:
            f[f'batch_{batch_idx:05}'] = outputs.to('cpu').numpy()
            if len(batch) > 1:
                f[f'labels_{batch_idx:05}'] = batch[1].to('cpu').numpy()