John6666 commited on
Commit
136905c
·
verified ·
1 Parent(s): efad4b5

Upload 3 files

Browse files
Files changed (3) hide show
  1. README.md +8 -0
  2. handler.py +130 -0
  3. requirements.txt +14 -0
README.md ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: other
3
+ license_name: flux-1-dev-non-commercial-license
4
+ license_link: https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.
5
+ language:
6
+ - en
7
+ inference: true
8
+ ---
handler.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/sayakpaul/diffusers-torchao
2
+
3
+ import os
4
+ from typing import Any, Dict
5
+
6
+ from diffusers import FluxPipeline, FluxTransformer2DModel, AutoencoderKL, TorchAoConfig
7
+ from PIL import Image
8
+ import torch
9
+ from torchao.quantization import quantize_, autoquant, int8_dynamic_activation_int8_weight
10
+ from huggingface_hub import hf_hub_download
11
+
12
+ IS_COMPILE = False
13
+ IS_TURBO = True
14
+
15
+ if IS_COMPILE:
16
+ import torch._dynamo
17
+ torch._dynamo.config.suppress_errors = True
18
+
19
+ from huggingface_inference_toolkit.logging import logger
20
+
21
+ def load_pipeline_stable(repo_id: str, dtype: torch.dtype) -> Any:
22
+ quantization_config = TorchAoConfig("int8dq")
23
+ vae = AutoencoderKL.from_pretrained(repo_id, subfolder="vae", torch_dtype=dtype)
24
+ pipe = FluxPipeline.from_pretrained(repo_id, vae=vae, torch_dtype=dtype, quantization_config=quantization_config)
25
+ pipe.transformer.fuse_qkv_projections()
26
+ pipe.vae.fuse_qkv_projections()
27
+ pipe.to("cuda")
28
+ return pipe
29
+
30
+ def load_pipeline_compile(repo_id: str, dtype: torch.dtype) -> Any:
31
+ quantization_config = TorchAoConfig("int8dq")
32
+ vae = AutoencoderKL.from_pretrained(repo_id, subfolder="vae", torch_dtype=dtype)
33
+ pipe = FluxPipeline.from_pretrained(repo_id, vae=vae, torch_dtype=dtype, quantization_config=quantization_config)
34
+ pipe.transformer.fuse_qkv_projections()
35
+ pipe.vae.fuse_qkv_projections()
36
+ pipe.transformer.to(memory_format=torch.channels_last)
37
+ pipe.transformer = torch.compile(pipe.transformer, mode="reduce-overhead", fullgraph=False, dynamic=False)
38
+ pipe.vae.to(memory_format=torch.channels_last)
39
+ pipe.vae = torch.compile(pipe.vae, mode="reduce-overhead", fullgraph=False, dynamic=False)
40
+ pipe.to("cuda")
41
+ return pipe
42
+
43
+ def load_pipeline_autoquant(repo_id: str, dtype: torch.dtype) -> Any:
44
+ pipe = FluxPipeline.from_pretrained(repo_id, torch_dtype=dtype).to("cuda")
45
+ pipe.transformer.fuse_qkv_projections()
46
+ pipe.vae.fuse_qkv_projections()
47
+ pipe.transformer.to(memory_format=torch.channels_last)
48
+ pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True)
49
+ pipe.vae.to(memory_format=torch.channels_last)
50
+ pipe.vae = torch.compile(pipe.vae, mode="max-autotune", fullgraph=True)
51
+ pipe.transformer = autoquant(pipe.transformer, error_on_unseen=False)
52
+ pipe.vae = autoquant(pipe.vae, error_on_unseen=False)
53
+ pipe.to("cuda")
54
+ return pipe
55
+
56
+ def load_pipeline_turbo(repo_id: str, dtype: torch.dtype) -> Any:
57
+ pipe = FluxPipeline.from_pretrained(repo_id, torch_dtype=dtype).to("cuda")
58
+ pipe.load_lora_weights(hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"), adapter_name="hyper-sd")
59
+ pipe.set_adapters(["hyper-sd"], adapter_weights=[0.125])
60
+ pipe.fuse_lora()
61
+ pipe.transformer.fuse_qkv_projections()
62
+ pipe.vae.fuse_qkv_projections()
63
+ quantize_(pipe.transformer, int8_dynamic_activation_int8_weight(), device="cuda")
64
+ quantize_(pipe.vae, int8_dynamic_activation_int8_weight(), device="cuda")
65
+ pipe.to("cuda")
66
+ return pipe
67
+
68
+ def load_pipeline_turbo_compile(repo_id: str, dtype: torch.dtype) -> Any:
69
+ pipe = FluxPipeline.from_pretrained(repo_id, torch_dtype=dtype).to("cuda")
70
+ pipe.load_lora_weights(hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"), adapter_name="hyper-sd")
71
+ pipe.set_adapters(["hyper-sd"], adapter_weights=[0.125])
72
+ pipe.fuse_lora()
73
+ pipe.transformer.fuse_qkv_projections()
74
+ pipe.vae.fuse_qkv_projections()
75
+ quantize_(pipe.transformer, int8_dynamic_activation_int8_weight(), device="cuda")
76
+ quantize_(pipe.vae, int8_dynamic_activation_int8_weight(), device="cuda")
77
+ pipe.transformer.to(memory_format=torch.channels_last)
78
+ pipe.transformer = torch.compile(pipe.transformer, mode="reduce-overhead", fullgraph=False, dynamic=False)
79
+ pipe.vae.to(memory_format=torch.channels_last)
80
+ pipe.vae = torch.compile(pipe.vae, mode="reduce-overhead", fullgraph=False, dynamic=False)
81
+ pipe.to("cuda")
82
+ return pipe
83
+
84
+ class EndpointHandler:
85
+ def __init__(self, path=""):
86
+ repo_id = "camenduru/FLUX.1-dev-diffusers"
87
+ #dtype = torch.bfloat16
88
+ dtype = torch.float16 # for older nVidia GPUs
89
+ if IS_COMPILE:
90
+ if IS_TURBO: self.pipeline = load_pipeline_turbo_compile(repo_id, dtype)
91
+ else: self.pipeline = load_pipeline_compile(repo_id, dtype)
92
+ else:
93
+ if IS_TURBO: self.pipeline = load_pipeline_turbo(repo_id, dtype)
94
+ else: self.pipeline = load_pipeline_stable(repo_id, dtype)
95
+
96
+ def __call__(self, data: Dict[str, Any]) -> Image.Image:
97
+ logger.info(f"Received incoming request with {data=}")
98
+
99
+ if "inputs" in data and isinstance(data["inputs"], str):
100
+ prompt = data.pop("inputs")
101
+ elif "prompt" in data and isinstance(data["prompt"], str):
102
+ prompt = data.pop("prompt")
103
+ else:
104
+ raise ValueError(
105
+ "Provided input body must contain either the key `inputs` or `prompt` with the"
106
+ " prompt to use for the image generation, and it needs to be a non-empty string."
107
+ )
108
+
109
+ parameters = data.pop("parameters", {})
110
+
111
+ num_inference_steps = parameters.get("num_inference_steps", 8 if IS_TURBO else 28)
112
+ width = parameters.get("width", 1024)
113
+ height = parameters.get("height", 1024)
114
+ guidance_scale = parameters.get("guidance_scale", 3.5)
115
+
116
+ # seed generator (seed cannot be provided as is but via a generator)
117
+ seed = parameters.get("seed", 0)
118
+ generator = torch.manual_seed(seed)
119
+
120
+ return self.pipeline( # type: ignore
121
+ prompt,
122
+ height=height,
123
+ width=width,
124
+ guidance_scale=guidance_scale,
125
+ num_inference_steps=num_inference_steps,
126
+ generator=generator,
127
+ output_type="pil",
128
+ ).images[0]
129
+
130
+
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ huggingface_hub
2
+ torch==2.4.0
3
+ torchvision
4
+ torchaudio
5
+ torchao==0.9.0
6
+ diffusers==0.32.2
7
+ peft
8
+ transformers
9
+ numpy
10
+ scipy
11
+ Pillow
12
+ sentencepiece
13
+ protobuf
14
+ triton