Froddan commited on
Commit
eecb6ed
1 Parent(s): 3d988d8

Delete generate_model_grid_backup.py

Browse files
Files changed (1) hide show
  1. generate_model_grid_backup.py +0 -294
generate_model_grid_backup.py DELETED
@@ -1,294 +0,0 @@
1
- from collections import namedtuple
2
- from copy import copy
3
- from itertools import permutations, chain
4
- import random
5
- import csv
6
- from io import StringIO
7
- from PIL import Image
8
- import numpy as np
9
-
10
- import modules.scripts as scripts
11
- import gradio as gr
12
-
13
- from modules import images, sd_samplers
14
- from modules.hypernetworks import hypernetwork
15
- from modules.processing import process_images, Processed, StableDiffusionProcessingTxt2Img
16
- from modules.shared import opts, cmd_opts, state
17
- import modules.shared as shared
18
- import modules.sd_samplers
19
- import modules.sd_models
20
- import re
21
-
22
-
23
- def apply_field(field):
24
- def fun(p, x, xs):
25
- setattr(p, field, x)
26
-
27
- return fun
28
-
29
-
30
- def apply_prompt(p, x, xs):
31
- if xs[0] not in p.prompt and xs[0] not in p.negative_prompt:
32
- raise RuntimeError(f"Prompt S/R did not find {xs[0]} in prompt or negative prompt.")
33
-
34
- p.prompt = p.prompt.replace(xs[0], x)
35
- p.negative_prompt = p.negative_prompt.replace(xs[0], x)
36
-
37
- def edit_prompt(p,x,z):
38
- p.prompt = z + " " + x
39
-
40
-
41
- def apply_order(p, x, xs):
42
- token_order = []
43
-
44
- # Initally grab the tokens from the prompt, so they can be replaced in order of earliest seen
45
- for token in x:
46
- token_order.append((p.prompt.find(token), token))
47
-
48
- token_order.sort(key=lambda t: t[0])
49
-
50
- prompt_parts = []
51
-
52
- # Split the prompt up, taking out the tokens
53
- for _, token in token_order:
54
- n = p.prompt.find(token)
55
- prompt_parts.append(p.prompt[0:n])
56
- p.prompt = p.prompt[n + len(token):]
57
-
58
- # Rebuild the prompt with the tokens in the order we want
59
- prompt_tmp = ""
60
- for idx, part in enumerate(prompt_parts):
61
- prompt_tmp += part
62
- prompt_tmp += x[idx]
63
- p.prompt = prompt_tmp + p.prompt
64
-
65
-
66
- def build_samplers_dict():
67
- samplers_dict = {}
68
- for i, sampler in enumerate(sd_samplers.all_samplers):
69
- samplers_dict[sampler.name.lower()] = i
70
- for alias in sampler.aliases:
71
- samplers_dict[alias.lower()] = i
72
- return samplers_dict
73
-
74
-
75
- def apply_sampler(p, x, xs):
76
- sampler_index = build_samplers_dict().get(x.lower(), None)
77
- if sampler_index is None:
78
- raise RuntimeError(f"Unknown sampler: {x}")
79
-
80
- p.sampler_index = sampler_index
81
-
82
-
83
- def confirm_samplers(p, xs):
84
- samplers_dict = build_samplers_dict()
85
- for x in xs:
86
- if x.lower() not in samplers_dict.keys():
87
- raise RuntimeError(f"Unknown sampler: {x}")
88
-
89
-
90
- def apply_checkpoint(p, x, xs):
91
- info = modules.sd_models.get_closet_checkpoint_match(x)
92
- if info is None:
93
- raise RuntimeError(f"Unknown checkpoint: {x}")
94
- modules.sd_models.reload_model_weights(shared.sd_model, info)
95
- p.sd_model = shared.sd_model
96
-
97
-
98
- def confirm_checkpoints(p, xs):
99
- for x in xs:
100
- if modules.sd_models.get_closet_checkpoint_match(x) is None:
101
- raise RuntimeError(f"Unknown checkpoint: {x}")
102
-
103
-
104
- def apply_hypernetwork(p, x, xs):
105
- if x.lower() in ["", "none"]:
106
- name = None
107
- else:
108
- name = hypernetwork.find_closest_hypernetwork_name(x)
109
- if not name:
110
- raise RuntimeError(f"Unknown hypernetwork: {x}")
111
- hypernetwork.load_hypernetwork(name)
112
-
113
-
114
- def apply_hypernetwork_strength(p, x, xs):
115
- hypernetwork.apply_strength(x)
116
-
117
-
118
- def confirm_hypernetworks(p, xs):
119
- for x in xs:
120
- if x.lower() in ["", "none"]:
121
- continue
122
- if not hypernetwork.find_closest_hypernetwork_name(x):
123
- raise RuntimeError(f"Unknown hypernetwork: {x}")
124
-
125
-
126
- def apply_clip_skip(p, x, xs):
127
- opts.data["CLIP_stop_at_last_layers"] = x
128
-
129
-
130
- def format_value_add_label(p, opt, x):
131
- if type(x) == float:
132
- x = round(x, 8)
133
-
134
- return f"{opt.label}: {x}"
135
-
136
-
137
- def format_value(p, opt, x):
138
- if type(x) == float:
139
- x = round(x, 8)
140
- return x
141
-
142
-
143
- def format_value_join_list(p, opt, x):
144
- return ", ".join(x)
145
-
146
-
147
- def do_nothing(p, x, xs):
148
- pass
149
-
150
-
151
- def format_nothing(p, opt, x):
152
- return ""
153
-
154
-
155
- def str_permutations(x):
156
- """dummy function for specifying it in AxisOption's type when you want to get a list of permutations"""
157
- return x
158
-
159
- # AxisOption = namedtuple("AxisOption", ["label", "type", "apply", "format_value", "confirm"])
160
- # AxisOptionImg2Img = namedtuple("AxisOptionImg2Img", ["label", "type", "apply", "format_value", "confirm"])
161
-
162
-
163
- def draw_xy_grid(p, xs, ys, zs, x_labels, y_labels, cell, draw_legend, include_lone_images):
164
- ver_texts = [[images.GridAnnotation(y)] for y in y_labels]
165
- hor_texts = [[images.GridAnnotation(x)] for x in x_labels]
166
-
167
- # Temporary list of all the images that are generated to be populated into the grid.
168
- # Will be filled with empty images for any individual step that fails to process properly
169
- image_cache = []
170
-
171
- processed_result = None
172
- cell_mode = "P"
173
- cell_size = (1,1)
174
-
175
- state.job_count = len(xs) * len(ys) * p.n_iter
176
-
177
- for iy, y in enumerate(ys):
178
- for ix, x in enumerate(xs):
179
- state.job = f"{ix + iy * len(xs) + 1} out of {len(xs) * len(ys)}"
180
- z = zs[iy]
181
- processed:Processed = cell(x, y, z)
182
- try:
183
- # this dereference will throw an exception if the image was not processed
184
- # (this happens in cases such as if the user stops the process from the UI)
185
- processed_image = processed.images[0]
186
-
187
- if processed_result is None:
188
- # Use our first valid processed result as a template container to hold our full results
189
- processed_result = copy(processed)
190
- cell_mode = processed_image.mode
191
- cell_size = processed_image.size
192
- processed_result.images = [Image.new(cell_mode, cell_size)]
193
-
194
- image_cache.append(processed_image)
195
- if include_lone_images:
196
- processed_result.images.append(processed_image)
197
- processed_result.all_prompts.append(processed.prompt)
198
- processed_result.all_seeds.append(processed.seed)
199
- processed_result.infotexts.append(processed.infotexts[0])
200
- except:
201
- image_cache.append(Image.new(cell_mode, cell_size))
202
-
203
- if not processed_result:
204
- print("Unexpected error: draw_xy_grid failed to return even a single processed image")
205
- return Processed()
206
-
207
- grid = images.image_grid(image_cache, rows=len(ys))
208
- if draw_legend:
209
- grid = images.draw_grid_annotations(grid, cell_size[0], cell_size[1], hor_texts, ver_texts)
210
-
211
- processed_result.images[0] = grid
212
-
213
- return processed_result
214
-
215
-
216
- class SharedSettingsStackHelper(object):
217
- def __enter__(self):
218
- self.CLIP_stop_at_last_layers = opts.CLIP_stop_at_last_layers
219
- self.hypernetwork = opts.sd_hypernetwork
220
- self.model = shared.sd_model
221
-
222
- def __exit__(self, exc_type, exc_value, tb):
223
- modules.sd_models.reload_model_weights(self.model)
224
-
225
- hypernetwork.load_hypernetwork(self.hypernetwork)
226
- hypernetwork.apply_strength()
227
-
228
- opts.data["CLIP_stop_at_last_layers"] = self.CLIP_stop_at_last_layers
229
-
230
-
231
- re_range = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\(([+-]\d+)\s*\))?\s*")
232
- re_range_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\(([+-]\d+(?:.\d*)?)\s*\))?\s*")
233
-
234
- re_range_count = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\[(\d+)\s*\])?\s*")
235
- re_range_count_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\[(\d+(?:.\d*)?)\s*\])?\s*")
236
-
237
- class Script(scripts.Script):
238
- def title(self):
239
- return "Generate Model Grid"
240
-
241
- def ui(self, is_img2img):
242
-
243
- with gr.Row():
244
- x_values = gr.Textbox(label="Prompts, separated with &", lines=1)
245
-
246
- with gr.Row():
247
- y_values = gr.Textbox(label="Checkpoint file names, including file ending", lines=1)
248
-
249
-
250
- with gr.Row():
251
- z_values = gr.Textbox(label="Model tokens", lines=1)
252
-
253
- draw_legend = gr.Checkbox(label='Draw legend', value=True)
254
- include_lone_images = gr.Checkbox(label='Include Separate Images', value=False)
255
- no_fixed_seeds = gr.Checkbox(label='Keep -1 for seeds', value=False)
256
-
257
- return [x_values, y_values, z_values, draw_legend, include_lone_images, no_fixed_seeds]
258
-
259
- def run(self, p, x_values, y_values, z_values, draw_legend, include_lone_images, no_fixed_seeds):
260
- if not no_fixed_seeds:
261
- modules.processing.fix_seed(p)
262
-
263
- if not opts.return_grid:
264
- p.batch_size = 1
265
-
266
- xs = [x.strip() for x in chain.from_iterable(csv.reader(StringIO(x_values), delimiter='&'))]
267
- ys = [x.strip() for x in chain.from_iterable(csv.reader(StringIO(y_values)))]
268
- zs = [x.strip() for x in chain.from_iterable(csv.reader(StringIO(z_values)))]
269
-
270
- def cell(x, y, z):
271
- pc = copy(p)
272
- edit_prompt(pc, x, z)
273
- confirm_checkpoints(pc,ys)
274
- apply_checkpoint(pc, y, ys)
275
-
276
- return process_images(pc)
277
-
278
- with SharedSettingsStackHelper():
279
- processed = draw_xy_grid(
280
- p,
281
- xs=xs,
282
- ys=ys,
283
- zs=zs,
284
- x_labels=xs,
285
- y_labels=ys,
286
- cell=cell,
287
- draw_legend=draw_legend,
288
- include_lone_images=include_lone_images
289
- )
290
-
291
- if opts.grid_save:
292
- images.save_image(processed.images[0], p.outpath_grids, "xy_grid", prompt=p.prompt, seed=processed.seed, grid=True, p=p)
293
-
294
- return processed