|
|
|
import os |
|
import utils.constants as constants |
|
|
|
import gradio as gr |
|
from torch import __version__ as torch_version_, version, cuda, bfloat16, float32, Generator, no_grad, backends |
|
from diffusers import FluxPipeline,FluxImg2ImgPipeline,FluxControlPipeline |
|
import accelerate |
|
from transformers import AutoTokenizer |
|
import safetensors |
|
|
|
|
|
|
|
from PIL import Image |
|
from tempfile import NamedTemporaryFile |
|
|
|
from utils.image_utils import ( |
|
crop_and_resize_image, |
|
) |
|
from utils.version_info import ( |
|
get_torch_info, |
|
|
|
|
|
|
|
initialize_cuda, |
|
release_torch_resources |
|
) |
|
import gc |
|
from utils.lora_details import get_trigger_words, approximate_token_count, split_prompt_precisely |
|
|
|
|
|
|
|
import warnings |
|
warnings.filterwarnings("ignore", message=".*Torch was not compiled with flash attention.*") |
|
|
|
|
|
|
|
PIPELINE_CLASSES = { |
|
"FluxPipeline": FluxPipeline, |
|
"FluxImg2ImgPipeline": FluxImg2ImgPipeline, |
|
"FluxControlPipeline": FluxControlPipeline |
|
} |
|
|
|
def generate_image_from_text( |
|
text, |
|
model_name="black-forest-labs/FLUX.1-dev", |
|
lora_weights=None, |
|
conditioned_image=None, |
|
image_width=1344, |
|
image_height=848, |
|
guidance_scale=3.5, |
|
num_inference_steps=50, |
|
seed=0, |
|
additional_parameters=None, |
|
progress=gr.Progress(track_tqdm=True) |
|
): |
|
from src.condition import Condition |
|
device = "cuda" if cuda.is_available() else "cpu" |
|
print(f"device:{device}\nmodel_name:{model_name}\n") |
|
|
|
|
|
pipe = FluxPipeline.from_pretrained( |
|
model_name, |
|
torch_dtype=bfloat16 if device == "cuda" else float32 |
|
).to(device) |
|
pipe.enable_model_cpu_offload() |
|
|
|
|
|
tokenizer = pipe.tokenizer |
|
|
|
|
|
if getattr(tokenizer, 'add_prefix_space', False): |
|
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) |
|
|
|
pipe.tokenizer = tokenizer |
|
|
|
|
|
if lora_weights: |
|
for lora_weight in lora_weights: |
|
lora_configs = constants.LORA_DETAILS.get(lora_weight, []) |
|
if lora_configs: |
|
for config in lora_configs: |
|
weight_name = config.get("weight_name") |
|
adapter_name = config.get("adapter_name") |
|
pipe.load_lora_weights( |
|
lora_weight, |
|
weight_name=weight_name, |
|
adapter_name=adapter_name, |
|
use_auth_token=constants.HF_API_TOKEN |
|
) |
|
else: |
|
pipe.load_lora_weights(lora_weight, use_auth_token=constants.HF_API_TOKEN) |
|
|
|
|
|
generator = Generator(device=device).manual_seed(seed) |
|
conditions = [] |
|
|
|
|
|
if conditioned_image is not None: |
|
conditioned_image = crop_and_resize_image(conditioned_image, 1024, 1024) |
|
condition = Condition("subject", conditioned_image) |
|
conditions.append(condition) |
|
|
|
|
|
generate_params = { |
|
"prompt": text, |
|
"height": image_height, |
|
"width": image_width, |
|
"guidance_scale": guidance_scale, |
|
"num_inference_steps": num_inference_steps, |
|
"generator": generator, |
|
"conditions": conditions if conditions else None |
|
} |
|
|
|
if additional_parameters: |
|
generate_params.update(additional_parameters) |
|
generate_params = {k: v for k, v in generate_params.items() if v is not None} |
|
|
|
|
|
result = pipe(**generate_params) |
|
image = result.images[0] |
|
pipe.unload_lora_weights() |
|
|
|
|
|
del result |
|
del conditions |
|
del generator |
|
del pipe |
|
cuda.empty_cache() |
|
cuda.ipc_collect() |
|
|
|
return image |
|
|
|
|
|
def generate_image_lowmem( |
|
text, |
|
neg_prompt=None, |
|
model_name="black-forest-labs/FLUX.1-dev", |
|
lora_weights=None, |
|
conditioned_image=None, |
|
image_width=1368, |
|
image_height=848, |
|
guidance_scale=3.5, |
|
num_inference_steps=30, |
|
seed=0, |
|
true_cfg_scale=1.0, |
|
pipeline_name="FluxPipeline", |
|
strength=0.75, |
|
additional_parameters=None, |
|
progress=gr.Progress(track_tqdm=True) |
|
): |
|
|
|
pipeline_class = PIPELINE_CLASSES.get(pipeline_name) |
|
if not pipeline_class: |
|
raise ValueError(f"Unsupported pipeline type '{pipeline_name}'. " |
|
f"Available options: {list(PIPELINE_CLASSES.keys())}") |
|
|
|
initialize_cuda() |
|
device = "cuda" if cuda.is_available() else "cpu" |
|
from src.condition import Condition |
|
|
|
print(f"device:{device}\nmodel_name:{model_name}\nlora_weights:{lora_weights}\n") |
|
print(f"\n {get_torch_info()}\n") |
|
|
|
with no_grad(): |
|
|
|
pipe = pipeline_class.from_pretrained( |
|
model_name, |
|
torch_dtype=bfloat16 if device == "cuda" else float32 |
|
).to(device) |
|
|
|
|
|
|
|
|
|
if pipeline_name == "FluxPipeline": |
|
pipe.enable_model_cpu_offload() |
|
pipe.vae.enable_slicing() |
|
pipe.vae.enable_tiling() |
|
else: |
|
pipe.enable_model_cpu_offload() |
|
|
|
|
|
tokenizer = pipe.tokenizer |
|
|
|
|
|
if getattr(tokenizer, 'add_prefix_space', False): |
|
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False, device_map = 'cpu') |
|
|
|
pipe.tokenizer = tokenizer |
|
pipe.to(device) |
|
|
|
flash_attention_enabled = backends.cuda.flash_sdp_enabled() |
|
if flash_attention_enabled == False: |
|
|
|
|
|
print("\nEnabled xFormers memory-efficient attention.\n") |
|
else: |
|
pipe.attn_implementation="flash_attention_2" |
|
print("\nEnabled flash_attention_2.\n") |
|
|
|
condition_type = "subject" |
|
|
|
|
|
if lora_weights: |
|
for lora_weight in lora_weights: |
|
lora_configs = constants.LORA_DETAILS.get(lora_weight, []) |
|
lora_weight_set = False |
|
if lora_configs: |
|
for config in lora_configs: |
|
|
|
if 'weight_name' in config: |
|
weight_name = config.get("weight_name") |
|
adapter_name = config.get("adapter_name") |
|
lora_collection = config.get("lora_collection") |
|
if weight_name and adapter_name and lora_collection and lora_weight_set == False: |
|
pipe.load_lora_weights( |
|
lora_collection, |
|
weight_name=weight_name, |
|
adapter_name=adapter_name, |
|
token=constants.HF_API_TOKEN |
|
) |
|
lora_weight_set = True |
|
print(f"\npipe.load_lora_weights({lora_weight}, weight_name={weight_name}, adapter_name={adapter_name}, lora_collection={lora_collection}\n") |
|
elif weight_name and adapter_name==None and lora_collection and lora_weight_set == False: |
|
pipe.load_lora_weights( |
|
lora_collection, |
|
weight_name=weight_name, |
|
token=constants.HF_API_TOKEN |
|
) |
|
lora_weight_set = True |
|
print(f"\npipe.load_lora_weights({lora_weight}, weight_name={weight_name}, adapter_name={adapter_name}, lora_collection={lora_collection}\n") |
|
elif weight_name and adapter_name and lora_weight_set == False: |
|
pipe.load_lora_weights( |
|
lora_weight, |
|
weight_name=weight_name, |
|
adapter_name=adapter_name, |
|
token=constants.HF_API_TOKEN |
|
) |
|
lora_weight_set = True |
|
print(f"\npipe.load_lora_weights({lora_weight}, weight_name={weight_name}, adapter_name={adapter_name}\n") |
|
elif weight_name and adapter_name==None and lora_weight_set == False: |
|
pipe.load_lora_weights( |
|
lora_weight, |
|
weight_name=weight_name, |
|
token=constants.HF_API_TOKEN |
|
) |
|
lora_weight_set = True |
|
print(f"\npipe.load_lora_weights({lora_weight}, weight_name={weight_name}, adapter_name={adapter_name}\n") |
|
elif lora_weight_set == False: |
|
pipe.load_lora_weights( |
|
lora_weight, |
|
token=constants.HF_API_TOKEN |
|
) |
|
lora_weight_set = True |
|
print(f"\npipe.load_lora_weights({lora_weight}, weight_name={weight_name}, adapter_name={adapter_name}\n") |
|
|
|
if 'pipe' in config: |
|
pipe_config = config['pipe'] |
|
for method_name, params in pipe_config.items(): |
|
method = getattr(pipe, method_name, None) |
|
if method: |
|
print(f"Applying pipe method: {method_name} with params: {params}") |
|
method(**params) |
|
else: |
|
print(f"Method {method_name} not found in pipe.") |
|
if 'condition_type' in config: |
|
condition_type = config['condition_type'] |
|
if condition_type == "coloring": |
|
|
|
print("\nEnabled coloring.\n") |
|
elif condition_type == "deblurring": |
|
|
|
print("\nEnabled deblurring.\n") |
|
elif condition_type == "fill": |
|
|
|
print("\nEnabled fill.\n") |
|
elif condition_type == "depth": |
|
|
|
print("\nEnabled depth.\n") |
|
elif condition_type == "canny": |
|
|
|
print("\nEnabled canny.\n") |
|
elif condition_type == "subject": |
|
|
|
print("\nEnabled subject.\n") |
|
else: |
|
print(f"Condition type {condition_type} not implemented.") |
|
else: |
|
pipe.load_lora_weights(lora_weight, use_auth_token=constants.HF_API_TOKEN) |
|
|
|
generator = Generator(device=device).manual_seed(seed) |
|
conditions = [] |
|
if conditioned_image is not None: |
|
conditioned_image = crop_and_resize_image(conditioned_image, image_width, image_height) |
|
condition = Condition(condition_type, conditioned_image) |
|
conditions.append(condition) |
|
print(f"\nAdded conditioned image.\n {conditioned_image.size}") |
|
|
|
additional_parameters ={ |
|
"strength": strength, |
|
"image": conditioned_image, |
|
} |
|
else: |
|
print("\nNo conditioned image provided.") |
|
if neg_prompt!=None: |
|
true_cfg_scale=1.1 |
|
additional_parameters ={ |
|
"negative_prompt": neg_prompt, |
|
"true_cfg_scale": true_cfg_scale, |
|
} |
|
|
|
if approximate_token_count(text) > 76: |
|
prompt, prompt2 = split_prompt_precisely(text) |
|
prompt_parameters = { |
|
"prompt" : prompt, |
|
"prompt_2": prompt2 |
|
} |
|
else: |
|
prompt_parameters = { |
|
"prompt" :text |
|
} |
|
additional_parameters.update(prompt_parameters) |
|
|
|
generate_params = { |
|
"height": image_height, |
|
"width": image_width, |
|
"guidance_scale": guidance_scale, |
|
"num_inference_steps": num_inference_steps, |
|
"generator": generator, } |
|
if additional_parameters: |
|
generate_params.update(additional_parameters) |
|
generate_params = {k: v for k, v in generate_params.items() if v is not None} |
|
print(f"generate_params: {generate_params}") |
|
|
|
result = pipe(**generate_params) |
|
image = result.images[0] |
|
|
|
del result |
|
del conditions |
|
del generator |
|
|
|
del pipe |
|
cuda.empty_cache() |
|
cuda.ipc_collect() |
|
print(cuda.memory_summary(device=None, abbreviated=False)) |
|
|
|
return image |
|
|
|
def generate_ai_image_local ( |
|
map_option, |
|
prompt_textbox_value, |
|
neg_prompt_textbox_value, |
|
model="black-forest-labs/FLUX.1-dev", |
|
lora_weights=None, |
|
conditioned_image=None, |
|
height=512, |
|
width=912, |
|
num_inference_steps=30, |
|
guidance_scale=3.5, |
|
seed=777, |
|
pipeline_name="FluxPipeline", |
|
strength=0.75, |
|
progress=gr.Progress(track_tqdm=True) |
|
): |
|
release_torch_resources() |
|
print(f"Generating image with lowmem") |
|
try: |
|
if map_option != "Prompt": |
|
prompt = constants.PROMPTS[map_option] |
|
negative_prompt = constants.NEGATIVE_PROMPTS.get(map_option, "") |
|
else: |
|
prompt = prompt_textbox_value |
|
negative_prompt = neg_prompt_textbox_value or "" |
|
|
|
additional_parameters = {} |
|
if lora_weights: |
|
for lora_weight in lora_weights: |
|
lora_configs = constants.LORA_DETAILS.get(lora_weight, []) |
|
for config in lora_configs: |
|
if 'parameters' in config: |
|
additional_parameters.update(config['parameters']) |
|
elif 'trigger_words' in config: |
|
trigger_words = get_trigger_words(lora_weight) |
|
prompt = f"{trigger_words} {prompt}" |
|
for key, value in additional_parameters.items(): |
|
if key in ['height', 'width', 'num_inference_steps', 'max_sequence_length']: |
|
additional_parameters[key] = int(value) |
|
elif key in ['guidance_scale','true_cfg_scale']: |
|
additional_parameters[key] = float(value) |
|
height = additional_parameters.pop('height', height) |
|
width = additional_parameters.pop('width', width) |
|
num_inference_steps = additional_parameters.pop('num_inference_steps', num_inference_steps) |
|
guidance_scale = additional_parameters.pop('guidance_scale', guidance_scale) |
|
print("Generating image with the following parameters:") |
|
print(f"Model: {model}") |
|
print(f"LoRA Weights: {lora_weights}") |
|
print(f"Prompt: {prompt}") |
|
print(f"Neg Prompt: {negative_prompt}") |
|
print(f"Height: {height}") |
|
print(f"Width: {width}") |
|
print(f"Number of Inference Steps: {num_inference_steps}") |
|
print(f"Guidance Scale: {guidance_scale}") |
|
print(f"Seed: {seed}") |
|
print(f"Additional Parameters: {additional_parameters}") |
|
print(f"Conditioned Image: {conditioned_image}") |
|
print(f"Conditioned Image Strength: {strength}") |
|
print(f"pipeline: {pipeline_name}") |
|
image = generate_image_lowmem( |
|
text=prompt, |
|
model_name=model, |
|
neg_prompt=negative_prompt, |
|
lora_weights=lora_weights, |
|
conditioned_image=conditioned_image, |
|
image_width=width, |
|
image_height=height, |
|
guidance_scale=guidance_scale, |
|
num_inference_steps=num_inference_steps, |
|
seed=seed, |
|
pipeline_name=pipeline_name, |
|
strength=strength, |
|
additional_parameters=additional_parameters |
|
) |
|
with NamedTemporaryFile(delete=False, suffix=".png") as tmp: |
|
image.save(tmp.name, format="PNG") |
|
constants.temp_files.append(tmp.name) |
|
print(f"Image saved to {tmp.name}") |
|
gc.collect() |
|
return tmp.name |
|
except Exception as e: |
|
print(f"Error generating AI image: {e}") |
|
gc.collect() |
|
return None |
|
|
|
|
|
def merge_LoRA_weights(model="black-forest-labs/FLUX.1-dev", |
|
lora_weights="Borcherding/FLUX.1-dev-LoRA-FractalLand-v0.1"): |
|
|
|
model_suffix = model.split("/")[-1] |
|
if model_suffix not in lora_weights: |
|
raise ValueError(f"The model suffix '{model_suffix}' must be in the lora_weights string '{lora_weights}' to proceed.") |
|
|
|
pipe = FluxPipeline.from_pretrained(model, torch_dtype=bfloat16) |
|
pipe.load_lora_weights(lora_weights) |
|
pipe.save_lora_weights(os.getenv("TMPDIR")) |
|
lora_name = lora_weights.split("/")[-1] + "-merged" |
|
pipe.save_pretrained(lora_name) |
|
pipe.unload_lora_weights() |
|
|
|
|