adarsh commited on
Commit
3f31a1d
·
1 Parent(s): 1968bb1
Files changed (3) hide show
  1. app.py +128 -0
  2. requirements.txt +94 -0
  3. yolov8_model/best_model.pt +3 -0
app.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from ultralytics import YOLO
3
+ import os
4
+ from PIL import Image
5
+ import cv2
6
+ import numpy as np
7
+ import uuid
8
+
9
+ class TreeDetectionModel:
10
+ def __init__(self, model_path):
11
+ """
12
+ Initialize the YOLOv8 model for tree detection.
13
+ Args:
14
+ model_path (str): Path to the trained YOLOv8 model (.pt or .onnx).
15
+ """
16
+ self.model = YOLO(model_path)
17
+
18
+ def detect(self, image_path):
19
+ """
20
+ Perform inference on an image.
21
+ Args:
22
+ image_path (str): Path to the input image.
23
+ Returns:
24
+ dict: Detection results including bounding boxes, class labels, and confidence scores.
25
+ """
26
+ results = self.model(image_path)
27
+ detections = []
28
+
29
+ for result in results:
30
+ for box in result.boxes:
31
+ detections.append({
32
+ 'class': result.names[int(box.cls)],
33
+ 'confidence': float(box.conf),
34
+ 'bbox': box.xyxy.tolist()[0]
35
+ })
36
+
37
+ return detections
38
+
39
+ def extract_geotag(self, image_path):
40
+ """
41
+ Extract GPS coordinates from the image's EXIF data.
42
+ Args:
43
+ image_path (str): Path to the input image.
44
+ Returns:
45
+ dict: Geotag information (latitude, longitude, altitude).
46
+ """
47
+ img = Image.open(image_path)
48
+ exif_data = img._getexif()
49
+ if exif_data:
50
+ gps_info = exif_data.get(34853, {})
51
+ return {
52
+ 'lat': gps_info.get(2),
53
+ 'lon': gps_info.get(4),
54
+ 'alt': gps_info.get(6)
55
+ }
56
+ return None
57
+
58
+
59
+ # Load the trained YOLOv8 model
60
+ model_path = "/home/kalie/work/projects/Community-Tree/ai/yolov8_model/best.pt"
61
+ tree_detector = TreeDetectionModel(model_path)
62
+
63
+ def detect_trees(image):
64
+ """
65
+ Perform tree detection on the uploaded image and display results.
66
+ Args:
67
+ image (PIL.Image): Input image.
68
+ Returns:
69
+ tuple: (annotated_image, detections_table, geotag_info)
70
+ """
71
+ # Convert PIL image to OpenCV format
72
+ image_cv = np.array(image)
73
+ image_cv = cv2.cvtColor(image_cv, cv2.COLOR_RGB2BGR)
74
+
75
+ # Save the image temporarily
76
+ temp_image_path = "temp_image.jpg"
77
+ cv2.imwrite(temp_image_path, image_cv)
78
+
79
+ # Perform detection
80
+ detections = tree_detector.detect(temp_image_path)
81
+
82
+ # Draw bounding boxes on the image
83
+ for detection in detections:
84
+ x1, y1, x2, y2 = map(int, detection['bbox'])
85
+ cv2.rectangle(image_cv, (x1, y1), (x2, y2), (0, 255, 0), 2)
86
+ label = f"{detection['class']} ({detection['confidence']:.2f})"
87
+ cv2.putText(image_cv, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
88
+
89
+ # Convert back to PIL format for Gradio
90
+ annotated_image = Image.fromarray(cv2.cvtColor(image_cv, cv2.COLOR_BGR2RGB))
91
+
92
+ # Prepare detections table
93
+ detections_table = [
94
+ [detection['class'], detection['confidence'], detection['bbox']]
95
+ for detection in detections
96
+ ]
97
+
98
+ # Extract geotag
99
+ geotag = tree_detector.extract_geotag(temp_image_path)
100
+ geotag_info = f"Latitude: {geotag['lat']}\nLongitude: {geotag['lon']}\nAltitude: {geotag['alt']}" if geotag else "No geotag found."
101
+
102
+ return annotated_image, detections_table, geotag_info
103
+
104
+ # Gradio Interface
105
+ with gr.Blocks() as demo:
106
+ gr.Markdown("# 🌳 Tree Detection App")
107
+ gr.Markdown("Upload an image to detect trees and view results.")
108
+
109
+ with gr.Row():
110
+ with gr.Column():
111
+ image_input = gr.Image(label="Upload Image", type="pil")
112
+ submit_button = gr.Button("Detect Trees")
113
+ with gr.Column():
114
+ image_output = gr.Image(label="Annotated Image")
115
+ detections_output = gr.Dataframe(
116
+ headers=["Class", "Confidence", "Bounding Box"],
117
+ label="Detection Results"
118
+ )
119
+ geotag_output = gr.Textbox(label="Geotag Information")
120
+
121
+ submit_button.click(
122
+ detect_trees,
123
+ inputs=image_input,
124
+ outputs=[image_output, detections_output, geotag_output]
125
+ )
126
+
127
+ # Run the app
128
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.2.1
2
+ aniso8601==10.0.0
3
+ annotated-types==0.7.0
4
+ anyio==4.8.0
5
+ blinker==1.9.0
6
+ certifi==2025.1.31
7
+ charset-normalizer==3.4.1
8
+ click==8.1.8
9
+ contourpy==1.3.1
10
+ cycler==0.12.1
11
+ exceptiongroup==1.2.2
12
+ fastapi==0.115.8
13
+ ffmpy==0.5.0
14
+ filelock==3.17.0
15
+ Flask==3.1.0
16
+ Flask-Cors==5.0.0
17
+ Flask-JWT-Extended==4.7.1
18
+ Flask-RESTful==0.3.10
19
+ fonttools==4.55.8
20
+ fsspec==2025.2.0
21
+ gradio==5.14.0
22
+ gradio_client==1.7.0
23
+ h11==0.14.0
24
+ httpcore==1.0.7
25
+ httpx==0.28.1
26
+ huggingface-hub==0.28.1
27
+ idna==3.10
28
+ itsdangerous==2.2.0
29
+ Jinja2==3.1.5
30
+ kiwisolver==1.4.8
31
+ markdown-it-py==3.0.0
32
+ MarkupSafe==2.1.5
33
+ matplotlib==3.10.0
34
+ mdurl==0.1.2
35
+ mpmath==1.3.0
36
+ networkx==3.4.2
37
+ numpy==2.1.1
38
+ nvidia-cublas-cu12==12.4.5.8
39
+ nvidia-cuda-cupti-cu12==12.4.127
40
+ nvidia-cuda-nvrtc-cu12==12.4.127
41
+ nvidia-cuda-runtime-cu12==12.4.127
42
+ nvidia-cudnn-cu12==9.1.0.70
43
+ nvidia-cufft-cu12==11.2.1.3
44
+ nvidia-curand-cu12==10.3.5.147
45
+ nvidia-cusolver-cu12==11.6.1.9
46
+ nvidia-cusparse-cu12==12.3.1.170
47
+ nvidia-cusparselt-cu12==0.6.2
48
+ nvidia-nccl-cu12==2.21.5
49
+ nvidia-nvjitlink-cu12==12.4.127
50
+ nvidia-nvtx-cu12==12.4.127
51
+ opencv-python==4.11.0.86
52
+ orjson==3.10.15
53
+ packaging==24.2
54
+ pandas==2.2.3
55
+ pillow==11.1.0
56
+ psutil==6.1.1
57
+ py-cpuinfo==9.0.0
58
+ pydantic==2.10.6
59
+ pydantic_core==2.27.2
60
+ pydub==0.25.1
61
+ Pygments==2.19.1
62
+ PyJWT==2.10.1
63
+ pyparsing==3.2.1
64
+ python-dateutil==2.9.0.post0
65
+ python-dotenv==1.0.1
66
+ python-multipart==0.0.20
67
+ pytz==2024.2
68
+ PyYAML==6.0.2
69
+ requests==2.32.3
70
+ rich==13.9.4
71
+ ruff==0.9.4
72
+ safehttpx==0.1.6
73
+ scipy==1.15.1
74
+ seaborn==0.13.2
75
+ semantic-version==2.10.0
76
+ shellingham==1.5.4
77
+ six==1.17.0
78
+ sniffio==1.3.1
79
+ starlette==0.45.3
80
+ sympy==1.13.1
81
+ tomlkit==0.13.2
82
+ torch==2.6.0
83
+ torchvision==0.21.0
84
+ tqdm==4.67.1
85
+ triton==3.2.0
86
+ typer==0.15.1
87
+ typing_extensions==4.12.2
88
+ tzdata==2025.1
89
+ ultralytics==8.3.70
90
+ ultralytics-thop==2.0.14
91
+ urllib3==2.3.0
92
+ uvicorn==0.34.0
93
+ websockets==14.2
94
+ Werkzeug==3.1.3
yolov8_model/best_model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4a6058d690917cdbef87cb4684f3d45f4b91b3a68e23cb59edbd5a63d6df912c
3
+ size 52010507