jotase's picture
Update app.py
48ec822 verified
from typing import Optional
import gradio as gr
import spaces
import supervision as sv
import torch
from PIL import Image
from utils.florence import load_florence_model, run_florence_inference, \
FLORENCE_OPEN_VOCABULARY_DETECTION_TASK
from utils.sam import load_sam_image_model, run_sam_inference
DEVICE = torch.device("cuda")
# DEVICE = torch.device("cpu")
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
if torch.cuda.get_device_properties(0).major >= 8:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
FLORENCE_MODEL, FLORENCE_PROCESSOR = load_florence_model(device=DEVICE)
SAM_IMAGE_MODEL = load_sam_image_model(device=DEVICE)
@spaces.GPU(duration=20)
@torch.inference_mode()
@torch.autocast(device_type="cuda", dtype=torch.bfloat16)
def process_image(image_input, text_input) -> Optional[Image.Image]:
if not image_input:
gr.Info("Please upload an image.")
return None
if not text_input:
gr.Info("Please enter a text prompt.")
return None
_, result = run_florence_inference(
model=FLORENCE_MODEL,
processor=FLORENCE_PROCESSOR,
device=DEVICE,
image=image_input,
task=FLORENCE_OPEN_VOCABULARY_DETECTION_TASK,
text=text_input
)
detections = sv.Detections.from_lmm(
lmm=sv.LMM.FLORENCE_2,
result=result,
resolution_wh=image_input.size
)
detections = run_sam_inference(SAM_IMAGE_MODEL, image_input, detections)
if len(detections) == 0:
gr.Info("No objects detected.")
return None
return Image.fromarray(detections.mask[0].astype("uint8") * 255)
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
image_input_component = gr.Image(
type='pil', label='Upload image')
text_input_component = gr.Textbox(
label='Text prompt',
placeholder='Enter text prompts')
submit_button_component = gr.Button(
value='Submit', variant='primary')
with gr.Column():
image_output_component = gr.Image(label='Output mask')
submit_button_component.click(
fn=process_image,
inputs=[
image_input_component,
text_input_component
],
outputs=[
image_output_component,
]
)
text_input_component.submit(
fn=process_image,
inputs=[
image_input_component,
text_input_component
],
outputs=[
image_output_component,
]
)
demo.launch(debug=False, show_error=True)