hlky HF staff commited on
Commit
6691b0e
·
verified ·
1 Parent(s): b9608a0

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +3 -4
handler.py CHANGED
@@ -7,10 +7,9 @@ from diffusers.image_processor import VaeImageProcessor
7
 
8
  class EndpointHandler:
9
  def __init__(self, path=""):
10
- kwargs = {"torch_dtype": torch.float16}
11
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
12
- self.dtype = kwargs["torch_dtype"]
13
- self.vae = AutoencoderKL.from_pretrained(path, **kwargs).to(self.device).eval()
14
 
15
  self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
16
  self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
 
7
 
8
  class EndpointHandler:
9
  def __init__(self, path=""):
10
+ self.device = "cuda"
11
+ self.dtype = torch.float16
12
+ self.vae = AutoencoderKL.from_pretrained(path, torch_dtype=self.dtype).to(self.device, self.dtype).eval()
 
13
 
14
  self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
15
  self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)