File size: 1,949 Bytes
6c6a66c ae592b3 6c6a66c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
import torch
import cv2
import gradio as gr
import numpy as np
import requests
from PIL import Image
from io import BytesIO
from transformers import OwlViTProcessor, OwlViTForObjectDetection
# Use GPU if available
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
model = OwlViTForObjectDetection.from_pretrained("google/owlvit-large-patch14").to(device)
model.eval()
processor = OwlViTProcessor.from_pretrained("google/owlvit-large-patch14")
def query_image(img_url, text_queries, score_threshold):
text_queries = text_queries.split(",")
response = requests.get(img_url)
img = Image.open(BytesIO(response.content))
img = np.array(img)
target_sizes = torch.Tensor([img.shape[:2]])
inputs = processor(text=text_queries, images=img, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs)
outputs.logits = outputs.logits.cpu()
outputs.pred_boxes = outputs.pred_boxes.cpu()
results = processor.post_process(outputs=outputs, target_sizes=target_sizes)
boxes, scores, labels = results[0]["boxes"], results[0]["scores"], results[0]["labels"]
font = cv2.FONT_HERSHEY_SIMPLEX
for box, score, label in zip(boxes, scores, labels):
box = [int(i) for i in box.tolist()]
if score >= score_threshold:
img = cv2.rectangle(img, box[:2], box[2:], (255,0,0), 5)
if box[3] + 25 > 768:
y = box[3] - 10
else:
y = box[3] + 25
img = cv2.putText(
img, text_queries[label], (box[0], y), font, 1, (255,0,0), 2, cv2.LINE_AA
)
return img
description = """
DEMO
"""
demo = gr.Interface(
query_image,
inputs=["text", "text", gr.Slider(0, 1, value=0.1)],
outputs="image",
title="Zero-Shot Object Detection with OWL-ViT",
description=description,
examples=[],
)
demo.launch() |