Vincent-luo commited on
Commit
b2c8596
1 Parent(s): 3087f4c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +92 -0
app.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import jax
2
+ import jax.numpy as jnp
3
+ from flax import jax_utils
4
+ from flax.training.common_utils import shard
5
+ from PIL import Image
6
+ from argparse import Namespace
7
+ import gradio as gr
8
+
9
+ from diffusers import (
10
+ FlaxControlNetModel,
11
+ FlaxStableDiffusionControlNetPipeline,
12
+ )
13
+
14
+
15
+ args = Namespace(
16
+ pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5",
17
+ revision="non-ema",
18
+ from_pt=True,
19
+ controlnet_model_name_or_path="Vincent-luo/controlnet-hands",
20
+ controlnet_revision=None,
21
+ controlnet_from_pt=False,
22
+ )
23
+ weight_dtype = jnp.float32
24
+
25
+ controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
26
+ args.controlnet_model_name_or_path,
27
+ revision=args.controlnet_revision,
28
+ from_pt=args.controlnet_from_pt,
29
+ dtype=jnp.float32,
30
+ )
31
+
32
+ pipeline, pipeline_params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
33
+ args.pretrained_model_name_or_path,
34
+ # tokenizer=tokenizer,
35
+ controlnet=controlnet,
36
+ safety_checker=None,
37
+ dtype=weight_dtype,
38
+ revision=args.revision,
39
+ from_pt=args.from_pt,
40
+ )
41
+
42
+
43
+ pipeline_params["controlnet"] = controlnet_params
44
+ pipeline_params = jax_utils.replicate(pipeline_params)
45
+
46
+ rng = jax.random.PRNGKey(0)
47
+ num_samples = jax.device_count()
48
+ prng_seed = jax.random.split(rng, jax.device_count())
49
+
50
+
51
+ def infer(prompt, negative_prompt, image):
52
+ prompts = num_samples * [prompt]
53
+ prompt_ids = pipeline.prepare_text_inputs(prompts)
54
+ prompt_ids = shard(prompt_ids)
55
+
56
+ validation_image = Image.fromarray(image).convert("RGB")
57
+ processed_image = pipeline.prepare_image_inputs(num_samples * [validation_image])
58
+ processed_image = shard(processed_image)
59
+
60
+ negative_prompt_ids = pipeline.prepare_text_inputs([negative_prompt] * num_samples)
61
+ negative_prompt_ids = shard(negative_prompt_ids)
62
+
63
+ images = pipeline(
64
+ prompt_ids=prompt_ids,
65
+ image=processed_image,
66
+ params=pipeline_params,
67
+ prng_seed=prng_seed,
68
+ num_inference_steps=50,
69
+ neg_prompt_ids=negative_prompt_ids,
70
+ jit=True,
71
+ ).images
72
+
73
+
74
+ images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
75
+
76
+ return images[0]
77
+
78
+
79
+ with gr.Blocks(theme='gradio/soft') as demo:
80
+ gr.Markdown("## Stable Diffusion with Hand Control")
81
+ gr.Markdown("In this app, you can find different ControlNets with different filters. ")
82
+
83
+ with gr.Column():
84
+ prompt_input = gr.Textbox(label="Prompt")
85
+ negative_prompt = gr.Textbox(label="Negative Prompt")
86
+ input_image = gr.Image(label="Input Image")
87
+ output_image = gr.Image(label="Output Image")
88
+ submit_btn = gr.Button(value = "Submit")
89
+ inputs = [prompt_input, negative_prompt, input_image]
90
+ submit_btn.click(fn=infer, inputs=inputs, outputs=[output_image])
91
+
92
+ demo.launch()