qrcode_ai / handler.py
simdi's picture
Update handler.py
17c1dab verified
raw
history blame
4.21 kB
import torch
from diffusers import (
StableDiffusionControlNetPipeline,
ControlNetModel,
EulerAncestralDiscreteScheduler,
)
from typing import Dict, List, Any
import qrcode
import os
import base64
from io import BytesIO
from PIL import Image
MODEL_ID = "simdi/colorful_qr"
WIDTH = 768
HEIGHT = 768
WEIGHT_PAIRS = [
(0.25, 0.20),
(0.25, 0.25),
(0.35, 0.20),
(0.35, 0.25),
(0.45, 0.20),
(0.45, 0.25),
]
def float_to_pair_index(f: float):
length = len(WEIGHT_PAIRS)
# If f is less than length, convert to integer and use directly
if f < length:
return int(f)
# If f is greater or equal to length, assume it's a proportion of the length
else:
# Ensuring f is between 0 and 1
f = max(0.0, min(f, 1.0))
# Convert the float to an index
index = int(f * length)
# Make sure the index is in the valid range
index = min(index, length - 1)
return index
def select_weight_pair(f: float):
return WEIGHT_PAIRS[float_to_pair_index(f)]
def load_models():
controlnet_tile = ControlNetModel.from_pretrained(
"lllyasviel/control_v11f1e_sd15_tile",
torch_dtype=torch.float16,
)
controlnet_brightness = ControlNetModel.from_pretrained(
"ioclab/control_v1p_sd15_brightness",
torch_dtype=torch.float16,
)
pipe = StableDiffusionControlNetPipeline.from_pretrained(
MODEL_ID,
controlnet=[
controlnet_tile,
controlnet_brightness,
],
torch_dtype=torch.float16,
cache_dir="cache",
# local_files_only=True,
).to("cuda")
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
pipe.enable_xformers_memory_efficient_attention()
return pipe
def resize_for_condition_image(input_image, resolution: int):
input_image = input_image.convert("RGB")
W, H = input_image.size
k = float(resolution) / min(H, W)
H *= k
W *= k
H = int(round(H / 64.0)) * 64
W = int(round(W / 64.0)) * 64
img = input_image.resize((W, H), resample=Image.LANCZOS)
return img
def generate_qr_code(content: str):
qrcode_generator = qrcode.QRCode(
version=1,
error_correction=qrcode.ERROR_CORRECT_H,
box_size=10,
border=2,
)
qrcode_generator.clear()
qrcode_generator.add_data(content)
qrcode_generator.make(fit=True)
img = qrcode_generator.make_image(fill_color="black", back_color="white")
img = resize_for_condition_image(img, 768)
return img
def image_to_base64(image):
buffered = BytesIO()
image.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode("utf-8")
def generate_image_with_conditioning_scale(**inputs):
styles = inputs["styles"]
pair = inputs["pair"]
pipe = inputs["pipe"]
qr_image = inputs["qr_image"]
generator = inputs["generator"]
images = pipe(
prompt=styles,
negative_prompt=[""] * len(styles),
width=WIDTH,
height=HEIGHT,
guidance_scale=7.0,
generator=generator,
num_inference_steps=25,
num_images_per_prompt=2,
controlnet_conditioning_scale=pair,
image=[qr_image] * 2,
).images
return [{"data": image_to_base64(image), "format": "png"} for image in images]
def generate_image(pipe, inputs):
styles = inputs["styles"]
content = inputs["content"]
art_scale = inputs["art_scale"]
with torch.inference_mode():
with torch.autocast("cuda"):
qr_image = generate_qr_code(content)
generator = torch.Generator()
pair = select_weight_pair(art_scale)
return generate_image_with_conditioning_scale(
styles=styles,
pair=pair,
pipe=pipe,
qr_image=qr_image,
generator=generator,
)
class EndpointHandler:
def __init__(self, path=""):
self._model = load_models()
def __call__(self, model_input: Dict[str, Any]) -> List[Dict[str, Any]]:
images = generate_image(self._model, model_input)
return images