File size: 6,135 Bytes
a62ae31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c81c46
a62ae31
 
 
 
425bfff
a62ae31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c81c46
a62ae31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
from ultralytics import YOLO
import cv2
import gradio as gr
import numpy as np
import os
import torch
from image_segmenter import ImageSegmenter

# params
CANCEL_PROCESSING = False

img_seg = ImageSegmenter(model_type="yolov8m-seg-custom")

def resize(image):
    """
    resize the input nd array
    """
    h, w = image.shape[:2]
    if h > w:
        return cv2.resize(image, (480, 640)) 
    else:
        return cv2.resize(image, (640, 480)) 

def process_image(image):
    image = resize(image)
    prediction, _ = img_seg.predict(image)
    return prediction


def process_video(vid_path=None):
    vid_cap = cv2.VideoCapture(vid_path)
    while vid_cap.isOpened():
        ret, frame = vid_cap.read()        
        if ret:
            print("Making frame predictions ....")
            frame = resize(frame)
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            prediction, _ = img_seg.predict(frame)
            yield prediction
    
    return None

def update_segmentation_options(options):
    img_seg.is_show_bounding_boxes = True if 'Show Boundary Box' in options else False
    img_seg.is_show_segmentation = True if 'Show Segmentation Region' in options else False
    img_seg.is_show_segmentation_boundary = True if 'Show Segmentation Boundary' in options else False

def update_confidence_threshold(thres_val):
    img_seg.confidence_threshold = thres_val/100

def model_selector(model_type):

    if "Small - Better performance and less accuracy" == model_type:
        yolo_model = "yolov8s_seg_custom"
    elif "Medium - Balanced performance and accuracy" == model_type:
        yolo_model = "yolov8m-seg-custom"
    elif "Large - Slow performance and high accuracy" == model_type:
        yolo_model = "yolov8m-seg-custom"
    else:
        yolo_model = "yolov8m-seg-custom"

    img_seg = ImageSegmenter(model_type=yolo_model)

def cancel():
    CANCEL_PROCESSING = True

if __name__ == "__main__":

    # gradio gui app
    with gr.Blocks() as my_app:

        # title
        gr.Markdown("<h1><center>Hand detection and segmentation</center></h1>")

        # tabs
        with gr.Tab("Image"):
            with gr.Row():
                with gr.Column(scale=1):
                    img_input = gr.Image()
                    model_type_img = gr.Dropdown(
                        ["Small - Better performance and less accuracy", 
                         "Medium - Balanced performance and accuracy", 
                         "Large - Slow performance and high accuracy"], 
                        label="Model Type", value="Medium - Balanced performance and accuracy",
                        info="Select the inference model before running predictions!")
                    options_checkbox_img = gr.CheckboxGroup(["Show Boundary Box", "Show Segmentation Region"], label="Options")
                    conf_thres_img = gr.Slider(1, 100, value=60, label="Confidence Threshold", info="Choose the threshold above which objects should be detected")
                    submit_btn_img = gr.Button(value="Predict")                    

                with gr.Column(scale=2):
                    with gr.Row():
                        img_output = gr.Image(height=600, label="Segmentation")
            
            gr.Markdown("## Sample Images")
            gr.Examples(
                examples=[os.path.join(os.path.dirname(__file__), "assets/images/img_1.jpg"),
                          os.path.join(os.path.dirname(__file__), "assets/images/img_2.jpg")],
                inputs=img_input,
                outputs=img_output,
                fn=process_image,
                cache_examples=True,
            )

        with gr.Tab("Video"):
            with gr.Row():
                with gr.Column(scale=1):
                    vid_input = gr.Video()
                    model_type_vid = gr.Dropdown(
                        ["Small - Better performance and less accuracy", 
                         "Medium - Balanced performance and accuracy", 
                         "Large - Slow performance and high accuracy"], 
                        label="Model Type", value="Medium - Balanced performance and accuracy",
                        info="Select the inference model before running predictions!")
                    
                    options_checkbox_vid = gr.CheckboxGroup(["Show Boundary Box", "Show Segmentation Region"], label="Options")
                    conf_thres_vid = gr.Slider(1, 100, value=60, label="Confidence Threshold", info="Choose the threshold above which objects should be detected")
                    with gr.Row():
                        cancel_btn = gr.Button(value="Cancel")
                        submit_btn_vid = gr.Button(value="Predict")
            
                with gr.Column(scale=2):
                    with gr.Row():
                        vid_output = gr.Image(height=600, label="Segmentation")
            
            gr.Markdown("## Sample Videos")
            gr.Examples(
                examples=[os.path.join(os.path.dirname(__file__), "assets/videos/vid_1.mp4"),
                          os.path.join(os.path.dirname(__file__), "assets/videos/vid_2.mp4"),],
                inputs=vid_input,
                # outputs=vid_output,
                # fn=vid_segmenation,
            )
            

        # image tab logic
        submit_btn_img.click(process_image, inputs=img_input, outputs=img_output)
        options_checkbox_img.change(update_segmentation_options, options_checkbox_img, [])
        conf_thres_img.change(update_confidence_threshold, conf_thres_img, [])
        model_type_img.change(model_selector, model_type_img, [])

        # video tab logic
        submit_btn_vid.click(process_video, inputs=vid_input, outputs=vid_output)
        model_type_vid.change(model_selector, model_type_vid, [])
        cancel_btn.click(cancel, inputs=[], outputs=[])
        options_checkbox_vid.change(update_segmentation_options, options_checkbox_vid, [])
        conf_thres_vid.change(update_confidence_threshold, conf_thres_vid, [])       


    my_app.queue(concurrency_count=5, max_size=20).launch(debug=True)