HexaGrid / utils /ai_generator_diffusers_flux.py
Surn's picture
Reverse Changes
6dd859c
# utils/ai_generator_diffusers_flux.py
import os
import utils.constants as constants
#import spaces
import gradio as gr
from torch import __version__ as torch_version_, version, cuda, bfloat16, float32, Generator, no_grad, backends
from diffusers import FluxPipeline,FluxImg2ImgPipeline,FluxControlPipeline
import accelerate
from transformers import AutoTokenizer
import safetensors
#import xformers
#from diffusers.utils import load_image
#from huggingface_hub import hf_hub_download
from PIL import Image
from tempfile import NamedTemporaryFile
from utils.image_utils import (
crop_and_resize_image,
)
from utils.version_info import (
get_torch_info,
# get_diffusers_version,
# get_transformers_version,
# get_xformers_version,
initialize_cuda,
release_torch_resources
)
import gc
from utils.lora_details import get_trigger_words, approximate_token_count, split_prompt_precisely
#from utils.color_utils import detect_color_format
#import utils.misc as misc
#from pathlib import Path
import warnings
warnings.filterwarnings("ignore", message=".*Torch was not compiled with flash attention.*")
#print(torch_version_) # Ensure it's 2.0 or newer
#print(cuda.is_available()) # Ensure CUDA is available
PIPELINE_CLASSES = {
"FluxPipeline": FluxPipeline,
"FluxImg2ImgPipeline": FluxImg2ImgPipeline,
"FluxControlPipeline": FluxControlPipeline
}
#@spaces.GPU()
def generate_image_from_text(
text,
model_name="black-forest-labs/FLUX.1-dev",
lora_weights=None,
conditioned_image=None,
image_width=1344,
image_height=848,
guidance_scale=3.5,
num_inference_steps=50,
seed=0,
additional_parameters=None,
progress=gr.Progress(track_tqdm=True)
):
from src.condition import Condition
device = "cuda" if cuda.is_available() else "cpu"
print(f"device:{device}\nmodel_name:{model_name}\n")
# Initialize the pipeline
pipe = FluxPipeline.from_pretrained(
model_name,
torch_dtype=bfloat16 if device == "cuda" else float32
).to(device)
pipe.enable_model_cpu_offload()
# Access the tokenizer from the pipeline
tokenizer = pipe.tokenizer
# Handle add_prefix_space attribute
if getattr(tokenizer, 'add_prefix_space', False):
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
# Update the pipeline's tokenizer
pipe.tokenizer = tokenizer
# Load and apply LoRA weights
if lora_weights:
for lora_weight in lora_weights:
lora_configs = constants.LORA_DETAILS.get(lora_weight, [])
if lora_configs:
for config in lora_configs:
weight_name = config.get("weight_name")
adapter_name = config.get("adapter_name")
pipe.load_lora_weights(
lora_weight,
weight_name=weight_name,
adapter_name=adapter_name,
use_auth_token=constants.HF_API_TOKEN
)
else:
pipe.load_lora_weights(lora_weight, use_auth_token=constants.HF_API_TOKEN)
# Set the random seed for reproducibility
generator = Generator(device=device).manual_seed(seed)
conditions = []
# Handle conditioned image if provided
if conditioned_image is not None:
conditioned_image = crop_and_resize_image(conditioned_image, 1024, 1024)
condition = Condition("subject", conditioned_image)
conditions.append(condition)
# Prepare parameters for image generation
generate_params = {
"prompt": text,
"height": image_height,
"width": image_width,
"guidance_scale": guidance_scale,
"num_inference_steps": num_inference_steps,
"generator": generator,
"conditions": conditions if conditions else None
}
if additional_parameters:
generate_params.update(additional_parameters)
generate_params = {k: v for k, v in generate_params.items() if v is not None}
# Generate the image
result = pipe(**generate_params)
image = result.images[0]
pipe.unload_lora_weights()
# Clean up
del result
del conditions
del generator
del pipe
cuda.empty_cache()
cuda.ipc_collect()
return image
#@spaces.GPU(progress=gr.Progress(track_tqdm=True))
def generate_image_lowmem(
text,
neg_prompt=None,
model_name="black-forest-labs/FLUX.1-dev",
lora_weights=None,
conditioned_image=None,
image_width=1368,
image_height=848,
guidance_scale=3.5,
num_inference_steps=30,
seed=0,
true_cfg_scale=1.0,
pipeline_name="FluxPipeline",
strength=0.75,
additional_parameters=None,
progress=gr.Progress(track_tqdm=True)
):
# Retrieve the pipeline class from the mapping
pipeline_class = PIPELINE_CLASSES.get(pipeline_name)
if not pipeline_class:
raise ValueError(f"Unsupported pipeline type '{pipeline_name}'. "
f"Available options: {list(PIPELINE_CLASSES.keys())}")
initialize_cuda()
device = "cuda" if cuda.is_available() else "cpu"
from src.condition import Condition
print(f"device:{device}\nmodel_name:{model_name}\nlora_weights:{lora_weights}\n")
print(f"\n {get_torch_info()}\n")
# Disable gradient calculations
with no_grad():
# Initialize the pipeline inside the context manager
pipe = pipeline_class.from_pretrained(
model_name,
torch_dtype=bfloat16 if device == "cuda" else float32
).to(device)
# Optionally, don't use CPU offload if not necessary
# alternative version that may be more efficient
# pipe.enable_sequential_cpu_offload()
if pipeline_name == "FluxPipeline":
pipe.enable_model_cpu_offload()
pipe.vae.enable_slicing()
pipe.vae.enable_tiling()
else:
pipe.enable_model_cpu_offload()
# Access the tokenizer from the pipeline
tokenizer = pipe.tokenizer
# Check if add_prefix_space is set and convert to slow tokenizer if necessary
if getattr(tokenizer, 'add_prefix_space', False):
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False, device_map = 'cpu')
# Update the pipeline's tokenizer
pipe.tokenizer = tokenizer
pipe.to(device)
flash_attention_enabled = backends.cuda.flash_sdp_enabled()
if flash_attention_enabled == False:
#Enable xFormers memory-efficient attention (optional)
#pipe.enable_xformers_memory_efficient_attention()
print("\nEnabled xFormers memory-efficient attention.\n")
else:
pipe.attn_implementation="flash_attention_2"
print("\nEnabled flash_attention_2.\n")
condition_type = "subject"
# Load LoRA weights
# note: does not yet handle multiple LoRA weights with different names, needs .set_adapters(["depth", "hyper-sd"], adapter_weights=[0.85, 0.125])
if lora_weights:
for lora_weight in lora_weights:
lora_configs = constants.LORA_DETAILS.get(lora_weight, [])
lora_weight_set = False
if lora_configs:
for config in lora_configs:
# Load LoRA weights with optional weight_name and adapter_name
if 'weight_name' in config:
weight_name = config.get("weight_name")
adapter_name = config.get("adapter_name")
lora_collection = config.get("lora_collection")
if weight_name and adapter_name and lora_collection and lora_weight_set == False:
pipe.load_lora_weights(
lora_collection,
weight_name=weight_name,
adapter_name=adapter_name,
token=constants.HF_API_TOKEN
)
lora_weight_set = True
print(f"\npipe.load_lora_weights({lora_weight}, weight_name={weight_name}, adapter_name={adapter_name}, lora_collection={lora_collection}\n")
elif weight_name and adapter_name==None and lora_collection and lora_weight_set == False:
pipe.load_lora_weights(
lora_collection,
weight_name=weight_name,
token=constants.HF_API_TOKEN
)
lora_weight_set = True
print(f"\npipe.load_lora_weights({lora_weight}, weight_name={weight_name}, adapter_name={adapter_name}, lora_collection={lora_collection}\n")
elif weight_name and adapter_name and lora_weight_set == False:
pipe.load_lora_weights(
lora_weight,
weight_name=weight_name,
adapter_name=adapter_name,
token=constants.HF_API_TOKEN
)
lora_weight_set = True
print(f"\npipe.load_lora_weights({lora_weight}, weight_name={weight_name}, adapter_name={adapter_name}\n")
elif weight_name and adapter_name==None and lora_weight_set == False:
pipe.load_lora_weights(
lora_weight,
weight_name=weight_name,
token=constants.HF_API_TOKEN
)
lora_weight_set = True
print(f"\npipe.load_lora_weights({lora_weight}, weight_name={weight_name}, adapter_name={adapter_name}\n")
elif lora_weight_set == False:
pipe.load_lora_weights(
lora_weight,
token=constants.HF_API_TOKEN
)
lora_weight_set = True
print(f"\npipe.load_lora_weights({lora_weight}, weight_name={weight_name}, adapter_name={adapter_name}\n")
# Apply 'pipe' configurations if present
if 'pipe' in config:
pipe_config = config['pipe']
for method_name, params in pipe_config.items():
method = getattr(pipe, method_name, None)
if method:
print(f"Applying pipe method: {method_name} with params: {params}")
method(**params)
else:
print(f"Method {method_name} not found in pipe.")
if 'condition_type' in config:
condition_type = config['condition_type']
if condition_type == "coloring":
#pipe.enable_coloring()
print("\nEnabled coloring.\n")
elif condition_type == "deblurring":
#pipe.enable_deblurring()
print("\nEnabled deblurring.\n")
elif condition_type == "fill":
#pipe.enable_fill()
print("\nEnabled fill.\n")
elif condition_type == "depth":
#pipe.enable_depth()
print("\nEnabled depth.\n")
elif condition_type == "canny":
#pipe.enable_canny()
print("\nEnabled canny.\n")
elif condition_type == "subject":
#pipe.enable_subject()
print("\nEnabled subject.\n")
else:
print(f"Condition type {condition_type} not implemented.")
else:
pipe.load_lora_weights(lora_weight, use_auth_token=constants.HF_API_TOKEN)
# Set the random seed for reproducibility
generator = Generator(device=device).manual_seed(seed)
conditions = []
if conditioned_image is not None:
conditioned_image = crop_and_resize_image(conditioned_image, image_width, image_height)
condition = Condition(condition_type, conditioned_image)
conditions.append(condition)
print(f"\nAdded conditioned image.\n {conditioned_image.size}")
# Prepare the parameters for image generation
additional_parameters ={
"strength": strength,
"image": conditioned_image,
}
else:
print("\nNo conditioned image provided.")
if neg_prompt!=None:
true_cfg_scale=1.1
additional_parameters ={
"negative_prompt": neg_prompt,
"true_cfg_scale": true_cfg_scale,
}
# handle long prompts by splitting them
if approximate_token_count(text) > 76:
prompt, prompt2 = split_prompt_precisely(text)
prompt_parameters = {
"prompt" : prompt,
"prompt_2": prompt2
}
else:
prompt_parameters = {
"prompt" :text
}
additional_parameters.update(prompt_parameters)
# Combine all parameters
generate_params = {
"height": image_height,
"width": image_width,
"guidance_scale": guidance_scale,
"num_inference_steps": num_inference_steps,
"generator": generator, }
if additional_parameters:
generate_params.update(additional_parameters)
generate_params = {k: v for k, v in generate_params.items() if v is not None}
print(f"generate_params: {generate_params}")
# Generate the image
result = pipe(**generate_params)
image = result.images[0]
# Clean up
del result
del conditions
del generator
# Delete the pipeline and clear cache
del pipe
cuda.empty_cache()
cuda.ipc_collect()
print(cuda.memory_summary(device=None, abbreviated=False))
return image
def generate_ai_image_local (
map_option,
prompt_textbox_value,
neg_prompt_textbox_value,
model="black-forest-labs/FLUX.1-dev",
lora_weights=None,
conditioned_image=None,
height=512,
width=912,
num_inference_steps=30,
guidance_scale=3.5,
seed=777,
pipeline_name="FluxPipeline",
strength=0.75,
progress=gr.Progress(track_tqdm=True)
):
release_torch_resources()
print(f"Generating image with lowmem")
try:
if map_option != "Prompt":
prompt = constants.PROMPTS[map_option]
negative_prompt = constants.NEGATIVE_PROMPTS.get(map_option, "")
else:
prompt = prompt_textbox_value
negative_prompt = neg_prompt_textbox_value or ""
#full_prompt = f"{prompt} {negative_prompt}"
additional_parameters = {}
if lora_weights:
for lora_weight in lora_weights:
lora_configs = constants.LORA_DETAILS.get(lora_weight, [])
for config in lora_configs:
if 'parameters' in config:
additional_parameters.update(config['parameters'])
elif 'trigger_words' in config:
trigger_words = get_trigger_words(lora_weight)
prompt = f"{trigger_words} {prompt}"
for key, value in additional_parameters.items():
if key in ['height', 'width', 'num_inference_steps', 'max_sequence_length']:
additional_parameters[key] = int(value)
elif key in ['guidance_scale','true_cfg_scale']:
additional_parameters[key] = float(value)
height = additional_parameters.pop('height', height)
width = additional_parameters.pop('width', width)
num_inference_steps = additional_parameters.pop('num_inference_steps', num_inference_steps)
guidance_scale = additional_parameters.pop('guidance_scale', guidance_scale)
print("Generating image with the following parameters:")
print(f"Model: {model}")
print(f"LoRA Weights: {lora_weights}")
print(f"Prompt: {prompt}")
print(f"Neg Prompt: {negative_prompt}")
print(f"Height: {height}")
print(f"Width: {width}")
print(f"Number of Inference Steps: {num_inference_steps}")
print(f"Guidance Scale: {guidance_scale}")
print(f"Seed: {seed}")
print(f"Additional Parameters: {additional_parameters}")
print(f"Conditioned Image: {conditioned_image}")
print(f"Conditioned Image Strength: {strength}")
print(f"pipeline: {pipeline_name}")
image = generate_image_lowmem(
text=prompt,
model_name=model,
neg_prompt=negative_prompt,
lora_weights=lora_weights,
conditioned_image=conditioned_image,
image_width=width,
image_height=height,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
seed=seed,
pipeline_name=pipeline_name,
strength=strength,
additional_parameters=additional_parameters
)
with NamedTemporaryFile(delete=False, suffix=".png") as tmp:
image.save(tmp.name, format="PNG")
constants.temp_files.append(tmp.name)
print(f"Image saved to {tmp.name}")
gc.collect()
return tmp.name
except Exception as e:
print(f"Error generating AI image: {e}")
gc.collect()
return None
# does not work
def merge_LoRA_weights(model="black-forest-labs/FLUX.1-dev",
lora_weights="Borcherding/FLUX.1-dev-LoRA-FractalLand-v0.1"):
model_suffix = model.split("/")[-1]
if model_suffix not in lora_weights:
raise ValueError(f"The model suffix '{model_suffix}' must be in the lora_weights string '{lora_weights}' to proceed.")
pipe = FluxPipeline.from_pretrained(model, torch_dtype=bfloat16)
pipe.load_lora_weights(lora_weights)
pipe.save_lora_weights(os.getenv("TMPDIR"))
lora_name = lora_weights.split("/")[-1] + "-merged"
pipe.save_pretrained(lora_name)
pipe.unload_lora_weights()