import gradio as gr import numpy as np import cv2 from PIL import Image import torch import base64 import requests import random import os from io import BytesIO from region_control import MultiDiffusion, get_views, preprocess_mask, seed_everything from sketch_helper import get_high_freq_colors, color_quantization, create_binary_matrix MAX_COLORS = 12 sd = MultiDiffusion("cuda", "2.1") is_shared_ui = True if "weizmannscience/multidiffusion-region-based" in os.environ['SPACE_ID'] else False is_gpu_associated = True if torch.cuda.is_available() else False canvas_html = "
" load_js = """ async () => { const url = "https://huggingface.co./datasets/radames/gradio-components/raw/main/sketch-canvas.js" fetch(url) .then(res => res.text()) .then(text => { const script = document.createElement('script'); script.type = "module" script.src = URL.createObjectURL(new Blob([text], { type: 'application/javascript' })); document.head.appendChild(script); }); } """ get_js_colors = """ async (canvasData) => { const canvasEl = document.getElementById("canvas-root"); return [canvasEl._data] } """ set_canvas_size =""" async (aspect) => { if(aspect ==='square'){ _updateCanvas(512,512) } if(aspect ==='horizontal'){ _updateCanvas(768,512) } if(aspect ==='vertical'){ _updateCanvas(512,768) } } """ def process_sketch(canvas_data, binary_matrixes): binary_matrixes.clear() base64_img = canvas_data['image'] image_data = base64.b64decode(base64_img.split(',')[1]) image = Image.open(BytesIO(image_data)).convert("RGB") im2arr = np.array(image) colors = [tuple(map(int, rgb[4:-1].split(','))) for rgb in canvas_data['colors']] colors_fixed = [] for color in colors: r, g, b = color if any(c != 255 for c in (r, g, b)): binary_matrix = create_binary_matrix(im2arr, (r,g,b)) binary_matrixes.append(binary_matrix) colors_fixed.append(gr.update(value=f'
')) visibilities = [] colors = [] for n in range(MAX_COLORS): visibilities.append(gr.update(visible=False)) colors.append(gr.update(value=f'
')) for n in range(len(colors_fixed)): visibilities[n] = gr.update(visible=True) colors[n] = colors_fixed[n] return [gr.update(visible=True), binary_matrixes, *visibilities, *colors] def process_generation(model, binary_matrixes, boostrapping, aspect, steps, seed, master_prompt, negative_prompt, *prompts): global sd if(model != "stabilityai/stable-diffusion-2-1-base"): sd = MultiDiffusion("cuda", model) if(seed == -1): seed = random.randint(1, 2147483647) seed_everything(seed) dimensions = {"square": (512, 512), "horizontal": (768, 512), "vertical": (512, 768)} width, height = dimensions.get(aspect, dimensions["square"]) clipped_prompts = prompts[:len(binary_matrixes)] prompts = [master_prompt] + list(clipped_prompts) neg_prompts = [negative_prompt] * len(prompts) fg_masks = torch.cat([preprocess_mask(mask_path, height // 8, width // 8, "cuda") for mask_path in binary_matrixes]) bg_mask = 1 - torch.sum(fg_masks, dim=0, keepdim=True) bg_mask[bg_mask < 0] = 0 masks = torch.cat([bg_mask, fg_masks]) print(masks.size()) image = sd.generate(masks, prompts, neg_prompts, height, width, steps, bootstrapping=boostrapping) return(image) css = ''' #color-bg{display:flex;justify-content: center;align-items: center;} .color-bg-item{width: 100%; height: 32px} #main_button{width:100%} ") aspect.change(None, inputs=[aspect], outputs=None, _js = set_canvas_size) button_run.click(process_sketch, inputs=[canvas_data, binary_matrixes], outputs=[post_sketch, binary_matrixes, *color_row, *colors], _js=get_js_colors, queue=False) final_run_btn.click(process_generation, inputs=[model, binary_matrixes, boostrapping, aspect, steps, seed, general_prompt, negative_prompt, *prompts], outputs=out_image) demo.load(None, None, None, _js=load_js) demo.launch(debug=True)