|
import torch |
|
from diffusers import ( |
|
StableDiffusionControlNetPipeline, |
|
ControlNetModel, |
|
EulerAncestralDiscreteScheduler, |
|
) |
|
from typing import Dict, List, Any |
|
|
|
import qrcode |
|
import os |
|
import base64 |
|
from io import BytesIO |
|
|
|
from pyzbar.pyzbar import decode |
|
from PIL import Image |
|
|
|
MODEL_ID = "simdi/colorful_qr" |
|
WIDTH = 768 |
|
HEIGHT = 768 |
|
|
|
WEIGHT_PAIRS = [ |
|
(0.25, 0.20), |
|
(0.25, 0.25), |
|
(0.35, 0.20), |
|
(0.35, 0.25), |
|
(0.45, 0.20), |
|
(0.45, 0.25), |
|
] |
|
|
|
|
|
def float_to_pair_index(f: float): |
|
length = len(WEIGHT_PAIRS) |
|
|
|
if f < length: |
|
return int(f) |
|
|
|
else: |
|
|
|
f = max(0.0, min(f, 1.0)) |
|
|
|
index = int(f * length) |
|
|
|
index = min(index, length - 1) |
|
return index |
|
|
|
|
|
def select_weight_pair(f: float): |
|
return WEIGHT_PAIRS[float_to_pair_index(f)] |
|
|
|
|
|
def load_models(): |
|
controlnet_tile = ControlNetModel.from_pretrained( |
|
"lllyasviel/control_v11f1e_sd15_tile", |
|
torch_dtype=torch.float16, |
|
) |
|
|
|
controlnet_brightness = ControlNetModel.from_pretrained( |
|
"ioclab/control_v1p_sd15_brightness", |
|
torch_dtype=torch.float16, |
|
) |
|
|
|
pipe = StableDiffusionControlNetPipeline.from_pretrained( |
|
MODEL_ID, |
|
controlnet=[ |
|
controlnet_tile, |
|
controlnet_brightness, |
|
], |
|
torch_dtype=torch.float16, |
|
cache_dir="cache", |
|
|
|
).to("cuda") |
|
|
|
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config) |
|
pipe.enable_xformers_memory_efficient_attention() |
|
return pipe |
|
|
|
|
|
def resize_for_condition_image(input_image, resolution: int): |
|
input_image = input_image.convert("RGB") |
|
W, H = input_image.size |
|
k = float(resolution) / min(H, W) |
|
H *= k |
|
W *= k |
|
H = int(round(H / 64.0)) * 64 |
|
W = int(round(W / 64.0)) * 64 |
|
img = input_image.resize((W, H), resample=Image.LANCZOS) |
|
return img |
|
|
|
|
|
def generate_qr_code(content: str): |
|
qrcode_generator = qrcode.QRCode( |
|
version=1, |
|
error_correction=qrcode.ERROR_CORRECT_H, |
|
box_size=10, |
|
border=2, |
|
) |
|
qrcode_generator.clear() |
|
qrcode_generator.add_data(content) |
|
qrcode_generator.make(fit=True) |
|
img = qrcode_generator.make_image(fill_color="black", back_color="white") |
|
img = resize_for_condition_image(img, 768) |
|
return img |
|
|
|
|
|
def image_to_base64(image): |
|
buffered = BytesIO() |
|
image.save(buffered, format="PNG") |
|
return base64.b64encode(buffered.getvalue()).decode("utf-8") |
|
|
|
|
|
def generate_image_with_conditioning_scale(**inputs): |
|
styles = inputs["styles"] |
|
pair = inputs["pair"] |
|
pipe = inputs["pipe"] |
|
qr_image = inputs["qr_image"] |
|
generator = inputs["generator"] |
|
|
|
images = pipe( |
|
prompt=styles, |
|
negative_prompt=[""] * len(styles), |
|
width=WIDTH, |
|
height=HEIGHT, |
|
guidance_scale=7.0, |
|
generator=generator, |
|
num_inference_steps=25, |
|
num_images_per_prompt=2, |
|
controlnet_conditioning_scale=pair, |
|
image=[qr_image] * 2, |
|
).images |
|
|
|
return [{"data": image_to_base64(image), "format": "png"} for image in images] |
|
|
|
|
|
def generate_image(pipe, inputs): |
|
styles = inputs["styles"] |
|
content = inputs["content"] |
|
art_scale = inputs["art_scale"] |
|
|
|
with torch.inference_mode(): |
|
with torch.autocast("cuda"): |
|
|
|
qr_image = generate_qr_code(content) |
|
generator = torch.Generator() |
|
pair = select_weight_pair(art_scale) |
|
return generate_image_with_conditioning_scale( |
|
styles=styles, |
|
pair=pair, |
|
pipe=pipe, |
|
qr_image=qr_image, |
|
generator=generator, |
|
) |
|
|
|
|
|
class EndpointHandler: |
|
def __init__(self, path=""): |
|
self._model = load_models() |
|
|
|
def __call__(self, model_input: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
images = generate_image(self._model, model_input) |
|
return images |
|
|