amildravid4292 commited on
Commit
a020647
·
verified ·
1 Parent(s): f00e65e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +153 -21
app.py CHANGED
@@ -19,9 +19,13 @@ global vae
19
  global text_encoder
20
  global tokenizer
21
  global noise_scheduler
 
 
 
22
  device = "cuda:0"
23
  generator = torch.Generator(device=device)
24
 
 
25
  models_path = snapshot_download(repo_id="Snapchat/w2w")
26
 
27
  mean = torch.load(f"{models_path}/mean.pt").bfloat16().to(device)
@@ -30,6 +34,7 @@ v = torch.load(f"{models_path}/V.pt").bfloat16().to(device)
30
  proj = torch.load(f"{models_path}/proj_1000pc.pt").bfloat16().to(device)
31
  df = torch.load(f"{models_path}/identity_df.pt")
32
  weight_dimensions = torch.load(f"{models_path}/weight_dimensions.pt")
 
33
 
34
  unet, vae, text_encoder, tokenizer, noise_scheduler = load_models(device)
35
  global network
@@ -38,16 +43,64 @@ def sample_model():
38
  global unet
39
  del unet
40
  global network
 
41
  unet, _, _, _, _ = load_models(device)
42
  network = sample_weights(unet, proj, mean, std, v[:, :1000], device, factor = 1.00)
43
 
44
- ### start off with an initial model
45
- sample_model()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
 
 
 
47
 
48
 
49
  @torch.no_grad()
50
- def inference( prompt, negative_prompt, guidance_scale, ddim_steps, seed):
 
51
  global device
52
  global generator
53
  global unet
@@ -55,6 +108,15 @@ def inference( prompt, negative_prompt, guidance_scale, ddim_steps, seed):
55
  global text_encoder
56
  global tokenizer
57
  global noise_scheduler
 
 
 
 
 
 
 
 
 
58
  generator = generator.manual_seed(seed)
59
  latents = torch.randn(
60
  (1, unet.in_channels, 512 // 8, 512 // 8),
@@ -76,11 +138,23 @@ def inference( prompt, negative_prompt, guidance_scale, ddim_steps, seed):
76
  noise_scheduler.set_timesteps(ddim_steps)
77
  latents = latents * noise_scheduler.init_noise_sigma
78
 
 
 
79
  for i,t in enumerate(tqdm.tqdm(noise_scheduler.timesteps)):
80
  latent_model_input = torch.cat([latents] * 2)
81
  latent_model_input = noise_scheduler.scale_model_input(latent_model_input, timestep=t)
 
 
 
 
 
 
 
 
82
  with network:
83
  noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings, timestep_cond= None).sample
 
 
84
  #guidance
85
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
86
  noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
@@ -89,13 +163,61 @@ def inference( prompt, negative_prompt, guidance_scale, ddim_steps, seed):
89
  latents = 1 / 0.18215 * latents
90
  image = vae.decode(latents).sample
91
  image = (image / 2 + 0.5).clamp(0, 1)
 
92
  image = image.detach().cpu().float().permute(0, 2, 3, 1).numpy()[0]
93
 
94
  image = Image.fromarray((image * 255).round().astype("uint8"))
95
 
 
 
 
 
96
  return [image]
97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
  css = ''
101
  with gr.Blocks(css=css) as demo:
@@ -103,36 +225,41 @@ with gr.Blocks(css=css) as demo:
103
  gr.Markdown("Demo for the [h94/IP-Adapter-FaceID model](https://huggingface.co/h94/IP-Adapter-FaceID) - Generate AI images with your own face - Non-commercial license")
104
  with gr.Row():
105
  with gr.Column():
106
- files = gr.Files(
107
- label="Upload a photo of your face to invert, or sample a new model",
108
- file_types=["image"]
109
- )
110
- uploaded_files = gr.Gallery(label="Your images", visible=False, columns=5, rows=1, height=125)
111
-
112
  sample = gr.Button("Sample New Model")
 
113
 
114
- with gr.Column(visible=False) as clear_button:
115
- remove_and_reupload = gr.ClearButton(value="Remove and upload new ones", components=files, size="sm")
116
  prompt = gr.Textbox(label="Prompt",
117
- info="Make sure to include 'sks person'" ,
118
- placeholder="sks person",
119
- value="sks person")
120
  negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="low quality, blurry, unfinished, cartoon", value="low quality, blurry, unfinished, cartoon")
121
- seed = gr.Number(value=5, label="Seed", interactive=True)
122
  cfg = gr.Slider(label="CFG", value=3.0, step=0.1, minimum=0, maximum=10, interactive=True)
123
- steps = gr.Slider(label="Inference Steps", value=50, step=1, minimum=0, maximum=100, interactive=True)
 
 
 
 
 
 
 
124
 
125
 
126
  submit = gr.Button("Submit")
127
 
128
  with gr.Column():
129
- gallery = gr.Gallery(label="Generated Images")
 
 
 
 
 
130
 
131
- sample.click(fn=sample_model)
132
 
133
- submit.click(fn=inference,
134
- inputs=[prompt, negative_prompt, cfg, steps, seed],
135
- outputs=gallery)
136
 
137
 
138
 
@@ -140,3 +267,8 @@ with gr.Blocks(css=css) as demo:
140
 
141
 
142
  demo.launch(share=True)
 
 
 
 
 
 
19
  global text_encoder
20
  global tokenizer
21
  global noise_scheduler
22
+ global young_val
23
+ global pointy_val
24
+ global bags_val
25
  device = "cuda:0"
26
  generator = torch.Generator(device=device)
27
 
28
+
29
  models_path = snapshot_download(repo_id="Snapchat/w2w")
30
 
31
  mean = torch.load(f"{models_path}/mean.pt").bfloat16().to(device)
 
34
  proj = torch.load(f"{models_path}/proj_1000pc.pt").bfloat16().to(device)
35
  df = torch.load(f"{models_path}/identity_df.pt")
36
  weight_dimensions = torch.load(f"{models_path}/weight_dimensions.pt")
37
+ pinverse = torch.load(f"{models_path}/pinverse_1000pc.pt").bfloat16().to(device)
38
 
39
  unet, vae, text_encoder, tokenizer, noise_scheduler = load_models(device)
40
  global network
 
43
  global unet
44
  del unet
45
  global network
46
+
47
  unet, _, _, _, _ = load_models(device)
48
  network = sample_weights(unet, proj, mean, std, v[:, :1000], device, factor = 1.00)
49
 
50
+
51
+ @torch.no_grad()
52
+ def inference( prompt, negative_prompt, guidance_scale, ddim_steps, seed):
53
+ global device
54
+ global generator
55
+ global unet
56
+ global vae
57
+ global text_encoder
58
+ global tokenizer
59
+ global noise_scheduler
60
+ generator = generator.manual_seed(seed)
61
+ latents = torch.randn(
62
+ (1, unet.in_channels, 512 // 8, 512 // 8),
63
+ generator = generator,
64
+ device = device
65
+ ).bfloat16()
66
+
67
+
68
+ text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
69
+
70
+ text_embeddings = text_encoder(text_input.input_ids.to(device))[0]
71
+
72
+ max_length = text_input.input_ids.shape[-1]
73
+ uncond_input = tokenizer(
74
+ [negative_prompt], padding="max_length", max_length=max_length, return_tensors="pt"
75
+ )
76
+ uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0]
77
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
78
+ noise_scheduler.set_timesteps(ddim_steps)
79
+ latents = latents * noise_scheduler.init_noise_sigma
80
+
81
+ for i,t in enumerate(tqdm.tqdm(noise_scheduler.timesteps)):
82
+ latent_model_input = torch.cat([latents] * 2)
83
+ latent_model_input = noise_scheduler.scale_model_input(latent_model_input, timestep=t)
84
+ with network:
85
+ noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings, timestep_cond= None).sample
86
+ #guidance
87
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
88
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
89
+ latents = noise_scheduler.step(noise_pred, t, latents).prev_sample
90
+
91
+ latents = 1 / 0.18215 * latents
92
+ image = vae.decode(latents).sample
93
+ image = (image / 2 + 0.5).clamp(0, 1)
94
+ image = image.detach().cpu().float().permute(0, 2, 3, 1).numpy()[0]
95
 
96
+ image = Image.fromarray((image * 255).round().astype("uint8"))
97
+
98
+ return [image]
99
 
100
 
101
  @torch.no_grad()
102
+ def edit_inference(prompt, negative_prompt, guidance_scale, ddim_steps, seed, start_noise, a1, a2, a3):
103
+
104
  global device
105
  global generator
106
  global unet
 
108
  global text_encoder
109
  global tokenizer
110
  global noise_scheduler
111
+ global young
112
+ global pointy
113
+ global bags
114
+
115
+ original_weights = network.proj.clone()
116
+
117
+
118
+ edited_weights = original_weights+a1*young+a2*pointy+a3*bags
119
+
120
  generator = generator.manual_seed(seed)
121
  latents = torch.randn(
122
  (1, unet.in_channels, 512 // 8, 512 // 8),
 
138
  noise_scheduler.set_timesteps(ddim_steps)
139
  latents = latents * noise_scheduler.init_noise_sigma
140
 
141
+
142
+
143
  for i,t in enumerate(tqdm.tqdm(noise_scheduler.timesteps)):
144
  latent_model_input = torch.cat([latents] * 2)
145
  latent_model_input = noise_scheduler.scale_model_input(latent_model_input, timestep=t)
146
+
147
+ if t>start_noise:
148
+ pass
149
+ elif t<=start_noise:
150
+ network.proj = torch.nn.Parameter(edited_weights)
151
+ network.reset()
152
+
153
+
154
  with network:
155
  noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings, timestep_cond= None).sample
156
+
157
+
158
  #guidance
159
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
160
  noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
 
163
  latents = 1 / 0.18215 * latents
164
  image = vae.decode(latents).sample
165
  image = (image / 2 + 0.5).clamp(0, 1)
166
+
167
  image = image.detach().cpu().float().permute(0, 2, 3, 1).numpy()[0]
168
 
169
  image = Image.fromarray((image * 255).round().astype("uint8"))
170
 
171
+ #reset weights back to original
172
+ network.proj = torch.nn.Parameter(original_weights)
173
+ network.reset()
174
+
175
  return [image]
176
 
177
+
178
+
179
+
180
+ def sample_then_run():
181
+ global young_val
182
+ global pointy_val
183
+ global bags_val
184
+ global young
185
+ global pointy
186
+ global bags
187
+
188
+ sample_model()
189
+
190
+ young_val = network.proj@young[0]/(torch.norm(young)**2).item()
191
+ pointy_val = network.proj@pointy[0]/(torch.norm(pointy)**2).item()
192
+ bags_val = network.proj@bags[0]/(torch.norm(bags)**2).item()
193
 
194
+ prompt = "sks person"
195
+ negative_prompt = "low quality, blurry, unfinished, cartoon"
196
+ seed = 5
197
+ cfg = 3.0
198
+ steps = 50
199
+ image = inference( prompt, negative_prompt, cfg, steps, seed)
200
+ return image
201
+
202
+
203
+ #directions
204
+ global young
205
+ global pointy
206
+ global bags
207
+ young = get_direction(df, "Young", pinverse, 1000, device)
208
+ young = debias(young, "Male", df, pinverse, device)
209
+ young_max = torch.max(proj@young[0]/(torch.norm(young))**2).item()
210
+ young_min = torch.min(proj@young[0]/(torch.norm(young))**2).item()
211
+
212
+ pointy = get_direction(df, "Pointy_Nose", pinverse, 1000, device)
213
+ pointy_max = torch.max(proj@pointy[0]/(torch.norm(pointy))**2).item()
214
+ pointy_min = torch.min(proj@pointy[0]/(torch.norm(pointy))**2).item()
215
+
216
+ bags = get_direction(df, "Bags_Under_Eyes", pinverse, 1000, device)
217
+ bags_max = torch.max(proj@bags[0]/(torch.norm(bags))**2).item()
218
+ bags_min = torch.min(proj@bags[0]/(torch.norm(bags))**2).item()
219
+
220
+
221
 
222
  css = ''
223
  with gr.Blocks(css=css) as demo:
 
225
  gr.Markdown("Demo for the [h94/IP-Adapter-FaceID model](https://huggingface.co/h94/IP-Adapter-FaceID) - Generate AI images with your own face - Non-commercial license")
226
  with gr.Row():
227
  with gr.Column():
 
 
 
 
 
 
228
  sample = gr.Button("Sample New Model")
229
+ gallery1 = gr.Gallery(label="Identity from Sampled Model")
230
 
231
+ with gr.Column():
 
232
  prompt = gr.Textbox(label="Prompt",
233
+ info="Make sure to include 'sks person'" ,
234
+ placeholder="sks person",
235
+ value="sks person")
236
  negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="low quality, blurry, unfinished, cartoon", value="low quality, blurry, unfinished, cartoon")
237
+ seed = gr.Number(value=5, precision=0, label="Seed", interactive=True)
238
  cfg = gr.Slider(label="CFG", value=3.0, step=0.1, minimum=0, maximum=10, interactive=True)
239
+ steps = gr.Slider(label="Inference Steps", precision=0, value=50, step=1, minimum=0, maximum=100, interactive=True)
240
+ injection_step = gr.Slider(label="Injection Step", precision=0, value=800, step=1, minimum=0, maximum=1000, interactive=True)
241
+
242
+
243
+ with gr.Row():
244
+ a1 = gr.Slider(label="Young", value=0, step=1, minimum=-1000000, maximum=1000000, interactive=True)
245
+ a2 = gr.Slider(label="Pointy Nose", value=0, step=1, minimum=-1000000, maximum=1000000, interactive=True)
246
+ a3 = gr.Slider(label="Undereye Bags", value=0, step=1, minimum=-1000000, maximum=1000000, interactive=True)
247
 
248
 
249
  submit = gr.Button("Submit")
250
 
251
  with gr.Column():
252
+ gallery2 = gr.Gallery(label="Identity from Edited Model")
253
+
254
+
255
+
256
+
257
+ sample.click(fn=sample_then_run, outputs=gallery1)
258
 
 
259
 
260
+ submit.click(fn=edit_inference,
261
+ inputs=[prompt, negative_prompt, cfg, steps, seed, injection_step, a1, a2, a3],
262
+ outputs=gallery2)
263
 
264
 
265
 
 
267
 
268
 
269
  demo.launch(share=True)
270
+
271
+
272
+
273
+
274
+