amildravid4292 commited on
Commit
7217618
·
verified ·
1 Parent(s): 45ec4cb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -64
app.py CHANGED
@@ -140,7 +140,7 @@ def sample_then_run(net):
140
  return net, image
141
 
142
  @torch.no_grad()
143
- @spaces.GPU(duration=120)
144
  def inference(net, prompt, negative_prompt, guidance_scale, ddim_steps, seed):
145
  mean.to(device)
146
  std.to(device)
@@ -197,77 +197,75 @@ def inference(net, prompt, negative_prompt, guidance_scale, ddim_steps, seed):
197
 
198
  image = Image.fromarray((image * 255).round().astype("uint8"))
199
 
200
- del network
201
 
202
  return image
203
 
204
 
 
 
205
  @torch.no_grad()
206
- @spaces.GPU(duration=120)
207
- def edit_inference(self, prompt, negative_prompt, guidance_scale, ddim_steps, seed, start_noise, a1, a2, a3, a4):
208
- device = self.device
209
- self.unet.to(device)
210
- self.text_encoder.to(device)
211
- self.vae.to(device)
212
- self.mean.to(device)
213
- self.std.to(device)
214
- self.v.to(device)
215
- self.proj.to(device)
216
- self.weights = torch.load("model.pt").to(device)
217
- self.young.to(device)
218
- self.pointy.to(device)
219
- self.wavy.to(device)
220
- self.thick.to(device)
221
 
222
- network = LoRAw2w( self.weights.bfloat16(), self.mean.bfloat16(), self.std.bfloat16(), self.v[:, :1000].bfloat16(),
223
- self.unet,
 
 
224
  rank=1,
225
  multiplier=1.0,
226
  alpha=27.0,
227
  train_method="xattn-strict"
228
  ).to(device, torch.bfloat16)
229
-
230
-
231
- original_weights = self.weights.clone()
 
232
 
233
  #pad to same number of PCs
234
- pcs_original = original_weights.shape[1]
235
- pcs_edits = self.young.shape[1]
236
  padding = torch.zeros((1,pcs_original-pcs_edits)).to(device)
237
- young_pad = torch.cat((self.young, padding), 1)
238
- pointy_pad = torch.cat((self.pointy, padding), 1)
239
- wavy_pad = torch.cat((self.wavy, padding), 1)
240
- thick_pad = torch.cat((self.thick, padding), 1)
241
 
242
 
243
- edited_weights = original_weights+a1*1e6*young_pad+a2*1e6*pointy_pad+a3*1e6*wavy_pad+a4*2e6*thick_pad
244
 
245
  generator = torch.Generator(device=device).manual_seed(seed)
246
  latents = torch.randn(
247
- (1, self.unet.in_channels, 512 // 8, 512 // 8),
248
  generator = generator,
249
- device = self.device
250
  ).bfloat16()
251
 
252
 
253
- text_input = self.tokenizer(prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt")
254
 
255
- text_embeddings = self.text_encoder(text_input.input_ids.to(device))[0]
256
 
257
  max_length = text_input.input_ids.shape[-1]
258
- uncond_input = self.tokenizer(
259
  [negative_prompt], padding="max_length", max_length=max_length, return_tensors="pt"
260
  )
261
- uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(device))[0]
262
  text_embeddings = torch.cat([uncond_embeddings, text_embeddings]).bfloat16()
263
- self.noise_scheduler.set_timesteps(ddim_steps)
264
- latents = latents * self.noise_scheduler.init_noise_sigma
265
 
266
 
267
 
268
- for i,t in enumerate(tqdm.tqdm(self.noise_scheduler.timesteps)):
269
  latent_model_input = torch.cat([latents] * 2)
270
- latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, timestep=t)
271
 
272
  if t>start_noise:
273
  pass
@@ -276,7 +274,7 @@ def edit_inference(self, prompt, negative_prompt, guidance_scale, ddim_steps, se
276
  network.reset()
277
 
278
  with network:
279
- noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings, timestep_cond= None).sample
280
 
281
 
282
  #guidance
@@ -285,31 +283,12 @@ def edit_inference(self, prompt, negative_prompt, guidance_scale, ddim_steps, se
285
  latents = noise_scheduler.step(noise_pred, t, latents).prev_sample
286
 
287
  latents = 1 / 0.18215 * latents
288
- image = self.vae.decode(latents.float()).sample
289
  image = (image / 2 + 0.5).clamp(0, 1)
290
  image = image.detach().cpu().float().permute(0, 2, 3, 1).numpy()[0]
291
  image = Image.fromarray((image * 255).round().astype("uint8"))
292
 
293
- return image
294
-
295
- # @torch.no_grad()
296
- # @spaces.GPU(duration=120)
297
- # def sample_then_run(self):
298
- # self.unet = UNet2DConditionModel.from_pretrained(
299
- # "stablediffusionapi/realistic-vision-v51" , subfolder="unet", revision=None
300
- # )
301
- # self.unet.to(self.device, dtype=torch.bfloat16)
302
- # self.weights = sample_weights(self.unet, self.proj, self.mean, self.std, self.v[:, :1000], self.device, factor = 1.00)
303
-
304
- # prompt = "sks person"
305
- # negative_prompt = "low quality, blurry, unfinished, nudity, weapon"
306
- # seed = 5
307
- # cfg = 3.0
308
- # steps = 25
309
- # image = self.inference(prompt, negative_prompt, cfg, steps, seed)
310
- # torch.save(self.weights.cpu().detach(), "model.pt" )
311
- # return image, "model.pt"
312
-
313
 
314
 
315
  class CustomImageDataset(Dataset):
@@ -535,9 +514,9 @@ with gr.Blocks(css="style.css") as demo:
535
 
536
  sample.click(fn=sample_then_run,inputs = [net], outputs=[net, input_image])
537
 
538
- # submit.click(
539
- # fn=model.edit_inference, inputs=[prompt, negative_prompt, cfg, steps, seed, injection_step, a1, a2, a3, a4], outputs=[gallery]
540
- # )
541
  # file_input.change(fn=model.file_upload, inputs=file_input, outputs = gallery)
542
 
543
 
 
140
  return net, image
141
 
142
  @torch.no_grad()
143
+ @spaces.GPU()
144
  def inference(net, prompt, negative_prompt, guidance_scale, ddim_steps, seed):
145
  mean.to(device)
146
  std.to(device)
 
197
 
198
  image = Image.fromarray((image * 255).round().astype("uint8"))
199
 
 
200
 
201
  return image
202
 
203
 
204
+
205
+
206
  @torch.no_grad()
207
+ @spaces.GPU()
208
+ def edit_inference(net, prompt, negative_prompt, guidance_scale, ddim_steps, seed, start_noise, a1, a2, a3, a4):
209
+ mean.to(device)
210
+ std.to(device)
211
+ v.to(device)
212
+ young.to(device)
213
+ pointy.to(device)
214
+ wavy.to(device)
215
+ thick.to(device)
 
 
 
 
 
 
216
 
217
+
218
+ weights = torch.load(net).to(device)
219
+ network = LoRAw2w(weights, mean, std, v[:, :1000],
220
+ unet,
221
  rank=1,
222
  multiplier=1.0,
223
  alpha=27.0,
224
  train_method="xattn-strict"
225
  ).to(device, torch.bfloat16)
226
+
227
+
228
+
229
+
230
 
231
  #pad to same number of PCs
232
+ pcs_original = weights.shape[1]
233
+ pcs_edits = young.shape[1]
234
  padding = torch.zeros((1,pcs_original-pcs_edits)).to(device)
235
+ young_pad = torch.cat((young, padding), 1)
236
+ pointy_pad = torch.cat((pointy, padding), 1)
237
+ wavy_pad = torch.cat((wavy, padding), 1)
238
+ thick_pad = torch.cat((thick, padding), 1)
239
 
240
 
241
+ edited_weights = weights+a1*1e6*young_pad+a2*1e6*pointy_pad+a3*1e6*wavy_pad+a4*2e6*thick_pad
242
 
243
  generator = torch.Generator(device=device).manual_seed(seed)
244
  latents = torch.randn(
245
+ (1, unet.in_channels, 512 // 8, 512 // 8),
246
  generator = generator,
247
+ device = device
248
  ).bfloat16()
249
 
250
 
251
+ text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
252
 
253
+ text_embeddings = text_encoder(text_input.input_ids.to(device))[0]
254
 
255
  max_length = text_input.input_ids.shape[-1]
256
+ uncond_input = tokenizer(
257
  [negative_prompt], padding="max_length", max_length=max_length, return_tensors="pt"
258
  )
259
+ uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0]
260
  text_embeddings = torch.cat([uncond_embeddings, text_embeddings]).bfloat16()
261
+ noise_scheduler.set_timesteps(ddim_steps)
262
+ latents = latents * noise_scheduler.init_noise_sigma
263
 
264
 
265
 
266
+ for i,t in enumerate(tqdm.tqdm(noise_scheduler.timesteps)):
267
  latent_model_input = torch.cat([latents] * 2)
268
+ latent_model_input = noise_scheduler.scale_model_input(latent_model_input, timestep=t)
269
 
270
  if t>start_noise:
271
  pass
 
274
  network.reset()
275
 
276
  with network:
277
+ noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings, timestep_cond= None).sample
278
 
279
 
280
  #guidance
 
283
  latents = noise_scheduler.step(noise_pred, t, latents).prev_sample
284
 
285
  latents = 1 / 0.18215 * latents
286
+ image = vae.decode(latents).sample
287
  image = (image / 2 + 0.5).clamp(0, 1)
288
  image = image.detach().cpu().float().permute(0, 2, 3, 1).numpy()[0]
289
  image = Image.fromarray((image * 255).round().astype("uint8"))
290
 
291
+ return net, image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
292
 
293
 
294
  class CustomImageDataset(Dataset):
 
514
 
515
  sample.click(fn=sample_then_run,inputs = [net], outputs=[net, input_image])
516
 
517
+ submit.click(
518
+ fn=edit_inference, inputs=[net, prompt, negative_prompt, cfg, steps, seed, injection_step, a1, a2, a3, a4], outputs=[net, gallery]
519
+ )
520
  # file_input.change(fn=model.file_upload, inputs=file_input, outputs = gallery)
521
 
522