import jax import jax.numpy as jnp from flax import jax_utils from flax.training.common_utils import shard from PIL import Image from argparse import Namespace import gradio as gr import numpy as np import mediapipe as mp from mediapipe import solutions from mediapipe.framework.formats import landmark_pb2 from mediapipe.tasks import python from mediapipe.tasks.python import vision import cv2 from diffusers import ( FlaxControlNetModel, FlaxStableDiffusionControlNetPipeline, ) # mediapipe annotation MARGIN = 10 # pixels FONT_SIZE = 1 FONT_THICKNESS = 1 HANDEDNESS_TEXT_COLOR = (88, 205, 54) # vibrant green def draw_landmarks_on_image(rgb_image, detection_result): hand_landmarks_list = detection_result.hand_landmarks handedness_list = detection_result.handedness annotated_image = np.zeros_like(rgb_image) # Loop through the detected hands to visualize. for idx in range(len(hand_landmarks_list)): hand_landmarks = hand_landmarks_list[idx] handedness = handedness_list[idx] # Draw the hand landmarks. hand_landmarks_proto = landmark_pb2.NormalizedLandmarkList() hand_landmarks_proto.landmark.extend([ landmark_pb2.NormalizedLandmark(x=landmark.x, y=landmark.y, z=landmark.z) for landmark in hand_landmarks ]) solutions.drawing_utils.draw_landmarks( annotated_image, hand_landmarks_proto, solutions.hands.HAND_CONNECTIONS, solutions.drawing_styles.get_default_hand_landmarks_style(), solutions.drawing_styles.get_default_hand_connections_style()) return annotated_image def generate_annotation(img): """img(input): numpy array annotated_image(output): numpy array """ # STEP 2: Create an HandLandmarker object. base_options = python.BaseOptions(model_asset_path='hand_landmarker.task') options = vision.HandLandmarkerOptions(base_options=base_options, num_hands=2) detector = vision.HandLandmarker.create_from_options(options) # STEP 3: Load the input image. image = mp.Image( image_format=mp.ImageFormat.SRGB, data=img) # STEP 4: Detect hand landmarks from the input image. detection_result = detector.detect(image) # STEP 5: Process the classification result. In this case, visualize it. annotated_image = draw_landmarks_on_image(image.numpy_view(), detection_result) return annotated_image args = Namespace( pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5", revision="non-ema", from_pt=True, controlnet_model_name_or_path="Vincent-luo/controlnet-hands", controlnet_revision=None, controlnet_from_pt=False, ) controlnet, controlnet_params = FlaxControlNetModel.from_pretrained( args.controlnet_model_name_or_path, revision=args.controlnet_revision, from_pt=args.controlnet_from_pt, dtype=jnp.bfloat16, ) pipeline, pipeline_params = FlaxStableDiffusionControlNetPipeline.from_pretrained( args.pretrained_model_name_or_path, # tokenizer=tokenizer, controlnet=controlnet, safety_checker=None, dtype=jnp.bfloat16, revision=args.revision, from_pt=args.from_pt, ) pipeline_params["controlnet"] = controlnet_params pipeline_params = jax_utils.replicate(pipeline_params) rng = jax.random.PRNGKey(0) num_samples = jax.device_count() prng_seed = jax.random.split(rng, jax.device_count()) def infer(prompt, negative_prompt, image): prompts = num_samples * [prompt] prompt_ids = pipeline.prepare_text_inputs(prompts) prompt_ids = shard(prompt_ids) annotated_image = generate_annotation(image) validation_image = Image.fromarray(annotated_image).convert("RGB") processed_image = pipeline.prepare_image_inputs(num_samples * [validation_image]) processed_image = shard(processed_image) negative_prompt_ids = pipeline.prepare_text_inputs([negative_prompt] * num_samples) negative_prompt_ids = shard(negative_prompt_ids) images = pipeline( prompt_ids=prompt_ids, image=processed_image, params=pipeline_params, prng_seed=prng_seed, num_inference_steps=50, neg_prompt_ids=negative_prompt_ids, jit=True, ).images images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:]) results = [i for i in images] return [annotated_image] + results with gr.Blocks(theme='gradio/soft') as demo: gr.Markdown("## Stable Diffusion with Hand Control") with gr.Column(): prompt_input = gr.Textbox(label="Prompt") negative_prompt = gr.Textbox(label="Negative Prompt") input_image = gr.Image(label="Input Image") output_image = gr.Gallery(label='Output Image', show_label=False, elem_id="gallery").style(grid=3, height='auto') submit_btn = gr.Button(value = "Submit") inputs = [prompt_input, negative_prompt, input_image] submit_btn.click(fn=infer, inputs=inputs, outputs=[output_image]) demo.launch()