import gc |
import laspy |
import torch |
import tempfile |
import numpy as np |
import open3d as o3d |
import streamlit as st |
import plotly.graph_objs as go |
import pointnet2_cls_msg as pn2 |
from utils import calculate_dbh, calc_canopy_volume, CLASSES |
from SingleTreePointCloudLoader import SingleTreePointCloudLoader |
gc.enable() |
with st.spinner("Loading PointNet++ model..."): |
checkpoint = torch.load('checkpoints/best_model.pth', map_location=torch.device('cpu')) |
classifier = pn2.get_model(num_class=4, normal_channel=False) |
classifier.load_state_dict(checkpoint['model_state_dict']) |
classifier.eval() |
st.title("Tree Species Identification") |
uploaded_file = st.file_uploader( |
label="Upload Point Cloud Data", |
type=['laz', 'las', 'pcd'], |
help="Please upload trees with ground points removed" |
) |
Z_THRESHOLD = st.slider( |
label="Z-Threshold(%)", |
min_value=5, |
max_value=100, |
value=50, |
step=1, |
help="Please select a Z-Threshold for canopy volume calculation" |
) |
DBH_HEIGHT = st.slider( |
label="DBH Height(m)", |
min_value=1.3, |
max_value=1.4, |
value=1.4, |
step=0.01, |
help="Enter height used for DBH calculation" |
) |
proceed = None |
if uploaded_file: |
try: |
with st.spinner("Reading point cloud file..."): |
file_type = uploaded_file.name.split('.')[-1].lower() |
with tempfile.NamedTemporaryFile(delete=False, suffix=f".{uploaded_file.name.split('.')[-1]}") as tmp: |
tmp.write(uploaded_file.read()) |
temp_file_path = tmp.name |
if file_type == 'pcd': |
pcd = o3d.io.read_point_cloud(temp_file_path) |
points = np.asarray(pcd.points) |
else: |
point_cloud = laspy.read(temp_file_path) |
points = np.vstack((point_cloud.x, point_cloud.y, point_cloud.z)).transpose() |
proceed = st.button("Run model") |
except Exception as e: |
st.error(f"An error occured: {str(e)}") |
if proceed: |
try: |
with st.spinner("Calculating tree inventory..."): |
dbh, trunk_points = calculate_dbh(points, DBH_HEIGHT) |
z_min = np.min(points[:, 2]) |
z_max = np.max(points[:, 2]) |
height = z_max - z_min |
canopy_volume, canopy_points = calc_canopy_volume(points, Z_THRESHOLD, height, z_min) |
with st.spinner("Visualizing point cloud..."): |
fig = go.Figure() |
fig.add_trace(go.Scatter3d( |
x=points[:, 0], |
y=points[:, 1], |
z=points[:, 2], |
mode='markers', |
marker=dict( |
size=0.5, |
color=points[:, 2], |
colorscale='Viridis', |
opacity=1.0, |
), |
name='Tree' |
)) |
fig.add_trace(go.Scatter3d( |
x=canopy_points[:, 0], |
y=canopy_points[:, 1], |
z=canopy_points[:, 2], |
mode='markers', |
marker=dict( |
size=2, |
color='blue', |
opacity=0.8, |
), |
name='Canopy points' |
)) |
fig.add_trace(go.Scatter3d( |
x=trunk_points[:, 0], |
y=trunk_points[:, 1], |
z=trunk_points[:, 2], |
mode='markers', |
marker=dict( |
size=2, |
color='red', |
opacity=0.9, |
), |
name='DBH' |
)) |
fig.update_layout( |
margin=dict(l=0, r=0, b=0, t=0), |
scene=dict( |
xaxis_title="X", |
yaxis_title="Y", |
zaxis_title="Z", |
aspectmode='data' |
) |
) |
st.plotly_chart(fig, use_container_width=True) |
with st.spinner("Running inference..."): |
testFile = SingleTreePointCloudLoader(temp_file_path, file_type) |
testFileLoader = torch.utils.data.DataLoader(testFile, batch_size=8, shuffle=False, num_workers=0) |
point_set, _ = next(iter(testFileLoader)) |
point_set = point_set.transpose(2, 1) |
with torch.no_grad(): |
logits, _ = classifier(point_set) |
probabilities = torch.softmax(logits, dim=-1) |
predicted_class = torch.argmax(probabilities, dim=-1).item() |
confidence_score = (probabilities.numpy().tolist())[0][predicted_class] * 100 |
predicted_label = CLASSES[predicted_class] |
st.write(f"**Predicted class: {predicted_label}**") |
st.write(f"**Confidence score: {confidence_score:.2f}%**") |
st.write(f"**Height of tree: {height:.2f}m**") |
st.write(f"**Canopy volume: {canopy_volume:.2f}m\u00b3**") |
st.write(f"**DBH: {dbh:.2f}m**") |
except Exception as e: |
st.error(f"An error occured: {str(e)}") |