SPADESegResNet / app.py
srijaydeshpande's picture
Update app.py
1c3571e verified
raw
history blame
10.5 kB
import colorsys
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from metrics import *
import torchvision.transforms as T
import gradio as gr
import matplotlib.pyplot as plt
import tempfile
import os
import spaces
from huggingface_hub import snapshot_download
from huggingface_hub import login
login(token = os.getenv('HF_TOKEN'))
model_dir = snapshot_download(
repo_id="srijaydeshpande/spadesegresnet"
)
class SPADE(nn.Module):
def __init__(self, norm_nc, label_nc, norm):
super().__init__()
if norm == 'instance':
self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False)
elif norm == 'batch':
self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=False)
# The dimension of the intermediate embedding space. Yes, hardcoded.
nhidden = 128
ks = 3
pw = ks // 2
self.mlp_shared = nn.Sequential(
nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw),
nn.ReLU()
)
self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)
self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)
def forward(self, x, segmap):
# Part 1. generate parameter-free normalized activations
normalized = self.param_free_norm(x)
# Part 2. produce scaling and bias conditioned on semantic map
segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest')
actv = self.mlp_shared(segmap)
gamma = self.mlp_gamma(actv)
beta = self.mlp_beta(actv)
# apply scale and bias
out = normalized * (1 + gamma) + beta
return out
class SPADEResnetBlock(nn.Module):
def __init__(self, fin, fout):
super().__init__()
# Attributes
self.learned_shortcut = (fin != fout)
fmiddle = min(fin, fout)
# create conv layers
self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=1)
self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=1)
if self.learned_shortcut:
self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False)
# define normalization layers
self.norm_0 = SPADE(fin, 3, norm='instance')
self.norm_1 = SPADE(fmiddle, 3, norm='instance')
if self.learned_shortcut:
self.norm_s = SPADE(fin, 3, norm='instance')
def forward(self, x, seg):
x_s = self.shortcut(x, seg)
dx = self.conv_0(self.actvn(self.norm_0(x, seg)))
dx = self.conv_1(self.actvn(self.norm_1(dx, seg)))
out = x_s + dx
return out
def shortcut(self, x, seg):
if self.learned_shortcut:
x_s = self.conv_s(self.norm_s(x, seg))
else:
x_s = x
return x_s
def actvn(self, x):
return F.leaky_relu(x, 2e-1)
class ResnetBlock(nn.Module):
def __init__(self, dim, padding_type, norm_layer, activation=nn.ReLU(True), use_dropout=False):
super(ResnetBlock, self).__init__()
self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, activation, use_dropout)
def build_conv_block(self, dim, padding_type, norm_layer, activation, use_dropout):
conv_block = []
p = 0
if padding_type == 'reflect':
conv_block += [nn.ReflectionPad2d(1)]
elif padding_type == 'replicate':
conv_block += [nn.ReplicationPad2d(1)]
elif padding_type == 'zero':
p = 1
else:
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p),
norm_layer(dim),
activation]
if use_dropout:
conv_block += [nn.Dropout(0.5)]
p = 0
if padding_type == 'reflect':
conv_block += [nn.ReflectionPad2d(1)]
elif padding_type == 'replicate':
conv_block += [nn.ReplicationPad2d(1)]
elif padding_type == 'zero':
p = 1
else:
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p),
norm_layer(dim)]
return nn.Sequential(*conv_block)
def forward(self, x):
out = x + self.conv_block(x)
return out
class SPADEResNet(torch.nn.Module):
def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3, n_blocks=5, norm_layer=nn.BatchNorm2d,
padding_type='reflect'):
assert (n_blocks >= 0)
super(SPADEResNet, self).__init__()
activation = nn.ReLU(True)
downsampler = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0), norm_layer(ngf), activation]
### downsample
for i in range(n_downsampling):
mult = 2 ** i
downsampler += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1),
norm_layer(ngf * mult * 2), activation]
self.downsampler = nn.Sequential(*downsampler)
### resnet blocks
mult = 2 ** n_downsampling
self.resnetblocks1 = SPADEResnetBlock(ngf * mult, ngf * mult)
self.resnetblocks2 = SPADEResnetBlock(ngf * mult, ngf * mult)
self.resnetblocks3 = SPADEResnetBlock(ngf * mult, ngf * mult)
self.resnetblocks4 = SPADEResnetBlock(ngf * mult, ngf * mult)
self.resnetblocks5 = SPADEResnetBlock(ngf * mult, ngf * mult)
### upsample
upsampler = []
for i in range(n_downsampling):
mult = 2 ** (n_downsampling - i)
upsampler += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1,
output_padding=1),
norm_layer(int(ngf * mult / 2)), activation]
upsampler += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Tanh()]
self.upsampler = nn.Sequential(*upsampler)
def forward(self, input):
downsampled = self.downsampler(input)
resnet1 = self.resnetblocks1(downsampled, input)
resnet2 = self.resnetblocks1(resnet1, input)
resnet3 = self.resnetblocks1(resnet2, input)
resnet4 = self.resnetblocks1(resnet3, input)
resnet5 = self.resnetblocks1(resnet4, input)
upsampled = self.upsampler(resnet5)
return upsampled
def generate_colors(n):
brightness = 0.7
hsv = [(i / n, 1, brightness) for i in range(n)]
colors = list(map(lambda c: colorsys.hsv_to_rgb(*c), hsv))
colors = list(map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)),colors))
return colors
def generate_colored_image(labels):
colors = generate_colors(6)
w, h = labels.shape
new_mk = np.empty([w, h, 3])
for i in range(0,w):
for j in range(0,h):
new_mk[i][j] = colors[labels[i][j]]
# new_mk = new_mk / 255.0
new_mk = new_mk.astype(np.uint8)
return Image.fromarray(new_mk)
def predict_wsi(image):
patch_size = 768
stride = 700 # stride is kept relatively lower than the tile size so as to allow some overlap while constructing bigger regions
generator_output_size = patch_size
num_classes=5
pred_labels = torch.zeros(1, num_classes+1, image.shape[2], image.shape[3]).cuda()
counter_tensor = torch.zeros(1, 1, image.shape[2], image.shape[3]).cuda()
for i in range(0, image.shape[2] - patch_size + 1, stride):
for j in range(0, image.shape[3] - patch_size + 1, stride):
i_lowered = min(i, image.shape[2] - patch_size)
j_lowered = min(j, image.shape[3] - patch_size)
patch = image[:, :, i_lowered:i_lowered + patch_size, j_lowered:j_lowered + patch_size]
pred_labels_patch = model(patch.float())
update_region_i = i_lowered + (patch_size - generator_output_size) // 2
update_region_j = j_lowered + (patch_size - generator_output_size) // 2
pred_labels[:, :, update_region_i:update_region_i + generator_output_size,
update_region_j:update_region_j + generator_output_size] += pred_labels_patch
counter_tensor[:, :, update_region_i:update_region_i + generator_output_size,
update_region_j:update_region_j + generator_output_size] += 1
pred_labels /= counter_tensor
return pred_labels
@spaces.GPU(duration=120)
def segment_image(image):
# img = Image.open(image_path)
img = image
img = np.asarray(img)
if (np.max(img) > 100):
img = img / 255.0
transform = T.Compose([T.ToTensor()])
image = transform(img)
image = image[None, :]
with torch.no_grad():
pred_labels = predict_wsi(image.float())
pred_labels = F.softmax(pred_labels, dim=1)
pred_labels_probs = pred_labels.cpu().numpy()
pred_labels = np.argmax(pred_labels_probs, axis=1)
pred_labels = pred_labels[0]
image = generate_colored_image(pred_labels)
class_labels = ['tumor', 'stroma', 'inflammatory', 'necrosis', 'others']
pixels_counts = []
total=0
print(np.unique(pred_labels))
for i in range(1,len(class_labels)+1):
current_count=np.sum(pred_labels == i)
pixels_counts.append(current_count)
total+=current_count
pixels_counts = [(value / total) * 100 for value in pixels_counts]
print(pixels_counts)
plt.figure(figsize=(10, 6))
bar_width = 0.15
plt.bar(class_labels, pixels_counts, color='blue', width=bar_width)
plt.xticks(rotation=45, ha='right')
plt.xlabel('Tissue types', fontsize=17)
plt.ylabel('Class Percentage', fontsize=17)
plt.title('Classes distribution', fontsize=18)
plt.xticks(fontsize=16)
plt.yticks(fontsize=16)
plt.tight_layout()
with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmpfile:
plt.savefig(tmpfile.name)
temp_filename = tmpfile.name
stats = Image.open(temp_filename)
legend = Image.open('legend.png')
return image, legend, stats
model_path = os.path.join(model_dir, 'spaderesnet.pt')
model = SPADEResNet(input_nc=3, output_nc=6)
model = nn.DataParallel(model)
model = model.cuda()
model.load_state_dict(torch.load(model_path), strict=True)
examples = [
["sample1.png"],
["sample2.png"]
]
demo = gr.Interface(
segment_image,
inputs=gr.Image(),
outputs=["image", "image", "image"],
title="Breast Cancer Semantic Segmentation"
)
demo.launch()