|
import numpy as np |
|
import torch |
|
from torchvision.transforms import Compose, Resize, ToTensor, CenterCrop, transforms |
|
|
|
from findings_classifier.chexpert_train import LitIGClassifier |
|
|
|
|
|
class ExpandChannels: |
|
""" |
|
Transforms an image with one channel to an image with three channels by copying |
|
pixel intensities of the image along the 1st dimension. |
|
""" |
|
|
|
def __call__(self, data: torch.Tensor) -> torch.Tensor: |
|
""" |
|
:param data: Tensor of shape [1, H, W]. |
|
:return: Tensor with channel copied three times, shape [3, H, W]. |
|
""" |
|
if data.shape[0] != 1: |
|
raise ValueError(f"Expected input of shape [1, H, W], found {data.shape}") |
|
return torch.repeat_interleave(data, 3, dim=0) |
|
|
|
def create_chest_xray_transform_for_inference(resize: int, center_crop_size: int) -> Compose: |
|
""" |
|
Defines the image transformation pipeline for Chest-Xray datasets. |
|
|
|
:param resize: The size to resize the image to. Linear resampling is used. |
|
Resizing is applied on the axis with smaller shape. |
|
:param center_crop_size: The size to center crop the image to. Square crop is applied. |
|
""" |
|
|
|
transforms = [Resize(resize), CenterCrop(center_crop_size), ToTensor(), ExpandChannels()] |
|
return Compose(transforms) |
|
|
|
def remap_to_uint8(array: np.ndarray, percentiles=None) -> np.ndarray: |
|
"""Remap values in input so the output range is :math:`[0, 255]`. |
|
|
|
Percentiles can be used to specify the range of values to remap. |
|
This is useful to discard outliers in the input data. |
|
|
|
:param array: Input array. |
|
:param percentiles: Percentiles of the input values that will be mapped to ``0`` and ``255``. |
|
Passing ``None`` is equivalent to using percentiles ``(0, 100)`` (but faster). |
|
:returns: Array with ``0`` and ``255`` as minimum and maximum values. |
|
""" |
|
array = array.astype(float) |
|
if percentiles is not None: |
|
len_percentiles = len(percentiles) |
|
if len_percentiles != 2: |
|
message = ( |
|
'The value for percentiles should be a sequence of length 2,' |
|
f' but has length {len_percentiles}' |
|
) |
|
raise ValueError(message) |
|
a, b = percentiles |
|
if a >= b: |
|
raise ValueError(f'Percentiles must be in ascending order, but a sequence "{percentiles}" was passed') |
|
if a < 0 or b > 100: |
|
raise ValueError(f'Percentiles must be in the range [0, 100], but a sequence "{percentiles}" was passed') |
|
cutoff: np.ndarray = np.percentile(array, percentiles) |
|
array = np.clip(array, *cutoff) |
|
array -= array.min() |
|
array /= array.max() |
|
array *= 255 |
|
return array.astype(np.uint8) |
|
|
|
def init_chexpert_predictor(): |
|
ckpt_path = f"findings_classifier/checkpoints/chexpert_train/ChexpertClassifier.ckpt" |
|
chexpert_cols = ["No Finding", "Enlarged Cardiomediastinum", |
|
"Cardiomegaly", "Lung Opacity", |
|
"Lung Lesion", "Edema", |
|
"Consolidation", "Pneumonia", |
|
"Atelectasis", "Pneumothorax", |
|
"Pleural Effusion", "Pleural Other", |
|
"Fracture", "Support Devices"] |
|
model = LitIGClassifier.load_from_checkpoint(ckpt_path, num_classes=14, class_names=chexpert_cols, strict=False) |
|
model.eval() |
|
model.cuda() |
|
model.half() |
|
cp_transforms = Compose([Resize(512), CenterCrop(488), ToTensor(), ExpandChannels()]) |
|
|
|
return model, np.asarray(model.class_names), cp_transforms |
|
|