|
import spaces |
|
import time |
|
import os |
|
|
|
import gradio as gr |
|
import torch |
|
from einops import rearrange |
|
from PIL import Image |
|
|
|
from flux.cli import SamplingOptions |
|
from flux.sampling import denoise, get_noise, get_schedule, prepare, unpack |
|
from flux.util import load_ae, load_clip, load_flow_model, load_t5 |
|
from pulid.pipeline_flux import PuLIDPipeline |
|
from pulid.utils import resize_numpy_image_long |
|
|
|
|
|
def get_models(name: str, device: torch.device, offload: bool): |
|
t5 = load_t5(device, max_length=128) |
|
clip = load_clip(device) |
|
model = load_flow_model(name, device="cpu" if offload else device) |
|
model.eval() |
|
ae = load_ae(name, device="cpu" if offload else device) |
|
return model, ae, t5, clip |
|
|
|
|
|
class FluxGenerator: |
|
def __init__(self): |
|
self.device = torch.device('cuda') |
|
self.offload = False |
|
self.model_name = 'flux-dev' |
|
self.model, self.ae, self.t5, self.clip = get_models( |
|
self.model_name, |
|
device=self.device, |
|
offload=self.offload, |
|
) |
|
self.pulid_model = PuLIDPipeline(self.model, 'cuda', weight_dtype=torch.bfloat16) |
|
self.pulid_model.load_pretrain() |
|
|
|
|
|
flux_generator = FluxGenerator() |
|
|
|
|
|
@spaces.GPU |
|
@torch.inference_mode() |
|
def generate_image( |
|
width, |
|
height, |
|
num_steps, |
|
start_step, |
|
guidance, |
|
seed, |
|
prompt, |
|
id_image=None, |
|
id_weight=1.0, |
|
neg_prompt="", |
|
true_cfg=1.0, |
|
timestep_to_start_cfg=1, |
|
max_sequence_length=128, |
|
): |
|
flux_generator.t5.max_length = max_sequence_length |
|
|
|
seed = int(seed) |
|
if seed == -1: |
|
seed = None |
|
|
|
opts = SamplingOptions( |
|
prompt=prompt, |
|
width=width, |
|
height=height, |
|
num_steps=num_steps, |
|
guidance=guidance, |
|
seed=seed, |
|
) |
|
|
|
if opts.seed is None: |
|
opts.seed = torch.Generator(device="cpu").seed() |
|
|
|
t0 = time.perf_counter() |
|
|
|
use_true_cfg = abs(true_cfg - 1.0) > 1e-2 |
|
|
|
if id_image is not None: |
|
id_image = resize_numpy_image_long(id_image, 1024) |
|
id_embeddings, uncond_id_embeddings = flux_generator.pulid_model.get_id_embedding(id_image, cal_uncond=use_true_cfg) |
|
else: |
|
id_embeddings = None |
|
uncond_id_embeddings = None |
|
|
|
|
|
x = get_noise( |
|
1, |
|
opts.height, |
|
opts.width, |
|
device=flux_generator.device, |
|
dtype=torch.bfloat16, |
|
seed=opts.seed, |
|
) |
|
timesteps = get_schedule( |
|
opts.num_steps, |
|
x.shape[-1] * x.shape[-2] // 4, |
|
shift=True, |
|
) |
|
|
|
if flux_generator.offload: |
|
flux_generator.t5, flux_generator.clip = flux_generator.t5.to(flux_generator.device), flux_generator.clip.to(flux_generator.device) |
|
inp = prepare(t5=flux_generator.t5, clip=flux_generator.clip, img=x, prompt=opts.prompt) |
|
inp_neg = prepare(t5=flux_generator.t5, clip=flux_generator.clip, img=x, prompt=neg_prompt) if use_true_cfg else None |
|
|
|
|
|
if flux_generator.offload: |
|
flux_generator.t5, flux_generator.clip = flux_generator.t5.cpu(), flux_generator.clip.cpu() |
|
torch.cuda.empty_cache() |
|
flux_generator.model = flux_generator.model.to(flux_generator.device) |
|
|
|
|
|
x = denoise( |
|
flux_generator.model, **inp, timesteps=timesteps, guidance=opts.guidance, id=id_embeddings, id_weight=id_weight, |
|
start_step=start_step, uncond_id=uncond_id_embeddings, true_cfg=true_cfg, |
|
timestep_to_start_cfg=timestep_to_start_cfg, |
|
neg_txt=inp_neg["txt"] if use_true_cfg else None, |
|
neg_txt_ids=inp_neg["txt_ids"] if use_true_cfg else None, |
|
neg_vec=inp_neg["vec"] if use_true_cfg else None, |
|
) |
|
|
|
|
|
if flux_generator.offload: |
|
flux_generator.model.cpu() |
|
torch.cuda.empty_cache() |
|
flux_generator.ae.decoder.to(x.device) |
|
|
|
|
|
x = unpack(x.float(), opts.height, opts.width) |
|
with torch.autocast(device_type=flux_generator.device.type, dtype=torch.bfloat16): |
|
x = flux_generator.ae.decode(x) |
|
|
|
if flux_generator.offload: |
|
flux_generator.ae.decoder.cpu() |
|
torch.cuda.empty_cache() |
|
|
|
t1 = time.perf_counter() |
|
|
|
|
|
x = x.clamp(-1, 1) |
|
x = rearrange(x[0], "c h w -> h w c") |
|
|
|
img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy()) |
|
return img, str(opts.seed), flux_generator.pulid_model.debug_img_list |
|
|
|
|
|
css = """ |
|
footer { |
|
visibility: hidden; |
|
} |
|
""" |
|
|
|
def create_demo(args, model_name: str, device: str = "cuda" if torch.cuda.is_available() else "cpu", |
|
offload: bool = False): |
|
|
|
with gr.Blocks(theme="Nymbo/Nymbo_Theme", css=css) as demo: |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
prompt = gr.Textbox(label="Prompt", value="portrait, color, cinematic") |
|
id_image = gr.Image(label="ID Image") |
|
generate_btn = gr.Button("Generate") |
|
|
|
with gr.Column(): |
|
output_image = gr.Image(label="Generated Image") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
gr.Markdown("## Examples") |
|
|
|
all_examples = [ |
|
['a woman holding sign with glowing green text \"PuLID for FLUX\"', 'example_inputs/liuyifei.png'], |
|
['portrait, side view', 'example_inputs/liuyifei.png'], |
|
['white-haired woman with vr technology atmosphere', 'example_inputs/liuyifei.png'], |
|
['a young child is eating Icecream', 'example_inputs/liuyifei.png'], |
|
['a man is holding a sign with text \"PuLID for FLUX\", winter, snowing', 'example_inputs/pengwei.jpg'], |
|
['portrait, candle light', 'example_inputs/pengwei.jpg'], |
|
['profile shot dark photo of a 25-year-old male with smoke', 'example_inputs/pengwei.jpg'], |
|
['American Comics, 1boy', 'example_inputs/pengwei.jpg'], |
|
['portrait, pixar', 'example_inputs/pengwei.jpg'], |
|
['portrait, made of ice sculpture', 'example_inputs/lecun.jpg'], |
|
] |
|
|
|
example_images = [example[1] for example in all_examples] |
|
example_captions = [example[0] for example in all_examples] |
|
|
|
gallery = gr.Gallery( |
|
value=list(zip(example_images, example_captions)), |
|
label="Example Gallery", |
|
show_label=False, |
|
elem_id="gallery", |
|
columns=5, |
|
rows=2, |
|
object_fit="contain", |
|
height="auto" |
|
) |
|
|
|
def fill_example(evt: gr.SelectData): |
|
return [all_examples[evt.index][i] for i in [0, 1]] |
|
|
|
gallery.select( |
|
fill_example, |
|
None, |
|
[prompt, id_image], |
|
) |
|
|
|
generate_btn.click( |
|
fn=generate_image, |
|
inputs=[ |
|
gr.Slider(256, 1536, 896, step=16, visible=False), |
|
gr.Slider(256, 1536, 1152, step=16, visible=False), |
|
gr.Slider(1, 20, 20, step=1, visible=False), |
|
gr.Slider(0, 10, 0, step=1, visible=False), |
|
gr.Slider(1.0, 10.0, 4, step=0.1, visible=False), |
|
gr.Textbox(-1, visible=False), |
|
prompt, |
|
id_image, |
|
gr.Slider(0.0, 3.0, 1, step=0.05, visible=False), |
|
gr.Textbox("bad quality, worst quality, text, signature, watermark, extra limbs", visible=False), |
|
gr.Slider(1.0, 10.0, 1, step=0.1, visible=False), |
|
gr.Slider(0, 20, 1, step=1, visible=False), |
|
gr.Slider(128, 512, 128, step=128, visible=False), |
|
], |
|
outputs=[output_image], |
|
) |
|
|
|
return demo |
|
|
|
if __name__ == "__main__": |
|
import argparse |
|
|
|
parser = argparse.ArgumentParser(description="PuLID for FLUX.1-dev") |
|
parser.add_argument("--name", type=str, default="flux-dev", choices=list('flux-dev'), |
|
help="currently only support flux-dev") |
|
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", |
|
help="Device to use") |
|
parser.add_argument("--offload", action="store_true", help="Offload model to CPU when not in use") |
|
parser.add_argument("--port", type=int, default=8080, help="Port to use") |
|
parser.add_argument("--dev", action='store_true', help="Development mode") |
|
parser.add_argument("--pretrained_model", type=str, help='for development') |
|
args = parser.parse_args() |
|
|
|
import huggingface_hub |
|
huggingface_hub.login(os.getenv('HF_TOKEN')) |
|
|
|
demo = create_demo(args, args.name, args.device, args.offload) |
|
demo.launch() |