import jax import jax.numpy as jnp from flax import jax_utils from flax.training.common_utils import shard from PIL import Image from argparse import Namespace import gradio as gr from diffusers import ( FlaxControlNetModel, FlaxStableDiffusionControlNetPipeline, ) args = Namespace( pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5", revision="non-ema", from_pt=True, controlnet_model_name_or_path="Vincent-luo/controlnet-hands", controlnet_revision=None, controlnet_from_pt=False, ) weight_dtype = jnp.float32 controlnet, controlnet_params = FlaxControlNetModel.from_pretrained( args.controlnet_model_name_or_path, revision=args.controlnet_revision, from_pt=args.controlnet_from_pt, dtype=jnp.float32, ) pipeline, pipeline_params = FlaxStableDiffusionControlNetPipeline.from_pretrained( args.pretrained_model_name_or_path, # tokenizer=tokenizer, controlnet=controlnet, safety_checker=None, dtype=weight_dtype, revision=args.revision, from_pt=args.from_pt, ) pipeline_params["controlnet"] = controlnet_params pipeline_params = jax_utils.replicate(pipeline_params) rng = jax.random.PRNGKey(0) num_samples = jax.device_count() prng_seed = jax.random.split(rng, jax.device_count()) def infer(prompt, negative_prompt, image): prompts = num_samples * [prompt] prompt_ids = pipeline.prepare_text_inputs(prompts) prompt_ids = shard(prompt_ids) validation_image = Image.fromarray(image).convert("RGB") processed_image = pipeline.prepare_image_inputs(num_samples * [validation_image]) processed_image = shard(processed_image) negative_prompt_ids = pipeline.prepare_text_inputs([negative_prompt] * num_samples) negative_prompt_ids = shard(negative_prompt_ids) images = pipeline( prompt_ids=prompt_ids, image=processed_image, params=pipeline_params, prng_seed=prng_seed, num_inference_steps=50, neg_prompt_ids=negative_prompt_ids, jit=True, ).images images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:]) return images[0] with gr.Blocks(theme='gradio/soft') as demo: gr.Markdown("## Stable Diffusion with Hand Control") gr.Markdown("In this app, you can find different ControlNets with different filters. ") with gr.Column(): prompt_input = gr.Textbox(label="Prompt") negative_prompt = gr.Textbox(label="Negative Prompt") input_image = gr.Image(label="Input Image") output_image = gr.Image(label="Output Image") submit_btn = gr.Button(value = "Submit") inputs = [prompt_input, negative_prompt, input_image] submit_btn.click(fn=infer, inputs=inputs, outputs=[output_image]) demo.launch()