import torch import cv2 import gradio as gr import numpy as np import requests from PIL import Image from io import BytesIO from transformers import OwlViTProcessor, OwlViTForObjectDetection import os # Use GPU if available if torch.cuda.is_available(): device = torch.device("cuda") else: device = torch.device("cpu") model = OwlViTForObjectDetection.from_pretrained("google/owlvit-large-patch14").to(device) model.eval() processor = OwlViTProcessor.from_pretrained("google/owlvit-large-patch14") def query_image(img, text_queries, score_threshold): text_queries = text_queries.split(",") img = np.array(img) target_sizes = torch.Tensor([img.shape[:2]]) inputs = processor(text=text_queries, images=img, return_tensors="pt").to(device) with torch.no_grad(): outputs = model(**inputs) outputs.logits = outputs.logits.cpu() outputs.pred_boxes = outputs.pred_boxes.cpu() results = processor.post_process(outputs=outputs, target_sizes=target_sizes) boxes, scores, labels = results[0]["boxes"], results[0]["scores"], results[0]["labels"] font = cv2.FONT_HERSHEY_SIMPLEX for box, score, label in zip(boxes, scores, labels): box = [int(i) for i in box.tolist()] if score >= score_threshold: img = cv2.rectangle(img, box[:2], box[2:], (255,0,0), 5) if box[3] + 25 > 768: y = box[3] - 10 else: y = box[3] + 25 img = cv2.putText( img, text_queries[label], (box[0], y), font, 1, (255,0,0), 2, cv2.LINE_AA ) return img with gr.Blocks() as demo: with gr.Column(): with gr.Tab("Capture image with webcam"): with gr.Row(): with gr.Column(): gr.Markdown("""Insert an image below and add text descriptions of what you are looking for. If you wish for assistance to find the right text queries you can ask for help from [ChatBRD](https://chatbrd.novonordisk.com/#/) but remember you need to log on Novo's VPN before you can use it.""") inputweb1 = gr.Image(source="webcam") inputweb2 = gr.Textbox() gr.Markdown(""" \n You can also use the score threshold slider to set a threshold to filter out lower probability predictions. """) inputweb3 = gr.Slider(0, 1, value=0.1) inputs_web = [inputweb1, inputweb2, inputweb3] submit_btn_web = gr.Button("Submit") web_output = gr.Image() with gr.Tab("Upload image"): gr.Markdown("""Insert an image below and add text descriptions of what you are looking for. If you wish for assistance to find the right text queries you can ask for help from [ChatBRD](https://chatbrd.novonordisk.com/#/) but remember you need to log on Novo's VPN before you can use it.""") with gr.Row(): with gr.Column(): gr.Markdown("""Insert an image below and add text descriptions of what you are looking for. If you wish for assistance to find the right text queries you can ask for help from [ChatBRD](https://chatbrd.novonordisk.com/#/) but remember you need to log on Novo's VPN before you can use it.""") inputf1 = gr.Image(source="upload") inputf2 = gr.Textbox() gr.Markdown(""" \n You can also use the score threshold slider to set a threshold to filter out lower probability predictions. """) inputf3 = gr.Slider(0, 1, value=0.1) inputs_file = [inputf1, inputf2, inputf3] submit_btn = gr.Button("Submit") im_output = gr.Image() submit_btn.click(fn=query_image, inputs= inputs_file, outputs = im_output, queue=True) submit_btn_web.click(fn=query_image, inputs= inputs_web, outputs = web_output, queue=True) #gr.Markdown("## Image Examples") #examples= [os.path.join(os.path.dirname(__file__), "IMGP0178.jpg")] #gr.Examples(postprocess=False, # examples= examples, # inputs=[inputs_file], # outputs=[im_output], # fn=query_image # ) if __name__ == "__main__": demo.queue( concurrency_count=40, # When you increase the concurrency_count parameter in queue(), max_threads() in launch() is automatically increased as well. max_size=25, # Maximum number of requests that the queue processes api_open = False # When creating a Gradio demo, you may want to restrict all traffic to happen through the user interface as opposed to the programmatic API that is automatically created for your Gradio demo. ) demo.launch(auth=("novouser", "bstad2023"))