import numpy as np import torch from torchvision.transforms import Compose, Resize, ToTensor, CenterCrop, transforms from findings_classifier.chexpert_train import LitIGClassifier class ExpandChannels: """ Transforms an image with one channel to an image with three channels by copying pixel intensities of the image along the 1st dimension. """ def __call__(self, data: torch.Tensor) -> torch.Tensor: """ :param data: Tensor of shape [1, H, W]. :return: Tensor with channel copied three times, shape [3, H, W]. """ if data.shape[0] != 1: raise ValueError(f"Expected input of shape [1, H, W], found {data.shape}") return torch.repeat_interleave(data, 3, dim=0) def create_chest_xray_transform_for_inference(resize: int, center_crop_size: int) -> Compose: """ Defines the image transformation pipeline for Chest-Xray datasets. :param resize: The size to resize the image to. Linear resampling is used. Resizing is applied on the axis with smaller shape. :param center_crop_size: The size to center crop the image to. Square crop is applied. """ transforms = [Resize(resize), CenterCrop(center_crop_size), ToTensor(), ExpandChannels()] return Compose(transforms) def remap_to_uint8(array: np.ndarray, percentiles=None) -> np.ndarray: """Remap values in input so the output range is :math:`[0, 255]`. Percentiles can be used to specify the range of values to remap. This is useful to discard outliers in the input data. :param array: Input array. :param percentiles: Percentiles of the input values that will be mapped to ``0`` and ``255``. Passing ``None`` is equivalent to using percentiles ``(0, 100)`` (but faster). :returns: Array with ``0`` and ``255`` as minimum and maximum values. """ array = array.astype(float) if percentiles is not None: len_percentiles = len(percentiles) if len_percentiles != 2: message = ( 'The value for percentiles should be a sequence of length 2,' f' but has length {len_percentiles}' ) raise ValueError(message) a, b = percentiles if a >= b: raise ValueError(f'Percentiles must be in ascending order, but a sequence "{percentiles}" was passed') if a < 0 or b > 100: raise ValueError(f'Percentiles must be in the range [0, 100], but a sequence "{percentiles}" was passed') cutoff: np.ndarray = np.percentile(array, percentiles) array = np.clip(array, *cutoff) array -= array.min() array /= array.max() array *= 255 return array.astype(np.uint8) def init_chexpert_predictor(): ckpt_path = f"findings_classifier/checkpoints/chexpert_train/ChexpertClassifier.ckpt" chexpert_cols = ["No Finding", "Enlarged Cardiomediastinum", "Cardiomegaly", "Lung Opacity", "Lung Lesion", "Edema", "Consolidation", "Pneumonia", "Atelectasis", "Pneumothorax", "Pleural Effusion", "Pleural Other", "Fracture", "Support Devices"] model = LitIGClassifier.load_from_checkpoint(ckpt_path, num_classes=14, class_names=chexpert_cols, strict=False) model.eval() model.cuda() model.half() cp_transforms = Compose([Resize(512), CenterCrop(488), ToTensor(), ExpandChannels()]) return model, np.asarray(model.class_names), cp_transforms