John6666 commited on
Commit
fecd4a4
·
verified ·
1 Parent(s): f28bd15

Upload handler.py

Browse files
Files changed (1) hide show
  1. handler.py +4 -2
handler.py CHANGED
@@ -29,8 +29,10 @@ class EndpointHandler:
29
  vae = AutoencoderKL.from_pretrained(repo_id, subfolder="vae", torch_dtype=dtype)
30
  #transformer = FluxTransformer2DModel.from_pretrained(repo_id, subfolder="transformer", torch_dtype=dtype, quantization_config=quantization_config).to("cuda")
31
  self.pipeline = FluxPipeline.from_pretrained(repo_id, vae=vae, torch_dtype=dtype, quantization_config=quantization_config)
32
- self.pipeline.transformer.fuse_qkv_projections().to(memory_format=torch.channels_last)
33
- self.pipeline.vae.fuse_qkv_projections().to(memory_format=torch.channels_last)
 
 
34
  if IS_COMPILE: self.pipeline = compile_pipeline(self.pipeline)
35
  self.pipeline.to(distributed_state.device)
36
 
 
29
  vae = AutoencoderKL.from_pretrained(repo_id, subfolder="vae", torch_dtype=dtype)
30
  #transformer = FluxTransformer2DModel.from_pretrained(repo_id, subfolder="transformer", torch_dtype=dtype, quantization_config=quantization_config).to("cuda")
31
  self.pipeline = FluxPipeline.from_pretrained(repo_id, vae=vae, torch_dtype=dtype, quantization_config=quantization_config)
32
+ self.pipeline.transformer.fuse_qkv_projections()
33
+ self.pipeline.transformer.to(memory_format=torch.channels_last)
34
+ self.pipeline.vae.fuse_qkv_projections()
35
+ self.pipeline.vae.to(memory_format=torch.channels_last)
36
  if IS_COMPILE: self.pipeline = compile_pipeline(self.pipeline)
37
  self.pipeline.to(distributed_state.device)
38