import jax import jax.numpy as jnp from flax import jax_utils from flax.training.common_utils import shard from PIL import Image from argparse import Namespace import gradio as gr import copy # added import numpy as np import mediapipe as mp from mediapipe import solutions from mediapipe.framework.formats import landmark_pb2 from mediapipe.tasks import python from mediapipe.tasks.python import vision import cv2 import psutil from gpuinfo import GPUInfo import time import gc import torch from diffusers import ( FlaxControlNetModel, FlaxStableDiffusionControlNetPipeline, ) right_style_lm = copy.deepcopy(solutions.drawing_styles.get_default_hand_landmarks_style()) left_style_lm = copy.deepcopy(solutions.drawing_styles.get_default_hand_landmarks_style()) right_style_lm[0].color=(251, 206, 177) left_style_lm[0].color=(255, 255, 225) def draw_landmarks_on_image(rgb_image, detection_result, overlap=False, hand_encoding=False): hand_landmarks_list = detection_result.hand_landmarks handedness_list = detection_result.handedness if overlap: annotated_image = np.copy(rgb_image) else: annotated_image = np.zeros_like(rgb_image) # Loop through the detected hands to visualize. for idx in range(len(hand_landmarks_list)): hand_landmarks = hand_landmarks_list[idx] handedness = handedness_list[idx] # Draw the hand landmarks. hand_landmarks_proto = landmark_pb2.NormalizedLandmarkList() hand_landmarks_proto.landmark.extend([ landmark_pb2.NormalizedLandmark(x=landmark.x, y=landmark.y, z=landmark.z) for landmark in hand_landmarks ]) if hand_encoding: if handedness[0].category_name == "Left": solutions.drawing_utils.draw_landmarks( annotated_image, hand_landmarks_proto, solutions.hands.HAND_CONNECTIONS, left_style_lm, solutions.drawing_styles.get_default_hand_connections_style()) if handedness[0].category_name == "Right": solutions.drawing_utils.draw_landmarks( annotated_image, hand_landmarks_proto, solutions.hands.HAND_CONNECTIONS, right_style_lm, solutions.drawing_styles.get_default_hand_connections_style()) else: solutions.drawing_utils.draw_landmarks( annotated_image, hand_landmarks_proto, solutions.hands.HAND_CONNECTIONS, solutions.drawing_styles.get_default_hand_landmarks_style(), solutions.drawing_styles.get_default_hand_connections_style()) return annotated_image def generate_annotation(img, overlap=False, hand_encoding=False): """img(input): numpy array annotated_image(output): numpy array """ # STEP 2: Create an HandLandmarker object. base_options = python.BaseOptions(model_asset_path='hand_landmarker.task') options = vision.HandLandmarkerOptions(base_options=base_options, num_hands=2) detector = vision.HandLandmarker.create_from_options(options) # STEP 3: Load the input image. image = mp.Image( image_format=mp.ImageFormat.SRGB, data=img) # STEP 4: Detect hand landmarks from the input image. detection_result = detector.detect(image) # STEP 5: Process the classification result. In this case, visualize it. annotated_image = draw_landmarks_on_image(image.numpy_view(), detection_result, overlap=overlap, hand_encoding=hand_encoding) return annotated_image std_args = Namespace( pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5", revision="non-ema", from_pt=True, controlnet_model_name_or_path="Vincent-luo/controlnet-hands", controlnet_revision=None, controlnet_from_pt=False, ) enc_args = Namespace( pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5", revision="non-ema", from_pt=True, controlnet_model_name_or_path="MakiPan/controlnet-encoded-hands-130k", controlnet_revision=None, controlnet_from_pt=False, ) std_controlnet, std_controlnet_params = FlaxControlNetModel.from_pretrained( std_args.controlnet_model_name_or_path, revision=std_args.controlnet_revision, from_pt=std_args.controlnet_from_pt, dtype=jnp.float32, # jnp.bfloat16 ) enc_controlnet, enc_controlnet_params = FlaxControlNetModel.from_pretrained( enc_args.controlnet_model_name_or_path, revision=enc_args.controlnet_revision, from_pt=enc_args.controlnet_from_pt, dtype=jnp.float32, # jnp.bfloat16 ) std_pipeline, std_pipeline_params = FlaxStableDiffusionControlNetPipeline.from_pretrained( std_args.pretrained_model_name_or_path, # tokenizer=tokenizer, controlnet=std_controlnet, safety_checker=None, dtype=jnp.float32, # jnp.bfloat16 revision=std_args.revision, from_pt=std_args.from_pt, ) enc_pipeline, enc_pipeline_params = FlaxStableDiffusionControlNetPipeline.from_pretrained( enc_args.pretrained_model_name_or_path, # tokenizer=tokenizer, controlnet=enc_controlnet, safety_checker=None, dtype=jnp.float32, # jnp.bfloat16 revision=enc_args.revision, from_pt=enc_args.from_pt, ) std_pipeline_params["controlnet"] = std_controlnet_params std_pipeline_params = jax_utils.replicate(std_pipeline_params) enc_pipeline_params["controlnet"] = enc_controlnet_params enc_pipeline_params = jax_utils.replicate(enc_pipeline_params) rng = jax.random.PRNGKey(0) num_samples = jax.device_count() prng_seed = jax.random.split(rng, jax.device_count()) memory = psutil.virtual_memory() def infer(prompt, negative_prompt, image, model_type="Standard"): time_start = time.time() prompts = num_samples * [prompt] if model_type=="Standard": prompt_ids = std_pipeline.prepare_text_inputs(prompts) elif model_type=="Hand Encoding": prompt_ids = enc_pipeline.prepare_text_inputs(prompts) else: pass prompt_ids = shard(prompt_ids) if model_type=="Standard": annotated_image = generate_annotation(image, overlap=False, hand_encoding=False) overlap_image = generate_annotation(image, overlap=True, hand_encoding=False) elif model_type=="Hand Encoding": annotated_image = generate_annotation(image, overlap=False, hand_encoding=True) overlap_image = generate_annotation(image, overlap=True, hand_encoding=True) else: pass validation_image = Image.fromarray(annotated_image).convert("RGB") if model_type=="Standard": processed_image = std_pipeline.prepare_image_inputs(num_samples * [validation_image]) processed_image = shard(processed_image) negative_prompt_ids = std_pipeline.prepare_text_inputs([negative_prompt] * num_samples) negative_prompt_ids = shard(negative_prompt_ids) images = std_pipeline( prompt_ids=prompt_ids, image=processed_image, params=std_pipeline_params, prng_seed=prng_seed, num_inference_steps=50, neg_prompt_ids=negative_prompt_ids, jit=True, ).images elif model_type=="Hand Encoding": processed_image = enc_pipeline.prepare_image_inputs(num_samples * [validation_image]) processed_image = shard(processed_image) negative_prompt_ids = enc_pipeline.prepare_text_inputs([negative_prompt] * num_samples) negative_prompt_ids = shard(negative_prompt_ids) images = enc_pipeline( prompt_ids=prompt_ids, image=processed_image, params=enc_pipeline_params, prng_seed=prng_seed, num_inference_steps=50, neg_prompt_ids=negative_prompt_ids, jit=True, ).images else: pass images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:]) results = [i for i in images] # running info time_end = time.time() time_diff = time_end - time_start gc.collect() torch.cuda.empty_cache() memory = psutil.virtual_memory() gpu_utilization, gpu_memory = GPUInfo.gpu_usage() gpu_utilization = gpu_utilization[0] if len(gpu_utilization) > 0 else 0 gpu_memory = gpu_memory[0] if len(gpu_memory) > 0 else 0 system_info = f""" *Memory: {memory.total / (1024 * 1024 * 1024):.2f}GB, used: {memory.percent}%, available: {memory.available / (1024 * 1024 * 1024):.2f}GB.* *Processing time: {time_diff:.5} seconds.* *GPU Utilization: {gpu_utilization}%, GPU Memory: {gpu_memory}MiB.* """ return [overlap_image, annotated_image] + results, system_info with gr.Blocks(theme='gradio/soft') as demo: gr.Markdown("## Stable Diffusion with Hand Control") gr.Markdown("This model is a ControlNet model using MediaPipe hand landmarks for control.") with gr.Box(): gr.Markdown("""

Summary 📋

""") with gr.Accordion("Detail information", open=False): gr.Markdown(""" As Stable diffusion and other diffusion models are notoriously poor at generating realistic hands for our project we decided to train a ControlNet model using MediaPipes landmarks in order to generate more realistic hands avoiding common issues such as unrealistic positions and irregular digits.
We opted to use the [HAnd Gesture Recognition Image Dataset](https://github.com/hukenovs/hagrid) (HaGRID) and [MediaPipe's Hand Landmarker](https://developers.google.com/mediapipe/solutions/vision/hand_landmarker) to train a control net that could potentially be used independently or as an in-painting tool. To preprocess the data there were three options we considered: We anecdotally determined that when trained at lower steps the encoded hand model performed better than the standard MediaPipe model due to implied handedness. We theorize that with a larger dataset of more full-body hand and pose classifications, Holistic landmarks will provide the best images in the future however for the moment the hand-encoded model performs best. """) # Information links with gr.Box(): gr.Markdown("""

Links 🔗

""") with gr.Accordion("Models 🚀", open=False): gr.Markdown("""

Standard Model

Model using Hand Encoding

""") with gr.Accordion("Datasets 💾", open=False): gr.Markdown("""

Dataset for Standard Model

Dataset for Hand Encoding Model

""") with gr.Accordion("Preprocessing Scripts 📑", open=False): gr.Markdown("""

Standard Data Preprocessing Script

Hand Encoding Data Preprocessing Script

""") # How to use model with gr.Box(): gr.Markdown("""

How to use ⌛️

""") with gr.Accordion("Generate image with ControlnetHand", open=True): gr.Markdown(""" - Step 1. Select preprocessing method (Standard or Hand encoding) - Step 2. Describe the image you want to create along with the hand details of the uploaded or captured image - Step 3. Provide a negative prompt that helps the model not to create redundant details - Step 4. Upload or capture by webcam a clear image of hands that are prominently visible in the foreground - Step 5. Submit and enjoy """) # Model input parameters model_type = gr.Radio(["Standard", "Hand Encoding"], value="Standard", label="Model preprocessing", info="We developed two models, one with standard MediaPipe landmarks, and one with different (but similar) coloring on palm landmarks to distinguish left and right") with gr.Row(): with gr.Column(): prompt_input = gr.Textbox(label="Prompt") negative_prompt = gr.Textbox(label="Negative Prompt") with gr.Box(): with gr.Tab("Upload Image"): upload_image = gr.Image(label="Upload Image", source="upload") with gr.Tab("Webcam"): webcam_image = gr.Image(label="Webcam", source="webcam") # output_image = gr.Gallery(label='Output Image', show_label=False, elem_id="gallery").style(grid=3, height='auto') submit_btn = gr.Button(value = "Submit") # inputs = [prompt_input, negative_prompt, input_image] # submit_btn.click(fn=infer, inputs=inputs, outputs=[output_image]) system_info = gr.Markdown(f"*Memory: {memory.total / (1024 * 1024 * 1024):.2f}GB, used: {memory.percent}%, available: {memory.available / (1024 * 1024 * 1024):.2f}GB*") with gr.Column(): output_image = gr.Gallery(label='Output Image', show_label=False, elem_id="gallery").style(grid=2, height='auto') gr.Examples( examples=[ [ "a woman is making an ok sign in front of a painting", "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality", "example.png" ], [ "a man with his hands up in the air making a rock sign", "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality", "example1.png" ], [ "a man is making a thumbs up gesture", "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality", "example2.png" ], [ "a woman is holding up her hand in front of a window", "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality", "example3.png" ], [ "a man with his finger on his lips", "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality", "example4.png" ], ], inputs=[prompt_input, negative_prompt, upload_image, model_type], outputs=[output_image, system_info], fn=infer, cache_examples=True, ) # check source of image if upload_image and webcam_image is None: input_image = upload_image else: input_image = webcam_image inputs = [prompt_input, negative_prompt, input_image, model_type] submit_btn.click(fn=infer, inputs=inputs, outputs=[output_image, system_info]) demo.launch()