import gradio as gr from PIL import Image import torch from torchvision.transforms import InterpolationMode BICUBIC = InterpolationMode.BICUBIC from utils import setup, get_similarity_map,get_noun_phrase, rgb_to_hsv, hsv_to_rgb from vpt.launch import default_argument_parser from collections import OrderedDict import numpy as np import matplotlib.pyplot as plt import models import string import nltk nltk.download('punkt') nltk.download('averaged_perceptron_tagger') from nltk.tokenize import word_tokenize import torchvision import spacy # download the model spacy.cli.download("en_core_web_sm") # Load spaCy model nlp = spacy.load("en_core_web_sm") def extract_objects(prompt): doc = nlp(prompt) # Extract object nouns (including proper nouns and compound nouns) objects = set() for token in doc: # Check if the token is a noun or part of a named entity if token.pos_ in {"NOUN", "PROPN"} or token.ent_type_: objects.add(token.text) # Check if the token is part of a compound noun if token.dep_ in {"compound"}: objects.add(token.head.text) return list(objects) args = default_argument_parser().parse_args() cfg = setup(args) multi_classes = True device = "cuda" if torch.cuda.is_available() else "cpu" Ours, preprocess = models.load("CS-ViT-B/16", device=device, cfg=cfg, train_bool=False) state_dict = torch.load("sketch_seg_best_miou.pth", map_location=device) # Trained on 2 gpus so we need to remove the prefix "module." to test it on a single GPU new_state_dict = OrderedDict() for k, v in state_dict.items(): name = k[7:] # remove `module.` new_state_dict[name] = v Ours.load_state_dict(new_state_dict) Ours.eval() print("Model loaded successfully") def run(sketch, caption, threshold, seed): # select a random seed between 1 and 10 for the color color_seed = np.random.randint(0, 4) # set the condidate classes here caption = caption.replace('\n',' ') classes = extract_objects(caption) # translator = str.maketrans('', '', string.punctuation) # caption = caption.translate(translator).lower() # words = word_tokenize(caption) # classes = get_noun_phrase(words) # print(classes) if len(classes) ==0 or multi_classes == False: classes = [caption] # print(classes) colors = plt.get_cmap("Set1").colors classes_colors = colors[color_seed:len(classes)+color_seed] sketch2 = sketch['composite'] # when the drawing tool is used if sketch2[:,:,0:3].sum() == 0: temp = sketch2[:,:,3] # invert it temp = 255 - temp sketch2 = np.repeat(temp[:, :, np.newaxis], 3, axis=2) temp2= np.full_like(temp, 255) sketch2 = np.dstack((sketch2, temp2)) sketch2 = np.array(sketch2) pil_img = Image.fromarray(sketch2).convert('RGB') sketch_tensor = preprocess(pil_img).unsqueeze(0).to(device) # torchvision.utils.save_image(sketch_tensor, 'sketch_tensor.png') with torch.no_grad(): text_features = models.encode_text_with_prompt_ensemble(Ours, classes, device, no_module=True) redundant_features = models.encode_text_with_prompt_ensemble(Ours, [""], device, no_module=True) num_of_tokens = 3 with torch.no_grad(): sketch_features = Ours.encode_image(sketch_tensor, layers=[12], text_features=text_features - redundant_features, mode="test").squeeze(0) sketch_features = sketch_features / sketch_features.norm(dim=1, keepdim=True) similarity = sketch_features @ (text_features - redundant_features).t() patches_similarity = similarity[0, num_of_tokens + 1:, :] pixel_similarity = get_similarity_map(patches_similarity.unsqueeze(0), pil_img.size).cpu() # visualize_attention_maps_with_tokens(pixel_similarity, classes) pixel_similarity[pixel_similarity < threshold] = 0 pixel_similarity_array = pixel_similarity.cpu().numpy().transpose(2, 0, 1) # display_segmented_sketch(pixel_similarity_array, sketch2, classes, classes_colors, live=True) # Find the class index with the highest similarity for each pixel class_indices = np.argmax(pixel_similarity_array, axis=0) # Create an HSV image placeholder hsv_image = np.zeros(class_indices.shape + (3,)) # Shape (512, 512, 3) hsv_image[..., 2] = 1 # Set Value to 1 for a white base # Set the hue and value channels for i, color in enumerate(classes_colors): rgb_color = np.array(color).reshape(1, 1, 3) hsv_color = rgb_to_hsv(rgb_color) mask = class_indices == i if i < len(classes): # For the first N-2 classes, set color based on similarity hsv_image[..., 0][mask] = hsv_color[0, 0, 0] # Hue hsv_image[..., 1][mask] = pixel_similarity_array[i][mask] > 0 # Saturation hsv_image[..., 2][mask] = pixel_similarity_array[i][mask] # Value else: # For the last two classes, set pixels to black hsv_image[..., 0][mask] = 0 # Hue doesn't matter for black hsv_image[..., 1][mask] = 0 # Saturation set to 0 hsv_image[..., 2][mask] = 0 # Value set to 0, making it black mask_tensor_org = sketch2[:,:,0]/255 hsv_image[mask_tensor_org>=0.5] = [0,0,1] # Convert the HSV image back to RGB to display and save rgb_image = hsv_to_rgb(hsv_image) if len(classes) > 1: # Calculate centroids and render class names for i, class_name in enumerate(classes): mask = class_indices == i if np.any(mask): y, x = np.nonzero(mask) centroid_x, centroid_y = np.mean(x), np.mean(y) plt.text(centroid_x, centroid_y, class_name, color=classes_colors[i], ha='center', va='center',fontsize=10, # color=classes_colors[i] bbox=dict(facecolor='lightgrey', edgecolor='none', boxstyle='round,pad=0.2', alpha=0.8)) # Display the image with class names plt.imshow(rgb_image) plt.axis('off') plt.tight_layout() # plt.savefig(f'poster_vis/{classes[0]}.png', bbox_inches='tight', pad_inches=0) plt.savefig('output.png', bbox_inches='tight', pad_inches=0) plt.close() # rgb_image = Image.open(f'poster_vis/{classes[0]}.png') rgb_image = Image.open('output.png') return rgb_image scripts = """ async () => { // START gallery format // Get all image elements with the class "image" var images = document.querySelectorAll('.image_gallery'); var originalParent = document.querySelector('#component-0'); // Create a new parent div element var parentDiv = document.createElement('div'); var beforeDiv= document.querySelector('.table-wrap').parentElement; parentDiv.id = "gallery_container"; // Loop through each image, append it to the parent div, and remove it from its original parent images.forEach(function(image , index ) { // Append the image to the parent div parentDiv.appendChild(image); // Add click event listener to each image image.addEventListener('click', function() { let nth_ch = index+1 document.querySelector('.tr-body:nth-child(' + nth_ch + ')').click() console.log('.tr-body:nth-child(' + nth_ch + ')'); }); // Remove the image from its original parent }); // Get a reference to the original parent of the images var originalParent = document.querySelector('#component-0'); // Append the new parent div to the original parent originalParent.insertBefore(parentDiv, beforeDiv); // END gallery format // START confidence span // Get the selected div (replace 'selectedDivId' with the actual ID of your div) var selectedDiv = document.querySelector("label[for='range_id_0'] > span") // Get the text content of the div var textContent = selectedDiv.textContent; // Find the text before the first colon ':' var colonIndex = textContent.indexOf(':'); var textBeforeColon = textContent.substring(0, colonIndex); // Wrap the text before colon with a span element var spanElement = document.createElement('span'); spanElement.textContent = textBeforeColon; // Replace the original text with the modified text containing the span selectedDiv.innerHTML = textContent.replace(textBeforeColon, spanElement.outerHTML); // START format the column names : // Get all elements with the class "test_class" var elements = document.querySelectorAll('.tr-head > th'); // Iterate over each element elements.forEach(function(element) { // Get the text content of the element var text = element.textContent.trim(); // Remove ":" from the text var wordWithoutColon = text.replace(':', ''); // Split the text into words var words = wordWithoutColon.split(' '); // Keep only the first word var firstWord = words[0]; // Set the text content of the element to the first word element.textContent = firstWord; }); document.querySelector('input[type=number]').disabled = true; } """ css=""" gradio-app { background-color: white !important; } .white-bg { background-color: white !important; } .gray-border { border: 1px solid dimgrey !important; } .border-radius { border-radius: 8px !important; } .black-text { color : black !important; } th { color : black !important; } tr { background-color: white !important; color: black !important; } td { border-bottom : 1px solid black !important; } label[data-testid="block-label"] { background: white; color: black; font-weight: bold; } .controls-wrap button:disabled { color: gray !important; background-color: white !important; } .controls-wrap button:not(:disabled) { color: black !important; background-color: white !important; } .source-wrap button { color: black !important; } .toolbar-wrap button { color: black !important; } .empty.wrap { color: black !important; } textarea { background-color : #f7f9f8 !important; color : #afb0b1 !important } input[data-testid="number-input"] { background-color : #f7f9f8 !important; color : black !important } tr > th { border-bottom : 1px solid black !important; } tr:hover { background: #f7f9f8 !important; } #component-19{ justify-content: center !important; } #component-19 > button { flex: none !important; background-color : black !important; font-weight: bold !important; } .bold { font-weight: bold !important; } span[data-testid="block-info"]{ color: black !important; font-weight: bold !important; } #component-14 > div { background-color : white !important; } button[aria-label="Clear"] { background-color : white !important; color: black !important; } #gallery_container { display: flex; flex-wrap: wrap; justify-content: start; } .image_gallery { margin-bottom: 1rem; margin-right: 1rem; } label[for='range_id_0'] > span > span { text-decoration: underline; } label[for='range_id_0'] > span > span { font-size: normal !important; } .underline { text-decoration: underline; } .mt-mb-1{ margin-top: 1rem; margin-bottom: 1rem; } #gallery_container + div { visibility: hidden; height: 10px; } input[type=number][disabled] { background-color: rgb(247, 249, 248) !important; color: black !important; -webkit-text-fill-color: black !important; } #component-13 { display: flex; flex-direction: column; align-items: center; } """ with gr.Blocks(js=scripts, css=css, theme='gstaff/xkcd') as demo: gr.HTML("

Open Vocabulary Scene Sketch Semantic Understanding") gr.HTML("
") # gr.HTML("
Ahmed Bourouis,Judith Ellen Fan, Yulia Gryaditskaya
") gr.HTML("
Ahmed Bourouis, Judith Ellen Fan, Yulia Gryaditskaya
") gr.HTML("
CVPR, 2024

") gr.HTML("

Project page

") # gr.Markdown( "Scene Sketch Semantic Segmentation.", elem_classes=["black-txt" , "h1"] ) # gr.Markdown( "Open Vocabulary Scene Sketch Semantic Understanding", elem_classes=["black-txt" , "p"] ) # gr.Markdown( "Open Vocabulary Scene Sketch Semantic Understanding", elem_classes=["black-txt" , "p"] ) # gr.Markdown( "") with gr.Row(): with gr.Column(): # in_image = gr.Image( label="Sketch", type="pil", sources="upload" , height=512 ) in_canvas_image = gr.Sketchpad( # value=Image.new('RGB', (512, 512), color=(255, 255, 255)), brush=gr.Brush(colors=["#000000"], color_mode="fixed" , default_size=2), image_mode="RGBA",elem_classes=["white-bg", "gray-border" , "border-radius" ,"own-shadow" ] , label="Sketch" , canvas_size=(512,512) ,sources=['upload'], interactive=True , layers= False, transforms=[] ) query_selector = 'button[aria-label="Upload button"]' # with gr.Row(): # segment_btn.click(fn=run, inputs=[in_image, in_textbox, in_slider], outputs=[out_image]) upload_draw_btn = gr.HTML(f"""
""") # in_textbox = gr.Textbox( lines=2, elem_classes=["white-bg", "gray-border" , "border-radius" ,"own-shadow" ] ,label="Caption your Sketch!", placeholder="Include the categories that you want the AI to segment. \n e.g. 'giraffe, clouds' or 'a boy flying a kite' ") with gr.Column(): out_image = gr.Image( value=Image.new('RGB', (512, 512), color=(255, 255, 255)), elem_classes=["white-bg", "gray-border" , "border-radius" ,"own-shadow" ] , type="pil", label="Segmented Sketch" ) #, height=512, width=512) # # gr.HTML("

Confidence: Adjust AI agent confidence in guessing categories

") # in_slider = gr.Slider(elem_classes=["white-bg", "gray-border" , "border-radius" ,"own-shadow" ] , # info="Adjust AI agent confidence in guessing categories", # label="Confidence:", # value=0.5 , interactive=True, step=0.05, minimum=0, maximum=1) with gr.Row(): with gr.Column(): in_textbox = gr.Textbox( lines=2, elem_classes=["white-bg", "gray-border" , "border-radius" ,"own-shadow" ] ,label="Caption your Sketch!", placeholder="Include the categories that you want the AI to segment. \n e.g. 'giraffe, clouds' or 'a boy flying a kite' ") with gr.Column(): # gr.HTML("

Confidence: Adjust AI agent confidence in guessing categories ") in_slider = gr.Slider(elem_classes=["white-bg", "gray-border" , "border-radius" ,"own-shadow" ] , info="Adjust AI agent confidence in guessing categories", label="Confidence:", value=0.5 , interactive=True, step=0.05, minimum=0, maximum=1) with gr.Row(): segment_btn = gr.Button( 'Segment it¹ !' , elem_classes=["white-bg", "gray-border" , "border-radius" ,"own-shadow" , 'bold' , 'mt-mb-1' ] , size="sm") segment_btn.click(fn=run, inputs=[in_canvas_image , in_textbox , in_slider ], outputs=[out_image]) gallery_label = gr.HTML("

Gallery: you can click on any of the example sketches below to start segmenting them (or even drawing over them) ") gallery= gr.HTML(f"""
{gr.Image( elem_classes=["image_gallery"] , label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/sketch_1.png', height=200, width=200)} {gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/sketch_2.png', height=200, width=200)} {gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/sketch_3.png', height=200, width=200)} {gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000004068.png', height=200, width=200)} {gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000004546.png', height=200, width=200)} {gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000005076.png', height=200, width=200)} {gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000006336.png', height=200, width=200)} {gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000011766.png', height=200, width=200)} {gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000024458.png', height=200, width=200)} {gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000024931.png', height=200, width=200)} {gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000034214.png', height=200, width=200)} {gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000260974.png', height=200, width=200)} {gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000268340.png', height=200, width=200)} {gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000305414.png', height=200, width=200)} {gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000484246.png', height=200, width=200)} {gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000549338.png', height=200, width=200)} {gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000038116.png', height=200, width=200)} {gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000221509.png', height=200, width=200)} {gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000246066.png', height=200, width=200)} {gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000001611.png', height=200, width=200)}
""") examples = gr.Examples( examples_per_page=30, examples=[ ['demo/sketch_1.png', 'giraffe looking at you', 0.6], ['demo/sketch_2.png', 'a kite flying in the sky', 0.6], ['demo/sketch_3.png', 'a girl playing', 0.6], ['demo/000000004068.png', 'car going so fast', 0.6], ['demo/000000004546.png', 'mountains in the background', 0.6], ['demo/000000005076.png', 'huge tree', 0.6], ['demo/000000006336.png', 'nice three sheeps', 0.6], ['demo/000000011766.png', 'bird minding its own business', 0.6], ['demo/000000024458.png', 'horse with a mask on', 0.6], ['demo/000000024931.png', 'some random person', 0.6], ['demo/000000034214.png', 'a cool kid on a skateboard', 0.6], ['demo/000000260974.png', 'the chair on the left', 0.6], ['demo/000000268340.png', 'stop sign', 0.6], ['demo/000000305414.png', 'a lonely elephant roaming around', 0.6], ['demo/000000484246.png', 'giraffe with a loong neck', 0.6], ['demo/000000549338.png', 'two donkeys trying to be smart', 0.6], ['demo/000000038116.png', 'a bat next to a kid', 0.6], ['demo/000000221509.png', 'funny looking cow', 0.6], ['demo/000000246066.png', 'bench in the park', 0.6], ['demo/000000001611.png', 'trees in the background', 0.6] ], inputs=[in_canvas_image, in_textbox , in_slider], fn=run, # cache_examples=True, ) gr.HTML("

¹This demo runs on a basic 2 vCPU. For instant segmentation, use a commercial Nvidia RTX 3090 GPU.
") gr.HTML("
¹We compare the entire caption to the scene sketch and threshold most similar pixels, without extracting individual classes.
") demo.launch(share=False)