English
Inference Endpoints
John6666 commited on
Commit
d3841a2
·
verified ·
1 Parent(s): d9431ee

Upload 3 files

Browse files
Files changed (3) hide show
  1. README.md +8 -0
  2. handler.py +203 -0
  3. requirements.txt +20 -0
README.md ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: other
3
+ license_name: flux-1-dev-non-commercial-license
4
+ license_link: https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.
5
+ language:
6
+ - en
7
+ inference: true
8
+ ---
handler.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/sayakpaul/diffusers-torchao
2
+ # https://github.com/pytorch/ao/releases
3
+ # https://developer.nvidia.com/cuda-gpus
4
+
5
+ import os
6
+ from typing import Any, Dict
7
+ import gc
8
+ from PIL import Image
9
+ from huggingface_hub import hf_hub_download
10
+ import torch
11
+ from torchao.quantization import quantize_, autoquant, int8_dynamic_activation_int8_weight, int8_dynamic_activation_int4_weight, float8_dynamic_activation_float8_weight
12
+ from torchao.quantization.quant_api import PerRow
13
+ from diffusers import FluxPipeline, FluxTransformer2DModel, AutoencoderKL, TorchAoConfig
14
+ from transformers import T5EncoderModel, BitsAndBytesConfig
15
+ from optimum.quanto import freeze, qfloat8, quantize
16
+ from para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe
17
+ from huggingface_inference_toolkit.logging import logger
18
+
19
+ import subprocess
20
+ subprocess.run("pip list", shell=True)
21
+
22
+ print(torch.cuda.get_device_name())
23
+ print(torch.cuda.get_device_capability())
24
+ print(torch.cuda.get_arch_list())
25
+
26
+ IS_NEW_GPU = False if torch.cuda.get_device_capability() < (8, 9) else True
27
+ IS_TURBO = False
28
+ IS_4BIT = True
29
+ IS_COMPILE = False
30
+ IS_AUTOQ = False
31
+ IS_PARA = True
32
+ IS_LVRAM = True
33
+
34
+ # Set high precision for float32 matrix multiplications.
35
+ # This setting optimizes performance on NVIDIA GPUs with Ampere architecture (e.g., A100, RTX 30 series) or newer.
36
+ torch.set_float32_matmul_precision("high")
37
+
38
+ if IS_COMPILE:
39
+ import torch._dynamo
40
+ torch._dynamo.config.suppress_errors = True
41
+
42
+ def offload_pipe(pipe) -> Any:
43
+ if IS_LVRAM: pipe.enable_model_cpu_offload()
44
+ return pipe
45
+
46
+ def load_te2(repo_id: str, dtype: torch.dtype) -> Any:
47
+ if IS_4BIT:
48
+ nf4_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16)
49
+ text_encoder_2 = T5EncoderModel.from_pretrained(repo_id, subfolder="text_encoder_2", torch_dtype=dtype, quantization_config=nf4_config)
50
+ else:
51
+ text_encoder_2 = T5EncoderModel.from_pretrained(repo_id, subfolder="text_encoder_2", torch_dtype=dtype)
52
+ quantize(text_encoder_2, weights=qfloat8)
53
+ freeze(text_encoder_2)
54
+ return text_encoder_2
55
+
56
+ def load_pipeline_stable(repo_id: str, dtype: torch.dtype) -> Any:
57
+ quantization_config = TorchAoConfig("int4dq" if IS_4BIT else "float8dq" if IS_NEW_GPU else "int8wo")
58
+ vae = AutoencoderKL.from_pretrained(repo_id, subfolder="vae", torch_dtype=dtype)
59
+ pipe = offload_pipe(FluxPipeline.from_pretrained(repo_id, vae=vae, text_encoder_2=load_te2(repo_id, dtype), torch_dtype=dtype, quantization_config=quantization_config))
60
+ pipe.transformer.fuse_qkv_projections()
61
+ pipe.vae.fuse_qkv_projections()
62
+ return pipe
63
+
64
+ def load_pipeline_lowvram(repo_id: str, dtype: torch.dtype) -> Any:
65
+ int4_config = TorchAoConfig("int4dq")
66
+ float8_config = TorchAoConfig("float8dq")
67
+ vae = AutoencoderKL.from_pretrained(repo_id, subfolder="vae", torch_dtype=dtype)
68
+ transformer = AutoencoderKL.from_pretrained(repo_id, subfolder="transformer", torch_dtype=dtype, quantization_config=float8_config)
69
+ pipe = FluxPipeline.from_pretrained(repo_id, vae=vae, transformer=transformer, text_encoder_2=load_te2(repo_id, dtype), torch_dtype=dtype, quantization_config=int4_config)
70
+ pipe.transformer.fuse_qkv_projections()
71
+ pipe.vae.fuse_qkv_projections()
72
+ pipe.to("cuda")
73
+ return pipe
74
+
75
+ def load_pipeline_compile(repo_id: str, dtype: torch.dtype) -> Any:
76
+ quantization_config = TorchAoConfig("int4dq" if IS_4BIT else "float8dq" if IS_NEW_GPU else "int8wo")
77
+ vae = AutoencoderKL.from_pretrained(repo_id, subfolder="vae", torch_dtype=dtype)
78
+ pipe = offload_pipe(FluxPipeline.from_pretrained(repo_id, vae=vae, text_encoder_2=load_te2(repo_id, dtype), torch_dtype=dtype, quantization_config=quantization_config))
79
+ pipe.transformer.fuse_qkv_projections()
80
+ pipe.vae.fuse_qkv_projections()
81
+ pipe.transformer.to(memory_format=torch.channels_last)
82
+ pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True)
83
+ pipe.vae.to(memory_format=torch.channels_last)
84
+ pipe.vae = torch.compile(pipe.vae, mode="max-autotune", fullgraph=True)
85
+ return pipe
86
+
87
+ def load_pipeline_autoquant(repo_id: str, dtype: torch.dtype) -> Any:
88
+ pipe = offload_pipe(FluxPipeline.from_pretrained(repo_id, torch_dtype=dtype))
89
+ pipe.transformer.fuse_qkv_projections()
90
+ pipe.vae.fuse_qkv_projections()
91
+ pipe.transformer.to(memory_format=torch.channels_last)
92
+ pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True)
93
+ pipe.vae.to(memory_format=torch.channels_last)
94
+ pipe.vae = torch.compile(pipe.vae, mode="max-autotune", fullgraph=True)
95
+ pipe.transformer = autoquant(pipe.transformer, error_on_unseen=False)
96
+ pipe.vae = autoquant(pipe.vae, error_on_unseen=False)
97
+ return pipe
98
+
99
+ def load_pipeline_turbo(repo_id: str, dtype: torch.dtype) -> Any:
100
+ pipe = offload_pipe(FluxPipeline.from_pretrained(repo_id, torch_dtype=dtype))
101
+ pipe.load_lora_weights(hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"), adapter_name="hyper-sd")
102
+ pipe.set_adapters(["hyper-sd"], adapter_weights=[0.125])
103
+ pipe.fuse_lora()
104
+ pipe.unload_lora_weights()
105
+ pipe.transformer.fuse_qkv_projections()
106
+ pipe.vae.fuse_qkv_projections()
107
+ weight = int8_dynamic_activation_int4_weight() if IS_4BIT else int8_dynamic_activation_int8_weight()
108
+ quantize_(pipe.transformer, weight, device="cuda")
109
+ quantize_(pipe.vae, weight, device="cuda")
110
+ return pipe
111
+
112
+ def load_pipeline_turbo_compile(repo_id: str, dtype: torch.dtype) -> Any:
113
+ pipe = offload_pipe(FluxPipeline.from_pretrained(repo_id, torch_dtype=dtype))
114
+ pipe.load_lora_weights(hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"), adapter_name="hyper-sd")
115
+ pipe.set_adapters(["hyper-sd"], adapter_weights=[0.125])
116
+ pipe.fuse_lora()
117
+ pipe.unload_lora_weights()
118
+ pipe.transformer.fuse_qkv_projections()
119
+ pipe.vae.fuse_qkv_projections()
120
+ weight = int8_dynamic_activation_int4_weight() if IS_4BIT else int8_dynamic_activation_int8_weight()
121
+ quantize_(pipe.transformer, weight, device="cuda")
122
+ quantize_(pipe.vae, weight, device="cuda")
123
+ pipe.transformer.to(memory_format=torch.channels_last)
124
+ pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True)
125
+ pipe.vae.to(memory_format=torch.channels_last)
126
+ pipe.vae = torch.compile(pipe.vae, mode="max-autotune", fullgraph=True)
127
+ return pipe
128
+
129
+ def load_pipeline_opt(repo_id: str, dtype: torch.dtype) -> Any:
130
+ quantization_config = TorchAoConfig("int4dq" if IS_4BIT else "float8dq" if IS_NEW_GPU else "int8wo")
131
+ weight = int8_dynamic_activation_int4_weight() if IS_4BIT else int8_dynamic_activation_int8_weight()
132
+ transformer = FluxTransformer2DModel.from_pretrained(repo_id, subfolder="transformer", torch_dtype=dtype)
133
+ transformer.fuse_qkv_projections()
134
+ if IS_NEW_GPU: quantize_(transformer, float8_dynamic_activation_float8_weight(granularity=PerRow()), device="cuda")
135
+ else: quantize_(transformer, weight, device="cuda")
136
+ transformer.to(memory_format=torch.channels_last)
137
+ transformer = torch.compile(transformer, mode="max-autotune", fullgraph=True)
138
+ vae = AutoencoderKL.from_pretrained(repo_id, subfolder="vae", torch_dtype=dtype)
139
+ vae.fuse_qkv_projections()
140
+ if IS_NEW_GPU: quantize_(vae, float8_dynamic_activation_float8_weight(granularity=PerRow()), device="cuda")
141
+ else: quantize_(vae, weight, device="cuda")
142
+ vae.to(memory_format=torch.channels_last)
143
+ vae = torch.compile(vae, mode="max-autotune", fullgraph=True)
144
+ pipe = offload_pipe(FluxPipeline.from_pretrained(repo_id, transformer=None, vae=None, text_encoder_2=load_te2(repo_id, dtype), torch_dtype=dtype, quantization_config=quantization_config))
145
+ pipe.transformer = transformer
146
+ pipe.vae = vae
147
+ return pipe
148
+
149
+ class EndpointHandler:
150
+ def __init__(self, path=""):
151
+ repo_id = "NoMoreCopyrightOrg/flux-dev-8step" if IS_TURBO else "NoMoreCopyrightOrg/flux-dev"
152
+ dtype = torch.bfloat16
153
+ #dtype = torch.float16 # for older nVidia GPUs
154
+ if IS_AUTOQ: self.pipeline = load_pipeline_autoquant(repo_id, dtype)
155
+ elif IS_COMPILE: self.pipeline = load_pipeline_opt(repo_id, dtype)
156
+ elif IS_LVRAM and IS_NEW_GPU: self.pipeline = load_pipeline_lowvram(repo_id, dtype)
157
+ else: self.pipeline = load_pipeline_stable(repo_id, dtype)
158
+ if IS_PARA: apply_cache_on_pipe(self.pipeline, residual_diff_threshold=0.12)
159
+ gc.collect()
160
+ torch.cuda.empty_cache()
161
+ self.enable_vae_slicing()
162
+ self.enable_vae_tiling()
163
+ if IS_LVRAM:
164
+ self.pipeline.transformer.to("cuda")
165
+ self.pipeline.vae.to("cuda")
166
+ else: self.pipeline.to("cuda")
167
+ print(self.pipeline)
168
+
169
+ def __call__(self, data: Dict[str, Any]) -> Image.Image:
170
+ logger.info(f"Received incoming request with {data=}")
171
+
172
+ if "inputs" in data and isinstance(data["inputs"], str):
173
+ prompt = data.pop("inputs")
174
+ elif "prompt" in data and isinstance(data["prompt"], str):
175
+ prompt = data.pop("prompt")
176
+ else:
177
+ raise ValueError(
178
+ "Provided input body must contain either the key `inputs` or `prompt` with the"
179
+ " prompt to use for the image generation, and it needs to be a non-empty string."
180
+ )
181
+
182
+ parameters = data.pop("parameters", {})
183
+
184
+ num_inference_steps = parameters.get("num_inference_steps", 8 if IS_TURBO else 28)
185
+ width = parameters.get("width", 1024)
186
+ height = parameters.get("height", 1024)
187
+ guidance_scale = parameters.get("guidance_scale", 3.5)
188
+
189
+ # seed generator (seed cannot be provided as is but via a generator)
190
+ seed = parameters.get("seed", 0)
191
+ generator = torch.manual_seed(seed)
192
+
193
+ return self.pipeline( # type: ignore
194
+ prompt,
195
+ height=height,
196
+ width=width,
197
+ guidance_scale=guidance_scale,
198
+ num_inference_steps=num_inference_steps,
199
+ generator=generator,
200
+ output_type="pil",
201
+ ).images[0]
202
+
203
+
requirements.txt ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu126
2
+ torch>=2.6.0
3
+ torchvision
4
+ torchaudio
5
+ huggingface_hub
6
+ torchao>=0.9.0
7
+ diffusers>=0.32.2
8
+ peft
9
+ transformers==4.48.3
10
+ accelerate
11
+ numpy
12
+ scipy
13
+ Pillow
14
+ sentencepiece
15
+ protobuf
16
+ triton
17
+ gemlite
18
+ para-attn
19
+ bitsandbytes
20
+ optimum-quanto