import gradio as gr
import numpy as np
import cv2
from PIL import Image
import torch
import base64
import requests
import random
import os
from io import BytesIO
from region_control import MultiDiffusion, get_views, preprocess_mask, seed_everything
from sketch_helper import get_high_freq_colors, color_quantization, create_binary_matrix
MAX_COLORS = 12
sd = MultiDiffusion("cuda", "2.1")
is_shared_ui = True if "weizmannscience/multidiffusion-region-based" in os.environ['SPACE_ID'] else False
is_gpu_associated = True if torch.cuda.is_available() else False
canvas_html = "
"
load_js = """
async () => {
const url = "https://huggingface.co./datasets/radames/gradio-components/raw/main/sketch-canvas.js"
fetch(url)
.then(res => res.text())
.then(text => {
const script = document.createElement('script');
script.type = "module"
script.src = URL.createObjectURL(new Blob([text], { type: 'application/javascript' }));
document.head.appendChild(script);
});
}
"""
get_js_colors = """
async (canvasData) => {
const canvasEl = document.getElementById("canvas-root");
return [canvasEl._data]
}
"""
set_canvas_size ="""
async (aspect) => {
if(aspect ==='square'){
_updateCanvas(512,512)
}
if(aspect ==='horizontal'){
_updateCanvas(768,512)
}
if(aspect ==='vertical'){
_updateCanvas(512,768)
}
}
"""
def process_sketch(canvas_data, binary_matrixes):
binary_matrixes.clear()
base64_img = canvas_data['image']
image_data = base64.b64decode(base64_img.split(',')[1])
image = Image.open(BytesIO(image_data)).convert("RGB")
im2arr = np.array(image)
colors = [tuple(map(int, rgb[4:-1].split(','))) for rgb in canvas_data['colors']]
colors_fixed = []
for color in colors:
r, g, b = color
if any(c != 255 for c in (r, g, b)):
binary_matrix = create_binary_matrix(im2arr, (r,g,b))
binary_matrixes.append(binary_matrix)
colors_fixed.append(gr.update(value=f''))
visibilities = []
colors = []
for n in range(MAX_COLORS):
visibilities.append(gr.update(visible=False))
colors.append(gr.update(value=f''))
for n in range(len(colors_fixed)):
visibilities[n] = gr.update(visible=True)
colors[n] = colors_fixed[n]
return [gr.update(visible=True), binary_matrixes, *visibilities, *colors]
def process_generation(model, binary_matrixes, boostrapping, aspect, steps, seed, master_prompt, negative_prompt, *prompts):
global sd
if(model != "stabilityai/stable-diffusion-2-1-base"):
sd = MultiDiffusion("cuda", model)
if(seed == -1):
seed = random.randint(1, 2147483647)
seed_everything(seed)
dimensions = {"square": (512, 512), "horizontal": (768, 512), "vertical": (512, 768)}
width, height = dimensions.get(aspect, dimensions["square"])
clipped_prompts = prompts[:len(binary_matrixes)]
prompts = [master_prompt] + list(clipped_prompts)
neg_prompts = [negative_prompt] * len(prompts)
fg_masks = torch.cat([preprocess_mask(mask_path, height // 8, width // 8, "cuda") for mask_path in binary_matrixes])
bg_mask = 1 - torch.sum(fg_masks, dim=0, keepdim=True)
bg_mask[bg_mask < 0] = 0
masks = torch.cat([bg_mask, fg_masks])
print(masks.size())
image = sd.generate(masks, prompts, neg_prompts, height, width, steps, bootstrapping=boostrapping)
return(image)
css = '''
#color-bg{display:flex;justify-content: center;align-items: center;}
.color-bg-item{width: 100%; height: 32px}
#main_button{width:100%}
")
aspect.change(None, inputs=[aspect], outputs=None, _js = set_canvas_size)
button_run.click(process_sketch, inputs=[canvas_data, binary_matrixes], outputs=[post_sketch, binary_matrixes, *color_row, *colors], _js=get_js_colors, queue=False)
final_run_btn.click(process_generation, inputs=[model, binary_matrixes, boostrapping, aspect, steps, seed, general_prompt, negative_prompt, *prompts], outputs=out_image)
demo.load(None, None, None, _js=load_js)
demo.launch(debug=True)