capradeepgujaran
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -2,277 +2,24 @@ import gradio as gr
|
|
2 |
import cv2
|
3 |
import numpy as np
|
4 |
from groq import Groq
|
|
|
5 |
from PIL import Image as PILImage
|
6 |
import io
|
7 |
-
import base64
|
8 |
-
import torch
|
9 |
-
import warnings
|
10 |
-
from typing import Tuple, List, Dict, Optional
|
11 |
import os
|
12 |
-
|
13 |
-
|
14 |
-
warnings.filterwarnings('ignore', category=FutureWarning)
|
15 |
-
warnings.filterwarnings('ignore', category=UserWarning)
|
16 |
-
|
17 |
-
class RobustSafetyMonitor:
|
18 |
-
def __init__(self):
|
19 |
-
"""Initialize the safety detection tool with improved configuration."""
|
20 |
-
self.client = Groq()
|
21 |
-
self.model_name = "llama-3.2-11b-vision-preview"
|
22 |
-
self.max_image_size = (800, 800)
|
23 |
-
self.colors = [(0, 0, 255), (255, 0, 0), (0, 255, 0), (255, 255, 0), (255, 0, 255)]
|
24 |
-
|
25 |
-
# Load YOLOv5 with optimized settings
|
26 |
-
self.yolo_model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True)
|
27 |
-
self.yolo_model.conf = 0.25 # Lower confidence threshold
|
28 |
-
self.yolo_model.iou = 0.45 # Adjusted IOU threshold
|
29 |
-
self.yolo_model.classes = None # Detect all classes
|
30 |
-
self.yolo_model.max_det = 50 # Increased maximum detections
|
31 |
-
self.yolo_model.cpu()
|
32 |
-
self.yolo_model.eval()
|
33 |
-
|
34 |
-
# Construction-specific keywords
|
35 |
-
self.construction_keywords = [
|
36 |
-
'person', 'worker', 'helmet', 'tool', 'machine', 'equipment',
|
37 |
-
'brick', 'block', 'pile', 'stack', 'surface', 'floor', 'ground',
|
38 |
-
'construction', 'building', 'structure'
|
39 |
-
]
|
40 |
-
|
41 |
-
def preprocess_image(self, frame: np.ndarray) -> np.ndarray:
|
42 |
-
"""Process image for analysis."""
|
43 |
-
if frame is None:
|
44 |
-
raise ValueError("No image provided")
|
45 |
-
|
46 |
-
if len(frame.shape) == 2:
|
47 |
-
frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB)
|
48 |
-
elif len(frame.shape) == 3 and frame.shape[2] == 4:
|
49 |
-
frame = cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB)
|
50 |
-
|
51 |
-
return self.resize_image(frame)
|
52 |
-
|
53 |
-
def resize_image(self, image: np.ndarray) -> np.ndarray:
|
54 |
-
"""Resize image while maintaining aspect ratio."""
|
55 |
-
height, width = image.shape[:2]
|
56 |
-
if height > self.max_image_size[1] or width > self.max_image_size[0]:
|
57 |
-
aspect = width / height
|
58 |
-
if width > height:
|
59 |
-
new_width = self.max_image_size[0]
|
60 |
-
new_height = int(new_width / aspect)
|
61 |
-
else:
|
62 |
-
new_height = self.max_image_size[1]
|
63 |
-
new_width = int(new_height * aspect)
|
64 |
-
return cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_AREA)
|
65 |
-
return image
|
66 |
-
|
67 |
-
def encode_image(self, frame: np.ndarray) -> str:
|
68 |
-
"""Convert image to base64 encoding."""
|
69 |
-
try:
|
70 |
-
frame_pil = PILImage.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
71 |
-
buffered = io.BytesIO()
|
72 |
-
frame_pil.save(buffered, format="JPEG", quality=95)
|
73 |
-
img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
74 |
-
return f"data:image/jpeg;base64,{img_base64}"
|
75 |
-
except Exception as e:
|
76 |
-
raise ValueError(f"Error encoding image: {str(e)}")
|
77 |
-
|
78 |
-
def detect_objects(self, frame: np.ndarray) -> Tuple[np.ndarray, Dict]:
|
79 |
-
"""Enhanced object detection using YOLOv5."""
|
80 |
-
try:
|
81 |
-
# Ensure proper image format
|
82 |
-
if len(frame.shape) == 2:
|
83 |
-
frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB)
|
84 |
-
elif frame.shape[2] == 4:
|
85 |
-
frame = cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB)
|
86 |
-
|
87 |
-
# Run inference with augmentation
|
88 |
-
with torch.no_grad():
|
89 |
-
results = self.yolo_model(frame, augment=True)
|
90 |
-
|
91 |
-
# Get detections
|
92 |
-
bbox_data = results.xyxy[0].cpu().numpy()
|
93 |
-
labels = results.names
|
94 |
-
|
95 |
-
# Filter and process detections
|
96 |
-
processed_boxes = []
|
97 |
-
for box in bbox_data:
|
98 |
-
x1, y1, x2, y2, conf, cls = box
|
99 |
-
if conf > 0.25: # Keep lower confidence threshold
|
100 |
-
processed_boxes.append(box)
|
101 |
-
|
102 |
-
return np.array(processed_boxes), labels
|
103 |
-
except Exception as e:
|
104 |
-
print(f"Error in object detection: {str(e)}")
|
105 |
-
return np.array([]), {}
|
106 |
-
|
107 |
-
def analyze_frame(self, frame: np.ndarray) -> Tuple[List[Dict], str]:
|
108 |
-
"""Perform safety analysis using Llama Vision."""
|
109 |
-
if frame is None:
|
110 |
-
return [], "No frame received"
|
111 |
-
|
112 |
-
try:
|
113 |
-
frame = self.preprocess_image(frame)
|
114 |
-
image_base64 = self.encode_image(frame)
|
115 |
-
|
116 |
-
completion = self.client.chat.completions.create(
|
117 |
-
model=self.model_name,
|
118 |
-
messages=[
|
119 |
-
{
|
120 |
-
"role": "user",
|
121 |
-
"content": [
|
122 |
-
{
|
123 |
-
"type": "text",
|
124 |
-
"text": """Analyze this workplace image for safety risks. Focus on:
|
125 |
-
1. Worker posture and positioning
|
126 |
-
2. Equipment and tool safety
|
127 |
-
3. Environmental hazards
|
128 |
-
4. PPE compliance
|
129 |
-
5. Material handling
|
130 |
-
|
131 |
-
List each risk on a new line starting with 'Risk:'.
|
132 |
-
Format: Risk: [Object/Area] - [Detailed description of hazard]"""
|
133 |
-
},
|
134 |
-
{
|
135 |
-
"type": "image_url",
|
136 |
-
"image_url": {
|
137 |
-
"url": image_base64
|
138 |
-
}
|
139 |
-
}
|
140 |
-
]
|
141 |
-
}
|
142 |
-
],
|
143 |
-
temperature=0.7,
|
144 |
-
max_tokens=1024,
|
145 |
-
stream=False
|
146 |
-
)
|
147 |
-
|
148 |
-
try:
|
149 |
-
response = completion.choices[0].message.content
|
150 |
-
except AttributeError:
|
151 |
-
response = str(completion.choices[0].message)
|
152 |
-
|
153 |
-
safety_issues = self.parse_safety_analysis(response)
|
154 |
-
return safety_issues, response
|
155 |
-
|
156 |
-
except Exception as e:
|
157 |
-
print(f"Analysis error: {str(e)}")
|
158 |
-
return [], f"Analysis Error: {str(e)}"
|
159 |
-
|
160 |
-
def draw_bounding_boxes(self, image: np.ndarray, bboxes: np.ndarray,
|
161 |
-
labels: Dict, safety_issues: List[Dict]) -> np.ndarray:
|
162 |
-
"""Improved bounding box visualization."""
|
163 |
-
image_copy = image.copy()
|
164 |
-
font = cv2.FONT_HERSHEY_SIMPLEX
|
165 |
-
font_scale = 0.5
|
166 |
-
thickness = 2
|
167 |
-
|
168 |
-
for idx, bbox in enumerate(bboxes):
|
169 |
-
try:
|
170 |
-
x1, y1, x2, y2, conf, class_id = bbox
|
171 |
-
label = labels[int(class_id)]
|
172 |
-
|
173 |
-
# Check if object is construction-related
|
174 |
-
is_relevant = any(keyword in label.lower() for keyword in self.construction_keywords)
|
175 |
-
|
176 |
-
if is_relevant or conf > 0.35:
|
177 |
-
color = self.colors[idx % len(self.colors)]
|
178 |
-
|
179 |
-
# Convert coordinates to integers
|
180 |
-
x1, y1, x2, y2 = map(int, [x1, y1, x2, y2])
|
181 |
-
|
182 |
-
# Draw bounding box
|
183 |
-
cv2.rectangle(image_copy, (x1, y1), (x2, y2), color, thickness)
|
184 |
-
|
185 |
-
# Check for associated safety issues
|
186 |
-
risk_found = False
|
187 |
-
for safety_issue in safety_issues:
|
188 |
-
issue_keywords = safety_issue.get('object', '').lower().split()
|
189 |
-
if any(keyword in label.lower() for keyword in issue_keywords):
|
190 |
-
label_text = f"Risk: {safety_issue.get('description', '')}"
|
191 |
-
y_pos = max(y1 - 10, 20)
|
192 |
-
cv2.putText(image_copy, label_text, (x1, y_pos), font,
|
193 |
-
font_scale, (0, 0, 255), thickness)
|
194 |
-
risk_found = True
|
195 |
-
break
|
196 |
-
|
197 |
-
if not risk_found:
|
198 |
-
label_text = f"{label} {conf:.2f}"
|
199 |
-
y_pos = max(y1 - 10, 20)
|
200 |
-
cv2.putText(image_copy, label_text, (x1, y_pos), font,
|
201 |
-
font_scale, color, thickness)
|
202 |
-
|
203 |
-
# Mark high-risk areas
|
204 |
-
if conf > 0.5 and any(risk_word in label.lower() for risk_word in
|
205 |
-
['worker', 'person', 'equipment', 'machine']):
|
206 |
-
cv2.circle(image_copy, (int((x1 + x2)/2), int((y1 + y2)/2)),
|
207 |
-
5, (0, 0, 255), -1)
|
208 |
-
|
209 |
-
except Exception as e:
|
210 |
-
print(f"Error drawing box: {str(e)}")
|
211 |
-
continue
|
212 |
-
|
213 |
-
return image_copy
|
214 |
-
|
215 |
-
def process_frame(self, frame: np.ndarray) -> Tuple[Optional[np.ndarray], str]:
|
216 |
-
"""Main processing pipeline for safety analysis."""
|
217 |
-
if frame is None:
|
218 |
-
return None, "No image provided"
|
219 |
-
|
220 |
-
try:
|
221 |
-
# Detect objects
|
222 |
-
bbox_data, labels = self.detect_objects(frame)
|
223 |
-
|
224 |
-
# Get safety analysis
|
225 |
-
safety_issues, analysis = self.analyze_frame(frame)
|
226 |
-
|
227 |
-
# Draw annotations
|
228 |
-
annotated_frame = self.draw_bounding_boxes(frame, bbox_data, labels, safety_issues)
|
229 |
-
|
230 |
-
return annotated_frame, analysis
|
231 |
-
|
232 |
-
except Exception as e:
|
233 |
-
print(f"Processing error: {str(e)}")
|
234 |
-
return None, f"Error processing image: {str(e)}"
|
235 |
-
|
236 |
-
def parse_safety_analysis(self, analysis: str) -> List[Dict]:
|
237 |
-
"""Parse the safety analysis text."""
|
238 |
-
safety_issues = []
|
239 |
-
|
240 |
-
if not isinstance(analysis, str):
|
241 |
-
return safety_issues
|
242 |
-
|
243 |
-
for line in analysis.split('\n'):
|
244 |
-
if "risk:" in line.lower():
|
245 |
-
try:
|
246 |
-
parts = line.lower().split('risk:', 1)[1].strip()
|
247 |
-
if '-' in parts:
|
248 |
-
obj, desc = parts.split('-', 1)
|
249 |
-
else:
|
250 |
-
obj, desc = parts, parts
|
251 |
-
|
252 |
-
safety_issues.append({
|
253 |
-
"object": obj.strip(),
|
254 |
-
"description": desc.strip()
|
255 |
-
})
|
256 |
-
except Exception as e:
|
257 |
-
print(f"Error parsing line: {line}, Error: {str(e)}")
|
258 |
-
continue
|
259 |
-
|
260 |
-
return safety_issues
|
261 |
-
|
262 |
|
263 |
def create_monitor_interface():
|
264 |
api_key = os.getenv("GROQ_API_KEY")
|
265 |
|
266 |
class SafetyMonitor:
|
267 |
def __init__(self):
|
268 |
-
"""Initialize Safety Monitor with configuration."""
|
269 |
self.client = Groq()
|
270 |
self.model_name = "llama-3.2-90b-vision-preview"
|
271 |
-
self.max_image_size = (800, 800)
|
272 |
-
self.colors = [(
|
273 |
-
|
274 |
def resize_image(self, image):
|
275 |
-
"""Resize image while maintaining aspect ratio."""
|
276 |
height, width = image.shape[:2]
|
277 |
aspect = width / height
|
278 |
|
@@ -286,7 +33,6 @@ def create_monitor_interface():
|
|
286 |
return cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_AREA)
|
287 |
|
288 |
def analyze_frame(self, frame: np.ndarray) -> str:
|
289 |
-
"""Analyze frame for safety concerns."""
|
290 |
if frame is None:
|
291 |
return "No frame received"
|
292 |
|
@@ -299,11 +45,11 @@ def create_monitor_interface():
|
|
299 |
frame = self.resize_image(frame)
|
300 |
frame_pil = PILImage.fromarray(frame)
|
301 |
|
302 |
-
# Convert to base64
|
303 |
buffered = io.BytesIO()
|
304 |
frame_pil.save(buffered,
|
305 |
format="JPEG",
|
306 |
-
quality=
|
307 |
optimize=True)
|
308 |
img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
309 |
image_url = f"data:image/jpeg;base64,{img_base64}"
|
@@ -317,21 +63,9 @@ def create_monitor_interface():
|
|
317 |
"content": [
|
318 |
{
|
319 |
"type": "text",
|
320 |
-
"text": """Analyze this workplace image
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
Format each finding as:
|
325 |
-
- <location>position:detailed safety description</location>
|
326 |
-
|
327 |
-
Consider:
|
328 |
-
- PPE usage and compliance
|
329 |
-
- Ergonomic risks
|
330 |
-
- Equipment safety
|
331 |
-
- Environmental hazards
|
332 |
-
- Work procedures
|
333 |
-
- Material handling
|
334 |
-
"""
|
335 |
},
|
336 |
{
|
337 |
"type": "image_url",
|
@@ -340,139 +74,78 @@ def create_monitor_interface():
|
|
340 |
}
|
341 |
}
|
342 |
]
|
|
|
|
|
|
|
|
|
343 |
}
|
344 |
],
|
345 |
-
temperature=0.
|
346 |
-
max_tokens=
|
347 |
-
|
|
|
|
|
348 |
)
|
349 |
return completion.choices[0].message.content
|
350 |
except Exception as e:
|
351 |
-
print(f"
|
352 |
return f"Analysis Error: {str(e)}"
|
353 |
|
354 |
def draw_observations(self, image, observations):
|
355 |
-
"""Draw safety observations with accurate locations."""
|
356 |
height, width = image.shape[:2]
|
357 |
font = cv2.FONT_HERSHEY_SIMPLEX
|
358 |
font_scale = 0.5
|
359 |
thickness = 2
|
360 |
|
361 |
-
|
362 |
-
"""Get coordinates based on location description."""
|
363 |
-
location_text = location_text.lower()
|
364 |
-
regions = {
|
365 |
-
# Basic positions
|
366 |
-
'center': (width//3, height//3, 2*width//3, 2*height//3),
|
367 |
-
'top': (width//4, 0, 3*width//4, height//3),
|
368 |
-
'bottom': (width//4, 2*height//3, 3*width//4, height),
|
369 |
-
'left': (0, height//4, width//3, 3*height//4),
|
370 |
-
'right': (2*width//3, height//4, width, 3*height//4),
|
371 |
-
'top-left': (0, 0, width//3, height//3),
|
372 |
-
'top-right': (2*width//3, 0, width, height//3),
|
373 |
-
'bottom-left': (0, 2*height//3, width//3, height),
|
374 |
-
'bottom-right': (2*width//3, 2*height//3, width, height),
|
375 |
-
|
376 |
-
# Work areas
|
377 |
-
'workspace': (width//4, height//4, 3*width//4, 3*height//4),
|
378 |
-
'machine': (2*width//3, 0, width, height),
|
379 |
-
'equipment': (2*width//3, height//3, width, 2*height//3),
|
380 |
-
'material': (0, 2*height//3, width//3, height),
|
381 |
-
'ground': (0, 2*height//3, width, height),
|
382 |
-
'floor': (0, 3*height//4, width, height),
|
383 |
-
|
384 |
-
# Body regions
|
385 |
-
'body': (width//3, height//3, 2*width//3, 2*height//3),
|
386 |
-
'hands': (width//2, height//2, 3*width//4, 2*height//3),
|
387 |
-
'head': (width//3, 0, 2*width//3, height//4),
|
388 |
-
'feet': (width//3, 3*height//4, 2*width//3, height),
|
389 |
-
'back': (width//3, height//3, 2*width//3, 2*height//3),
|
390 |
-
'knees': (width//3, 2*height//3, 2*width//3, height),
|
391 |
-
|
392 |
-
# Special areas
|
393 |
-
'workspace': (width//4, height//4, 3*width//4, 3*height//4),
|
394 |
-
'working-area': (width//4, height//4, 3*width//4, 3*height//4),
|
395 |
-
'surrounding': (0, 0, width, height),
|
396 |
-
'background': (0, 0, width, height)
|
397 |
-
}
|
398 |
-
|
399 |
-
# Find best matching region
|
400 |
-
best_match = 'center' # default
|
401 |
-
max_match_length = 0
|
402 |
-
|
403 |
-
for region_name in regions.keys():
|
404 |
-
if region_name in location_text and len(region_name) > max_match_length:
|
405 |
-
best_match = region_name
|
406 |
-
max_match_length = len(region_name)
|
407 |
-
|
408 |
-
return regions[best_match]
|
409 |
-
|
410 |
for idx, obs in enumerate(observations):
|
411 |
color = self.colors[idx % len(self.colors)]
|
412 |
|
413 |
-
#
|
414 |
-
|
415 |
-
|
416 |
-
|
417 |
-
|
418 |
-
else:
|
419 |
-
location = 'center'
|
420 |
-
description = obs
|
421 |
-
|
422 |
-
# Get region coordinates
|
423 |
-
x1, y1, x2, y2 = get_region_coordinates(location)
|
424 |
|
425 |
# Draw rectangle
|
426 |
-
cv2.rectangle(image, (
|
427 |
|
428 |
-
# Add label
|
429 |
-
label =
|
430 |
label_size = cv2.getTextSize(label, font, font_scale, thickness)[0]
|
431 |
-
|
432 |
-
|
433 |
-
text_x = max(0, x1)
|
434 |
-
text_y = max(20, y1 - 5)
|
435 |
-
|
436 |
-
# Draw text background
|
437 |
-
cv2.rectangle(image,
|
438 |
-
(text_x, text_y - label_size[1] - 5),
|
439 |
-
(text_x + label_size[0], text_y),
|
440 |
-
color, -1)
|
441 |
-
|
442 |
-
# Draw text
|
443 |
-
cv2.putText(image, label, (text_x, text_y - 5),
|
444 |
-
font, font_scale, (255, 255, 255), thickness)
|
445 |
|
446 |
return image
|
447 |
|
448 |
def process_frame(self, frame: np.ndarray) -> tuple[np.ndarray, str]:
|
449 |
-
"""Process frame and generate safety analysis."""
|
450 |
if frame is None:
|
451 |
return None, "No image provided"
|
452 |
|
453 |
analysis = self.analyze_frame(frame)
|
454 |
display_frame = self.resize_image(frame.copy())
|
455 |
|
456 |
-
# Parse observations
|
457 |
observations = []
|
458 |
for line in analysis.split('\n'):
|
459 |
line = line.strip()
|
460 |
if line.startswith('-'):
|
|
|
461 |
if '<location>' in line and '</location>' in line:
|
462 |
start = line.find('<location>') + len('<location>')
|
463 |
end = line.find('</location>')
|
464 |
-
observation = line[
|
465 |
-
|
466 |
-
|
|
|
|
|
467 |
|
468 |
-
# Draw observations
|
469 |
-
|
470 |
-
annotated_frame = self.draw_observations(display_frame, observations)
|
471 |
-
return annotated_frame, analysis
|
472 |
|
473 |
-
return
|
474 |
|
475 |
-
# Create interface
|
476 |
monitor = SafetyMonitor()
|
477 |
|
478 |
with gr.Blocks() as demo:
|
@@ -480,7 +153,7 @@ def create_monitor_interface():
|
|
480 |
|
481 |
with gr.Row():
|
482 |
input_image = gr.Image(label="Upload Image")
|
483 |
-
output_image = gr.Image(label="
|
484 |
|
485 |
analysis_text = gr.Textbox(label="Detailed Analysis", lines=5)
|
486 |
|
@@ -500,15 +173,7 @@ def create_monitor_interface():
|
|
500 |
outputs=[output_image, analysis_text]
|
501 |
)
|
502 |
|
503 |
-
gr.Markdown("""
|
504 |
-
## Instructions:
|
505 |
-
1. Upload a workplace image
|
506 |
-
2. View detected safety concerns
|
507 |
-
3. Check detailed analysis
|
508 |
-
""")
|
509 |
-
|
510 |
return demo
|
511 |
|
512 |
-
|
513 |
-
|
514 |
-
demo.launch()
|
|
|
2 |
import cv2
|
3 |
import numpy as np
|
4 |
from groq import Groq
|
5 |
+
import time
|
6 |
from PIL import Image as PILImage
|
7 |
import io
|
|
|
|
|
|
|
|
|
8 |
import os
|
9 |
+
import base64
|
10 |
+
import random
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
def create_monitor_interface():
|
13 |
api_key = os.getenv("GROQ_API_KEY")
|
14 |
|
15 |
class SafetyMonitor:
|
16 |
def __init__(self):
|
|
|
17 |
self.client = Groq()
|
18 |
self.model_name = "llama-3.2-90b-vision-preview"
|
19 |
+
self.max_image_size = (800, 800) # Increased size for better visibility
|
20 |
+
self.colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255)]
|
21 |
+
|
22 |
def resize_image(self, image):
|
|
|
23 |
height, width = image.shape[:2]
|
24 |
aspect = width / height
|
25 |
|
|
|
33 |
return cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_AREA)
|
34 |
|
35 |
def analyze_frame(self, frame: np.ndarray) -> str:
|
|
|
36 |
if frame is None:
|
37 |
return "No frame received"
|
38 |
|
|
|
45 |
frame = self.resize_image(frame)
|
46 |
frame_pil = PILImage.fromarray(frame)
|
47 |
|
48 |
+
# Convert to base64 with minimal quality
|
49 |
buffered = io.BytesIO()
|
50 |
frame_pil.save(buffered,
|
51 |
format="JPEG",
|
52 |
+
quality=30,
|
53 |
optimize=True)
|
54 |
img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
55 |
image_url = f"data:image/jpeg;base64,{img_base64}"
|
|
|
63 |
"content": [
|
64 |
{
|
65 |
"type": "text",
|
66 |
+
"text": """Analyze this workplace image and describe each safety concern in this format:
|
67 |
+
- <location>Description</location>
|
68 |
+
Use one line per issue, starting with a dash and location in tags."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
},
|
70 |
{
|
71 |
"type": "image_url",
|
|
|
74 |
}
|
75 |
}
|
76 |
]
|
77 |
+
},
|
78 |
+
{
|
79 |
+
"role": "assistant",
|
80 |
+
"content": ""
|
81 |
}
|
82 |
],
|
83 |
+
temperature=0.1,
|
84 |
+
max_tokens=150,
|
85 |
+
top_p=1,
|
86 |
+
stream=False,
|
87 |
+
stop=None
|
88 |
)
|
89 |
return completion.choices[0].message.content
|
90 |
except Exception as e:
|
91 |
+
print(f"Detailed error: {str(e)}")
|
92 |
return f"Analysis Error: {str(e)}"
|
93 |
|
94 |
def draw_observations(self, image, observations):
|
|
|
95 |
height, width = image.shape[:2]
|
96 |
font = cv2.FONT_HERSHEY_SIMPLEX
|
97 |
font_scale = 0.5
|
98 |
thickness = 2
|
99 |
|
100 |
+
# Generate random positions for each observation
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
for idx, obs in enumerate(observations):
|
102 |
color = self.colors[idx % len(self.colors)]
|
103 |
|
104 |
+
# Generate random box position
|
105 |
+
box_width = width // 3
|
106 |
+
box_height = height // 3
|
107 |
+
x = random.randint(0, width - box_width)
|
108 |
+
y = random.randint(0, height - box_height)
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
|
110 |
# Draw rectangle
|
111 |
+
cv2.rectangle(image, (x, y), (x + box_width, y + box_height), color, 2)
|
112 |
|
113 |
+
# Add label with background
|
114 |
+
label = obs[:40] + "..." if len(obs) > 40 else obs
|
115 |
label_size = cv2.getTextSize(label, font, font_scale, thickness)[0]
|
116 |
+
cv2.rectangle(image, (x, y - 20), (x + label_size[0], y), color, -1)
|
117 |
+
cv2.putText(image, label, (x, y - 5), font, font_scale, (255, 255, 255), thickness)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
|
119 |
return image
|
120 |
|
121 |
def process_frame(self, frame: np.ndarray) -> tuple[np.ndarray, str]:
|
|
|
122 |
if frame is None:
|
123 |
return None, "No image provided"
|
124 |
|
125 |
analysis = self.analyze_frame(frame)
|
126 |
display_frame = self.resize_image(frame.copy())
|
127 |
|
128 |
+
# Parse observations from the analysis
|
129 |
observations = []
|
130 |
for line in analysis.split('\n'):
|
131 |
line = line.strip()
|
132 |
if line.startswith('-'):
|
133 |
+
# Extract text between <location> tags if present
|
134 |
if '<location>' in line and '</location>' in line:
|
135 |
start = line.find('<location>') + len('<location>')
|
136 |
end = line.find('</location>')
|
137 |
+
observation = line[end + len('</location>'):].strip()
|
138 |
+
else:
|
139 |
+
observation = line[1:].strip() # Remove the dash
|
140 |
+
if observation:
|
141 |
+
observations.append(observation)
|
142 |
|
143 |
+
# Draw observations on the image
|
144 |
+
annotated_frame = self.draw_observations(display_frame, observations)
|
|
|
|
|
145 |
|
146 |
+
return annotated_frame, analysis
|
147 |
|
148 |
+
# Create the main interface
|
149 |
monitor = SafetyMonitor()
|
150 |
|
151 |
with gr.Blocks() as demo:
|
|
|
153 |
|
154 |
with gr.Row():
|
155 |
input_image = gr.Image(label="Upload Image")
|
156 |
+
output_image = gr.Image(label="Annotated Results")
|
157 |
|
158 |
analysis_text = gr.Textbox(label="Detailed Analysis", lines=5)
|
159 |
|
|
|
173 |
outputs=[output_image, analysis_text]
|
174 |
)
|
175 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
176 |
return demo
|
177 |
|
178 |
+
demo = create_monitor_interface()
|
179 |
+
demo.launch()
|
|