kfahn commited on
Commit
090c9fa
1 Parent(s): f9183eb

Update app.py

Browse files

Adding negative prompts back

Files changed (1) hide show
  1. app.py +6 -7
app.py CHANGED
@@ -18,23 +18,22 @@ pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
18
  "runwayml/stable-diffusion-v1-5", controlnet=controlnet, revision="flax", dtype=jnp.bfloat16
19
  )
20
 
21
- def infer(prompt, image):
22
- #def infer(prompt):
23
  params["controlnet"] = controlnet_params
24
 
25
  num_samples = 1 #jax.device_count()
26
  rng = create_key(0)
27
  rng = jax.random.split(rng, jax.device_count())
28
- im = image
29
- image = Image.fromarray(im)
30
 
31
  prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)
32
- #negative_prompt_ids = pipe.prepare_text_inputs([negative_prompts] * num_samples)
33
  processed_image = pipe.prepare_image_inputs([image] * num_samples)
34
 
35
  p_params = replicate(params)
36
  prompt_ids = shard(prompt_ids)
37
- #negative_prompt_ids = shard(negative_prompt_ids)
38
  processed_image = shard(processed_image)
39
 
40
  output = pipe(
@@ -43,7 +42,7 @@ def infer(prompt, image):
43
  params=p_params,
44
  prng_seed=rng,
45
  num_inference_steps=50,
46
- #neg_prompt_ids=negative_prompt_ids,
47
  jit=True,
48
  ).images
49
 
 
18
  "runwayml/stable-diffusion-v1-5", controlnet=controlnet, revision="flax", dtype=jnp.bfloat16
19
  )
20
 
21
+ def infer(prompts, negative_prompts, image):
22
+
23
  params["controlnet"] = controlnet_params
24
 
25
  num_samples = 1 #jax.device_count()
26
  rng = create_key(0)
27
  rng = jax.random.split(rng, jax.device_count())
28
+ image = Image.fromarray(image)
 
29
 
30
  prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)
31
+ negative_prompt_ids = pipe.prepare_text_inputs([negative_prompts] * num_samples)
32
  processed_image = pipe.prepare_image_inputs([image] * num_samples)
33
 
34
  p_params = replicate(params)
35
  prompt_ids = shard(prompt_ids)
36
+ negative_prompt_ids = shard(negative_prompt_ids)
37
  processed_image = shard(processed_image)
38
 
39
  output = pipe(
 
42
  params=p_params,
43
  prng_seed=rng,
44
  num_inference_steps=50,
45
+ neg_prompt_ids=negative_prompt_ids,
46
  jit=True,
47
  ).images
48