lingchmao commited on
Commit
52a9229
1 Parent(s): 212dd4f

create app.py

Browse files
Files changed (1) hide show
  1. app.py +335 -0
app.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import gradio as gr
4
+ import torch
5
+ import monai
6
+ import morphsnakes as ms
7
+ from utils.sliding_window import sw_inference
8
+ from utils.tumor_features import generate_features
9
+ from monai.networks.nets import SegResNetVAE
10
+ from monai.transforms import (
11
+ LoadImage, Orientation, Compose, ToTensor, Activations,
12
+ FillHoles, KeepLargestConnectedComponent, AsDiscrete, ScaleIntensityRange
13
+ )
14
+
15
+
16
+ # global params
17
+ THIS_DIR = os.path.dirname(os.path.abspath(__file__))
18
+ examples_path = [
19
+ os.path.join(THIS_DIR, 'examples', 'HCC_003.nrrd'),
20
+ os.path.join(THIS_DIR, 'examples', 'HCC_006.nrrd'),
21
+ os.path.join(THIS_DIR, 'examples', 'HCC_007.nrrd'),
22
+ os.path.join(THIS_DIR, 'examples', 'HCC_018.nrrd')
23
+ ]
24
+ models_path = {
25
+ "liver": os.path.join(THIS_DIR, 'checkpoints', 'liver_3DSegResNetVAE.pth'),
26
+ "tumor": os.path.join(THIS_DIR, 'checkpoints', 'tumor_3DSegResNetVAE_weak_morp.pth')
27
+ }
28
+ cache_path = {
29
+ "liver mask": "liver_mask.npy",
30
+ "tumor mask": "tumor_mask.npy"
31
+ }
32
+ device = "cpu"
33
+ mydict = {}
34
+
35
+
36
+ def render(image_name, x, selected_slice):
37
+
38
+ if not isinstance(image_name, str) or '/' in image_name:
39
+ image_name = image_name.name.split('/')[-1].replace(".nrrd","")
40
+
41
+ if 'img' not in mydict[image_name].keys():
42
+ return (np.zeros((512, 512)), []), f'z-value: {x}, (zmin: {None}, zmax: {None})'
43
+
44
+ # set slider ranges
45
+ zmin, zmax = 0, mydict[image_name]['img'].shape[-1] - 1
46
+ if x > zmax: x = zmax
47
+ if x < zmin: x = zmin
48
+
49
+ # image
50
+ img = mydict[image_name]['img'][:,:,x]
51
+ img = (img - np.min(img)) / (np.max(img) - np.min(img)) # scale to 0 and 1
52
+
53
+ # masks
54
+ annotations = []
55
+ if 'liver mask' in mydict[image_name].keys():
56
+ annotations.append((mydict[image_name]['liver mask'][:,:,x], "segmented liver"))
57
+ if 'tumor mask' in mydict[image_name].keys():
58
+ annotations.append((mydict[image_name]['tumor mask'][:,:,x], "segmented tumor"))
59
+
60
+ return img, annotations
61
+
62
+
63
+ def load_liver_model():
64
+
65
+ liver_model = SegResNetVAE(
66
+ input_image_size=(512,512,16),
67
+ vae_estimate_std=False,
68
+ vae_default_std=0.3,
69
+ vae_nz=256,
70
+ spatial_dims=3,
71
+ blocks_down=[1, 2, 2, 4],
72
+ blocks_up=[1, 1, 1],
73
+ init_filters=16,
74
+ in_channels=1,
75
+ norm='instance',
76
+ out_channels=2,
77
+ dropout_prob=0.2,
78
+ )
79
+
80
+ liver_model.load_state_dict(torch.load(models_path['liver'], map_location=torch.device(device)))
81
+
82
+ return liver_model
83
+
84
+
85
+ def load_tumor_model():
86
+
87
+ tumor_model = SegResNetVAE(
88
+ input_image_size=(256,256,32),
89
+ vae_estimate_std=False,
90
+ vae_default_std=0.3,
91
+ vae_nz=256,
92
+ spatial_dims=3,
93
+ blocks_down=[1, 2, 2, 4],
94
+ blocks_up=[1, 1, 1],
95
+ init_filters=16,
96
+ in_channels=1,
97
+ norm='instance',
98
+ out_channels=3,
99
+ dropout_prob=0.2,
100
+ )
101
+
102
+ tumor_model.load_state_dict(torch.load(models_path['tumor'], map_location=torch.device('cpu')))
103
+
104
+ return tumor_model
105
+
106
+
107
+ def load_image(image, slider, selected_slice):
108
+
109
+ global mydict
110
+
111
+ image_name = image.name.split('/')[-1].replace(".nrrd","")
112
+ mydict = {image_name: {}}
113
+
114
+ preprocessing_liver = Compose([
115
+ # load image
116
+ LoadImage(reader="NrrdReader", ensure_channel_first=True),
117
+ # ensure orientation
118
+ Orientation(axcodes="PLI"),
119
+ # convert to tensor
120
+ ToTensor()
121
+ ])
122
+
123
+ input = preprocessing_liver(image.name)
124
+ mydict[image_name]["img"] = input[0].numpy() # first channel
125
+
126
+ print("Loaded image", image_name)
127
+
128
+ image, annotations = render(image_name, slider, selected_slice)
129
+
130
+ return f"Your image is successfully loaded! Please use the slider to view the image (zmin: 1, zmax: {input.shape[-1]}).", (image, annotations)
131
+
132
+
133
+ def segment_tumor(image_name):
134
+
135
+ if os.path.isfile(f"cache/{image_name}_{cache_path['tumor mask']}"):
136
+ mydict[image_name]['tumor mask'] = np.load(f"cache/{image_name}_{cache_path['tumor mask']}")
137
+
138
+ if 'tumor mask' in mydict[image_name].keys() and mydict[image_name]['tumor mask'] is not None:
139
+ return
140
+
141
+ input = torch.from_numpy(mydict[image_name]['img'])
142
+
143
+ tumor_model = load_tumor_model()
144
+
145
+ preprocessing_tumor = Compose([
146
+ ScaleIntensityRange(a_min=-200, a_max=250, b_min=0.0, b_max=1.0, clip=True)
147
+ ])
148
+
149
+ postprocessing_tumor = Compose([
150
+ Activations(sigmoid=True),
151
+ # Convert to binary predictions
152
+ AsDiscrete(argmax=True, to_onehot=3),
153
+ # Remove small connected components for 1=liver and 2=tumor
154
+ KeepLargestConnectedComponent(applied_labels=[2]),
155
+ # Fill holes in the binary mask for 1=liver and 2=tumor
156
+ FillHoles(applied_labels=[2]),
157
+ ToTensor()
158
+ ])
159
+
160
+ # Preprocessing
161
+ input = preprocessing_tumor(input)
162
+ input = torch.multiply(input, torch.from_numpy(mydict[image_name]['liver mask'])) # mask non-liver regions
163
+
164
+ # Generate segmentation
165
+ with torch.no_grad():
166
+ segmented_mask = sw_inference(tumor_model, input[None, None, :], (256,256,32), False, discard_second_output=True, overlap=0.2)[0] # input dimensions [B,C,H,W,Z]
167
+
168
+ # Postprocess image
169
+ segmented_mask = postprocessing_tumor(segmented_mask)[-1].numpy() # background, liver, tumor
170
+ segmented_mask = ms.morphological_chan_vese(segmented_mask, iterations=2, init_level_set=segmented_mask)
171
+ segmented_mask = np.multiply(segmented_mask, mydict[image_name]['liver mask']) # Mask regions outside liver
172
+ mydict[image_name]["tumor mask"] = segmented_mask
173
+
174
+ # Saving
175
+ np.save(f"cache/{image_name}_{cache_path['tumor mask']}", mydict[image_name]["tumor mask"])
176
+ print(f"tumor mask saved to 'cache/{image_name}_{cache_path['tumor mask']}")
177
+
178
+ return
179
+
180
+
181
+ def segment_liver(image_name):
182
+
183
+ if os.path.isfile(f"cache/{image_name}_{cache_path['liver mask']}"):
184
+ mydict[image_name]['liver mask'] = np.load(f"cache/{image_name}_{cache_path['liver mask']}")
185
+
186
+ if 'liver mask' in mydict[image_name].keys() and mydict[image_name]['liver mask'] is not None:
187
+ return
188
+
189
+ input = torch.from_numpy(mydict[image_name]['img'])
190
+
191
+ # load model
192
+ liver_model = load_liver_model()
193
+
194
+ # HU Windowing
195
+ preprocessing_liver = Compose([
196
+ ScaleIntensityRange(a_min=-150, a_max=250, b_min=0.0, b_max=1.0, clip=True)
197
+ ])
198
+
199
+ postprocessing_liver = Compose([
200
+ # Apply softmax activation to convert logits to probabilities
201
+ Activations(sigmoid=True),
202
+ # Convert predicted probabilities to discrete values (0 or 1)
203
+ AsDiscrete(argmax=True, to_onehot=None),
204
+ # Remove small connected components for 1=liver and 2=tumor
205
+ KeepLargestConnectedComponent(applied_labels=[1]),
206
+ # Fill holes in the binary mask for 1=liver and 2=tumor
207
+ FillHoles(applied_labels=[1]),
208
+ ToTensor()
209
+ ])
210
+
211
+ # Preprocessing
212
+ input = preprocessing_liver(input)
213
+
214
+ # Generate segmentation
215
+ with torch.no_grad():
216
+ segmented_mask = sw_inference(liver_model, input[None, None, :], (512,512,16), False, discard_second_output=True, overlap=0.2)[0] # input dimensions [B,C,H,W,Z]
217
+
218
+ # Postprocess image
219
+ segmented_mask = postprocessing_liver(segmented_mask)[0].numpy() # first channel
220
+ mydict[image_name]["liver mask"] = segmented_mask
221
+ print(f"liver mask shape: {segmented_mask.shape}")
222
+
223
+ # Saving
224
+ np.save(f"cache/{image_name}_{cache_path['liver mask']}", mydict[image_name]["liver mask"])
225
+ print(f"liver mask saved to cache/{image_name}_{cache_path['liver mask']}")
226
+
227
+ return
228
+
229
+
230
+ def segment(image, selected_mask, slider, selected_slice):
231
+
232
+ image_name = image.name.split('/')[-1].replace(".nrrd", "")
233
+ download_liver = gr.DownloadButton(label="Download liver mask", visible = False)
234
+ download_tumor = gr.DownloadButton(label="Download tumor mask", visible = False)
235
+
236
+ if 'liver mask' in selected_mask:
237
+ print('Segmenting liver...')
238
+ segment_liver(image_name)
239
+ download_liver = gr.DownloadButton(label="Download liver mask", value=f"cache/{image_name}_{cache_path['liver mask']}", visible=True)
240
+
241
+ if 'tumor mask' in selected_mask:
242
+ print('Segmenting tumor...')
243
+ segment_tumor(image_name)
244
+ download_tumor = gr.DownloadButton(label="Download tumor mask", value=f"cache/{image_name}_{cache_path['tumor mask']}", visible=True)
245
+
246
+ image, annotations = render(image, slider, selected_slice)
247
+
248
+ return f"Segmentation is completed! ", download_liver, download_tumor, (image, annotations)
249
+
250
+
251
+ def generate_summary(image):
252
+ image_name = image.name.split('/')[-1].replace(".nrrd","")
253
+ features = generate_features(mydict[image_name]["img"], mydict[image_name]["liver mask"], mydict[image_name]["tumor mask"])
254
+ print(features)
255
+
256
+ return ""
257
+
258
+
259
+ with gr.Blocks() as app:
260
+ with gr.Column():
261
+ gr.Markdown(
262
+ """
263
+ # Lung Tumor Segmentation App
264
+
265
+ This tool is designed to assist in the identification and segmentation of lung and tumor from medical images. By uploading a CT scan image, a pre-trained machine learning model will automatically segment the lung and tumor regions. Segmented tumor's characteristics such as shape, size, and location are then analyzed to produce an AI-generated diagnosis report of the lung cancer.
266
+
267
+ ⚠️ Important disclaimer: these model outputs should NOT replace the medical diagnosis of healthcare professionals. For your reference, our model was trained on the [HCC-TACE-Seg dataset](https://www.cancerimagingarchive.net/collection/hcc-tace-seg/) and achieved 0.954 dice score for lung segmentation and 0.570 dice score for tumor segmentation. Improving tumor segmentation is still an active area of research!
268
+ """)
269
+
270
+ with gr.Row():
271
+ comment = gr.Textbox(label='Your tool guide:', value="👋 Hi there, welcome to explore the power of AI for automated medical image analysis with our user-friendly app! Start by uploading a CT scan image. Note that for now we accept .nrrd formats only.")
272
+
273
+
274
+ with gr.Row():
275
+
276
+ with gr.Column(scale=2):
277
+ image_file = gr.File(label="Step 1: Upload a CT image (.nrrd)", file_count='single', file_types=['.nrrd'], type='filepath')
278
+ btn_upload = gr.Button("Upload")
279
+
280
+ with gr.Column(scale=2):
281
+ selected_mask = gr.CheckboxGroup(label='Step 2: Select mask to produce', choices=['liver mask', 'tumor mask'], value = ['liver mask'])
282
+ btn_segment = gr.Button("Segment")
283
+
284
+ with gr.Row():
285
+ slider = gr.Slider(1, 100, step=1, label="Slice (z)")
286
+ selected_slice = gr.State(value=1)
287
+
288
+ with gr.Row():
289
+ myimage = gr.AnnotatedImage(label="Image Viewer", height=1000, width=1000, color_map={"segmented liver": "#0373fc", "segmented tumor": "#eb5334"})
290
+
291
+ with gr.Row():
292
+ with gr.Column(scale=2):
293
+ btn_download_liver = gr.DownloadButton("Download liver mask", visible=False)
294
+ with gr.Column(scale=2):
295
+ btn_download_tumor = gr.DownloadButton("Download tumor mask", visible=False)
296
+
297
+ with gr.Row():
298
+ report = gr.Textbox(label='Step 4. Generate summary report using AI:')
299
+
300
+ with gr.Row():
301
+ btn_report = gr.Button("Generate summary")
302
+
303
+
304
+ gr.Examples(
305
+ examples_path,
306
+ [image_file],
307
+ )
308
+
309
+ btn_upload.click(fn=load_image,
310
+ inputs=[image_file, slider, selected_slice],
311
+ outputs=[comment, myimage],
312
+ )
313
+
314
+ btn_segment.click(fn=segment,
315
+ inputs=[image_file, selected_mask, slider, selected_slice],
316
+ outputs=[comment, btn_download_liver, btn_download_tumor, myimage],
317
+ )
318
+
319
+ slider.change(
320
+ render,
321
+ inputs=[image_file, slider, selected_slice],
322
+ outputs=[myimage]
323
+ )
324
+
325
+ btn_report.click(fn=generate_summary,
326
+ outputs=report
327
+ )
328
+
329
+
330
+ app.launch()
331
+
332
+
333
+
334
+
335
+