dikdimon's picture
Upload extensions using SD-Hub extension
f4a41d8 verified
raw
history blame
11.8 kB
import re
from itertools import product
from typing import Callable, List, Dict, Any, Union, Tuple, cast
import torch
import comfy.sample
import comfy.model_management
import comfy.samplers
from nodes import common_ksampler
from comfy.sd import ModelPatcher
from .model.iter import iterize_model, CondForModels
from .model import merge2
re_int = re.compile(r"\s*([+-]?\s*\d+)\s*")
re_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*")
re_range = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\(([+-]\d+)\s*\))?\s*")
re_range_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\(([+-]\d+(?:.\d*)?)\s*\))?\s*")
def frange(start, end, step):
x = float(start)
end = float(end)
step = float(step)
while x < end:
yield x
x += step
def get_noise(seeds: List[int], latent_image: torch.Tensor, disable_noise: bool, skip: int):
noises: List[torch.Tensor] = []
latents: List[torch.Tensor] = []
if latent_image.dim() == 3:
latent_image = latent_image.unsqueeze(0) # add batch dim
if disable_noise:
noise_ = torch.zeros([len(seeds)]+list(latent_image.size())[-3:], dtype=latent_image.dtype, layout=latent_image.layout, device="cpu")
noises.append(noise_)
latents.extend([latent_image] * (len(seeds) // latent_image.shape[0]))
else:
for s in seeds:
noise_ = comfy.sample.prepare_noise(latent_image, s, skip)
noises.append(noise_)
latents.append(latent_image)
return torch.cat(noises), torch.cat(latents)
def get_cfg(noises: torch.Tensor, latent_image: torch.Tensor, cfgs: List[float]):
# batch_size = noises.shape[0] * len(cfgs)
ns = [noises] * len(cfgs)
lat = [latent_image] * len(cfgs)
cf = torch.FloatTensor(cfgs * noises.shape[0])
return torch.cat(ns), torch.cat(lat), cf[...,None,None,None]
def process_cond_for_models(
cond: List[List[Union[torch.Tensor,CondForModels,dict]]],
model_index: int
):
"""
select conditioning tensor for the current model
"""
assert (
all(isinstance(p[0], CondForModels) for p in cond)
or not any(isinstance(p[0], CondForModels) for p in cond)
)
if not isinstance(cond[0][0], CondForModels):
return cond
sizes = set( len(cast(CondForModels, p[0]).ex) for p in cond )
assert len(sizes) == 1, f'number of conditions: {sizes}'
size = sizes.pop()
assert model_index < size
#
# conds
# + [ CondForModels, dictA ]
# | .ex + condA for model1
# | + condA for model2
# | ...
# | L condA for model{size}
# + [ CondForModels, dictB ]
# | .ex + condB for model1
# | + condB for model2
# | ...
# | L condB for model{size}
# ...
#
# vvv
#
# conds
# + [ [ condA_for_model1, dictA ], [ condB_for_model1, dictB ], ... ]
# + [ [ condA_for_model2, dictA ], [ condB_for_model2, dictB ], ... ] <- model_index
# ...
#
result = []
for c, *rest in cond:
assert isinstance(c, CondForModels)
actual_cond = c.ex[model_index]
result.append([actual_cond, *rest])
return result
def xyz_args(
model: ModelPatcher,
samplers: List[str],
schedulers: List[str],
steps: List[int],
):
for (model_index, model_fn), sampler, scheduler, step in product(enumerate(iterize_model(model)), samplers, schedulers, steps):
if sampler not in comfy.samplers.KSampler.SAMPLERS:
raise ValueError(f'unknown sampler name: {sampler}')
if scheduler not in comfy.samplers.KSampler.SCHEDULERS:
raise ValueError(f'unknown scheduler name: {scheduler}')
yield (
model_index,
model_fn,
step,
sampler,
scheduler,
)
def common_ksampler_xyz(
model: ModelPatcher,
seed: Union[int,List[int]],
steps: Union[int,List[int]],
cfg: Union[float,List[float]],
sampler_name: Union[str,List[str]],
scheduler: Union[str,List[str]],
positive,
negative,
latent,
denoise=1.0,
disable_noise=False,
start_step=None,
last_step=None,
force_full_denoise=False
):
if not isinstance(seed, list):
seed = [seed]
if not isinstance(steps, list):
steps = [steps]
if not isinstance(cfg, list):
cfg = [cfg]
if not isinstance(sampler_name, list):
sampler_name = [sampler_name]
if not isinstance(scheduler, list):
scheduler = [scheduler]
latent_image = latent["samples"]
noise_mask = latent.get('noise_mask', None)
noise, latent_image = get_noise(seed, latent_image, disable_noise, latent.get('batch_index', 0))
noise, latent_image, cfg_ = get_cfg(noise, latent_image, cfg)
cfg_ = cfg_.to('cuda')
all_samples: List[torch.Tensor] = []
for (
model_index, model_fn, step, sampler, scheduler
) in xyz_args(model, sampler_name, scheduler, steps):
current_model = model_fn()
positive_copy = process_cond_for_models(positive, model_index)
negative_copy = process_cond_for_models(negative, model_index)
print(f'XYZ sampler=model@{model_index}/{sampler}/{scheduler} {step}steps')
alphas = merge2.get_current_alpha(current_model.model)
if alphas is not None:
print(f'alpha = {alphas}')
samples = comfy.sample.sample(
current_model, noise, step, cfg_, sampler, scheduler,
positive_copy, negative_copy, latent_image,
denoise=denoise, disable_noise=disable_noise,
start_step=start_step, last_step=last_step,
force_full_denoise=force_full_denoise, noise_mask=noise_mask
)
samples = samples.cpu()
all_samples.append(samples)
out = latent.copy()
out["samples"] = torch.cat(all_samples)
return (out, )
class KSamplerSetting:
@classmethod
def INPUT_TYPES(cls):
return {
'required': {
'model': ('MODEL',),
'seed': ('INT', {'default': 0, 'min': 0, 'max': 0xffffffffffffffff}),
'steps': ('INT', {'default': 20, 'min': 1, 'max': 10000}),
'cfg': ('FLOAT', {'default': 8.0, 'min': 0.0, 'max': 100.0}),
'sampler_name': (comfy.samplers.KSampler.SAMPLERS, ),
'scheduler': (comfy.samplers.KSampler.SCHEDULERS, ),
'positive': ('CONDITIONING', ),
'negative': ('CONDITIONING', ),
'latent_image': ('LATENT', ),
'denoise': ('FLOAT', {'default': 1.0, 'min': 0.0, 'max': 1.0, 'step': 0.01}),
}
}
RETURN_TYPES = ('DICT',)
FUNCTION = 'sample'
CATEGORY = 'sampling'
def sample(self, **kwargs):
return kwargs,
class KSamplerOverrided:
@classmethod
def INPUT_TYPES(cls):
return {
'required': {
'setting': ('DICT',),
},
'optional': {
'model': ('MODEL',),
'seed': ('Integer', {'default': 0, 'min': 0, 'max': 0xffffffffffffffff}),
'steps': ('Integer', {'default': 20, 'min': 1, 'max': 10000}),
'cfg': ('Float', {'default': 8.0, 'min': 0.0, 'max': 100.0}),
'sampler_name': ('SamplerName',),
'scheduler': ('SchedulerName', ),
'positive': ('CONDITIONING', ),
'negative': ('CONDITIONING', ),
'latent_image': ('LATENT', ),
'denoise': ('Float', {'default': 1.0, 'min': 0.0, 'max': 1.0, 'step': 0.01}),
}
}
RETURN_TYPES = ('LATENT',)
FUNCTION = 'sample'
CATEGORY = 'sampling'
def sample(self, setting: dict, **kwargs):
if 'latent_image' in setting:
setting['latent'] = setting['latent_image']
del setting['latent_image']
setting.update(kwargs)
return common_ksampler(**setting)
class KSamplerXYZ:
@classmethod
def INPUT_TYPES(cls):
return {
'required': {
'setting': ('DICT',),
},
'optional': {
'model': ('MODEL',),
'seed': ('STRING', { 'multiline': True, 'default': '' }),
'steps': ('STRING', { 'multiline': True, 'default': '' }),
'cfg': ('STRING', { 'multiline': True, 'default': '' }),
'sampler_name': ('STRING', { 'multiline': True, 'default': '' }),
'scheduler': ('STRING', { 'multiline': True, 'default': '' }),
}
}
RETURN_TYPES = ('LATENT',)
FUNCTION = 'sample'
CATEGORY = 'sampling'
def sample(self, setting: dict, **kwargs):
if 'latent_image' in setting:
setting['latent'] = setting['latent_image']
del setting['latent_image']
# ignore empty string
kwargs = { k: v for k, v in kwargs.items() if not isinstance(v, str) or len(v) != 0 }
setting = { **setting, **kwargs }
if isinstance(setting.get('seed', None), str):
setting['seed'] = self.parse(setting['seed'], self.parse_int)
if isinstance(setting.get('steps', None), str):
setting['steps'] = self.parse(setting['steps'], self.parse_int)
if isinstance(setting.get('cfg', None), str):
setting['cfg'] = self.parse(setting['cfg'], self.parse_float)
if isinstance(setting.get('sampler_name', None), str):
setting['sampler_name'] = self.parse(setting['sampler_name'], None)
if len(setting['sampler_name']) == 1:
setting['sampler_name'] = setting['sampler_name'][0]
if isinstance(setting.get('scheduler', None), str):
setting['scheduler'] = self.parse(setting['scheduler'], None)
if len(setting['scheduler']) == 1:
setting['scheduler'] = setting['scheduler'][0]
for k, v in setting.items():
if k in kwargs and isinstance(v, (list, tuple)):
print(f'XYZ {k}: {v}')
return common_ksampler_xyz(**setting) # type: ignore
def parse(self, input: str, cont: Union[Callable[[str],Any],None]):
vs = [ x.strip() for x in input.split(',') ]
if cont is not None:
new_vs = []
for v in vs:
new_v = cont(v)
if isinstance(new_v, list):
new_vs += new_v
else:
new_vs.append(new_v)
vs = new_vs
return vs
def parse_int(self, input: str):
m = re_int.fullmatch(input)
if m is not None:
return int(m.group(1))
m = re_range.fullmatch(input)
if m is None:
raise ValueError(f'failed to process: {input}')
start, end, step = m.group(1), m.group(2), m.group(3)
if step is None:
step = 1
return list(range(int(start), int(end) + 1, int(step)))
def parse_float(self, input: str):
m = re_float.fullmatch(input)
if m is not None:
return float(m.group(1))
m = re_range_float.fullmatch(input)
if m is None:
raise ValueError(f'failed to process: {input}')
start, end, step = m.group(1), m.group(2), m.group(3)
if step is None:
step = 1.0
return list(frange(float(start), float(end), float(step)))