ga89tiy
module device
0a8703d
import numpy as np
import torch
from torchvision.transforms import Compose, Resize, ToTensor, CenterCrop, transforms
from findings_classifier.chexpert_train import LitIGClassifier
class ExpandChannels:
"""
Transforms an image with one channel to an image with three channels by copying
pixel intensities of the image along the 1st dimension.
"""
def __call__(self, data: torch.Tensor) -> torch.Tensor:
"""
:param data: Tensor of shape [1, H, W].
:return: Tensor with channel copied three times, shape [3, H, W].
"""
if data.shape[0] != 1:
raise ValueError(f"Expected input of shape [1, H, W], found {data.shape}")
return torch.repeat_interleave(data, 3, dim=0)
def create_chest_xray_transform_for_inference(resize: int, center_crop_size: int) -> Compose:
"""
Defines the image transformation pipeline for Chest-Xray datasets.
:param resize: The size to resize the image to. Linear resampling is used.
Resizing is applied on the axis with smaller shape.
:param center_crop_size: The size to center crop the image to. Square crop is applied.
"""
transforms = [Resize(resize), CenterCrop(center_crop_size), ToTensor(), ExpandChannels()]
return Compose(transforms)
def remap_to_uint8(array: np.ndarray, percentiles=None) -> np.ndarray:
"""Remap values in input so the output range is :math:`[0, 255]`.
Percentiles can be used to specify the range of values to remap.
This is useful to discard outliers in the input data.
:param array: Input array.
:param percentiles: Percentiles of the input values that will be mapped to ``0`` and ``255``.
Passing ``None`` is equivalent to using percentiles ``(0, 100)`` (but faster).
:returns: Array with ``0`` and ``255`` as minimum and maximum values.
"""
array = array.astype(float)
if percentiles is not None:
len_percentiles = len(percentiles)
if len_percentiles != 2:
message = (
'The value for percentiles should be a sequence of length 2,'
f' but has length {len_percentiles}'
)
raise ValueError(message)
a, b = percentiles
if a >= b:
raise ValueError(f'Percentiles must be in ascending order, but a sequence "{percentiles}" was passed')
if a < 0 or b > 100:
raise ValueError(f'Percentiles must be in the range [0, 100], but a sequence "{percentiles}" was passed')
cutoff: np.ndarray = np.percentile(array, percentiles)
array = np.clip(array, *cutoff)
array -= array.min()
array /= array.max()
array *= 255
return array.astype(np.uint8)
def init_chexpert_predictor():
ckpt_path = f"findings_classifier/checkpoints/chexpert_train/ChexpertClassifier.ckpt"
chexpert_cols = ["No Finding", "Enlarged Cardiomediastinum",
"Cardiomegaly", "Lung Opacity",
"Lung Lesion", "Edema",
"Consolidation", "Pneumonia",
"Atelectasis", "Pneumothorax",
"Pleural Effusion", "Pleural Other",
"Fracture", "Support Devices"]
model = LitIGClassifier.load_from_checkpoint(ckpt_path, num_classes=14, class_names=chexpert_cols, strict=False)
model.eval()
model.cuda()
model.half()
cp_transforms = Compose([Resize(512), CenterCrop(488), ToTensor(), ExpandChannels()])
return model, np.asarray(model.class_names), cp_transforms