|
--- |
|
library_name: transformers |
|
tags: |
|
- chest_x_ray |
|
- x_ray |
|
- medical_imaging |
|
- radiology |
|
- segmentation |
|
- classification |
|
- lungs |
|
- heart |
|
base_model: |
|
- timm/tf_efficientnetv2_s.in21k_ft_in1k |
|
pipeline_tag: image-segmentation |
|
--- |
|
|
|
This model performs both segmentation and classification on chest radiographs (X-rays). |
|
The model uses a `tf_efficientnetv2_s` backbone with a U-Net decoder for segmentation and linear layer for classification. |
|
For frontal radiographs, the model segments the: 1) right lung, 2) left lung, and 3) heart. |
|
The model also predicts the chest X-ray view (AP, PA, lateral), patient age, and patient sex. |
|
The [CheXpert](https://stanfordmlgroup.github.io/competitions/chexpert/) (small version) and [NIH Chest X-ray](https://nihcc.app.box.com/v/ChestXray-NIHCC) datasets were used to train the model. |
|
Segmentation masks were obtained from the CheXmask [dataset](https://physionet.org/content/chexmask-cxr-segmentation-data/0.4/) ([paper](https://www.nature.com/articles/s41597-024-03358-1)). |
|
The final dataset comprised 335,516 images from 96,385 patients and was split into 80% training/20% validation. A holdout test set was not used since minimal tuning was performed. |
|
The view classifier was trained only on CheXpert images (NIH images excluded from loss function), given that lateral radiographs are only present in CheXpert. |
|
This is to avoid unwanted bias in the model, which can occur if one class originates only from a single dataset. |
|
|
|
Validation performance as follows: |
|
``` |
|
Segmentation (Dice similarity coefficient): |
|
Right Lung: 0.957 |
|
Left Lung: 0.948 |
|
Heart: 0.943 |
|
|
|
Age Prediction: |
|
Mean Absolute Error: 5.25 years |
|
|
|
Classification: |
|
View (AP, PA, lateral): 99.42% accuracy |
|
Female: 0.999 AUC |
|
``` |
|
|
|
To use the model: |
|
``` |
|
import cv2 |
|
import torch |
|
from transformers import AutoModel |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
model = AutoModel.from_pretrained("ianpan/chest-x-ray-basic", trust_remote_code=True) |
|
model = model.eval().to(device) |
|
img = cv2.imread(..., 0) |
|
x = model.preprocess(img) # only takes single image as input |
|
x = torch.from_numpy(x).unsqueeze(0).unsqueeze(0) # add channel, batch dims |
|
x = x.float() |
|
|
|
with torch.inference_mode(): |
|
out = model(x.to(device)) |
|
``` |
|
|
|
The output is a dictionary which contains 4 keys: |
|
* `mask` has 3 channels containing the segmentation masks. Take the argmax over the channel dimension to create a single image mask (i.e., `out["mask"].argmax(1)`): 1 = right lung, 2 = left lung, 3 = heart. |
|
* `age`, in years. |
|
* `view`, with 3 classes for each possible view. Take the argmax to select the predicted view (i.e., `out["view"].argmax(1)`): 0 = AP, 1 = PA, 2 = lateral. |
|
* `female`, binarize with `out["female"] >= 0.5`. |
|
|
|
You can use the segmentation mask to crop the region containing the lungs from the rest of the X-ray. |
|
You can also calculate the [cardiothoracic ratio (CTR)](https://radiopaedia.org/articles/cardiothoracic-ratio?lang=us) using this function: |
|
``` |
|
import numpy as np |
|
|
|
def calculate_ctr(mask): # single mask with dims (height, width) |
|
lungs = np.zeros_like(mask) |
|
lungs[mask == 1] = 1 |
|
lungs[mask == 2] = 1 |
|
heart = (mask == 3).astype("int") |
|
y, x = np.stack(np.where(lungs == 1)) |
|
lung_min = x.min() |
|
lung_max = x.max() |
|
y, x = np.stack(np.where(heart == 1)) |
|
heart_min = x.min() |
|
heart_max = x.max() |
|
lung_range = lung_max - lung_min |
|
heart_range = heart_max - heart_min |
|
return heart_range / lung_range |
|
``` |
|
|
|
If you have `pydicom` installed, you can also load a DICOM image directly: |
|
``` |
|
img = model.load_image_from_dicom(path_to_dicom) |
|
``` |
|
|
|
This model is for demonstration and research purposes only and has NOT been approved by any regulatory agency for clinical use. |
|
The user assumes any and all responsibility regarding their own use of this model and its outputs. |