File size: 2,673 Bytes
0691c7d
baea9b2
 
488d99e
baea9b2
 
 
 
2fbf361
b32b0a3
 
488d99e
08430c8
 
488d99e
08430c8
 
 
 
488d99e
 
2fbf361
488d99e
2fbf361
baea9b2
48ec822
08430c8
 
0691c7d
576e22a
488d99e
0691c7d
488d99e
 
 
0691c7d
488d99e
0691c7d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
488d99e
 
baea9b2
b32b0a3
 
 
 
 
 
0691c7d
b32b0a3
 
 
0691c7d
b32b0a3
 
488d99e
576e22a
b32b0a3
 
576e22a
 
0691c7d
576e22a
 
b32b0a3
488d99e
 
b32b0a3
 
488d99e
 
0691c7d
488d99e
 
576e22a
5ae5bca
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
from typing import Optional

import gradio as gr
import spaces
import supervision as sv
import torch
from PIL import Image

from utils.florence import load_florence_model, run_florence_inference, \
    FLORENCE_OPEN_VOCABULARY_DETECTION_TASK
from utils.sam import load_sam_image_model, run_sam_inference

DEVICE = torch.device("cuda")
# DEVICE = torch.device("cpu")

torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
if torch.cuda.get_device_properties(0).major >= 8:
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True


FLORENCE_MODEL, FLORENCE_PROCESSOR = load_florence_model(device=DEVICE)
SAM_IMAGE_MODEL = load_sam_image_model(device=DEVICE)


@spaces.GPU(duration=20)
@torch.inference_mode()
@torch.autocast(device_type="cuda", dtype=torch.bfloat16)
def process_image(image_input, text_input) -> Optional[Image.Image]:
    if not image_input:
        gr.Info("Please upload an image.")
        return None

    if not text_input:
        gr.Info("Please enter a text prompt.")
        return None

    _, result = run_florence_inference(
        model=FLORENCE_MODEL,
        processor=FLORENCE_PROCESSOR,
        device=DEVICE,
        image=image_input,
        task=FLORENCE_OPEN_VOCABULARY_DETECTION_TASK,
        text=text_input
    )
    detections = sv.Detections.from_lmm(
        lmm=sv.LMM.FLORENCE_2,
        result=result,
        resolution_wh=image_input.size
    )
    detections = run_sam_inference(SAM_IMAGE_MODEL, image_input, detections)
    if len(detections) == 0:
        gr.Info("No objects detected.")
        return None
    return Image.fromarray(detections.mask[0].astype("uint8") * 255)


with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column():
            image_input_component = gr.Image(
                type='pil', label='Upload image')
            text_input_component = gr.Textbox(
                label='Text prompt',
                placeholder='Enter text prompts')
            submit_button_component = gr.Button(
                value='Submit', variant='primary')
        with gr.Column():
            image_output_component = gr.Image(label='Output mask')

    submit_button_component.click(
        fn=process_image,
        inputs=[
            image_input_component,
            text_input_component
        ],
        outputs=[
            image_output_component,
        ]
    )
    text_input_component.submit(
        fn=process_image,
        inputs=[
            image_input_component,
            text_input_component
        ],
        outputs=[
            image_output_component,
        ]
    )

demo.launch(debug=False, show_error=True)