amildravid4292 commited on
Commit
f821ec0
·
verified ·
1 Parent(s): 7dcf34d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -12
app.py CHANGED
@@ -159,39 +159,39 @@ def inference(net, prompt, negative_prompt, guidance_scale, ddim_steps, seed):
159
 
160
  generator = torch.Generator(device=device).manual_seed(seed)
161
  latents = torch.randn(
162
- (1, self.unet.in_channels, 512 // 8, 512 // 8),
163
  generator = generator,
164
- device = self.device
165
  ).bfloat16()
166
 
167
 
168
- text_input = self.tokenizer(prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt")
169
 
170
- text_embeddings = self.text_encoder(text_input.input_ids.to(device))[0]
171
 
172
  max_length = text_input.input_ids.shape[-1]
173
- uncond_input = self.tokenizer(
174
  [negative_prompt], padding="max_length", max_length=max_length, return_tensors="pt"
175
  )
176
- uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(device))[0]
177
  text_embeddings = torch.cat([uncond_embeddings, text_embeddings]).bfloat16()
178
- self.noise_scheduler.set_timesteps(ddim_steps)
179
  latents = latents * self.noise_scheduler.init_noise_sigma
180
 
181
- for i,t in enumerate(tqdm.tqdm(self.noise_scheduler.timesteps)):
182
  latent_model_input = torch.cat([latents] * 2)
183
- latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, timestep=t)
184
 
185
  with network:
186
- noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings, timestep_cond= None).sample
187
 
188
  #guidance
189
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
190
  noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
191
- latents = self.noise_scheduler.step(noise_pred, t, latents).prev_sample
192
 
193
  latents = 1 / 0.18215 * latents
194
- image = self.vae.decode(latents.float()).sample
195
  image = (image / 2 + 0.5).clamp(0, 1)
196
  image = image.detach().cpu().float().permute(0, 2, 3, 1).numpy()[0]
197
 
 
159
 
160
  generator = torch.Generator(device=device).manual_seed(seed)
161
  latents = torch.randn(
162
+ (1, unet.in_channels, 512 // 8, 512 // 8),
163
  generator = generator,
164
+ device = device
165
  ).bfloat16()
166
 
167
 
168
+ text_input = self.tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
169
 
170
+ text_embeddings = text_encoder(text_input.input_ids.to(device))[0]
171
 
172
  max_length = text_input.input_ids.shape[-1]
173
+ uncond_input = tokenizer(
174
  [negative_prompt], padding="max_length", max_length=max_length, return_tensors="pt"
175
  )
176
+ uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0]
177
  text_embeddings = torch.cat([uncond_embeddings, text_embeddings]).bfloat16()
178
+ noise_scheduler.set_timesteps(ddim_steps)
179
  latents = latents * self.noise_scheduler.init_noise_sigma
180
 
181
+ for i,t in enumerate(tqdm.tqdm(noise_scheduler.timesteps)):
182
  latent_model_input = torch.cat([latents] * 2)
183
+ latent_model_input = noise_scheduler.scale_model_input(latent_model_input, timestep=t)
184
 
185
  with network:
186
+ noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings, timestep_cond= None).sample
187
 
188
  #guidance
189
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
190
  noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
191
+ latents = noise_scheduler.step(noise_pred, t, latents).prev_sample
192
 
193
  latents = 1 / 0.18215 * latents
194
+ image = vae.decode(latents.float()).sample
195
  image = (image / 2 + 0.5).clamp(0, 1)
196
  image = image.detach().cpu().float().permute(0, 2, 3, 1).numpy()[0]
197