medical imaging
ultrasound
laughingrice's picture
Upload 11 files
6ce7d82
raw
history blame
18.7 kB
"""
Network definition file
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchaudio.functional import lfilter
from pytorch_lightning import LightningModule
import numpy as np
from scipy.signal import butter, gaussian
from copy import deepcopy
import argparse
class Net(LightningModule):
def __init__(self, **kwargs):
super().__init__()
parser = Net.add_model_specific_args()
for action in parser._actions:
if action.dest in kwargs:
action.default = kwargs[action.dest]
args = parser.parse_args([])
self.hparams.update(vars(args))
if not hasattr(self, f"_init_{self.hparams.net_type}_net"):
raise ValueError(f"Unknown net type {self.hparams.net_type}")
self._net = eval(f"self._init_{self.hparams.net_type}_net(n_inputs={self.hparams.n_inputs}, n_outputs={self.hparams.n_outputs})")
if self.hparams.bias is not None:
if hasattr(self.hparams.bias, "__iter__"):
for i in range(len(self.hparams.bias)):
self._net[-1].c.bias[i].data.fill_(self.hparams.bias[i])
else:
self._net[-1].c.bias.data.fill_(self.hparams.bias)
@staticmethod
def _init_tbme2_net(n_inputs: int = 1, n_outputs: int = 1):
return nn.Sequential(
# Encoder
DownBlock(n_inputs, 32, 32, 3, stride=[1, 2], pool=None, push=False, layers=3),
DownBlock(32, 32, 32, 3, stride=[1, 2], pool=None, push=False, layers=3),
DownBlock(32, 32, 32, 3, stride=[1, 2], pool=None, push=False, layers=3),
DownBlock(32, 32, 32, 3, stride=[1, 2], pool=None, push=True, layers=3),
DownBlock(32, 32, 64, 3, stride=1, pool=[2, 2], push=True, layers=3),
DownBlock(64, 64, 128, 3, stride=1, pool=[2, 2], push=True, layers=3),
DownBlock(128, 128, 512, 3, stride=1, pool=[2, 2], push=False, layers=3),
# Decoder
UpBlock(512, 128, 3, scale_factor=2, pop=False, layers=3),
UpBlock(256, 64, 3, scale_factor=2, pop=True, layers=3),
UpBlock(128, 32, 3, scale_factor=2, pop=True, layers=3),
UpBlock(64, 32, 3, scale_factor=2, pop=True, layers=3),
UpStep(32, 32, 3, scale_factor=1),
Compress(32, n_outputs))
@staticmethod
def _init_embc_net(n_inputs: int = 1, n_outputs: int = 1):
return nn.Sequential(
# Encoder
DownBlock(n_inputs, 32, 32, 15, [1, 2], None, layers=1),
DownBlock(32, 32, 32, 13, [1, 2], None, layers=1),
DownBlock(32, 32, 32, 11, [1, 2], None, layers=1),
DownBlock(32, 32, 32, 9, [1, 2], None, True, layers=1),
DownBlock(32, 32, 64, 7, 1, [2, 2], True, layers=1),
DownBlock(64, 64, 128, 5, 1, [2, 2], True, layers=1),
DownBlock(128, 128, 512, 3, 1, [2, 2], layers=1),
# Decoder
UpBlock(512, 128, 5, 2, layers=1),
UpBlock(256, 64, 7, 2, True, layers=1),
UpBlock(128, 32, 9, 2, True, layers=1),
UpBlock(64, 32, 11, 2, True, layers=1),
UpStep(32, 32, 3, 1),
Compress(32, n_outputs))
@staticmethod
def _init_tbme_net(n_inputs: int = 1, n_outputs: int = 1):
return nn.Sequential(
# Encoder
DownBlock(n_inputs, 32, 32, 3, [1, 2], None, layers=1),
DownBlock(32, 32, 32, 3, [1, 2], None, layers=1),
DownBlock(32, 32, 32, 3, [1, 2], None, layers=1),
DownBlock(32, 32, 32, 3, [1, 2], None, True, layers=1),
DownBlock(32, 32, 64, 3, 1, [2, 2], True, layers=1),
DownBlock(64, 64, 128, 3, 1, [2, 2], True, layers=1),
DownBlock(128, 128, 512, 3, 1, [2, 2], layers=1),
# Decoder
UpBlock(512, 128, 3, 2, layers=1),
UpBlock(256, 64, 3, 2, True, layers=1),
UpBlock(128, 32, 3, 2, True, layers=1),
UpBlock(64, 32, 3, 2, True, layers=1),
UpStep(32, 32, 3, 1),
Compress(32, n_outputs))
@staticmethod
def add_model_specific_args(parent_parser=None):
parser = argparse.ArgumentParser(
prog="Net",
usage=Net.__doc__,
parents=[parent_parser] if parent_parser is not None else [],
add_help=False)
parser.add_argument("--random_mirror", type=int, nargs="?", default=1, help="Randomly mirror data to increase diversity when using flat plate wave")
parser.add_argument("--noise_std", type=float, nargs="*", help="range of std of random noise to add to the input signal [0 val] or [min max]")
parser.add_argument("--quantization", type=float, nargs="?", help="Quantization noise")
parser.add_argument("--rand_drop", type=int, nargs="*", help="Random drop lines, between 0 and value lines if single value, or between two values")
parser.add_argument("--normalize_net", type=float, default=0.0, help="Coefficient for normalizing network weights")
parser.add_argument("--learning_rate", type=float, default=5e-3, help="Learning rate to use for optimizer")
parser.add_argument("--lr_sched_step", type=int, default=15, help="Learning decay, update step size")
parser.add_argument("--lr_sched_gamma", type=float, default=0.65, help="Learning decay gamma")
parser.add_argument("--net_type", default="tbme2", help="The network to use [tbme2/embc/tbme]")
parser.add_argument("--bias", type=float, nargs="*", help="Set bias on last layer, set to 1500 when training from scratch on SoS output")
parser.add_argument("--decimation", type=int, help="Subsample phase signal")
parser.add_argument("--phase_inv", type=int, default=0, help="Use phase for inversion")
parser.add_argument("--center_freq", type=float, default=5e6, help="Matched filter and IQ demodulation frequency")
parser.add_argument("--n_periods", type=float, default=5, help="Matched filter length")
parser.add_argument("--matched_filter", type=int, nargs="?", default=0, help="Apply matched filter, set to 1 to run during forward pass, 2 to run during preprocessing phase (before adding noise)")
parser.add_argument("--rand_output_crop", type=int, help="Subsample phase signal")
parser.add_argument("--rand_scale", type=float, nargs="*", help="Random scaling range [min max] -- (10 ** rand_scale)")
parser.add_argument("--rand_gain", type=float, nargs="*", help="Random gain coefficient range [min max] -- (10 ** rand_gain)")
parser.add_argument("--n_inputs", type=int, default=1, help="Number of input layers")
parser.add_argument("--n_outputs", type=int, default=1, help="Number of output layers")
parser.add_argument("--scale_losses", type=float, nargs="*", help="Scale each layer of the loss function by given value")
return parser
def forward(self, x) -> torch.Tensor:
# Matched filter
if self.hparams.matched_filter == 1:
x = self._matched_filter(x)
# compute IQ phase if in phase_inv mode
if self.hparams.phase_inv:
x = self._phase(x)
# Decimation
if self.hparams.decimation != 1:
x = x[..., ::self.hparams.decimation]
# Apply network
x = self._net((x, []))
return x
def _matched_filter(self, x):
sampling_freq = 40e6
samples_per_cycle = sampling_freq / self.hparams.center_freq
n_samples = np.ceil(samples_per_cycle * self.hparams.n_periods + 1)
signal = torch.sin(torch.arange(n_samples, device=x.device) / samples_per_cycle * 2 * np.pi) * torch.from_numpy(gaussian(n_samples, (n_samples - 1) / 6).astype(np.single)).to(x.device)
return torch.nn.functional.conv1d(x.reshape(x.shape[:2] + (-1,)), signal.reshape(1, 1, -1), padding="same").reshape(x.shape)
def _phase(self, x):
f = self.hparams.center_freq
F = 40e6
N = x.shape[-1]
n = int(round(f * N / F))
X = torch.fft.fft(x, dim=-1)
X[..., (2 * n + 1):] = 0
X[..., :(2 * n + 1)] *= torch.from_numpy(gaussian(2 * n + 1, 2 * n / 6).astype(np.single)).to(x.device)
X = X.roll(-n, dims=-1)
x = torch.fft.ifft(X, dim=-1)
return x.angle()
def _preprocess(self, x):
# Matched filter
if self.hparams.matched_filter == 2:
x = self._matched_filter(x)
# Gaussian (normal) noise - random scaling, normalized to signal STD
if (ns := self.hparams.noise_std) and len(ns):
scl = ns[0] if len(ns) == 1 else torch.rand([x.shape[0]] + [1] * 3).to(x.device) * (ns[-1] - ns[-2]) + ns[-2]
scl *= x.std()
x += torch.empty_like(x).normal_() * scl
# Random multiplicative scaling
if (rs := self.hparams.rand_scale) and len(rs):
x *= 10 ** (torch.rand([x.shape[0]] + [1] * 3).to(x.device) * (rs[-1] - rs[-2]) + rs[-2])
# Random exponential gain
if (gs := self.hparams.rand_gain) and len(gs):
gain = torch.FloatTensor([10.0]).to(x.device) ** \
(torch.rand([x.shape[0]] + [1] * 3).to(x.device) * ((gs[-1] - gs[-2]) + gs[-2]) *
torch.linspace(0, 1, x.shape[-1]).to(x.device).view(1, 1, 1, -1))
x *= gain
# Quantization noise, to emulated ADC
if (quantization := self.hparams.quantization) is not None:
x = (x * quantization).round() * (1.0 / quantization)
# Randomly zero out some of the channels
if (rand_drop := self.hparams.rand_drop) and len(rand_drop):
if len(rand_drop) == 1:
rand_drop = [0, ] + rand_drop
for i in range(x.shape[0]):
lines = np.random.randint(0, x.shape[2], np.random.randint(rand_drop[0], rand_drop[1] + 1))
x[i, :, lines, :] = 0.
return x
def _log_losses(self, outputs: torch.Tensor, labels: torch.Tensor, prefix: str = ""):
diff = torch.abs(labels.detach() - outputs.detach())
s1 = int(diff.shape[-1] * (1.0 / 3.0))
s2 = int(diff.shape[-1] * (2.0 / 3.0))
for i in range(diff.shape[1]):
tag = f"{i}_" if diff.shape[1] > 1 else ""
losses = {
f"{prefix + tag}rmse": torch.sqrt(torch.mean(diff[:, i, ...] * diff[:, i, ...])).item(),
f"{prefix + tag}mean": torch.mean(diff[:, i, ...]).item(),
f"{prefix + tag}short": torch.mean(diff[:, i, :, :s1]).item(),
f"{prefix + tag}med": torch.mean(diff[:, i, :, s1:s2]).item(),
f"{prefix + tag}long": torch.mean(diff[:, i, :, s2:]).item()}
self.log_dict(losses, prog_bar=True)
def training_step(self, batch, batch_idx):
if self.hparams.random_mirror:
mirror = np.random.randint(0, 2, batch[0].shape[0])
for b in batch:
for i, m in enumerate(mirror):
if not m:
continue
b[i, ...] = b[i, :, range(b.shape[-2] - 1, -1, -1), :] # Pytorch does not handle negative steps
loss = self._common_step(batch, batch_idx, "train_")
if self.hparams.normalize_net:
for W in self.parameters():
loss += self.hparams.normalize_net * W.norm(2)
return loss
def validation_step(self, batch, batch_idx):
return self._common_step(batch, batch_idx, "validate_")
def test_step(self, batch, batch_idx):
return self._common_step(batch, batch_idx, "test_")
def predict_step(self, batch, batch_idx):
x = batch[0]
x = self._preprocess(x)
z = self(x)
if isinstance(z, tuple):
z = z[0]
return z
def _common_step(self, batch, batch_idx, prefix):
x, y = batch
if self.hparams.rand_output_crop:
crop = np.random.randint(0, self.hparams.rand_output_crop, batch[0].shape[0])
for i, c in enumerate(crop):
if not c:
continue
x[i, :, :-c, :] = x[i, :, c:, :].clone()
y[i, :, :-c*2, :] = \
y[i, :, c*2-1:-1, :].clone() if np.random.randint(2) else \
y[i, :, c*2:, :].clone()
x = x[..., :-self.hparams.rand_output_crop, :]
y = y[..., :-self.hparams.rand_output_crop*2, :]
x = self._preprocess(x)
z = self(x)
outputs = z[0] if isinstance(z, tuple) or isinstance(z, list) else z
self._log_losses(outputs, y, prefix)
if (self.hparams.scale_losses) and len(self.hparams.scale_losses):
s = torch.FloatTensor(self.hparams.scale_losses).to(y.device).view(1, -1, 1, 1)
loss = F.mse_loss(s * z, s * y)
else:
loss = F.mse_loss(y, outputs)
self.log(prefix + "loss", np.sqrt(loss.item()))
return loss
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, self.hparams.lr_sched_step, self.hparams.lr_sched_gamma)
return [optimizer], [scheduler]
class DownStep(nn.Module):
"""
Down scaling step in the encoder decoder network
"""
def __init__(self, in_channels: int, out_channels: int, kernel_size: tuple, stride: int = 1, pool: tuple = None) -> None:
"""Constructor
Arguments:
in_channels {int} -- Number of input channels for 2D convolution
out_channels {int} -- Number of output channels for 2D convolution
kernel_size {tuple} -- Convolution kernel size
Keyword Arguments:
stride {int} -- Stride of convolution, set to 1 to disable (default: {1})
pool {tuple} -- max pulling size, set to None to disable (default: {None})
"""
super(DownStep, self).__init__()
self.c = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=kernel_size // 2)
self.n = nn.BatchNorm2d(out_channels)
self.pool = pool
def forward(self, x: torch.tensor) -> torch.tensor:
"""Run the forward step
Arguments:
x {torch.tensor} -- input tensor
Returns:
torch.tensor -- output tensor
"""
x = self.c(x)
x = F.relu(x)
if self.pool is not None:
x = F.max_pool2d(x, self.pool)
x = self.n(x)
return x
class UpStep(nn.Module):
"""
Up scaling step in the encoder decoder network
"""
def __init__(self, in_channels: int, out_channels: int, kernel_size: int, scale_factor: int = 2) -> None:
"""Constructor
Arguments:
in_channels {int} -- Number of input channels for 2D convolution
out_channels {int} -- Number of output channels for 2D convolution
kernel_size {int} -- Convolution kernel size
Keyword Arguments:
scale_factor {int} -- Upsampling scaling factor (default: {2})
"""
super(UpStep, self).__init__()
self.c = nn.Conv2d(in_channels, out_channels, kernel_size, padding=kernel_size // 2)
self.n = nn.BatchNorm2d(out_channels)
self.scale_factor = scale_factor
def forward(self, x: torch.tensor) -> torch.tensor:
"""Run the forward step
Arguments:
x {torch.tensor} -- input tensor
Returns:
torch.tensor -- output tensor
"""
if isinstance(x, tuple):
x = x[0]
if self.scale_factor != 1:
x = F.interpolate(x, scale_factor=self.scale_factor)
x = self.c(x)
x = F.relu(x)
x = self.n(x)
return x
class Compress(nn.Module):
"""
Up scaling step in the encoder decoder network
"""
def __init__(self, in_channels: int, out_channels: int = 1, kernel_size: int = 1, scale_factor: int = 1) -> None:
"""Constructor
Arguments:
in_channels {int} -- [description]
Keyword Arguments:
out_channels {int} -- [description] (default: {1})
kernel_size {int} -- [description] (default: {1})
"""
super(Compress, self).__init__()
self.scale_factor = scale_factor
self.c = nn.Conv2d(in_channels, out_channels, kernel_size, padding=kernel_size // 2)
def forward(self, x: torch.tensor) -> torch.tensor:
"""Run the forward step
Arguments:
x {torch.tensor} -- input tensor
Returns:
torch.tensor -- output tensor
"""
if isinstance(x, tuple) or isinstance(x, list):
x = x[0]
x = self.c(x)
if self.scale_factor != 1:
x = F.interpolate(x, scale_factor=self.scale_factor)
return x
class DownBlock(nn.Module):
def __init__(
self,
in_chan: int, inter_chan: int, out_chan: int,
kernel_size: int = 3, stride: int = 1, pool: tuple = None,
push: bool = False,
layers: int = 3):
super().__init__()
self.s = []
for i in range(layers):
self.s.append(deepcopy(DownStep(
in_chan if i == 0 else inter_chan,
inter_chan if i < layers - 1 else out_chan,
kernel_size,
1 if i < layers - 1 else stride,
None if i < layers - 1 else pool)))
self.s = nn.Sequential(*self.s)
self.push = push
def forward(self, x: torch.tensor) -> torch.tensor:
i, s = x
i = self.s(i)
if self.push:
s.append(i)
return i, s
class UpBlock(nn.Module):
def __init__(
self,
in_chan: int, out_chan: int,
kernel_size: int, scale_factor: int = 2,
pop: bool = False,
layers: int = 3):
super().__init__()
self.s = []
for i in range(layers):
self.s.append(deepcopy(UpStep(
in_chan if i == 0 else out_chan,
out_chan,
kernel_size,
1 if i < layers - 1 else scale_factor)))
self.s = nn.Sequential(*self.s)
self.pop = pop
def forward(self, x: torch.tensor) -> torch.tensor:
i, s = x
if self.pop:
i = torch.cat((i, s.pop()), dim=1)
i = self.s(i)
return i, s