capradeepgujaran's picture
Update app.py
32acaa8 verified
raw
history blame
21.6 kB
import gradio as gr
import cv2
import numpy as np
from groq import Groq
from PIL import Image as PILImage
import io
import base64
import torch
import warnings
from typing import Tuple, List, Dict, Optional
import os
# Suppress warnings
warnings.filterwarnings('ignore', category=FutureWarning)
warnings.filterwarnings('ignore', category=UserWarning)
class RobustSafetyMonitor:
def __init__(self):
"""Initialize the safety detection tool with improved configuration."""
self.client = Groq()
self.model_name = "llama-3.2-11b-vision-preview"
self.max_image_size = (800, 800)
self.colors = [(0, 0, 255), (255, 0, 0), (0, 255, 0), (255, 255, 0), (255, 0, 255)]
# Load YOLOv5 with optimized settings
self.yolo_model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True)
self.yolo_model.conf = 0.25 # Lower confidence threshold
self.yolo_model.iou = 0.45 # Adjusted IOU threshold
self.yolo_model.classes = None # Detect all classes
self.yolo_model.max_det = 50 # Increased maximum detections
self.yolo_model.cpu()
self.yolo_model.eval()
# Construction-specific keywords
self.construction_keywords = [
'person', 'worker', 'helmet', 'tool', 'machine', 'equipment',
'brick', 'block', 'pile', 'stack', 'surface', 'floor', 'ground',
'construction', 'building', 'structure'
]
def preprocess_image(self, frame: np.ndarray) -> np.ndarray:
"""Process image for analysis."""
if frame is None:
raise ValueError("No image provided")
if len(frame.shape) == 2:
frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB)
elif len(frame.shape) == 3 and frame.shape[2] == 4:
frame = cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB)
return self.resize_image(frame)
def resize_image(self, image: np.ndarray) -> np.ndarray:
"""Resize image while maintaining aspect ratio."""
height, width = image.shape[:2]
if height > self.max_image_size[1] or width > self.max_image_size[0]:
aspect = width / height
if width > height:
new_width = self.max_image_size[0]
new_height = int(new_width / aspect)
else:
new_height = self.max_image_size[1]
new_width = int(new_height * aspect)
return cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_AREA)
return image
def encode_image(self, frame: np.ndarray) -> str:
"""Convert image to base64 encoding."""
try:
frame_pil = PILImage.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
buffered = io.BytesIO()
frame_pil.save(buffered, format="JPEG", quality=95)
img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
return f"data:image/jpeg;base64,{img_base64}"
except Exception as e:
raise ValueError(f"Error encoding image: {str(e)}")
def detect_objects(self, frame: np.ndarray) -> Tuple[np.ndarray, Dict]:
"""Enhanced object detection using YOLOv5."""
try:
# Ensure proper image format
if len(frame.shape) == 2:
frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB)
elif frame.shape[2] == 4:
frame = cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB)
# Run inference with augmentation
with torch.no_grad():
results = self.yolo_model(frame, augment=True)
# Get detections
bbox_data = results.xyxy[0].cpu().numpy()
labels = results.names
# Filter and process detections
processed_boxes = []
for box in bbox_data:
x1, y1, x2, y2, conf, cls = box
if conf > 0.25: # Keep lower confidence threshold
processed_boxes.append(box)
return np.array(processed_boxes), labels
except Exception as e:
print(f"Error in object detection: {str(e)}")
return np.array([]), {}
def analyze_frame(self, frame: np.ndarray) -> Tuple[List[Dict], str]:
"""Perform safety analysis using Llama Vision."""
if frame is None:
return [], "No frame received"
try:
frame = self.preprocess_image(frame)
image_base64 = self.encode_image(frame)
completion = self.client.chat.completions.create(
model=self.model_name,
messages=[
{
"role": "user",
"content": [
{
"type": "text",
"text": """Analyze this workplace image for safety risks. Focus on:
1. Worker posture and positioning
2. Equipment and tool safety
3. Environmental hazards
4. PPE compliance
5. Material handling
List each risk on a new line starting with 'Risk:'.
Format: Risk: [Object/Area] - [Detailed description of hazard]"""
},
{
"type": "image_url",
"image_url": {
"url": image_base64
}
}
]
}
],
temperature=0.7,
max_tokens=1024,
stream=False
)
try:
response = completion.choices[0].message.content
except AttributeError:
response = str(completion.choices[0].message)
safety_issues = self.parse_safety_analysis(response)
return safety_issues, response
except Exception as e:
print(f"Analysis error: {str(e)}")
return [], f"Analysis Error: {str(e)}"
def draw_bounding_boxes(self, image: np.ndarray, bboxes: np.ndarray,
labels: Dict, safety_issues: List[Dict]) -> np.ndarray:
"""Improved bounding box visualization."""
image_copy = image.copy()
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = 0.5
thickness = 2
for idx, bbox in enumerate(bboxes):
try:
x1, y1, x2, y2, conf, class_id = bbox
label = labels[int(class_id)]
# Check if object is construction-related
is_relevant = any(keyword in label.lower() for keyword in self.construction_keywords)
if is_relevant or conf > 0.35:
color = self.colors[idx % len(self.colors)]
# Convert coordinates to integers
x1, y1, x2, y2 = map(int, [x1, y1, x2, y2])
# Draw bounding box
cv2.rectangle(image_copy, (x1, y1), (x2, y2), color, thickness)
# Check for associated safety issues
risk_found = False
for safety_issue in safety_issues:
issue_keywords = safety_issue.get('object', '').lower().split()
if any(keyword in label.lower() for keyword in issue_keywords):
label_text = f"Risk: {safety_issue.get('description', '')}"
y_pos = max(y1 - 10, 20)
cv2.putText(image_copy, label_text, (x1, y_pos), font,
font_scale, (0, 0, 255), thickness)
risk_found = True
break
if not risk_found:
label_text = f"{label} {conf:.2f}"
y_pos = max(y1 - 10, 20)
cv2.putText(image_copy, label_text, (x1, y_pos), font,
font_scale, color, thickness)
# Mark high-risk areas
if conf > 0.5 and any(risk_word in label.lower() for risk_word in
['worker', 'person', 'equipment', 'machine']):
cv2.circle(image_copy, (int((x1 + x2)/2), int((y1 + y2)/2)),
5, (0, 0, 255), -1)
except Exception as e:
print(f"Error drawing box: {str(e)}")
continue
return image_copy
def process_frame(self, frame: np.ndarray) -> Tuple[Optional[np.ndarray], str]:
"""Main processing pipeline for safety analysis."""
if frame is None:
return None, "No image provided"
try:
# Detect objects
bbox_data, labels = self.detect_objects(frame)
# Get safety analysis
safety_issues, analysis = self.analyze_frame(frame)
# Draw annotations
annotated_frame = self.draw_bounding_boxes(frame, bbox_data, labels, safety_issues)
return annotated_frame, analysis
except Exception as e:
print(f"Processing error: {str(e)}")
return None, f"Error processing image: {str(e)}"
def parse_safety_analysis(self, analysis: str) -> List[Dict]:
"""Parse the safety analysis text."""
safety_issues = []
if not isinstance(analysis, str):
return safety_issues
for line in analysis.split('\n'):
if "risk:" in line.lower():
try:
parts = line.lower().split('risk:', 1)[1].strip()
if '-' in parts:
obj, desc = parts.split('-', 1)
else:
obj, desc = parts, parts
safety_issues.append({
"object": obj.strip(),
"description": desc.strip()
})
except Exception as e:
print(f"Error parsing line: {line}, Error: {str(e)}")
continue
return safety_issues
def create_monitor_interface():
api_key = os.getenv("GROQ_API_KEY")
class SafetyMonitor:
def __init__(self):
"""Initialize Safety Monitor with configuration."""
self.client = Groq()
self.model_name = "llama-3.2-90b-vision-preview"
self.max_image_size = (800, 800)
self.colors = [(0, 0, 255), (255, 0, 0), (0, 255, 0), (255, 255, 0), (255, 0, 255)]
def resize_image(self, image):
"""Resize image while maintaining aspect ratio."""
height, width = image.shape[:2]
aspect = width / height
if width > height:
new_width = min(self.max_image_size[0], width)
new_height = int(new_width / aspect)
else:
new_height = min(self.max_image_size[1], height)
new_width = int(new_height * aspect)
return cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_AREA)
def analyze_frame(self, frame: np.ndarray) -> str:
"""Analyze frame for safety concerns."""
if frame is None:
return "No frame received"
# Convert and resize image
if len(frame.shape) == 2:
frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB)
elif len(frame.shape) == 3 and frame.shape[2] == 4:
frame = cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB)
frame = self.resize_image(frame)
frame_pil = PILImage.fromarray(frame)
# Convert to base64
buffered = io.BytesIO()
frame_pil.save(buffered,
format="JPEG",
quality=95, # High quality for better analysis
optimize=True)
img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
image_url = f"data:image/jpeg;base64,{img_base64}"
try:
completion = self.client.chat.completions.create(
model=self.model_name,
messages=[
{
"role": "user",
"content": [
{
"type": "text",
"text": """Analyze this workplace image for safety hazards. For each hazard:
1. Specify the exact location (e.g., center, top-left, bottom-right)
2. Describe the safety concern in detail
Format each finding as:
- <location>position:detailed safety description</location>
Consider:
- PPE usage and compliance
- Ergonomic risks
- Equipment safety
- Environmental hazards
- Work procedures
- Material handling
"""
},
{
"type": "image_url",
"image_url": {
"url": image_url
}
}
]
}
],
temperature=0.5,
max_tokens=500,
stream=False
)
return completion.choices[0].message.content
except Exception as e:
print(f"Analysis error: {str(e)}")
return f"Analysis Error: {str(e)}"
def draw_observations(self, image, observations):
"""Draw safety observations with accurate locations."""
height, width = image.shape[:2]
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = 0.5
thickness = 2
def get_region_coordinates(location_text):
"""Get coordinates based on location description."""
location_text = location_text.lower()
regions = {
# Basic positions
'center': (width//3, height//3, 2*width//3, 2*height//3),
'top': (width//4, 0, 3*width//4, height//3),
'bottom': (width//4, 2*height//3, 3*width//4, height),
'left': (0, height//4, width//3, 3*height//4),
'right': (2*width//3, height//4, width, 3*height//4),
'top-left': (0, 0, width//3, height//3),
'top-right': (2*width//3, 0, width, height//3),
'bottom-left': (0, 2*height//3, width//3, height),
'bottom-right': (2*width//3, 2*height//3, width, height),
# Work areas
'workspace': (width//4, height//4, 3*width//4, 3*height//4),
'machine': (2*width//3, 0, width, height),
'equipment': (2*width//3, height//3, width, 2*height//3),
'material': (0, 2*height//3, width//3, height),
'ground': (0, 2*height//3, width, height),
'floor': (0, 3*height//4, width, height),
# Body regions
'body': (width//3, height//3, 2*width//3, 2*height//3),
'hands': (width//2, height//2, 3*width//4, 2*height//3),
'head': (width//3, 0, 2*width//3, height//4),
'feet': (width//3, 3*height//4, 2*width//3, height),
'back': (width//3, height//3, 2*width//3, 2*height//3),
'knees': (width//3, 2*height//3, 2*width//3, height),
# Special areas
'workspace': (width//4, height//4, 3*width//4, 3*height//4),
'working-area': (width//4, height//4, 3*width//4, 3*height//4),
'surrounding': (0, 0, width, height),
'background': (0, 0, width, height)
}
# Find best matching region
best_match = 'center' # default
max_match_length = 0
for region_name in regions.keys():
if region_name in location_text and len(region_name) > max_match_length:
best_match = region_name
max_match_length = len(region_name)
return regions[best_match]
for idx, obs in enumerate(observations):
color = self.colors[idx % len(self.colors)]
# Split location and description if available
parts = obs.split(':')
if len(parts) >= 2:
location = parts[0]
description = ':'.join(parts[1:])
else:
location = 'center'
description = obs
# Get region coordinates
x1, y1, x2, y2 = get_region_coordinates(location)
# Draw rectangle
cv2.rectangle(image, (x1, y1), (x2, y2), color, 2)
# Add label
label = description[:50] + "..." if len(description) > 50 else description
label_size = cv2.getTextSize(label, font, font_scale, thickness)[0]
# Position text above box
text_x = max(0, x1)
text_y = max(20, y1 - 5)
# Draw text background
cv2.rectangle(image,
(text_x, text_y - label_size[1] - 5),
(text_x + label_size[0], text_y),
color, -1)
# Draw text
cv2.putText(image, label, (text_x, text_y - 5),
font, font_scale, (255, 255, 255), thickness)
return image
def process_frame(self, frame: np.ndarray) -> tuple[np.ndarray, str]:
"""Process frame and generate safety analysis."""
if frame is None:
return None, "No image provided"
analysis = self.analyze_frame(frame)
display_frame = self.resize_image(frame.copy())
# Parse observations
observations = []
for line in analysis.split('\n'):
line = line.strip()
if line.startswith('-'):
if '<location>' in line and '</location>' in line:
start = line.find('<location>') + len('<location>')
end = line.find('</location>')
observation = line[start:end].strip()
if observation:
observations.append(observation)
# Draw observations
if observations:
annotated_frame = self.draw_observations(display_frame, observations)
return annotated_frame, analysis
return display_frame, analysis
# Create interface
monitor = SafetyMonitor()
with gr.Blocks() as demo:
gr.Markdown("# Safety Analysis System powered by Llama 3.2 90b vision")
with gr.Row():
input_image = gr.Image(label="Upload Image")
output_image = gr.Image(label="Safety Analysis")
analysis_text = gr.Textbox(label="Detailed Analysis", lines=5)
def analyze_image(image):
if image is None:
return None, "No image provided"
try:
processed_frame, analysis = monitor.process_frame(image)
return processed_frame, analysis
except Exception as e:
print(f"Processing error: {str(e)}")
return None, f"Error processing image: {str(e)}"
input_image.change(
fn=analyze_image,
inputs=input_image,
outputs=[output_image, analysis_text]
)
gr.Markdown("""
## Instructions:
1. Upload a workplace image
2. View detected safety concerns
3. Check detailed analysis
""")
return demo
if __name__ == "__main__":
demo = create_monitor_interface()
demo.launch()