import gradio as gr import cv2 import numpy as np from groq import Groq from PIL import Image as PILImage import io import base64 import torch import warnings from typing import Tuple, List, Dict, Optional import os # Suppress warnings warnings.filterwarnings('ignore', category=FutureWarning) warnings.filterwarnings('ignore', category=UserWarning) class RobustSafetyMonitor: def __init__(self): """Initialize the safety detection tool with improved configuration.""" self.client = Groq() self.model_name = "llama-3.2-11b-vision-preview" self.max_image_size = (800, 800) self.colors = [(0, 0, 255), (255, 0, 0), (0, 255, 0), (255, 255, 0), (255, 0, 255)] # Load YOLOv5 with optimized settings self.yolo_model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True) self.yolo_model.conf = 0.25 # Lower confidence threshold self.yolo_model.iou = 0.45 # Adjusted IOU threshold self.yolo_model.classes = None # Detect all classes self.yolo_model.max_det = 50 # Increased maximum detections self.yolo_model.cpu() self.yolo_model.eval() # Construction-specific keywords self.construction_keywords = [ 'person', 'worker', 'helmet', 'tool', 'machine', 'equipment', 'brick', 'block', 'pile', 'stack', 'surface', 'floor', 'ground', 'construction', 'building', 'structure' ] def preprocess_image(self, frame: np.ndarray) -> np.ndarray: """Process image for analysis.""" if frame is None: raise ValueError("No image provided") if len(frame.shape) == 2: frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB) elif len(frame.shape) == 3 and frame.shape[2] == 4: frame = cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB) return self.resize_image(frame) def resize_image(self, image: np.ndarray) -> np.ndarray: """Resize image while maintaining aspect ratio.""" height, width = image.shape[:2] if height > self.max_image_size[1] or width > self.max_image_size[0]: aspect = width / height if width > height: new_width = self.max_image_size[0] new_height = int(new_width / aspect) else: new_height = self.max_image_size[1] new_width = int(new_height * aspect) return cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_AREA) return image def encode_image(self, frame: np.ndarray) -> str: """Convert image to base64 encoding.""" try: frame_pil = PILImage.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) buffered = io.BytesIO() frame_pil.save(buffered, format="JPEG", quality=95) img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8') return f"data:image/jpeg;base64,{img_base64}" except Exception as e: raise ValueError(f"Error encoding image: {str(e)}") def detect_objects(self, frame: np.ndarray) -> Tuple[np.ndarray, Dict]: """Enhanced object detection using YOLOv5.""" try: # Ensure proper image format if len(frame.shape) == 2: frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB) elif frame.shape[2] == 4: frame = cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB) # Run inference with augmentation with torch.no_grad(): results = self.yolo_model(frame, augment=True) # Get detections bbox_data = results.xyxy[0].cpu().numpy() labels = results.names # Filter and process detections processed_boxes = [] for box in bbox_data: x1, y1, x2, y2, conf, cls = box if conf > 0.25: # Keep lower confidence threshold processed_boxes.append(box) return np.array(processed_boxes), labels except Exception as e: print(f"Error in object detection: {str(e)}") return np.array([]), {} def analyze_frame(self, frame: np.ndarray) -> Tuple[List[Dict], str]: """Perform safety analysis using Llama Vision.""" if frame is None: return [], "No frame received" try: frame = self.preprocess_image(frame) image_base64 = self.encode_image(frame) completion = self.client.chat.completions.create( model=self.model_name, messages=[ { "role": "user", "content": [ { "type": "text", "text": """Analyze this workplace image for safety risks. Focus on: 1. Worker posture and positioning 2. Equipment and tool safety 3. Environmental hazards 4. PPE compliance 5. Material handling List each risk on a new line starting with 'Risk:'. Format: Risk: [Object/Area] - [Detailed description of hazard]""" }, { "type": "image_url", "image_url": { "url": image_base64 } } ] } ], temperature=0.7, max_tokens=1024, stream=False ) try: response = completion.choices[0].message.content except AttributeError: response = str(completion.choices[0].message) safety_issues = self.parse_safety_analysis(response) return safety_issues, response except Exception as e: print(f"Analysis error: {str(e)}") return [], f"Analysis Error: {str(e)}" def draw_bounding_boxes(self, image: np.ndarray, bboxes: np.ndarray, labels: Dict, safety_issues: List[Dict]) -> np.ndarray: """Improved bounding box visualization.""" image_copy = image.copy() font = cv2.FONT_HERSHEY_SIMPLEX font_scale = 0.5 thickness = 2 for idx, bbox in enumerate(bboxes): try: x1, y1, x2, y2, conf, class_id = bbox label = labels[int(class_id)] # Check if object is construction-related is_relevant = any(keyword in label.lower() for keyword in self.construction_keywords) if is_relevant or conf > 0.35: color = self.colors[idx % len(self.colors)] # Convert coordinates to integers x1, y1, x2, y2 = map(int, [x1, y1, x2, y2]) # Draw bounding box cv2.rectangle(image_copy, (x1, y1), (x2, y2), color, thickness) # Check for associated safety issues risk_found = False for safety_issue in safety_issues: issue_keywords = safety_issue.get('object', '').lower().split() if any(keyword in label.lower() for keyword in issue_keywords): label_text = f"Risk: {safety_issue.get('description', '')}" y_pos = max(y1 - 10, 20) cv2.putText(image_copy, label_text, (x1, y_pos), font, font_scale, (0, 0, 255), thickness) risk_found = True break if not risk_found: label_text = f"{label} {conf:.2f}" y_pos = max(y1 - 10, 20) cv2.putText(image_copy, label_text, (x1, y_pos), font, font_scale, color, thickness) # Mark high-risk areas if conf > 0.5 and any(risk_word in label.lower() for risk_word in ['worker', 'person', 'equipment', 'machine']): cv2.circle(image_copy, (int((x1 + x2)/2), int((y1 + y2)/2)), 5, (0, 0, 255), -1) except Exception as e: print(f"Error drawing box: {str(e)}") continue return image_copy def process_frame(self, frame: np.ndarray) -> Tuple[Optional[np.ndarray], str]: """Main processing pipeline for safety analysis.""" if frame is None: return None, "No image provided" try: # Detect objects bbox_data, labels = self.detect_objects(frame) # Get safety analysis safety_issues, analysis = self.analyze_frame(frame) # Draw annotations annotated_frame = self.draw_bounding_boxes(frame, bbox_data, labels, safety_issues) return annotated_frame, analysis except Exception as e: print(f"Processing error: {str(e)}") return None, f"Error processing image: {str(e)}" def parse_safety_analysis(self, analysis: str) -> List[Dict]: """Parse the safety analysis text.""" safety_issues = [] if not isinstance(analysis, str): return safety_issues for line in analysis.split('\n'): if "risk:" in line.lower(): try: parts = line.lower().split('risk:', 1)[1].strip() if '-' in parts: obj, desc = parts.split('-', 1) else: obj, desc = parts, parts safety_issues.append({ "object": obj.strip(), "description": desc.strip() }) except Exception as e: print(f"Error parsing line: {line}, Error: {str(e)}") continue return safety_issues def create_monitor_interface(): api_key = os.getenv("GROQ_API_KEY") class SafetyMonitor: def __init__(self): """Initialize Safety Monitor with configuration.""" self.client = Groq() self.model_name = "llama-3.2-90b-vision-preview" self.max_image_size = (800, 800) self.colors = [(0, 0, 255), (255, 0, 0), (0, 255, 0), (255, 255, 0), (255, 0, 255)] def resize_image(self, image): """Resize image while maintaining aspect ratio.""" height, width = image.shape[:2] aspect = width / height if width > height: new_width = min(self.max_image_size[0], width) new_height = int(new_width / aspect) else: new_height = min(self.max_image_size[1], height) new_width = int(new_height * aspect) return cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_AREA) def analyze_frame(self, frame: np.ndarray) -> str: """Analyze frame for safety concerns.""" if frame is None: return "No frame received" # Convert and resize image if len(frame.shape) == 2: frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB) elif len(frame.shape) == 3 and frame.shape[2] == 4: frame = cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB) frame = self.resize_image(frame) frame_pil = PILImage.fromarray(frame) # Convert to base64 buffered = io.BytesIO() frame_pil.save(buffered, format="JPEG", quality=95, # High quality for better analysis optimize=True) img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8') image_url = f"data:image/jpeg;base64,{img_base64}" try: completion = self.client.chat.completions.create( model=self.model_name, messages=[ { "role": "user", "content": [ { "type": "text", "text": """Analyze this workplace image for safety hazards. For each hazard: 1. Specify the exact location (e.g., center, top-left, bottom-right) 2. Describe the safety concern in detail Format each finding as: - position:detailed safety description Consider: - PPE usage and compliance - Ergonomic risks - Equipment safety - Environmental hazards - Work procedures - Material handling """ }, { "type": "image_url", "image_url": { "url": image_url } } ] } ], temperature=0.5, max_tokens=500, stream=False ) return completion.choices[0].message.content except Exception as e: print(f"Analysis error: {str(e)}") return f"Analysis Error: {str(e)}" def draw_observations(self, image, observations): """Draw safety observations with accurate locations.""" height, width = image.shape[:2] font = cv2.FONT_HERSHEY_SIMPLEX font_scale = 0.5 thickness = 2 def get_region_coordinates(location_text): """Get coordinates based on location description.""" location_text = location_text.lower() regions = { # Basic positions 'center': (width//3, height//3, 2*width//3, 2*height//3), 'top': (width//4, 0, 3*width//4, height//3), 'bottom': (width//4, 2*height//3, 3*width//4, height), 'left': (0, height//4, width//3, 3*height//4), 'right': (2*width//3, height//4, width, 3*height//4), 'top-left': (0, 0, width//3, height//3), 'top-right': (2*width//3, 0, width, height//3), 'bottom-left': (0, 2*height//3, width//3, height), 'bottom-right': (2*width//3, 2*height//3, width, height), # Work areas 'workspace': (width//4, height//4, 3*width//4, 3*height//4), 'machine': (2*width//3, 0, width, height), 'equipment': (2*width//3, height//3, width, 2*height//3), 'material': (0, 2*height//3, width//3, height), 'ground': (0, 2*height//3, width, height), 'floor': (0, 3*height//4, width, height), # Body regions 'body': (width//3, height//3, 2*width//3, 2*height//3), 'hands': (width//2, height//2, 3*width//4, 2*height//3), 'head': (width//3, 0, 2*width//3, height//4), 'feet': (width//3, 3*height//4, 2*width//3, height), 'back': (width//3, height//3, 2*width//3, 2*height//3), 'knees': (width//3, 2*height//3, 2*width//3, height), # Special areas 'workspace': (width//4, height//4, 3*width//4, 3*height//4), 'working-area': (width//4, height//4, 3*width//4, 3*height//4), 'surrounding': (0, 0, width, height), 'background': (0, 0, width, height) } # Find best matching region best_match = 'center' # default max_match_length = 0 for region_name in regions.keys(): if region_name in location_text and len(region_name) > max_match_length: best_match = region_name max_match_length = len(region_name) return regions[best_match] for idx, obs in enumerate(observations): color = self.colors[idx % len(self.colors)] # Split location and description if available parts = obs.split(':') if len(parts) >= 2: location = parts[0] description = ':'.join(parts[1:]) else: location = 'center' description = obs # Get region coordinates x1, y1, x2, y2 = get_region_coordinates(location) # Draw rectangle cv2.rectangle(image, (x1, y1), (x2, y2), color, 2) # Add label label = description[:50] + "..." if len(description) > 50 else description label_size = cv2.getTextSize(label, font, font_scale, thickness)[0] # Position text above box text_x = max(0, x1) text_y = max(20, y1 - 5) # Draw text background cv2.rectangle(image, (text_x, text_y - label_size[1] - 5), (text_x + label_size[0], text_y), color, -1) # Draw text cv2.putText(image, label, (text_x, text_y - 5), font, font_scale, (255, 255, 255), thickness) return image def process_frame(self, frame: np.ndarray) -> tuple[np.ndarray, str]: """Process frame and generate safety analysis.""" if frame is None: return None, "No image provided" analysis = self.analyze_frame(frame) display_frame = self.resize_image(frame.copy()) # Parse observations observations = [] for line in analysis.split('\n'): line = line.strip() if line.startswith('-'): if '' in line and '' in line: start = line.find('') + len('') end = line.find('') observation = line[start:end].strip() if observation: observations.append(observation) # Draw observations if observations: annotated_frame = self.draw_observations(display_frame, observations) return annotated_frame, analysis return display_frame, analysis # Create interface monitor = SafetyMonitor() with gr.Blocks() as demo: gr.Markdown("# Safety Analysis System powered by Llama 3.2 90b vision") with gr.Row(): input_image = gr.Image(label="Upload Image") output_image = gr.Image(label="Safety Analysis") analysis_text = gr.Textbox(label="Detailed Analysis", lines=5) def analyze_image(image): if image is None: return None, "No image provided" try: processed_frame, analysis = monitor.process_frame(image) return processed_frame, analysis except Exception as e: print(f"Processing error: {str(e)}") return None, f"Error processing image: {str(e)}" input_image.change( fn=analyze_image, inputs=input_image, outputs=[output_image, analysis_text] ) gr.Markdown(""" ## Instructions: 1. Upload a workplace image 2. View detected safety concerns 3. Check detailed analysis """) return demo if __name__ == "__main__": demo = create_monitor_interface() demo.launch()