Update handler.py
Browse files- handler.py +5 -4
handler.py
CHANGED
@@ -41,10 +41,10 @@ class EndpointHandler:
|
|
41 |
"""
|
42 |
tensor = cast(torch.Tensor, data["inputs"])
|
43 |
parameters = cast(dict, data.get("parameters", {}))
|
44 |
-
if "height" not in parameters or "width" not in parameters:
|
45 |
raise ValueError("Expected `height` and `width` in parameters.")
|
46 |
-
height = cast(int, parameters.get("height"))
|
47 |
-
width = cast(int, parameters.get("width"))
|
48 |
do_scaling = cast(bool, parameters.get("do_scaling", True))
|
49 |
output_type = cast(str, parameters.get("output_type", "pil"))
|
50 |
partial_postprocess = cast(bool, parameters.get("partial_postprocess", False))
|
@@ -52,7 +52,8 @@ class EndpointHandler:
|
|
52 |
output_type = "pt"
|
53 |
|
54 |
tensor = tensor.to(self.device, self.dtype)
|
55 |
-
tensor
|
|
|
56 |
|
57 |
if do_scaling:
|
58 |
tensor = (
|
|
|
41 |
"""
|
42 |
tensor = cast(torch.Tensor, data["inputs"])
|
43 |
parameters = cast(dict, data.get("parameters", {}))
|
44 |
+
if tensor.ndim == 3 and ("height" not in parameters or "width" not in parameters):
|
45 |
raise ValueError("Expected `height` and `width` in parameters.")
|
46 |
+
height = cast(int, parameters.get("height", 0))
|
47 |
+
width = cast(int, parameters.get("width", 0))
|
48 |
do_scaling = cast(bool, parameters.get("do_scaling", True))
|
49 |
output_type = cast(str, parameters.get("output_type", "pil"))
|
50 |
partial_postprocess = cast(bool, parameters.get("partial_postprocess", False))
|
|
|
52 |
output_type = "pt"
|
53 |
|
54 |
tensor = tensor.to(self.device, self.dtype)
|
55 |
+
if tensor.ndim == 3:
|
56 |
+
tensor = self._unpack_latents(tensor, height, width, self.vae_scale_factor)
|
57 |
|
58 |
if do_scaling:
|
59 |
tensor = (
|