--- library_name: transformers tags: - radiology - mammo_crop - mammography - medical_imaging license: apache-2.0 base_model: - timm/mobilenetv3_small_100.lamb_in1k pipeline_tag: object-detection --- This model crops mammography images to eliminate unnecessary background. The model uses a lightweight `mobilenetv3_small_100` backbone and predicts normalized `xywh` coordinates. The model was trained and validated using 54,706 screening mammography images from the [RSNA Screening Mammography Breast Cancer Detection](https://www.kaggle.com/competitions/rsna-breast-cancer-detection/) challenge using a 90%/10% split. On single-fold validation, the model achieved mean absolute errors (normalized coordinates) of: ``` x: 0.0032 y: 0.0030 w: 0.0054 h: 0.0088 ``` The ground-truth coordinates were generated using the following code: ``` import cv2 def crop_roi(img): img = img[5:-5, 5:-5] output = cv2.connectedComponentsWithStats((img > 10).astype("uint8")[:, :], 8, cv2.CV_32S) stats = output[2] idx = stats[1:, 4].argmax() + 1 x1, y1, w, h = stats[idx][:4] x1 = max(0, x1 - 5) y1 = max(0, y1 - 5) img_h, img_w = img.shape[:2] return x1, y1, w, h) ``` While not guaranteed to be foolproof, a cursory review of a sample of cropped images demonstrated excellent performance. The model was trained with a large batch size (256) to mitigate noise. To use the model: ``` import cv2 import torch from transformers import AutoModel model = AutoModel.from_pretrained("ianpan/mammo-crop", trust_remote_code=True) model = model.eval() img = cv2.imread(..., 0) img_shape = torch.tensor([img.shape[:2]]) x = model.preprocess(img) x = torch.from_numpy(x).unsqueeze(0).unsqueeze(0) x = x.float() # if you do not provide img_shape # model will return normalized coordinates with torch.inference_mode(): coords = model(x, img_shape) # only 1 sample in batch coords = coords[0].numpy() x, y, w, h = coords # coords already rescaled with img_shape cropped_img = img[y: y + h, x: x + w] ``` If you have `pydicom` installed, you can also load a DICOM image directly: ``` img = model.load_image_from_dicom(path_to_dicom) ```