seminar-demo / app.py
ilhamap's picture
Update app.py
b42c3a8 verified
import streamlit as st
import io
import collections
from scipy.io import loadmat
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
import torch
import argparse
import torch.nn as nn
import torch.utils.data as Data
import torch.backends.cudnn as cudnn
from scipy.io import loadmat
from scipy.io import savemat
from torch import optim
from torch.autograd import Variable
from sstvit import SSTViT
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
from matplotlib import colors
import numpy as np
from patchify import patchify, unpatchify
import time
from matplotlib import colors as mcolors
import base64
import pandas as pd
import st_aggrid
import os
import json
import plotly.express as px
css='''
<style>
section.main > div {max-width:60rem}
</style>
'''
st.markdown(css, unsafe_allow_html=True)
class Args(dict):
__setattr__ = dict.__setitem__
__getattr__ = dict.__getitem__
args = {
'dataset' : 'mg',
'flag_test' : 'train',
'gpu_id' : 0,
'seed' : int(0),
'batch_size' : int(64),
'test_freq' : int(10),
'patches' : int(5),
'band_patches' : int(1),
'epoches' : int(2000),
'learning_rate' : float(5e-4),
'gamma' : float(0.9),
'weight_decay' : float(0),
'train_number' : int(500)
}
args = Args(args) # dict2object
obj = args.copy() # object2dict
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_id)
def test_epoch(model, test_loader):
pre = np.array([])
for batch_idx, (batch_data_t1, batch_data_t2) in enumerate(test_loader):
batch_data_t1 = batch_data_t1
batch_data_t2 = batch_data_t2
batch_pred = model(batch_data_t1,batch_data_t2)
_, pred = batch_pred.topk(1, 1, True, True)
pp = pred.squeeze()
pre = np.append(pre, pp.data.cpu().numpy())
return pre
mdic = ['Before','After','Before','After']
colors = ['#3b68f8', '#ff0201', '#23fe01'] #-1,0,1,2,3
cmap = mcolors.ListedColormap(colors)
# Parameter Setting
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
cudnn.deterministic = True
cudnn.benchmark = False
def encode_masks_to_rgb(masks):
colors = [(0, 0, 255), (255, 0, 0), (0, 255, 0)]
# Create an empty RGB image
height, width = masks.shape
rgb_image = np.zeros((height, width, 3), dtype=np.uint8)
# Assign colors based on the mask values
for i in range(len(colors)):
mask_indices = masks == i
rgb_image[mask_indices] = colors[i]
return rgb_image
def count_pixel(pred):
image = Image.fromarray(pred)
# Define the colors you want to count in RGB format
color2label = {
(0, 0, 255): "Non Mangrove",
(255, 0, 0): "Mangrove Loss",
(0, 255, 0): "Mangrove Before",
}
# Create a flattened list of pixel values
pixels = list(image.getdata())
# Count the number of pixels for each color
color_counts = collections.Counter(pixels)
# Calculate the total number of pixels in the image
total_pixels = len(pixels)
# Initialize a dictionary to store the average number of pixels for each class
average_counts = {color2label[label]: (count / total_pixels)*100 for label, count in color_counts.items()}
class_counts = {color2label[label]: count for label, count in color_counts.items()}
pix_avg = {}
pix_count = {}
for _, i in color2label.items():
try:
pix_avg[i] = average_counts[i]
pix_count[i] = class_counts[i]
except:
pix_avg[i] = 0
pix_count[i] = 0
x = {
"class": list(pix_avg.keys()),
"percentage": list(pix_avg.values()),
"pixel_count": list(pix_count.values())
}
# print(x)
return pd.DataFrame(x)
def count_pixel1(pred):
image = Image.fromarray(pred)
# Define the colors you want to count in RGB format
color2label = {
(0, 0, 255): "Non Mangrove",
(255, 0, 0): "Mangrove Loss",
(0, 255, 0): "Mangrove After",
}
# Create a flattened list of pixel values
pixels = list(image.getdata())
# Count the number of pixels for each color
color_counts = collections.Counter(pixels)
# Calculate the total number of pixels in the image
total_pixels = len(pixels)
# Initialize a dictionary to store the average number of pixels for each class
average_counts = {color2label[label]: (count / total_pixels)*100 for label, count in color_counts.items()}
class_counts = {color2label[label]: count for label, count in color_counts.items()}
pix_avg = {}
pix_count = {}
for _, i in color2label.items():
try:
pix_avg[i] = average_counts[i]
pix_count[i] = class_counts[i]
except:
pix_avg[i] = 0
pix_count[i] = 0
x = {
"class": list(pix_avg.keys()),
"percentage": list(pix_avg.values()),
"pixel_count": list(pix_count.values())
}
# print(x)
return pd.DataFrame(x)
file = st.file_uploader("Upload file", type=['mat'])
if file:
data_img2 = loadmat(file)['data_img2']
data_img1 = loadmat(file)['data_img1']
st.subheader("Preview Dataset")
col1, col2 = st.columns(2)
with col1:
fig = plt.figure(figsize=(5, 5))
plt.subplot(121)
plt.imshow(data_img1)
plt.title('Before', fontweight='bold')
plt.xticks([])
plt.yticks([])
plt.subplot(122)
plt.imshow(data_img2)
plt.title('After', fontweight='bold')
plt.xticks([])
plt.yticks([])
plt.show()
st.pyplot(fig)
holder = st.empty()
if holder.button("Start Prediction"):
start = time.time()
holder.empty()
with st.spinner("Processing, please wait around 7-15 minute"):
data_t1 = loadmat(file)['data_t1']
data_t2 = loadmat(file)['data_t2']
L_post = loadmat(file)['L_post']
L_pre = loadmat(file)['L_pre']
data_img1 = loadmat(file)['data_img1']
data_img2 = loadmat(file)['data_img2']
L_post = np.double(L_post)
L_post[L_post==0]=-0.8
L_post[L_post==1]=0
L_post[L_post==0]=-0.2
L_pre = np.double(L_pre)
L_pre[L_pre==0]=-0.8
L_pre[L_pre==1]=0
L_pre[L_pre==0]=-0.2
data_t1 = data_t1[:L_post.shape[0],:L_post.shape[1],:]
data_t2 = data_t2[:L_post.shape[0],:L_post.shape[1],:]
data_cb1 = np.zeros(shape=(L_post.shape[0],L_post.shape[1],11),dtype=np.float32)
data_cb2 = np.zeros(shape=(L_post.shape[0],L_post.shape[1],11),dtype=np.float32)
data_cb1[:,:,:10]=data_t1
data_cb1[:,:,10]=L_pre
data_cb2[:,:,:10]=data_t2
data_cb2[:,:,10]=L_post
height, width, band = data_cb1.shape
height=height-4
width = width-4
x1 = patchify(data_cb1, (5, 5, 11), step=1).reshape(-1,5*5, 11)
x2 = patchify(data_cb2, (5, 5, 11), step=1).reshape(-1,5*5, 11)
# create model
model = SSTViT(
image_size = 5,
near_band = args.band_patches,
num_patches = 11,
num_classes = 3,
dim = 32,
depth = 2,
heads = 4,
dim_head=16,
mlp_dim = 8,
b_dim = 512,
b_depth = 3,
b_heads = 8,
b_dim_head= 32,
b_mlp_head = 8,
dropout = 0.2,
emb_dropout = 0.1,
)
model.load_state_dict(torch.load("model/lsstformer.pth",map_location=torch.device("cpu")))
x1_true_band=torch.from_numpy(x1.transpose(0,2,1)).type(torch.FloatTensor)
x2_true_band=torch.from_numpy(x1.transpose(0,2,1)).type(torch.FloatTensor)
Label_true=Data.TensorDataset(x1_true_band,x2_true_band)
label_true_loader=Data.DataLoader(Label_true,batch_size=100,shuffle=False)
model.eval()
# output classification maps
pre_u = test_epoch(model, label_true_loader)
prediction_matrix = pre_u.reshape(height,width)
x1_true_band=torch.from_numpy(x1.transpose(0,2,1)).type(torch.FloatTensor)
x2_true_band=torch.from_numpy(x2.transpose(0,2,1)).type(torch.FloatTensor)
Label_true=Data.TensorDataset(x1_true_band,x2_true_band)
label_true_loader=Data.DataLoader(Label_true,batch_size=100,shuffle=False)
model.eval()
# output classification maps
pre_u = test_epoch(model, label_true_loader)
prediction_matrix2 = pre_u.reshape(height,width)
A = prediction_matrix.reshape(-1)
B = prediction_matrix2.reshape(-1)
mg = np.array(np.where(A==2))
mg1 = np.array(np.where(B==2))
mgls = np.array(np.where(B==1))
class_counts = count_pixel(encode_masks_to_rgb(prediction_matrix))
class_counts1 = count_pixel1(encode_masks_to_rgb(prediction_matrix2))
with st.container():
st.subheader("Prediction Result")
col1, col2 = st.columns(2)
with col1:
with st.container():
fig = plt.figure(figsize=(10, 10))
plt.subplot(121)
plt.imshow(prediction_matrix, cmap=cmap)
plt.title('Before',fontsize=25, fontweight='bold')
plt.xticks([])
plt.yticks([])
plt.subplot(122)
plt.imshow(prediction_matrix2, cmap=cmap)
plt.title('After',fontsize=25, fontweight='bold')
plt.xticks([])
plt.yticks([])
plt.show()
st.pyplot(fig)
buf = io.BytesIO()
fig.savefig(buf, format="png")
with col2:
with st.container():
table_data = {
"Total mangrove before":f"{mg.shape[1]*100} m\u00B2",
"Total mangrove after":f"{mg1.shape[1]*100} m\u00B2",
"Total mangrove loss":f"{mgls.shape[1]*100} m\u00B2",
}
df = pd.DataFrame(list(table_data.items()), columns=['Key', 'Value'])
MIN_HEIGHT = 100
MAX_HEIGHT = 180
ROW_HEIGHT = 50
# st.dataframe(df, hide_index=True, use_container_width=True)
st_aggrid.AgGrid(df,fit_columns_on_grid_load=True, height=min(MIN_HEIGHT + len(df) * ROW_HEIGHT, MAX_HEIGHT))
with st.container():
st.subheader("Pixel Distribution")
df = class_counts
df = df.drop(0)
df1 = df.drop(1)
df2 = class_counts1
df3 = df2.drop(0)
vertical_concat = pd.concat([df1, df3], axis=0)
MIN_HEIGHT = 100
MAX_HEIGHT = 180
ROW_HEIGHT = 50
vertical_concat = vertical_concat.iloc[[0,2,1],:]
st_aggrid.AgGrid(vertical_concat,fit_columns_on_grid_load=True, height=min(MIN_HEIGHT + len(vertical_concat) * ROW_HEIGHT, MAX_HEIGHT))
fig = px.bar(vertical_concat, x='percentage', y='class', color='class', orientation='h',
color_discrete_sequence=["green","green", "red", "blue"],
category_orders={"class": ["Mangrove Before","Mangrove After", "Mangrove Loss", "Non Mangrove",]}
)
st.plotly_chart(fig,use_container_width=False)
end = time.time()
process = end-start
st.write('process',process)
show_file = st.empty()
if not file:
url = "https://drive.usercontent.google.com/download?id=1u48pMzRWQ2Etfjaq5A0CUjRtGKZaJoJy&export=download&authuser=2&confirm=t&uuid=52b0e01e-377f-42cb-8412-c84aa38a1740&at=APZUnTXslmuCCV1drJ2WWtkZr9BR%3A1710357675310"
show_file.info("""
The model was trained using Sentinel-2 imagery, users can upload MAT files to perform LSST-Former for mangrove loss detection models that have been trained in this research. Tool for generate from Sentinel-2 to MAT file i will create later, please download demo dataset bellow. for better in mobile phone, use desktop mode.
""")
st.write("download demo datasets this [link](%s)" % url)