English
Inference Endpoints
John6666 commited on
Commit
9ae3cc2
·
verified ·
1 Parent(s): a3e1aa2

Upload handler.py

Browse files
Files changed (1) hide show
  1. handler.py +2 -2
handler.py CHANGED
@@ -67,8 +67,8 @@ def load_pipeline_lowvram(repo_id: str, dtype: torch.dtype) -> Any:
67
  vae = AutoencoderKL.from_pretrained(repo_id, subfolder="vae", torch_dtype=dtype, quantization_config=float8_config)
68
  transformer = FluxTransformer2DModel.from_pretrained(repo_id, subfolder="transformer", torch_dtype=dtype, quantization_config=float8_config)
69
  pipe = FluxPipeline.from_pretrained(repo_id, vae=vae, transformer=transformer, text_encoder_2=load_te2(repo_id, dtype), torch_dtype=dtype, quantization_config=int4_config)
70
- pipe.transformer.fuse_qkv_projections()
71
- pipe.vae.fuse_qkv_projections()
72
  pipe.to("cuda")
73
  return pipe
74
 
 
67
  vae = AutoencoderKL.from_pretrained(repo_id, subfolder="vae", torch_dtype=dtype, quantization_config=float8_config)
68
  transformer = FluxTransformer2DModel.from_pretrained(repo_id, subfolder="transformer", torch_dtype=dtype, quantization_config=float8_config)
69
  pipe = FluxPipeline.from_pretrained(repo_id, vae=vae, transformer=transformer, text_encoder_2=load_te2(repo_id, dtype), torch_dtype=dtype, quantization_config=int4_config)
70
+ #pipe.transformer.fuse_qkv_projections()
71
+ #pipe.vae.fuse_qkv_projections()
72
  pipe.to("cuda")
73
  return pipe
74