from collections import defaultdict from typing import Dict, List import torch from tqdm import trange from .model.iter import try_get_iter class VAEDecodeBatched: def __init__(self, device="cpu"): self.device = device @classmethod def INPUT_TYPES(cls): return { "required": { "samples": ("LATENT", ), "vae": ("VAE", ), "batch_size": ("INT", { "default": 1, "min": 1, "max": 32, "step": 1 }), } } RETURN_TYPES = ("IMAGE",) FUNCTION = "decode" CATEGORY = "latent" def decode(self, vae, samples, batch_size: int): s = samples['samples'] n = s.shape[0] iters = try_get_iter(vae) if iters is None: vae_num = 1 else: vae_num = len(iters) vae_results: Dict[int,List[torch.Tensor]] = defaultdict(lambda: []) for i in trange(0, n, batch_size): e = min([i+batch_size, n]) t = s[i:e, ...] v = vae.decode(t) vaes = torch.chunk(v, vae_num) for vn, vv in enumerate(vaes): vae_results[vn].append(vv) results = [] for k in sorted(vae_results.keys()): v = vae_results[k] results.extend(v) vs = torch.cat(results).contiguous() return (vs,) class VAEEncodeBatched: def __init__(self, device="cpu"): self.device = device @classmethod def INPUT_TYPES(cls): return { "required": { "pixels": ("IMAGE", ), "vae": ("VAE", ), "batch_size": ("INT", { "default": 1, "min": 1, "max": 32, "step": 1 }), } } RETURN_TYPES = ("LATENT",) FUNCTION = "encode" CATEGORY = "latent" def encode(self, vae, pixels, batch_size: int): n = pixels.shape[0] x = (pixels.shape[1] // 64) * 64 y = (pixels.shape[2] // 64) * 64 if pixels.shape[1] != x or pixels.shape[2] != y: pixels = pixels[:,:x,:y,:] pixels = pixels[:,:,:,:3] results = [] for i in trange(0, n, batch_size): e = max([i+batch_size, n]) t = pixels[i:e, ...] v = vae.encode(t) results.append(v) vs = torch.cat(results) return ({"samples":vs}, )