File size: 3,152 Bytes
a950ee6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import numpy as np
import monai.transforms as transforms
import streamlit as st
import tempfile

class MinMaxNormalization(transforms.Transform):
    def __call__(self, data):
        d = dict(data)
        k = "image"
        d[k] = d[k] - d[k].min()
        d[k] = d[k] / np.clip(d[k].max(), a_min=1e-8, a_max=None)
        return d

class DimTranspose(transforms.Transform):
    def __init__(self, keys):
        self.keys = keys
    
    def __call__(self, data):
        d = dict(data)
        for key in self.keys:
            d[key] = np.swapaxes(d[key], -1, -3)
        return d

class ForegroundNormalization(transforms.Transform):
    def __init__(self, keys):
        self.keys = keys
    
    def __call__(self, data):
        d = dict(data)
        
        for key in self.keys:
            d[key] = self.normalize(d[key])
        return d
    
    def normalize(self, ct_narray):
        ct_voxel_ndarray = ct_narray.copy()
        ct_voxel_ndarray = ct_voxel_ndarray.flatten()
        thred = np.mean(ct_voxel_ndarray)
        voxel_filtered = ct_voxel_ndarray[(ct_voxel_ndarray > thred)]
        upper_bound = np.percentile(voxel_filtered, 99.95)
        lower_bound = np.percentile(voxel_filtered, 00.05)
        mean = np.mean(voxel_filtered)
        std = np.std(voxel_filtered)
        ### transform ###
        ct_narray = np.clip(ct_narray, lower_bound, upper_bound)
        ct_narray = (ct_narray - mean) / max(std, 1e-8)
        return ct_narray
    
@st.cache_data
def process_ct_gt(case_path, spatial_size=(32,256,256)):
    if case_path is None:
        return None
    print('Data preprocessing...')
    # transform
    img_loader = transforms.LoadImage(dtype=np.float32)
    transform = transforms.Compose(
        [
            transforms.Orientationd(keys=["image"], axcodes="RAS"),
            ForegroundNormalization(keys=["image"]),
            DimTranspose(keys=["image"]),
            MinMaxNormalization(),
            transforms.SpatialPadd(keys=["image"], spatial_size=spatial_size, mode='constant'),
            transforms.CropForegroundd(keys=["image"], source_key="image"),
            transforms.ToTensord(keys=["image"]),
        ]
    )
    zoom_out_transform = transforms.Resized(keys=["image"], spatial_size=spatial_size, mode='nearest-exact')
    z_transform = transforms.Resized(keys=["image"], spatial_size=(325,325,325), mode='nearest-exact')
    ###
    item = {}
    # generate ct_voxel_ndarray
    if type(case_path) is str:
        ct_voxel_ndarray, _ = img_loader(case_path)
    else:
        bytes_data = case_path.read()
        with tempfile.NamedTemporaryFile(suffix='.nii.gz') as tmp:
            tmp.write(bytes_data)
            tmp.seek(0)
            ct_voxel_ndarray, _ = img_loader(tmp.name)
    ct_voxel_ndarray = np.array(ct_voxel_ndarray).squeeze()
    ct_voxel_ndarray = np.expand_dims(ct_voxel_ndarray, axis=0)
    item['image'] = ct_voxel_ndarray

    # transform
    item = transform(item)
    item_zoom_out = zoom_out_transform(item)
    item['zoom_out_image'] = item_zoom_out['image']
    
    item_z = z_transform(item)
    item['z_image'] = item_z['image']
    return item