Spaces:
Runtime error
Runtime error
import json | |
import os | |
import re | |
import logging | |
from collections import defaultdict | |
from modules import errors | |
extra_network_registry = {} | |
extra_network_aliases = {} | |
def initialize(): | |
extra_network_registry.clear() | |
extra_network_aliases.clear() | |
def register_extra_network(extra_network): | |
extra_network_registry[extra_network.name] = extra_network | |
def register_extra_network_alias(extra_network, alias): | |
extra_network_aliases[alias] = extra_network | |
def register_default_extra_networks(): | |
from modules.extra_networks_hypernet import ExtraNetworkHypernet | |
register_extra_network(ExtraNetworkHypernet()) | |
class ExtraNetworkParams: | |
def __init__(self, items=None): | |
self.items = items or [] | |
self.positional = [] | |
self.named = {} | |
for item in self.items: | |
parts = item.split('=', 2) if isinstance(item, str) else [item] | |
if len(parts) == 2: | |
self.named[parts[0]] = parts[1] | |
else: | |
self.positional.append(item) | |
def __eq__(self, other): | |
return self.items == other.items | |
class ExtraNetwork: | |
def __init__(self, name): | |
self.name = name | |
def activate(self, p, params_list): | |
""" | |
Called by processing on every run. Whatever the extra network is meant to do should be activated here. | |
Passes arguments related to this extra network in params_list. | |
User passes arguments by specifying this in his prompt: | |
<name:arg1:arg2:arg3> | |
Where name matches the name of this ExtraNetwork object, and arg1:arg2:arg3 are any natural number of text arguments | |
separated by colon. | |
Even if the user does not mention this ExtraNetwork in his prompt, the call will stil be made, with empty params_list - | |
in this case, all effects of this extra networks should be disabled. | |
Can be called multiple times before deactivate() - each new call should override the previous call completely. | |
For example, if this ExtraNetwork's name is 'hypernet' and user's prompt is: | |
> "1girl, <hypernet:agm:1.1> <extrasupernet:master:12:13:14> <hypernet:ray>" | |
params_list will be: | |
[ | |
ExtraNetworkParams(items=["agm", "1.1"]), | |
ExtraNetworkParams(items=["ray"]) | |
] | |
""" | |
raise NotImplementedError | |
def deactivate(self, p): | |
""" | |
Called at the end of processing for housekeeping. No need to do anything here. | |
""" | |
raise NotImplementedError | |
def lookup_extra_networks(extra_network_data): | |
"""returns a dict mapping ExtraNetwork objects to lists of arguments for those extra networks. | |
Example input: | |
{ | |
'lora': [<modules.extra_networks.ExtraNetworkParams object at 0x0000020690D58310>], | |
'lyco': [<modules.extra_networks.ExtraNetworkParams object at 0x0000020690D58F70>], | |
'hypernet': [<modules.extra_networks.ExtraNetworkParams object at 0x0000020690D5A800>] | |
} | |
Example output: | |
{ | |
<extra_networks_lora.ExtraNetworkLora object at 0x0000020581BEECE0>: [<modules.extra_networks.ExtraNetworkParams object at 0x0000020690D58310>, <modules.extra_networks.ExtraNetworkParams object at 0x0000020690D58F70>], | |
<modules.extra_networks_hypernet.ExtraNetworkHypernet object at 0x0000020581BEEE60>: [<modules.extra_networks.ExtraNetworkParams object at 0x0000020690D5A800>] | |
} | |
""" | |
res = {} | |
for extra_network_name, extra_network_args in list(extra_network_data.items()): | |
extra_network = extra_network_registry.get(extra_network_name, None) | |
alias = extra_network_aliases.get(extra_network_name, None) | |
if alias is not None and extra_network is None: | |
extra_network = alias | |
if extra_network is None: | |
logging.info(f"Skipping unknown extra network: {extra_network_name}") | |
continue | |
res.setdefault(extra_network, []).extend(extra_network_args) | |
return res | |
def activate(p, extra_network_data): | |
"""call activate for extra networks in extra_network_data in specified order, then call | |
activate for all remaining registered networks with an empty argument list""" | |
activated = [] | |
for extra_network, extra_network_args in lookup_extra_networks(extra_network_data).items(): | |
try: | |
extra_network.activate(p, extra_network_args) | |
activated.append(extra_network) | |
except Exception as e: | |
errors.display(e, f"activating extra network {extra_network.name} with arguments {extra_network_args}") | |
for extra_network_name, extra_network in extra_network_registry.items(): | |
if extra_network in activated: | |
continue | |
try: | |
extra_network.activate(p, []) | |
except Exception as e: | |
errors.display(e, f"activating extra network {extra_network_name}") | |
if p.scripts is not None: | |
p.scripts.after_extra_networks_activate(p, batch_number=p.iteration, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds, extra_network_data=extra_network_data) | |
def deactivate(p, extra_network_data): | |
"""call deactivate for extra networks in extra_network_data in specified order, then call | |
deactivate for all remaining registered networks""" | |
data = lookup_extra_networks(extra_network_data) | |
for extra_network in data: | |
try: | |
extra_network.deactivate(p) | |
except Exception as e: | |
errors.display(e, f"deactivating extra network {extra_network.name}") | |
for extra_network_name, extra_network in extra_network_registry.items(): | |
if extra_network in data: | |
continue | |
try: | |
extra_network.deactivate(p) | |
except Exception as e: | |
errors.display(e, f"deactivating unmentioned extra network {extra_network_name}") | |
re_extra_net = re.compile(r"<(\w+):([^>]+)>") | |
def parse_prompt(prompt): | |
res = defaultdict(list) | |
def found(m): | |
name = m.group(1) | |
args = m.group(2) | |
res[name].append(ExtraNetworkParams(items=args.split(":"))) | |
return "" | |
prompt = re.sub(re_extra_net, found, prompt) | |
return prompt, res | |
def parse_prompts(prompts): | |
res = [] | |
extra_data = None | |
for prompt in prompts: | |
updated_prompt, parsed_extra_data = parse_prompt(prompt) | |
if extra_data is None: | |
extra_data = parsed_extra_data | |
res.append(updated_prompt) | |
return res, extra_data | |
def get_user_metadata(filename, lister=None): | |
if filename is None: | |
return {} | |
basename, ext = os.path.splitext(filename) | |
metadata_filename = basename + '.json' | |
metadata = {} | |
try: | |
exists = lister.exists(metadata_filename) if lister else os.path.exists(metadata_filename) | |
if exists: | |
with open(metadata_filename, "r", encoding="utf8") as file: | |
metadata = json.load(file) | |
except Exception as e: | |
errors.display(e, f"reading extra network user metadata from {metadata_filename}") | |
return metadata | |