John6666 commited on
Commit
acf2d5f
·
verified ·
1 Parent(s): fecd4a4

Upload handler.py

Browse files
Files changed (1) hide show
  1. handler.py +4 -6
handler.py CHANGED
@@ -4,10 +4,8 @@ from typing import Any, Dict
4
  from diffusers import FluxPipeline, FluxTransformer2DModel, AutoencoderKL, TorchAoConfig
5
  from PIL import Image
6
  import torch
7
- from accelerate import PartialState
8
- distributed_state = PartialState()
9
 
10
- IS_COMPILE = False
11
 
12
  if IS_COMPILE:
13
  import torch._dynamo
@@ -16,8 +14,10 @@ if IS_COMPILE:
16
  #from huggingface_inference_toolkit.logging import logger
17
 
18
  def compile_pipeline(pipe) -> Any:
 
19
  pipe.transformer = torch.compile(pipe.transformer, mode="reduce-overhead", fullgraph=False, dynamic=False, backend="inductor")
20
- pipe.vae = torch.compile(pipe.vae, mode="reduce-overhead", fullgraph=False, dynamic=False, backend="inductor")
 
21
  return pipe
22
 
23
  class EndpointHandler:
@@ -30,9 +30,7 @@ class EndpointHandler:
30
  #transformer = FluxTransformer2DModel.from_pretrained(repo_id, subfolder="transformer", torch_dtype=dtype, quantization_config=quantization_config).to("cuda")
31
  self.pipeline = FluxPipeline.from_pretrained(repo_id, vae=vae, torch_dtype=dtype, quantization_config=quantization_config)
32
  self.pipeline.transformer.fuse_qkv_projections()
33
- self.pipeline.transformer.to(memory_format=torch.channels_last)
34
  self.pipeline.vae.fuse_qkv_projections()
35
- self.pipeline.vae.to(memory_format=torch.channels_last)
36
  if IS_COMPILE: self.pipeline = compile_pipeline(self.pipeline)
37
  self.pipeline.to(distributed_state.device)
38
 
 
4
  from diffusers import FluxPipeline, FluxTransformer2DModel, AutoencoderKL, TorchAoConfig
5
  from PIL import Image
6
  import torch
 
 
7
 
8
+ IS_COMPILE = True
9
 
10
  if IS_COMPILE:
11
  import torch._dynamo
 
14
  #from huggingface_inference_toolkit.logging import logger
15
 
16
  def compile_pipeline(pipe) -> Any:
17
+ pipe.transformer.to(memory_format=torch.channels_last)
18
  pipe.transformer = torch.compile(pipe.transformer, mode="reduce-overhead", fullgraph=False, dynamic=False, backend="inductor")
19
+ #pipe.vae.to(memory_format=torch.channels_last)
20
+ #pipe.vae = torch.compile(pipe.vae, mode="reduce-overhead", fullgraph=False, dynamic=False, backend="inductor")
21
  return pipe
22
 
23
  class EndpointHandler:
 
30
  #transformer = FluxTransformer2DModel.from_pretrained(repo_id, subfolder="transformer", torch_dtype=dtype, quantization_config=quantization_config).to("cuda")
31
  self.pipeline = FluxPipeline.from_pretrained(repo_id, vae=vae, torch_dtype=dtype, quantization_config=quantization_config)
32
  self.pipeline.transformer.fuse_qkv_projections()
 
33
  self.pipeline.vae.fuse_qkv_projections()
 
34
  if IS_COMPILE: self.pipeline = compile_pipeline(self.pipeline)
35
  self.pipeline.to(distributed_state.device)
36