File size: 24,488 Bytes
0578219
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
from flask import Flask, request, jsonify, render_template, url_for
from flask_socketio import SocketIO
import threading
from ultralytics import YOLO
import numpy as np
import cv2
import matplotlib.pyplot as plt
import importlib
from segment_anything import sam_model_registry, SamPredictor
import os
from werkzeug.utils import secure_filename
import logging
import json
import shutil
import sys
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
app = Flask(__name__)
socketio = SocketIO(app)

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Configuration
class Config:
    BASE_DIR = os.path.abspath(os.path.dirname(__file__))
    UPLOAD_FOLDER = os.path.join(BASE_DIR, 'static', 'uploads')
    SAM_RESULT_FOLDER = os.path.join(BASE_DIR, 'static', 'sam','sam_results')
    YOLO_RESULT_FOLDER = os.path.join(BASE_DIR, 'static', 'yolo','yolo_results')
    YOLO_TRAIN_IMAGE_FOLDER = os.path.join(BASE_DIR, 'static', 'yolo','dataset_yolo','train','images')
    YOLO_TRAIN_LABEL_FOLDER = os.path.join(BASE_DIR, 'static', 'yolo','dataset_yolo','train','labels')
    AREA_DATA_FOLDER = os.path.join(BASE_DIR, 'static', 'yolo','area_data')
    ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg'}
    MAX_CONTENT_LENGTH = 16 * 1024 * 1024  # 16MB max file size
    SAM_CHECKPOINT = os.path.join(BASE_DIR, 'static', 'sam',"sam_vit_h_4b8939.pth")
    SAM_2 = os.path.join(BASE_DIR, 'static', 'sam',"sam2.1_hiera_large.pt")
    YOLO_PATH = os.path.join(BASE_DIR, 'static', 'yolo', "model_yolo.pt")
    RETRAINED_MODEL_PATH = os.path.join(BASE_DIR, 'static', 'yolo', "model_retrained.pt")
    DATA_PATH = os.path.join(BASE_DIR, 'static', 'yolo','dataset_yolo', "data.yaml")

app.config.from_object(Config)

# Ensure directories exist
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
os.makedirs(app.config['SAM_RESULT_FOLDER'], exist_ok=True)
os.makedirs(app.config['YOLO_RESULT_FOLDER'], exist_ok=True)
os.makedirs(app.config['YOLO_TRAIN_IMAGE_FOLDER'], exist_ok=True)
os.makedirs(app.config['YOLO_TRAIN_LABEL_FOLDER'], exist_ok=True)
os.makedirs(app.config['AREA_DATA_FOLDER'], exist_ok=True)


# Initialize Yolo model
try:
    model = YOLO(app.config['YOLO_PATH'])
except Exception as e:
    logger.error(f"Failed to initialize YOLO model: {str(e)}")
    raise

try:
    sam2_checkpoint = app.config['SAM_2']
    model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"

    sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cpu")
    predictor = SAM2ImagePredictor(sam2_model)
except Exception as e:
    logger.error(f"Failed to initialize SAM model: {str(e)}")
    raise

def allowed_file(filename):
    return '.' in filename and filename.rsplit('.', 1)[1].lower() in app.config['ALLOWED_EXTENSIONS']

def scale_coordinates(coords, original_dims, target_dims):
    """
    Scale coordinates from one dimension space to another.
    
    Args:
        coords: List of [x, y] coordinates
        original_dims: Tuple of (width, height) of original space
        target_dims: Tuple of (width, height) of target space
    
    Returns:
        Scaled coordinates
    """
    scale_x = target_dims[0] / original_dims[0]
    scale_y = target_dims[1] / original_dims[1]
    
    return [
        [int(coord[0] * scale_x), int(coord[1] * scale_y)]
        for coord in coords
    ]

def scale_box(box, original_dims, target_dims):
    """
    Scale bounding box coordinates from one dimension space to another.
    
    Args:
        box: List of [x1, y1, x2, y2] coordinates
        original_dims: Tuple of (width, height) of original space
        target_dims: Tuple of (width, height) of target space
    
    Returns:
        Scaled box coordinates
    """
    scale_x = target_dims[0] / original_dims[0]
    scale_y = target_dims[1] / original_dims[1]
    
    return [
        int(box[0] * scale_x),  # x1
        int(box[1] * scale_y),  # y1
        int(box[2] * scale_x),  # x2
        int(box[3] * scale_y)   # y2
    ]

def retrain_model_fn():
    # Parameters for retraining
    data_path = app.config['DATA_PATH']
    epochs = 5
    img_size = 640
    batch_size = 8

    # Start training with YOLO, using event listeners for epoch completion
    for epoch in range(epochs):
        # Train the model for one epoch, here we simulate with a loop
        model.train(
            data=data_path,
            epochs=1,  # Use 1 epoch per call to get individual progress
            imgsz=img_size,
            batch=batch_size,
            device="cpu"  # Adjust based on system capabilities
        )

        # Emit an update to the client after each epoch
        socketio.emit('training_update', {
            'epoch': epoch + 1,
            'status': f"Epoch {epoch + 1} complete"
        })

    # Emit a message once training is complete
    socketio.emit('training_complete', {'status': "Retraining complete"})
    model.save(app.config['YOLO_PATH'])
    logger.info("Model retrained successfully")

@app.route('/')
def index():
    return render_template('index.html')

@app.route('/yolo')
def yolo():
    return render_template('yolo.html')

@app.route('/upload_sam', methods=['POST'])
def upload_sam_file():
    """
    Handles SAM image upload and embeds the image into the predictor instance.

    Returns:
        JSON response with 'message', 'image_url', 'filename', and 'dimensions' keys
            on success, or 'error' key with an appropriate error message on failure.
    """

    try:
        if 'file' not in request.files:
            return jsonify({'error': 'No file part'}), 400
        
        file = request.files['file']
        if file.filename == '':
            return jsonify({'error': 'No selected file'}), 400
        
        if not allowed_file(file.filename):
            return jsonify({'error': 'Invalid file type. Allowed types: PNG, JPG, JPEG'}), 400
        
        filename = secure_filename(file.filename)
        filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
        file.save(filepath)
        
        # Set the image for predictor right after upload
        image = cv2.imread(filepath)
        if image is None:
            return jsonify({'error': 'Failed to load uploaded image'}), 500
            
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        predictor.set_image(image)
        logger.info("Image embedded successfully")
        
        # Get image dimensions
        height, width = image.shape[:2]
        
        image_url = url_for('static', filename=f'uploads/{filename}')
        logger.info(f"File uploaded successfully: {filepath}")
        
        return jsonify({
            'message': 'File uploaded successfully',
            'image_url': image_url,
            'filename': filename,
            'dimensions': {
                'width': width,
                'height': height
            }
        })
        
    except Exception as e:
        logger.error(f"Upload error: {str(e)}")
        return jsonify({'error': 'Server error during upload'}), 500

@app.route('/upload_yolo', methods=['POST'])
def upload_yolo_file():
    """
    Upload a YOLO image file
    
    This endpoint allows a POST request containing a single image file. The file is
    saved to the uploads folder and the image is embedded into the YOLO model.
    
    Returns a JSON response with the following keys:
    - message: a success message
    - image_url: the URL of the uploaded image
    - filename: the name of the uploaded file
    
    If an error occurs, the JSON response will contain an 'error' key with a
    descriptive error message.
    """
    try:
        if 'file' not in request.files:
            return jsonify({'error': 'No file part'}), 400
        
        file = request.files['file']
        if file.filename == '':
            return jsonify({'error': 'No selected file'}), 400
        
        if not allowed_file(file.filename):
            return jsonify({'error': 'Invalid file type. Allowed types: PNG, JPG, JPEG'}), 400
        
        filename = secure_filename(file.filename)
        filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
        file.save(filepath)
        
        
        image_url = url_for('static', filename=f'uploads/{filename}')
        logger.info(f"File uploaded successfully: {filepath}")
        
        return jsonify({
            'message': 'File uploaded successfully',
            'image_url': image_url,
            'filename': filename,
        })
        
    except Exception as e:
        logger.error(f"Upload error: {str(e)}")
        return jsonify({'error': 'Server error during upload'}), 500

@app.route('/generate_mask', methods=['POST'])
def generate_mask():
    """
    Generate a mask for a given image using the YOLO model
    @param data: a JSON object containing the following keys:
        - filename: the name of the image file
        - normalized_void_points: a list of normalized 2D points (x, y) representing the voids
        - normalized_component_boxes: a list of normalized 2D bounding boxes (x, y, w, h) representing the components
    @return: a JSON object containing the following keys:
        - status: a string indicating the status of the request
        - train_image_url: the URL of the saved train image
        - result_path: the URL of the saved result image
    """
    try:
        data = request.json
        normalized_void_points = data.get('void_points', [])
        normalized_component_boxes = data.get('component_boxes', [])
        filename = data.get('filename', '')
        
        if not filename:
            return jsonify({'error': 'No filename provided'}), 400
            
        image_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
        if not os.path.exists(image_path):
            return jsonify({'error': 'Image file not found'}), 404

        # Read image
        image = cv2.imread(image_path)
        if image is None:
            return jsonify({'error': 'Failed to load image'}), 500
            
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image_height, image_width = image.shape[:2]
        
        # Denormalize coordinates back to pixel values
        void_points = [
            [int(point[0] * image_width), int(point[1] * image_height)]
            for point in normalized_void_points
        ]
        logger.info(f"Void points: {void_points}")
        
        component_boxes = [
            [
                int(box[0] * image_width),
                int(box[1] * image_height),
                int(box[2] * image_width),
                int(box[3] * image_height)
            ]
            for box in normalized_component_boxes
        ]
        logger.info(f"Void points: {void_points}")

        # Create a list to store individual void masks
        void_masks = []
        
        # Process void points one by one
        for point in void_points:
            # Convert point to correct format: [N, 2] array
            point_coord = np.array([[point[0], point[1]]])
            point_label = np.array([1])  # Single label
            
            masks, scores, _ = predictor.predict(
                point_coords=point_coord,
                point_labels=point_label,
                multimask_output=True  # Get multiple masks
            )
            
            if len(masks) > 0:  # Check if any masks were generated
                # Get the mask with highest score
                best_mask_idx = np.argmax(scores)
                void_masks.append(masks[best_mask_idx])
                logger.info(f"Processed void point {point} with score {scores[best_mask_idx]}")

        # Process component boxes
        component_masks = []
        if component_boxes:
            for box in component_boxes:
                # Convert box to correct format: [2, 2] array
                box_np = np.array([[box[0], box[1]], [box[2], box[3]]])
                masks, scores, _ = predictor.predict(
                    box=box_np,
                    multimask_output=True
                )
                if len(masks) > 0:
                    best_mask_idx = np.argmax(scores)
                    component_masks.append(masks[best_mask_idx])
                    logger.info(f"Processed component box {box}")

        # Create visualization with different colors for each void
        combined_image = image.copy()

        # Font settings for labels
        font = cv2.FONT_HERSHEY_SIMPLEX
        font_scale = 0.6
        font_color = (0,0,0)  # White text color
        font_thickness = 1
        background_color = (255, 255, 255)  # White background for text

        # Helper function to get bounding box coordinates
        def get_bounding_box(mask):
            coords = np.column_stack(np.where(mask))
            x_min, y_min = coords.min(axis=0)
            x_max, y_max = coords.max(axis=0)
            return (x_min, y_min, x_max, y_max)
        
        # Helper function to add text with background
        def put_text_with_background(img, text, pos):
            # Calculate text size
            (text_w, text_h), _ = cv2.getTextSize(text, font, font_scale, font_thickness)
            # Define the rectangle coordinates for background
            background_tl = (pos[0], pos[1] - text_h - 2)
            background_br = (pos[0] + text_w, pos[1] + 2)
            # Draw white rectangle as background
            cv2.rectangle(img, background_tl, background_br, background_color, -1)
            # Put the text over the background rectangle
            cv2.putText(img, text, pos, font, font_scale, font_color, font_thickness, cv2.LINE_AA)

        def get_safe_label_position(x_min, y_min, x_max, y_max, text_w, text_h, img_width, img_height):
            # Default to top-right of bounding box
            x_pos = min(y_max, img_width - text_w - 10)  # Keep 10px margin from the right
            y_pos = max(x_min + text_h + 5, text_h + 5)  # Keep 5px margin from the top
            return x_pos, y_pos


        # Apply void masks with different colors
        for mask in void_masks:
            mask = mask.astype(bool)
            combined_image[mask, 0] = np.clip(0.5 * image[mask, 0] + 0.5 * 255, 0, 255)  # Red channel with transparency
            combined_image[mask, 1] = np.clip(0.5 * image[mask, 1], 0, 255)              # Green channel reduced
            combined_image[mask, 2] = np.clip(0.5 * image[mask, 2], 0, 255)
            logger.info("Mask Drawn")  

        # Apply component masks in green
        for mask in component_masks:
            mask = mask.astype(bool)
        # Only apply green where there is no red overlay
            non_red_area = mask & ~np.any([void_mask for void_mask in void_masks], axis=0)
            combined_image[non_red_area, 0] = np.clip(0.5 * image[non_red_area, 0], 0, 255)              # Reduced red channel
            combined_image[non_red_area, 1] = np.clip(0.5 * image[non_red_area, 1] + 0.5 * 255, 0, 255)  # Green channel
            combined_image[non_red_area, 2] = np.clip(0.5 * image[non_red_area, 2], 0, 255)
            logger.info("Mask Drawn") 


        # Add labels on top of masks
        for i,mask in enumerate(void_masks):
            x_min, y_min, x_max, y_max = get_bounding_box(mask)
            (text_w, text_h), _ = cv2.getTextSize("Void", font, font_scale, font_thickness)
            label_position = get_safe_label_position(x_min, y_min, x_max, y_max, text_w, text_h, combined_image.shape[1], combined_image.shape[0])
            put_text_with_background(combined_image, f"Void {i+1}", label_position)    

        for i,mask in enumerate(component_masks):
            i=i+1
            x_min, y_min, x_max, y_max = get_bounding_box(mask)
            (text_w, text_h), _ = cv2.getTextSize("Component", font, font_scale, font_thickness)
            label_position = get_safe_label_position(x_min, y_min, x_max, y_max, text_w, text_h, combined_image.shape[1], combined_image.shape[0])
            put_text_with_background(combined_image, f"Component {i}", label_position)

        # Prepare an empty list to store the output in the required format
        mask_coordinates = []

        for mask in void_masks:
            # Get contours from the mask
            contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
            # Image dimensions
            height, width = mask.shape

            # For each contour, extract the normalized coordinates
            for contour in contours:
                contour_points = contour.reshape(-1, 2)  # Flatten to (N, 2) where N is the number of points
                normalized_points = contour_points / [width, height]  # Normalize to (0, 1)

                class_id = 1  # 1 for voids
                row = [class_id] + normalized_points.flatten().tolist()  # Flatten and add the class
                mask_coordinates.append(row)

        for mask in component_masks:
            # Get contours from the mask
            contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
            # Filter to keep only the largest contour
            contours = sorted(contours, key=cv2.contourArea, reverse=True)
            largest_contour = [contours[0]] if contours else []
            # Image dimensions
            height, width = mask.shape

            # For each contour, extract the normalized coordinates
            for contour in largest_contour:
                contour_points = contour.reshape(-1, 2)  # Flatten to (N, 2) where N is the number of points
                normalized_points = contour_points / [width, height]  # Normalize to (0, 1)

                class_id = 0  # for components 
                row = [class_id] + normalized_points.flatten().tolist()  # Flatten and add the class
                mask_coordinates.append(row)

        mask_coordinates_filename = f'{filename}.txt'  # Create a unique filename
        mask_coordinates_path = os.path.join(app.config['YOLO_TRAIN_LABEL_FOLDER'], mask_coordinates_filename)


        with open(mask_coordinates_path, "w") as file:
            for row in mask_coordinates:
                # Join elements of the row into a string with spaces in between and write to the file
                file.write(" ".join(map(str, row)) + "\n")

        # Save train image
        train_image_filepath = os.path.join(app.config['YOLO_TRAIN_IMAGE_FOLDER'], filename)
        shutil.copy(image_path, train_image_filepath)
        train_image_url = url_for('static', filename=f'yolo/dataset_yolo/train/images/{filename}')

        # Save result
        result_filename = f'segmented_{filename}'
        result_path = os.path.join(app.config['SAM_RESULT_FOLDER'], result_filename)
        plt.imsave(result_path, combined_image)
        logger.info("Mask generation completed successfully")
        
        return jsonify({
            'status': 'success',
            'train_image_url':train_image_url,
            'result_path': url_for('static', filename=f'sam/sam_results/{result_filename}')
        })

    except Exception as e:
        logger.error(f"Mask generation error: {str(e)}")
        return jsonify({'error': str(e)}), 500

@app.route('/classify', methods=['POST'])
def classify():
    """
    Classify an image and return the classification result, area data, and the annotated image.

    Request body should contain a JSON object with a single key 'filename' specifying the image file to be classified.

    Returns a JSON object with the following keys:

    - status: 'success' if the classification is successful, 'error' if there is an error.
    - result_path: URL of the annotated image.
    - area_data: a list of dictionaries containing the area and overlap statistics for each component.
    - area_data_path: URL of the JSON file containing the area data.

    If there is an error, returns a JSON object with a single key 'error' containing the error message.
    """

    try:
        data = request.json
        filename = data.get('filename', '')
        if not filename:
            return jsonify({'error': 'No filename provided'}), 400
            
        image_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
        if not os.path.exists(image_path):
            return jsonify({'error': 'Image file not found'}), 404
        
        # Read image
        image = cv2.imread(image_path)
        if image is None:
            return jsonify({'error': 'Failed to load image'}), 500 

        results = model(image)
        result = results[0]

        component_masks = []
        void_masks = []

        # Extract masks and labels from results
        for mask, label in zip(result.masks.data, result.boxes.cls):
            mask_array = mask.cpu().numpy().astype(bool)  # Convert to a binary mask (boolean array)
            if label == 1:  # Assuming label '1' represents void
                void_masks.append(mask_array)
            elif label == 0:  # Assuming label '0' represents component
                component_masks.append(mask_array)

        # Calculate area and overlap statistics
        area_data = []
        for i, component_mask in enumerate(component_masks):
            component_area = np.sum(component_mask).item()  # Total component area in pixels
            void_area_within_component = 0
            max_void_area_percentage = 0
            
            # Calculate overlap of each void mask with the component mask
            for void_mask in void_masks:
                overlap_area = np.sum(void_mask & component_mask).item()  # Overlapping area
                void_area_within_component += overlap_area
                void_area_percentage = (overlap_area / component_area) * 100 if component_area > 0 else 0
                max_void_area_percentage = max(max_void_area_percentage, void_area_percentage)
            
            # Append data for this component
            area_data.append({
                "Image": filename,
                'Component': f'Component {i+1}',
                'Area': component_area,
                'Void Area (pixels)': void_area_within_component,
                'Void Area %': void_area_within_component / component_area * 100 if component_area > 0 else 0,
                'Max Void Area %': max_void_area_percentage
            })

        area_data_filename = f'area_data_{filename.split("/")[-1]}.json'  # Create a unique filename
        area_data_path = os.path.join(app.config['AREA_DATA_FOLDER'], area_data_filename)

        with open(area_data_path, 'w') as json_file:
            json.dump(area_data, json_file, indent=4)

        annotated_image = result.plot() 

        output_filename = f'output_{filename}'
        output_image_path = os.path.join(app.config['YOLO_RESULT_FOLDER'], output_filename)
        plt.imsave(output_image_path, annotated_image) 
        logger.info("Classification completed successfully")  

        return jsonify({
            'status': 'success',
            'result_path': url_for('static', filename=f'yolo/yolo_results/{output_filename}'),
            'area_data': area_data,
            'area_data_path': url_for('static', filename=f'yolo/area_data/{area_data_filename}')
        })
    except Exception as e:
        logger.error(f"Classification error: {str(e)}")
        return jsonify({'error': str(e)}), 500

retraining_status = {
    'status': 'idle',
    'progress': None,
    'message': None
}

@app.route('/start_retraining', methods=['GET', 'POST'])
def start_retraining():
    """
    Start the model retraining process.

    If the request is a POST, start the model retraining process in a separate thread.
    If the request is a GET, render the retraining page.

    Returns:
        A JSON response with the status of the retraining process, or a rendered HTML page.
    """
    if request.method == 'POST':
        # Reset status
        global retraining_status
        retraining_status['status'] = 'in_progress'
        retraining_status['progress'] = 'Initializing'
        
        # Start retraining in a separate thread
        threading.Thread(target=retrain_model_fn).start()
        return jsonify({'status': 'started'})
    else:
        # GET request - render the retraining page
        return render_template('retrain.html')

# Event handler for client connection
@socketio.on('connect')
def handle_connect():
    print('Client connected')


if __name__ == '__main__':
    app.run(port=5001, debug=True)