import time import gradio as gr import spaces import numpy as np import torch from einops import rearrange, repeat from PIL import Image from flux.sampling import denoise, get_noise, get_schedule, prepare, rf_denoise, rf_inversion, unpack from flux.util import ( SamplingOptions, load_ae, load_clip, load_flow_model, load_flow_model_quintized, load_t5, ) from pulid.pipeline_flux import PuLIDPipeline from pulid.utils import resize_numpy_image_long, seed_everything def get_models(name: str, device: torch.device, offload: bool): t5 = load_t5(device, max_length=128) clip = load_clip(device) model = load_flow_model(name, device="cpu" if offload else device) model.eval() ae = load_ae(name, device="cpu" if offload else device) return model, ae, t5, clip class FluxGenerator: def __init__(self): self.device = torch.device('cuda') self.offload = False self.model_name = 'flux-dev' self.model, self.ae, self.t5, self.clip = get_models( self.model_name, device=self.device, offload=self.offload, ) self.pulid_model = PuLIDPipeline(self.model, 'cuda', weight_dtype=torch.bfloat16) self.pulid_model.load_pretrain() flux_generator = FluxGenerator() @spaces.GPU(duration=30) @torch.inference_mode() def generate_image( prompt: str, id_image = None, width: int = 512, height: int = 512, num_steps: int = 20, start_step: int = 0, guidance: float = 4.0, seed: int = -1, id_weight: float = 1.0, neg_prompt: str = "", true_cfg: float = 1.0, timestep_to_start_cfg: int = 1, max_sequence_length: int = 128, gamma: float = 0.5, eta: float = 0.7, s: float = 0, tau: float = 5, perform_inversion: bool = True, perform_reconstruction: bool = False, perform_editing: bool = True, inversion_true_cfg: float = 1.0, ): """ Core function that performs the image generation. """ # self.t5.to(self.device) # self.clip_model.to(self.device) # self.ae.to(self.device) # self.model.to(self.device) flux_generator.t5.max_length = max_sequence_length # If seed == -1, random seed = int(seed) if seed == -1: seed = None opts = SamplingOptions( prompt=prompt, width=width, height=height, num_steps=num_steps, guidance=guidance, seed=seed, ) if opts.seed is None: opts.seed = torch.Generator(device="cpu").seed() seed_everything(opts.seed) print(f"Generating prompt: '{opts.prompt}' (seed={opts.seed})...") t0 = time.perf_counter() use_true_cfg = abs(true_cfg - 1.0) > 1e-6 # 1) Prepare input noise noise = get_noise( num_samples=1, height=opts.height, width=opts.width, device=flux_generator.device, dtype=torch.bfloat16, seed=opts.seed, ) bs, c, h, w = noise.shape noise = rearrange(noise, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) if noise.shape[0] == 1 and bs > 1: noise = repeat(noise, "1 ... -> bs ...", bs=bs) # Encode id_image directly here encode_t0 = time.perf_counter() # Resize image id_image = id_image.resize((opts.width, opts.height), resample=Image.LANCZOS) # Convert image to torch.Tensor and scale to [-1, 1] x = np.array(id_image).astype(np.float32) x = torch.from_numpy(x) # shape: (H, W, C) x = (x / 127.5) - 1.0 # now in [-1, 1] x = rearrange(x, "h w c -> 1 c h w") # shape: (1, C, H, W) x = x.to(flux_generator.device) # Encode with autocast with torch.autocast(device_type=flux_generator.device.type, dtype=torch.bfloat16): x = flux_generator.ae.encode(x) x = x.to(torch.bfloat16) # Offload if needed if flux_generator.offload: flux_generator.ae.encoder.to("cpu") torch.cuda.empty_cache() encode_t1 = time.perf_counter() print(f"Encoded in {encode_t1 - encode_t0:.2f} seconds.") timesteps = get_schedule(opts.num_steps, x.shape[-1] * x.shape[-2] // 4, shift=False) # 2) Prepare text embeddings if flux_generator.offload: flux_generator.t5 = flux_generator.t5.to(flux_generator.device) flux_generator.clip_model = flux_generator.clip_model.to(flux_generator.device) inp = prepare(t5=flux_generator.t5, clip=flux_generator.clip_model, img=x, prompt=opts.prompt) inp_inversion = prepare(t5=flux_generator.t5, clip=flux_generator.clip_model, img=x, prompt="") inp_neg = None if use_true_cfg: inp_neg = prepare(t5=flux_generator.t5, clip=flux_generator.clip_model, img=x, prompt=neg_prompt) # Offload text encoders, load ID detection to GPU if flux_generator.offload: flux_generator.t5 = flux_generator.t5.cpu() flux_generator.clip_model = flux_generator.clip_model.cpu() torch.cuda.empty_cache() flux_generator.pulid_model.components_to_device(torch.device("cuda")) # 3) ID Embeddings (optional) id_embeddings = None uncond_id_embeddings = None if id_image is not None: id_image = np.array(id_image) id_image = resize_numpy_image_long(id_image, 1024) id_embeddings, uncond_id_embeddings = flux_generator.pulid_model.get_id_embedding(id_image, cal_uncond=use_true_cfg) else: id_embeddings = None uncond_id_embeddings = None y_0 = inp["img"].clone().detach() inverted = None if perform_inversion: inverted = rf_inversion( flux_generator.model, **inp_inversion, timesteps=timesteps, guidance=opts.guidance, id=id_embeddings, id_weight=id_weight, start_step=start_step, uncond_id=uncond_id_embeddings, true_cfg=inversion_true_cfg, timestep_to_start_cfg=timestep_to_start_cfg, neg_txt=inp_neg["txt"] if use_true_cfg else None, neg_txt_ids=inp_neg["txt_ids"] if use_true_cfg else None, neg_vec=inp_neg["vec"] if use_true_cfg else None, aggressive_offload=flux_generator.aggressive_offload, y_1=noise, gamma=gamma ) img = inverted else: img = noise inp["img"] = img inp_inversion["img"] = img recon = None if perform_reconstruction: recon = rf_denoise( flux_generator.model, **inp_inversion, timesteps=timesteps, guidance=opts.guidance, id=id_embeddings, id_weight=id_weight, start_step=start_step, uncond_id=uncond_id_embeddings, true_cfg=inversion_true_cfg, timestep_to_start_cfg=timestep_to_start_cfg, neg_txt=inp_neg["txt"] if use_true_cfg else None, neg_txt_ids=inp_neg["txt_ids"] if use_true_cfg else None, neg_vec=inp_neg["vec"] if use_true_cfg else None, aggressive_offload=flux_generator.aggressive_offload, y_0=y_0, eta=eta, s=s, tau=tau, ) edited = None if perform_editing: edited = rf_denoise( flux_generator.model, **inp, timesteps=timesteps, guidance=opts.guidance, id=id_embeddings, id_weight=id_weight, start_step=start_step, uncond_id=uncond_id_embeddings, true_cfg=true_cfg, timestep_to_start_cfg=timestep_to_start_cfg, neg_txt=inp_neg["txt"] if use_true_cfg else None, neg_txt_ids=inp_neg["txt_ids"] if use_true_cfg else None, neg_vec=inp_neg["vec"] if use_true_cfg else None, aggressive_offload=flux_generator.aggressive_offload, y_0=y_0, eta=eta, s=s, tau=tau, ) # Offload flux model, load auto-decoder if flux_generator.offload: flux_generator.model.cpu() torch.cuda.empty_cache() flux_generator.ae.decoder.to(x.device) # 5) Decode latents if edited is not None: edited = unpack(edited.float(), opts.height, opts.width) with torch.autocast(device_type=flux_generator.device.type, dtype=torch.bfloat16): edited = flux_generator.ae.decode(edited) if inverted is not None: inverted = unpack(inverted.float(), opts.height, opts.width) with torch.autocast(device_type=flux_generator.device.type, dtype=torch.bfloat16): inverted = flux_generator.ae.decode(inverted) if recon is not None: recon = unpack(recon.float(), opts.height, opts.width) with torch.autocast(device_type=flux_generator.device.type, dtype=torch.bfloat16): recon = flux_generator.ae.decode(recon) if flux_generator.offload: flux_generator.ae.decoder.cpu() torch.cuda.empty_cache() t1 = time.perf_counter() print(f"Done in {t1 - t0:.2f} seconds.") # Convert to PIL if edited is not None: edited = edited.clamp(-1, 1) edited = rearrange(edited[0], "c h w -> h w c") edited = Image.fromarray((127.5 * (edited + 1.0)).cpu().byte().numpy()) if inverted is not None: inverted = inverted.clamp(-1, 1) inverted = rearrange(inverted[0], "c h w -> h w c") inverted = Image.fromarray((127.5 * (inverted + 1.0)).cpu().byte().numpy()) if recon is not None: recon = recon.clamp(-1, 1) recon = rearrange(recon[0], "c h w -> h w c") recon = Image.fromarray((127.5 * (recon + 1.0)).cpu().byte().numpy()) return edited, str(opts.seed), flux_generator.pulid_model.debug_img_list #
Paper: PuLID: Pure and Lightning ID Customization via Contrastive Alignment | Codes: GitHub
_HEADER_ = '''