import jax import jax.numpy as jnp from flax import jax_utils from flax.training.common_utils import shard from PIL import Image from argparse import Namespace import gradio as gr import copy # added import numpy as np import mediapipe as mp from mediapipe import solutions from mediapipe.framework.formats import landmark_pb2 from mediapipe.tasks import python from mediapipe.tasks.python import vision import cv2 import psutil from gpuinfo import GPUInfo import time import gc import torch from diffusers import ( FlaxControlNetModel, FlaxStableDiffusionControlNetPipeline, ) right_style_lm = copy.deepcopy(solutions.drawing_styles.get_default_hand_landmarks_style()) left_style_lm = copy.deepcopy(solutions.drawing_styles.get_default_hand_landmarks_style()) right_style_lm[0].color=(251, 206, 177) left_style_lm[0].color=(255, 255, 225) def draw_landmarks_on_image(rgb_image, detection_result, overlap=False, hand_encoding=False): hand_landmarks_list = detection_result.hand_landmarks handedness_list = detection_result.handedness if overlap: annotated_image = np.copy(rgb_image) else: annotated_image = np.zeros_like(rgb_image) # Loop through the detected hands to visualize. for idx in range(len(hand_landmarks_list)): hand_landmarks = hand_landmarks_list[idx] handedness = handedness_list[idx] # Draw the hand landmarks. hand_landmarks_proto = landmark_pb2.NormalizedLandmarkList() hand_landmarks_proto.landmark.extend([ landmark_pb2.NormalizedLandmark(x=landmark.x, y=landmark.y, z=landmark.z) for landmark in hand_landmarks ]) if hand_encoding: if handedness[0].category_name == "Left": solutions.drawing_utils.draw_landmarks( annotated_image, hand_landmarks_proto, solutions.hands.HAND_CONNECTIONS, left_style_lm, solutions.drawing_styles.get_default_hand_connections_style()) if handedness[0].category_name == "Right": solutions.drawing_utils.draw_landmarks( annotated_image, hand_landmarks_proto, solutions.hands.HAND_CONNECTIONS, right_style_lm, solutions.drawing_styles.get_default_hand_connections_style()) else: solutions.drawing_utils.draw_landmarks( annotated_image, hand_landmarks_proto, solutions.hands.HAND_CONNECTIONS, solutions.drawing_styles.get_default_hand_landmarks_style(), solutions.drawing_styles.get_default_hand_connections_style()) return annotated_image def generate_annotation(img, overlap=False, hand_encoding=False): """img(input): numpy array annotated_image(output): numpy array """ # STEP 2: Create an HandLandmarker object. base_options = python.BaseOptions(model_asset_path='hand_landmarker.task') options = vision.HandLandmarkerOptions(base_options=base_options, num_hands=2) detector = vision.HandLandmarker.create_from_options(options) # STEP 3: Load the input image. image = mp.Image( image_format=mp.ImageFormat.SRGB, data=img) # STEP 4: Detect hand landmarks from the input image. detection_result = detector.detect(image) # STEP 5: Process the classification result. In this case, visualize it. annotated_image = draw_landmarks_on_image(image.numpy_view(), detection_result, overlap=overlap, hand_encoding=hand_encoding) return annotated_image std_args = Namespace( pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5", revision="non-ema", from_pt=True, controlnet_model_name_or_path="Vincent-luo/controlnet-hands", controlnet_revision=None, controlnet_from_pt=False, ) enc_args = Namespace( pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5", revision="non-ema", from_pt=True, controlnet_model_name_or_path="MakiPan/controlnet-encoded-hands-130k", controlnet_revision=None, controlnet_from_pt=False, ) std_controlnet, std_controlnet_params = FlaxControlNetModel.from_pretrained( std_args.controlnet_model_name_or_path, revision=std_args.controlnet_revision, from_pt=std_args.controlnet_from_pt, dtype=jnp.float32, # jnp.bfloat16 ) enc_controlnet, enc_controlnet_params = FlaxControlNetModel.from_pretrained( enc_args.controlnet_model_name_or_path, revision=enc_args.controlnet_revision, from_pt=enc_args.controlnet_from_pt, dtype=jnp.float32, # jnp.bfloat16 ) std_pipeline, std_pipeline_params = FlaxStableDiffusionControlNetPipeline.from_pretrained( std_args.pretrained_model_name_or_path, # tokenizer=tokenizer, controlnet=std_controlnet, safety_checker=None, dtype=jnp.float32, # jnp.bfloat16 revision=std_args.revision, from_pt=std_args.from_pt, ) enc_pipeline, enc_pipeline_params = FlaxStableDiffusionControlNetPipeline.from_pretrained( enc_args.pretrained_model_name_or_path, # tokenizer=tokenizer, controlnet=enc_controlnet, safety_checker=None, dtype=jnp.float32, # jnp.bfloat16 revision=enc_args.revision, from_pt=enc_args.from_pt, ) std_pipeline_params["controlnet"] = std_controlnet_params std_pipeline_params = jax_utils.replicate(std_pipeline_params) enc_pipeline_params["controlnet"] = enc_controlnet_params enc_pipeline_params = jax_utils.replicate(enc_pipeline_params) rng = jax.random.PRNGKey(0) num_samples = jax.device_count() prng_seed = jax.random.split(rng, jax.device_count()) memory = psutil.virtual_memory() def infer(prompt, negative_prompt, image, model_type="Standard"): time_start = time.time() prompts = num_samples * [prompt] if model_type=="Standard": prompt_ids = std_pipeline.prepare_text_inputs(prompts) elif model_type=="Hand Encoding": prompt_ids = enc_pipeline.prepare_text_inputs(prompts) else: pass prompt_ids = shard(prompt_ids) if model_type=="Standard": annotated_image = generate_annotation(image, overlap=False, hand_encoding=False) overlap_image = generate_annotation(image, overlap=True, hand_encoding=False) elif model_type=="Hand Encoding": annotated_image = generate_annotation(image, overlap=False, hand_encoding=True) overlap_image = generate_annotation(image, overlap=True, hand_encoding=True) else: pass validation_image = Image.fromarray(annotated_image).convert("RGB") if model_type=="Standard": processed_image = std_pipeline.prepare_image_inputs(num_samples * [validation_image]) processed_image = shard(processed_image) negative_prompt_ids = std_pipeline.prepare_text_inputs([negative_prompt] * num_samples) negative_prompt_ids = shard(negative_prompt_ids) images = std_pipeline( prompt_ids=prompt_ids, image=processed_image, params=std_pipeline_params, prng_seed=prng_seed, num_inference_steps=50, neg_prompt_ids=negative_prompt_ids, jit=True, ).images elif model_type=="Hand Encoding": processed_image = enc_pipeline.prepare_image_inputs(num_samples * [validation_image]) processed_image = shard(processed_image) negative_prompt_ids = enc_pipeline.prepare_text_inputs([negative_prompt] * num_samples) negative_prompt_ids = shard(negative_prompt_ids) images = enc_pipeline( prompt_ids=prompt_ids, image=processed_image, params=enc_pipeline_params, prng_seed=prng_seed, num_inference_steps=50, neg_prompt_ids=negative_prompt_ids, jit=True, ).images else: pass images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:]) results = [i for i in images] # running info time_end = time.time() time_diff = time_end - time_start gc.collect() torch.cuda.empty_cache() memory = psutil.virtual_memory() gpu_utilization, gpu_memory = GPUInfo.gpu_usage() gpu_utilization = gpu_utilization[0] if len(gpu_utilization) > 0 else 0 gpu_memory = gpu_memory[0] if len(gpu_memory) > 0 else 0 system_info = f""" *Memory: {memory.total / (1024 * 1024 * 1024):.2f}GB, used: {memory.percent}%, available: {memory.available / (1024 * 1024 * 1024):.2f}GB.* *Processing time: {time_diff:.5} seconds.* *GPU Utilization: {gpu_utilization}%, GPU Memory: {gpu_memory}MiB.* """ return [overlap_image, annotated_image] + results, system_info with gr.Blocks(theme='gradio/soft') as demo: gr.Markdown("## Stable Diffusion with Hand Control") gr.Markdown("This model is a ControlNet model using MediaPipe hand landmarks for control.") with gr.Box(): gr.Markdown("""
|
|
|
|