radames commited on
Commit
31bcd4c
1 Parent(s): 9b5f363

use tinyVAE

Browse files
Files changed (2) hide show
  1. app-img2img.py +3 -4
  2. latent_consistency_img2img.py +15 -5
app-img2img.py CHANGED
@@ -55,10 +55,9 @@ else:
55
  custom_pipeline="latent_consistency_img2img.py",
56
  custom_revision="main",
57
  )
58
- # TODO try to use tiny VAE
59
- # pipe.vae = AutoencoderTiny.from_pretrained(
60
- # "madebyollin/taesd", torch_dtype=torch.float16, use_safetensors=True
61
- # )
62
  pipe.set_progress_bar_config(disable=True)
63
  pipe.to(torch_device=torch_device, torch_dtype=torch_dtype).to(device)
64
  pipe.unet.to(memory_format=torch.channels_last)
 
55
  custom_pipeline="latent_consistency_img2img.py",
56
  custom_revision="main",
57
  )
58
+ pipe.vae = AutoencoderTiny.from_pretrained(
59
+ "madebyollin/taesd", torch_dtype=torch.float16, use_safetensors=True
60
+ )
 
61
  pipe.set_progress_bar_config(disable=True)
62
  pipe.to(torch_device=torch_device, torch_dtype=torch_dtype).to(device)
63
  pipe.unet.to(memory_format=torch.channels_last)
latent_consistency_img2img.py CHANGED
@@ -25,6 +25,7 @@ import torch
25
  from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
26
 
27
  from diffusers import (
 
28
  AutoencoderKL,
29
  ConfigMixin,
30
  DiffusionPipeline,
@@ -226,13 +227,22 @@ class LatentConsistencyModelImg2ImgPipeline(DiffusionPipeline):
226
  )
227
 
228
  elif isinstance(generator, list):
229
- init_latents = [
230
- self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i])
231
- for i in range(batch_size)
232
- ]
 
 
 
 
 
 
233
  init_latents = torch.cat(init_latents, dim=0)
234
  else:
235
- init_latents = self.vae.encode(image).latent_dist.sample(generator)
 
 
 
236
 
237
  init_latents = self.vae.config.scaling_factor * init_latents
238
 
 
25
  from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
26
 
27
  from diffusers import (
28
+ AutoencoderTiny,
29
  AutoencoderKL,
30
  ConfigMixin,
31
  DiffusionPipeline,
 
227
  )
228
 
229
  elif isinstance(generator, list):
230
+ if isinstance(self.vae, AutoencoderTiny):
231
+ init_latents = [
232
+ self.vae.encode(image[i : i + 1]).latents
233
+ for i in range(batch_size)
234
+ ]
235
+ else:
236
+ init_latents = [
237
+ self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i])
238
+ for i in range(batch_size)
239
+ ]
240
  init_latents = torch.cat(init_latents, dim=0)
241
  else:
242
+ if isinstance(self.vae, AutoencoderTiny):
243
+ init_latents = self.vae.encode(image).latents
244
+ else:
245
+ init_latents = self.vae.encode(image).latent_dist.sample(generator)
246
 
247
  init_latents = self.vae.config.scaling_factor * init_latents
248