This model has been trained and validated on 14,036 pediatric hand radiographs from the RSNA Pediatric Bone Age Challenge dataset, which is publicly available. It can be loaded using:
from transformers import AutoModel
model = AutoModel.from_pretrained("ianpan/bone-age", trust_remote_code=True)
The model is a 3-fold ensemble utilizing the convnextv2_tiny
backbone.
The individual single-fold models can be accessed through model.net0
, model.net1
, model.net2
. Each of these models was trained over 20,000 iterations using a batch size of 64 across 2 NVIDIA RTX 3090 GPUs.
Originally, it was trained with both a regression and classification head.
However, this model only loads the classification head, as stand-alone performance was slightly better. The classification head also generates better GradCAMs.
The softmax function is applied to the output logits and multiplied by the corresponding class indices, then summed.
This outputs a scalar float value representing the predicted bone age in units of months.
In addition to standard data augmentation, additional augmentations were also applied:
- Using a cropped radiograph (from the model https://huggingface.co./ianpan/bone-age-crop) with probability 0.5
- Histogram matching with a reference image (available in this repo under Files,
ref_img.png
) with probability 0.5
Note that both of the above augmentations could be applied simultaneously and in conjunction with standard data augamentations. Thus, the model accommodates a large range of variability in the appearance of a hand radiograph.
On the original challenge test set comprising 200 multi-annotated pediatric hand radiographs, this model achieves a mean absolute error of 4.16 months (when applying both cropping and histogram matching to the input radiograph), which surpasses the top solutions from the original challenge.
Specific results as follows, with single model performance using model.net0
in brackets:
Crop (-) / Histogram Matching (-): 4.42 [4.67] months
Crop (+) / Histogram Matching (-): 4.47 [4.84] months
Crop (-) / Histogram Matching (+): 4.34 [4.59] months
Crop (+) / Histogram Matching (+): 4.16 [4.45] months
Thus it is preferable to both crop and histogram match the image to obtain the optimal results. See https://huggingface.co./ianpan/bone-age-crop for how to crop a bone age radiograph with a pretrained model. To histogram match with a reference image:
import cv2
from skimage.exposure import match_histograms
x = cv2.imread("target_radiograph.png", 0)
ref = cv2.imread("ref_img.png", 0) # download ref_img.png from this repo
x = match_histograms(x, ref)
Patient sex is an important variable affecting the model's prediction. This is passed to the model's forward()
function using the female
argument:
# 1 indicates female, 0 male
model(x, female=torch.tensor([1, 0, 1, 0])) # assuming batch size of 4
Example usage for a single image:
import cv2
import torch
from skimage.exposure import match_histograms
from transformers import AutoModel
device = "cuda" if torch.cuda.is_available() else "cpu"
crop_model = AutoModel.from_pretrained("ianpan/bone-age-crop", trust_remote_code=True)
crop_model = crop_model.eval().to(device)
img = cv2.imread(..., 0)
img_shape = torch.tensor([img.shape[:2]])
x = crop_model.preprocess(img) # only takes single image as input
x = torch.from_numpy(x).unsqueeze(0).unsqueeze(0) # add channel, batch dims
x = x.float()
# if you do not provide img_shape
# model will return normalized coordinates
with torch.inference_mode():
coords = model(x.to(device), img_shape.to(device))
# only 1 sample in batch
coords = coords[0].cpu().numpy()
x, y, w, h = coords
# coords already rescaled with img_shape
cropped_img = img[y: y + h, x: x + w]
# histogram matching
ref = cv2.imread("ref_img.png", 0) # download ref_img.png from this repo
cropped_img = match_histograms(cropped_img, ref)
model = AutoModel.from_pretrained("ianpan/bone-age", trust_remote_code=True)
model = model.eval().to(device)
x = model.preprocess(cropped_img)
x = torch.from_numpy(x).unsqueeze(0).unsqueeze(0)
x = x.float()
female = torch.tensor([1])
with torch.inference_mode():
bone_age = model(x.to(device), female.to(device))
If you want the raw logits (class i
= i
months), you can pass return_logits=True
to forward()
:
bone_age_logits = model(x, female, return_logits=True)
To run single model inference, simply access one of the nets:
bone_age = model.net0(x, female)
If you have pydicom
installed, you can also load a DICOM image directly:
img = model.load_image_from_dicom(path_to_dicom)
This model is for demonstration and research purposes only and has NOT been approved by any regulatory agency for clinical use. The user assumes any and all responsibility regarding their own use of this model and its outputs.
- Downloads last month
- 31
Model tree for ianpan/bone-age
Base model
timm/convnextv2_tiny.fcmae_ft_in22k_in1k