|
import os |
|
from pathlib import Path |
|
|
|
import numpy as np |
|
import pandas as pd |
|
import torch |
|
from PIL import Image |
|
from skimage import io |
|
from torch.utils.data import Dataset |
|
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, transforms |
|
|
|
from local_config import VIS_ROOT, PATH_TO_MIMIC_CXR |
|
|
|
class Chexpert_Dataset(Dataset): |
|
def __init__(self, split='train', truncate=None, loss_weighting="none", use_augs=False): |
|
|
|
super().__init__() |
|
|
|
|
|
self.split = pd.read_csv(f'{PATH_TO_MIMIC_CXR}/mimic-cxr-jpg/2.0.0/mimic-cxr-2.0.0-split.csv') |
|
self.reports = pd.read_csv('mimic-cxr/reports_processed/mimic_cxr_sectioned.csv') |
|
self.reports = self.reports.dropna(subset=['findings']) |
|
|
|
self.vis_root = VIS_ROOT |
|
self.img_ids = {img_id: i for i, img_id in enumerate(self.reports['dicom_id'])} |
|
self.split_ids = set(self.split.loc[self.split['split'] == split]['dicom_id']) |
|
self.chexpert = pd.read_csv(f'data/data_files/finding_chexbert_labels.csv') |
|
self.chexpert_cols = ["No Finding", "Enlarged Cardiomediastinum", |
|
"Cardiomegaly", "Lung Opacity", |
|
"Lung Lesion", "Edema", |
|
"Consolidation", "Pneumonia", |
|
"Atelectasis", "Pneumothorax", |
|
"Pleural Effusion", "Pleural Other", |
|
"Fracture", "Support Devices"] |
|
|
|
|
|
self.annotation = self.reports.loc[self.reports['dicom_id'].isin(self.split_ids)] |
|
self.annotation['study_id'] = self.annotation['Note_file'].apply(lambda x: int(x.lstrip('s').rstrip('.txt'))) |
|
|
|
self.annotation = pd.merge(self.annotation, self.chexpert, how='left', left_on=['dicom_id'], right_on=['dicom_id']) |
|
if truncate is not None: |
|
self.annotation = self.annotation[:truncate] |
|
|
|
self.vis_transforms = Compose([Resize(512), CenterCrop(488), ToTensor(), ExpandChannels()]) |
|
if use_augs: |
|
aug_tfm = transforms.Compose([transforms.RandomAffine(degrees=30, shear=15), |
|
transforms.ColorJitter(brightness=0.2, contrast=0.2)]) |
|
|
|
self.vis_transforms = transforms.Compose([self.vis_transforms, aug_tfm]) |
|
self.loss_weighting = loss_weighting |
|
|
|
def get_class_weights(self): |
|
"""Compute class weights based on the inverse of class frequencies. |
|
|
|
Returns: |
|
Dict[str, float]: Class weights. |
|
""" |
|
if self.loss_weighting == "none": |
|
return torch.ones(len(self.chexpert_cols), dtype=torch.float32) |
|
|
|
label_counts = torch.zeros(len(self.chexpert_cols), dtype=torch.float32) |
|
|
|
for _, ann in self.annotation.iterrows(): |
|
chexpert_labels = self._extract_chexpert_labels_from_row(ann) |
|
label_counts += chexpert_labels |
|
|
|
|
|
if self.loss_weighting == "lin": |
|
class_weights = len(self.annotation) / label_counts |
|
elif self.loss_weighting == "log": |
|
class_weights = torch.log(len(self.annotation) / label_counts) |
|
|
|
return class_weights |
|
|
|
def remap_to_uint8(self, array: np.ndarray, percentiles=None) -> np.ndarray: |
|
"""Remap values in input so the output range is :math:`[0, 255]`. |
|
|
|
Percentiles can be used to specify the range of values to remap. |
|
This is useful to discard outliers in the input data. |
|
|
|
:param array: Input array. |
|
:param percentiles: Percentiles of the input values that will be mapped to ``0`` and ``255``. |
|
Passing ``None`` is equivalent to using percentiles ``(0, 100)`` (but faster). |
|
:returns: Array with ``0`` and ``255`` as minimum and maximum values. |
|
""" |
|
array = array.astype(float) |
|
if percentiles is not None: |
|
len_percentiles = len(percentiles) |
|
if len_percentiles != 2: |
|
message = ( |
|
'The value for percentiles should be a sequence of length 2,' |
|
f' but has length {len_percentiles}' |
|
) |
|
raise ValueError(message) |
|
a, b = percentiles |
|
if a >= b: |
|
raise ValueError(f'Percentiles must be in ascending order, but a sequence "{percentiles}" was passed') |
|
if a < 0 or b > 100: |
|
raise ValueError(f'Percentiles must be in the range [0, 100], but a sequence "{percentiles}" was passed') |
|
cutoff: np.ndarray = np.percentile(array, percentiles) |
|
array = np.clip(array, *cutoff) |
|
array -= array.min() |
|
array /= array.max() |
|
array *= 255 |
|
return array.astype(np.uint8) |
|
|
|
def load_image(self, path) -> Image.Image: |
|
"""Load an image from disk. |
|
|
|
The image values are remapped to :math:`[0, 255]` and cast to 8-bit unsigned integers. |
|
|
|
:param path: Path to image. |
|
:returns: Image as ``Pillow`` ``Image``. |
|
""" |
|
|
|
if path.suffix in [".jpg", ".jpeg", ".png"]: |
|
image = io.imread(path) |
|
else: |
|
raise ValueError(f"Image type not supported, filename was: {path}") |
|
|
|
image = self.remap_to_uint8(image) |
|
return Image.fromarray(image).convert("L") |
|
|
|
def _extract_chexpert_labels_from_row(self, row: pd.Series) -> torch.Tensor: |
|
labels = torch.zeros(len(self.chexpert_cols), dtype=torch.float32) |
|
for i, col in enumerate(self.chexpert_cols): |
|
if row[col] == 1: |
|
labels[i] = 1 |
|
return labels |
|
|
|
def __getitem__(self, index): |
|
ann = self.annotation.iloc[index] |
|
image_path = os.path.join(self.vis_root, ann["Img_Folder"], ann["Img_Filename"]) |
|
image = self.load_image(Path(image_path)) |
|
image = self.vis_transforms(image) |
|
chexpert_labels = self._extract_chexpert_labels_from_row(ann) |
|
|
|
return { |
|
"image": image, |
|
"labels": chexpert_labels, |
|
"image_id": self.img_ids[ann["dicom_id"]], |
|
"report": ann["findings"], |
|
"study_id": ann["study_id"], |
|
"dicom_id": ann["dicom_id"], |
|
} |
|
|
|
def __len__(self): |
|
return len(self.annotation) |
|
|
|
|
|
if __name__ == '__main__': |
|
dataset = Chexpert_Dataset() |
|
print(dataset[0]) |
|
|