English
Inference Endpoints
John6666 commited on
Commit
b452566
·
verified ·
1 Parent(s): 67e2f7f

Upload handler.py

Browse files
Files changed (1) hide show
  1. handler.py +9 -9
handler.py CHANGED
@@ -5,10 +5,10 @@ from diffusers import FluxPipeline, FluxTransformer2DModel, AutoencoderKL, Torch
5
  from PIL.Image import Image
6
  import torch
7
 
8
- #import torch._dynamo
9
- #torch._dynamo.config.suppress_errors = True
10
 
11
- #from huggingface_inference_toolkit.logging import logger
12
 
13
  def compile_pipeline(pipe) -> Any:
14
  pipe.transformer.fuse_qkv_projections()
@@ -17,10 +17,10 @@ def compile_pipeline(pipe) -> Any:
17
  return pipe
18
 
19
  class EndpointHandler:
20
- def __init__(self, path="NoMoreCopyright/FLUX.1-dev-test"):
21
- is_compile = False
22
- repo_id = "camenduru/FLUX.1-dev-diffusers"
23
- #repo_id = "NoMoreCopyright/FLUX.1-dev-test"
24
  dtype = torch.bfloat16
25
  quantization_config = TorchAoConfig("int4dq")
26
  vae = AutoencoderKL.from_pretrained(repo_id, subfolder="vae", torch_dtype=dtype)
@@ -29,9 +29,9 @@ class EndpointHandler:
29
  if is_compile: self.pipeline = compile_pipeline(self.pipeline)
30
  self.pipeline.to("cuda")
31
 
32
- #@torch.inference_mode()
33
  def __call__(self, data: Dict[str, Any]) -> Image:
34
- #logger.info(f"Received incoming request with {data=}")
35
 
36
  if "inputs" in data and isinstance(data["inputs"], str):
37
  prompt = data.pop("inputs")
 
5
  from PIL.Image import Image
6
  import torch
7
 
8
+ import torch._dynamo
9
+ torch._dynamo.config.suppress_errors = True
10
 
11
+ from huggingface_inference_toolkit.logging import logger
12
 
13
  def compile_pipeline(pipe) -> Any:
14
  pipe.transformer.fuse_qkv_projections()
 
17
  return pipe
18
 
19
  class EndpointHandler:
20
+ def __init__(self, path=""):
21
+ is_compile = True
22
+ #repo_id = "camenduru/FLUX.1-dev-diffusers"
23
+ repo_id = "NoMoreCopyright/FLUX.1-dev-test"
24
  dtype = torch.bfloat16
25
  quantization_config = TorchAoConfig("int4dq")
26
  vae = AutoencoderKL.from_pretrained(repo_id, subfolder="vae", torch_dtype=dtype)
 
29
  if is_compile: self.pipeline = compile_pipeline(self.pipeline)
30
  self.pipeline.to("cuda")
31
 
32
+ @torch.inference_mode()
33
  def __call__(self, data: Dict[str, Any]) -> Image:
34
+ logger.info(f"Received incoming request with {data=}")
35
 
36
  if "inputs" in data and isinstance(data["inputs"], str):
37
  prompt = data.pop("inputs")