|
__all__ = [ |
|
"ORGAN", |
|
"IMAGE_SIZE", |
|
"MODEL_NAME", |
|
"THRESHOLD", |
|
"CODES", |
|
"learn", |
|
"title", |
|
"description", |
|
"examples", |
|
"interpretation", |
|
"demo", |
|
"x_getter", |
|
"y_getter", |
|
"splitter", |
|
"make3D", |
|
"predict", |
|
"infer", |
|
"remove_small_segs", |
|
"to_oberlay_image", |
|
] |
|
|
|
import numpy as np |
|
import pandas as pd |
|
import skimage |
|
from fastai.vision.all import * |
|
import segmentation_models_pytorch as smp |
|
|
|
import gradio as gr |
|
|
|
ORGAN = "kidney" |
|
IMAGE_SIZE = 512 |
|
MODEL_NAME = "unetpp_b4_th60_d9414.pkl" |
|
THRESHOLD = float(MODEL_NAME.split("_")[2][2:]) / 100.0 |
|
CODES = ["Background", "FTU"] |
|
|
|
|
|
def x_getter(r): |
|
return r["fnames"] |
|
|
|
|
|
def y_getter(r): |
|
rle = r["rle"] |
|
shape = (int(r["img_height"]), int(r["img_width"])) |
|
return rle_decode(rle, shape).T |
|
|
|
|
|
def splitter(model): |
|
enc_params = L(model.encoder.parameters()) |
|
dec_params = L(model.decoder.parameters()) |
|
sg_params = L(model.segmentation_head.parameters()) |
|
untrained_params = L([*dec_params, *sg_params]) |
|
return L([enc_params, untrained_params]) |
|
|
|
|
|
learn = load_learner(MODEL_NAME) |
|
|
|
|
|
def make3D(t: np.array) -> np.array: |
|
t = np.expand_dims(t, axis=2) |
|
t = np.concatenate((t, t, t), axis=2) |
|
return t |
|
|
|
|
|
def predict(fn, cutoff_area=200): |
|
data = infer(fn) |
|
data = remove_small_segs(data, cutoff_area=cutoff_area) |
|
return to_oberlay_image(data), data["df"] |
|
|
|
|
|
def infer(fn): |
|
img = PILImage.create(fn) |
|
tf_img, _, _, preds = learn.predict(img, with_input=True) |
|
mask = (F.softmax(preds.float(), dim=0) > THRESHOLD).int()[1] |
|
mask = np.array(mask, dtype=np.uint8) |
|
resized_image = Image.fromarray( |
|
tf_img.numpy().transpose(1, 2, 0).astype(np.uint8) |
|
).resize(img.shape) |
|
resized_image = np.array(resized_image) |
|
return { |
|
"tf_image": tf_img.numpy().transpose(1, 2, 0).astype(np.uint8), |
|
"tf_mask": mask, |
|
} |
|
|
|
|
|
def remove_small_segs(data, cutoff_area=250): |
|
labeled_mask = skimage.measure.label(data["tf_mask"]) |
|
props = skimage.measure.regionprops(labeled_mask) |
|
df = {"Glomerulus": [], "Area (in px)": []} |
|
for i, prop in enumerate(props): |
|
if prop.area < cutoff_area: |
|
labeled_mask[labeled_mask == i + 1] = 0 |
|
continue |
|
df["Glomerulus"].append(len(df["Glomerulus"]) + 1) |
|
df["Area (in px)"].append(prop.area) |
|
labeled_mask[labeled_mask > 0] = 1 |
|
data["tf_mask"] = labeled_mask.astype(np.uint8) |
|
data["df"] = pd.DataFrame(df) |
|
return data |
|
|
|
|
|
def to_oberlay_image(data): |
|
img, msk = data["tf_image"], data["tf_mask"] |
|
msk_im = np.zeros_like(img) |
|
|
|
msk_im[:, :, 0] = 255 |
|
msk_im[:, :, 1] = 80 |
|
msk_im[:, :, 2] = 80 |
|
img = Image.fromarray(img).convert("RGBA") |
|
msk_im = Image.fromarray(msk_im).convert("RGBA") |
|
msk = Image.fromarray((msk * 255 * 0.5).astype(np.uint8)) |
|
|
|
img.paste( |
|
msk_im, |
|
(0, 0), |
|
msk, |
|
) |
|
return img |
|
|
|
|
|
title = "Glomerulus Segmentation" |
|
description = """ |
|
A web app that segments glomeruli in histological kidney slices! |
|
|
|
The model deployed here is a [UNet++](https://arxiv.org/abs/1807.10165) with an [efficientnet-b4](https://arxiv.org/abs/1905.11946) encoder from the [segmentation_models_pytorch](https://github.com/qubvel/segmentation_models.pytorch) library. |
|
|
|
The provided example images are random subset of kidney slices from the [Human Protein Atlas](https://www.proteinatlas.org/). These have been collected separately from model training and have neither been part of the training, validation nor test set. |
|
|
|
Here is my corresponding [blog post](https://fhatje.github.io/posts/glomseg/train_model.html). |
|
""" |
|
|
|
examples = [str(p) for p in get_image_files("example_images")] |
|
interpretation = "default" |
|
|
|
demo = gr.Interface( |
|
fn=predict, |
|
inputs=gr.components.Image(width=IMAGE_SIZE, height=IMAGE_SIZE), |
|
outputs=[gr.components.Image(), gr.components.DataFrame()], |
|
title=title, |
|
description=description, |
|
examples=examples, |
|
) |
|
|
|
demo.launch() |
|
|