from typing import List, Callable, Any, Optional import torch import tqdm from comfy.sd import ModelPatcher, CLIP, VAE class CondForModels(torch.Tensor): @staticmethod def __new__(cls, x, ex, *args, **kwargs): return super().__new__(cls, x, *args, **kwargs) # type: ignore def __init__(self, x, ex: List[torch.Tensor], *args, **kwargs): super().__init__() self.ex = ex ATTR_NAME = 'iter_fn' def iterize_model(model: ModelPatcher) -> List[Callable[[],ModelPatcher]]: if not hasattr(model, ATTR_NAME): setattr(model, ATTR_NAME, [lambda: model]) return getattr(model, ATTR_NAME) def iterize_clip(clip: CLIP) -> List[Callable[[],CLIP]]: if hasattr(clip, ATTR_NAME): return getattr(clip, ATTR_NAME) setattr(clip, ATTR_NAME, [lambda: clip]) old_encode = CLIP.encode def new_encode(*args, **kwargs): xs = [] clips = getattr(clip, ATTR_NAME) for fn in tqdm.tqdm(clips): clip_: CLIP = fn() if clip_ == clip: x = old_encode(clip_, *args, **kwargs) else: x = clip_.encode(*args, **kwargs) if x.dim() == 2: x = x.unsqueeze(0) xs.append(x) return CondForModels(xs[0], xs) clip.encode = new_encode return getattr(clip, ATTR_NAME) def iterize_vae(vae: VAE) -> List[Callable[[],VAE]]: if hasattr(vae, ATTR_NAME): return getattr(vae, ATTR_NAME) setattr(vae, ATTR_NAME, [lambda: vae]) old_decode = VAE.decode def new_decode(*args, **kwargs): xs = [] vaes = getattr(vae, ATTR_NAME) for fn in tqdm.tqdm(vaes): vae_: VAE = fn() if vae_ == vae: x = old_decode(vae_, *args, **kwargs) else: x = vae_.decode(*args, **kwargs) if x.dim() == 3: x = x.unsqueeze(0) xs.append(x) return torch.cat(xs) vae.decode = new_decode return getattr(vae, ATTR_NAME) def try_get_iter(obj) -> Optional[List[Callable[[],Any]]]: return getattr(obj, ATTR_NAME, None) class ModelIter: @classmethod def INPUT_TYPES(cls): return { 'required': { 'model1': ('MODEL', ), 'model2': ('MODEL', ) } } RETURN_TYPES = ('MODEL',) FUNCTION = 'execute' CATEGORY = 'model' def execute(self, model1, model2): fns = iterize_model(model1) fns.append(lambda: model2) return (model1,) class CLIPIter: @classmethod def INPUT_TYPES(cls): return { 'required': { 'clip1': ('CLIP', ), 'clip2': ('CLIP', ) } } RETURN_TYPES = ('CLIP',) FUNCTION = 'execute' CATEGORY = 'model' def execute(self, clip1, clip2): fns = iterize_clip(clip1) fns.append(lambda: clip2) return (clip1,) class VAEIter: @classmethod def INPUT_TYPES(cls): return { 'required': { 'vae1': ('VAE', ), 'vae2': ('VAE', ) } } RETURN_TYPES = ('VAE',) FUNCTION = 'execute' CATEGORY = 'model' def execute(self, vae1, vae2): fns = iterize_vae(vae1) fns.append(lambda: vae2) return (vae1,)