Gladiator's picture
Update app.py
5d902d2
raw
history blame contribute delete
No virus
2.04 kB
import streamlit as st
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from cellpose import models
@st.cache()
def load_model(model_path):
inf_model = models.CellposeModel(gpu=False, pretrained_model=model_path)
return inf_model
#def inference(img, model_path, **model_params):
# inf_model = models.CellposeModel(gpu=False, pretrained_model=model_path)
# preds, flows, _ = inf_model.eval([img], **model_params)
# return preds, flows
if __name__ == "__main__":
st.title("Sartorius Neuronal Cell Segmentation")
model_path = "./cellpose_residual_on_style_on_concatenation_off_fold1_ep_649_cv_0.2834"
inf_model = models.CellposeModel(gpu=False, pretrained_model=model_path)
uploaded_img = st.file_uploader(label="Upload neuronal cell image")
with st.expander("View input image"):
if uploaded_img is not None:
st.image(uploaded_img)
else:
st.warning("Please upload an image")
segment = st.button("Perform segmentation")
if uploaded_img is not None and segment:
img = Image.open(uploaded_img)
img = np.array(img)
model_params = {
"diameter": 19.0,
"channels": [0, 0],
"augment": True,
"resample": True,
}
with st.spinner("Performing segmentation. This might take a while..."):
#preds, flows = inference(
# img=img,
# model_path="./cellpose_residual_on_style_on_concatenation_off_fold1_ep_649_cv_0.2834",
# **model_params
#)
preds, flows, _ = inf_model.eval([img], **model_params)
fig, (ax1, ax2, ax3) = plt.subplots(1, 3)
ax1.axis("off")
ax2.axis("off")
ax3.axis("off")
ax1.set_title("Original Image")
ax1.imshow(img, cmap="gray")
ax2.set_title("Segmented image")
ax2.imshow(preds[0])
ax3.set_title("Image flows")
ax3.imshow(flows[0][0])
st.pyplot(fig)