File size: 6,038 Bytes
419d7f4
 
b748322
 
 
 
 
 
419d7f4
 
b748322
fb6d062
419d7f4
 
b748322
 
 
 
 
419d7f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b748322
419d7f4
 
 
 
 
 
 
 
 
 
 
acf2d5f
419d7f4
 
 
 
b748322
 
 
 
419d7f4
 
 
 
 
 
b748322
419d7f4
b748322
 
 
 
 
 
 
 
 
 
 
 
 
419d7f4
b748322
 
 
 
 
 
 
 
9b4d4c0
b748322
 
 
 
 
 
9b4d4c0
 
b748322
2393f58
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
# https://github.com/sayakpaul/diffusers-torchao

import os
from typing import Any, Dict

from diffusers import FluxPipeline, FluxTransformer2DModel, AutoencoderKL, TorchAoConfig
from PIL import Image
import torch
from torchao.quantization import quantize_, autoquant, int8_dynamic_activation_int8_weight, int8_dynamic_activation_int4_weight
from huggingface_hub import hf_hub_download

IS_COMPILE = False
IS_TURBO = False
IS_4BIT = True

if IS_COMPILE:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

from huggingface_inference_toolkit.logging import logger

def load_pipeline_stable(repo_id: str, dtype: torch.dtype) -> Any:
    quantization_config = TorchAoConfig("int4dq" if IS_4BIT else "int8dq")
    vae = AutoencoderKL.from_pretrained(repo_id, subfolder="vae", torch_dtype=dtype)
    pipe = FluxPipeline.from_pretrained(repo_id, vae=vae, torch_dtype=dtype, quantization_config=quantization_config)
    pipe.transformer.fuse_qkv_projections()
    pipe.vae.fuse_qkv_projections()
    pipe.to("cuda")
    return pipe

def load_pipeline_compile(repo_id: str, dtype: torch.dtype) -> Any:
    quantization_config = TorchAoConfig("int4dq" if IS_4BIT else "int8dq")
    vae = AutoencoderKL.from_pretrained(repo_id, subfolder="vae", torch_dtype=dtype)
    pipe = FluxPipeline.from_pretrained(repo_id, vae=vae, torch_dtype=dtype, quantization_config=quantization_config)
    pipe.transformer.fuse_qkv_projections()
    pipe.vae.fuse_qkv_projections()
    pipe.transformer.to(memory_format=torch.channels_last)
    pipe.transformer = torch.compile(pipe.transformer, mode="reduce-overhead", fullgraph=False, dynamic=False)
    pipe.vae.to(memory_format=torch.channels_last)
    pipe.vae = torch.compile(pipe.vae, mode="reduce-overhead", fullgraph=False, dynamic=False)
    pipe.to("cuda")
    return pipe

def load_pipeline_autoquant(repo_id: str, dtype: torch.dtype) -> Any:
    pipe = FluxPipeline.from_pretrained(repo_id, torch_dtype=dtype).to("cuda")
    pipe.transformer.fuse_qkv_projections()
    pipe.vae.fuse_qkv_projections()
    pipe.transformer.to(memory_format=torch.channels_last)
    pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True)
    pipe.vae.to(memory_format=torch.channels_last)
    pipe.vae = torch.compile(pipe.vae, mode="max-autotune", fullgraph=True)
    pipe.transformer = autoquant(pipe.transformer, error_on_unseen=False)
    pipe.vae = autoquant(pipe.vae, error_on_unseen=False)
    pipe.to("cuda")
    return pipe

def load_pipeline_turbo(repo_id: str, dtype: torch.dtype) -> Any:
    pipe = FluxPipeline.from_pretrained(repo_id, torch_dtype=dtype).to("cuda")
    pipe.load_lora_weights(hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"), adapter_name="hyper-sd")
    pipe.set_adapters(["hyper-sd"], adapter_weights=[0.125])
    pipe.fuse_lora()
    pipe.transformer.fuse_qkv_projections()
    pipe.vae.fuse_qkv_projections()
    weight = int8_dynamic_activation_int4_weight() if IS_4BIT else int8_dynamic_activation_int8_weight()
    quantize_(pipe.transformer, weight, device="cuda")
    quantize_(pipe.vae, weight, device="cuda")
    quantize_(pipe.text_encoder_2, weight, device="cuda")
    pipe.to("cuda")
    return pipe

def load_pipeline_turbo_compile(repo_id: str, dtype: torch.dtype) -> Any:
    pipe = FluxPipeline.from_pretrained(repo_id, torch_dtype=dtype).to("cuda")
    pipe.load_lora_weights(hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"), adapter_name="hyper-sd")
    pipe.set_adapters(["hyper-sd"], adapter_weights=[0.125])
    pipe.fuse_lora()
    pipe.transformer.fuse_qkv_projections()
    pipe.vae.fuse_qkv_projections()
    weight = int8_dynamic_activation_int4_weight() if IS_4BIT else int8_dynamic_activation_int8_weight()
    quantize_(pipe.transformer, weight, device="cuda")
    quantize_(pipe.vae, weight, device="cuda")
    quantize_(pipe.text_encoder_2, weight, device="cuda")
    pipe.transformer.to(memory_format=torch.channels_last)
    pipe.transformer = torch.compile(pipe.transformer, mode="reduce-overhead", fullgraph=False, dynamic=False)
    pipe.vae.to(memory_format=torch.channels_last)
    pipe.vae = torch.compile(pipe.vae, mode="reduce-overhead", fullgraph=False, dynamic=False)
    pipe.to("cuda")
    return pipe

class EndpointHandler:
    def __init__(self, path=""):
        repo_id = "NoMoreCopyrightOrg/flux-dev-8step" if IS_TURBO else "NoMoreCopyrightOrg/flux-dev"
        #dtype = torch.bfloat16
        dtype = torch.float16 # for older nVidia GPUs
        if IS_COMPILE: load_pipeline_compile(repo_id, dtype)
        else: self.pipeline = load_pipeline_stable(repo_id, dtype)

    def __call__(self, data: Dict[str, Any]) -> Image.Image:
        logger.info(f"Received incoming request with {data=}")

        if "inputs" in data and isinstance(data["inputs"], str):
            prompt = data.pop("inputs")
        elif "prompt" in data and isinstance(data["prompt"], str):
            prompt = data.pop("prompt")
        else:
            raise ValueError(
                "Provided input body must contain either the key `inputs` or `prompt` with the"
                " prompt to use for the image generation, and it needs to be a non-empty string."
            )

        parameters = data.pop("parameters", {})

        num_inference_steps = parameters.get("num_inference_steps", 8 if IS_TURBO else 28)
        width = parameters.get("width", 1024)
        height = parameters.get("height", 1024)
        guidance_scale = parameters.get("guidance_scale", 3.5)

        # seed generator (seed cannot be provided as is but via a generator)
        seed = parameters.get("seed", 0)
        generator = torch.manual_seed(seed)

        return self.pipeline(  # type: ignore
            prompt,
            height=height,
            width=width,
            guidance_scale=guidance_scale,
            num_inference_steps=num_inference_steps,
            generator=generator,
            output_type="pil",
        ).images[0]