Update app.py
Browse files
app.py
CHANGED
@@ -1,6 +1,9 @@
|
|
1 |
# !pip install ultralytics
|
2 |
# !pip install gradio
|
3 |
|
|
|
|
|
|
|
4 |
import cv2
|
5 |
from ultralytics import YOLO
|
6 |
from PIL import Image
|
@@ -10,65 +13,65 @@ import numpy as np
|
|
10 |
import tempfile
|
11 |
import os
|
12 |
|
13 |
-
#
|
14 |
-
if os.path.exists('
|
15 |
print("Model file found.")
|
16 |
else:
|
17 |
print("Model file not found. Please upload 'best.pt' to the Space.")
|
18 |
|
19 |
-
#
|
20 |
-
model = YOLO('best.pt') #
|
21 |
|
22 |
-
#
|
23 |
class_names = {
|
24 |
-
0:
|
25 |
-
1:
|
26 |
-
2:
|
27 |
-
3:
|
28 |
-
4:
|
29 |
-
5:
|
30 |
-
6:
|
31 |
-
7:
|
32 |
-
8:
|
33 |
-
9:
|
34 |
-
10:
|
35 |
-
11:
|
36 |
-
12:
|
37 |
-
13:
|
38 |
-
14:
|
39 |
}
|
40 |
|
41 |
-
#
|
42 |
colors = {
|
43 |
-
0: (255, 0, 0), #
|
44 |
-
1: (0, 255, 0), #
|
45 |
-
2: (0, 0, 255), #
|
46 |
-
3: (255, 255, 0), #
|
47 |
-
4: (255, 0, 255), #
|
48 |
-
5: (0, 255, 255), #
|
49 |
-
6: (128, 0, 128), #
|
50 |
-
7: (255, 165, 0), #
|
51 |
-
8: (0, 128, 0), #
|
52 |
-
9: (128, 128, 0), #
|
53 |
-
10: (0, 255, 0), #
|
54 |
-
11: (0, 128, 128), #
|
55 |
-
12: (0, 0, 128), #
|
56 |
-
13: (75, 0, 130), #
|
57 |
-
14: (199, 21, 133) #
|
58 |
}
|
59 |
|
60 |
-
#
|
61 |
def detect_and_draw_image(input_image):
|
62 |
try:
|
63 |
-
#
|
64 |
input_image_np = np.array(input_image)
|
65 |
print("Image converted to NumPy array.")
|
66 |
|
67 |
-
#
|
68 |
results = model.predict(source=input_image_np, conf=0.3)
|
69 |
print("Model prediction completed.")
|
70 |
|
71 |
-
#
|
72 |
if hasattr(results[0], 'obb') and results[0].obb is not None:
|
73 |
obb_results = results[0].obb
|
74 |
print("Accessed obb_results.")
|
@@ -76,49 +79,48 @@ def detect_and_draw_image(input_image):
|
|
76 |
print("No 'obb' attribute found in results[0].")
|
77 |
obb_results = None
|
78 |
|
79 |
-
#
|
80 |
if obb_results is None or len(obb_results.data) == 0:
|
81 |
-
print("
|
82 |
df = pd.DataFrame({
|
83 |
-
'Label
|
84 |
-
'Label (Persian)': [],
|
85 |
'Object Count': []
|
86 |
})
|
87 |
return input_image, df
|
88 |
|
89 |
counts = {}
|
90 |
-
#
|
91 |
for obb, conf, cls in zip(obb_results.data.cpu().numpy(), obb_results.conf.cpu().numpy(), obb_results.cls.cpu().numpy()):
|
92 |
x_center, y_center, width, height, rotation = obb[:5]
|
93 |
class_id = int(cls)
|
94 |
confidence = float(conf)
|
95 |
|
96 |
-
#
|
97 |
-
rect = ((x_center, y_center), (width, height), rotation * 180.0 / np.pi) #
|
98 |
box_points = cv2.boxPoints(rect)
|
99 |
box_points = np.int0(box_points)
|
100 |
color = colors.get(class_id, (0, 255, 0))
|
101 |
-
cv2.drawContours(input_image_np, [box_points], 0, color,
|
102 |
print(f"Drawn OBB for class_id {class_id} with confidence {confidence}.")
|
103 |
|
104 |
-
#
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
|
|
109 |
|
110 |
-
#
|
111 |
-
counts[
|
112 |
|
113 |
-
#
|
114 |
image_rgb = cv2.cvtColor(input_image_np, cv2.COLOR_BGR2RGB)
|
115 |
output_image = Image.fromarray(image_rgb)
|
116 |
print("Image converted back to RGB for Gradio.")
|
117 |
|
118 |
-
#
|
119 |
df = pd.DataFrame({
|
120 |
-
'Label
|
121 |
-
'Label (Persian)': [class_names.get(k, ('unknown', 'ناشناخته'))[1] for k in counts.keys()],
|
122 |
'Object Count': list(counts.values())
|
123 |
})
|
124 |
print("DataFrame created.")
|
@@ -128,13 +130,12 @@ def detect_and_draw_image(input_image):
|
|
128 |
except Exception as e:
|
129 |
print(f"Error in detect_and_draw_image: {e}")
|
130 |
df = pd.DataFrame({
|
131 |
-
'Label
|
132 |
-
'Label (Persian)': [],
|
133 |
'Object Count': []
|
134 |
})
|
135 |
return input_image, df
|
136 |
|
137 |
-
#
|
138 |
def detect_and_draw_video(video_path):
|
139 |
try:
|
140 |
cap = cv2.VideoCapture(video_path)
|
@@ -150,14 +151,14 @@ def detect_and_draw_video(video_path):
|
|
150 |
frame_count +=1
|
151 |
print(f"Processing frame {frame_count}")
|
152 |
|
153 |
-
#
|
154 |
frame = cv2.resize(frame, (640, 480))
|
155 |
|
156 |
-
#
|
157 |
results = model.predict(source=frame, conf=0.3)
|
158 |
print(f"Model prediction completed for frame {frame_count}.")
|
159 |
|
160 |
-
#
|
161 |
if hasattr(results[0], 'obb') and results[0].obb is not None:
|
162 |
obb_results = results[0].obb
|
163 |
print("Accessed obb_results for frame.")
|
@@ -171,22 +172,23 @@ def detect_and_draw_video(video_path):
|
|
171 |
class_id = int(cls)
|
172 |
confidence = float(conf)
|
173 |
|
174 |
-
#
|
175 |
rect = ((x_center, y_center), (width, height), rotation * 180.0 / np.pi)
|
176 |
box_points = cv2.boxPoints(rect)
|
177 |
box_points = np.int0(box_points)
|
178 |
color = colors.get(class_id, (0, 255, 0))
|
179 |
-
cv2.drawContours(frame, [box_points], 0, color,
|
180 |
print(f"Drawn OBB for class_id {class_id} with confidence {confidence} in frame {frame_count}.")
|
181 |
|
182 |
-
#
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
|
|
187 |
|
188 |
-
#
|
189 |
-
overall_counts[
|
190 |
|
191 |
frames.append(frame)
|
192 |
print(f"Frame {frame_count} processed.")
|
@@ -194,7 +196,7 @@ def detect_and_draw_video(video_path):
|
|
194 |
cap.release()
|
195 |
print("Video processing completed.")
|
196 |
|
197 |
-
#
|
198 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as tmpfile:
|
199 |
output_path = tmpfile.name
|
200 |
print(f"Saving processed video to {output_path}")
|
@@ -208,10 +210,9 @@ def detect_and_draw_video(video_path):
|
|
208 |
out.release()
|
209 |
print("Video saved.")
|
210 |
|
211 |
-
#
|
212 |
df = pd.DataFrame({
|
213 |
-
'Label
|
214 |
-
'Label (Persian)': [class_names.get(k, ('unknown', 'ناشناخته'))[1] for k in overall_counts.keys()],
|
215 |
'Object Count': list(overall_counts.values())
|
216 |
})
|
217 |
print("DataFrame created.")
|
@@ -220,40 +221,39 @@ def detect_and_draw_video(video_path):
|
|
220 |
|
221 |
except Exception as e:
|
222 |
print(f"Error in detect_and_draw_video: {e}")
|
223 |
-
#
|
224 |
return video_path, pd.DataFrame({
|
225 |
-
'Label
|
226 |
-
'Label (Persian)': [],
|
227 |
'Object Count': []
|
228 |
})
|
229 |
|
230 |
-
#
|
231 |
image_interface = gr.Interface(
|
232 |
fn=detect_and_draw_image,
|
233 |
-
inputs=gr.Image(type="pil", label="
|
234 |
-
outputs=[gr.Image(type="pil", label="
|
235 |
-
title="
|
236 |
-
description="
|
237 |
examples=[
|
238 |
-
'
|
239 |
-
'
|
240 |
-
'
|
241 |
]
|
242 |
)
|
243 |
|
244 |
-
#
|
245 |
video_interface = gr.Interface(
|
246 |
fn=detect_and_draw_video,
|
247 |
-
inputs=gr.Video(label="
|
248 |
-
outputs=[gr.Video(label="
|
249 |
-
title="
|
250 |
-
description="
|
251 |
examples=[
|
252 |
-
'
|
253 |
-
'
|
254 |
]
|
255 |
)
|
256 |
|
257 |
-
#
|
258 |
-
app = gr.TabbedInterface([image_interface, video_interface], ["
|
259 |
app.launch()
|
|
|
1 |
# !pip install ultralytics
|
2 |
# !pip install gradio
|
3 |
|
4 |
+
# !pip install ultralytics
|
5 |
+
# !pip install gradio
|
6 |
+
|
7 |
import cv2
|
8 |
from ultralytics import YOLO
|
9 |
from PIL import Image
|
|
|
13 |
import tempfile
|
14 |
import os
|
15 |
|
16 |
+
# Check if the model file exists
|
17 |
+
if os.path.exists('/content/best.pt'):
|
18 |
print("Model file found.")
|
19 |
else:
|
20 |
print("Model file not found. Please upload 'best.pt' to the Space.")
|
21 |
|
22 |
+
# Load your trained model
|
23 |
+
model = YOLO('best.pt') # Adjust the model path accordingly
|
24 |
|
25 |
+
# Define class names
|
26 |
class_names = {
|
27 |
+
0: 'plane',
|
28 |
+
1: 'ship',
|
29 |
+
2: 'storage tank',
|
30 |
+
3: 'baseball diamond',
|
31 |
+
4: 'tennis court',
|
32 |
+
5: 'basketball court',
|
33 |
+
6: 'ground track field',
|
34 |
+
7: 'harbor',
|
35 |
+
8: 'bridge',
|
36 |
+
9: 'large vehicle',
|
37 |
+
10: 'small vehicle',
|
38 |
+
11: 'helicopter',
|
39 |
+
12: 'roundabout',
|
40 |
+
13: 'soccer ball field',
|
41 |
+
14: 'swimming pool'
|
42 |
}
|
43 |
|
44 |
+
# Colors for each class (BGR for OpenCV)
|
45 |
colors = {
|
46 |
+
0: (255, 0, 0), # Red
|
47 |
+
1: (0, 255, 0), # Green
|
48 |
+
2: (0, 0, 255), # Blue
|
49 |
+
3: (255, 255, 0), # Yellow
|
50 |
+
4: (255, 0, 255), # Magenta
|
51 |
+
5: (0, 255, 255), # Cyan
|
52 |
+
6: (128, 0, 128), # Purple
|
53 |
+
7: (255, 165, 0), # Orange
|
54 |
+
8: (0, 128, 0), # Dark Green
|
55 |
+
9: (128, 128, 0), # Olive
|
56 |
+
10: (0, 255, 0), # Light Green for class_id=10
|
57 |
+
11: (0, 128, 128), # Teal
|
58 |
+
12: (0, 0, 128), # Navy
|
59 |
+
13: (75, 0, 130), # Indigo
|
60 |
+
14: (199, 21, 133) # Medium Violet Red
|
61 |
}
|
62 |
|
63 |
+
# Function to detect objects in images
|
64 |
def detect_and_draw_image(input_image):
|
65 |
try:
|
66 |
+
# Convert PIL image to NumPy array (RGB)
|
67 |
input_image_np = np.array(input_image)
|
68 |
print("Image converted to NumPy array.")
|
69 |
|
70 |
+
# Run the model on the image using NumPy array (RGB)
|
71 |
results = model.predict(source=input_image_np, conf=0.3)
|
72 |
print("Model prediction completed.")
|
73 |
|
74 |
+
# Access OBB results
|
75 |
if hasattr(results[0], 'obb') and results[0].obb is not None:
|
76 |
obb_results = results[0].obb
|
77 |
print("Accessed obb_results.")
|
|
|
79 |
print("No 'obb' attribute found in results[0].")
|
80 |
obb_results = None
|
81 |
|
82 |
+
# Check if any detections are found
|
83 |
if obb_results is None or len(obb_results.data) == 0:
|
84 |
+
print("No objects detected.")
|
85 |
df = pd.DataFrame({
|
86 |
+
'Label': [],
|
|
|
87 |
'Object Count': []
|
88 |
})
|
89 |
return input_image, df
|
90 |
|
91 |
counts = {}
|
92 |
+
# Process results and draw bounding boxes
|
93 |
for obb, conf, cls in zip(obb_results.data.cpu().numpy(), obb_results.conf.cpu().numpy(), obb_results.cls.cpu().numpy()):
|
94 |
x_center, y_center, width, height, rotation = obb[:5]
|
95 |
class_id = int(cls)
|
96 |
confidence = float(conf)
|
97 |
|
98 |
+
# Draw rotated bounding box using OpenCV
|
99 |
+
rect = ((x_center, y_center), (width, height), rotation * 180.0 / np.pi) # Convert radians to degrees
|
100 |
box_points = cv2.boxPoints(rect)
|
101 |
box_points = np.int0(box_points)
|
102 |
color = colors.get(class_id, (0, 255, 0))
|
103 |
+
cv2.drawContours(input_image_np, [box_points], 0, color, 1) # Reduced thickness to 1
|
104 |
print(f"Drawn OBB for class_id {class_id} with confidence {confidence}.")
|
105 |
|
106 |
+
# Draw label with less thickness and appropriate position
|
107 |
+
label = class_names.get(class_id, 'unknown')
|
108 |
+
text_position = (int(x_center), int(y_center) - int(height / 2) - 10)
|
109 |
+
cv2.putText(input_image_np, f'{label}: {confidence:.2f}',
|
110 |
+
text_position,
|
111 |
+
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1, cv2.LINE_AA) # Reduced font thickness to 1
|
112 |
|
113 |
+
# Count objects
|
114 |
+
counts[label] = counts.get(label, 0) + 1
|
115 |
|
116 |
+
# Convert image to RGB for Gradio
|
117 |
image_rgb = cv2.cvtColor(input_image_np, cv2.COLOR_BGR2RGB)
|
118 |
output_image = Image.fromarray(image_rgb)
|
119 |
print("Image converted back to RGB for Gradio.")
|
120 |
|
121 |
+
# Create DataFrame to display results
|
122 |
df = pd.DataFrame({
|
123 |
+
'Label': list(counts.keys()),
|
|
|
124 |
'Object Count': list(counts.values())
|
125 |
})
|
126 |
print("DataFrame created.")
|
|
|
130 |
except Exception as e:
|
131 |
print(f"Error in detect_and_draw_image: {e}")
|
132 |
df = pd.DataFrame({
|
133 |
+
'Label': [],
|
|
|
134 |
'Object Count': []
|
135 |
})
|
136 |
return input_image, df
|
137 |
|
138 |
+
# Function to detect objects in videos
|
139 |
def detect_and_draw_video(video_path):
|
140 |
try:
|
141 |
cap = cv2.VideoCapture(video_path)
|
|
|
151 |
frame_count +=1
|
152 |
print(f"Processing frame {frame_count}")
|
153 |
|
154 |
+
# Resize frame
|
155 |
frame = cv2.resize(frame, (640, 480))
|
156 |
|
157 |
+
# Run the model on the frame
|
158 |
results = model.predict(source=frame, conf=0.3)
|
159 |
print(f"Model prediction completed for frame {frame_count}.")
|
160 |
|
161 |
+
# Access OBB results
|
162 |
if hasattr(results[0], 'obb') and results[0].obb is not None:
|
163 |
obb_results = results[0].obb
|
164 |
print("Accessed obb_results for frame.")
|
|
|
172 |
class_id = int(cls)
|
173 |
confidence = float(conf)
|
174 |
|
175 |
+
# Draw rotated bounding box using OpenCV
|
176 |
rect = ((x_center, y_center), (width, height), rotation * 180.0 / np.pi)
|
177 |
box_points = cv2.boxPoints(rect)
|
178 |
box_points = np.int0(box_points)
|
179 |
color = colors.get(class_id, (0, 255, 0))
|
180 |
+
cv2.drawContours(frame, [box_points], 0, color, 1) # Reduced thickness to 1
|
181 |
print(f"Drawn OBB for class_id {class_id} with confidence {confidence} in frame {frame_count}.")
|
182 |
|
183 |
+
# Draw label with less thickness and appropriate position
|
184 |
+
label = class_names.get(class_id, 'unknown')
|
185 |
+
text_position = (int(x_center), int(y_center) - int(height / 2) - 10)
|
186 |
+
cv2.putText(frame, f"{label}: {confidence:.2f}",
|
187 |
+
text_position,
|
188 |
+
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1, cv2.LINE_AA) # Reduced font thickness to 1
|
189 |
|
190 |
+
# Count objects
|
191 |
+
overall_counts[label] = overall_counts.get(label, 0) + 1
|
192 |
|
193 |
frames.append(frame)
|
194 |
print(f"Frame {frame_count} processed.")
|
|
|
196 |
cap.release()
|
197 |
print("Video processing completed.")
|
198 |
|
199 |
+
# Save processed video to a temporary file
|
200 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as tmpfile:
|
201 |
output_path = tmpfile.name
|
202 |
print(f"Saving processed video to {output_path}")
|
|
|
210 |
out.release()
|
211 |
print("Video saved.")
|
212 |
|
213 |
+
# Create DataFrame to store results
|
214 |
df = pd.DataFrame({
|
215 |
+
'Label': list(overall_counts.keys()),
|
|
|
216 |
'Object Count': list(overall_counts.values())
|
217 |
})
|
218 |
print("DataFrame created.")
|
|
|
221 |
|
222 |
except Exception as e:
|
223 |
print(f"Error in detect_and_draw_video: {e}")
|
224 |
+
# In case of an error, return the original video and an empty DataFrame
|
225 |
return video_path, pd.DataFrame({
|
226 |
+
'Label': [],
|
|
|
227 |
'Object Count': []
|
228 |
})
|
229 |
|
230 |
+
# Gradio interface for images
|
231 |
image_interface = gr.Interface(
|
232 |
fn=detect_and_draw_image,
|
233 |
+
inputs=gr.Image(type="pil", label="Upload Image"),
|
234 |
+
outputs=[gr.Image(type="pil", label="Processed Image"), gr.Dataframe(label="Object Counts")],
|
235 |
+
title="Object Detection in Aerial Images",
|
236 |
+
description="Upload an aerial image to see detected objects and their counts.",
|
237 |
examples=[
|
238 |
+
'/content/EXAMPLES/IMAGES/Examples_images_areial_car.jpg',
|
239 |
+
'/content/EXAMPLES/IMAGES/Examples_images_images.jpg',
|
240 |
+
'/content/EXAMPLES/IMAGES/Examples_images_t.jpg'
|
241 |
]
|
242 |
)
|
243 |
|
244 |
+
# Gradio interface for videos
|
245 |
video_interface = gr.Interface(
|
246 |
fn=detect_and_draw_video,
|
247 |
+
inputs=gr.Video(label="Upload Video"),
|
248 |
+
outputs=[gr.Video(label="Processed Video"), gr.Dataframe(label="Object Counts")],
|
249 |
+
title="Object Detection in Videos",
|
250 |
+
description="Upload a video to see detected objects and their counts.",
|
251 |
examples=[
|
252 |
+
'/content/EXAMPLES/VIDEO/airplane.mp4',
|
253 |
+
'/content/EXAMPLES/VIDEO/city.mp4'
|
254 |
]
|
255 |
)
|
256 |
|
257 |
+
# Launch the app using a tabbed interface
|
258 |
+
app = gr.TabbedInterface([image_interface, video_interface], ["Image Detection", "Video Detection"])
|
259 |
app.launch()
|