Vincent-luo's picture
Create app.py
b2c8596
raw
history blame
2.78 kB
import jax
import jax.numpy as jnp
from flax import jax_utils
from flax.training.common_utils import shard
from PIL import Image
from argparse import Namespace
import gradio as gr
from diffusers import (
FlaxControlNetModel,
FlaxStableDiffusionControlNetPipeline,
)
args = Namespace(
pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5",
revision="non-ema",
from_pt=True,
controlnet_model_name_or_path="Vincent-luo/controlnet-hands",
controlnet_revision=None,
controlnet_from_pt=False,
)
weight_dtype = jnp.float32
controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
args.controlnet_model_name_or_path,
revision=args.controlnet_revision,
from_pt=args.controlnet_from_pt,
dtype=jnp.float32,
)
pipeline, pipeline_params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
args.pretrained_model_name_or_path,
# tokenizer=tokenizer,
controlnet=controlnet,
safety_checker=None,
dtype=weight_dtype,
revision=args.revision,
from_pt=args.from_pt,
)
pipeline_params["controlnet"] = controlnet_params
pipeline_params = jax_utils.replicate(pipeline_params)
rng = jax.random.PRNGKey(0)
num_samples = jax.device_count()
prng_seed = jax.random.split(rng, jax.device_count())
def infer(prompt, negative_prompt, image):
prompts = num_samples * [prompt]
prompt_ids = pipeline.prepare_text_inputs(prompts)
prompt_ids = shard(prompt_ids)
validation_image = Image.fromarray(image).convert("RGB")
processed_image = pipeline.prepare_image_inputs(num_samples * [validation_image])
processed_image = shard(processed_image)
negative_prompt_ids = pipeline.prepare_text_inputs([negative_prompt] * num_samples)
negative_prompt_ids = shard(negative_prompt_ids)
images = pipeline(
prompt_ids=prompt_ids,
image=processed_image,
params=pipeline_params,
prng_seed=prng_seed,
num_inference_steps=50,
neg_prompt_ids=negative_prompt_ids,
jit=True,
).images
images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
return images[0]
with gr.Blocks(theme='gradio/soft') as demo:
gr.Markdown("## Stable Diffusion with Hand Control")
gr.Markdown("In this app, you can find different ControlNets with different filters. ")
with gr.Column():
prompt_input = gr.Textbox(label="Prompt")
negative_prompt = gr.Textbox(label="Negative Prompt")
input_image = gr.Image(label="Input Image")
output_image = gr.Image(label="Output Image")
submit_btn = gr.Button(value = "Submit")
inputs = [prompt_input, negative_prompt, input_image]
submit_btn.click(fn=infer, inputs=inputs, outputs=[output_image])
demo.launch()