Files changed (2) hide show
  1. handler.py +54 -0
  2. requirements.txt +7 -0
handler.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Any, Dict
3
+
4
+ from diffusers import FluxPipeline, FluxTransformer2DModel
5
+ from torchao.quantization import int8_weight_only, quantize_
6
+ from PIL.Image import Image
7
+ import torch
8
+
9
+ from huggingface_inference_toolkit.logging import logger
10
+
11
+ class EndpointHandler:
12
+ def __init__(self, **kwargs: Any) -> None: # type: ignore
13
+ repo_id = "camenduru/FLUX.1-dev-diffusers"
14
+ dtype = torch.bfloat16
15
+ transformer = FluxTransformer2DModel.from_pretrained(repo_id, subfolder="transformer", torch_dtype=dtype)
16
+ quantize_(transformer, int8_weight_only(), device="cuda")
17
+ transformer.to(memory_format=torch.channels_last)
18
+ transformer = torch.compile(transformer, mode="max-autotune", fullgraph=True)
19
+ self.pipeline = FluxPipeline.from_pretrained(repo_id, transformer=transformer, torch_dtype=torch.bfloat16).to("cuda")
20
+ self.pipeline.vae.to(memory_format=torch.channels_last)
21
+ self.pipeline.vae.decode = torch.compile(self.pipeline.vae.decode, mode="max-autotune", fullgraph=True)
22
+
23
+ def __call__(self, data: Dict[str, Any]) -> Image:
24
+ logger.info(f"Received incoming request with {data=}")
25
+
26
+ if "inputs" in data and isinstance(data["inputs"], str):
27
+ prompt = data.pop("inputs")
28
+ elif "prompt" in data and isinstance(data["prompt"], str):
29
+ prompt = data.pop("prompt")
30
+ else:
31
+ raise ValueError(
32
+ "Provided input body must contain either the key `inputs` or `prompt` with the"
33
+ " prompt to use for the image generation, and it needs to be a non-empty string."
34
+ )
35
+
36
+ parameters = data.pop("parameters", {})
37
+
38
+ num_inference_steps = parameters.get("num_inference_steps", 30)
39
+ width = parameters.get("width", 1024)
40
+ height = parameters.get("height", 768)
41
+ guidance_scale = parameters.get("guidance_scale", 3.5)
42
+
43
+ # seed generator (seed cannot be provided as is but via a generator)
44
+ seed = parameters.get("seed", 0)
45
+ generator = torch.manual_seed(seed)
46
+
47
+ return self.pipeline( # type: ignore
48
+ prompt,
49
+ height=height,
50
+ width=width,
51
+ guidance_scale=guidance_scale,
52
+ num_inference_steps=num_inference_steps,
53
+ generator=generator,
54
+ ).images[0]
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch
2
+ diffusers
3
+ peft
4
+ accelerate
5
+ transformers
6
+ numpy
7
+ Pillow