Spaces:
Runtime error
Runtime error
import jax | |
import jax.numpy as jnp | |
from flax import jax_utils | |
from flax.training.common_utils import shard | |
from PIL import Image | |
from argparse import Namespace | |
import gradio as gr | |
from diffusers import ( | |
FlaxControlNetModel, | |
FlaxStableDiffusionControlNetPipeline, | |
) | |
args = Namespace( | |
pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5", | |
revision="non-ema", | |
from_pt=True, | |
controlnet_model_name_or_path="Vincent-luo/controlnet-hands", | |
controlnet_revision=None, | |
controlnet_from_pt=False, | |
) | |
weight_dtype = jnp.float32 | |
controlnet, controlnet_params = FlaxControlNetModel.from_pretrained( | |
args.controlnet_model_name_or_path, | |
revision=args.controlnet_revision, | |
from_pt=args.controlnet_from_pt, | |
dtype=jnp.float32, | |
) | |
pipeline, pipeline_params = FlaxStableDiffusionControlNetPipeline.from_pretrained( | |
args.pretrained_model_name_or_path, | |
# tokenizer=tokenizer, | |
controlnet=controlnet, | |
safety_checker=None, | |
dtype=weight_dtype, | |
revision=args.revision, | |
from_pt=args.from_pt, | |
) | |
pipeline_params["controlnet"] = controlnet_params | |
pipeline_params = jax_utils.replicate(pipeline_params) | |
rng = jax.random.PRNGKey(0) | |
num_samples = jax.device_count() | |
prng_seed = jax.random.split(rng, jax.device_count()) | |
def infer(prompt, negative_prompt, image): | |
prompts = num_samples * [prompt] | |
prompt_ids = pipeline.prepare_text_inputs(prompts) | |
prompt_ids = shard(prompt_ids) | |
validation_image = Image.fromarray(image).convert("RGB") | |
processed_image = pipeline.prepare_image_inputs(num_samples * [validation_image]) | |
processed_image = shard(processed_image) | |
negative_prompt_ids = pipeline.prepare_text_inputs([negative_prompt] * num_samples) | |
negative_prompt_ids = shard(negative_prompt_ids) | |
images = pipeline( | |
prompt_ids=prompt_ids, | |
image=processed_image, | |
params=pipeline_params, | |
prng_seed=prng_seed, | |
num_inference_steps=50, | |
neg_prompt_ids=negative_prompt_ids, | |
jit=True, | |
).images | |
images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:]) | |
return images[0] | |
with gr.Blocks(theme='gradio/soft') as demo: | |
gr.Markdown("## Stable Diffusion with Hand Control") | |
gr.Markdown("In this app, you can find different ControlNets with different filters. ") | |
with gr.Column(): | |
prompt_input = gr.Textbox(label="Prompt") | |
negative_prompt = gr.Textbox(label="Negative Prompt") | |
input_image = gr.Image(label="Input Image") | |
output_image = gr.Image(label="Output Image") | |
submit_btn = gr.Button(value = "Submit") | |
inputs = [prompt_input, negative_prompt, input_image] | |
submit_btn.click(fn=infer, inputs=inputs, outputs=[output_image]) | |
demo.launch() |