Froddan commited on
Commit
6f2df7f
1 Parent(s): a0555f6

Upload generate_model_grid_backup.py

Browse files
Files changed (1) hide show
  1. generate_model_grid_backup.py +294 -0
generate_model_grid_backup.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import namedtuple
2
+ from copy import copy
3
+ from itertools import permutations, chain
4
+ import random
5
+ import csv
6
+ from io import StringIO
7
+ from PIL import Image
8
+ import numpy as np
9
+
10
+ import modules.scripts as scripts
11
+ import gradio as gr
12
+
13
+ from modules import images, sd_samplers
14
+ from modules.hypernetworks import hypernetwork
15
+ from modules.processing import process_images, Processed, StableDiffusionProcessingTxt2Img
16
+ from modules.shared import opts, cmd_opts, state
17
+ import modules.shared as shared
18
+ import modules.sd_samplers
19
+ import modules.sd_models
20
+ import re
21
+
22
+
23
+ def apply_field(field):
24
+ def fun(p, x, xs):
25
+ setattr(p, field, x)
26
+
27
+ return fun
28
+
29
+
30
+ def apply_prompt(p, x, xs):
31
+ if xs[0] not in p.prompt and xs[0] not in p.negative_prompt:
32
+ raise RuntimeError(f"Prompt S/R did not find {xs[0]} in prompt or negative prompt.")
33
+
34
+ p.prompt = p.prompt.replace(xs[0], x)
35
+ p.negative_prompt = p.negative_prompt.replace(xs[0], x)
36
+
37
+ def edit_prompt(p,x,z):
38
+ p.prompt = z + " " + x
39
+
40
+
41
+ def apply_order(p, x, xs):
42
+ token_order = []
43
+
44
+ # Initally grab the tokens from the prompt, so they can be replaced in order of earliest seen
45
+ for token in x:
46
+ token_order.append((p.prompt.find(token), token))
47
+
48
+ token_order.sort(key=lambda t: t[0])
49
+
50
+ prompt_parts = []
51
+
52
+ # Split the prompt up, taking out the tokens
53
+ for _, token in token_order:
54
+ n = p.prompt.find(token)
55
+ prompt_parts.append(p.prompt[0:n])
56
+ p.prompt = p.prompt[n + len(token):]
57
+
58
+ # Rebuild the prompt with the tokens in the order we want
59
+ prompt_tmp = ""
60
+ for idx, part in enumerate(prompt_parts):
61
+ prompt_tmp += part
62
+ prompt_tmp += x[idx]
63
+ p.prompt = prompt_tmp + p.prompt
64
+
65
+
66
+ def build_samplers_dict():
67
+ samplers_dict = {}
68
+ for i, sampler in enumerate(sd_samplers.all_samplers):
69
+ samplers_dict[sampler.name.lower()] = i
70
+ for alias in sampler.aliases:
71
+ samplers_dict[alias.lower()] = i
72
+ return samplers_dict
73
+
74
+
75
+ def apply_sampler(p, x, xs):
76
+ sampler_index = build_samplers_dict().get(x.lower(), None)
77
+ if sampler_index is None:
78
+ raise RuntimeError(f"Unknown sampler: {x}")
79
+
80
+ p.sampler_index = sampler_index
81
+
82
+
83
+ def confirm_samplers(p, xs):
84
+ samplers_dict = build_samplers_dict()
85
+ for x in xs:
86
+ if x.lower() not in samplers_dict.keys():
87
+ raise RuntimeError(f"Unknown sampler: {x}")
88
+
89
+
90
+ def apply_checkpoint(p, x, xs):
91
+ info = modules.sd_models.get_closet_checkpoint_match(x)
92
+ if info is None:
93
+ raise RuntimeError(f"Unknown checkpoint: {x}")
94
+ modules.sd_models.reload_model_weights(shared.sd_model, info)
95
+ p.sd_model = shared.sd_model
96
+
97
+
98
+ def confirm_checkpoints(p, xs):
99
+ for x in xs:
100
+ if modules.sd_models.get_closet_checkpoint_match(x) is None:
101
+ raise RuntimeError(f"Unknown checkpoint: {x}")
102
+
103
+
104
+ def apply_hypernetwork(p, x, xs):
105
+ if x.lower() in ["", "none"]:
106
+ name = None
107
+ else:
108
+ name = hypernetwork.find_closest_hypernetwork_name(x)
109
+ if not name:
110
+ raise RuntimeError(f"Unknown hypernetwork: {x}")
111
+ hypernetwork.load_hypernetwork(name)
112
+
113
+
114
+ def apply_hypernetwork_strength(p, x, xs):
115
+ hypernetwork.apply_strength(x)
116
+
117
+
118
+ def confirm_hypernetworks(p, xs):
119
+ for x in xs:
120
+ if x.lower() in ["", "none"]:
121
+ continue
122
+ if not hypernetwork.find_closest_hypernetwork_name(x):
123
+ raise RuntimeError(f"Unknown hypernetwork: {x}")
124
+
125
+
126
+ def apply_clip_skip(p, x, xs):
127
+ opts.data["CLIP_stop_at_last_layers"] = x
128
+
129
+
130
+ def format_value_add_label(p, opt, x):
131
+ if type(x) == float:
132
+ x = round(x, 8)
133
+
134
+ return f"{opt.label}: {x}"
135
+
136
+
137
+ def format_value(p, opt, x):
138
+ if type(x) == float:
139
+ x = round(x, 8)
140
+ return x
141
+
142
+
143
+ def format_value_join_list(p, opt, x):
144
+ return ", ".join(x)
145
+
146
+
147
+ def do_nothing(p, x, xs):
148
+ pass
149
+
150
+
151
+ def format_nothing(p, opt, x):
152
+ return ""
153
+
154
+
155
+ def str_permutations(x):
156
+ """dummy function for specifying it in AxisOption's type when you want to get a list of permutations"""
157
+ return x
158
+
159
+ # AxisOption = namedtuple("AxisOption", ["label", "type", "apply", "format_value", "confirm"])
160
+ # AxisOptionImg2Img = namedtuple("AxisOptionImg2Img", ["label", "type", "apply", "format_value", "confirm"])
161
+
162
+
163
+ def draw_xy_grid(p, xs, ys, zs, x_labels, y_labels, cell, draw_legend, include_lone_images):
164
+ ver_texts = [[images.GridAnnotation(y)] for y in y_labels]
165
+ hor_texts = [[images.GridAnnotation(x)] for x in x_labels]
166
+
167
+ # Temporary list of all the images that are generated to be populated into the grid.
168
+ # Will be filled with empty images for any individual step that fails to process properly
169
+ image_cache = []
170
+
171
+ processed_result = None
172
+ cell_mode = "P"
173
+ cell_size = (1,1)
174
+
175
+ state.job_count = len(xs) * len(ys) * p.n_iter
176
+
177
+ for iy, y in enumerate(ys):
178
+ for ix, x in enumerate(xs):
179
+ state.job = f"{ix + iy * len(xs) + 1} out of {len(xs) * len(ys)}"
180
+ z = zs[iy]
181
+ processed:Processed = cell(x, y, z)
182
+ try:
183
+ # this dereference will throw an exception if the image was not processed
184
+ # (this happens in cases such as if the user stops the process from the UI)
185
+ processed_image = processed.images[0]
186
+
187
+ if processed_result is None:
188
+ # Use our first valid processed result as a template container to hold our full results
189
+ processed_result = copy(processed)
190
+ cell_mode = processed_image.mode
191
+ cell_size = processed_image.size
192
+ processed_result.images = [Image.new(cell_mode, cell_size)]
193
+
194
+ image_cache.append(processed_image)
195
+ if include_lone_images:
196
+ processed_result.images.append(processed_image)
197
+ processed_result.all_prompts.append(processed.prompt)
198
+ processed_result.all_seeds.append(processed.seed)
199
+ processed_result.infotexts.append(processed.infotexts[0])
200
+ except:
201
+ image_cache.append(Image.new(cell_mode, cell_size))
202
+
203
+ if not processed_result:
204
+ print("Unexpected error: draw_xy_grid failed to return even a single processed image")
205
+ return Processed()
206
+
207
+ grid = images.image_grid(image_cache, rows=len(ys))
208
+ if draw_legend:
209
+ grid = images.draw_grid_annotations(grid, cell_size[0], cell_size[1], hor_texts, ver_texts)
210
+
211
+ processed_result.images[0] = grid
212
+
213
+ return processed_result
214
+
215
+
216
+ class SharedSettingsStackHelper(object):
217
+ def __enter__(self):
218
+ self.CLIP_stop_at_last_layers = opts.CLIP_stop_at_last_layers
219
+ self.hypernetwork = opts.sd_hypernetwork
220
+ self.model = shared.sd_model
221
+
222
+ def __exit__(self, exc_type, exc_value, tb):
223
+ modules.sd_models.reload_model_weights(self.model)
224
+
225
+ hypernetwork.load_hypernetwork(self.hypernetwork)
226
+ hypernetwork.apply_strength()
227
+
228
+ opts.data["CLIP_stop_at_last_layers"] = self.CLIP_stop_at_last_layers
229
+
230
+
231
+ re_range = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\(([+-]\d+)\s*\))?\s*")
232
+ re_range_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\(([+-]\d+(?:.\d*)?)\s*\))?\s*")
233
+
234
+ re_range_count = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\[(\d+)\s*\])?\s*")
235
+ re_range_count_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\[(\d+(?:.\d*)?)\s*\])?\s*")
236
+
237
+ class Script(scripts.Script):
238
+ def title(self):
239
+ return "Generate Model Grid"
240
+
241
+ def ui(self, is_img2img):
242
+
243
+ with gr.Row():
244
+ x_values = gr.Textbox(label="Prompts, separated with &", lines=1)
245
+
246
+ with gr.Row():
247
+ y_values = gr.Textbox(label="Checkpoint file names, including file ending", lines=1)
248
+
249
+
250
+ with gr.Row():
251
+ z_values = gr.Textbox(label="Model tokens", lines=1)
252
+
253
+ draw_legend = gr.Checkbox(label='Draw legend', value=True)
254
+ include_lone_images = gr.Checkbox(label='Include Separate Images', value=False)
255
+ no_fixed_seeds = gr.Checkbox(label='Keep -1 for seeds', value=False)
256
+
257
+ return [x_values, y_values, z_values, draw_legend, include_lone_images, no_fixed_seeds]
258
+
259
+ def run(self, p, x_values, y_values, z_values, draw_legend, include_lone_images, no_fixed_seeds):
260
+ if not no_fixed_seeds:
261
+ modules.processing.fix_seed(p)
262
+
263
+ if not opts.return_grid:
264
+ p.batch_size = 1
265
+
266
+ xs = [x.strip() for x in chain.from_iterable(csv.reader(StringIO(x_values), delimiter='&'))]
267
+ ys = [x.strip() for x in chain.from_iterable(csv.reader(StringIO(y_values)))]
268
+ zs = [x.strip() for x in chain.from_iterable(csv.reader(StringIO(z_values)))]
269
+
270
+ def cell(x, y, z):
271
+ pc = copy(p)
272
+ edit_prompt(pc, x, z)
273
+ confirm_checkpoints(pc,ys)
274
+ apply_checkpoint(pc, y, ys)
275
+
276
+ return process_images(pc)
277
+
278
+ with SharedSettingsStackHelper():
279
+ processed = draw_xy_grid(
280
+ p,
281
+ xs=xs,
282
+ ys=ys,
283
+ zs=zs,
284
+ x_labels=xs,
285
+ y_labels=ys,
286
+ cell=cell,
287
+ draw_legend=draw_legend,
288
+ include_lone_images=include_lone_images
289
+ )
290
+
291
+ if opts.grid_save:
292
+ images.save_image(processed.images[0], p.outpath_grids, "xy_grid", prompt=p.prompt, seed=processed.seed, grid=True, p=p)
293
+
294
+ return processed