File size: 2,037 Bytes
7a1d3a9
 
85c7830
308f32a
271c748
7a1d3a9
 
fb058b3
 
 
 
 
 
 
 
 
7a1d3a9
 
 
 
271c748
5d902d2
fb058b3
7a1d3a9
fb058b3
271c748
 
 
 
 
 
708ffff
271c748
708ffff
7a1d3a9
2492cff
7a1d3a9
 
 
 
 
 
 
0fbcbdf
fb058b3
 
 
 
 
 
271c748
 
 
 
 
 
51f706a
271c748
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import streamlit as st
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from cellpose import models


@st.cache()
def load_model(model_path):
    inf_model = models.CellposeModel(gpu=False, pretrained_model=model_path)    
    return inf_model

#def inference(img, model_path, **model_params):
#    inf_model = models.CellposeModel(gpu=False, pretrained_model=model_path)
#    preds, flows, _ = inf_model.eval([img], **model_params)
#    return preds, flows


if __name__ == "__main__":

    st.title("Sartorius Neuronal Cell Segmentation")
    model_path = "./cellpose_residual_on_style_on_concatenation_off_fold1_ep_649_cv_0.2834"
    inf_model = models.CellposeModel(gpu=False, pretrained_model=model_path)    
    uploaded_img = st.file_uploader(label="Upload neuronal cell image")
    
    with st.expander("View input image"):
        if uploaded_img is not None:
            st.image(uploaded_img)
        else:
            st.warning("Please upload an image")

    segment = st.button("Perform segmentation")

    if uploaded_img is not None and segment:
        img = Image.open(uploaded_img)
        img = np.array(img)

        model_params = {
            "diameter": 19.0,
            "channels": [0, 0],
            "augment": True,
            "resample": True,
        }
        with st.spinner("Performing segmentation. This might take a while..."):
            #preds, flows = inference(
             #   img=img,
              #  model_path="./cellpose_residual_on_style_on_concatenation_off_fold1_ep_649_cv_0.2834",
               # **model_params
            #)
            preds, flows, _ = inf_model.eval([img], **model_params)

        fig, (ax1, ax2, ax3) = plt.subplots(1, 3)
        ax1.axis("off")
        ax2.axis("off")
        ax3.axis("off")
        ax1.set_title("Original Image")
        ax1.imshow(img, cmap="gray")
        ax2.set_title("Segmented image")
        ax2.imshow(preds[0])
        ax3.set_title("Image flows")
        ax3.imshow(flows[0][0])
        st.pyplot(fig)