import jax import jax.numpy as jnp from flax import jax_utils from flax.training.common_utils import shard from PIL import Image from argparse import Namespace import gradio as gr import numpy as np import mediapipe as mp from mediapipe import solutions from mediapipe.framework.formats import landmark_pb2 from mediapipe.tasks import python from mediapipe.tasks.python import vision import cv2 from diffusers import ( FlaxControlNetModel, FlaxStableDiffusionControlNetPipeline, ) # mediapipe annotation MARGIN = 10 # pixels FONT_SIZE = 1 FONT_THICKNESS = 1 HANDEDNESS_TEXT_COLOR = (88, 205, 54) # vibrant green def draw_landmarks_on_image(rgb_image, detection_result): hand_landmarks_list = detection_result.hand_landmarks handedness_list = detection_result.handedness annotated_image = np.zeros_like(rgb_image) # Loop through the detected hands to visualize. for idx in range(len(hand_landmarks_list)): hand_landmarks = hand_landmarks_list[idx] handedness = handedness_list[idx] # Draw the hand landmarks. hand_landmarks_proto = landmark_pb2.NormalizedLandmarkList() hand_landmarks_proto.landmark.extend([ landmark_pb2.NormalizedLandmark(x=landmark.x, y=landmark.y, z=landmark.z) for landmark in hand_landmarks ]) solutions.drawing_utils.draw_landmarks( annotated_image, hand_landmarks_proto, solutions.hands.HAND_CONNECTIONS, solutions.drawing_styles.get_default_hand_landmarks_style(), solutions.drawing_styles.get_default_hand_connections_style()) return annotated_image def generate_annotation(img): """img(input): numpy array annotated_image(output): numpy array """ # STEP 2: Create an HandLandmarker object. base_options = python.BaseOptions(model_asset_path='hand_landmarker.task') options = vision.HandLandmarkerOptions(base_options=base_options, num_hands=2) detector = vision.HandLandmarker.create_from_options(options) # STEP 3: Load the input image. image = mp.Image( image_format=mp.ImageFormat.SRGB, data=img) # STEP 4: Detect hand landmarks from the input image. detection_result = detector.detect(image) # STEP 5: Process the classification result. In this case, visualize it. annotated_image = draw_landmarks_on_image(image.numpy_view(), detection_result) return annotated_image args = Namespace( pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5", revision="non-ema", from_pt=True, controlnet_model_name_or_path="Vincent-luo/controlnet-hands", controlnet_revision=None, controlnet_from_pt=False, ) controlnet, controlnet_params = FlaxControlNetModel.from_pretrained( args.controlnet_model_name_or_path, revision=args.controlnet_revision, from_pt=args.controlnet_from_pt, dtype=jnp.float32, # jnp.bfloat16 ) pipeline, pipeline_params = FlaxStableDiffusionControlNetPipeline.from_pretrained( args.pretrained_model_name_or_path, # tokenizer=tokenizer, controlnet=controlnet, safety_checker=None, dtype=jnp.float32, # jnp.bfloat16 revision=args.revision, from_pt=args.from_pt, ) pipeline_params["controlnet"] = controlnet_params pipeline_params = jax_utils.replicate(pipeline_params) rng = jax.random.PRNGKey(0) num_samples = jax.device_count() prng_seed = jax.random.split(rng, jax.device_count()) def infer(prompt, negative_prompt, image): prompts = num_samples * [prompt] prompt_ids = pipeline.prepare_text_inputs(prompts) prompt_ids = shard(prompt_ids) annotated_image = generate_annotation(image) validation_image = Image.fromarray(annotated_image).convert("RGB") processed_image = pipeline.prepare_image_inputs(num_samples * [validation_image]) processed_image = shard(processed_image) negative_prompt_ids = pipeline.prepare_text_inputs([negative_prompt] * num_samples) negative_prompt_ids = shard(negative_prompt_ids) images = pipeline( prompt_ids=prompt_ids, image=processed_image, params=pipeline_params, prng_seed=prng_seed, num_inference_steps=50, neg_prompt_ids=negative_prompt_ids, jit=True, ).images images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:]) results = [i for i in images] return [annotated_image] + results with gr.Blocks(theme='gradio/soft') as demo: gr.Markdown("## Stable Diffusion with Hand Control") gr.Markdown("This model is a ControlNet model using MediaPipe hand landmarks for control.") with gr.Row(): with gr.Column(): prompt_input = gr.Textbox(label="Prompt") negative_prompt = gr.Textbox(label="Negative Prompt") input_image = gr.Image(label="Input Image") # output_image = gr.Gallery(label='Output Image', show_label=False, elem_id="gallery").style(grid=3, height='auto') submit_btn = gr.Button(value = "Submit") # inputs = [prompt_input, negative_prompt, input_image] # submit_btn.click(fn=infer, inputs=inputs, outputs=[output_image]) with gr.Column(): output_image = gr.Gallery(label='Output Image', show_label=False, elem_id="gallery").style(grid=2, height='auto') gr.Examples( examples=[ [ "a woman is making an ok sign in front of a painting", "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality", "example.png" ], [ "a man with his hands up in the air making a rock sign", "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality", "example1.png" ], [ "a man is making a thumbs up gesture", "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality", "example2.png" ], [ "a woman is holding up her hand in front of a window", "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality", "example3.png" ], [ "a man with his finger on his lips", "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality", "example4.png" ], ], inputs=[prompt_input, negative_prompt, input_image], outputs=[output_image], fn=infer, cache_examples=True, ) inputs = [prompt_input, negative_prompt, input_image] submit_btn.click(fn=infer, inputs=inputs, outputs=[output_image]) demo.launch()