yanze commited on
Commit
3f80493
1 Parent(s): 60c3461

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +121 -119
app.py CHANGED
@@ -24,125 +24,128 @@ def get_models(name: str, device: torch.device, offload: bool):
24
 
25
 
26
  class FluxGenerator:
27
- def __init__(self, model_name: str, device: str, offload: bool, args):
28
- self.device = torch.device(device)
29
- self.offload = offload
30
- self.model_name = model_name
31
  self.model, self.ae, self.t5, self.clip = get_models(
32
- model_name,
33
  device=self.device,
34
  offload=self.offload,
35
  )
36
- self.pulid_model = PuLIDPipeline(self.model, device, weight_dtype=torch.bfloat16)
37
- self.pulid_model.load_pretrain(args.pretrained_model)
38
-
39
- @spaces.GPU
40
- def generate_image(
41
- self,
42
- width,
43
- height,
44
- num_steps,
45
- start_step,
46
- guidance,
47
- seed,
48
- prompt,
49
- id_image=None,
50
- id_weight=1.0,
51
- neg_prompt="",
52
- true_cfg=1.0,
53
- timestep_to_start_cfg=1,
54
- max_sequence_length=128,
55
- ):
56
- self.t5.max_length = max_sequence_length
57
-
58
- seed = int(seed)
59
- if seed == -1:
60
- seed = None
61
-
62
- opts = SamplingOptions(
63
- prompt=prompt,
64
- width=width,
65
- height=height,
66
- num_steps=num_steps,
67
- guidance=guidance,
68
- seed=seed,
69
- )
70
-
71
- if opts.seed is None:
72
- opts.seed = torch.Generator(device="cpu").seed()
73
- print(f"Generating '{opts.prompt}' with seed {opts.seed}")
74
- t0 = time.perf_counter()
75
-
76
- use_true_cfg = abs(true_cfg - 1.0) > 1e-2
77
-
78
- if id_image is not None:
79
- id_image = resize_numpy_image_long(id_image, 1024)
80
- id_embeddings, uncond_id_embeddings = self.pulid_model.get_id_embedding(id_image, cal_uncond=use_true_cfg)
81
- else:
82
- id_embeddings = None
83
- uncond_id_embeddings = None
84
-
85
- # prepare input
86
- x = get_noise(
87
- 1,
88
- opts.height,
89
- opts.width,
90
- device=self.device,
91
- dtype=torch.bfloat16,
92
- seed=opts.seed,
93
- )
94
- timesteps = get_schedule(
95
- opts.num_steps,
96
- x.shape[-1] * x.shape[-2] // 4,
97
- shift=True,
98
- )
99
-
100
- if self.offload:
101
- self.t5, self.clip = self.t5.to(self.device), self.clip.to(self.device)
102
- inp = prepare(t5=self.t5, clip=self.clip, img=x, prompt=opts.prompt)
103
- inp_neg = prepare(t5=self.t5, clip=self.clip, img=x, prompt=neg_prompt) if use_true_cfg else None
104
-
105
- # offload TEs to CPU, load model to gpu
106
- if self.offload:
107
- self.t5, self.clip = self.t5.cpu(), self.clip.cpu()
108
- torch.cuda.empty_cache()
109
- self.model = self.model.to(self.device)
110
-
111
- # denoise initial noise
112
- x = denoise(
113
- self.model, **inp, timesteps=timesteps, guidance=opts.guidance, id=id_embeddings, id_weight=id_weight,
114
- start_step=start_step, uncond_id=uncond_id_embeddings, true_cfg=true_cfg,
115
- timestep_to_start_cfg=timestep_to_start_cfg,
116
- neg_txt=inp_neg["txt"] if use_true_cfg else None,
117
- neg_txt_ids=inp_neg["txt_ids"] if use_true_cfg else None,
118
- neg_vec=inp_neg["vec"] if use_true_cfg else None,
119
- )
120
-
121
- # offload model, load autoencoder to gpu
122
- if self.offload:
123
- self.model.cpu()
124
- torch.cuda.empty_cache()
125
- self.ae.decoder.to(x.device)
126
-
127
- # decode latents to pixel space
128
- x = unpack(x.float(), opts.height, opts.width)
129
- with torch.autocast(device_type=self.device.type, dtype=torch.bfloat16):
130
- x = self.ae.decode(x)
131
-
132
- if self.offload:
133
- self.ae.decoder.cpu()
134
- torch.cuda.empty_cache()
135
-
136
- t1 = time.perf_counter()
137
-
138
- print(f"Done in {t1 - t0:.1f}s.")
139
- # bring into PIL format
140
- x = x.clamp(-1, 1)
141
- # x = embed_watermark(x.float())
142
- x = rearrange(x[0], "c h w -> h w c")
143
-
144
- img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
145
- return img, str(opts.seed), self.pulid_model.debug_img_list
 
 
 
146
 
147
  _HEADER_ = '''
148
  <div style="text-align: center; max-width: 650px; margin: 0 auto;">
@@ -169,8 +172,6 @@ If you have any questions or feedbacks, feel free to open a discussion or contac
169
 
170
  def create_demo(args, model_name: str, device: str = "cuda" if torch.cuda.is_available() else "cpu",
171
  offload: bool = False):
172
- generator = FluxGenerator(model_name, device, offload, args)
173
-
174
  with gr.Blocks() as demo:
175
  gr.Markdown(_HEADER_)
176
 
@@ -267,7 +268,7 @@ def create_demo(args, model_name: str, device: str = "cuda" if torch.cuda.is_ava
267
  label='true CFG')
268
 
269
  generate_btn.click(
270
- fn=generator.generate_image,
271
  inputs=[width, height, num_steps, start_step, guidance, seed, prompt, id_image, id_weight, neg_prompt,
272
  true_cfg, timestep_to_start_cfg, max_sequence_length],
273
  outputs=[output_image, seed_output, intermediate_output],
@@ -282,7 +283,8 @@ if __name__ == "__main__":
282
  parser = argparse.ArgumentParser(description="PuLID for FLUX.1-dev")
283
  parser.add_argument("--name", type=str, default="flux-dev", choices=list('flux-dev'),
284
  help="currently only support flux-dev")
285
- parser.add_argument("--device", type=str, default="cuda", help="Device to use")
 
286
  parser.add_argument("--offload", action="store_true", help="Offload model to CPU when not in use")
287
  parser.add_argument("--port", type=int, default=8080, help="Port to use")
288
  parser.add_argument("--dev", action='store_true', help="Development mode")
 
24
 
25
 
26
  class FluxGenerator:
27
+ def __init__(self):
28
+ self.device = torch.device('cuda')
29
+ self.offload = False
30
+ self.model_name = 'flux-dev'
31
  self.model, self.ae, self.t5, self.clip = get_models(
32
+ self.model_name,
33
  device=self.device,
34
  offload=self.offload,
35
  )
36
+ self.pulid_model = PuLIDPipeline(self.model, 'cuda', weight_dtype=torch.bfloat16)
37
+ self.pulid_model.load_pretrain()
38
+
39
+
40
+ flux_generator = FluxGenerator()
41
+
42
+
43
+ @spaces.GPU
44
+ def generate_image(
45
+ width,
46
+ height,
47
+ num_steps,
48
+ start_step,
49
+ guidance,
50
+ seed,
51
+ prompt,
52
+ id_image=None,
53
+ id_weight=1.0,
54
+ neg_prompt="",
55
+ true_cfg=1.0,
56
+ timestep_to_start_cfg=1,
57
+ max_sequence_length=128,
58
+ ):
59
+ flux_generator.t5.max_length = max_sequence_length
60
+
61
+ seed = int(seed)
62
+ if seed == -1:
63
+ seed = None
64
+
65
+ opts = SamplingOptions(
66
+ prompt=prompt,
67
+ width=width,
68
+ height=height,
69
+ num_steps=num_steps,
70
+ guidance=guidance,
71
+ seed=seed,
72
+ )
73
+
74
+ if opts.seed is None:
75
+ opts.seed = torch.Generator(device="cpu").seed()
76
+ print(f"Generating '{opts.prompt}' with seed {opts.seed}")
77
+ t0 = time.perf_counter()
78
+
79
+ use_true_cfg = abs(true_cfg - 1.0) > 1e-2
80
+
81
+ if id_image is not None:
82
+ id_image = resize_numpy_image_long(id_image, 1024)
83
+ id_embeddings, uncond_id_embeddings = flux_generator.pulid_model.get_id_embedding(id_image, cal_uncond=use_true_cfg)
84
+ else:
85
+ id_embeddings = None
86
+ uncond_id_embeddings = None
87
+
88
+ # prepare input
89
+ x = get_noise(
90
+ 1,
91
+ opts.height,
92
+ opts.width,
93
+ device=flux_generator.device,
94
+ dtype=torch.bfloat16,
95
+ seed=opts.seed,
96
+ )
97
+ timesteps = get_schedule(
98
+ opts.num_steps,
99
+ x.shape[-1] * x.shape[-2] // 4,
100
+ shift=True,
101
+ )
102
+
103
+ if flux_generator.offload:
104
+ flux_generator.t5, flux_generator.clip = flux_generator.t5.to(flux_generator.device), flux_generator.clip.to(flux_generator.device)
105
+ inp = prepare(t5=flux_generator.t5, clip=flux_generator.clip, img=x, prompt=opts.prompt)
106
+ inp_neg = prepare(t5=flux_generator.t5, clip=flux_generator.clip, img=x, prompt=neg_prompt) if use_true_cfg else None
107
+
108
+ # offload TEs to CPU, load model to gpu
109
+ if flux_generator.offload:
110
+ flux_generator.t5, flux_generator.clip = flux_generator.t5.cpu(), flux_generator.clip.cpu()
111
+ torch.cuda.empty_cache()
112
+ flux_generator.model = flux_generator.model.to(flux_generator.device)
113
+
114
+ # denoise initial noise
115
+ x = denoise(
116
+ flux_generator.model, **inp, timesteps=timesteps, guidance=opts.guidance, id=id_embeddings, id_weight=id_weight,
117
+ start_step=start_step, uncond_id=uncond_id_embeddings, true_cfg=true_cfg,
118
+ timestep_to_start_cfg=timestep_to_start_cfg,
119
+ neg_txt=inp_neg["txt"] if use_true_cfg else None,
120
+ neg_txt_ids=inp_neg["txt_ids"] if use_true_cfg else None,
121
+ neg_vec=inp_neg["vec"] if use_true_cfg else None,
122
+ )
123
+
124
+ # offload model, load autoencoder to gpu
125
+ if flux_generator.offload:
126
+ flux_generator.model.cpu()
127
+ torch.cuda.empty_cache()
128
+ flux_generator.ae.decoder.to(x.device)
129
+
130
+ # decode latents to pixel space
131
+ x = unpack(x.float(), opts.height, opts.width)
132
+ with torch.autocast(device_type=flux_generator.device.type, dtype=torch.bfloat16):
133
+ x = flux_generator.ae.decode(x)
134
+
135
+ if flux_generator.offload:
136
+ flux_generator.ae.decoder.cpu()
137
+ torch.cuda.empty_cache()
138
+
139
+ t1 = time.perf_counter()
140
+
141
+ print(f"Done in {t1 - t0:.1f}s.")
142
+ # bring into PIL format
143
+ x = x.clamp(-1, 1)
144
+ # x = embed_watermark(x.float())
145
+ x = rearrange(x[0], "c h w -> h w c")
146
+
147
+ img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
148
+ return img, str(opts.seed), flux_generator.pulid_model.debug_img_list
149
 
150
  _HEADER_ = '''
151
  <div style="text-align: center; max-width: 650px; margin: 0 auto;">
 
172
 
173
  def create_demo(args, model_name: str, device: str = "cuda" if torch.cuda.is_available() else "cpu",
174
  offload: bool = False):
 
 
175
  with gr.Blocks() as demo:
176
  gr.Markdown(_HEADER_)
177
 
 
268
  label='true CFG')
269
 
270
  generate_btn.click(
271
+ fn=generate_image,
272
  inputs=[width, height, num_steps, start_step, guidance, seed, prompt, id_image, id_weight, neg_prompt,
273
  true_cfg, timestep_to_start_cfg, max_sequence_length],
274
  outputs=[output_image, seed_output, intermediate_output],
 
283
  parser = argparse.ArgumentParser(description="PuLID for FLUX.1-dev")
284
  parser.add_argument("--name", type=str, default="flux-dev", choices=list('flux-dev'),
285
  help="currently only support flux-dev")
286
+ parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu",
287
+ help="Device to use")
288
  parser.add_argument("--offload", action="store_true", help="Offload model to CPU when not in use")
289
  parser.add_argument("--port", type=int, default=8080, help="Port to use")
290
  parser.add_argument("--dev", action='store_true', help="Development mode")