File size: 3,458 Bytes
f4a41d8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
from typing import List, Callable, Any, Optional
import torch
import tqdm
from comfy.sd import ModelPatcher, CLIP, VAE
class CondForModels(torch.Tensor):
@staticmethod
def __new__(cls, x, ex, *args, **kwargs):
return super().__new__(cls, x, *args, **kwargs) # type: ignore
def __init__(self, x, ex: List[torch.Tensor], *args, **kwargs):
super().__init__()
self.ex = ex
ATTR_NAME = 'iter_fn'
def iterize_model(model: ModelPatcher) -> List[Callable[[],ModelPatcher]]:
if not hasattr(model, ATTR_NAME):
setattr(model, ATTR_NAME, [lambda: model])
return getattr(model, ATTR_NAME)
def iterize_clip(clip: CLIP) -> List[Callable[[],CLIP]]:
if hasattr(clip, ATTR_NAME):
return getattr(clip, ATTR_NAME)
setattr(clip, ATTR_NAME, [lambda: clip])
old_encode = CLIP.encode
def new_encode(*args, **kwargs):
xs = []
clips = getattr(clip, ATTR_NAME)
for fn in tqdm.tqdm(clips):
clip_: CLIP = fn()
if clip_ == clip:
x = old_encode(clip_, *args, **kwargs)
else:
x = clip_.encode(*args, **kwargs)
if x.dim() == 2:
x = x.unsqueeze(0)
xs.append(x)
return CondForModels(xs[0], xs)
clip.encode = new_encode
return getattr(clip, ATTR_NAME)
def iterize_vae(vae: VAE) -> List[Callable[[],VAE]]:
if hasattr(vae, ATTR_NAME):
return getattr(vae, ATTR_NAME)
setattr(vae, ATTR_NAME, [lambda: vae])
old_decode = VAE.decode
def new_decode(*args, **kwargs):
xs = []
vaes = getattr(vae, ATTR_NAME)
for fn in tqdm.tqdm(vaes):
vae_: VAE = fn()
if vae_ == vae:
x = old_decode(vae_, *args, **kwargs)
else:
x = vae_.decode(*args, **kwargs)
if x.dim() == 3:
x = x.unsqueeze(0)
xs.append(x)
return torch.cat(xs)
vae.decode = new_decode
return getattr(vae, ATTR_NAME)
def try_get_iter(obj) -> Optional[List[Callable[[],Any]]]:
return getattr(obj, ATTR_NAME, None)
class ModelIter:
@classmethod
def INPUT_TYPES(cls):
return {
'required': {
'model1': ('MODEL', ),
'model2': ('MODEL', )
}
}
RETURN_TYPES = ('MODEL',)
FUNCTION = 'execute'
CATEGORY = 'model'
def execute(self, model1, model2):
fns = iterize_model(model1)
fns.append(lambda: model2)
return (model1,)
class CLIPIter:
@classmethod
def INPUT_TYPES(cls):
return {
'required': {
'clip1': ('CLIP', ),
'clip2': ('CLIP', )
}
}
RETURN_TYPES = ('CLIP',)
FUNCTION = 'execute'
CATEGORY = 'model'
def execute(self, clip1, clip2):
fns = iterize_clip(clip1)
fns.append(lambda: clip2)
return (clip1,)
class VAEIter:
@classmethod
def INPUT_TYPES(cls):
return {
'required': {
'vae1': ('VAE', ),
'vae2': ('VAE', )
}
}
RETURN_TYPES = ('VAE',)
FUNCTION = 'execute'
CATEGORY = 'model'
def execute(self, vae1, vae2):
fns = iterize_vae(vae1)
fns.append(lambda: vae2)
return (vae1,)
|