File size: 2,164 Bytes
4b3d085
 
30cd0c8
 
 
 
 
 
 
 
 
4b3d085
 
30cd0c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79ef347
30cd0c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
---
library_name: transformers
tags:
- radiology
- mammo_crop
- mammography
- medical_imaging
license: apache-2.0
base_model:
- timm/mobilenetv3_small_100.lamb_in1k
pipeline_tag: object-detection
---

This model crops mammography images to eliminate unnecessary background. The model uses a lightweight `mobilenetv3_small_100` backbone and predicts normalized `xywh` coordinates.

The model was trained and validated using 54,706 screening mammography images from the [RSNA Screening Mammography Breast Cancer Detection](https://www.kaggle.com/competitions/rsna-breast-cancer-detection/) challenge using a 90%/10% split.
On single-fold validation, the model achieved mean absolute errors (normalized coordinates) of:
```
x: 0.0032
y: 0.0030
w: 0.0054
h: 0.0088
```

The ground-truth coordinates were generated using the following code:

```
import cv2

def crop_roi(img):
    img = img[5:-5, 5:-5]

    output = cv2.connectedComponentsWithStats((img > 10).astype("uint8")[:, :], 8, cv2.CV_32S)
    stats = output[2]

    idx = stats[1:, 4].argmax() + 1
    x1, y1, w, h = stats[idx][:4]
    x1 = max(0, x1 - 5)
    y1 = max(0, y1 - 5)
    
    img_h, img_w = img.shape[:2]

    return x1, y1, w, h)
```

While not guaranteed to be foolproof, a cursory review of a sample of cropped images demonstrated excellent performance. 
The model was trained with a large batch size (256) to mitigate noise. 

To use the model:
```
import cv2
import torch
from transformers import AutoModel

model = AutoModel.from_pretrained("ianpan/mammo-crop", trust_remote_code=True)
model = model.eval()
img = cv2.imread(..., 0)
img_shape = torch.tensor([img.shape[:2]])
x = model.preprocess(img)
x = torch.from_numpy(x).unsqueeze(0).unsqueeze(0)
x = x.float()

# if you do not provide img_shape
# model will return normalized coordinates
with torch.inference_mode():
  coords = model(x, img_shape)

# only 1 sample in batch
coords = coords[0].numpy()
x, y, w, h = coords
# coords already rescaled with img_shape
cropped_img = img[y: y + h, x: x + w]
```

If you have `pydicom` installed, you can also load a DICOM image directly:
```
img = model.load_image_from_dicom(path_to_dicom)
```