English
Inference Endpoints
John6666 commited on
Commit
b7d193e
·
verified ·
1 Parent(s): 6551d57

Upload handler.py

Browse files
Files changed (1) hide show
  1. handler.py +9 -5
handler.py CHANGED
@@ -58,11 +58,15 @@ def load_pipeline_stable(repo_id: str, dtype: torch.dtype) -> Any:
58
  return pipe
59
 
60
  def load_pipeline_lowvram(repo_id: str, dtype: torch.dtype) -> Any:
61
- int4_config = TorchAoConfig("int4dq")
62
- float8_config = TorchAoConfig("float8dq")
63
- vae = AutoencoderKL.from_pretrained(repo_id, subfolder="vae", torch_dtype=dtype, quantization_config=float8_config)
64
- transformer = FluxTransformer2DModel.from_pretrained(repo_id, subfolder="transformer", torch_dtype=dtype, quantization_config=float8_config)
65
- pipe = FluxPipeline.from_pretrained(repo_id, vae=vae, transformer=transformer, text_encoder_2=load_te2(repo_id, dtype), torch_dtype=dtype, quantization_config=int4_config)
 
 
 
 
66
  #pipe.transformer.fuse_qkv_projections()
67
  #pipe.vae.fuse_qkv_projections()
68
  pipe.to("cuda")
 
58
  return pipe
59
 
60
  def load_pipeline_lowvram(repo_id: str, dtype: torch.dtype) -> Any:
61
+ #int4_config = TorchAoConfig("int4dq")
62
+ #float8_config = TorchAoConfig("float8dq")
63
+ vae = AutoencoderKL.from_pretrained(repo_id, subfolder="vae", torch_dtype=dtype)
64
+ quantize_(vae, float8_dynamic_activation_float8_weight(granularity=PerRow()), device="cuda")
65
+ transformer = FluxTransformer2DModel.from_pretrained(repo_id, subfolder="transformer", torch_dtype=dtype)
66
+ quantize_(transformer, float8_dynamic_activation_float8_weight(granularity=PerRow()), device="cuda")
67
+ pipe = FluxPipeline.from_pretrained(repo_id, vae=None, transformer=None, text_encoder_2=load_te2(repo_id, dtype), torch_dtype=dtype, quantization_config=int4_config)
68
+ pipe.transformer = transformer
69
+ pipe.vae = vae
70
  #pipe.transformer.fuse_qkv_projections()
71
  #pipe.vae.fuse_qkv_projections()
72
  pipe.to("cuda")