Spaces:
Running
Running
File size: 5,263 Bytes
19dc0f3 33473a0 19dc0f3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
import functools
import pprint
import random
from pathlib import Path
import yaml
from modules import shared
from modules.loaders import loaders_samplers
from modules.logging_colors import logger
def default_preset():
return {
'temperature': 1,
'temperature_last': False,
'dynamic_temperature': False,
'dynatemp_low': 1,
'dynatemp_high': 1,
'dynatemp_exponent': 1,
'smoothing_factor': 0,
'smoothing_curve': 1,
'top_p': 1,
'min_p': 0,
'top_k': 0,
'repetition_penalty': 1,
'presence_penalty': 0,
'frequency_penalty': 0,
'repetition_penalty_range': 1024,
'typical_p': 1,
'tfs': 1,
'top_a': 0,
'epsilon_cutoff': 0,
'eta_cutoff': 0,
'guidance_scale': 1,
'penalty_alpha': 0,
'mirostat_mode': 0,
'mirostat_tau': 5,
'mirostat_eta': 0.1,
'do_sample': True,
'encoder_repetition_penalty': 1,
'no_repeat_ngram_size': 0,
'dry_multiplier': 0,
'dry_base': 1.75,
'dry_allowed_length': 2,
'dry_sequence_breakers': '"\\n", ":", "\\"", "*"',
'sampler_priority': 'temperature\ndynamic_temperature\nquadratic_sampling\ntop_k\ntop_p\ntypical_p\nepsilon_cutoff\neta_cutoff\ntfs\ntop_a\nmin_p\nmirostat'
}
def presets_params():
return [k for k in default_preset()]
def load_preset(name, verbose=False):
generate_params = default_preset()
if name not in ['None', None, '']:
path = Path(f'presets/{name}.yaml')
if path.exists():
with open(path, 'r') as infile:
preset = yaml.safe_load(infile)
for k in preset:
generate_params[k] = preset[k]
else:
logger.error(f"The preset \"{name}\" does not exist under \"{path}\". Using the default parameters.")
if verbose:
logger.info(f"\"{name}\" preset:")
pprint.PrettyPrinter(indent=4, width=1, sort_dicts=False).pprint(remove_defaults(generate_params))
return generate_params
@functools.cache
def load_preset_memoized(name):
return load_preset(name)
def load_preset_for_ui(name, state):
generate_params = load_preset(name, verbose=True)
state.update(generate_params)
return state, *[generate_params[k] for k in presets_params()]
def random_preset(state):
params_and_values = {
'remove_tail_tokens': {
'top_p': [0.5, 0.8, 0.9, 0.95, 0.99],
'min_p': [0.5, 0.2, 0.1, 0.05, 0.01],
'top_k': [3, 5, 10, 20, 30, 40],
'typical_p': [0.2, 0.575, 0.95],
'tfs': [0.5, 0.8, 0.9, 0.95, 0.99],
'top_a': [0.5, 0.2, 0.1, 0.05, 0.01],
'epsilon_cutoff': [1, 3, 5, 7, 9],
'eta_cutoff': [3, 6, 9, 12, 15, 18],
},
'flatten_distribution': {
'temperature': [0.1, 0.5, 0.7, 0.8, 1, 1.2, 1.5, 2.0, 5.0],
'dynamic_temperature': [
[0.1, 1],
[0.1, 1.5],
[0.1, 2],
[0.1, 5],
[0.5, 1],
[0.5, 1.5],
[0.5, 2],
[0.5, 5],
[0.8, 1],
[0.8, 1.5],
[0.8, 2],
[0.8, 5],
[1, 1.5],
[1, 2],
[1, 5]
],
'smoothing_factor': [0.2, 0.3, 0.6, 1.2],
},
'repetition': {
'repetition_penalty': [1, 1.05, 1.1, 1.15, 1.20, 1.25],
'presence_penalty': [0, 0.1, 0.2, 0.4, 0.6, 0.8, 1.0, 2.0],
'frequency_penalty': [0, 0.1, 0.2, 0.4, 0.6, 0.8, 1.0, 2.0],
},
'other': {
'temperature_last': [True, False],
}
}
generate_params = default_preset()
for cat in params_and_values:
choices = list(params_and_values[cat].keys())
if shared.args.loader is not None:
choices = [x for x in choices if loader_contains(x)]
if len(choices) > 0:
choice = random.choice(choices)
value = random.choice(params_and_values[cat][choice])
if choice == 'dynamic_temperature':
generate_params['dynamic_temperature'] = True
generate_params['dynatemp_low'] = value[0]
generate_params['dynatemp_high'] = value[1]
else:
generate_params[choice] = value
state.update(generate_params)
logger.info("GENERATED_PRESET=")
pprint.PrettyPrinter(indent=4, width=1, sort_dicts=False).pprint(remove_defaults(state))
return state, *[generate_params[k] for k in presets_params()]
def loader_contains(sampler):
if sampler == 'dynamic_temperature' and 'dynatemp_low' in loaders_samplers[shared.args.loader]:
return True
else:
return sampler in loaders_samplers[shared.args.loader]
def remove_defaults(state):
defaults = default_preset()
data = {k: state[k] for k in presets_params()}
for k in list(data.keys()):
if data[k] == defaults[k]:
del data[k]
return data
def generate_preset_yaml(state):
data = remove_defaults(state)
return yaml.dump(data, sort_keys=False)
|