mav735 commited on
Commit
371ecdf
1 Parent(s): ca10bab

Upload 3 files

Browse files
Files changed (4) hide show
  1. .gitattributes +1 -0
  2. app.py +42 -0
  3. model.py +166 -0
  4. model_5_7_14_27_0.993125_final +3 -0
.gitattributes CHANGED
@@ -32,3 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ model_5_7_14_27_0.993125_final filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from model import get_results_model
3
+ from model import model_
4
+ import cv2
5
+
6
+ IMAGES = 0
7
+
8
+
9
+ def predict_image(image):
10
+ global IMAGES
11
+ paths = f'images/image_{IMAGES}.jpg'
12
+ cv2.imwrite(paths, image)
13
+ IMAGES += 1
14
+ result = get_results_model(paths, model_)
15
+ if result[2] < 0.001:
16
+ label_img = 'Unrecognised'
17
+ pred_acc = ''
18
+ else:
19
+ label_img = result[1]
20
+ pred_acc = f'Probability: &nbsp; **{(result[2] * 100):.2f} %**'
21
+ return result[0], f'<font size="10"> Class: &nbsp; **{label_img}** &nbsp;&nbsp;&nbsp;&nbsp; {pred_acc}</font>'
22
+
23
+
24
+ with gr.Blocks() as demo:
25
+ gr.Markdown('**<font size="10">MRI Assistant</font>**')
26
+ with gr.Row():
27
+ with gr.Column():
28
+ image_input = gr.Image(label='MRI')
29
+ label = gr.Markdown("")
30
+ image_output = gr.Image(label='AI results')
31
+
32
+ image_button = gr.Button("Predict results")
33
+
34
+ gr.Markdown(r"""
35
+ <font size="10">Social:</font>\
36
+ &nbsp;&nbsp; <font size="7">*1.*</font>&nbsp;&nbsp; <font size="6"> [*Developers*](https://t.me/HenSolaris) </font>\
37
+ &nbsp;&nbsp; <font size="7">*2.*</font>&nbsp;&nbsp; <font size="6"> [*Telegram bot*](https://t.me/Altsheimer_AI_bot) </font>
38
+ """)
39
+
40
+ image_button.click(predict_image, inputs=image_input, outputs=[image_output, label])
41
+
42
+ demo.launch()
model.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import matplotlib.pyplot as plt
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as func
7
+ from captum.attr import IntegratedGradients
8
+ import __main__
9
+
10
+
11
+ class ConvNet(nn.Module):
12
+ def __init__(self):
13
+ super().__init__()
14
+
15
+ # размер исходной картинки 180x180
16
+
17
+ self.conv1 = nn.Conv2d(3, 8, 3, padding=1)
18
+ self.batchnorm1 = nn.BatchNorm2d(8)
19
+ self.pool1 = nn.MaxPool2d((2, 2))
20
+
21
+ self.conv2 = nn.Conv2d(8, 16, 8, padding=1)
22
+ self.dropout2 = nn.Dropout(0.25)
23
+
24
+ self.batchnorm2 = nn.BatchNorm2d(16)
25
+ self.pool2 = nn.MaxPool2d((2, 2))
26
+
27
+ self.conv3 = nn.Conv2d(16, 32, 2, padding=1)
28
+ self.dropout3 = nn.Dropout(0.25)
29
+ self.batchnorm3 = nn.BatchNorm2d(32)
30
+ self.pool3 = nn.MaxPool2d((2, 2))
31
+
32
+ self.conv4 = nn.Conv2d(32, 16, 16, padding=1)
33
+ self.dropout4 = nn.Dropout(0.25)
34
+ self.batchnorm4 = nn.BatchNorm2d(16)
35
+
36
+ # flatten
37
+ self.flatten = nn.Flatten()
38
+
39
+ self.fc_2_1 = nn.Linear(28224, 512)
40
+ self.fc_2_2 = nn.Linear(512, 4)
41
+
42
+ # linear 1
43
+ self.fc1 = nn.Linear(1024, 512)
44
+ self.fc2 = nn.Linear(512, 4)
45
+
46
+ def forward(self, x):
47
+ x = func.relu(self.conv1(x))
48
+ x = self.batchnorm1(x)
49
+ x = self.pool1(x)
50
+
51
+ x = func.relu(self.conv2(x))
52
+ x = self.dropout2(x)
53
+ x = self.batchnorm2(x)
54
+ x = self.pool2(x)
55
+
56
+ x_1 = func.relu(self.conv3(x))
57
+ x_1 = self.dropout3(x_1)
58
+ x_1 = self.batchnorm3(x_1)
59
+ x_1 = self.pool3(x_1)
60
+
61
+ x_1 = func.relu(self.conv4(x_1))
62
+ x_1 = self.dropout4(x_1)
63
+ x_1 = self.batchnorm4(x_1)
64
+
65
+ x_1 = self.flatten(x_1)
66
+ x_1 = func.relu(self.fc1(x_1))
67
+ x_1 = self.fc2(x_1)
68
+
69
+ x_2 = self.flatten(x)
70
+ x_2 = func.relu(self.fc_2_1(x_2))
71
+ x_2 = self.fc_2_2(x_2)
72
+
73
+ return x_1 + x_2
74
+
75
+
76
+ setattr(__main__, "ConvNet", ConvNet)
77
+
78
+ device = 'cpu'
79
+ model_ = torch.load('model_5_7_14_27_0.993125_final')
80
+ model_.eval()
81
+
82
+
83
+ def get_class_of_demension(idx):
84
+ classes = ['NonDemented', 'VeryMildDemented', 'MildDemented', 'ModerateDemented']
85
+ return classes[idx]
86
+
87
+
88
+ def get_segmented_map(image_attr: np.array,
89
+ color_map: str = 'positive',
90
+ borders: tuple = (20, 20)) -> np.array:
91
+ """arg: color_map: [positive, all]"""
92
+ if color_map != 'all':
93
+ for i in range(len(image_attr)):
94
+ for j in range(len(image_attr[i])):
95
+ flag_zero = False
96
+ if color_map == 'positive':
97
+ if max(image_attr[i][j]) != image_attr[i][j][1]:
98
+ flag_zero = True
99
+ else:
100
+ if sum(image_attr[i][j]) - max(image_attr[i][j]) > borders[1]:
101
+ flag_zero = True
102
+ elif color_map == 'negative':
103
+ if max(image_attr[i][j]) == image_attr[i][j][1] or max(image_attr[i][j]) == image_attr[i][j][2]:
104
+ flag_zero = True
105
+ else:
106
+ if sum(image_attr[i][j]) - max(image_attr[i][j]) > borders[0]:
107
+ flag_zero = True
108
+ if flag_zero:
109
+ image_attr[i][j] = [0, 0, 0]
110
+ return image_attr
111
+
112
+
113
+ def show_pack_of_images(images, labels):
114
+ f, axes = plt.subplots(1, len(images), figsize=(30, 5))
115
+ for i, axis in enumerate(axes):
116
+ img = images[i]
117
+ axes[i].imshow(img)
118
+ axes[i].set_title(labels[i])
119
+ plt.show()
120
+
121
+
122
+ def create_color_map_igrad(net, img_path: str) -> tuple:
123
+ integrated_gradients = IntegratedGradients(net)
124
+ img = cv2.cvtColor(cv2.resize(cv2.imread(img_path, 0), (180, 180)), cv2.COLOR_GRAY2RGB)
125
+ img_tensor = torch.from_numpy(np.array(img).astype(np.float32)).to('cpu')
126
+ img_tensor = img_tensor.permute(2, 0, 1) / 255
127
+ img_tensor = img_tensor.unsqueeze(0)
128
+
129
+ output = model_(img_tensor)
130
+ prob = func.sigmoid(output)
131
+ probability = float(np.max(prob.detach().numpy()))
132
+ prediction_score, pred_label_idx = torch.topk(output, 1)
133
+ pred_label_idx.squeeze_()
134
+ predicted_label = pred_label_idx.item()
135
+
136
+ attributions_ig = integrated_gradients.attribute(img_tensor, target=pred_label_idx, n_steps=200)
137
+
138
+ imgs = [(img_tensor.squeeze().permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8),
139
+ (np.transpose(attributions_ig.squeeze().cpu().detach().numpy(), (1, 2, 0)) * 255).astype(np.uint8)]
140
+ imgs.extend([get_segmented_map(imgs[1].copy(), 'negative'), get_segmented_map(imgs[1].copy(), 'positive')])
141
+ labels = [get_class_of_demension(predicted_label), 'all', 'negative', 'positive']
142
+
143
+ return imgs, labels, probability
144
+
145
+
146
+ def get_results_model(image_path, model):
147
+ images, labels, probability = create_color_map_igrad(model, image_path)
148
+
149
+ img = images[3].copy()
150
+ original = images[0].copy()
151
+
152
+ result = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY);
153
+ result = cv2.blur(result, (5, 5));
154
+
155
+ min_val, max_val, min_loc, max_loc = cv2.minMaxLoc(result)
156
+ ret, result = cv2.threshold(result, 0.3 * max_val, 255, cv2.THRESH_BINARY)
157
+
158
+ contours, hierarchy = cv2.findContours(result, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)
159
+
160
+ for element in contours:
161
+ if 150 > len(element) > 35:
162
+ color = (255, 0, 0)
163
+ x, y, w, h = cv2.boundingRect(element)
164
+ cv2.rectangle(original, (x - 2, y - 2), (x + w + 1, y + h + 1), color, 1)
165
+
166
+ return original, labels[0], probability
model_5_7_14_27_0.993125_final ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f4e58c22eb0c9888135ed5aa62f53fae9aca6d45c4d156b38098d9adf3bb3bd6
3
+ size 60504053