Paras Shah
Change model load to CPU
c4592bc
raw
history blame
5.24 kB
import gc
import laspy
import torch
import tempfile
import numpy as np
import open3d as o3d
import streamlit as st
import plotly.graph_objs as go
import pointnet2_cls_msg as pn2
from utils import calculate_dbh, calc_canopy_volume, CLASSES
from SingleTreePointCloudLoader import SingleTreePointCloudLoader
gc.enable()
with st.spinner("Loading PointNet++ model..."):
checkpoint = torch.load('checkpoints/best_model.pth', map_location=torch.device('cpu'))
classifier = pn2.get_model(num_class=4, normal_channel=False)
classifier.load_state_dict(checkpoint['model_state_dict'])
classifier.eval()
st.title("Tree Species Identification")
uploaded_file = st.file_uploader(
label="Upload Point Cloud Data",
type=['laz', 'las', 'pcd'],
help="Please upload trees with ground points removed"
)
Z_THRESHOLD = st.slider(
label="Z-Threshold(%)",
min_value=5,
max_value=100,
value=50,
step=1,
help="Please select a Z-Threshold for canopy volume calculation"
)
DBH_HEIGHT = st.slider(
label="DBH Height(m)",
min_value=1.3,
max_value=1.4,
value=1.4,
step=0.01,
help="Enter height used for DBH calculation"
)
proceed = None
if uploaded_file:
try:
with st.spinner("Reading point cloud file..."):
file_type = uploaded_file.name.split('.')[-1].lower()
with tempfile.NamedTemporaryFile(delete=False, suffix=f".{uploaded_file.name.split('.')[-1]}") as tmp:
tmp.write(uploaded_file.read())
temp_file_path = tmp.name
if file_type == 'pcd':
pcd = o3d.io.read_point_cloud(temp_file_path)
points = np.asarray(pcd.points)
else:
point_cloud = laspy.read(temp_file_path)
points = np.vstack((point_cloud.x, point_cloud.y, point_cloud.z)).transpose()
proceed = st.button("Run model")
except Exception as e:
st.error(f"An error occured: {str(e)}")
if proceed:
try:
with st.spinner("Calculating tree inventory..."):
dbh, trunk_points = calculate_dbh(points, DBH_HEIGHT)
z_min = np.min(points[:, 2])
z_max = np.max(points[:, 2])
height = z_max - z_min
canopy_volume, canopy_points = calc_canopy_volume(points, Z_THRESHOLD, height, z_min)
with st.spinner("Visualizing point cloud..."):
fig = go.Figure()
fig.add_trace(go.Scatter3d(
x=points[:, 0],
y=points[:, 1],
z=points[:, 2],
mode='markers',
marker=dict(
size=0.5,
color=points[:, 2],
colorscale='Viridis',
opacity=1.0,
),
name='Tree'
))
fig.add_trace(go.Scatter3d(
x=canopy_points[:, 0],
y=canopy_points[:, 1],
z=canopy_points[:, 2],
mode='markers',
marker=dict(
size=2,
color='blue',
opacity=0.8,
),
name='Canopy points'
))
fig.add_trace(go.Scatter3d(
x=trunk_points[:, 0],
y=trunk_points[:, 1],
z=trunk_points[:, 2],
mode='markers',
marker=dict(
size=2,
color='red',
opacity=0.9,
),
name='DBH'
))
fig.update_layout(
margin=dict(l=0, r=0, b=0, t=0),
scene=dict(
xaxis_title="X",
yaxis_title="Y",
zaxis_title="Z",
aspectmode='data'
)
)
st.plotly_chart(fig, use_container_width=True)
with st.spinner("Running inference..."):
testFile = SingleTreePointCloudLoader(temp_file_path, file_type)
testFileLoader = torch.utils.data.DataLoader(testFile, batch_size=8, shuffle=False, num_workers=0)
point_set, _ = next(iter(testFileLoader))
point_set = point_set.transpose(2, 1)
with torch.no_grad():
logits, _ = classifier(point_set)
probabilities = torch.softmax(logits, dim=-1)
predicted_class = torch.argmax(probabilities, dim=-1).item()
confidence_score = (probabilities.numpy().tolist())[0][predicted_class] * 100
predicted_label = CLASSES[predicted_class]
st.write(f"**Predicted class: {predicted_label}**")
# st.write(f"Class Probabilities: {probabilities.numpy().tolist()}")
st.write(f"**Confidence score: {confidence_score:.2f}%**")
st.write(f"**Height of tree: {height:.2f}m**")
st.write(f"**Canopy volume: {canopy_volume:.2f}m\u00b3**")
st.write(f"**DBH: {dbh:.2f}m**")
except Exception as e:
st.error(f"An error occured: {str(e)}")