hlky HF staff commited on
Commit
f0be81e
·
verified ·
1 Parent(s): 6b84ca5

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +5 -4
handler.py CHANGED
@@ -41,10 +41,10 @@ class EndpointHandler:
41
  """
42
  tensor = cast(torch.Tensor, data["inputs"])
43
  parameters = cast(dict, data.get("parameters", {}))
44
- if "height" not in parameters or "width" not in parameters:
45
  raise ValueError("Expected `height` and `width` in parameters.")
46
- height = cast(int, parameters.get("height"))
47
- width = cast(int, parameters.get("width"))
48
  do_scaling = cast(bool, parameters.get("do_scaling", True))
49
  output_type = cast(str, parameters.get("output_type", "pil"))
50
  partial_postprocess = cast(bool, parameters.get("partial_postprocess", False))
@@ -52,7 +52,8 @@ class EndpointHandler:
52
  output_type = "pt"
53
 
54
  tensor = tensor.to(self.device, self.dtype)
55
- tensor = self._unpack_latents(tensor, height, width, self.vae_scale_factor)
 
56
 
57
  if do_scaling:
58
  tensor = (
 
41
  """
42
  tensor = cast(torch.Tensor, data["inputs"])
43
  parameters = cast(dict, data.get("parameters", {}))
44
+ if tensor.ndim == 3 and ("height" not in parameters or "width" not in parameters):
45
  raise ValueError("Expected `height` and `width` in parameters.")
46
+ height = cast(int, parameters.get("height", 0))
47
+ width = cast(int, parameters.get("width", 0))
48
  do_scaling = cast(bool, parameters.get("do_scaling", True))
49
  output_type = cast(str, parameters.get("output_type", "pil"))
50
  partial_postprocess = cast(bool, parameters.get("partial_postprocess", False))
 
52
  output_type = "pt"
53
 
54
  tensor = tensor.to(self.device, self.dtype)
55
+ if tensor.ndim == 3:
56
+ tensor = self._unpack_latents(tensor, height, width, self.vae_scale_factor)
57
 
58
  if do_scaling:
59
  tensor = (