English
Inference Endpoints
John6666 commited on
Commit
71d24b5
·
verified ·
1 Parent(s): 257dda1

Upload handler.py

Browse files
Files changed (1) hide show
  1. handler.py +5 -0
handler.py CHANGED
@@ -8,6 +8,7 @@ from PIL import Image
8
  import torch
9
  from torchao.quantization import quantize_, autoquant, int8_dynamic_activation_int8_weight, int8_dynamic_activation_int4_weight
10
  from huggingface_hub import hf_hub_download
 
11
 
12
  IS_COMPILE = False
13
  IS_TURBO = False
@@ -59,6 +60,7 @@ def load_pipeline_turbo(repo_id: str, dtype: torch.dtype) -> Any:
59
  pipe.load_lora_weights(hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"), adapter_name="hyper-sd")
60
  pipe.set_adapters(["hyper-sd"], adapter_weights=[0.125])
61
  pipe.fuse_lora()
 
62
  pipe.transformer.fuse_qkv_projections()
63
  pipe.vae.fuse_qkv_projections()
64
  weight = int8_dynamic_activation_int4_weight() if IS_4BIT else int8_dynamic_activation_int8_weight()
@@ -73,6 +75,7 @@ def load_pipeline_turbo_compile(repo_id: str, dtype: torch.dtype) -> Any:
73
  pipe.load_lora_weights(hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"), adapter_name="hyper-sd")
74
  pipe.set_adapters(["hyper-sd"], adapter_weights=[0.125])
75
  pipe.fuse_lora()
 
76
  pipe.transformer.fuse_qkv_projections()
77
  pipe.vae.fuse_qkv_projections()
78
  weight = int8_dynamic_activation_int4_weight() if IS_4BIT else int8_dynamic_activation_int8_weight()
@@ -93,6 +96,8 @@ class EndpointHandler:
93
  dtype = torch.float16 # for older nVidia GPUs
94
  if IS_COMPILE: load_pipeline_compile(repo_id, dtype)
95
  else: self.pipeline = load_pipeline_stable(repo_id, dtype)
 
 
96
 
97
  def __call__(self, data: Dict[str, Any]) -> Image.Image:
98
  logger.info(f"Received incoming request with {data=}")
 
8
  import torch
9
  from torchao.quantization import quantize_, autoquant, int8_dynamic_activation_int8_weight, int8_dynamic_activation_int4_weight
10
  from huggingface_hub import hf_hub_download
11
+ import gc
12
 
13
  IS_COMPILE = False
14
  IS_TURBO = False
 
60
  pipe.load_lora_weights(hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"), adapter_name="hyper-sd")
61
  pipe.set_adapters(["hyper-sd"], adapter_weights=[0.125])
62
  pipe.fuse_lora()
63
+ pipe.unload_lora_weights()
64
  pipe.transformer.fuse_qkv_projections()
65
  pipe.vae.fuse_qkv_projections()
66
  weight = int8_dynamic_activation_int4_weight() if IS_4BIT else int8_dynamic_activation_int8_weight()
 
75
  pipe.load_lora_weights(hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"), adapter_name="hyper-sd")
76
  pipe.set_adapters(["hyper-sd"], adapter_weights=[0.125])
77
  pipe.fuse_lora()
78
+ pipe.unload_lora_weights()
79
  pipe.transformer.fuse_qkv_projections()
80
  pipe.vae.fuse_qkv_projections()
81
  weight = int8_dynamic_activation_int4_weight() if IS_4BIT else int8_dynamic_activation_int8_weight()
 
96
  dtype = torch.float16 # for older nVidia GPUs
97
  if IS_COMPILE: load_pipeline_compile(repo_id, dtype)
98
  else: self.pipeline = load_pipeline_stable(repo_id, dtype)
99
+ gc.collect()
100
+ torch.cuda.empty_cache()
101
 
102
  def __call__(self, data: Dict[str, Any]) -> Image.Image:
103
  logger.info(f"Received incoming request with {data=}")