English
Inference Endpoints
John6666 commited on
Commit
f32b8e8
·
verified ·
1 Parent(s): 7e7c650

Upload handler.py

Browse files
Files changed (1) hide show
  1. handler.py +12 -7
handler.py CHANGED
@@ -6,11 +6,12 @@ from typing import Any, Dict
6
  from diffusers import FluxPipeline, FluxTransformer2DModel, AutoencoderKL, TorchAoConfig
7
  from PIL import Image
8
  import torch
9
- from torchao.quantization import quantize_, autoquant, int8_dynamic_activation_int8_weight
10
  from huggingface_hub import hf_hub_download
11
 
12
  IS_COMPILE = False
13
  IS_TURBO = False
 
14
 
15
  if IS_COMPILE:
16
  import torch._dynamo
@@ -19,7 +20,7 @@ if IS_COMPILE:
19
  from huggingface_inference_toolkit.logging import logger
20
 
21
  def load_pipeline_stable(repo_id: str, dtype: torch.dtype) -> Any:
22
- quantization_config = TorchAoConfig("int8dq")
23
  vae = AutoencoderKL.from_pretrained(repo_id, subfolder="vae", torch_dtype=dtype)
24
  pipe = FluxPipeline.from_pretrained(repo_id, vae=vae, torch_dtype=dtype, quantization_config=quantization_config)
25
  pipe.transformer.fuse_qkv_projections()
@@ -28,7 +29,7 @@ def load_pipeline_stable(repo_id: str, dtype: torch.dtype) -> Any:
28
  return pipe
29
 
30
  def load_pipeline_compile(repo_id: str, dtype: torch.dtype) -> Any:
31
- quantization_config = TorchAoConfig("int8dq")
32
  vae = AutoencoderKL.from_pretrained(repo_id, subfolder="vae", torch_dtype=dtype)
33
  pipe = FluxPipeline.from_pretrained(repo_id, vae=vae, torch_dtype=dtype, quantization_config=quantization_config)
34
  pipe.transformer.fuse_qkv_projections()
@@ -60,8 +61,10 @@ def load_pipeline_turbo(repo_id: str, dtype: torch.dtype) -> Any:
60
  pipe.fuse_lora()
61
  pipe.transformer.fuse_qkv_projections()
62
  pipe.vae.fuse_qkv_projections()
63
- quantize_(pipe.transformer, int8_dynamic_activation_int8_weight(), device="cuda")
64
- quantize_(pipe.vae, int8_dynamic_activation_int8_weight(), device="cuda")
 
 
65
  pipe.to("cuda")
66
  return pipe
67
 
@@ -72,8 +75,10 @@ def load_pipeline_turbo_compile(repo_id: str, dtype: torch.dtype) -> Any:
72
  pipe.fuse_lora()
73
  pipe.transformer.fuse_qkv_projections()
74
  pipe.vae.fuse_qkv_projections()
75
- quantize_(pipe.transformer, int8_dynamic_activation_int8_weight(), device="cuda")
76
- quantize_(pipe.vae, int8_dynamic_activation_int8_weight(), device="cuda")
 
 
77
  pipe.transformer.to(memory_format=torch.channels_last)
78
  pipe.transformer = torch.compile(pipe.transformer, mode="reduce-overhead", fullgraph=False, dynamic=False)
79
  pipe.vae.to(memory_format=torch.channels_last)
 
6
  from diffusers import FluxPipeline, FluxTransformer2DModel, AutoencoderKL, TorchAoConfig
7
  from PIL import Image
8
  import torch
9
+ from torchao.quantization import quantize_, autoquant, int8_dynamic_activation_int8_weight, int8_dynamic_activation_int4_weight
10
  from huggingface_hub import hf_hub_download
11
 
12
  IS_COMPILE = False
13
  IS_TURBO = False
14
+ IS_4BIT = True
15
 
16
  if IS_COMPILE:
17
  import torch._dynamo
 
20
  from huggingface_inference_toolkit.logging import logger
21
 
22
  def load_pipeline_stable(repo_id: str, dtype: torch.dtype) -> Any:
23
+ quantization_config = TorchAoConfig("int4dq" if IS_4BIT else "int8dq")
24
  vae = AutoencoderKL.from_pretrained(repo_id, subfolder="vae", torch_dtype=dtype)
25
  pipe = FluxPipeline.from_pretrained(repo_id, vae=vae, torch_dtype=dtype, quantization_config=quantization_config)
26
  pipe.transformer.fuse_qkv_projections()
 
29
  return pipe
30
 
31
  def load_pipeline_compile(repo_id: str, dtype: torch.dtype) -> Any:
32
+ quantization_config = TorchAoConfig("int4dq" if IS_4BIT else "int8dq")
33
  vae = AutoencoderKL.from_pretrained(repo_id, subfolder="vae", torch_dtype=dtype)
34
  pipe = FluxPipeline.from_pretrained(repo_id, vae=vae, torch_dtype=dtype, quantization_config=quantization_config)
35
  pipe.transformer.fuse_qkv_projections()
 
61
  pipe.fuse_lora()
62
  pipe.transformer.fuse_qkv_projections()
63
  pipe.vae.fuse_qkv_projections()
64
+ weight = int8_dynamic_activation_int4_weight() if IS_4BIT else int8_dynamic_activation_int8_weight()
65
+ quantize_(pipe.transformer, weight, device="cuda")
66
+ quantize_(pipe.vae, weight, device="cuda")
67
+ quantize_(pipe.text_encoder_2, weight, device="cuda")
68
  pipe.to("cuda")
69
  return pipe
70
 
 
75
  pipe.fuse_lora()
76
  pipe.transformer.fuse_qkv_projections()
77
  pipe.vae.fuse_qkv_projections()
78
+ weight = int8_dynamic_activation_int4_weight() if IS_4BIT else int8_dynamic_activation_int8_weight()
79
+ quantize_(pipe.transformer, weight, device="cuda")
80
+ quantize_(pipe.vae, weight, device="cuda")
81
+ quantize_(pipe.text_encoder_2, weight, device="cuda")
82
  pipe.transformer.to(memory_format=torch.channels_last)
83
  pipe.transformer = torch.compile(pipe.transformer, mode="reduce-overhead", fullgraph=False, dynamic=False)
84
  pipe.vae.to(memory_format=torch.channels_last)