tight-inversion commited on
Commit
0f1b614
·
1 Parent(s): e818c21

Align with pulid demo

Browse files
Files changed (1) hide show
  1. app.py +263 -273
app.py CHANGED
@@ -20,292 +20,283 @@ from pulid.pipeline_flux import PuLIDPipeline
20
  from pulid.utils import resize_numpy_image_long, seed_everything
21
 
22
 
23
- def get_models(name: str, device: torch.device, offload: bool, fp8: bool):
24
  t5 = load_t5(device, max_length=128)
25
  clip = load_clip(device)
26
  model = load_flow_model(name, device="cpu" if offload else device)
27
  model.eval()
28
- ae = load_ae(name, device=device)
29
  return model, ae, t5, clip
30
 
 
31
  class FluxGenerator:
32
- def __init__(self, model_name: str, device: str, offload: bool, aggressive_offload: bool, args):
33
- self.device = torch.device(device)
34
  self.offload = False
35
- self.aggressive_offload = aggressive_offload
36
- self.model_name = model_name
37
- self.model, self.ae, self.t5, self.clip_model = get_models(
38
- model_name,
39
  device=self.device,
40
  offload=self.offload,
41
- fp8=args.fp8,
42
- )
43
- self.pulid_model = PuLIDPipeline(self.model, device='cuda', weight_dtype=torch.bfloat16)
44
- self.pulid_model.load_pretrain(args.pretrained_model)
45
-
46
- @spaces.GPU(duration=30)
47
- @torch.inference_mode()
48
- def generate_image(
49
- self,
50
- prompt: str,
51
- id_image = None,
52
- width: int = 512,
53
- height: int = 512,
54
- num_steps: int = 20,
55
- start_step: int = 0,
56
- guidance: float = 4.0,
57
- seed: int = -1,
58
- id_weight: float = 1.0,
59
- neg_prompt: str = "",
60
- true_cfg: float = 1.0,
61
- timestep_to_start_cfg: int = 1,
62
- max_sequence_length: int = 128,
63
- gamma: float = 0.5,
64
- eta: float = 0.7,
65
- s: float = 0,
66
- tau: float = 5,
67
- perform_inversion: bool = True,
68
- perform_reconstruction: bool = False,
69
- perform_editing: bool = True,
70
- inversion_true_cfg: float = 1.0,
71
- ):
72
- """
73
- Core function that performs the image generation.
74
- """
75
- self.t5.to(self.device)
76
- self.clip_model.to(self.device)
77
- self.ae.to(self.device)
78
- self.model.to(self.device)
79
- self.t5.max_length = max_sequence_length
80
-
81
- # If seed == -1, random
82
- seed = int(seed)
83
- if seed == -1:
84
- seed = None
85
-
86
- opts = SamplingOptions(
87
- prompt=prompt,
88
- width=width,
89
- height=height,
90
- num_steps=num_steps,
91
- guidance=guidance,
92
- seed=seed,
93
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
- if opts.seed is None:
96
- opts.seed = torch.Generator(device="cpu").seed()
97
-
98
- seed_everything(opts.seed)
99
-
100
- print(f"Generating prompt: '{opts.prompt}' (seed={opts.seed})...")
101
- t0 = time.perf_counter()
102
-
103
- use_true_cfg = abs(true_cfg - 1.0) > 1e-6
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
- # 1) Prepare input noise
107
- noise = get_noise(
108
- num_samples=1,
109
- height=opts.height,
110
- width=opts.width,
111
- device=self.device,
112
- dtype=torch.bfloat16,
113
- seed=opts.seed,
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  )
115
- bs, c, h, w = noise.shape
116
- noise = rearrange(noise, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
117
- if noise.shape[0] == 1 and bs > 1:
118
- noise = repeat(noise, "1 ... -> bs ...", bs=bs)
119
- # Encode id_image directly here
120
- encode_t0 = time.perf_counter()
121
-
122
- # Resize image
123
- id_image = id_image.resize((opts.width, opts.height), resample=Image.LANCZOS)
124
-
125
- # Convert image to torch.Tensor and scale to [-1, 1]
126
- x = np.array(id_image).astype(np.float32)
127
- x = torch.from_numpy(x) # shape: (H, W, C)
128
- x = (x / 127.5) - 1.0 # now in [-1, 1]
129
- x = rearrange(x, "h w c -> 1 c h w") # shape: (1, C, H, W)
130
- x = x.to(self.device)
131
- # Encode with autocast
132
- self.ae.encoder.to(self.device)
133
- with torch.autocast(device_type=self.device.type, dtype=torch.bfloat16):
134
- x = self.ae.encode(x)
135
-
136
- x = x.to(torch.bfloat16)
137
-
138
- # Offload if needed
139
- if self.offload:
140
- self.ae.encoder.to("cpu")
141
- torch.cuda.empty_cache()
142
-
143
- encode_t1 = time.perf_counter()
144
- print(f"Encoded in {encode_t1 - encode_t0:.2f} seconds.")
145
-
146
- timesteps = get_schedule(opts.num_steps, x.shape[-1] * x.shape[-2] // 4, shift=False)
147
-
148
- # 2) Prepare text embeddings
149
- if self.offload:
150
- self.t5 = self.t5.to(self.device)
151
- self.clip_model = self.clip_model.to(self.device)
152
-
153
- inp = prepare(t5=self.t5, clip=self.clip_model, img=x, prompt=opts.prompt)
154
- inp_inversion = prepare(t5=self.t5, clip=self.clip_model, img=x, prompt="")
155
- inp_neg = None
156
- if use_true_cfg:
157
- inp_neg = prepare(t5=self.t5, clip=self.clip_model, img=x, prompt=neg_prompt)
158
-
159
- # Offload text encoders, load ID detection to GPU
160
- if self.offload:
161
- self.t5 = self.t5.cpu()
162
- self.clip_model = self.clip_model.cpu()
163
- torch.cuda.empty_cache()
164
- self.pulid_model.components_to_device(torch.device("cuda"))
165
-
166
- # 3) ID Embeddings (optional)
167
- id_embeddings = None
168
- uncond_id_embeddings = None
169
- if id_image is not None:
170
- id_image = np.array(id_image)
171
- id_image = resize_numpy_image_long(id_image, 1024)
172
- id_embeddings, uncond_id_embeddings = self.pulid_model.get_id_embedding(id_image, cal_uncond=use_true_cfg)
173
- else:
174
- id_embeddings = None
175
- uncond_id_embeddings = None
176
-
177
- # Offload ID pipeline, load main FLUX model to GPU
178
- if self.offload:
179
- self.pulid_model.components_to_device(torch.device("cpu"))
180
- torch.cuda.empty_cache()
181
-
182
- if self.aggressive_offload:
183
- self.model.components_to_gpu()
184
- else:
185
- self.model = self.model.to(self.device)
186
-
187
- y_0 = inp["img"].clone().detach()
188
-
189
- inverted = None
190
- if perform_inversion:
191
- inverted = rf_inversion(
192
- self.model,
193
- **inp_inversion,
194
- timesteps=timesteps,
195
- guidance=opts.guidance,
196
- id=id_embeddings,
197
- id_weight=id_weight,
198
- start_step=start_step,
199
- uncond_id=uncond_id_embeddings,
200
- true_cfg=inversion_true_cfg,
201
- timestep_to_start_cfg=timestep_to_start_cfg,
202
- neg_txt=inp_neg["txt"] if use_true_cfg else None,
203
- neg_txt_ids=inp_neg["txt_ids"] if use_true_cfg else None,
204
- neg_vec=inp_neg["vec"] if use_true_cfg else None,
205
- aggressive_offload=self.aggressive_offload,
206
- y_1=noise,
207
- gamma=gamma
208
- )
209
-
210
- img = inverted
211
- else:
212
- img = noise
213
- inp["img"] = img
214
- inp_inversion["img"] = img
215
-
216
- recon = None
217
- if perform_reconstruction:
218
- recon = rf_denoise(
219
- self.model,
220
- **inp_inversion,
221
- timesteps=timesteps,
222
- guidance=opts.guidance,
223
- id=id_embeddings,
224
- id_weight=id_weight,
225
- start_step=start_step,
226
- uncond_id=uncond_id_embeddings,
227
- true_cfg=inversion_true_cfg,
228
- timestep_to_start_cfg=timestep_to_start_cfg,
229
- neg_txt=inp_neg["txt"] if use_true_cfg else None,
230
- neg_txt_ids=inp_neg["txt_ids"] if use_true_cfg else None,
231
- neg_vec=inp_neg["vec"] if use_true_cfg else None,
232
- aggressive_offload=self.aggressive_offload,
233
- y_0=y_0,
234
- eta=eta,
235
- s=s,
236
- tau=tau,
237
- )
238
-
239
- edited = None
240
- if perform_editing:
241
- edited = rf_denoise(
242
- self.model,
243
- **inp,
244
- timesteps=timesteps,
245
- guidance=opts.guidance,
246
- id=id_embeddings,
247
- id_weight=id_weight,
248
- start_step=start_step,
249
- uncond_id=uncond_id_embeddings,
250
- true_cfg=true_cfg,
251
- timestep_to_start_cfg=timestep_to_start_cfg,
252
- neg_txt=inp_neg["txt"] if use_true_cfg else None,
253
- neg_txt_ids=inp_neg["txt_ids"] if use_true_cfg else None,
254
- neg_vec=inp_neg["vec"] if use_true_cfg else None,
255
- aggressive_offload=self.aggressive_offload,
256
- y_0=y_0,
257
- eta=eta,
258
- s=s,
259
- tau=tau,
260
- )
261
-
262
- # Offload flux model, load auto-decoder
263
- self.ae.decoder.to(self.device)
264
- if self.offload:
265
- self.model.cpu()
266
- torch.cuda.empty_cache()
267
- self.ae.decoder.to(x.device)
268
-
269
- # 5) Decode latents
270
- if edited is not None:
271
- edited = unpack(edited.float(), opts.height, opts.width)
272
- with torch.autocast(device_type=self.device.type, dtype=torch.bfloat16):
273
- edited = self.ae.decode(edited)
274
-
275
- if inverted is not None:
276
- inverted = unpack(inverted.float(), opts.height, opts.width)
277
- with torch.autocast(device_type=self.device.type, dtype=torch.bfloat16):
278
- inverted = self.ae.decode(inverted)
279
-
280
- if recon is not None:
281
- recon = unpack(recon.float(), opts.height, opts.width)
282
- with torch.autocast(device_type=self.device.type, dtype=torch.bfloat16):
283
- recon = self.ae.decode(recon)
284
-
285
- if self.offload:
286
- self.ae.decoder.cpu()
287
- torch.cuda.empty_cache()
288
-
289
- t1 = time.perf_counter()
290
- print(f"Done in {t1 - t0:.2f} seconds.")
291
-
292
- # Convert to PIL
293
- if edited is not None:
294
- edited = edited.clamp(-1, 1)
295
- edited = rearrange(edited[0], "c h w -> h w c")
296
- edited = Image.fromarray((127.5 * (edited + 1.0)).cpu().byte().numpy())
297
-
298
- if inverted is not None:
299
- inverted = inverted.clamp(-1, 1)
300
- inverted = rearrange(inverted[0], "c h w -> h w c")
301
- inverted = Image.fromarray((127.5 * (inverted + 1.0)).cpu().byte().numpy())
302
-
303
- if recon is not None:
304
- recon = recon.clamp(-1, 1)
305
- recon = rearrange(recon[0], "c h w -> h w c")
306
- recon = Image.fromarray((127.5 * (recon + 1.0)).cpu().byte().numpy())
307
-
308
- return edited, str(opts.seed), self.pulid_model.debug_img_list
309
 
310
  # <p style="font-size: 1rem; margin-bottom: 1.5rem;">Paper: <a href='https://arxiv.org/abs/2404.16022' target='_blank'>PuLID: Pure and Lightning ID Customization via Contrastive Alignment</a> | Codes: <a href='https://github.com/ToTheBeginning/PuLID' target='_blank'>GitHub</a></p>
311
  _HEADER_ = '''
@@ -322,7 +313,6 @@ _CITE_ = r"""
322
 
323
  def create_demo(args, model_name: str, device: str = "cuda" if torch.cuda.is_available() else "cpu",
324
  offload: bool = False, aggressive_offload: bool = False):
325
- generator = FluxGenerator(model_name, device, offload, aggressive_offload, args)
326
 
327
  with gr.Blocks() as demo:
328
  gr.Markdown(_HEADER_)
@@ -404,7 +394,7 @@ def create_demo(args, model_name: str, device: str = "cuda" if torch.cuda.is_ava
404
  gr.Examples(examples=example_inps, inputs=[prompt, id_image, id_weight, guidance, seed, true_cfg])
405
 
406
  generate_btn.click(
407
- fn=generator.generate_image,
408
  inputs=[prompt, id_image, width, height, num_steps, start_step, guidance, seed, id_weight, neg_prompt,
409
  true_cfg, timestep_to_start_cfg, max_sequence_length, gamma, eta, s, tau],
410
  outputs=[output_image, seed_output, intermediate_output],
 
20
  from pulid.utils import resize_numpy_image_long, seed_everything
21
 
22
 
23
+ def get_models(name: str, device: torch.device, offload: bool):
24
  t5 = load_t5(device, max_length=128)
25
  clip = load_clip(device)
26
  model = load_flow_model(name, device="cpu" if offload else device)
27
  model.eval()
28
+ ae = load_ae(name, device="cpu" if offload else device)
29
  return model, ae, t5, clip
30
 
31
+
32
  class FluxGenerator:
33
+ def __init__(self):
34
+ self.device = torch.device('cuda')
35
  self.offload = False
36
+ self.model_name = 'flux-dev'
37
+ self.model, self.ae, self.t5, self.clip, self.nsfw_classifier = get_models(
38
+ self.model_name,
 
39
  device=self.device,
40
  offload=self.offload,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  )
42
+ self.pulid_model = PuLIDPipeline(self.model, 'cuda', weight_dtype=torch.bfloat16)
43
+ self.pulid_model.load_pretrain()
44
+
45
+
46
+ flux_generator = FluxGenerator()
47
+
48
+
49
+ @spaces.GPU(duration=30)
50
+ @torch.inference_mode()
51
+ def generate_image(
52
+ prompt: str,
53
+ id_image = None,
54
+ width: int = 512,
55
+ height: int = 512,
56
+ num_steps: int = 20,
57
+ start_step: int = 0,
58
+ guidance: float = 4.0,
59
+ seed: int = -1,
60
+ id_weight: float = 1.0,
61
+ neg_prompt: str = "",
62
+ true_cfg: float = 1.0,
63
+ timestep_to_start_cfg: int = 1,
64
+ max_sequence_length: int = 128,
65
+ gamma: float = 0.5,
66
+ eta: float = 0.7,
67
+ s: float = 0,
68
+ tau: float = 5,
69
+ perform_inversion: bool = True,
70
+ perform_reconstruction: bool = False,
71
+ perform_editing: bool = True,
72
+ inversion_true_cfg: float = 1.0,
73
+ ):
74
+ """
75
+ Core function that performs the image generation.
76
+ """
77
+ # self.t5.to(self.device)
78
+ # self.clip_model.to(self.device)
79
+ # self.ae.to(self.device)
80
+ # self.model.to(self.device)
81
+
82
+ flux_generator.t5.max_length = max_sequence_length
83
+
84
+ # If seed == -1, random
85
+ seed = int(seed)
86
+ if seed == -1:
87
+ seed = None
88
+
89
+ opts = SamplingOptions(
90
+ prompt=prompt,
91
+ width=width,
92
+ height=height,
93
+ num_steps=num_steps,
94
+ guidance=guidance,
95
+ seed=seed,
96
+ )
97
+
98
+ if opts.seed is None:
99
+ opts.seed = torch.Generator(device="cpu").seed()
100
+
101
+ seed_everything(opts.seed)
102
+
103
+ print(f"Generating prompt: '{opts.prompt}' (seed={opts.seed})...")
104
+ t0 = time.perf_counter()
105
+
106
+ use_true_cfg = abs(true_cfg - 1.0) > 1e-6
107
+
108
+
109
+ # 1) Prepare input noise
110
+ noise = get_noise(
111
+ num_samples=1,
112
+ height=opts.height,
113
+ width=opts.width,
114
+ device=flux_generator.device,
115
+ dtype=torch.bfloat16,
116
+ seed=opts.seed,
117
+ )
118
+ bs, c, h, w = noise.shape
119
+ noise = rearrange(noise, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
120
+ if noise.shape[0] == 1 and bs > 1:
121
+ noise = repeat(noise, "1 ... -> bs ...", bs=bs)
122
+ # Encode id_image directly here
123
+ encode_t0 = time.perf_counter()
124
+
125
+ # Resize image
126
+ id_image = id_image.resize((opts.width, opts.height), resample=Image.LANCZOS)
127
+
128
+ # Convert image to torch.Tensor and scale to [-1, 1]
129
+ x = np.array(id_image).astype(np.float32)
130
+ x = torch.from_numpy(x) # shape: (H, W, C)
131
+ x = (x / 127.5) - 1.0 # now in [-1, 1]
132
+ x = rearrange(x, "h w c -> 1 c h w") # shape: (1, C, H, W)
133
+ x = x.to(flux_generator.device)
134
+ # Encode with autocast
135
+ with torch.autocast(device_type=flux_generator.device.type, dtype=torch.bfloat16):
136
+ x = flux_generator.ae.encode(x)
137
+
138
+ x = x.to(torch.bfloat16)
139
+
140
+ # Offload if needed
141
+ if flux_generator.offload:
142
+ flux_generator.ae.encoder.to("cpu")
143
+ torch.cuda.empty_cache()
144
+
145
+ encode_t1 = time.perf_counter()
146
+ print(f"Encoded in {encode_t1 - encode_t0:.2f} seconds.")
147
+
148
+ timesteps = get_schedule(opts.num_steps, x.shape[-1] * x.shape[-2] // 4, shift=False)
149
+
150
+ # 2) Prepare text embeddings
151
+ if flux_generator.offload:
152
+ flux_generator.t5 = flux_generator.t5.to(flux_generator.device)
153
+ flux_generator.clip_model = flux_generator.clip_model.to(flux_generator.device)
154
+
155
+ inp = prepare(t5=flux_generator.t5, clip=flux_generator.clip_model, img=x, prompt=opts.prompt)
156
+ inp_inversion = prepare(t5=flux_generator.t5, clip=flux_generator.clip_model, img=x, prompt="")
157
+ inp_neg = None
158
+ if use_true_cfg:
159
+ inp_neg = prepare(t5=flux_generator.t5, clip=flux_generator.clip_model, img=x, prompt=neg_prompt)
160
+
161
+ # Offload text encoders, load ID detection to GPU
162
+ if flux_generator.offload:
163
+ flux_generator.t5 = flux_generator.t5.cpu()
164
+ flux_generator.clip_model = flux_generator.clip_model.cpu()
165
+ torch.cuda.empty_cache()
166
+ flux_generator.pulid_model.components_to_device(torch.device("cuda"))
167
+
168
+ # 3) ID Embeddings (optional)
169
+ id_embeddings = None
170
+ uncond_id_embeddings = None
171
+ if id_image is not None:
172
+ id_image = np.array(id_image)
173
+ id_image = resize_numpy_image_long(id_image, 1024)
174
+ id_embeddings, uncond_id_embeddings = flux_generator.pulid_model.get_id_embedding(id_image, cal_uncond=use_true_cfg)
175
+ else:
176
+ id_embeddings = None
177
+ uncond_id_embeddings = None
178
 
179
+ y_0 = inp["img"].clone().detach()
180
+
181
+ inverted = None
182
+ if perform_inversion:
183
+ inverted = rf_inversion(
184
+ flux_generator.model,
185
+ **inp_inversion,
186
+ timesteps=timesteps,
187
+ guidance=opts.guidance,
188
+ id=id_embeddings,
189
+ id_weight=id_weight,
190
+ start_step=start_step,
191
+ uncond_id=uncond_id_embeddings,
192
+ true_cfg=inversion_true_cfg,
193
+ timestep_to_start_cfg=timestep_to_start_cfg,
194
+ neg_txt=inp_neg["txt"] if use_true_cfg else None,
195
+ neg_txt_ids=inp_neg["txt_ids"] if use_true_cfg else None,
196
+ neg_vec=inp_neg["vec"] if use_true_cfg else None,
197
+ aggressive_offload=flux_generator.aggressive_offload,
198
+ y_1=noise,
199
+ gamma=gamma
200
+ )
201
 
202
+ img = inverted
203
+ else:
204
+ img = noise
205
+ inp["img"] = img
206
+ inp_inversion["img"] = img
207
+
208
+ recon = None
209
+ if perform_reconstruction:
210
+ recon = rf_denoise(
211
+ flux_generator.model,
212
+ **inp_inversion,
213
+ timesteps=timesteps,
214
+ guidance=opts.guidance,
215
+ id=id_embeddings,
216
+ id_weight=id_weight,
217
+ start_step=start_step,
218
+ uncond_id=uncond_id_embeddings,
219
+ true_cfg=inversion_true_cfg,
220
+ timestep_to_start_cfg=timestep_to_start_cfg,
221
+ neg_txt=inp_neg["txt"] if use_true_cfg else None,
222
+ neg_txt_ids=inp_neg["txt_ids"] if use_true_cfg else None,
223
+ neg_vec=inp_neg["vec"] if use_true_cfg else None,
224
+ aggressive_offload=flux_generator.aggressive_offload,
225
+ y_0=y_0,
226
+ eta=eta,
227
+ s=s,
228
+ tau=tau,
229
+ )
230
 
231
+ edited = None
232
+ if perform_editing:
233
+ edited = rf_denoise(
234
+ flux_generator.model,
235
+ **inp,
236
+ timesteps=timesteps,
237
+ guidance=opts.guidance,
238
+ id=id_embeddings,
239
+ id_weight=id_weight,
240
+ start_step=start_step,
241
+ uncond_id=uncond_id_embeddings,
242
+ true_cfg=true_cfg,
243
+ timestep_to_start_cfg=timestep_to_start_cfg,
244
+ neg_txt=inp_neg["txt"] if use_true_cfg else None,
245
+ neg_txt_ids=inp_neg["txt_ids"] if use_true_cfg else None,
246
+ neg_vec=inp_neg["vec"] if use_true_cfg else None,
247
+ aggressive_offload=flux_generator.aggressive_offload,
248
+ y_0=y_0,
249
+ eta=eta,
250
+ s=s,
251
+ tau=tau,
252
  )
253
+
254
+ # Offload flux model, load auto-decoder
255
+ if flux_generator.offload:
256
+ flux_generator.model.cpu()
257
+ torch.cuda.empty_cache()
258
+ flux_generator.ae.decoder.to(x.device)
259
+
260
+ # 5) Decode latents
261
+ if edited is not None:
262
+ edited = unpack(edited.float(), opts.height, opts.width)
263
+ with torch.autocast(device_type=flux_generator.device.type, dtype=torch.bfloat16):
264
+ edited = flux_generator.ae.decode(edited)
265
+
266
+ if inverted is not None:
267
+ inverted = unpack(inverted.float(), opts.height, opts.width)
268
+ with torch.autocast(device_type=flux_generator.device.type, dtype=torch.bfloat16):
269
+ inverted = flux_generator.ae.decode(inverted)
270
+
271
+ if recon is not None:
272
+ recon = unpack(recon.float(), opts.height, opts.width)
273
+ with torch.autocast(device_type=flux_generator.device.type, dtype=torch.bfloat16):
274
+ recon = flux_generator.ae.decode(recon)
275
+
276
+ if flux_generator.offload:
277
+ flux_generator.ae.decoder.cpu()
278
+ torch.cuda.empty_cache()
279
+
280
+ t1 = time.perf_counter()
281
+ print(f"Done in {t1 - t0:.2f} seconds.")
282
+
283
+ # Convert to PIL
284
+ if edited is not None:
285
+ edited = edited.clamp(-1, 1)
286
+ edited = rearrange(edited[0], "c h w -> h w c")
287
+ edited = Image.fromarray((127.5 * (edited + 1.0)).cpu().byte().numpy())
288
+
289
+ if inverted is not None:
290
+ inverted = inverted.clamp(-1, 1)
291
+ inverted = rearrange(inverted[0], "c h w -> h w c")
292
+ inverted = Image.fromarray((127.5 * (inverted + 1.0)).cpu().byte().numpy())
293
+
294
+ if recon is not None:
295
+ recon = recon.clamp(-1, 1)
296
+ recon = rearrange(recon[0], "c h w -> h w c")
297
+ recon = Image.fromarray((127.5 * (recon + 1.0)).cpu().byte().numpy())
298
+
299
+ return edited, str(opts.seed), flux_generator.pulid_model.debug_img_list
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
300
 
301
  # <p style="font-size: 1rem; margin-bottom: 1.5rem;">Paper: <a href='https://arxiv.org/abs/2404.16022' target='_blank'>PuLID: Pure and Lightning ID Customization via Contrastive Alignment</a> | Codes: <a href='https://github.com/ToTheBeginning/PuLID' target='_blank'>GitHub</a></p>
302
  _HEADER_ = '''
 
313
 
314
  def create_demo(args, model_name: str, device: str = "cuda" if torch.cuda.is_available() else "cpu",
315
  offload: bool = False, aggressive_offload: bool = False):
 
316
 
317
  with gr.Blocks() as demo:
318
  gr.Markdown(_HEADER_)
 
394
  gr.Examples(examples=example_inps, inputs=[prompt, id_image, id_weight, guidance, seed, true_cfg])
395
 
396
  generate_btn.click(
397
+ fn=generate_image,
398
  inputs=[prompt, id_image, width, height, num_steps, start_step, guidance, seed, id_weight, neg_prompt,
399
  true_cfg, timestep_to_start_cfg, max_sequence_length, gamma, eta, s, tau],
400
  outputs=[output_image, seed_output, intermediate_output],