File size: 4,380 Bytes
affe6d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import random
import numpy as np
from PIL import Image
import base64
from io import BytesIO

import torch
import torchvision.transforms.functional as F
from diffusers import ControlNetModel, StableDiffusionControlNetPipeline
import gradio as gr

device = "mps"   # Linux & Windows
weight_type = torch.float16  # torch.float16 works as well, but pictures seem to be a bit worse

controlnet = ControlNetModel.from_pretrained(
    "IDKiro/sdxs-512-dreamshaper-sketch", torch_dtype=weight_type
).to(device)
pipe = StableDiffusionControlNetPipeline.from_pretrained(
    "IDKiro/sdxs-512-dreamshaper", controlnet=controlnet, torch_dtype=weight_type
)
pipe.to(device)

style_list = [
    {
        "name": "No Style",
        "prompt": "{prompt}",
    },
    {
        "name": "Cinematic",
        "prompt": "cinematic still {prompt} . emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy",
    },
    # Additional styles omitted for brevity
]

styles = {k["name"]: k["prompt"] for k in style_list}
STYLE_NAMES = list(styles.keys())
DEFAULT_STYLE_NAME = "No Style"
MAX_SEED = np.iinfo(np.int32).max


def pil_image_to_data_url(img, format="PNG"):
    buffered = BytesIO()
    img.save(buffered, format=format)
    img_str = base64.b64encode(buffered.getvalue()).decode()
    return f"data:image/{format.lower()};base64,{img_str}"


def run(
    image, 
    prompt, 
    prompt_template, 
    style_name, 
    controlnet_conditioning_scale,
    device_type="GPU",
    param_dtype='torch.float16',
):
    if device_type == "CPU":
        device = "cpu" 
        param_dtype = 'torch.float32'
    else:
        device = "cuda"
    
    pipe.to(torch_device=device, torch_dtype=torch.float16 if param_dtype == 'torch.float16' else torch.float32)

    print(f"prompt: {prompt}")
    if image is None:
        ones = Image.new("L", (512, 512), 255)
        temp_url = pil_image_to_data_url(ones)
        return ones, gr.update(link=temp_url), gr.update(link=temp_url)
    prompt = prompt_template.replace("{prompt}", prompt)
    control_image = image.convert("RGB")
    control_image = Image.fromarray(255 - np.array(control_image))

    output_pil = pipe(
        prompt=prompt,
        image=control_image,
        width=512,
        height=512,
        guidance_scale=0.0,
        num_inference_steps=1,
        num_images_per_prompt=1,
        output_type="pil",
        controlnet_conditioning_scale=controlnet_conditioning_scale,
    ).images[0]

    input_image_url = pil_image_to_data_url(control_image)
    output_image_url = pil_image_to_data_url(output_pil)
    return (
        output_pil,
        gr.update(link=input_image_url),
        gr.update(link=output_image_url),
    )


with gr.Blocks(css="style.css") as demo:
    gr.Markdown("# SDXS-512-DreamShaper-Webcam")
    with gr.Row():
        with gr.Column():
            gr.Markdown("## INPUT")
            # Replace canvas with webcam image
            image = gr.Image(
                source="webcam", type="pil", label="Webcam Image", interactive=True
            )

            prompt = gr.Textbox(label="Prompt", value="", show_label=True)
            style = gr.Dropdown(label="Style", choices=STYLE_NAMES, value=DEFAULT_STYLE_NAME)
            prompt_template = gr.Textbox(label="Prompt Style Template", value=styles[DEFAULT_STYLE_NAME])

            controlnet_conditioning_scale = gr.Slider(label="Control Strength", minimum=0, maximum=1, step=0.01, value=0.8)

            device_choices = ['GPU','CPU']
            device_type = gr.Radio(device_choices, label='Device', value=device_choices[0], interactive=True)

            dtype_choices = ['torch.float16','torch.float32']
            param_dtype = gr.Radio(dtype_choices, label='torch.weight_type', value=dtype_choices[0], interactive=True)

        with gr.Column():
            gr.Markdown("## OUTPUT")
            result = gr.Image(label="Result", show_label=False, show_download_button=True)

    inputs = [image, prompt, prompt_template, style, controlnet_conditioning_scale, device_type, param_dtype]
    outputs = [result]
    prompt.submit(fn=run, inputs=inputs, outputs=outputs)
    style.change(lambda x: styles[x], inputs=[style], outputs=[prompt_template])
    image.change(run, inputs=inputs, outputs=outputs)

if __name__ == "__main__":
    demo.queue().launch(debug=True)