Upload handler.py
Browse files- handler.py +6 -1
handler.py
CHANGED
@@ -8,6 +8,7 @@ from PIL import Image
|
|
8 |
import torch
|
9 |
from torchao.quantization import quantize_, autoquant, int8_dynamic_activation_int8_weight, int8_dynamic_activation_int4_weight
|
10 |
from huggingface_hub import hf_hub_download
|
|
|
11 |
|
12 |
IS_COMPILE = False
|
13 |
IS_TURBO = True
|
@@ -59,6 +60,7 @@ def load_pipeline_turbo(repo_id: str, dtype: torch.dtype) -> Any:
|
|
59 |
pipe.load_lora_weights(hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"), adapter_name="hyper-sd")
|
60 |
pipe.set_adapters(["hyper-sd"], adapter_weights=[0.125])
|
61 |
pipe.fuse_lora()
|
|
|
62 |
pipe.transformer.fuse_qkv_projections()
|
63 |
pipe.vae.fuse_qkv_projections()
|
64 |
weight = int8_dynamic_activation_int4_weight() if IS_4BIT else int8_dynamic_activation_int8_weight()
|
@@ -73,6 +75,7 @@ def load_pipeline_turbo_compile(repo_id: str, dtype: torch.dtype) -> Any:
|
|
73 |
pipe.load_lora_weights(hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"), adapter_name="hyper-sd")
|
74 |
pipe.set_adapters(["hyper-sd"], adapter_weights=[0.125])
|
75 |
pipe.fuse_lora()
|
|
|
76 |
pipe.transformer.fuse_qkv_projections()
|
77 |
pipe.vae.fuse_qkv_projections()
|
78 |
weight = int8_dynamic_activation_int4_weight() if IS_4BIT else int8_dynamic_activation_int8_weight()
|
@@ -88,11 +91,13 @@ def load_pipeline_turbo_compile(repo_id: str, dtype: torch.dtype) -> Any:
|
|
88 |
|
89 |
class EndpointHandler:
|
90 |
def __init__(self, path=""):
|
91 |
-
repo_id = "NoMoreCopyrightOrg/flux-dev-8step" if IS_TURBO else "
|
92 |
#dtype = torch.bfloat16
|
93 |
dtype = torch.float16 # for older nVidia GPUs
|
94 |
if IS_COMPILE: load_pipeline_compile(repo_id, dtype)
|
95 |
else: self.pipeline = load_pipeline_stable(repo_id, dtype)
|
|
|
|
|
96 |
|
97 |
def __call__(self, data: Dict[str, Any]) -> Image.Image:
|
98 |
logger.info(f"Received incoming request with {data=}")
|
|
|
8 |
import torch
|
9 |
from torchao.quantization import quantize_, autoquant, int8_dynamic_activation_int8_weight, int8_dynamic_activation_int4_weight
|
10 |
from huggingface_hub import hf_hub_download
|
11 |
+
import gc
|
12 |
|
13 |
IS_COMPILE = False
|
14 |
IS_TURBO = True
|
|
|
60 |
pipe.load_lora_weights(hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"), adapter_name="hyper-sd")
|
61 |
pipe.set_adapters(["hyper-sd"], adapter_weights=[0.125])
|
62 |
pipe.fuse_lora()
|
63 |
+
pipe.unload_lora_weights()
|
64 |
pipe.transformer.fuse_qkv_projections()
|
65 |
pipe.vae.fuse_qkv_projections()
|
66 |
weight = int8_dynamic_activation_int4_weight() if IS_4BIT else int8_dynamic_activation_int8_weight()
|
|
|
75 |
pipe.load_lora_weights(hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"), adapter_name="hyper-sd")
|
76 |
pipe.set_adapters(["hyper-sd"], adapter_weights=[0.125])
|
77 |
pipe.fuse_lora()
|
78 |
+
pipe.unload_lora_weights()
|
79 |
pipe.transformer.fuse_qkv_projections()
|
80 |
pipe.vae.fuse_qkv_projections()
|
81 |
weight = int8_dynamic_activation_int4_weight() if IS_4BIT else int8_dynamic_activation_int8_weight()
|
|
|
91 |
|
92 |
class EndpointHandler:
|
93 |
def __init__(self, path=""):
|
94 |
+
repo_id = "NoMoreCopyrightOrg/flux-dev-8step" if IS_TURBO else "NoMoreCopyrightOrg/flux-dev"
|
95 |
#dtype = torch.bfloat16
|
96 |
dtype = torch.float16 # for older nVidia GPUs
|
97 |
if IS_COMPILE: load_pipeline_compile(repo_id, dtype)
|
98 |
else: self.pipeline = load_pipeline_stable(repo_id, dtype)
|
99 |
+
gc.collect()
|
100 |
+
torch.cuda.empty_cache()
|
101 |
|
102 |
def __call__(self, data: Dict[str, Any]) -> Image.Image:
|
103 |
logger.info(f"Received incoming request with {data=}")
|