Upload handler.py
Browse files- handler.py +9 -5
handler.py
CHANGED
@@ -58,11 +58,15 @@ def load_pipeline_stable(repo_id: str, dtype: torch.dtype) -> Any:
|
|
58 |
return pipe
|
59 |
|
60 |
def load_pipeline_lowvram(repo_id: str, dtype: torch.dtype) -> Any:
|
61 |
-
int4_config = TorchAoConfig("int4dq")
|
62 |
-
float8_config = TorchAoConfig("float8dq")
|
63 |
-
vae = AutoencoderKL.from_pretrained(repo_id, subfolder="vae", torch_dtype=dtype
|
64 |
-
|
65 |
-
|
|
|
|
|
|
|
|
|
66 |
#pipe.transformer.fuse_qkv_projections()
|
67 |
#pipe.vae.fuse_qkv_projections()
|
68 |
pipe.to("cuda")
|
|
|
58 |
return pipe
|
59 |
|
60 |
def load_pipeline_lowvram(repo_id: str, dtype: torch.dtype) -> Any:
|
61 |
+
#int4_config = TorchAoConfig("int4dq")
|
62 |
+
#float8_config = TorchAoConfig("float8dq")
|
63 |
+
vae = AutoencoderKL.from_pretrained(repo_id, subfolder="vae", torch_dtype=dtype)
|
64 |
+
quantize_(vae, float8_dynamic_activation_float8_weight(granularity=PerRow()), device="cuda")
|
65 |
+
transformer = FluxTransformer2DModel.from_pretrained(repo_id, subfolder="transformer", torch_dtype=dtype)
|
66 |
+
quantize_(transformer, float8_dynamic_activation_float8_weight(granularity=PerRow()), device="cuda")
|
67 |
+
pipe = FluxPipeline.from_pretrained(repo_id, vae=None, transformer=None, text_encoder_2=load_te2(repo_id, dtype), torch_dtype=dtype, quantization_config=int4_config)
|
68 |
+
pipe.transformer = transformer
|
69 |
+
pipe.vae = vae
|
70 |
#pipe.transformer.fuse_qkv_projections()
|
71 |
#pipe.vae.fuse_qkv_projections()
|
72 |
pipe.to("cuda")
|