glomseg / app.py
fhatje's picture
Removed unnecessay comment
bd79682
__all__ = [
"ORGAN",
"IMAGE_SIZE",
"MODEL_NAME",
"THRESHOLD",
"CODES",
"learn",
"title",
"description",
"examples",
"interpretation",
"demo",
"x_getter",
"y_getter",
"splitter",
"make3D",
"predict",
"infer",
"remove_small_segs",
"to_oberlay_image",
]
import numpy as np
import pandas as pd
import skimage
from fastai.vision.all import *
import segmentation_models_pytorch as smp
import gradio as gr
ORGAN = "kidney"
IMAGE_SIZE = 512
MODEL_NAME = "unetpp_b4_th60_d9414.pkl"
THRESHOLD = float(MODEL_NAME.split("_")[2][2:]) / 100.0
CODES = ["Background", "FTU"] # FTU = functional tissue unit
def x_getter(r):
return r["fnames"]
def y_getter(r):
rle = r["rle"]
shape = (int(r["img_height"]), int(r["img_width"]))
return rle_decode(rle, shape).T
def splitter(model):
enc_params = L(model.encoder.parameters())
dec_params = L(model.decoder.parameters())
sg_params = L(model.segmentation_head.parameters())
untrained_params = L([*dec_params, *sg_params])
return L([enc_params, untrained_params])
learn = load_learner(MODEL_NAME)
def make3D(t: np.array) -> np.array:
t = np.expand_dims(t, axis=2)
t = np.concatenate((t, t, t), axis=2)
return t
def predict(fn, cutoff_area=200):
data = infer(fn)
data = remove_small_segs(data, cutoff_area=cutoff_area)
return to_oberlay_image(data), data["df"]
def infer(fn):
img = PILImage.create(fn)
tf_img, _, _, preds = learn.predict(img, with_input=True)
mask = (F.softmax(preds.float(), dim=0) > THRESHOLD).int()[1]
mask = np.array(mask, dtype=np.uint8)
resized_image = Image.fromarray(
tf_img.numpy().transpose(1, 2, 0).astype(np.uint8)
).resize(img.shape)
resized_image = np.array(resized_image)
return {
"tf_image": tf_img.numpy().transpose(1, 2, 0).astype(np.uint8),
"tf_mask": mask,
}
def remove_small_segs(data, cutoff_area=250):
labeled_mask = skimage.measure.label(data["tf_mask"])
props = skimage.measure.regionprops(labeled_mask)
df = {"Glomerulus": [], "Area (in px)": []}
for i, prop in enumerate(props):
if prop.area < cutoff_area:
labeled_mask[labeled_mask == i + 1] = 0
continue
df["Glomerulus"].append(len(df["Glomerulus"]) + 1)
df["Area (in px)"].append(prop.area)
labeled_mask[labeled_mask > 0] = 1
data["tf_mask"] = labeled_mask.astype(np.uint8)
data["df"] = pd.DataFrame(df)
return data
def to_oberlay_image(data):
img, msk = data["tf_image"], data["tf_mask"]
msk_im = np.zeros_like(img)
# rgb code: 255, 80, 80
msk_im[:, :, 0] = 255
msk_im[:, :, 1] = 80
msk_im[:, :, 2] = 80
img = Image.fromarray(img).convert("RGBA")
msk_im = Image.fromarray(msk_im).convert("RGBA")
msk = Image.fromarray((msk * 255 * 0.5).astype(np.uint8))
img.paste(
msk_im,
(0, 0),
msk,
)
return img
title = "Glomerulus Segmentation"
description = """
A web app that segments glomeruli in histological kidney slices!
The model deployed here is a [UNet++](https://arxiv.org/abs/1807.10165) with an [efficientnet-b4](https://arxiv.org/abs/1905.11946) encoder from the [segmentation_models_pytorch](https://github.com/qubvel/segmentation_models.pytorch) library.
The provided example images are random subset of kidney slices from the [Human Protein Atlas](https://www.proteinatlas.org/). These have been collected separately from model training and have neither been part of the training, validation nor test set.
Here is my corresponding [blog post](https://fhatje.github.io/posts/glomseg/train_model.html).
"""
examples = [str(p) for p in get_image_files("example_images")]
interpretation = "default"
demo = gr.Interface(
fn=predict,
inputs=gr.components.Image(width=IMAGE_SIZE, height=IMAGE_SIZE),
outputs=[gr.components.Image(), gr.components.DataFrame()],
title=title,
description=description,
examples=examples,
)
demo.launch()