File size: 2,384 Bytes
11d7b39
 
 
 
 
e0684b1
11d7b39
 
 
 
a18eb75
11d7b39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3d11a50
 
11d7b39
 
 
3d11a50
11d7b39
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import gradio as gr

from modules import sam
from modules.ui_utils import *
from modules.html_constants import *
from modules.model_downloader import *


class App:
    def __init__(self):
        download_sam_model_url()
        self.app = gr.Blocks(css=CSS)
        self.sam = sam.SamInference()

    def launch(self):
        with self.app:
            with gr.Row():
                gr.Markdown(MARKDOWN_NOTE, elem_id="md_pgroject")
            with gr.Row().style(equal_height=True):  # bug https://github.com/gradio-app/gradio/issues/3202
                with gr.Column(scale=5):
                    img_input = gr.Image(label="Input image here")
                with gr.Column(scale=5):
                    # Tuable Params
                    nb_points_per_side = gr.Number(label="points_per_side", value=32)
                    sld_pred_iou_thresh = gr.Slider(label="pred_iou_thresh", value=0.88, minimum=0, maximum=1)
                    sld_stability_score_thresh = gr.Slider(label="stability_score_thresh", value=0.95, minimum=0,
                                                           maximum=1)
                    nb_crop_n_layers = gr.Number(label="crop_n_layers", value=0)
                    nb_crop_n_points_downscale_factor = gr.Number(label="crop_n_points_downscale_factor", value=1)
                    nb_min_mask_region_area = gr.Number(label="min_mask_region_area", value=0)
                    html_param_explain = gr.HTML(PARAMS_EXPLANATION, elem_id="html_param_explain")

            with gr.Row():
                btn_generate = gr.Button("GENERATE", variant="primary")
            with gr.Row():
                gallery_output = gr.Gallery(label="Output will be shown here", show_label=True).style(grid=5,
                                                                                                      height="auto")
            with gr.Row():
                file_psd = gr.File(label="PSD File will be shown here")

            params = [nb_points_per_side, sld_pred_iou_thresh, sld_stability_score_thresh, nb_crop_n_layers,
                      nb_crop_n_points_downscale_factor, nb_min_mask_region_area]
            btn_generate.click(fn=self.sam.generate_mask_app, inputs=[img_input] + params, outputs=[gallery_output, file_psd])

        self.app.queue(api_open=False).launch()


if __name__ == "__main__":
    app = App()
    app.launch()