Surn commited on
Commit
7a9500a
·
1 Parent(s): 8534154

Add progress and change tokenization

Browse files
app.py CHANGED
@@ -6,8 +6,6 @@ from tempfile import NamedTemporaryFile
6
  from pathlib import Path
7
  import atexit
8
  import random
9
- import spaces
10
-
11
  # Import constants
12
  import utils.constants as constants
13
 
@@ -161,8 +159,7 @@ def get_model_and_lora(model_textbox):
161
  default_model = model_textbox
162
  return default_model, []
163
 
164
- #@spaces.GPU(duration=256)
165
- def generate_input_image_click(map_option, prompt_textbox_value, negative_prompt_textbox_value, model_textbox_value, use_conditioned_image=False, strength=0.5, image_format="16:9", scale_factor=3):
166
  # Get the model and LoRA weights
167
  model, lora_weights = get_model_and_lora(model_textbox_value)
168
  global current_prerendered_image
@@ -191,7 +188,8 @@ def generate_input_image_click(map_option, prompt_textbox_value, negative_prompt
191
  conditioned_image,
192
  stength=strength,
193
  height=height,
194
- width=width
 
195
  )
196
 
197
  # Open the generated image
@@ -413,13 +411,24 @@ with gr.Blocks(css_paths="style_20250128.css", title="HexaGrid Creator", theme='
413
  label="Map Options",
414
  choices=list(constants.PROMPTS.keys()),
415
  value="Alien Landscape",
416
- elem_classes="solid"
 
417
  )
418
  with gr.Column():
419
  # Add Dropdown for sizing of Images, height and width based on selection. Options are 16x9, 16x10, 4x5, 1x1
420
  # The values of height and width are based on common resolutions for each aspect ratio
421
  # Default to 16x9, 912x512
422
- image_size_ratio = gr.Dropdown(label="Image Size", choices=["16:9", "16:10", "4:5", "4:3", "2:1","3:2","1:1", "9:16", "10:16", "5:4", "3:4","1:2", "2:3"], value="16:9", elem_classes="solid", type="value",interactive=True)
 
 
 
 
 
 
 
 
 
 
423
  prompt_textbox = gr.Textbox(
424
  label="Prompt",
425
  visible=False,
@@ -571,7 +580,7 @@ with gr.Blocks(css_paths="style_20250128.css", title="HexaGrid Creator", theme='
571
  )
572
  generate_input_image.click(
573
  fn=generate_input_image_click,
574
- inputs=[map_options, prompt_textbox, negative_prompt_textbox, model_textbox, gr.State(False), gr.State(0.5), image_size_ratio],
575
  outputs=[input_image], scroll_to_output=True
576
  )
577
  generate_depth_button.click(
 
6
  from pathlib import Path
7
  import atexit
8
  import random
 
 
9
  # Import constants
10
  import utils.constants as constants
11
 
 
159
  default_model = model_textbox
160
  return default_model, []
161
 
162
+ def generate_input_image_click(map_option, prompt_textbox_value, negative_prompt_textbox_value, model_textbox_value, seed=None, use_conditioned_image=False, strength=0.5, image_format="16:9", scale_factor=3, progress=gr.Progress(track_tqdm=True)):
 
163
  # Get the model and LoRA weights
164
  model, lora_weights = get_model_and_lora(model_textbox_value)
165
  global current_prerendered_image
 
188
  conditioned_image,
189
  stength=strength,
190
  height=height,
191
+ width=width,
192
+ seed=seed
193
  )
194
 
195
  # Open the generated image
 
411
  label="Map Options",
412
  choices=list(constants.PROMPTS.keys()),
413
  value="Alien Landscape",
414
+ elem_classes="solid",
415
+ scale=0
416
  )
417
  with gr.Column():
418
  # Add Dropdown for sizing of Images, height and width based on selection. Options are 16x9, 16x10, 4x5, 1x1
419
  # The values of height and width are based on common resolutions for each aspect ratio
420
  # Default to 16x9, 912x512
421
+ image_size_ratio = gr.Dropdown(label="Image Size", choices=["16:9", "16:10", "4:5", "4:3", "2:1","3:2","1:1", "9:16", "10:16", "5:4", "3:4","1:2", "2:3"], value="16:9", elem_classes="solid", type="value", scale=0, interactive=True)
422
+ with gr.Column():
423
+ seed = gr.Slider(
424
+ label="Seed",
425
+ minimum=0,
426
+ maximum=constants.MAX_SEED,
427
+ step=1,
428
+ value=0,
429
+ scale=0
430
+ )
431
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True, scale=0, interactive=True)
432
  prompt_textbox = gr.Textbox(
433
  label="Prompt",
434
  visible=False,
 
580
  )
581
  generate_input_image.click(
582
  fn=generate_input_image_click,
583
+ inputs=[map_options, prompt_textbox, negative_prompt_textbox, model_textbox,gr.State( seed if randomize_seed==False else random.randint(0, constants.MAX_SEED)), gr.State(False), gr.State(0.5), image_size_ratio],
584
  outputs=[input_image], scroll_to_output=True
585
  )
586
  generate_depth_button.click(
utils/ai_generator.py CHANGED
@@ -1,9 +1,8 @@
1
  # utils/ai_generator.py
2
-
3
  import os
4
  import time
5
  from turtle import width # Added for implementing delays
6
- import spaces
7
  import torch
8
  import random
9
  from utils.ai_generator_diffusers_flux import generate_ai_image_local
@@ -15,8 +14,7 @@ from PIL import Image
15
  from tempfile import NamedTemporaryFile
16
  import utils.constants as constants
17
 
18
-
19
- def generate_image_from_text(text, model_name="flax-community/dalle-mini", image_width=768, image_height=512):
20
  # Initialize the InferenceClient
21
  client = InferenceClient()
22
  # Generate the image from the text
@@ -40,12 +38,13 @@ def generate_ai_image(
40
  width=912,
41
  height=512,
42
  strength=0.5,
 
 
43
  *args,
44
  **kwargs
45
- ):
46
- seed = random.randint(1, 99999)
47
- if torch.cuda.is_available():
48
- print("Local GPU available. Generating image locally.")
49
  if conditioned_image is not None:
50
  pipeline = "FluxImg2ImgPipeline"
51
  return generate_ai_image_local(
@@ -69,10 +68,11 @@ def generate_ai_image(
69
  neg_prompt_textbox_value,
70
  model,
71
  height=height,
72
- width=width
 
73
  )
74
 
75
- def generate_ai_image_remote(map_option, prompt_textbox_value, neg_prompt_textbox_value, model, height=512, width=912, num_inference_steps=30, guidance_scale=3.5, seed=777):
76
  max_retries = 3
77
  retry_delay = 4 # Initial delay in seconds
78
 
 
1
  # utils/ai_generator.py
2
+ import gradio as gr
3
  import os
4
  import time
5
  from turtle import width # Added for implementing delays
 
6
  import torch
7
  import random
8
  from utils.ai_generator_diffusers_flux import generate_ai_image_local
 
14
  from tempfile import NamedTemporaryFile
15
  import utils.constants as constants
16
 
17
+ def generate_image_from_text(text, model_name="flax-community/dalle-mini", image_width=768, image_height=512, progress=gr.Progress(track_tqdm=True)):
 
18
  # Initialize the InferenceClient
19
  client = InferenceClient()
20
  # Generate the image from the text
 
38
  width=912,
39
  height=512,
40
  strength=0.5,
41
+ seed = random.randint(0, constants.MAX_SEED),
42
+ progress=gr.Progress(track_tqdm=True),
43
  *args,
44
  **kwargs
45
+ ):
46
+ if (torch.cuda.is_available() and torch.cuda.device_count() >= 1):
47
+ print("Local GPU available. Generating image locally.")
 
48
  if conditioned_image is not None:
49
  pipeline = "FluxImg2ImgPipeline"
50
  return generate_ai_image_local(
 
68
  neg_prompt_textbox_value,
69
  model,
70
  height=height,
71
+ width=width,
72
+ seed=seed
73
  )
74
 
75
+ def generate_ai_image_remote(map_option, prompt_textbox_value, neg_prompt_textbox_value, model, height=512, width=912, num_inference_steps=30, guidance_scale=3.5, seed=777,progress=gr.Progress(track_tqdm=True)):
76
  max_retries = 3
77
  retry_delay = 4 # Initial delay in seconds
78
 
utils/ai_generator_diffusers_flux.py CHANGED
@@ -1,11 +1,12 @@
1
  # utils/ai_generator_diffusers_flux.py
 
2
  import os
3
  import utils.constants as constants
4
  import spaces
5
  import torch
6
- from diffusers import FluxPipeline,FluxImg2ImgPipeline,FluxControlPipeline, DiffusionPipeline
7
  import accelerate
8
- import transformers
9
  import safetensors
10
  import xformers
11
  from diffusers.utils import load_image
@@ -27,7 +28,6 @@ from utils.color_utils import detect_color_format
27
  import utils.misc as misc
28
  from pathlib import Path
29
  import warnings
30
-
31
  warnings.filterwarnings("ignore", message=".*Torch was not compiled with flash attention.*")
32
  #print(torch.__version__) # Ensure it's 2.0 or newer
33
  #print(torch.cuda.is_available()) # Ensure CUDA is available
@@ -36,7 +36,6 @@ PIPELINE_CLASSES = {
36
  "FluxPipeline": FluxPipeline,
37
  "FluxImg2ImgPipeline": FluxImg2ImgPipeline
38
  }
39
-
40
  @spaces.GPU(duration=140)
41
  def generate_image_from_text(
42
  text,
@@ -48,16 +47,28 @@ def generate_image_from_text(
48
  guidance_scale=3.5,
49
  num_inference_steps=50,
50
  seed=0,
51
- additional_parameters=None
 
52
  ):
53
  device = "cuda" if torch.cuda.is_available() else "cpu"
54
  print(f"device:{device}\nmodel_name:{model_name}\n")
 
 
55
  pipe = FluxPipeline.from_pretrained(
56
  model_name,
57
  torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32
58
  ).to(device)
59
- pipe = pipe.to(device)
60
  pipe.enable_model_cpu_offload()
 
 
 
 
 
 
 
 
 
 
61
  # Load and apply LoRA weights
62
  if lora_weights:
63
  for lora_weight in lora_weights:
@@ -74,12 +85,18 @@ def generate_image_from_text(
74
  )
75
  else:
76
  pipe.load_lora_weights(lora_weight, use_auth_token=constants.HF_API_TOKEN)
 
 
77
  generator = torch.Generator(device=device).manual_seed(seed)
78
  conditions = []
 
 
79
  if conditioned_image is not None:
80
  conditioned_image = crop_and_resize_image(conditioned_image, 1024, 1024)
81
  condition = Condition("subject", conditioned_image)
82
  conditions.append(condition)
 
 
83
  generate_params = {
84
  "prompt": text,
85
  "height": image_height,
@@ -89,12 +106,24 @@ def generate_image_from_text(
89
  "generator": generator,
90
  "conditions": conditions if conditions else None
91
  }
 
92
  if additional_parameters:
93
  generate_params.update(additional_parameters)
94
  generate_params = {k: v for k, v in generate_params.items() if v is not None}
 
 
95
  result = pipe(**generate_params)
96
  image = result.images[0]
97
  pipe.unload_lora_weights()
 
 
 
 
 
 
 
 
 
98
  return image
99
 
100
  @spaces.GPU(duration=140)
@@ -111,18 +140,19 @@ def generate_image_lowmem(
111
  seed=0,
112
  true_cfg_scale=1.0,
113
  pipeline_name="FluxPipeline",
114
- strength=0.75,
115
- additional_parameters=None
116
- ):
117
- print(f"\n {get_torch_info()}\n")
118
  # Retrieve the pipeline class from the mapping
119
  pipeline_class = PIPELINE_CLASSES.get(pipeline_name)
120
  if not pipeline_class:
121
  raise ValueError(f"Unsupported pipeline type '{pipeline_name}'. "
122
  f"Available options: {list(PIPELINE_CLASSES.keys())}")
 
123
  device = "cuda" if torch.cuda.is_available() else "cpu"
124
  print(f"device:{device}\nmodel_name:{model_name}\nlora_weights:{lora_weights}\n")
125
-
126
  # Disable gradient calculations
127
  with torch.no_grad():
128
  # Initialize the pipeline inside the context manager
@@ -134,6 +164,16 @@ def generate_image_lowmem(
134
  pipe.enable_model_cpu_offload()
135
  # alternative version that may be more efficient
136
  # pipe.enable_sequential_cpu_offload()
 
 
 
 
 
 
 
 
 
 
137
  flash_attention_enabled = torch.backends.cuda.flash_sdp_enabled()
138
  if flash_attention_enabled == False:
139
  #Enable xFormers memory-efficient attention (optional)
@@ -282,6 +322,7 @@ def generate_ai_image_local (
282
  seed=777,
283
  pipeline_name="FluxPipeline",
284
  strength=0.75,
 
285
  ):
286
  try:
287
  if map_option != "Prompt":
@@ -306,10 +347,10 @@ def generate_ai_image_local (
306
  additional_parameters[key] = int(value)
307
  elif key in ['guidance_scale','true_cfg_scale']:
308
  additional_parameters[key] = float(value)
309
- height = additional_parameters.get('height', height)
310
- width = additional_parameters.get('width', width)
311
- num_inference_steps = additional_parameters.get('num_inference_steps', num_inference_steps)
312
- guidance_scale = additional_parameters.get('guidance_scale', guidance_scale)
313
  print("Generating image with the following parameters:")
314
  print(f"Model: {model}")
315
  print(f"LoRA Weights: {lora_weights}")
@@ -347,7 +388,6 @@ def generate_ai_image_local (
347
  return None
348
 
349
  # does not work
350
- #@spaces.GPU(duration=256)
351
  def merge_LoRA_weights(model="black-forest-labs/FLUX.1-dev",
352
  lora_weights="Borcherding/FLUX.1-dev-LoRA-FractalLand-v0.1"):
353
 
 
1
  # utils/ai_generator_diffusers_flux.py
2
+ import gradio as gr
3
  import os
4
  import utils.constants as constants
5
  import spaces
6
  import torch
7
+ from diffusers import FluxPipeline,FluxImg2ImgPipeline,FluxControlPipeline
8
  import accelerate
9
+ from transformers import AutoTokenizer
10
  import safetensors
11
  import xformers
12
  from diffusers.utils import load_image
 
28
  import utils.misc as misc
29
  from pathlib import Path
30
  import warnings
 
31
  warnings.filterwarnings("ignore", message=".*Torch was not compiled with flash attention.*")
32
  #print(torch.__version__) # Ensure it's 2.0 or newer
33
  #print(torch.cuda.is_available()) # Ensure CUDA is available
 
36
  "FluxPipeline": FluxPipeline,
37
  "FluxImg2ImgPipeline": FluxImg2ImgPipeline
38
  }
 
39
  @spaces.GPU(duration=140)
40
  def generate_image_from_text(
41
  text,
 
47
  guidance_scale=3.5,
48
  num_inference_steps=50,
49
  seed=0,
50
+ additional_parameters=None,
51
+ progress=gr.Progress(track_tqdm=True)
52
  ):
53
  device = "cuda" if torch.cuda.is_available() else "cpu"
54
  print(f"device:{device}\nmodel_name:{model_name}\n")
55
+
56
+ # Initialize the pipeline
57
  pipe = FluxPipeline.from_pretrained(
58
  model_name,
59
  torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32
60
  ).to(device)
 
61
  pipe.enable_model_cpu_offload()
62
+
63
+ # Access the tokenizer from the pipeline
64
+ tokenizer = pipe.tokenizer
65
+
66
+ # Handle add_prefix_space attribute
67
+ if getattr(tokenizer, 'add_prefix_space', False):
68
+ tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
69
+ # Update the pipeline's tokenizer
70
+ pipe.tokenizer = tokenizer
71
+
72
  # Load and apply LoRA weights
73
  if lora_weights:
74
  for lora_weight in lora_weights:
 
85
  )
86
  else:
87
  pipe.load_lora_weights(lora_weight, use_auth_token=constants.HF_API_TOKEN)
88
+
89
+ # Set the random seed for reproducibility
90
  generator = torch.Generator(device=device).manual_seed(seed)
91
  conditions = []
92
+
93
+ # Handle conditioned image if provided
94
  if conditioned_image is not None:
95
  conditioned_image = crop_and_resize_image(conditioned_image, 1024, 1024)
96
  condition = Condition("subject", conditioned_image)
97
  conditions.append(condition)
98
+
99
+ # Prepare parameters for image generation
100
  generate_params = {
101
  "prompt": text,
102
  "height": image_height,
 
106
  "generator": generator,
107
  "conditions": conditions if conditions else None
108
  }
109
+
110
  if additional_parameters:
111
  generate_params.update(additional_parameters)
112
  generate_params = {k: v for k, v in generate_params.items() if v is not None}
113
+
114
+ # Generate the image
115
  result = pipe(**generate_params)
116
  image = result.images[0]
117
  pipe.unload_lora_weights()
118
+
119
+ # Clean up
120
+ del result
121
+ del conditions
122
+ del generator
123
+ del pipe
124
+ torch.cuda.empty_cache()
125
+ torch.cuda.ipc_collect()
126
+
127
  return image
128
 
129
  @spaces.GPU(duration=140)
 
140
  seed=0,
141
  true_cfg_scale=1.0,
142
  pipeline_name="FluxPipeline",
143
+ strength=0.75,
144
+ additional_parameters=None,
145
+ progress=gr.Progress(track_tqdm=True)
146
+ ):
147
  # Retrieve the pipeline class from the mapping
148
  pipeline_class = PIPELINE_CLASSES.get(pipeline_name)
149
  if not pipeline_class:
150
  raise ValueError(f"Unsupported pipeline type '{pipeline_name}'. "
151
  f"Available options: {list(PIPELINE_CLASSES.keys())}")
152
+
153
  device = "cuda" if torch.cuda.is_available() else "cpu"
154
  print(f"device:{device}\nmodel_name:{model_name}\nlora_weights:{lora_weights}\n")
155
+ print(f"\n {get_torch_info()}\n")
156
  # Disable gradient calculations
157
  with torch.no_grad():
158
  # Initialize the pipeline inside the context manager
 
164
  pipe.enable_model_cpu_offload()
165
  # alternative version that may be more efficient
166
  # pipe.enable_sequential_cpu_offload()
167
+
168
+ # Access the tokenizer from the pipeline
169
+ tokenizer = pipe.tokenizer
170
+
171
+ # Check if add_prefix_space is set and convert to slow tokenizer if necessary
172
+ if getattr(tokenizer, 'add_prefix_space', False):
173
+ tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
174
+ # Update the pipeline's tokenizer
175
+ pipe.tokenizer = tokenizer
176
+
177
  flash_attention_enabled = torch.backends.cuda.flash_sdp_enabled()
178
  if flash_attention_enabled == False:
179
  #Enable xFormers memory-efficient attention (optional)
 
322
  seed=777,
323
  pipeline_name="FluxPipeline",
324
  strength=0.75,
325
+ progress=gr.Progress(track_tqdm=True)
326
  ):
327
  try:
328
  if map_option != "Prompt":
 
347
  additional_parameters[key] = int(value)
348
  elif key in ['guidance_scale','true_cfg_scale']:
349
  additional_parameters[key] = float(value)
350
+ height = additional_parameters.pop('height', height)
351
+ width = additional_parameters.pop('width', width)
352
+ num_inference_steps = additional_parameters.pop('num_inference_steps', num_inference_steps)
353
+ guidance_scale = additional_parameters.pop('guidance_scale', guidance_scale)
354
  print("Generating image with the following parameters:")
355
  print(f"Model: {model}")
356
  print(f"LoRA Weights: {lora_weights}")
 
388
  return None
389
 
390
  # does not work
 
391
  def merge_LoRA_weights(model="black-forest-labs/FLUX.1-dev",
392
  lora_weights="Borcherding/FLUX.1-dev-LoRA-FractalLand-v0.1"):
393
 
utils/constants.py CHANGED
@@ -4,6 +4,7 @@
4
  import os
5
  from pathlib import Path
6
  from dotenv import load_dotenv
 
7
 
8
  #Set the environment variables
9
  os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:256,expandable_segments:True"
@@ -32,6 +33,7 @@ if not HF_API_TOKEN:
32
  raise ValueError("HF_TOKEN is not set. Please check your .env file.")
33
 
34
  default_lut_example_img = "./LUT/daisy.jpg"
 
35
 
36
  PROMPTS = {
37
  "BorderBlack": "eight_color (tabletop_map built from small hexagon pieces) as ((empty black on all sides), barren alien_world_map), with light_blue_is_rivers and brown_is_mountains and red_is_volcano and [white_is_snow at the top and bottom of map] as (four_color background: light_blue, green, tan, brown), horizontal_gradient is (brown to tan to green to light_blue to blue) and vertical_gradient is (white to blue to (green, tan and red) to blue to white), (middle is dark, no_reflections, no_shadows), ((partial hexes on edges and sides are black))",
 
4
  import os
5
  from pathlib import Path
6
  from dotenv import load_dotenv
7
+ import numpy as np
8
 
9
  #Set the environment variables
10
  os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:256,expandable_segments:True"
 
33
  raise ValueError("HF_TOKEN is not set. Please check your .env file.")
34
 
35
  default_lut_example_img = "./LUT/daisy.jpg"
36
+ MAX_SEED = np.iinfo(np.int32).max
37
 
38
  PROMPTS = {
39
  "BorderBlack": "eight_color (tabletop_map built from small hexagon pieces) as ((empty black on all sides), barren alien_world_map), with light_blue_is_rivers and brown_is_mountains and red_is_volcano and [white_is_snow at the top and bottom of map] as (four_color background: light_blue, green, tan, brown), horizontal_gradient is (brown to tan to green to light_blue to blue) and vertical_gradient is (white to blue to (green, tan and red) to blue to white), (middle is dark, no_reflections, no_shadows), ((partial hexes on edges and sides are black))",
utils/live_preview_helpers.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from diffusers import FluxPipeline, AutoencoderTiny, FlowMatchEulerDiscreteScheduler
4
+ from typing import Any, Dict, List, Optional, Union
5
+
6
+ # Helper functions
7
+ def calculate_shift(
8
+ image_seq_len,
9
+ base_seq_len: int = 256,
10
+ max_seq_len: int = 4096,
11
+ base_shift: float = 0.5,
12
+ max_shift: float = 1.16,
13
+ ):
14
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
15
+ b = base_shift - m * base_seq_len
16
+ mu = image_seq_len * m + b
17
+ return mu
18
+
19
+ def retrieve_timesteps(
20
+ scheduler,
21
+ num_inference_steps: Optional[int] = None,
22
+ device: Optional[Union[str, torch.device]] = None,
23
+ timesteps: Optional[List[int]] = None,
24
+ sigmas: Optional[List[float]] = None,
25
+ **kwargs,
26
+ ):
27
+ if timesteps is not None and sigmas is not None:
28
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
29
+ if timesteps is not None:
30
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
31
+ timesteps = scheduler.timesteps
32
+ num_inference_steps = len(timesteps)
33
+ elif sigmas is not None:
34
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
35
+ timesteps = scheduler.timesteps
36
+ num_inference_steps = len(timesteps)
37
+ else:
38
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
39
+ timesteps = scheduler.timesteps
40
+ return timesteps, num_inference_steps
41
+
42
+ # FLUX pipeline function
43
+ @torch.inference_mode()
44
+ def flux_pipe_call_that_returns_an_iterable_of_images(
45
+ self,
46
+ prompt: Union[str, List[str]] = None,
47
+ prompt_2: Optional[Union[str, List[str]]] = None,
48
+ height: Optional[int] = None,
49
+ width: Optional[int] = None,
50
+ num_inference_steps: int = 28,
51
+ timesteps: List[int] = None,
52
+ guidance_scale: float = 3.5,
53
+ num_images_per_prompt: Optional[int] = 1,
54
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
55
+ latents: Optional[torch.FloatTensor] = None,
56
+ prompt_embeds: Optional[torch.FloatTensor] = None,
57
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
58
+ output_type: Optional[str] = "pil",
59
+ return_dict: bool = True,
60
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
61
+ max_sequence_length: int = 512,
62
+ good_vae: Optional[Any] = None,
63
+ ):
64
+ height = height or self.default_sample_size * self.vae_scale_factor
65
+ width = width or self.default_sample_size * self.vae_scale_factor
66
+
67
+ # 1. Check inputs
68
+ self.check_inputs(
69
+ prompt,
70
+ prompt_2,
71
+ height,
72
+ width,
73
+ prompt_embeds=prompt_embeds,
74
+ pooled_prompt_embeds=pooled_prompt_embeds,
75
+ max_sequence_length=max_sequence_length,
76
+ )
77
+
78
+ self._guidance_scale = guidance_scale
79
+ self._joint_attention_kwargs = joint_attention_kwargs
80
+ self._interrupt = False
81
+
82
+ # 2. Define call parameters
83
+ batch_size = 1 if isinstance(prompt, str) else len(prompt)
84
+ device = self._execution_device
85
+
86
+ # 3. Encode prompt
87
+ lora_scale = joint_attention_kwargs.get("scale", None) if joint_attention_kwargs is not None else None
88
+ prompt_embeds, pooled_prompt_embeds, text_ids = self.encode_prompt(
89
+ prompt=prompt,
90
+ prompt_2=prompt_2,
91
+ prompt_embeds=prompt_embeds,
92
+ pooled_prompt_embeds=pooled_prompt_embeds,
93
+ device=device,
94
+ num_images_per_prompt=num_images_per_prompt,
95
+ max_sequence_length=max_sequence_length,
96
+ lora_scale=lora_scale,
97
+ )
98
+ # 4. Prepare latent variables
99
+ num_channels_latents = self.transformer.config.in_channels // 4
100
+ latents, latent_image_ids = self.prepare_latents(
101
+ batch_size * num_images_per_prompt,
102
+ num_channels_latents,
103
+ height,
104
+ width,
105
+ prompt_embeds.dtype,
106
+ device,
107
+ generator,
108
+ latents,
109
+ )
110
+ # 5. Prepare timesteps
111
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
112
+ image_seq_len = latents.shape[1]
113
+ mu = calculate_shift(
114
+ image_seq_len,
115
+ self.scheduler.config.base_image_seq_len,
116
+ self.scheduler.config.max_image_seq_len,
117
+ self.scheduler.config.base_shift,
118
+ self.scheduler.config.max_shift,
119
+ )
120
+ timesteps, num_inference_steps = retrieve_timesteps(
121
+ self.scheduler,
122
+ num_inference_steps,
123
+ device,
124
+ timesteps,
125
+ sigmas,
126
+ mu=mu,
127
+ )
128
+ self._num_timesteps = len(timesteps)
129
+
130
+ # Handle guidance
131
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32).expand(latents.shape[0]) if self.transformer.config.guidance_embeds else None
132
+
133
+ # 6. Denoising loop
134
+ for i, t in enumerate(timesteps):
135
+ if self.interrupt:
136
+ continue
137
+
138
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
139
+
140
+ noise_pred = self.transformer(
141
+ hidden_states=latents,
142
+ timestep=timestep / 1000,
143
+ guidance=guidance,
144
+ pooled_projections=pooled_prompt_embeds,
145
+ encoder_hidden_states=prompt_embeds,
146
+ txt_ids=text_ids,
147
+ img_ids=latent_image_ids,
148
+ joint_attention_kwargs=self.joint_attention_kwargs,
149
+ return_dict=False,
150
+ )[0]
151
+ # Yield intermediate result
152
+ latents_for_image = self._unpack_latents(latents, height, width, self.vae_scale_factor)
153
+ latents_for_image = (latents_for_image / self.vae.config.scaling_factor) + self.vae.config.shift_factor
154
+ image = self.vae.decode(latents_for_image, return_dict=False)[0]
155
+ yield self.image_processor.postprocess(image, output_type=output_type)[0]
156
+
157
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
158
+ torch.cuda.empty_cache()
159
+
160
+ # Final image using good_vae
161
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
162
+ latents = (latents / good_vae.config.scaling_factor) + good_vae.config.shift_factor
163
+ image = good_vae.decode(latents, return_dict=False)[0]
164
+ self.maybe_free_model_hooks()
165
+ torch.cuda.empty_cache()
166
+ yield self.image_processor.postprocess(image, output_type=output_type)[0]