okaris multimodalart HF staff commited on
Commit
650b332
1 Parent(s): 3f0445a

UI changes and ZeroGPU optimizations (#1)

Browse files

- UI changes and ZeroGPU optimization (15766a45e9447be44476284711e250c201f27b0e)


Co-authored-by: Apolinário from multimodal AI art <[email protected]>

Files changed (1) hide show
  1. app.py +185 -55
app.py CHANGED
@@ -1,55 +1,175 @@
1
  import gradio as gr
2
  import spaces
3
- from omni_zero import OmniZeroSingle
 
4
 
5
- @spaces.GPU(duration=180)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  def generate(
7
- seed=42,
8
  prompt="A person",
 
 
 
 
 
9
  negative_prompt="blurry, out of focus",
10
  guidance_scale=3.0,
11
  number_of_images=1,
12
  number_of_steps=10,
13
- base_image="https://github.com/okaris/omni-zero/assets/1448702/2ca63443-c7f3-4ba6-95c1-2a341414865f",
14
  base_image_strength=0.15,
15
- composition_image="https://github.com/okaris/omni-zero/assets/1448702/2ca63443-c7f3-4ba6-95c1-2a341414865f",
16
  composition_image_strength=1.0,
17
- style_image="https://github.com/okaris/omni-zero/assets/1448702/64dc150b-f683-41b1-be23-b6a52c771584",
18
  style_image_strength=1.0,
19
- identity_image="https://github.com/okaris/omni-zero/assets/1448702/ba193a3a-f90e-4461-848a-560454531c58",
20
  identity_image_strength=1.0,
21
  depth_image=None,
22
  depth_image_strength=0.5,
 
23
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
- omni_zero = OmniZeroSingle(
26
- base_model="frankjoshua/albedobaseXL_v13",
27
- )
 
 
 
 
 
 
 
 
 
28
 
29
- images = omni_zero.generate(
30
- seed=seed,
 
 
 
 
 
 
 
 
 
 
 
 
31
  prompt=prompt,
32
  negative_prompt=negative_prompt,
33
  guidance_scale=guidance_scale,
34
- number_of_images=number_of_images,
35
- number_of_steps=number_of_steps,
36
- base_image=base_image,
37
- base_image_strength=base_image_strength,
38
- composition_image=composition_image,
39
- composition_image_strength=composition_image_strength,
40
- style_image=style_image,
41
- style_image_strength=style_image_strength,
42
- identity_image=identity_image,
43
- identity_image_strength=identity_image_strength,
44
- depth_image=depth_image,
45
- depth_image_strength=depth_image_strength,
46
- )
47
 
48
- # for i, image in enumerate(images):
49
- # image.save(f"oz_output_{i}.jpg")
50
  return images
51
 
 
 
 
52
  with gr.Blocks() as demo:
 
 
53
  with gr.Row():
54
  with gr.Column():
55
  with gr.Row():
@@ -57,38 +177,42 @@ with gr.Blocks() as demo:
57
  with gr.Row():
58
  negative_prompt = gr.Textbox(label="Negative Prompt", value="blurry, out of focus")
59
  with gr.Row():
60
- seed = gr.Slider(label="Seed",step=1, minimum=0, maximum=10000000, value=42)
61
- number_of_images = gr.Slider(label="Number of Outputs",step=1, minimum=1, maximum=4, value=1)
62
- with gr.Row():
63
- guidance_scale = gr.Slider(label="Guidance Scale",step=0.1, minimum=0.0, maximum=14.0, value=3.0)
64
- number_of_steps = gr.Slider(label="Number of Steps",step=1, minimum=1, maximum=50, value=10)
65
- with gr.Row():
66
- with gr.Column():
67
  with gr.Row():
68
- base_image = gr.Image(label="Base Image", value="https://github.com/okaris/omni-zero/assets/1448702/2ca63443-c7f3-4ba6-95c1-2a341414865f")
69
  with gr.Row():
70
- base_image_strength = gr.Slider(label="Base Image Strength",step=0.01, minimum=0.0, maximum=1.0, value=0.15)
71
- with gr.Column():
 
72
  with gr.Row():
73
- composition_image = gr.Image(label="Composition", value="https://github.com/okaris/omni-zero/assets/1448702/2ca63443-c7f3-4ba6-95c1-2a341414865f")
74
  with gr.Row():
75
- composition_image_strength = gr.Slider(label="Composition Image Strength",step=0.01, minimum=0.0, maximum=1.0, value=1.0)
76
- # with gr.Row():
77
- with gr.Column():
78
  with gr.Row():
79
- style_image = gr.Image(label="Style Image", value="https://github.com/okaris/omni-zero/assets/1448702/64dc150b-f683-41b1-be23-b6a52c771584")
80
  with gr.Row():
81
- style_image_strength = gr.Slider(label="Style Image Strength",step=0.01, minimum=0.0, maximum=1.0, value=1.0)
82
- with gr.Column():
83
- with gr.Row():
84
- identity_image = gr.Image(label="Identity Image", value="https://github.com/okaris/omni-zero/assets/1448702/ba193a3a-f90e-4461-848a-560454531c58")
85
- with gr.Row():
86
- identity_image_strength = gr.Slider(label="Identitiy Image Strenght",step=0.01, minimum=0.0, maximum=1.0, value=1.0)
87
- # with gr.Column():
 
 
88
  # with gr.Row():
89
  # depth_image = gr.Image(label="depth_image", value=None)
90
  # with gr.Row():
91
  # depth_image_strength = gr.Slider(label="depth_image_strength",step=0.01, minimum=0.0, maximum=1.0, value=0.5)
 
 
 
 
 
 
 
 
92
  with gr.Column():
93
  with gr.Row():
94
  out = gr.Gallery(label="Output(s)")
@@ -97,24 +221,30 @@ with gr.Blocks() as demo:
97
  submit = gr.Button("Generate")
98
 
99
  submit.click(generate, inputs=[
100
- seed,
101
  prompt,
 
 
 
 
 
102
  negative_prompt,
103
  guidance_scale,
104
  number_of_images,
105
  number_of_steps,
106
- base_image,
107
  base_image_strength,
108
- composition_image,
109
  composition_image_strength,
110
- style_image,
111
  style_image_strength,
112
- identity_image,
113
  identity_image_strength,
114
  ],
115
  outputs=[out]
116
  )
117
  # clear.click(lambda: None, None, chatbot, queue=False)
118
-
 
 
 
 
 
 
119
  if __name__ == "__main__":
120
  demo.launch()
 
1
  import gradio as gr
2
  import spaces
3
+ import os
4
+ os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
5
 
6
+ import sys
7
+ sys.path.insert(0, './diffusers/src')
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+
12
+ #Hack for ZeroGPU
13
+ torch.jit.script = lambda f: f
14
+ ####
15
+
16
+ from huggingface_hub import snapshot_download
17
+ from diffusers import DPMSolverMultistepScheduler
18
+ from diffusers.models import ControlNetModel
19
+
20
+ from transformers import CLIPVisionModelWithProjection
21
+
22
+ from pipeline import OmniZeroPipeline
23
+ from insightface.app import FaceAnalysis
24
+ from controlnet_aux import ZoeDetector
25
+ from utils import draw_kps, load_and_resize_image, align_images
26
+
27
+ import cv2
28
+ import numpy as np
29
+
30
+ base_model="frankjoshua/albedobaseXL_v13"
31
+
32
+ snapshot_download("okaris/antelopev2", local_dir="./models/antelopev2")
33
+ face_analysis = FaceAnalysis(name='antelopev2', root='./', providers=['CPUExecutionProvider'])
34
+ face_analysis.prepare(ctx_id=0, det_size=(640, 640))
35
+
36
+ dtype = torch.float16
37
+
38
+ ip_adapter_plus_image_encoder = CLIPVisionModelWithProjection.from_pretrained(
39
+ "h94/IP-Adapter",
40
+ subfolder="models/image_encoder",
41
+ torch_dtype=dtype,
42
+ ).to("cuda")
43
+
44
+ zoedepthnet_path = "okaris/zoe-depth-controlnet-xl"
45
+ zoedepthnet = ControlNetModel.from_pretrained(zoedepthnet_path,torch_dtype=dtype).to("cuda")
46
+
47
+ identitiynet_path = "okaris/face-controlnet-xl"
48
+ identitynet = ControlNetModel.from_pretrained(identitiynet_path, torch_dtype=dtype).to("cuda")
49
+
50
+ zoe_depth_detector = ZoeDetector.from_pretrained("lllyasviel/Annotators").to("cuda")
51
+
52
+ pipeline = OmniZeroPipeline.from_pretrained(
53
+ base_model,
54
+ controlnet=[identitynet, zoedepthnet],
55
+ torch_dtype=dtype,
56
+ image_encoder=ip_adapter_plus_image_encoder,
57
+ ).to("cuda")
58
+
59
+ config = pipeline.scheduler.config
60
+ config["timestep_spacing"] = "trailing"
61
+ pipeline.scheduler = DPMSolverMultistepScheduler.from_config(config, use_karras_sigmas=True, algorithm_type="sde-dpmsolver++", final_sigmas_type="zero")
62
+ pipeline.load_ip_adapter(["okaris/ip-adapter-instantid", "h94/IP-Adapter", "h94/IP-Adapter"], subfolder=[None, "sdxl_models", "sdxl_models"], weight_name=["ip-adapter-instantid.bin", "ip-adapter-plus_sdxl_vit-h.safetensors", "ip-adapter-plus_sdxl_vit-h.safetensors"])
63
+
64
+ def get_largest_face_embedding_and_kps(image, target_image=None):
65
+ face_info = face_analysis.get(cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR))
66
+ if len(face_info) == 0:
67
+ return None, None
68
+ largest_face = sorted(face_info, key=lambda x: x['bbox'][2] * x['bbox'][3], reverse=True)[0]
69
+ face_embedding = torch.tensor(largest_face['embedding']).to("cuda")
70
+ if target_image is None:
71
+ target_image = image
72
+ zeros = np.zeros((target_image.size[1], target_image.size[0], 3), dtype=np.uint8)
73
+ face_kps_image = draw_kps(zeros, largest_face['kps'])
74
+ return face_embedding, face_kps_image
75
+
76
+ @spaces.GPU()
77
  def generate(
 
78
  prompt="A person",
79
+ composition_image="https://github.com/okaris/omni-zero/assets/1448702/2ca63443-c7f3-4ba6-95c1-2a341414865f",
80
+ style_image="https://github.com/okaris/omni-zero/assets/1448702/64dc150b-f683-41b1-be23-b6a52c771584",
81
+ identity_image="https://github.com/okaris/omni-zero/assets/1448702/ba193a3a-f90e-4461-848a-560454531c58",
82
+ base_image="https://github.com/okaris/omni-zero/assets/1448702/2ca63443-c7f3-4ba6-95c1-2a341414865f",
83
+ seed=42,
84
  negative_prompt="blurry, out of focus",
85
  guidance_scale=3.0,
86
  number_of_images=1,
87
  number_of_steps=10,
 
88
  base_image_strength=0.15,
 
89
  composition_image_strength=1.0,
 
90
  style_image_strength=1.0,
 
91
  identity_image_strength=1.0,
92
  depth_image=None,
93
  depth_image_strength=0.5,
94
+ progress=gr.Progress(track_tqdm=True)
95
  ):
96
+ resolution = 1024
97
+
98
+ if base_image is not None:
99
+ base_image = load_and_resize_image(base_image, resolution, resolution)
100
+ else:
101
+ if composition_image is not None:
102
+ base_image = load_and_resize_image(composition_image, resolution, resolution)
103
+ else:
104
+ raise ValueError("You must provide a base image or a composition image")
105
+
106
+ if depth_image is None:
107
+ depth_image = zoe_depth_detector(base_image, detect_resolution=resolution, image_resolution=resolution)
108
+ else:
109
+ depth_image = load_and_resize_image(depth_image, resolution, resolution)
110
+
111
+ base_image, depth_image = align_images(base_image, depth_image)
112
+
113
+ if composition_image is not None:
114
+ composition_image = load_and_resize_image(composition_image, resolution, resolution)
115
+ else:
116
+ composition_image = base_image
117
+
118
+ if style_image is not None:
119
+ style_image = load_and_resize_image(style_image, resolution, resolution)
120
+ else:
121
+ raise ValueError("You must provide a style image")
122
 
123
+ if identity_image is not None:
124
+ identity_image = load_and_resize_image(identity_image, resolution, resolution)
125
+ else:
126
+ raise ValueError("You must provide an identity image")
127
+
128
+ face_embedding_identity_image, target_kps = get_largest_face_embedding_and_kps(identity_image, base_image)
129
+ if face_embedding_identity_image is None:
130
+ raise ValueError("No face found in the identity image, the image might be cropped too tightly or the face is too small")
131
+
132
+ face_embedding_base_image, face_kps_base_image = get_largest_face_embedding_and_kps(base_image)
133
+ if face_embedding_base_image is not None:
134
+ target_kps = face_kps_base_image
135
 
136
+ pipeline.set_ip_adapter_scale([identity_image_strength,
137
+ {
138
+ "down": { "block_2": [0.0, 0.0] },
139
+ "up": { "block_0": [0.0, style_image_strength, 0.0] }
140
+ },
141
+ {
142
+ "down": { "block_2": [0.0, composition_image_strength] },
143
+ "up": { "block_0": [0.0, 0.0, 0.0] }
144
+ }
145
+ ])
146
+
147
+ generator = torch.Generator(device="cpu").manual_seed(seed)
148
+
149
+ images = pipeline(
150
  prompt=prompt,
151
  negative_prompt=negative_prompt,
152
  guidance_scale=guidance_scale,
153
+ ip_adapter_image=[face_embedding_identity_image, style_image, composition_image],
154
+ image=base_image,
155
+ control_image=[target_kps, depth_image],
156
+ controlnet_conditioning_scale=[identity_image_strength, depth_image_strength],
157
+ identity_control_indices=[(0,0)],
158
+ num_inference_steps=number_of_steps,
159
+ num_images_per_prompt=number_of_images,
160
+ strength=(1-base_image_strength),
161
+ generator=generator,
162
+ seed=seed,
163
+ ).images
 
 
164
 
 
 
165
  return images
166
 
167
+ #Move the components in the example fields outside so they are available when gr.Examples is instantiated
168
+
169
+
170
  with gr.Blocks() as demo:
171
+ gr.Markdown("<h1 style='text-align: center'>Omni Zero</h1>")
172
+ gr.Markdown("<h4 style='text-align: center'>A diffusion pipeline for zero-shot stylized portrait creation [<a href='https://github.com/okaris/omni-zero' target='_blank'>GitHub</a>], [<a href='https://styleof.com/s/remix-yourself' target='_blank'>StyleOf Remix Yourself</a>]</h4>")
173
  with gr.Row():
174
  with gr.Column():
175
  with gr.Row():
 
177
  with gr.Row():
178
  negative_prompt = gr.Textbox(label="Negative Prompt", value="blurry, out of focus")
179
  with gr.Row():
180
+ with gr.Column(min_width=140):
 
 
 
 
 
 
181
  with gr.Row():
182
+ composition_image = gr.Image(label="Composition")
183
  with gr.Row():
184
+ composition_image_strength = gr.Slider(label="Strength",step=0.01, minimum=0.0, maximum=1.0, value=1.0)
185
+ #with gr.Row():
186
+ with gr.Column(min_width=140):
187
  with gr.Row():
188
+ style_image = gr.Image(label="Style Image")
189
  with gr.Row():
190
+ style_image_strength = gr.Slider(label="Strength",step=0.01, minimum=0.0, maximum=1.0, value=1.0)
191
+ with gr.Column(min_width=140):
 
192
  with gr.Row():
193
+ identity_image = gr.Image(label="Identity Image")
194
  with gr.Row():
195
+ identity_image_strength = gr.Slider(label="Strenght",step=0.01, minimum=0.0, maximum=1.0, value=1.0)
196
+ with gr.Accordion("Advanced options", open=False):
197
+ with gr.Row():
198
+ with gr.Column(min_width=140):
199
+ with gr.Row():
200
+ base_image = gr.Image(label="Base Image")
201
+ with gr.Row():
202
+ base_image_strength = gr.Slider(label="Strength",step=0.01, minimum=0.0, maximum=1.0, value=0.15, min_width=120)
203
+ # with gr.Column(min_width=140):
204
  # with gr.Row():
205
  # depth_image = gr.Image(label="depth_image", value=None)
206
  # with gr.Row():
207
  # depth_image_strength = gr.Slider(label="depth_image_strength",step=0.01, minimum=0.0, maximum=1.0, value=0.5)
208
+
209
+ with gr.Row():
210
+ seed = gr.Slider(label="Seed",step=1, minimum=0, maximum=10000000, value=42)
211
+ number_of_images = gr.Slider(label="Number of Outputs",step=1, minimum=1, maximum=4, value=1)
212
+ with gr.Row():
213
+ guidance_scale = gr.Slider(label="Guidance Scale",step=0.1, minimum=0.0, maximum=14.0, value=3.0)
214
+ number_of_steps = gr.Slider(label="Number of Steps",step=1, minimum=1, maximum=50, value=10)
215
+
216
  with gr.Column():
217
  with gr.Row():
218
  out = gr.Gallery(label="Output(s)")
 
221
  submit = gr.Button("Generate")
222
 
223
  submit.click(generate, inputs=[
 
224
  prompt,
225
+ composition_image,
226
+ style_image,
227
+ identity_image,
228
+ base_image,
229
+ seed,
230
  negative_prompt,
231
  guidance_scale,
232
  number_of_images,
233
  number_of_steps,
 
234
  base_image_strength,
 
235
  composition_image_strength,
 
236
  style_image_strength,
 
237
  identity_image_strength,
238
  ],
239
  outputs=[out]
240
  )
241
  # clear.click(lambda: None, None, chatbot, queue=False)
242
+ gr.Examples(
243
+ examples=[["A person", "https://github.com/okaris/omni-zero/assets/1448702/2ca63443-c7f3-4ba6-95c1-2a341414865f", "https://github.com/okaris/omni-zero/assets/1448702/64dc150b-f683-41b1-be23-b6a52c771584", "https://github.com/okaris/omni-zero/assets/1448702/ba193a3a-f90e-4461-848a-560454531c58"]],
244
+ inputs=[prompt, composition_image, style_image, identity_image],
245
+ outputs=[out],
246
+ fn=generate,
247
+ cache_examples="lazy",
248
+ )
249
  if __name__ == "__main__":
250
  demo.launch()