|
import tensorflow as tf |
|
from simple_unet_model import simple_unet_model |
|
from tensorflow.keras.utils import normalize |
|
import os |
|
from PIL import Image, ImageOps |
|
import numpy as np |
|
import gradio as gr |
|
|
|
|
|
def get_model(): |
|
return simple_unet_model(256, 256, 1) |
|
|
|
model = get_model() |
|
model.load_weights('mitochondria.hdf5') |
|
|
|
def predict(input_image): |
|
img = Image.fromarray(input_image) |
|
gray_img = ImageOps.grayscale(img) |
|
resized_img = gray_img.resize((256,256)) |
|
img = np.array(resized_img) |
|
img = np.expand_dims(img, axis = (0,3)) |
|
img = normalize(img, axis=1) |
|
mask = model.predict(img)[0,:,:,0] |
|
return mask |
|
|
|
def load_examples(): |
|
files = os.listdir() |
|
img_list = [] |
|
for file in files: |
|
if '.jpg' in file: |
|
img_list.append(file) |
|
return img_list |
|
|
|
examples = load_examples() |
|
|
|
demo = gr.Interface(fn=predict, |
|
inputs="image", |
|
outputs=gr.Image(shape=(256, 256)), |
|
title = "Mitochondria Detection", |
|
examples =examples ) |
|
demo.launch() |