Spaces:
Runtime error
Runtime error
code modifications to add the hyperlinks to model and dataset
#3
by
Jayabalambika
- opened
- app.py +78 -272
- requirements.txt +1 -3
app.py
CHANGED
@@ -5,7 +5,7 @@ from flax.training.common_utils import shard
|
|
5 |
from PIL import Image
|
6 |
from argparse import Namespace
|
7 |
import gradio as gr
|
8 |
-
|
9 |
import numpy as np
|
10 |
import mediapipe as mp
|
11 |
from mediapipe import solutions
|
@@ -13,64 +13,44 @@ from mediapipe.framework.formats import landmark_pb2
|
|
13 |
from mediapipe.tasks import python
|
14 |
from mediapipe.tasks.python import vision
|
15 |
import cv2
|
16 |
-
import psutil
|
17 |
-
from gpuinfo import GPUInfo
|
18 |
-
import time
|
19 |
-
import gc
|
20 |
-
import torch
|
21 |
|
22 |
from diffusers import (
|
23 |
FlaxControlNetModel,
|
24 |
FlaxStableDiffusionControlNetPipeline,
|
25 |
)
|
26 |
-
right_style_lm = copy.deepcopy(solutions.drawing_styles.get_default_hand_landmarks_style())
|
27 |
-
left_style_lm = copy.deepcopy(solutions.drawing_styles.get_default_hand_landmarks_style())
|
28 |
-
right_style_lm[0].color=(251, 206, 177)
|
29 |
-
left_style_lm[0].color=(255, 255, 225)
|
30 |
-
|
31 |
-
def draw_landmarks_on_image(rgb_image, detection_result, overlap=False, hand_encoding=False):
|
32 |
-
hand_landmarks_list = detection_result.hand_landmarks
|
33 |
-
handedness_list = detection_result.handedness
|
34 |
-
if overlap:
|
35 |
-
annotated_image = np.copy(rgb_image)
|
36 |
-
else:
|
37 |
-
annotated_image = np.zeros_like(rgb_image)
|
38 |
|
39 |
-
# Loop through the detected hands to visualize.
|
40 |
-
for idx in range(len(hand_landmarks_list)):
|
41 |
-
hand_landmarks = hand_landmarks_list[idx]
|
42 |
-
handedness = handedness_list[idx]
|
43 |
-
# Draw the hand landmarks.
|
44 |
-
hand_landmarks_proto = landmark_pb2.NormalizedLandmarkList()
|
45 |
-
hand_landmarks_proto.landmark.extend([
|
46 |
-
landmark_pb2.NormalizedLandmark(x=landmark.x, y=landmark.y, z=landmark.z) for landmark in hand_landmarks
|
47 |
-
])
|
48 |
-
if hand_encoding:
|
49 |
-
if handedness[0].category_name == "Left":
|
50 |
-
solutions.drawing_utils.draw_landmarks(
|
51 |
-
annotated_image,
|
52 |
-
hand_landmarks_proto,
|
53 |
-
solutions.hands.HAND_CONNECTIONS,
|
54 |
-
left_style_lm,
|
55 |
-
solutions.drawing_styles.get_default_hand_connections_style())
|
56 |
-
if handedness[0].category_name == "Right":
|
57 |
-
solutions.drawing_utils.draw_landmarks(
|
58 |
-
annotated_image,
|
59 |
-
hand_landmarks_proto,
|
60 |
-
solutions.hands.HAND_CONNECTIONS,
|
61 |
-
right_style_lm,
|
62 |
-
solutions.drawing_styles.get_default_hand_connections_style())
|
63 |
-
else:
|
64 |
-
solutions.drawing_utils.draw_landmarks(
|
65 |
-
annotated_image,
|
66 |
-
hand_landmarks_proto,
|
67 |
-
solutions.hands.HAND_CONNECTIONS,
|
68 |
-
solutions.drawing_styles.get_default_hand_landmarks_style(),
|
69 |
-
solutions.drawing_styles.get_default_hand_connections_style())
|
70 |
-
|
71 |
-
return annotated_image
|
72 |
|
73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
"""img(input): numpy array
|
75 |
annotated_image(output): numpy array
|
76 |
"""
|
@@ -88,260 +68,91 @@ def generate_annotation(img, overlap=False, hand_encoding=False):
|
|
88 |
detection_result = detector.detect(image)
|
89 |
|
90 |
# STEP 5: Process the classification result. In this case, visualize it.
|
91 |
-
annotated_image = draw_landmarks_on_image(image.numpy_view(), detection_result
|
92 |
return annotated_image
|
93 |
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
controlnet_revision=None,
|
102 |
-
controlnet_from_pt=False,
|
103 |
-
)
|
104 |
-
enc_args = Namespace(
|
105 |
-
pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5",
|
106 |
-
revision="non-ema",
|
107 |
-
from_pt=True,
|
108 |
-
controlnet_model_name_or_path="MakiPan/controlnet-encoded-hands-130k",
|
109 |
-
controlnet_revision=None,
|
110 |
-
controlnet_from_pt=False,
|
111 |
-
)
|
112 |
-
|
113 |
-
std_controlnet, std_controlnet_params = FlaxControlNetModel.from_pretrained(
|
114 |
-
std_args.controlnet_model_name_or_path,
|
115 |
-
revision=std_args.controlnet_revision,
|
116 |
-
from_pt=std_args.controlnet_from_pt,
|
117 |
-
dtype=jnp.float32, # jnp.bfloat16
|
118 |
-
)
|
119 |
-
enc_controlnet, enc_controlnet_params = FlaxControlNetModel.from_pretrained(
|
120 |
-
enc_args.controlnet_model_name_or_path,
|
121 |
-
revision=enc_args.controlnet_revision,
|
122 |
-
from_pt=enc_args.controlnet_from_pt,
|
123 |
-
dtype=jnp.float32, # jnp.bfloat16
|
124 |
)
|
125 |
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
# tokenizer=tokenizer,
|
131 |
-
controlnet=std_controlnet,
|
132 |
-
safety_checker=None,
|
133 |
dtype=jnp.float32, # jnp.bfloat16
|
134 |
-
revision=std_args.revision,
|
135 |
-
from_pt=std_args.from_pt,
|
136 |
)
|
137 |
-
|
138 |
-
|
|
|
139 |
# tokenizer=tokenizer,
|
140 |
-
controlnet=
|
141 |
safety_checker=None,
|
142 |
dtype=jnp.float32, # jnp.bfloat16
|
143 |
-
revision=
|
144 |
-
from_pt=
|
145 |
)
|
146 |
|
147 |
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
enc_pipeline_params["controlnet"] = enc_controlnet_params
|
152 |
-
enc_pipeline_params = jax_utils.replicate(enc_pipeline_params)
|
153 |
|
154 |
rng = jax.random.PRNGKey(0)
|
155 |
num_samples = jax.device_count()
|
156 |
prng_seed = jax.random.split(rng, jax.device_count())
|
157 |
-
memory = psutil.virtual_memory()
|
158 |
|
159 |
-
|
160 |
-
|
161 |
prompts = num_samples * [prompt]
|
162 |
-
|
163 |
-
prompt_ids = std_pipeline.prepare_text_inputs(prompts)
|
164 |
-
elif model_type=="Hand Encoding":
|
165 |
-
prompt_ids = enc_pipeline.prepare_text_inputs(prompts)
|
166 |
-
else:
|
167 |
-
pass
|
168 |
prompt_ids = shard(prompt_ids)
|
169 |
|
170 |
-
|
171 |
-
annotated_image = generate_annotation(image, overlap=False, hand_encoding=False)
|
172 |
-
overlap_image = generate_annotation(image, overlap=True, hand_encoding=False)
|
173 |
-
elif model_type=="Hand Encoding":
|
174 |
-
annotated_image = generate_annotation(image, overlap=False, hand_encoding=True)
|
175 |
-
overlap_image = generate_annotation(image, overlap=True, hand_encoding=True)
|
176 |
-
|
177 |
-
else:
|
178 |
-
pass
|
179 |
validation_image = Image.fromarray(annotated_image).convert("RGB")
|
|
|
|
|
180 |
|
181 |
-
|
182 |
-
|
183 |
-
processed_image = shard(processed_image)
|
184 |
|
185 |
-
|
186 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
187 |
|
188 |
-
images = std_pipeline(
|
189 |
-
prompt_ids=prompt_ids,
|
190 |
-
image=processed_image,
|
191 |
-
params=std_pipeline_params,
|
192 |
-
prng_seed=prng_seed,
|
193 |
-
num_inference_steps=50,
|
194 |
-
neg_prompt_ids=negative_prompt_ids,
|
195 |
-
jit=True,
|
196 |
-
).images
|
197 |
-
elif model_type=="Hand Encoding":
|
198 |
-
processed_image = enc_pipeline.prepare_image_inputs(num_samples * [validation_image])
|
199 |
-
processed_image = shard(processed_image)
|
200 |
|
201 |
-
negative_prompt_ids = enc_pipeline.prepare_text_inputs([negative_prompt] * num_samples)
|
202 |
-
negative_prompt_ids = shard(negative_prompt_ids)
|
203 |
-
|
204 |
-
images = enc_pipeline(
|
205 |
-
prompt_ids=prompt_ids,
|
206 |
-
image=processed_image,
|
207 |
-
params=enc_pipeline_params,
|
208 |
-
prng_seed=prng_seed,
|
209 |
-
num_inference_steps=50,
|
210 |
-
neg_prompt_ids=negative_prompt_ids,
|
211 |
-
jit=True,
|
212 |
-
).images
|
213 |
-
|
214 |
-
else:
|
215 |
-
pass
|
216 |
images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
|
217 |
|
218 |
results = [i for i in images]
|
219 |
-
|
220 |
-
# running info
|
221 |
-
time_end = time.time()
|
222 |
-
time_diff = time_end - time_start
|
223 |
-
gc.collect()
|
224 |
-
torch.cuda.empty_cache()
|
225 |
-
memory = psutil.virtual_memory()
|
226 |
-
gpu_utilization, gpu_memory = GPUInfo.gpu_usage()
|
227 |
-
gpu_utilization = gpu_utilization[0] if len(gpu_utilization) > 0 else 0
|
228 |
-
gpu_memory = gpu_memory[0] if len(gpu_memory) > 0 else 0
|
229 |
-
system_info = f"""
|
230 |
-
*Memory: {memory.total / (1024 * 1024 * 1024):.2f}GB, used: {memory.percent}%, available: {memory.available / (1024 * 1024 * 1024):.2f}GB.*
|
231 |
-
*Processing time: {time_diff:.5} seconds.*
|
232 |
-
*GPU Utilization: {gpu_utilization}%, GPU Memory: {gpu_memory}MiB.*
|
233 |
-
"""
|
234 |
-
return [overlap_image, annotated_image] + results, system_info
|
235 |
|
236 |
|
237 |
with gr.Blocks(theme='gradio/soft') as demo:
|
238 |
gr.Markdown("## Stable Diffusion with Hand Control")
|
239 |
gr.Markdown("This model is a ControlNet model using MediaPipe hand landmarks for control.")
|
240 |
-
with gr.Box():
|
241 |
-
gr.Markdown("""<h2><b>Summary 📋</b></h2>""")
|
242 |
-
with gr.Accordion("Detail information", open=False):
|
243 |
-
gr.Markdown("""
|
244 |
-
As Stable diffusion and other diffusion models are notoriously poor at generating realistic hands for our project we decided to train a ControlNet model using MediaPipes landmarks in order to generate more realistic hands avoiding common issues such as unrealistic positions and irregular digits.
|
245 |
-
<br>
|
246 |
-
We opted to use the [HAnd Gesture Recognition Image Dataset](https://github.com/hukenovs/hagrid) (HaGRID) and [MediaPipe's Hand Landmarker](https://developers.google.com/mediapipe/solutions/vision/hand_landmarker) to train a control net that could potentially be used independently or as an in-painting tool.
|
247 |
-
To preprocess the data there were three options we considered:
|
248 |
-
<ul>
|
249 |
-
<li>The first was to use Mediapipes built-in draw landmarks function. This was an obvious first choice however we noticed with low training steps that the model couldn't easily distinguish handedness and would often generate the wrong hand for the conditioning image.</li>
|
250 |
-
<center>
|
251 |
-
<table><tr>
|
252 |
-
<td>
|
253 |
-
<p align="center" style="padding: 10px">
|
254 |
-
<img alt="Forwarding" src="https://datasets-server.huggingface.co/assets/MakiPan/hagrid250k-blip2/--/MakiPan--hagrid250k-blip2/train/29/image/image.jpg" width="200">
|
255 |
-
<br>
|
256 |
-
<em style="color: grey">Original Image</em>
|
257 |
-
</p>
|
258 |
-
</td>
|
259 |
-
<td>
|
260 |
-
<p align="center">
|
261 |
-
<img alt="Routing" src="https://datasets-server.huggingface.co/assets/MakiPan/hagrid250k-blip2/--/MakiPan--hagrid250k-blip2/train/29/conditioning_image/image.jpg" width="200">
|
262 |
-
<br>
|
263 |
-
<em style="color: grey">Conditioning Image</em>
|
264 |
-
</p>
|
265 |
-
</td>
|
266 |
-
</tr></table>
|
267 |
-
</center>
|
268 |
-
<li>To counter this issue we changed the palm landmark colors with the intention to keep the color similar in order to learn that they provide similar information, but different to make the model know which hands were left or right.</li>
|
269 |
-
<center>
|
270 |
-
<table><tr>
|
271 |
-
<td>
|
272 |
-
<p align="center" style="padding: 10px">
|
273 |
-
<img alt="Forwarding" src="https://datasets-server.huggingface.co/assets/MakiPan/hagrid-hand-enc-250k/--/MakiPan--hagrid-hand-enc-250k/train/96/image/image.jpg" width="200">
|
274 |
-
<br>
|
275 |
-
<em style="color: grey">Original Image</em>
|
276 |
-
</p>
|
277 |
-
</td>
|
278 |
-
<td>
|
279 |
-
<p align="center">
|
280 |
-
<img alt="Routing" src="https://datasets-server.huggingface.co/assets/MakiPan/hagrid-hand-enc-250k/--/MakiPan--hagrid-hand-enc-250k/train/96/conditioning_image/image.jpg" width="200">
|
281 |
-
<br>
|
282 |
-
<em style="color: grey">Conditioning Image</em>
|
283 |
-
</p>
|
284 |
-
</td>
|
285 |
-
</tr></table>
|
286 |
-
</center>
|
287 |
-
<li>The last option was to use <a href="https://ai.googleblog.com/2020/12/mediapipe-holistic-simultaneous-face.html">MediaPipe Holistic</a> to provide pose face and hand landmarks to the ControlNet. This method was promising in theory, however, the HaGRID dataset was not suitable for this method as the Holistic model performs poorly with partial body and obscurely cropped images.</li>
|
288 |
-
</ul>
|
289 |
-
We anecdotally determined that when trained at lower steps the encoded hand model performed better than the standard MediaPipe model due to implied handedness. We theorize that with a larger dataset of more full-body hand and pose classifications, Holistic landmarks will provide the best images in the future however for the moment the hand-encoded model performs best.
|
290 |
-
""")
|
291 |
-
|
292 |
-
# Information links
|
293 |
-
with gr.Box():
|
294 |
-
gr.Markdown("""<h2><b>Links 🔗</b></h2>""")
|
295 |
-
with gr.Accordion("Models 🚀", open=False):
|
296 |
-
gr.Markdown("""
|
297 |
-
<h4><a href="https://huggingface.co/Vincent-luo/controlnet-hands">Standard Model</a></h4>
|
298 |
-
<h4> <a href="https://huggingface.co/MakiPan/controlnet-encoded-hands-130k/">Model using Hand Encoding</a></h4>
|
299 |
-
""")
|
300 |
-
|
301 |
-
with gr.Accordion("Datasets 💾", open=False):
|
302 |
-
gr.Markdown("""
|
303 |
-
<h4> <a href="https://huggingface.co/datasets/MakiPan/hagrid250k-blip2">Dataset for Standard Model</a></h4>
|
304 |
-
<h4> <a href="https://huggingface.co/datasets/MakiPan/hagrid-hand-enc-250k">Dataset for Hand Encoding Model</a></h4>
|
305 |
-
""")
|
306 |
-
|
307 |
-
with gr.Accordion("Preprocessing Scripts 📑", open=False):
|
308 |
-
gr.Markdown("""
|
309 |
-
<h4> <a href="https://github.com/Maki-DS/Jax-Controlnet-hand-training/blob/main/normal-preprocessing.py">Standard Data Preprocessing Script</a></h4>
|
310 |
-
<h4> <a href="https://github.com/Maki-DS/Jax-Controlnet-hand-training/blob/main/Hand-encoded-preprocessing.py">Hand Encoding Data Preprocessing Script</a></h4></center>
|
311 |
-
""")
|
312 |
-
|
313 |
-
# How to use model
|
314 |
-
with gr.Box():
|
315 |
-
gr.Markdown("""<h2><b>How to use ⌛️</b></h2>""")
|
316 |
-
with gr.Accordion("Generate image with ControlnetHand", open=True):
|
317 |
-
gr.Markdown("""
|
318 |
-
- Step 1. Select preprocessing method (Standard or Hand encoding)
|
319 |
-
- Step 2. Describe the image you want to create along with the hand details of the uploaded or captured image
|
320 |
-
- Step 3. Provide a negative prompt that helps the model not to create redundant details
|
321 |
-
- Step 4. Upload or capture by webcam a clear image of hands that are prominently visible in the foreground
|
322 |
-
- Step 5. Submit and enjoy
|
323 |
-
""")
|
324 |
-
|
325 |
-
# Model input parameters
|
326 |
-
model_type = gr.Radio(["Standard", "Hand Encoding"], value="Standard", label="Model preprocessing", info="We developed two models, one with standard MediaPipe landmarks, and one with different (but similar) coloring on palm landmarks to distinguish left and right")
|
327 |
|
328 |
with gr.Row():
|
329 |
with gr.Column():
|
330 |
prompt_input = gr.Textbox(label="Prompt")
|
331 |
negative_prompt = gr.Textbox(label="Negative Prompt")
|
332 |
-
|
333 |
-
with gr.Tab("Upload Image"):
|
334 |
-
upload_image = gr.Image(label="Upload Image", source="upload")
|
335 |
-
with gr.Tab("Webcam"):
|
336 |
-
webcam_image = gr.Image(label="Webcam", source="webcam")
|
337 |
# output_image = gr.Gallery(label='Output Image', show_label=False, elem_id="gallery").style(grid=3, height='auto')
|
338 |
submit_btn = gr.Button(value = "Submit")
|
339 |
# inputs = [prompt_input, negative_prompt, input_image]
|
340 |
# submit_btn.click(fn=infer, inputs=inputs, outputs=[output_image])
|
341 |
-
|
342 |
with gr.Column():
|
343 |
output_image = gr.Gallery(label='Output Image', show_label=False, elem_id="gallery").style(grid=2, height='auto')
|
344 |
-
|
345 |
gr.Examples(
|
346 |
examples=[
|
347 |
[
|
@@ -370,18 +181,13 @@ with gr.Blocks(theme='gradio/soft') as demo:
|
|
370 |
"example4.png"
|
371 |
],
|
372 |
],
|
373 |
-
inputs=[prompt_input, negative_prompt,
|
374 |
-
outputs=[output_image
|
375 |
fn=infer,
|
376 |
cache_examples=True,
|
377 |
)
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
else:
|
382 |
-
input_image = webcam_image
|
383 |
-
|
384 |
-
inputs = [prompt_input, negative_prompt, input_image, model_type]
|
385 |
-
submit_btn.click(fn=infer, inputs=inputs, outputs=[output_image, system_info])
|
386 |
|
387 |
demo.launch()
|
|
|
5 |
from PIL import Image
|
6 |
from argparse import Namespace
|
7 |
import gradio as gr
|
8 |
+
|
9 |
import numpy as np
|
10 |
import mediapipe as mp
|
11 |
from mediapipe import solutions
|
|
|
13 |
from mediapipe.tasks import python
|
14 |
from mediapipe.tasks.python import vision
|
15 |
import cv2
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
from diffusers import (
|
18 |
FlaxControlNetModel,
|
19 |
FlaxStableDiffusionControlNetPipeline,
|
20 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
+
# mediapipe annotation
|
24 |
+
MARGIN = 10 # pixels
|
25 |
+
FONT_SIZE = 1
|
26 |
+
FONT_THICKNESS = 1
|
27 |
+
HANDEDNESS_TEXT_COLOR = (88, 205, 54) # vibrant green
|
28 |
+
|
29 |
+
def draw_landmarks_on_image(rgb_image, detection_result):
|
30 |
+
hand_landmarks_list = detection_result.hand_landmarks
|
31 |
+
handedness_list = detection_result.handedness
|
32 |
+
annotated_image = np.zeros_like(rgb_image)
|
33 |
+
|
34 |
+
# Loop through the detected hands to visualize.
|
35 |
+
for idx in range(len(hand_landmarks_list)):
|
36 |
+
hand_landmarks = hand_landmarks_list[idx]
|
37 |
+
handedness = handedness_list[idx]
|
38 |
+
|
39 |
+
# Draw the hand landmarks.
|
40 |
+
hand_landmarks_proto = landmark_pb2.NormalizedLandmarkList()
|
41 |
+
hand_landmarks_proto.landmark.extend([
|
42 |
+
landmark_pb2.NormalizedLandmark(x=landmark.x, y=landmark.y, z=landmark.z) for landmark in hand_landmarks
|
43 |
+
])
|
44 |
+
solutions.drawing_utils.draw_landmarks(
|
45 |
+
annotated_image,
|
46 |
+
hand_landmarks_proto,
|
47 |
+
solutions.hands.HAND_CONNECTIONS,
|
48 |
+
solutions.drawing_styles.get_default_hand_landmarks_style(),
|
49 |
+
solutions.drawing_styles.get_default_hand_connections_style())
|
50 |
+
|
51 |
+
return annotated_image
|
52 |
+
|
53 |
+
def generate_annotation(img):
|
54 |
"""img(input): numpy array
|
55 |
annotated_image(output): numpy array
|
56 |
"""
|
|
|
68 |
detection_result = detector.detect(image)
|
69 |
|
70 |
# STEP 5: Process the classification result. In this case, visualize it.
|
71 |
+
annotated_image = draw_landmarks_on_image(image.numpy_view(), detection_result)
|
72 |
return annotated_image
|
73 |
|
74 |
+
args = Namespace(
|
75 |
+
pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5",
|
76 |
+
revision="non-ema",
|
77 |
+
from_pt=True,
|
78 |
+
controlnet_model_name_or_path="Vincent-luo/controlnet-hands",
|
79 |
+
controlnet_revision=None,
|
80 |
+
controlnet_from_pt=False,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
)
|
82 |
|
83 |
+
controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
|
84 |
+
args.controlnet_model_name_or_path,
|
85 |
+
revision=args.controlnet_revision,
|
86 |
+
from_pt=args.controlnet_from_pt,
|
|
|
|
|
|
|
87 |
dtype=jnp.float32, # jnp.bfloat16
|
|
|
|
|
88 |
)
|
89 |
+
|
90 |
+
pipeline, pipeline_params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
|
91 |
+
args.pretrained_model_name_or_path,
|
92 |
# tokenizer=tokenizer,
|
93 |
+
controlnet=controlnet,
|
94 |
safety_checker=None,
|
95 |
dtype=jnp.float32, # jnp.bfloat16
|
96 |
+
revision=args.revision,
|
97 |
+
from_pt=args.from_pt,
|
98 |
)
|
99 |
|
100 |
|
101 |
+
pipeline_params["controlnet"] = controlnet_params
|
102 |
+
pipeline_params = jax_utils.replicate(pipeline_params)
|
|
|
|
|
|
|
103 |
|
104 |
rng = jax.random.PRNGKey(0)
|
105 |
num_samples = jax.device_count()
|
106 |
prng_seed = jax.random.split(rng, jax.device_count())
|
|
|
107 |
|
108 |
+
|
109 |
+
def infer(prompt, negative_prompt, image):
|
110 |
prompts = num_samples * [prompt]
|
111 |
+
prompt_ids = pipeline.prepare_text_inputs(prompts)
|
|
|
|
|
|
|
|
|
|
|
112 |
prompt_ids = shard(prompt_ids)
|
113 |
|
114 |
+
annotated_image = generate_annotation(image)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
validation_image = Image.fromarray(annotated_image).convert("RGB")
|
116 |
+
processed_image = pipeline.prepare_image_inputs(num_samples * [validation_image])
|
117 |
+
processed_image = shard(processed_image)
|
118 |
|
119 |
+
negative_prompt_ids = pipeline.prepare_text_inputs([negative_prompt] * num_samples)
|
120 |
+
negative_prompt_ids = shard(negative_prompt_ids)
|
|
|
121 |
|
122 |
+
images = pipeline(
|
123 |
+
prompt_ids=prompt_ids,
|
124 |
+
image=processed_image,
|
125 |
+
params=pipeline_params,
|
126 |
+
prng_seed=prng_seed,
|
127 |
+
num_inference_steps=50,
|
128 |
+
neg_prompt_ids=negative_prompt_ids,
|
129 |
+
jit=True,
|
130 |
+
).images
|
131 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
132 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
133 |
images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
|
134 |
|
135 |
results = [i for i in images]
|
136 |
+
return [annotated_image] + results
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
137 |
|
138 |
|
139 |
with gr.Blocks(theme='gradio/soft') as demo:
|
140 |
gr.Markdown("## Stable Diffusion with Hand Control")
|
141 |
gr.Markdown("This model is a ControlNet model using MediaPipe hand landmarks for control.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
142 |
|
143 |
with gr.Row():
|
144 |
with gr.Column():
|
145 |
prompt_input = gr.Textbox(label="Prompt")
|
146 |
negative_prompt = gr.Textbox(label="Negative Prompt")
|
147 |
+
input_image = gr.Image(label="Input Image")
|
|
|
|
|
|
|
|
|
148 |
# output_image = gr.Gallery(label='Output Image', show_label=False, elem_id="gallery").style(grid=3, height='auto')
|
149 |
submit_btn = gr.Button(value = "Submit")
|
150 |
# inputs = [prompt_input, negative_prompt, input_image]
|
151 |
# submit_btn.click(fn=infer, inputs=inputs, outputs=[output_image])
|
152 |
+
|
153 |
with gr.Column():
|
154 |
output_image = gr.Gallery(label='Output Image', show_label=False, elem_id="gallery").style(grid=2, height='auto')
|
155 |
+
|
156 |
gr.Examples(
|
157 |
examples=[
|
158 |
[
|
|
|
181 |
"example4.png"
|
182 |
],
|
183 |
],
|
184 |
+
inputs=[prompt_input, negative_prompt, input_image],
|
185 |
+
outputs=[output_image],
|
186 |
fn=infer,
|
187 |
cache_examples=True,
|
188 |
)
|
189 |
+
|
190 |
+
inputs = [prompt_input, negative_prompt, input_image]
|
191 |
+
submit_btn.click(fn=infer, inputs=inputs, outputs=[output_image])
|
|
|
|
|
|
|
|
|
|
|
192 |
|
193 |
demo.launch()
|
requirements.txt
CHANGED
@@ -7,6 +7,4 @@ git+https://github.com/huggingface/diffusers@main
|
|
7 |
opencv-python
|
8 |
torch
|
9 |
torchvision
|
10 |
-
mediapipe==0.9.1
|
11 |
-
gpuinfo
|
12 |
-
psutil
|
|
|
7 |
opencv-python
|
8 |
torch
|
9 |
torchvision
|
10 |
+
mediapipe==0.9.1
|
|
|
|