Prime Cai commited on
Commit
dbabfb0
·
1 Parent(s): ec81f63

add num of images

Browse files
Files changed (2) hide show
  1. app.py +18 -6
  2. pipeline.py +5 -5
app.py CHANGED
@@ -6,7 +6,7 @@ from PIL import Image
6
  from diffusers.utils import load_image
7
  from pipeline import FluxConditionalPipeline
8
  from transformer import FluxTransformer2DConditionalModel
9
-
10
  import os
11
 
12
  pipe = None
@@ -44,7 +44,8 @@ def generate_image(
44
  gemini_prompt: bool = True,
45
  guidance: float = 3.5,
46
  i_guidance: float = 1.0,
47
- t_guidance: float = 1.0
 
48
  ):
49
  w, h, min_size = image.size[0], image.size[1], min(image.size)
50
  image = image.crop(
@@ -52,8 +53,13 @@ def generate_image(
52
  ).resize((512, 512))
53
 
54
  control_image = load_image(image)
 
 
 
 
 
55
  result_image = pipe(
56
- prompt=text.strip(),
57
  negative_prompt="",
58
  num_inference_steps=28,
59
  height=512,
@@ -63,7 +69,7 @@ def generate_image(
63
  guidance_scale_real_i=i_guidance,
64
  guidance_scale_real_t=t_guidance,
65
  gemini_prompt=gemini_prompt,
66
- ).images[0]
67
 
68
  return result_image
69
 
@@ -125,6 +131,10 @@ with demo:
125
  <a href="https://huggingface.co/datasets/primecai/dsd_data" target="_blank"><img src="https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace%20-Data-yellow" style="display:inline-block;"></a>
126
  <a href="https://huggingface.co/primecai/dsd_model" target="_blank"><img src="https://img.shields.io/badge/🤗%20Hugging%20Face%20-Model-green" style="display:inline-block;"></a>
127
  <a href="https://x.com/prime_cai?lang=en" target="_blank"><img src="https://img.shields.io/twitter/follow/prime_cai?style=social" style="display:inline-block;"></a>
 
 
 
 
128
  </div>
129
  """
130
  )
@@ -132,14 +142,16 @@ with demo:
132
  iface = gr.Interface(
133
  fn=generate_image,
134
  inputs=[
135
- gr.Image(type="pil", width=512),
136
  gr.Textbox(lines=2, label="text", info="Could be something as simple as 'this character playing soccer'."),
137
  gr.Checkbox(label="Gemini prompt", value=True, info="Use Gemini to enhance the prompt. This is recommended for most cases, unless you have a specific prompt similar to the examples in mind."),
138
  gr.Slider(minimum=1.0, maximum=6.0, step=0.5, value=3.5, label="guidance scale", info="Tip: start with 3.5, then gradually increase if the consistency is consistently off"),
139
  gr.Slider(minimum=1.0, maximum=2.0, step=0.05, value=1.5, label="real guidance scale for image", info="Tip: increase if the image is not consistent"),
140
  gr.Slider(minimum=1.0, maximum=2.0, step=0.05, value=1.0, label="real guidance scale for prompt", info="Tip: increase if the prompt is not consistent"),
 
141
  ],
142
- outputs=gr.Image(type="pil"),
 
143
  # examples=get_samples(),
144
  live=False,
145
  )
 
6
  from diffusers.utils import load_image
7
  from pipeline import FluxConditionalPipeline
8
  from transformer import FluxTransformer2DConditionalModel
9
+ from recaption import enhance_prompt
10
  import os
11
 
12
  pipe = None
 
44
  gemini_prompt: bool = True,
45
  guidance: float = 3.5,
46
  i_guidance: float = 1.0,
47
+ t_guidance: float = 1.0,
48
+ num_images: int = 4,
49
  ):
50
  w, h, min_size = image.size[0], image.size[1], min(image.size)
51
  image = image.crop(
 
53
  ).resize((512, 512))
54
 
55
  control_image = load_image(image)
56
+ text_list = []
57
+ for _ in range(num_images):
58
+ if gemini_prompt:
59
+ text = enhance_prompt(image, text.strip())
60
+ text_list.append(text.strip())
61
  result_image = pipe(
62
+ prompt=text_list,
63
  negative_prompt="",
64
  num_inference_steps=28,
65
  height=512,
 
69
  guidance_scale_real_i=i_guidance,
70
  guidance_scale_real_t=t_guidance,
71
  gemini_prompt=gemini_prompt,
72
+ ).images
73
 
74
  return result_image
75
 
 
131
  <a href="https://huggingface.co/datasets/primecai/dsd_data" target="_blank"><img src="https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace%20-Data-yellow" style="display:inline-block;"></a>
132
  <a href="https://huggingface.co/primecai/dsd_model" target="_blank"><img src="https://img.shields.io/badge/🤗%20Hugging%20Face%20-Model-green" style="display:inline-block;"></a>
133
  <a href="https://x.com/prime_cai?lang=en" target="_blank"><img src="https://img.shields.io/twitter/follow/prime_cai?style=social" style="display:inline-block;"></a>
134
+
135
+ <div style="text-align: center;">
136
+ The model does have randomness because of both the Gemini prompt enhancement and the diffusion initial noises. Please give it a few tries to get the best results.
137
+ </div>
138
  </div>
139
  """
140
  )
 
142
  iface = gr.Interface(
143
  fn=generate_image,
144
  inputs=[
145
+ gr.Image(type="pil", width=300),
146
  gr.Textbox(lines=2, label="text", info="Could be something as simple as 'this character playing soccer'."),
147
  gr.Checkbox(label="Gemini prompt", value=True, info="Use Gemini to enhance the prompt. This is recommended for most cases, unless you have a specific prompt similar to the examples in mind."),
148
  gr.Slider(minimum=1.0, maximum=6.0, step=0.5, value=3.5, label="guidance scale", info="Tip: start with 3.5, then gradually increase if the consistency is consistently off"),
149
  gr.Slider(minimum=1.0, maximum=2.0, step=0.05, value=1.5, label="real guidance scale for image", info="Tip: increase if the image is not consistent"),
150
  gr.Slider(minimum=1.0, maximum=2.0, step=0.05, value=1.0, label="real guidance scale for prompt", info="Tip: increase if the prompt is not consistent"),
151
+ gr.Slider(minimum=1, maximum=5, step=1, value=4, label="Number of images", info="Select how many images to generate"),
152
  ],
153
+ # outputs=gr.Image(type="pil"),
154
+ outputs=gr.Gallery(label="Generated Images", height=544),
155
  # examples=get_samples(),
156
  live=False,
157
  )
pipeline.py CHANGED
@@ -39,7 +39,7 @@ from diffusers.utils import (
39
  )
40
  from diffusers.utils.torch_utils import randn_tensor
41
  from diffusers.pipelines.pipeline_utils import DiffusionPipeline
42
- from recaption import enhance_prompt
43
 
44
 
45
  if is_torch_xla_available():
@@ -722,8 +722,8 @@ class FluxConditionalPipeline(DiffusionPipeline, SD3LoraLoaderMixin):
722
 
723
  device = self._execution_device
724
 
725
- if gemini_prompt:
726
- prompt = enhance_prompt(image, prompt)
727
  # if gemini_prompt:
728
  # while True:
729
  # try:
@@ -779,8 +779,8 @@ class FluxConditionalPipeline(DiffusionPipeline, SD3LoraLoaderMixin):
779
  # 3. Preprocess image
780
  image = self.image_processor.preprocess(image)
781
  # image = image[..., :512]
782
- image = torch.nn.functional.interpolate(image, size=512)
783
- black_image = torch.full((1, 3, 512, 512), -1.0)
784
  image = torch.cat([image, black_image], dim=3)
785
  latents_cond = self.vae.encode(image.to(self.vae.dtype).to(self.vae.device)).latent_dist.sample()
786
  latents_cond = (
 
39
  )
40
  from diffusers.utils.torch_utils import randn_tensor
41
  from diffusers.pipelines.pipeline_utils import DiffusionPipeline
42
+ # from recaption import enhance_prompt
43
 
44
 
45
  if is_torch_xla_available():
 
722
 
723
  device = self._execution_device
724
 
725
+ # if gemini_prompt:
726
+ # prompt = enhance_prompt(image, prompt)
727
  # if gemini_prompt:
728
  # while True:
729
  # try:
 
779
  # 3. Preprocess image
780
  image = self.image_processor.preprocess(image)
781
  # image = image[..., :512]
782
+ image = torch.nn.functional.interpolate(image, size=512).repeat(batch_size, 1, 1, 1)
783
+ black_image = torch.full((batch_size, 3, 512, 512), -1.0)
784
  image = torch.cat([image, black_image], dim=3)
785
  latents_cond = self.vae.encode(image.to(self.vae.dtype).to(self.vae.device)).latent_dist.sample()
786
  latents_cond = (