import gradio as gr import numpy as np import random import torch import spaces from PIL import Image import os import torch.cuda import gc from gradio_client import Client, file from pipeline_flux_ipa import FluxPipeline from transformer_flux import FluxTransformer2DModel from attention_processor import IPAFluxAttnProcessor2_0 from transformers import AutoProcessor, SiglipVisionModel from infer_flux_ipa_siglip import MLPProjModel, IPAdapter from huggingface_hub import hf_hub_download # Constants MAX_SEED = np.iinfo(np.int32).max MAX_IMAGE_SIZE = 1024 DEVICE = "cuda" if torch.cuda.is_available() else "cpu" image_encoder_path = "google/siglip-so400m-patch14-384" ipadapter_path = hf_hub_download(repo_id="InstantX/FLUX.1-dev-IP-Adapter", filename="ip-adapter.bin") transformer = FluxTransformer2DModel.from_pretrained( "black-forest-labs/FLUX.1-dev", subfolder="transformer", torch_dtype=torch.bfloat16 ) pipe = FluxPipeline.from_pretrained( "black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16 ) ip_model = IPAdapter(pipe, image_encoder_path, ipadapter_path, device="cuda", num_tokens=128) def clear_gpu_memory(): """Clear GPU memory and cache""" if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.ipc_collect() gc.collect() def resize_img(image, max_size=1024): width, height = image.size scaling_factor = min(max_size / width, max_size / height) new_width = int(width * scaling_factor) new_height = int(height * scaling_factor) return image.resize((new_width, new_height), Image.LANCZOS) @spaces.GPU def process_image( image, prompt: str, scale, seed: int, randomize_seed: bool, width: int, height: int, progress=gr.Progress(track_tqdm=True), ): clear_gpu_memory() if randomize_seed: seed = random.randint(0, MAX_SEED) if image is None: return None, seed # Ensure image is a PIL Image if not isinstance(image, Image.Image): image = Image.fromarray(image) image = resize_img(image) result = ip_model.generate( pil_image=image, prompt=prompt, scale=scale, width=width, height=height, seed=seed ) clear_gpu_memory() return result[0], seed def randomize_seed_fn(seed: int, randomize_seed: bool) -> int: if randomize_seed: seed = random.randint(0, 2000) return seed def create_image_sdxl( image_pil, prompt: str, n_prompt: str, scale, control_scale, guidance_scale: float, num_inference_steps: int, seed: int, target: str = "Load only style blocks", ): try: image_pil.save("./tmp.png", format="PNG") client = Client("Hatman/InstantStyle") result = client.predict( image_pil=file("./tmp.png"), prompt=prompt, n_prompt=n_prompt, scale=1, control_scale=control_scale, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, seed=seed, target=target, api_name="/create_image" ) return result except Exception as e: print(f"Error in create_image_sdxl: {str(e)}") return None # UI CSS css = """ ::-webkit-scrollbar { display: none; } #component-0 { max-width: 900px; margin: 0 auto; } .center-markdown { text-align: center !important; display: flex !important; justify-content: center !important; width: 100% !important; } .gradio-row { display: flex !important; gap: 1rem !important; flex-wrap: nowrap !important; } .gradio-column { flex: 1 1 0 !important; min-width: 0 !important; } """ title = r"""
Two different models using the IP Adapter with InstantStyle to preserve style across text-to-image generation.
""" article = r""" --- ```bibtex @article{wang2024instantstyle, title={InstantStyle: Free Lunch towards Style-Preserving in Text-to-Image Generation}, author={Wang, Haofan and Wang, Qixun and Bai, Xu and Qin, Zekui and Chen, Anthony}, journal={arXiv preprint arXiv:2404.02733}, year={2024} } ``` """ with gr.Blocks(css=css) as demo: gr.Markdown(title, elem_classes="center-markdown") gr.Markdown(description, elem_classes="center-markdown") with gr.Tab("FLUX"): with gr.Row(): with gr.Column(scale=1, min_width=300): input_image = gr.Image( label="Input Image", type="pil" ) scale = gr.Slider( label="Image Scale", minimum=0.0, maximum=1.0, step=0.1, value=0.7, ) prompt = gr.Text( label="Prompt", max_lines=1, placeholder="Enter your prompt", ) run_button = gr.Button("Generate", variant="primary") with gr.Column(scale=1, min_width=300): result = gr.Image(label="Result") with gr.Accordion("Advanced Settings", open=False): seed = gr.Slider( label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42, ) randomize_seed = gr.Checkbox(label="Randomize seed", value=True) with gr.Row(): width = gr.Slider( label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=512, ) height = gr.Slider( label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=512, ) run_button.click( fn=process_image, inputs=[ input_image, prompt, scale, seed, randomize_seed, width, height, ], outputs=[result, seed], ) with gr.Tab("SDXL"): with gr.Row(): with gr.Column(): image_pil = gr.Image(label="Style Image", type="pil") target_radio = gr.Radio( ["Load only style blocks", "Load only layout blocks", "Load style+layout block", "Load original IP-Adapter"], value="Load only style blocks", label="Style mode" ) prompt_textbox = gr.Textbox( label="Prompt", value="a dog, masterpiece, best quality, high quality" ) scale_slider_sdxl = gr.Slider( minimum=0, maximum=2.0, step=0.01, value=1.0, label="Scale" ) with gr.Accordion(open=False, label="Advanced Options"): control_scale_slider = gr.Slider( minimum=0, maximum=1.0, step=0.01, value=0.5, label="Controlnet conditioning scale" ) n_prompt_textbox = gr.Textbox( label="Neg Prompt", value="text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry" ) guidance_scale_slider = gr.Slider( minimum=1, maximum=15.0, step=0.01, value=5.0, label="guidance scale" ) num_inference_steps_slider = gr.Slider( minimum=5, maximum=50.0, step=1.0, value=20, label="num inference steps" ) seed_slider_sdxl = gr.Slider( minimum=-1000000, maximum=1000000, value=1, step=1, label="Seed Value" ) randomize_seed_checkbox_sdxl = gr.Checkbox(label="Randomize seed", value=True) generate_button = gr.Button("Generate Image", variant="primary") with gr.Column(): generated_image = gr.Image(label="Generated Image", show_label=False) generate_button.click( fn=randomize_seed_fn, inputs=[seed_slider_sdxl, randomize_seed_checkbox_sdxl], outputs=seed_slider_sdxl, queue=False, api_name=False, ).then( fn=create_image_sdxl, inputs=[ image_pil, prompt_textbox, n_prompt_textbox, scale_slider_sdxl, control_scale_slider, guidance_scale_slider, num_inference_steps_slider, seed_slider_sdxl, target_radio, ], outputs=[generated_image] ) gr.Markdown(article) if __name__ == "__main__": demo.launch( share=True, show_error=True, quiet=False )