import torch import cv2 import gradio as gr import numpy as np import requests from PIL import Image from io import BytesIO from transformers import OwlViTProcessor, OwlViTForObjectDetection # Use GPU if available if torch.cuda.is_available(): device = torch.device("cuda") else: device = torch.device("cpu") model = OwlViTForObjectDetection.from_pretrained("google/owlvit-large-patch14").to(device) model.eval() processor = OwlViTProcessor.from_pretrained("google/owlvit-large-patch14") def query_image(img_url, text_queries, score_threshold): text_queries = text_queries.split(",") response = requests.get(img_url) img = Image.open(BytesIO(response.content)) img = np.array(img) target_sizes = torch.Tensor([img.shape[:2]]) inputs = processor(text=text_queries, images=img, return_tensors="pt").to(device) with torch.no_grad(): outputs = model(**inputs) outputs.logits = outputs.logits.cpu() outputs.pred_boxes = outputs.pred_boxes.cpu() results = processor.post_process(outputs=outputs, target_sizes=target_sizes) boxes, scores, labels = results[0]["boxes"], results[0]["scores"], results[0]["labels"] font = cv2.FONT_HERSHEY_SIMPLEX for box, score, label in zip(boxes, scores, labels): box = [int(i) for i in box.tolist()] if score >= score_threshold: img = cv2.rectangle(img, box[:2], box[2:], (255,0,0), 5) if box[3] + 25 > 768: y = box[3] - 10 else: y = box[3] + 25 img = cv2.putText( img, text_queries[label], (box[0], y), font, 1, (255,0,0), 2, cv2.LINE_AA ) return img description = """ DEMO """ demo = gr.Interface( query_image, inputs=["text", "text", gr.Slider(0, 1, value=0.1)], outputs="image", title="Zero-Shot Object Detection with OWL-ViT", description=description, examples=[], ) demo.launch()