|
--- |
|
library_name: transformers |
|
tags: |
|
- radiology |
|
- mammo_crop |
|
- mammography |
|
- medical_imaging |
|
license: apache-2.0 |
|
base_model: |
|
- timm/mobilenetv3_small_100.lamb_in1k |
|
pipeline_tag: object-detection |
|
--- |
|
|
|
This model crops mammography images to eliminate unnecessary background. The model uses a lightweight `mobilenetv3_small_100` backbone and predicts normalized `xywh` coordinates. |
|
|
|
The model was trained and validated using 54,706 screening mammography images from the [RSNA Screening Mammography Breast Cancer Detection](https://www.kaggle.com/competitions/rsna-breast-cancer-detection/) challenge using a 90%/10% split. |
|
On single-fold validation, the model achieved mean absolute errors (normalized coordinates) of: |
|
``` |
|
x: 0.0032 |
|
y: 0.0030 |
|
w: 0.0054 |
|
h: 0.0088 |
|
``` |
|
|
|
The ground-truth coordinates were generated using the following code: |
|
|
|
``` |
|
import cv2 |
|
|
|
def crop_roi(img): |
|
img = img[5:-5, 5:-5] |
|
|
|
output = cv2.connectedComponentsWithStats((img > 10).astype("uint8")[:, :], 8, cv2.CV_32S) |
|
stats = output[2] |
|
|
|
idx = stats[1:, 4].argmax() + 1 |
|
x1, y1, w, h = stats[idx][:4] |
|
x1 = max(0, x1 - 5) |
|
y1 = max(0, y1 - 5) |
|
|
|
img_h, img_w = img.shape[:2] |
|
|
|
return x1, y1, w, h) |
|
``` |
|
|
|
While not guaranteed to be foolproof, a cursory review of a sample of cropped images demonstrated excellent performance. |
|
The model was trained with a large batch size (256) to mitigate noise. |
|
|
|
To use the model: |
|
``` |
|
import cv2 |
|
import torch |
|
from transformers import AutoModel |
|
|
|
model = AutoModel.from_pretrained("ianpan/mammo-crop", trust_remote_code=True) |
|
model = model.eval() |
|
img = cv2.imread(..., 0) |
|
img_shape = torch.tensor([img.shape[:2]]) |
|
x = model.preprocess(img) |
|
x = torch.from_numpy(x).unsqueeze(0).unsqueeze(0) |
|
x = x.float() |
|
|
|
# if you do not provide img_shape |
|
# model will return normalized coordinates |
|
with torch.inference_mode(): |
|
coords = model(x, img_shape) |
|
|
|
# only 1 sample in batch |
|
coords = coords[0].numpy() |
|
x, y, w, h = coords |
|
# coords already rescaled with img_shape |
|
cropped_img = img[y: y + h, x: x + w] |
|
``` |
|
|
|
If you have `pydicom` installed, you can also load a DICOM image directly: |
|
``` |
|
img = model.load_image_from_dicom(path_to_dicom) |
|
``` |