Spaces:
openfree
/
Running on Zero

openfree commited on
Commit
35f124b
โ€ข
1 Parent(s): 92fcf53

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +249 -1
app.py CHANGED
@@ -1,2 +1,250 @@
 
 
1
  import os
2
- exec(os.environ.get('APP'))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import time
3
  import os
4
+
5
+ import gradio as gr
6
+ import torch
7
+ from einops import rearrange
8
+ from PIL import Image
9
+
10
+ from flux.cli import SamplingOptions
11
+ from flux.sampling import denoise, get_noise, get_schedule, prepare, unpack
12
+ from flux.util import load_ae, load_clip, load_flow_model, load_t5
13
+ from pulid.pipeline_flux import PuLIDPipeline
14
+ from pulid.utils import resize_numpy_image_long
15
+
16
+
17
+ def get_models(name: str, device: torch.device, offload: bool):
18
+ t5 = load_t5(device, max_length=128)
19
+ clip = load_clip(device)
20
+ model = load_flow_model(name, device="cpu" if offload else device)
21
+ model.eval()
22
+ ae = load_ae(name, device="cpu" if offload else device)
23
+ return model, ae, t5, clip
24
+
25
+
26
+ class FluxGenerator:
27
+ def __init__(self):
28
+ self.device = torch.device('cuda')
29
+ self.offload = False
30
+ self.model_name = 'flux-dev'
31
+ self.model, self.ae, self.t5, self.clip = get_models(
32
+ self.model_name,
33
+ device=self.device,
34
+ offload=self.offload,
35
+ )
36
+ self.pulid_model = PuLIDPipeline(self.model, 'cuda', weight_dtype=torch.bfloat16)
37
+ self.pulid_model.load_pretrain()
38
+
39
+
40
+ flux_generator = FluxGenerator()
41
+
42
+
43
+ @spaces.GPU
44
+ @torch.inference_mode()
45
+ def generate_image(
46
+ width,
47
+ height,
48
+ num_steps,
49
+ start_step,
50
+ guidance,
51
+ seed,
52
+ prompt,
53
+ id_image=None,
54
+ id_weight=1.0,
55
+ neg_prompt="",
56
+ true_cfg=1.0,
57
+ timestep_to_start_cfg=1,
58
+ max_sequence_length=128,
59
+ ):
60
+ flux_generator.t5.max_length = max_sequence_length
61
+
62
+ seed = int(seed)
63
+ if seed == -1:
64
+ seed = None
65
+
66
+ opts = SamplingOptions(
67
+ prompt=prompt,
68
+ width=width,
69
+ height=height,
70
+ num_steps=num_steps,
71
+ guidance=guidance,
72
+ seed=seed,
73
+ )
74
+
75
+ if opts.seed is None:
76
+ opts.seed = torch.Generator(device="cpu").seed()
77
+
78
+ t0 = time.perf_counter()
79
+
80
+ use_true_cfg = abs(true_cfg - 1.0) > 1e-2
81
+
82
+ if id_image is not None:
83
+ id_image = resize_numpy_image_long(id_image, 1024)
84
+ id_embeddings, uncond_id_embeddings = flux_generator.pulid_model.get_id_embedding(id_image, cal_uncond=use_true_cfg)
85
+ else:
86
+ id_embeddings = None
87
+ uncond_id_embeddings = None
88
+
89
+ # prepare input
90
+ x = get_noise(
91
+ 1,
92
+ opts.height,
93
+ opts.width,
94
+ device=flux_generator.device,
95
+ dtype=torch.bfloat16,
96
+ seed=opts.seed,
97
+ )
98
+ timesteps = get_schedule(
99
+ opts.num_steps,
100
+ x.shape[-1] * x.shape[-2] // 4,
101
+ shift=True,
102
+ )
103
+
104
+ if flux_generator.offload:
105
+ flux_generator.t5, flux_generator.clip = flux_generator.t5.to(flux_generator.device), flux_generator.clip.to(flux_generator.device)
106
+ inp = prepare(t5=flux_generator.t5, clip=flux_generator.clip, img=x, prompt=opts.prompt)
107
+ inp_neg = prepare(t5=flux_generator.t5, clip=flux_generator.clip, img=x, prompt=neg_prompt) if use_true_cfg else None
108
+
109
+ # offload TEs to CPU, load model to gpu
110
+ if flux_generator.offload:
111
+ flux_generator.t5, flux_generator.clip = flux_generator.t5.cpu(), flux_generator.clip.cpu()
112
+ torch.cuda.empty_cache()
113
+ flux_generator.model = flux_generator.model.to(flux_generator.device)
114
+
115
+ # denoise initial noise
116
+ x = denoise(
117
+ flux_generator.model, **inp, timesteps=timesteps, guidance=opts.guidance, id=id_embeddings, id_weight=id_weight,
118
+ start_step=start_step, uncond_id=uncond_id_embeddings, true_cfg=true_cfg,
119
+ timestep_to_start_cfg=timestep_to_start_cfg,
120
+ neg_txt=inp_neg["txt"] if use_true_cfg else None,
121
+ neg_txt_ids=inp_neg["txt_ids"] if use_true_cfg else None,
122
+ neg_vec=inp_neg["vec"] if use_true_cfg else None,
123
+ )
124
+
125
+ # offload model, load autoencoder to gpu
126
+ if flux_generator.offload:
127
+ flux_generator.model.cpu()
128
+ torch.cuda.empty_cache()
129
+ flux_generator.ae.decoder.to(x.device)
130
+
131
+ # decode latents to pixel space
132
+ x = unpack(x.float(), opts.height, opts.width)
133
+ with torch.autocast(device_type=flux_generator.device.type, dtype=torch.bfloat16):
134
+ x = flux_generator.ae.decode(x)
135
+
136
+ if flux_generator.offload:
137
+ flux_generator.ae.decoder.cpu()
138
+ torch.cuda.empty_cache()
139
+
140
+ t1 = time.perf_counter()
141
+
142
+ # bring into PIL format
143
+ x = x.clamp(-1, 1)
144
+ x = rearrange(x[0], "c h w -> h w c")
145
+
146
+ img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
147
+ return img, str(opts.seed), flux_generator.pulid_model.debug_img_list
148
+
149
+
150
+ css = """
151
+ footer {
152
+ visibility: hidden;
153
+ }
154
+ """
155
+
156
+
157
+ def create_demo(args, model_name: str, device: str = "cuda" if torch.cuda.is_available() else "cpu",
158
+ offload: bool = False):
159
+
160
+ with gr.Blocks(theme="Nymbo/Nymbo_Theme", css=css) as demo:
161
+ gr.Markdown("### 'AI ํฌํ†  ์ง€๋‹ˆ'์ด์šฉ ์•ˆ๋‚ด: 1) '์Šคํƒ€์ผ'์ค‘ ํ•˜๋‚˜๋ฅผ ์„ ํƒ. 2) ์›น์บ ์„ ํด๋ฆญํ•˜๏ฟฝ๏ฟฝ๏ฟฝ ์–ผ๊ตด์ด ๋ณด์ด๋ฉด ์นด๋ฉ”๋ผ ๋ฒ„ํŠผ ํด๋ฆญ. 3) '์ƒ์„ฑ' ๋ฒ„ํŠผ์„ ํด๋ฆญํ•˜๊ณ  ๊ธฐ๋‹ค๋ฆฌ๋ฉด ๋ฉ๋‹ˆ๋‹ค.")
162
+
163
+ with gr.Row():
164
+ with gr.Column():
165
+ prompt = gr.Textbox(label="ํ”„๋กฌํ”„ํŠธ", value="์ดˆ์ƒํ™”, ์ƒ‰๊ฐ, ์˜ํ™”์ ")
166
+ id_image = gr.Image(label="ID ์ด๋ฏธ์ง€", sources=["webcam", "upload"], type="numpy")
167
+ generate_btn = gr.Button("์ƒ์„ฑ")
168
+
169
+ with gr.Column():
170
+ output_image = gr.Image(label="์ƒ์„ฑ๋œ ์ด๋ฏธ์ง€")
171
+
172
+ with gr.Row():
173
+ with gr.Column():
174
+ gr.Markdown("### ์Šคํƒ€์ผ")
175
+
176
+
177
+ all_examples = [
178
+ ["์šฐ์ฃผ ์—ฌํ–‰I", "I am an astronaut on a spacewalk. There is no helmet, and my face is visible. The background is Earth & starship as seen from space shuttle.", "example_inputs/1.webp"],
179
+ ["์šฐ์ฃผ ์—ฌํ–‰II", "I am an astronaut on a spacewalk. There is no helmet, and my face is visible. The background is Earth & starship as seen from space shuttle.I am holding sign with glowing green text \"I Love Mom\"", "example_inputs/2.webp"],
180
+ ["๋‚ด๊ฐ€ ์–ด๋ฅธ์ด ๋˜๋ฉด", "profile photo of a 40-year-old Adult Looking straight ahead, wear suite", "example_inputs/3.webp"],
181
+ ["์•„์ด์–ธ๋งจ ๋ณ€์‹ ", "I am an \"IRON MAN\"", "example_inputs/4.webp"],
182
+ ["ํ™”์„ฑ ํƒํ—˜", "I am wearing a spacesuit and have become an astronaut walking on Mars. I'm not wearing a helmet. I'm looking straight ahead. The background is a desolate area of Mars, and a space rover and a space station can be seen.", "example_inputs/5.webp"],
183
+ ["์ŠคํŒŒ์ด๋”๋งจ", "I am an \"spider MAN\"", "example_inputs/6.webp"],
184
+ ["์šฐ์ฃผ์„  ์กฐ์ข…", "I am wearing a spacesuit and have become an astronaut. I am piloting a spacecraft. Through the spacecraft's window, I can see outer space.", "example_inputs/7.webp"],
185
+ ["๋งŒํ™” ์ฃผ์ธ๊ณต", "portrait, pixar style", "example_inputs/8.webp"],
186
+ ["์›๋”์šฐ๋จผ", "I am an \"wonder woman\"", "example_inputs/9.webp"],
187
+ ["์นด์šฐ๋ณด์ด", "Cowboy, american comics style", "example_inputs/10.webp"],
188
+ ]
189
+
190
+ example_gallery = gr.Gallery(
191
+ [example[2] for example in all_examples],
192
+ label="์Šคํƒ€์ผ ์˜ˆ์‹œ",
193
+ elem_id="gallery",
194
+ columns=5,
195
+ rows=2,
196
+ object_fit="contain",
197
+ height="auto"
198
+ )
199
+
200
+ def fill_example(evt: gr.SelectData):
201
+ return all_examples[evt.index][1]
202
+
203
+ example_gallery.select(
204
+ fill_example,
205
+ None,
206
+ [prompt],
207
+ )
208
+
209
+ generate_btn.click(
210
+ fn=lambda *args: generate_image(*args)[0], # ์ฒซ ๋ฒˆ์งธ ํ•ญ๋ชฉ(์ด๋ฏธ์ง€)๋งŒ ๋ฐ˜ํ™˜
211
+ inputs=[
212
+ gr.Slider(256, 1536, 896, step=16, visible=False), # width
213
+ gr.Slider(256, 1536, 1152, step=16, visible=False), # height
214
+ gr.Slider(1, 20, 20, step=1, visible=False), # num_steps
215
+ gr.Slider(0, 10, 0, step=1, visible=False), # start_step
216
+ gr.Slider(1.0, 10.0, 4, step=0.1, visible=False), # guidance
217
+ gr.Textbox(-1, visible=False), # seed
218
+ prompt,
219
+ id_image,
220
+ gr.Slider(0.0, 3.0, 1, step=0.05, visible=False), # id_weight
221
+ gr.Textbox("Low quality, worst quality, text, signature, watermark, extra limbs", visible=False), # neg_prompt
222
+ gr.Slider(1.0, 10.0, 1, step=0.1, visible=False), # true_cfg
223
+ gr.Slider(0, 20, 1, step=1, visible=False), # timestep_to_start_cfg
224
+ gr.Slider(128, 512, 128, step=128, visible=False), # max_sequence_length
225
+ ],
226
+ outputs=[output_image],
227
+ )
228
+
229
+ return demo
230
+
231
+
232
+ if __name__ == "__main__":
233
+ import argparse
234
+
235
+ parser = argparse.ArgumentParser(description="PuLID for FLUX.1-dev")
236
+ parser.add_argument("--name", type=str, default="flux-dev", choices=list('flux-dev'),
237
+ help="ํ˜„์žฌ๋Š” flux-dev๋งŒ ์ง€์›ํ•ฉ๋‹ˆ๋‹ค")
238
+ parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu",
239
+ help="์‚ฌ์šฉํ•  ๋””๋ฐ”์ด์Šค")
240
+ parser.add_argument("--offload", action="store_true", help="์‚ฌ์šฉํ•˜์ง€ ์•Š์„ ๋•Œ ๋ชจ๋ธ์„ CPU๋กœ ์˜ฎ๊น๋‹ˆ๋‹ค")
241
+ parser.add_argument("--port", type=int, default=8080, help="์‚ฌ์šฉํ•  ํฌํŠธ")
242
+ parser.add_argument("--dev", action='store_true', help="๊ฐœ๋ฐœ ๋ชจ๋“œ")
243
+ parser.add_argument("--pretrained_model", type=str, help='๊ฐœ๋ฐœ์šฉ')
244
+ args = parser.parse_args()
245
+
246
+ import huggingface_hub
247
+ huggingface_hub.login(os.getenv('HF_TOKEN'))
248
+
249
+ demo = create_demo(args, args.name, args.device, args.offload)
250
+ demo.launch()