mri-assistent / model.py
mav735's picture
Upload 3 files
371ecdf
raw
history blame
5.53 kB
import cv2
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as func
from captum.attr import IntegratedGradients
import __main__
class ConvNet(nn.Module):
def __init__(self):
super().__init__()
# размер исходной картинки 180x180
self.conv1 = nn.Conv2d(3, 8, 3, padding=1)
self.batchnorm1 = nn.BatchNorm2d(8)
self.pool1 = nn.MaxPool2d((2, 2))
self.conv2 = nn.Conv2d(8, 16, 8, padding=1)
self.dropout2 = nn.Dropout(0.25)
self.batchnorm2 = nn.BatchNorm2d(16)
self.pool2 = nn.MaxPool2d((2, 2))
self.conv3 = nn.Conv2d(16, 32, 2, padding=1)
self.dropout3 = nn.Dropout(0.25)
self.batchnorm3 = nn.BatchNorm2d(32)
self.pool3 = nn.MaxPool2d((2, 2))
self.conv4 = nn.Conv2d(32, 16, 16, padding=1)
self.dropout4 = nn.Dropout(0.25)
self.batchnorm4 = nn.BatchNorm2d(16)
# flatten
self.flatten = nn.Flatten()
self.fc_2_1 = nn.Linear(28224, 512)
self.fc_2_2 = nn.Linear(512, 4)
# linear 1
self.fc1 = nn.Linear(1024, 512)
self.fc2 = nn.Linear(512, 4)
def forward(self, x):
x = func.relu(self.conv1(x))
x = self.batchnorm1(x)
x = self.pool1(x)
x = func.relu(self.conv2(x))
x = self.dropout2(x)
x = self.batchnorm2(x)
x = self.pool2(x)
x_1 = func.relu(self.conv3(x))
x_1 = self.dropout3(x_1)
x_1 = self.batchnorm3(x_1)
x_1 = self.pool3(x_1)
x_1 = func.relu(self.conv4(x_1))
x_1 = self.dropout4(x_1)
x_1 = self.batchnorm4(x_1)
x_1 = self.flatten(x_1)
x_1 = func.relu(self.fc1(x_1))
x_1 = self.fc2(x_1)
x_2 = self.flatten(x)
x_2 = func.relu(self.fc_2_1(x_2))
x_2 = self.fc_2_2(x_2)
return x_1 + x_2
setattr(__main__, "ConvNet", ConvNet)
device = 'cpu'
model_ = torch.load('model_5_7_14_27_0.993125_final')
model_.eval()
def get_class_of_demension(idx):
classes = ['NonDemented', 'VeryMildDemented', 'MildDemented', 'ModerateDemented']
return classes[idx]
def get_segmented_map(image_attr: np.array,
color_map: str = 'positive',
borders: tuple = (20, 20)) -> np.array:
"""arg: color_map: [positive, all]"""
if color_map != 'all':
for i in range(len(image_attr)):
for j in range(len(image_attr[i])):
flag_zero = False
if color_map == 'positive':
if max(image_attr[i][j]) != image_attr[i][j][1]:
flag_zero = True
else:
if sum(image_attr[i][j]) - max(image_attr[i][j]) > borders[1]:
flag_zero = True
elif color_map == 'negative':
if max(image_attr[i][j]) == image_attr[i][j][1] or max(image_attr[i][j]) == image_attr[i][j][2]:
flag_zero = True
else:
if sum(image_attr[i][j]) - max(image_attr[i][j]) > borders[0]:
flag_zero = True
if flag_zero:
image_attr[i][j] = [0, 0, 0]
return image_attr
def show_pack_of_images(images, labels):
f, axes = plt.subplots(1, len(images), figsize=(30, 5))
for i, axis in enumerate(axes):
img = images[i]
axes[i].imshow(img)
axes[i].set_title(labels[i])
plt.show()
def create_color_map_igrad(net, img_path: str) -> tuple:
integrated_gradients = IntegratedGradients(net)
img = cv2.cvtColor(cv2.resize(cv2.imread(img_path, 0), (180, 180)), cv2.COLOR_GRAY2RGB)
img_tensor = torch.from_numpy(np.array(img).astype(np.float32)).to('cpu')
img_tensor = img_tensor.permute(2, 0, 1) / 255
img_tensor = img_tensor.unsqueeze(0)
output = model_(img_tensor)
prob = func.sigmoid(output)
probability = float(np.max(prob.detach().numpy()))
prediction_score, pred_label_idx = torch.topk(output, 1)
pred_label_idx.squeeze_()
predicted_label = pred_label_idx.item()
attributions_ig = integrated_gradients.attribute(img_tensor, target=pred_label_idx, n_steps=200)
imgs = [(img_tensor.squeeze().permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8),
(np.transpose(attributions_ig.squeeze().cpu().detach().numpy(), (1, 2, 0)) * 255).astype(np.uint8)]
imgs.extend([get_segmented_map(imgs[1].copy(), 'negative'), get_segmented_map(imgs[1].copy(), 'positive')])
labels = [get_class_of_demension(predicted_label), 'all', 'negative', 'positive']
return imgs, labels, probability
def get_results_model(image_path, model):
images, labels, probability = create_color_map_igrad(model, image_path)
img = images[3].copy()
original = images[0].copy()
result = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY);
result = cv2.blur(result, (5, 5));
min_val, max_val, min_loc, max_loc = cv2.minMaxLoc(result)
ret, result = cv2.threshold(result, 0.3 * max_val, 255, cv2.THRESH_BINARY)
contours, hierarchy = cv2.findContours(result, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)
for element in contours:
if 150 > len(element) > 35:
color = (255, 0, 0)
x, y, w, h = cv2.boundingRect(element)
cv2.rectangle(original, (x - 2, y - 2), (x + w + 1, y + h + 1), color, 1)
return original, labels[0], probability