Ashegh-Sad-Warrior commited on
Commit
ecbd1be
·
verified ·
1 Parent(s): b23026a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +98 -98
app.py CHANGED
@@ -1,6 +1,9 @@
1
  # !pip install ultralytics
2
  # !pip install gradio
3
 
 
 
 
4
  import cv2
5
  from ultralytics import YOLO
6
  from PIL import Image
@@ -10,65 +13,65 @@ import numpy as np
10
  import tempfile
11
  import os
12
 
13
- # بررسی وجود فایل مدل
14
- if os.path.exists('weights/best.pt'):
15
  print("Model file found.")
16
  else:
17
  print("Model file not found. Please upload 'best.pt' to the Space.")
18
 
19
- # بارگذاری مدل آموزش‌دیده شما
20
- model = YOLO('best.pt') # مسیر مدل را به صورت نسبی تنظیم کنید
21
 
22
- # تعریف نام کلاس‌ها به انگلیسی و فارسی
23
  class_names = {
24
- 0: ('plane', 'هواپیما'),
25
- 1: ('ship', 'کشتی'),
26
- 2: ('storage tank', 'مخزن ذخیره'),
27
- 3: ('baseball diamond', 'زمین بیسبال'),
28
- 4: ('tennis court', 'زمین تنیس'),
29
- 5: ('basketball court', 'زمین بسکتبال'),
30
- 6: ('ground track field', 'زمین دو و میدانی'),
31
- 7: ('harbor', 'بندرگاه'),
32
- 8: ('bridge', 'پل'),
33
- 9: ('large vehicle', 'خودرو بزرگ'),
34
- 10: ('small vehicle', 'خودرو کوچک'),
35
- 11: ('helicopter', 'هلیکوپتر'),
36
- 12: ('roundabout', 'میدان'),
37
- 13: ('soccer ball field', 'زمین فوتبال'),
38
- 14: ('swimming pool', 'استخر شنا')
39
  }
40
 
41
- # رنگ‌ها برای هر کلاس (BGR برای OpenCV)
42
  colors = {
43
- 0: (255, 0, 0), # قرمز
44
- 1: (0, 255, 0), # سبز
45
- 2: (0, 0, 255), # آبی
46
- 3: (255, 255, 0), # زرد
47
- 4: (255, 0, 255), # مجنتا
48
- 5: (0, 255, 255), # فیروزه‌ای
49
- 6: (128, 0, 128), # بنفش
50
- 7: (255, 165, 0), # نارنجی
51
- 8: (0, 128, 0), # سبز تیره
52
- 9: (128, 128, 0), # زیتونی
53
- 10: (0, 255, 0), # سبز روشن برای class_id=10
54
- 11: (0, 128, 128), # سبز نفتی
55
- 12: (0, 0, 128), # نیوی
56
- 13: (75, 0, 130), # ایندیگو
57
- 14: (199, 21, 133) # رز متوسط
58
  }
59
 
60
- # تابع برای تشخیص اشیاء در تصاویر
61
  def detect_and_draw_image(input_image):
62
  try:
63
- # تبدیل تصویر PIL به آرایه NumPy (RGB)
64
  input_image_np = np.array(input_image)
65
  print("Image converted to NumPy array.")
66
 
67
- # اجرای مدل روی تصویر با استفاده از آرایه NumPy (RGB)
68
  results = model.predict(source=input_image_np, conf=0.3)
69
  print("Model prediction completed.")
70
 
71
- # دسترسی به نتایج OBB
72
  if hasattr(results[0], 'obb') and results[0].obb is not None:
73
  obb_results = results[0].obb
74
  print("Accessed obb_results.")
@@ -76,49 +79,48 @@ def detect_and_draw_image(input_image):
76
  print("No 'obb' attribute found in results[0].")
77
  obb_results = None
78
 
79
- # بررسی وجود جعبه‌های شناسایی شده
80
  if obb_results is None or len(obb_results.data) == 0:
81
- print("هیچ شیء شناسایی نشده است.")
82
  df = pd.DataFrame({
83
- 'Label (English)': [],
84
- 'Label (Persian)': [],
85
  'Object Count': []
86
  })
87
  return input_image, df
88
 
89
  counts = {}
90
- # پردازش نتایج و رسم جعبه‌ها
91
  for obb, conf, cls in zip(obb_results.data.cpu().numpy(), obb_results.conf.cpu().numpy(), obb_results.cls.cpu().numpy()):
92
  x_center, y_center, width, height, rotation = obb[:5]
93
  class_id = int(cls)
94
  confidence = float(conf)
95
 
96
- # رسم جعبه چرخان با استفاده از OpenCV
97
- rect = ((x_center, y_center), (width, height), rotation * 180.0 / np.pi) # تبدیل رادیان به درجه
98
  box_points = cv2.boxPoints(rect)
99
  box_points = np.int0(box_points)
100
  color = colors.get(class_id, (0, 255, 0))
101
- cv2.drawContours(input_image_np, [box_points], 0, color, 2)
102
  print(f"Drawn OBB for class_id {class_id} with confidence {confidence}.")
103
 
104
- # رسم برچسب
105
- label_en, label_fa = class_names.get(class_id, ('unknown', 'ناشناخته'))
106
- cv2.putText(input_image_np, f'{label_en}: {confidence:.2f}',
107
- (int(x_center), int(y_center)),
108
- cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2, cv2.LINE_AA)
 
109
 
110
- # شمارش اشیاء
111
- counts[label_en] = counts.get(label_en, 0) + 1
112
 
113
- # تبدیل تصویر به RGB برای Gradio
114
  image_rgb = cv2.cvtColor(input_image_np, cv2.COLOR_BGR2RGB)
115
  output_image = Image.fromarray(image_rgb)
116
  print("Image converted back to RGB for Gradio.")
117
 
118
- # ایجاد DataFrame برای نمایش نتایج
119
  df = pd.DataFrame({
120
- 'Label (English)': list(counts.keys()),
121
- 'Label (Persian)': [class_names.get(k, ('unknown', 'ناشناخته'))[1] for k in counts.keys()],
122
  'Object Count': list(counts.values())
123
  })
124
  print("DataFrame created.")
@@ -128,13 +130,12 @@ def detect_and_draw_image(input_image):
128
  except Exception as e:
129
  print(f"Error in detect_and_draw_image: {e}")
130
  df = pd.DataFrame({
131
- 'Label (English)': [],
132
- 'Label (Persian)': [],
133
  'Object Count': []
134
  })
135
  return input_image, df
136
 
137
- # تابع برای تشخیص اشیاء در ویدئوها
138
  def detect_and_draw_video(video_path):
139
  try:
140
  cap = cv2.VideoCapture(video_path)
@@ -150,14 +151,14 @@ def detect_and_draw_video(video_path):
150
  frame_count +=1
151
  print(f"Processing frame {frame_count}")
152
 
153
- # تغییر اندازه فریم
154
  frame = cv2.resize(frame, (640, 480))
155
 
156
- # اجرای مدل روی فریم
157
  results = model.predict(source=frame, conf=0.3)
158
  print(f"Model prediction completed for frame {frame_count}.")
159
 
160
- # دسترسی به نتایج OBB
161
  if hasattr(results[0], 'obb') and results[0].obb is not None:
162
  obb_results = results[0].obb
163
  print("Accessed obb_results for frame.")
@@ -171,22 +172,23 @@ def detect_and_draw_video(video_path):
171
  class_id = int(cls)
172
  confidence = float(conf)
173
 
174
- # رسم جعبه چرخان با استفاده از OpenCV
175
  rect = ((x_center, y_center), (width, height), rotation * 180.0 / np.pi)
176
  box_points = cv2.boxPoints(rect)
177
  box_points = np.int0(box_points)
178
  color = colors.get(class_id, (0, 255, 0))
179
- cv2.drawContours(frame, [box_points], 0, color, 2)
180
  print(f"Drawn OBB for class_id {class_id} with confidence {confidence} in frame {frame_count}.")
181
 
182
- # رسم برچسب
183
- label_en, label_fa = class_names.get(class_id, ('unknown', 'ناشناخته'))
184
- cv2.putText(frame, f"{label_en}: {confidence:.2f}",
185
- (int(x_center), int(y_center)),
186
- cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2, cv2.LINE_AA)
 
187
 
188
- # شمارش اشیاء
189
- overall_counts[label_en] = overall_counts.get(label_en, 0) + 1
190
 
191
  frames.append(frame)
192
  print(f"Frame {frame_count} processed.")
@@ -194,7 +196,7 @@ def detect_and_draw_video(video_path):
194
  cap.release()
195
  print("Video processing completed.")
196
 
197
- # ذخیره ویدئو پردازش‌شده در یک فایل موقت
198
  with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as tmpfile:
199
  output_path = tmpfile.name
200
  print(f"Saving processed video to {output_path}")
@@ -208,10 +210,9 @@ def detect_and_draw_video(video_path):
208
  out.release()
209
  print("Video saved.")
210
 
211
- # ایجاد DataFrame برای ذخیره نتایج
212
  df = pd.DataFrame({
213
- 'Label (English)': list(overall_counts.keys()),
214
- 'Label (Persian)': [class_names.get(k, ('unknown', 'ناشناخته'))[1] for k in overall_counts.keys()],
215
  'Object Count': list(overall_counts.values())
216
  })
217
  print("DataFrame created.")
@@ -220,40 +221,39 @@ def detect_and_draw_video(video_path):
220
 
221
  except Exception as e:
222
  print(f"Error in detect_and_draw_video: {e}")
223
- # در صورت بروز خطا، بازگرداندن ویدئوی اصلی بدون تغییر و یک DataFrame خالی
224
  return video_path, pd.DataFrame({
225
- 'Label (English)': [],
226
- 'Label (Persian)': [],
227
  'Object Count': []
228
  })
229
 
230
- # رابط کاربری Gradio برای تصاویر
231
  image_interface = gr.Interface(
232
  fn=detect_and_draw_image,
233
- inputs=gr.Image(type="pil", label="بارگذاری تصویر"),
234
- outputs=[gr.Image(type="pil", label="تصویر پردازش شده"), gr.Dataframe(label="تعداد اشیاء")],
235
- title="تشخیص اشیاء در تصاویر هوایی",
236
- description="یک تصویر هوایی بارگذاری کنید تا اشیاء شناسایی شده و تعداد آن‌ها را ببینید.",
237
  examples=[
238
- 'Examples/images/areial_car.jpg',
239
- 'Examples/images/arieal_car_1.jpg',
240
- 'Examples/images/t.jpg'
241
  ]
242
  )
243
 
244
- # رابط کاربری Gradio برای ویدئوها
245
  video_interface = gr.Interface(
246
  fn=detect_and_draw_video,
247
- inputs=gr.Video(label="بارگذاری ویدئو"),
248
- outputs=[gr.Video(label="ویدئوی پردازش شده"), gr.Dataframe(label="تعداد اشیاء")],
249
- title="تشخیص اشیاء در ویدئوها",
250
- description="یک ویدئو بارگذاری کنید تا اشیاء شناسایی شده و تعداد آن‌ها را ببینید.",
251
  examples=[
252
- 'Examples/video/city.mp4',
253
- 'Examples/video/airplane.mp4'
254
  ]
255
  )
256
 
257
- # اجرای برنامه با استفاده از رابط کاربری تب‌دار
258
- app = gr.TabbedInterface([image_interface, video_interface], ["تشخیص تصویر", "تشخیص ویدئو"])
259
  app.launch()
 
1
  # !pip install ultralytics
2
  # !pip install gradio
3
 
4
+ # !pip install ultralytics
5
+ # !pip install gradio
6
+
7
  import cv2
8
  from ultralytics import YOLO
9
  from PIL import Image
 
13
  import tempfile
14
  import os
15
 
16
+ # Check if the model file exists
17
+ if os.path.exists('/content/best.pt'):
18
  print("Model file found.")
19
  else:
20
  print("Model file not found. Please upload 'best.pt' to the Space.")
21
 
22
+ # Load your trained model
23
+ model = YOLO('best.pt') # Adjust the model path accordingly
24
 
25
+ # Define class names
26
  class_names = {
27
+ 0: 'plane',
28
+ 1: 'ship',
29
+ 2: 'storage tank',
30
+ 3: 'baseball diamond',
31
+ 4: 'tennis court',
32
+ 5: 'basketball court',
33
+ 6: 'ground track field',
34
+ 7: 'harbor',
35
+ 8: 'bridge',
36
+ 9: 'large vehicle',
37
+ 10: 'small vehicle',
38
+ 11: 'helicopter',
39
+ 12: 'roundabout',
40
+ 13: 'soccer ball field',
41
+ 14: 'swimming pool'
42
  }
43
 
44
+ # Colors for each class (BGR for OpenCV)
45
  colors = {
46
+ 0: (255, 0, 0), # Red
47
+ 1: (0, 255, 0), # Green
48
+ 2: (0, 0, 255), # Blue
49
+ 3: (255, 255, 0), # Yellow
50
+ 4: (255, 0, 255), # Magenta
51
+ 5: (0, 255, 255), # Cyan
52
+ 6: (128, 0, 128), # Purple
53
+ 7: (255, 165, 0), # Orange
54
+ 8: (0, 128, 0), # Dark Green
55
+ 9: (128, 128, 0), # Olive
56
+ 10: (0, 255, 0), # Light Green for class_id=10
57
+ 11: (0, 128, 128), # Teal
58
+ 12: (0, 0, 128), # Navy
59
+ 13: (75, 0, 130), # Indigo
60
+ 14: (199, 21, 133) # Medium Violet Red
61
  }
62
 
63
+ # Function to detect objects in images
64
  def detect_and_draw_image(input_image):
65
  try:
66
+ # Convert PIL image to NumPy array (RGB)
67
  input_image_np = np.array(input_image)
68
  print("Image converted to NumPy array.")
69
 
70
+ # Run the model on the image using NumPy array (RGB)
71
  results = model.predict(source=input_image_np, conf=0.3)
72
  print("Model prediction completed.")
73
 
74
+ # Access OBB results
75
  if hasattr(results[0], 'obb') and results[0].obb is not None:
76
  obb_results = results[0].obb
77
  print("Accessed obb_results.")
 
79
  print("No 'obb' attribute found in results[0].")
80
  obb_results = None
81
 
82
+ # Check if any detections are found
83
  if obb_results is None or len(obb_results.data) == 0:
84
+ print("No objects detected.")
85
  df = pd.DataFrame({
86
+ 'Label': [],
 
87
  'Object Count': []
88
  })
89
  return input_image, df
90
 
91
  counts = {}
92
+ # Process results and draw bounding boxes
93
  for obb, conf, cls in zip(obb_results.data.cpu().numpy(), obb_results.conf.cpu().numpy(), obb_results.cls.cpu().numpy()):
94
  x_center, y_center, width, height, rotation = obb[:5]
95
  class_id = int(cls)
96
  confidence = float(conf)
97
 
98
+ # Draw rotated bounding box using OpenCV
99
+ rect = ((x_center, y_center), (width, height), rotation * 180.0 / np.pi) # Convert radians to degrees
100
  box_points = cv2.boxPoints(rect)
101
  box_points = np.int0(box_points)
102
  color = colors.get(class_id, (0, 255, 0))
103
+ cv2.drawContours(input_image_np, [box_points], 0, color, 1) # Reduced thickness to 1
104
  print(f"Drawn OBB for class_id {class_id} with confidence {confidence}.")
105
 
106
+ # Draw label with less thickness and appropriate position
107
+ label = class_names.get(class_id, 'unknown')
108
+ text_position = (int(x_center), int(y_center) - int(height / 2) - 10)
109
+ cv2.putText(input_image_np, f'{label}: {confidence:.2f}',
110
+ text_position,
111
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1, cv2.LINE_AA) # Reduced font thickness to 1
112
 
113
+ # Count objects
114
+ counts[label] = counts.get(label, 0) + 1
115
 
116
+ # Convert image to RGB for Gradio
117
  image_rgb = cv2.cvtColor(input_image_np, cv2.COLOR_BGR2RGB)
118
  output_image = Image.fromarray(image_rgb)
119
  print("Image converted back to RGB for Gradio.")
120
 
121
+ # Create DataFrame to display results
122
  df = pd.DataFrame({
123
+ 'Label': list(counts.keys()),
 
124
  'Object Count': list(counts.values())
125
  })
126
  print("DataFrame created.")
 
130
  except Exception as e:
131
  print(f"Error in detect_and_draw_image: {e}")
132
  df = pd.DataFrame({
133
+ 'Label': [],
 
134
  'Object Count': []
135
  })
136
  return input_image, df
137
 
138
+ # Function to detect objects in videos
139
  def detect_and_draw_video(video_path):
140
  try:
141
  cap = cv2.VideoCapture(video_path)
 
151
  frame_count +=1
152
  print(f"Processing frame {frame_count}")
153
 
154
+ # Resize frame
155
  frame = cv2.resize(frame, (640, 480))
156
 
157
+ # Run the model on the frame
158
  results = model.predict(source=frame, conf=0.3)
159
  print(f"Model prediction completed for frame {frame_count}.")
160
 
161
+ # Access OBB results
162
  if hasattr(results[0], 'obb') and results[0].obb is not None:
163
  obb_results = results[0].obb
164
  print("Accessed obb_results for frame.")
 
172
  class_id = int(cls)
173
  confidence = float(conf)
174
 
175
+ # Draw rotated bounding box using OpenCV
176
  rect = ((x_center, y_center), (width, height), rotation * 180.0 / np.pi)
177
  box_points = cv2.boxPoints(rect)
178
  box_points = np.int0(box_points)
179
  color = colors.get(class_id, (0, 255, 0))
180
+ cv2.drawContours(frame, [box_points], 0, color, 1) # Reduced thickness to 1
181
  print(f"Drawn OBB for class_id {class_id} with confidence {confidence} in frame {frame_count}.")
182
 
183
+ # Draw label with less thickness and appropriate position
184
+ label = class_names.get(class_id, 'unknown')
185
+ text_position = (int(x_center), int(y_center) - int(height / 2) - 10)
186
+ cv2.putText(frame, f"{label}: {confidence:.2f}",
187
+ text_position,
188
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1, cv2.LINE_AA) # Reduced font thickness to 1
189
 
190
+ # Count objects
191
+ overall_counts[label] = overall_counts.get(label, 0) + 1
192
 
193
  frames.append(frame)
194
  print(f"Frame {frame_count} processed.")
 
196
  cap.release()
197
  print("Video processing completed.")
198
 
199
+ # Save processed video to a temporary file
200
  with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as tmpfile:
201
  output_path = tmpfile.name
202
  print(f"Saving processed video to {output_path}")
 
210
  out.release()
211
  print("Video saved.")
212
 
213
+ # Create DataFrame to store results
214
  df = pd.DataFrame({
215
+ 'Label': list(overall_counts.keys()),
 
216
  'Object Count': list(overall_counts.values())
217
  })
218
  print("DataFrame created.")
 
221
 
222
  except Exception as e:
223
  print(f"Error in detect_and_draw_video: {e}")
224
+ # In case of an error, return the original video and an empty DataFrame
225
  return video_path, pd.DataFrame({
226
+ 'Label': [],
 
227
  'Object Count': []
228
  })
229
 
230
+ # Gradio interface for images
231
  image_interface = gr.Interface(
232
  fn=detect_and_draw_image,
233
+ inputs=gr.Image(type="pil", label="Upload Image"),
234
+ outputs=[gr.Image(type="pil", label="Processed Image"), gr.Dataframe(label="Object Counts")],
235
+ title="Object Detection in Aerial Images",
236
+ description="Upload an aerial image to see detected objects and their counts.",
237
  examples=[
238
+ '/content/EXAMPLES/IMAGES/Examples_images_areial_car.jpg',
239
+ '/content/EXAMPLES/IMAGES/Examples_images_images.jpg',
240
+ '/content/EXAMPLES/IMAGES/Examples_images_t.jpg'
241
  ]
242
  )
243
 
244
+ # Gradio interface for videos
245
  video_interface = gr.Interface(
246
  fn=detect_and_draw_video,
247
+ inputs=gr.Video(label="Upload Video"),
248
+ outputs=[gr.Video(label="Processed Video"), gr.Dataframe(label="Object Counts")],
249
+ title="Object Detection in Videos",
250
+ description="Upload a video to see detected objects and their counts.",
251
  examples=[
252
+ '/content/EXAMPLES/VIDEO/airplane.mp4',
253
+ '/content/EXAMPLES/VIDEO/city.mp4'
254
  ]
255
  )
256
 
257
+ # Launch the app using a tabbed interface
258
+ app = gr.TabbedInterface([image_interface, video_interface], ["Image Detection", "Video Detection"])
259
  app.launch()