File size: 3,458 Bytes
f4a41d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
from typing import List, Callable, Any, Optional
import torch
import tqdm
from comfy.sd import ModelPatcher, CLIP, VAE

class CondForModels(torch.Tensor):
    
    @staticmethod
    def __new__(cls, x, ex, *args, **kwargs):
        return super().__new__(cls, x, *args, **kwargs) # type: ignore
    
    def __init__(self, x, ex: List[torch.Tensor], *args, **kwargs):
        super().__init__()
        self.ex = ex

ATTR_NAME = 'iter_fn'

def iterize_model(model: ModelPatcher) -> List[Callable[[],ModelPatcher]]:
    if not hasattr(model, ATTR_NAME):
        setattr(model, ATTR_NAME, [lambda: model])
    return getattr(model, ATTR_NAME)

def iterize_clip(clip: CLIP) -> List[Callable[[],CLIP]]:
    if hasattr(clip, ATTR_NAME):
        return getattr(clip, ATTR_NAME)
    
    setattr(clip, ATTR_NAME, [lambda: clip])
    
    old_encode = CLIP.encode
    
    def new_encode(*args, **kwargs):
        xs = []
        clips = getattr(clip, ATTR_NAME)
        for fn in tqdm.tqdm(clips):
            clip_: CLIP = fn()
            if clip_ == clip:
                x = old_encode(clip_, *args, **kwargs)
            else:
                x = clip_.encode(*args, **kwargs)
            if x.dim() == 2:
                x = x.unsqueeze(0)
            xs.append(x)
        return CondForModels(xs[0], xs)
    
    clip.encode = new_encode
    
    return getattr(clip, ATTR_NAME)

def iterize_vae(vae: VAE) -> List[Callable[[],VAE]]:
    if hasattr(vae, ATTR_NAME):
        return getattr(vae, ATTR_NAME)
    
    setattr(vae, ATTR_NAME, [lambda: vae])
    
    old_decode = VAE.decode
    
    def new_decode(*args, **kwargs):
        xs = []
        vaes = getattr(vae, ATTR_NAME)
        for fn in tqdm.tqdm(vaes):
            vae_: VAE = fn()
            if vae_ == vae:
                x = old_decode(vae_, *args, **kwargs)
            else:
                x = vae_.decode(*args, **kwargs)
            if x.dim() == 3:
                x = x.unsqueeze(0)
            xs.append(x)
        return torch.cat(xs)
    
    vae.decode = new_decode
    
    return getattr(vae, ATTR_NAME)

def try_get_iter(obj) -> Optional[List[Callable[[],Any]]]:
    return getattr(obj, ATTR_NAME, None)

class ModelIter:
    
    @classmethod
    def INPUT_TYPES(cls):
        return {
            'required': {
                'model1': ('MODEL', ),
                'model2': ('MODEL', )
            }
        }
    
    RETURN_TYPES = ('MODEL',)
    
    FUNCTION = 'execute'

    CATEGORY = 'model'

    def execute(self, model1, model2):
        fns = iterize_model(model1)
        fns.append(lambda: model2)
        return (model1,)


class CLIPIter:
    
    @classmethod
    def INPUT_TYPES(cls):
        return {
            'required': {
                'clip1': ('CLIP', ),
                'clip2': ('CLIP', )
            }
        }
    
    RETURN_TYPES = ('CLIP',)
    
    FUNCTION = 'execute'

    CATEGORY = 'model'

    def execute(self, clip1, clip2):
        fns = iterize_clip(clip1)
        fns.append(lambda: clip2)
        return (clip1,)


class VAEIter:
    
    @classmethod
    def INPUT_TYPES(cls):
        return {
            'required': {
                'vae1': ('VAE', ),
                'vae2': ('VAE', )
            }
        }
    
    RETURN_TYPES = ('VAE',)
    
    FUNCTION = 'execute'

    CATEGORY = 'model'

    def execute(self, vae1, vae2):
        fns = iterize_vae(vae1)
        fns.append(lambda: vae2)
        return (vae1,)