simdi commited on
Commit
9b34f86
1 Parent(s): 24e6f9a

initial commit

Browse files
Files changed (3) hide show
  1. handler.py +157 -0
  2. pre-requirements.txt +1 -0
  3. requirements.txt +3 -0
handler.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers import (
3
+ StableDiffusionControlNetPipeline,
4
+ ControlNetModel,
5
+ EulerAncestralDiscreteScheduler,
6
+ )
7
+ from typing import Dict, List, Any
8
+
9
+ import qrcode
10
+ import os
11
+ import base64
12
+ from io import BytesIO
13
+
14
+ MODEL_ID = "simdi/colorful_qr"
15
+ WIDTH = 768
16
+ HEIGHT = 768
17
+
18
+ WEIGHT_PAIRS = [
19
+ (0.25, 0.20),
20
+ (0.25, 0.25),
21
+ (0.35, 0.20),
22
+ (0.35, 0.25),
23
+ (0.45, 0.20),
24
+ (0.45, 0.25),
25
+ ]
26
+
27
+
28
+ def float_to_pair_index(f: float):
29
+ length = len(WEIGHT_PAIRS)
30
+ # If f is less than length, convert to integer and use directly
31
+ if f < length:
32
+ return int(f)
33
+ # If f is greater or equal to length, assume it's a proportion of the length
34
+ else:
35
+ # Ensuring f is between 0 and 1
36
+ f = max(0.0, min(f, 1.0))
37
+ # Convert the float to an index
38
+ index = int(f * length)
39
+ # Make sure the index is in the valid range
40
+ index = min(index, length - 1)
41
+ return index
42
+
43
+
44
+ def select_weight_pair(f: float):
45
+ return WEIGHT_PAIRS[float_to_pair_index(f)]
46
+
47
+
48
+ def load_models():
49
+ controlnet_tile = ControlNetModel.from_pretrained(
50
+ "lllyasviel/control_v11f1e_sd15_tile",
51
+ torch_dtype=torch.float16,
52
+ )
53
+
54
+ controlnet_brightness = ControlNetModel.from_pretrained(
55
+ "ioclab/control_v1p_sd15_brightness",
56
+ torch_dtype=torch.float16,
57
+ )
58
+
59
+ pipe = StableDiffusionControlNetPipeline.from_pretrained(
60
+ MODEL_ID,
61
+ controlnet=[
62
+ controlnet_tile,
63
+ controlnet_brightness,
64
+ ],
65
+ torch_dtype=torch.float16,
66
+ cache_dir="cache",
67
+ # local_files_only=True,
68
+ ).to("cuda")
69
+
70
+ pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
71
+ pipe.enable_xformers_memory_efficient_attention()
72
+ return pipe
73
+
74
+
75
+ def resize_for_condition_image(input_image: Image.Image, resolution: int):
76
+ input_image = input_image.convert("RGB")
77
+ W, H = input_image.size
78
+ k = float(resolution) / min(H, W)
79
+ H *= k
80
+ W *= k
81
+ H = int(round(H / 64.0)) * 64
82
+ W = int(round(W / 64.0)) * 64
83
+ img = input_image.resize((W, H), resample=Image.LANCZOS)
84
+ return img
85
+
86
+
87
+ def generate_qr_code(content: str):
88
+ qrcode_generator = qrcode.QRCode(
89
+ version=1,
90
+ error_correction=qrcode.ERROR_CORRECT_H,
91
+ box_size=10,
92
+ border=2,
93
+ )
94
+ qrcode_generator.clear()
95
+ qrcode_generator.add_data(content)
96
+ qrcode_generator.make(fit=True)
97
+ img = qrcode_generator.make_image(fill_color="black", back_color="white")
98
+ img = resize_for_condition_image(img, 768)
99
+ return img
100
+
101
+
102
+ def image_to_base64(image: Image.Image):
103
+ buffered = BytesIO()
104
+ image.save(buffered, format="PNG")
105
+ return base64.b64encode(buffered.getvalue()).decode("utf-8")
106
+
107
+
108
+ def generate_image_with_conditioning_scale(**inputs):
109
+ styles = inputs["styles"]
110
+ pair = inputs["pair"]
111
+ pipe = inputs["pipe"]
112
+ qr_image = inputs["qr_image"]
113
+ generator = inputs["generator"]
114
+
115
+ images = pipe(
116
+ prompt=styles,
117
+ negative_prompt=[""] * len(styles),
118
+ width=WIDTH,
119
+ height=HEIGHT,
120
+ guidance_scale=7.0,
121
+ generator=generator,
122
+ num_inference_steps=25,
123
+ num_images_per_prompt=2,
124
+ controlnet_conditioning_scale=pair,
125
+ image=[qr_image] * 2,
126
+ ).images
127
+
128
+ return [{"data": image_to_base64(image), "format": "png"} for image in images]
129
+
130
+
131
+ def generate_image(pipe, inputs):
132
+ styles = inputs["styles"]
133
+ content = inputs["content"]
134
+ art_scale = inputs["art_scale"]
135
+
136
+ with torch.inference_mode():
137
+ with torch.autocast("cuda"):
138
+
139
+ qr_image = generate_qr_code(content)
140
+ generator = torch.Generator()
141
+ pair = select_weight_pair(art_scale)
142
+ return generate_image_with_conditioning_scale(
143
+ styles=styles,
144
+ pair=pair,
145
+ pipe=pipe,
146
+ qr_image=qr_image,
147
+ generator=generator,
148
+ )
149
+
150
+
151
+ class EndpointHandler:
152
+ def __init__(self, path=""):
153
+ self._model = load_models()
154
+
155
+ def __call__(self, model_input: Dict[str, Any]) -> List[Dict[str, Any]]:
156
+ images = generate_image(self._model, model_input)
157
+ return images
pre-requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ libzbar0
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ qrcode
2
+ pyzbar
3
+ pillow