Spaces:
Sleeping
Sleeping
## Daniel Buscombe, Marda Science LLC 2023 | |
# This file contains many functions originally from Doodleverse https://github.com/Doodleverse programs | |
import gradio as gr | |
import numpy as np | |
import tensorflow as tf | |
import matplotlib.pyplot as plt | |
from skimage.transform import resize | |
from skimage.io import imsave, imread | |
from skimage.filters import threshold_otsu | |
# from skimage.measure import EllipseModel, CircleModel, ransac | |
from glob import glob | |
import json | |
from transformers import TFSegformerForSemanticSegmentation | |
##======================================================== | |
def segformer( | |
id2label, | |
num_classes=2, | |
): | |
""" | |
https://keras.io/examples/vision/segformer/ | |
https://huggingface.co./nvidia/mit-b0 | |
""" | |
label2id = {label: id for id, label in id2label.items()} | |
model_checkpoint = "nvidia/mit-b0" | |
model = TFSegformerForSemanticSegmentation.from_pretrained( | |
model_checkpoint, | |
num_labels=num_classes, | |
id2label=id2label, | |
label2id=label2id, | |
ignore_mismatched_sizes=True, | |
) | |
return model | |
##======================================================== | |
def fromhex(n): | |
"""hexadecimal to integer""" | |
return int(n, base=16) | |
##======================================================== | |
def label_to_colors( | |
img, | |
mask, | |
alpha, # =128, | |
colormap, # =class_label_colormap, #px.colors.qualitative.G10, | |
color_class_offset, # =0, | |
do_alpha, # =True | |
): | |
""" | |
Take MxN matrix containing integers representing labels and return an MxNx4 | |
matrix where each label has been replaced by a color looked up in colormap. | |
colormap entries must be strings like plotly.express style colormaps. | |
alpha is the value of the 4th channel | |
color_class_offset allows adding a value to the color class index to force | |
use of a particular range of colors in the colormap. This is useful for | |
example if 0 means 'no class' but we want the color of class 1 to be | |
colormap[0]. | |
""" | |
colormap = [ | |
tuple([fromhex(h[s : s + 2]) for s in range(0, len(h), 2)]) | |
for h in [c.replace("#", "") for c in colormap] | |
] | |
cimg = np.zeros(img.shape[:2] + (3,), dtype="uint8") | |
minc = np.min(img) | |
maxc = np.max(img) | |
for c in range(minc, maxc + 1): | |
cimg[img == c] = colormap[(c + color_class_offset) % len(colormap)] | |
cimg[mask == 1] = (0, 0, 0) | |
if do_alpha is True: | |
return np.concatenate( | |
(cimg, alpha * np.ones(img.shape[:2] + (1,), dtype="uint8")), axis=2 | |
) | |
else: | |
return cimg | |
##==================================== | |
def standardize(img): | |
# standardization using adjusted standard deviation | |
N = np.shape(img)[0] * np.shape(img)[1] | |
s = np.maximum(np.std(img), 1.0 / np.sqrt(N)) | |
m = np.mean(img) | |
img = (img - m) / s | |
del m, s, N | |
# | |
if np.ndim(img) == 2: | |
img = np.dstack((img, img, img)) | |
return img | |
############################################################ | |
############################################################ | |
#load model | |
filepath = './weights/ct_NAIP_8class_768_segformer_v3_fullmodel.h5' | |
configfile = filepath.replace('_fullmodel.h5','.json') | |
with open(configfile) as f: | |
config = json.load(f) | |
# This is how the program is able to use variables that have never been explicitly defined | |
for k in config.keys(): | |
exec(k+'=config["'+k+'"]') | |
id2label = {} | |
for k in range(NCLASSES): | |
id2label[k]=str(k) | |
model = segformer(id2label,num_classes=NCLASSES) | |
# model.compile(optimizer='adam') | |
model.load_weights(filepath) | |
############################################################ | |
############################################################ | |
# #----------------------------------- | |
def est_label_multiclass(image,Mc,MODEL,TESTTIMEAUG,NCLASSES,TARGET_SIZE): | |
est_label = np.zeros((TARGET_SIZE[0], TARGET_SIZE[1], NCLASSES)) | |
for counter, model in enumerate(Mc): | |
# heatmap = make_gradcam_heatmap(tf.expand_dims(image, 0) , model) | |
try: | |
if MODEL=='segformer': | |
est_label = model(tf.expand_dims(image, 0)).logits | |
else: | |
est_label = tf.squeeze(model(tf.expand_dims(image, 0))) | |
except: | |
if MODEL=='segformer': | |
est_label = model(tf.expand_dims(image[:,:,0], 0)).logits | |
else: | |
est_label = tf.squeeze(model(tf.expand_dims(image[:,:,0], 0))) | |
if TESTTIMEAUG == True: | |
# return the flipped prediction | |
if MODEL=='segformer': | |
est_label2 = np.flipud( | |
model(tf.expand_dims(np.flipud(image), 0)).logits | |
) | |
else: | |
est_label2 = np.flipud( | |
tf.squeeze(model(tf.expand_dims(np.flipud(image), 0))) | |
) | |
if MODEL=='segformer': | |
est_label3 = np.fliplr( | |
model( | |
tf.expand_dims(np.fliplr(image), 0)).logits | |
) | |
else: | |
est_label3 = np.fliplr( | |
tf.squeeze(model(tf.expand_dims(np.fliplr(image), 0))) | |
) | |
if MODEL=='segformer': | |
est_label4 = np.flipud( | |
np.fliplr( | |
tf.squeeze(model(tf.expand_dims(np.flipud(np.fliplr(image)), 0)).logits)) | |
) | |
else: | |
est_label4 = np.flipud( | |
np.fliplr( | |
tf.squeeze(model( | |
tf.expand_dims(np.flipud(np.fliplr(image)), 0))) | |
)) | |
# soft voting - sum the softmax scores to return the new TTA estimated softmax scores | |
est_label = est_label + est_label2 + est_label3 + est_label4 | |
return est_label, counter | |
# #----------------------------------- | |
def seg_file2tensor_3band(bigimage, TARGET_SIZE): | |
""" | |
"seg_file2tensor(f)" | |
This function reads a jpeg image from file into a cropped and resized tensor, | |
for use in prediction with a trained segmentation model | |
INPUTS: | |
* f [string] file name of jpeg | |
OPTIONAL INPUTS: None | |
OUTPUTS: | |
* image [tensor array]: unstandardized image | |
GLOBAL INPUTS: TARGET_SIZE | |
""" | |
smallimage = resize( | |
bigimage, (TARGET_SIZE[0], TARGET_SIZE[1]), preserve_range=True, clip=True | |
) | |
smallimage = np.array(smallimage) | |
smallimage = tf.cast(smallimage, tf.uint8) | |
w = tf.shape(bigimage)[0] | |
h = tf.shape(bigimage)[1] | |
return smallimage, w, h, bigimage | |
# #----------------------------------- | |
def get_image(f,N_DATA_BANDS,TARGET_SIZE,MODEL): | |
image, w, h, bigimage = seg_file2tensor_3band(f, TARGET_SIZE) | |
image = standardize(image.numpy()).squeeze() | |
if MODEL=='segformer': | |
if np.ndim(image)==2: | |
image = np.dstack((image, image, image)) | |
image = tf.transpose(image, (2, 0, 1)) | |
return image, w, h, bigimage | |
# #----------------------------------- | |
#segmentation | |
def segment(input_img, use_tta, use_otsu, dims=(768, 768)): | |
if use_otsu: | |
print("Use Otsu threshold") | |
else: | |
print("No Otsu threshold") | |
if use_tta: | |
print("Use TTA") | |
else: | |
print("Do not use TTA") | |
image, w, h, bigimage = get_image(input_img,N_DATA_BANDS,TARGET_SIZE,MODEL) | |
est_label, counter = est_label_multiclass(image,[model],'segformer',TESTTIMEAUG,NCLASSES,TARGET_SIZE) | |
print(est_label.shape) | |
est_label /= counter + 1 | |
# est_label cannot be float16 so convert to float32 | |
est_label = est_label.numpy().astype('float32') | |
est_label = resize(est_label, (1, NCLASSES, TARGET_SIZE[0],TARGET_SIZE[1]), preserve_range=True, clip=True).squeeze() | |
est_label = np.transpose(est_label, (1,2,0)) | |
est_label = resize(est_label, (w, h)) | |
est_label = np.argmax(est_label,-1) | |
print(est_label.shape) | |
imsave("greyscale_download_me.png", est_label.astype('uint8')) | |
class_label_colormap = [ | |
"#3366CC", | |
"#DC3912", | |
"#FF9900", | |
"#109618", | |
"#990099", | |
"#0099C6", | |
"#DD4477", | |
"#66AA00", | |
"#B82E2E", | |
"#316395", | |
] | |
# add classes | |
class_label_colormap = class_label_colormap[:NCLASSES] | |
color_label = label_to_colors( | |
est_label, | |
input_img[:, :, 0] == 0, | |
alpha=128, | |
colormap=class_label_colormap, | |
color_class_offset=0, | |
do_alpha=False, | |
) | |
imsave("color_download_me.png", color_label) | |
return color_label,"greyscale_download_me.png", "color_download_me.png" | |
title = "Mapping sand in high-res. imagery" | |
description = "This simple model demonstration segments NAIP RGB (visible spectrum) imagery into the following classes:1. water (unbroken water); 2. whitewater (surf, active wave breaking); 3. sediment (natural deposits of sand. gravel, mud, etc), 4. other_bare_natural_terrain, 5. marsh_vegetation, 6. terrestrial_vegetation, 7. agricultural, 8. development. Please note that, ordinarily, ensemble models are used in predictive mode. Here, we are using just one model, i.e. without ensembling. Allows upload of 3-band imagery in jpg format and download of label imagery only one at a time. " | |
examples= [[l] for l in glob('examples/*.jpg')] | |
inp = gr.Image() | |
out1 = gr.Image(type='numpy') | |
# out2 = gr.Plot(type='matplotlib') | |
out3 = gr.File() | |
out4 = gr.File() | |
inp2 = gr.inputs.Checkbox(default=False, label="Use TTA") | |
inp3 = gr.inputs.Checkbox(default=False, label="Use Otsu") | |
Segapp = gr.Interface(segment, [inp, inp2, inp3], | |
[out1, out3, out4], #out2 | |
title = title, description = description, examples=examples, | |
theme="grass") | |
Segapp.launch(enable_queue=True) |