import streamlit as st import io import collections from scipy.io import loadmat import matplotlib.pyplot as plt from PIL import Image import numpy as np import torch import argparse import torch.nn as nn import torch.utils.data as Data import torch.backends.cudnn as cudnn from scipy.io import loadmat from scipy.io import savemat from torch import optim from torch.autograd import Variable from sstvit import SSTViT from sklearn.metrics import confusion_matrix import matplotlib.pyplot as plt from matplotlib import colors import numpy as np from patchify import patchify, unpatchify import time from matplotlib import colors as mcolors import base64 import pandas as pd import st_aggrid import os import json import plotly.express as px css=''' ''' st.markdown(css, unsafe_allow_html=True) class Args(dict): __setattr__ = dict.__setitem__ __getattr__ = dict.__getitem__ args = { 'dataset' : 'mg', 'flag_test' : 'train', 'gpu_id' : 0, 'seed' : int(0), 'batch_size' : int(64), 'test_freq' : int(10), 'patches' : int(5), 'band_patches' : int(1), 'epoches' : int(2000), 'learning_rate' : float(5e-4), 'gamma' : float(0.9), 'weight_decay' : float(0), 'train_number' : int(500) } args = Args(args) # dict2object obj = args.copy() # object2dict os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_id) def test_epoch(model, test_loader): pre = np.array([]) for batch_idx, (batch_data_t1, batch_data_t2) in enumerate(test_loader): batch_data_t1 = batch_data_t1 batch_data_t2 = batch_data_t2 batch_pred = model(batch_data_t1,batch_data_t2) _, pred = batch_pred.topk(1, 1, True, True) pp = pred.squeeze() pre = np.append(pre, pp.data.cpu().numpy()) return pre mdic = ['Before','After','Before','After'] colors = ['#3b68f8', '#ff0201', '#23fe01'] #-1,0,1,2,3 cmap = mcolors.ListedColormap(colors) # Parameter Setting np.random.seed(args.seed) torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) cudnn.deterministic = True cudnn.benchmark = False def encode_masks_to_rgb(masks): colors = [(0, 0, 255), (255, 0, 0), (0, 255, 0)] # Create an empty RGB image height, width = masks.shape rgb_image = np.zeros((height, width, 3), dtype=np.uint8) # Assign colors based on the mask values for i in range(len(colors)): mask_indices = masks == i rgb_image[mask_indices] = colors[i] return rgb_image def count_pixel(pred): image = Image.fromarray(pred) # Define the colors you want to count in RGB format color2label = { (0, 0, 255): "Non Mangrove", (255, 0, 0): "Mangrove Loss", (0, 255, 0): "Mangrove Before", } # Create a flattened list of pixel values pixels = list(image.getdata()) # Count the number of pixels for each color color_counts = collections.Counter(pixels) # Calculate the total number of pixels in the image total_pixels = len(pixels) # Initialize a dictionary to store the average number of pixels for each class average_counts = {color2label[label]: (count / total_pixels)*100 for label, count in color_counts.items()} class_counts = {color2label[label]: count for label, count in color_counts.items()} pix_avg = {} pix_count = {} for _, i in color2label.items(): try: pix_avg[i] = average_counts[i] pix_count[i] = class_counts[i] except: pix_avg[i] = 0 pix_count[i] = 0 x = { "class": list(pix_avg.keys()), "percentage": list(pix_avg.values()), "pixel_count": list(pix_count.values()) } # print(x) return pd.DataFrame(x) def count_pixel1(pred): image = Image.fromarray(pred) # Define the colors you want to count in RGB format color2label = { (0, 0, 255): "Non Mangrove", (255, 0, 0): "Mangrove Loss", (0, 255, 0): "Mangrove After", } # Create a flattened list of pixel values pixels = list(image.getdata()) # Count the number of pixels for each color color_counts = collections.Counter(pixels) # Calculate the total number of pixels in the image total_pixels = len(pixels) # Initialize a dictionary to store the average number of pixels for each class average_counts = {color2label[label]: (count / total_pixels)*100 for label, count in color_counts.items()} class_counts = {color2label[label]: count for label, count in color_counts.items()} pix_avg = {} pix_count = {} for _, i in color2label.items(): try: pix_avg[i] = average_counts[i] pix_count[i] = class_counts[i] except: pix_avg[i] = 0 pix_count[i] = 0 x = { "class": list(pix_avg.keys()), "percentage": list(pix_avg.values()), "pixel_count": list(pix_count.values()) } # print(x) return pd.DataFrame(x) file = st.file_uploader("Upload file", type=['mat']) if file: data_img2 = loadmat(file)['data_img2'] data_img1 = loadmat(file)['data_img1'] st.subheader("Preview Dataset") col1, col2 = st.columns(2) with col1: fig = plt.figure(figsize=(5, 5)) plt.subplot(121) plt.imshow(data_img1) plt.title('Before', fontweight='bold') plt.xticks([]) plt.yticks([]) plt.subplot(122) plt.imshow(data_img2) plt.title('After', fontweight='bold') plt.xticks([]) plt.yticks([]) plt.show() st.pyplot(fig) holder = st.empty() if holder.button("Start Prediction"): start = time.time() holder.empty() with st.spinner("Processing, please wait around 7-15 minute"): data_t1 = loadmat(file)['data_t1'] data_t2 = loadmat(file)['data_t2'] L_post = loadmat(file)['L_post'] L_pre = loadmat(file)['L_pre'] data_img1 = loadmat(file)['data_img1'] data_img2 = loadmat(file)['data_img2'] L_post = np.double(L_post) L_post[L_post==0]=-0.8 L_post[L_post==1]=0 L_post[L_post==0]=-0.2 L_pre = np.double(L_pre) L_pre[L_pre==0]=-0.8 L_pre[L_pre==1]=0 L_pre[L_pre==0]=-0.2 data_t1 = data_t1[:L_post.shape[0],:L_post.shape[1],:] data_t2 = data_t2[:L_post.shape[0],:L_post.shape[1],:] data_cb1 = np.zeros(shape=(L_post.shape[0],L_post.shape[1],11),dtype=np.float32) data_cb2 = np.zeros(shape=(L_post.shape[0],L_post.shape[1],11),dtype=np.float32) data_cb1[:,:,:10]=data_t1 data_cb1[:,:,10]=L_pre data_cb2[:,:,:10]=data_t2 data_cb2[:,:,10]=L_post height, width, band = data_cb1.shape height=height-4 width = width-4 x1 = patchify(data_cb1, (5, 5, 11), step=1).reshape(-1,5*5, 11) x2 = patchify(data_cb2, (5, 5, 11), step=1).reshape(-1,5*5, 11) # create model model = SSTViT( image_size = 5, near_band = args.band_patches, num_patches = 11, num_classes = 3, dim = 32, depth = 2, heads = 4, dim_head=16, mlp_dim = 8, b_dim = 512, b_depth = 3, b_heads = 8, b_dim_head= 32, b_mlp_head = 8, dropout = 0.2, emb_dropout = 0.1, ) model.load_state_dict(torch.load("model/lsstformer.pth",map_location=torch.device("cpu"))) x1_true_band=torch.from_numpy(x1.transpose(0,2,1)).type(torch.FloatTensor) x2_true_band=torch.from_numpy(x1.transpose(0,2,1)).type(torch.FloatTensor) Label_true=Data.TensorDataset(x1_true_band,x2_true_band) label_true_loader=Data.DataLoader(Label_true,batch_size=100,shuffle=False) model.eval() # output classification maps pre_u = test_epoch(model, label_true_loader) prediction_matrix = pre_u.reshape(height,width) x1_true_band=torch.from_numpy(x1.transpose(0,2,1)).type(torch.FloatTensor) x2_true_band=torch.from_numpy(x2.transpose(0,2,1)).type(torch.FloatTensor) Label_true=Data.TensorDataset(x1_true_band,x2_true_band) label_true_loader=Data.DataLoader(Label_true,batch_size=100,shuffle=False) model.eval() # output classification maps pre_u = test_epoch(model, label_true_loader) prediction_matrix2 = pre_u.reshape(height,width) A = prediction_matrix.reshape(-1) B = prediction_matrix2.reshape(-1) mg = np.array(np.where(A==2)) mg1 = np.array(np.where(B==2)) mgls = np.array(np.where(B==1)) class_counts = count_pixel(encode_masks_to_rgb(prediction_matrix)) class_counts1 = count_pixel1(encode_masks_to_rgb(prediction_matrix2)) with st.container(): st.subheader("Prediction Result") col1, col2 = st.columns(2) with col1: with st.container(): fig = plt.figure(figsize=(10, 10)) plt.subplot(121) plt.imshow(prediction_matrix, cmap=cmap) plt.title('Before',fontsize=25, fontweight='bold') plt.xticks([]) plt.yticks([]) plt.subplot(122) plt.imshow(prediction_matrix2, cmap=cmap) plt.title('After',fontsize=25, fontweight='bold') plt.xticks([]) plt.yticks([]) plt.show() st.pyplot(fig) buf = io.BytesIO() fig.savefig(buf, format="png") with col2: with st.container(): table_data = { "Total mangrove before":f"{mg.shape[1]*100} m\u00B2", "Total mangrove after":f"{mg1.shape[1]*100} m\u00B2", "Total mangrove loss":f"{mgls.shape[1]*100} m\u00B2", } df = pd.DataFrame(list(table_data.items()), columns=['Key', 'Value']) MIN_HEIGHT = 100 MAX_HEIGHT = 180 ROW_HEIGHT = 50 # st.dataframe(df, hide_index=True, use_container_width=True) st_aggrid.AgGrid(df,fit_columns_on_grid_load=True, height=min(MIN_HEIGHT + len(df) * ROW_HEIGHT, MAX_HEIGHT)) with st.container(): st.subheader("Pixel Distribution") df = class_counts df = df.drop(0) df1 = df.drop(1) df2 = class_counts1 df3 = df2.drop(0) vertical_concat = pd.concat([df1, df3], axis=0) MIN_HEIGHT = 100 MAX_HEIGHT = 180 ROW_HEIGHT = 50 vertical_concat = vertical_concat.iloc[[0,2,1],:] st_aggrid.AgGrid(vertical_concat,fit_columns_on_grid_load=True, height=min(MIN_HEIGHT + len(vertical_concat) * ROW_HEIGHT, MAX_HEIGHT)) fig = px.bar(vertical_concat, x='percentage', y='class', color='class', orientation='h', color_discrete_sequence=["green","green", "red", "blue"], category_orders={"class": ["Mangrove Before","Mangrove After", "Mangrove Loss", "Non Mangrove",]} ) st.plotly_chart(fig,use_container_width=False) end = time.time() process = end-start st.write('process',process) show_file = st.empty() if not file: url = "https://drive.usercontent.google.com/download?id=1u48pMzRWQ2Etfjaq5A0CUjRtGKZaJoJy&export=download&authuser=2&confirm=t&uuid=52b0e01e-377f-42cb-8412-c84aa38a1740&at=APZUnTXslmuCCV1drJ2WWtkZr9BR%3A1710357675310" show_file.info(""" The model was trained using Sentinel-2 imagery, users can upload MAT files to perform LSST-Former for mangrove loss detection models that have been trained in this research. Tool for generate from Sentinel-2 to MAT file i will create later, please download demo dataset bellow. for better in mobile phone, use desktop mode. """) st.write("download demo datasets this [link](%s)" % url)