jennyzzt commited on
Commit
59da1c6
·
1 Parent(s): f5845a2
Files changed (5) hide show
  1. .gitignore +1 -0
  2. app.py +69 -0
  3. generate_examples.py +46 -0
  4. qdhf_things.py +558 -0
  5. requirements.txt +12 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__/
app.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from qdhf_things import run_qdhf, many_pictures
3
+ from generate_examples import EXAMPLE_PROMPTS
4
+ import os
5
+ import io
6
+
7
+ # Get the absolute path to the examples directory
8
+ EXAMPLES_DIR = os.path.abspath("./examples")
9
+
10
+ def generate_images(prompt, init_pop, total_itrs):
11
+ init_pop = int(init_pop)
12
+ total_itrs = int(total_itrs)
13
+
14
+ # Use placeholder if prompt is empty
15
+ if not prompt.strip():
16
+ prompt = "a duck crossing the street"
17
+
18
+ archive_plots = []
19
+ for archive, plt_fig in run_qdhf(prompt, init_pop, total_itrs):
20
+ buf = io.BytesIO()
21
+ plt_fig.savefig(buf, format='png')
22
+ buf.seek(0)
23
+ archive_plots.append(buf.getvalue())
24
+
25
+ final_archive_plot = archive_plots[-1]
26
+ generated_images = many_pictures(archive, prompt)
27
+
28
+ # Save the final archive plot and generated images as temporary files
29
+ temp_archive_file = "temp_archive_plot.png"
30
+ temp_images_file = "temp_generated_images.png"
31
+
32
+ with open(temp_archive_file, 'wb') as f:
33
+ f.write(final_archive_plot)
34
+
35
+ generated_images.savefig(temp_images_file)
36
+
37
+ return temp_archive_file, temp_images_file
38
+
39
+ def show_example(example):
40
+ index = EXAMPLE_PROMPTS.index(example)
41
+ archive_plot_path = os.path.join(EXAMPLES_DIR, f"archive_{index}.mp4")
42
+ images_path = os.path.join(EXAMPLES_DIR, f"archive_pics_{index}.png")
43
+ return archive_plot_path, images_path
44
+
45
+ with gr.Blocks() as demo:
46
+ gr.Markdown("# Quality Diversity through Human Feedback")
47
+
48
+ with gr.Row():
49
+ with gr.Column(scale=1):
50
+ prompt_input = gr.Textbox(label="Enter your prompt here", placeholder="a duck crossing the street")
51
+ init_pop = gr.Slider(minimum=10, maximum=300, value=200, step=10, label="Initial Population")
52
+ total_itrs = gr.Slider(minimum=10, maximum=300, value=200, step=10, label="Total Iterations")
53
+ generate_button = gr.Button("Generate", variant="primary")
54
+
55
+ with gr.Column(scale=2):
56
+ archive_output = gr.Video(label="Archive Plot")
57
+ images_output = gr.Image(label="Generated Pictures")
58
+
59
+ generate_button.click(generate_images,
60
+ inputs=[prompt_input, init_pop, total_itrs],
61
+ outputs=[archive_output, images_output])
62
+
63
+ gr.Markdown("## Examples:")
64
+ for example in EXAMPLE_PROMPTS:
65
+ example_button = gr.Button(example)
66
+ example_button.click(show_example, inputs=example_button, outputs=[archive_output, images_output])
67
+
68
+ if __name__ == "__main__":
69
+ demo.launch()
generate_examples.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pickle
3
+ import imageio
4
+ import matplotlib.pyplot as plt
5
+ from qdhf_things import run_qdhf, many_pictures
6
+
7
+ EXAMPLE_PROMPTS = [
8
+ 'an image of a cat on the sofa',
9
+ 'an image of a bear in a national park',
10
+ 'a photo of an astronaut riding a horse on mars',
11
+ 'a drawing of a tree behind a fence',
12
+ 'a painting of a sunset over the ocean',
13
+ 'a sketch of a racoon sitting on a mushroom',
14
+ 'a picture of a dragon flying over a castle',
15
+ 'a photo of a robot playing the guitar',
16
+ ]
17
+
18
+ if __name__ == '__main__':
19
+ print('Hello! I am a script!')
20
+
21
+ for i, example_prompt in enumerate(EXAMPLE_PROMPTS):
22
+ # Initialize list to store images for GIF
23
+ images = []
24
+
25
+ # Run QDHF
26
+ for archive, plt in run_qdhf(example_prompt):
27
+ # Save current plot to a temporary file
28
+ temp_filename = f'./examples/temp_plot_{i}.png'
29
+ plt.savefig(temp_filename)
30
+ plt.close()
31
+
32
+ # Read the saved image and append to images list
33
+ images.append(imageio.imread(temp_filename))
34
+ os.remove(temp_filename)
35
+
36
+ # Create a GIF from the images
37
+ gif_filename = f'./examples/archive_{i}.gif'
38
+ imageio.mimsave(gif_filename, images, duration=0.5) # Adjust duration as needed
39
+
40
+ # Save archive with pickle
41
+ pickle.dump(archive, open(f'./examples/archive_{i}.pkl', 'wb'))
42
+
43
+ # Save the final archive plot
44
+ plt = many_pictures(archive, example_prompt)
45
+ plt.savefig(f'./examples/archive_pics_{i}.png')
46
+ plt.close()
qdhf_things.py ADDED
@@ -0,0 +1,558 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import pydantic
3
+ import time
4
+ import numpy as np
5
+ from tqdm import tqdm, trange
6
+ import torch
7
+ from torch import nn
8
+ from diffusers import StableDiffusionPipeline
9
+ import clip
10
+ from dreamsim import dreamsim
11
+ from ribs.archives import GridArchive
12
+ from ribs.schedulers import Scheduler
13
+ from ribs.emitters import GaussianEmitter
14
+ import itertools
15
+ from ribs.visualize import grid_archive_heatmap
16
+
17
+
18
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
+ torch.cuda.empty_cache()
20
+ print("Torch device:", DEVICE)
21
+
22
+ # Use float16 for GPU, float32 for CPU.
23
+ TORCH_DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
24
+ print("Torch dtype:", TORCH_DTYPE)
25
+ IMG_WIDTH = 256
26
+ IMG_HEIGHT = 256
27
+ SD_IN_HEIGHT = 32
28
+ SD_IN_WIDTH = 32
29
+ SD_CHECKPOINT = "lambdalabs/miniSD-diffusers"
30
+
31
+ BATCH_SIZE = 4
32
+ SD_IN_CHANNELS = 4
33
+ SD_IN_SHAPE = (
34
+ BATCH_SIZE,
35
+ SD_IN_CHANNELS,
36
+ SD_IN_HEIGHT,
37
+ SD_IN_WIDTH,
38
+ )
39
+
40
+ SDPIPE = StableDiffusionPipeline.from_pretrained(
41
+ SD_CHECKPOINT,
42
+ torch_dtype=TORCH_DTYPE,
43
+ safety_checker=None, # For faster inference.
44
+ requires_safety_checker=False,
45
+ )
46
+
47
+ SDPIPE.set_progress_bar_config(disable=True)
48
+ SDPIPE = SDPIPE.to(DEVICE)
49
+
50
+ GRID_SIZE = (20, 20)
51
+ SEED = 123
52
+ np.random.seed(SEED)
53
+ torch.manual_seed(SEED)
54
+
55
+ # INIT_POP = 200 # Initial population.
56
+ # TOTAL_ITRS = 200 # Total number of iterations.
57
+
58
+
59
+ class DivProj(nn.Module):
60
+ def __init__(self, input_dim, latent_dim=2):
61
+ super().__init__()
62
+ self.proj = nn.Sequential(
63
+ nn.Linear(in_features=input_dim, out_features=latent_dim),
64
+ )
65
+
66
+ def forward(self, x):
67
+ """Get diversity representations."""
68
+ x = self.proj(x)
69
+ return x
70
+
71
+ def calc_dis(self, x1, x2):
72
+ """Calculate diversity distance as (squared) L2 distance."""
73
+ x1 = self.forward(x1)
74
+ x2 = self.forward(x2)
75
+ return torch.sum(torch.square(x1 - x2), -1)
76
+
77
+ def triplet_delta_dis(self, ref, x1, x2):
78
+ """Calculate delta distance comparing x1 and x2 to ref."""
79
+ x1 = self.forward(x1)
80
+ x2 = self.forward(x2)
81
+ ref = self.forward(ref)
82
+ return (torch.sum(torch.square(ref - x1), -1) -
83
+ torch.sum(torch.square(ref - x2), -1))
84
+
85
+
86
+ # Triplet loss with margin 0.05.
87
+ # The binary preference labels are scaled to y = 1 or -1 for the loss, where y = 1 means x2 is more similar to ref than x1.
88
+ loss_fn = lambda y, delta_dis: torch.max(
89
+ torch.tensor([0.0]).to(DEVICE), 0.05 - (y * 2 - 1) * delta_dis
90
+ ).mean()
91
+
92
+
93
+ def fit_div_proj(inputs, dreamsim_features, latent_dim, batch_size=32):
94
+ """Trains the DivProj model on ground-truth labels."""
95
+ t = time.time()
96
+ model = DivProj(input_dim=inputs.shape[-1], latent_dim=latent_dim)
97
+ model.to(DEVICE)
98
+
99
+ optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
100
+
101
+ n_pref_data = inputs.shape[0]
102
+ ref = inputs[:, 0]
103
+ x1 = inputs[:, 1]
104
+ x2 = inputs[:, 2]
105
+
106
+ n_train = int(n_pref_data * 0.75)
107
+ n_val = n_pref_data - n_train
108
+
109
+ # Split data into train and val.
110
+ ref_train = ref[:n_train]
111
+ x1_train = x1[:n_train]
112
+ x2_train = x2[:n_train]
113
+ ref_val = ref[n_train:]
114
+ x1_val = x1[n_train:]
115
+ x2_val = x2[n_train:]
116
+
117
+ # Split DreamSim features into train and val.
118
+ ref_dreamsim_features = dreamsim_features[:, 0]
119
+ x1_dreamsim_features = dreamsim_features[:, 1]
120
+ x2_dreamsim_features = dreamsim_features[:, 2]
121
+ ref_gt_train = ref_dreamsim_features[:n_train]
122
+ x1_gt_train = x1_dreamsim_features[:n_train]
123
+ x2_gt_train = x2_dreamsim_features[:n_train]
124
+ ref_gt_val = ref_dreamsim_features[n_train:]
125
+ x1_gt_val = x1_dreamsim_features[n_train:]
126
+ x2_gt_val = x2_dreamsim_features[n_train:]
127
+
128
+ val_acc = []
129
+ n_iters_per_epoch = max((n_train) // batch_size, 1)
130
+ for epoch in range(200):
131
+ for _ in range(n_iters_per_epoch):
132
+ optimizer.zero_grad()
133
+
134
+ idx = np.random.choice(n_train, batch_size)
135
+ batch_ref = ref_train[idx].float()
136
+ batch1 = x1_train[idx].float()
137
+ batch2 = x2_train[idx].float()
138
+
139
+ # Get delta distance from model.
140
+ delta_dis = model.triplet_delta_dis(batch_ref, batch1, batch2)
141
+
142
+ # Get preference labels from DreamSim features.
143
+ gt_dis = torch.nn.functional.cosine_similarity(
144
+ ref_gt_train[idx], x2_gt_train[idx], dim=-1
145
+ ) - torch.nn.functional.cosine_similarity(
146
+ ref_gt_train[idx], x1_gt_train[idx], dim=-1
147
+ )
148
+ gt = (gt_dis > 0).to(TORCH_DTYPE) # if distance from the two sims are greater than 0, convert gt to torch_type
149
+
150
+ loss = loss_fn(gt, delta_dis)
151
+ loss.backward()
152
+ optimizer.step()
153
+
154
+ # Validate.
155
+ n_correct = 0
156
+ n_total = 0
157
+ with torch.no_grad():
158
+ idx = np.arange(n_val)
159
+ batch_ref = ref_val[idx].float()
160
+ batch1 = x1_val[idx].float()
161
+ batch2 = x2_val[idx].float()
162
+ delta_dis = model.triplet_delta_dis(batch_ref, batch1, batch2)
163
+ pred = delta_dis > 0
164
+ gt_dis = torch.nn.functional.cosine_similarity(
165
+ ref_gt_val[idx], x2_gt_val[idx], dim=-1
166
+ ) - torch.nn.functional.cosine_similarity(
167
+ ref_gt_val[idx], x1_gt_val[idx], dim=-1
168
+ )
169
+ gt = gt_dis > 0
170
+ n_correct += (pred == gt).sum().item()
171
+ n_total += len(idx)
172
+
173
+ acc = n_correct / n_total
174
+ val_acc.append(acc)
175
+
176
+ # Early stopping if val_acc does not improve for 10 epochs.
177
+ if epoch > 10 and np.mean(val_acc[-10:]) < np.mean(val_acc[-11:-1]):
178
+ break
179
+
180
+ print(
181
+ f"{np.round(time.time()- t, 1)}s ({epoch+1} epochs) | DivProj (n={n_pref_data}) fitted with val acc.: {acc}"
182
+ )
183
+
184
+ return model.to(TORCH_DTYPE), acc
185
+
186
+
187
+ def compute_diversity_measures(clip_features, diversity_model):
188
+ with torch.no_grad():
189
+ measures = diversity_model(clip_features).detach().cpu().numpy()
190
+ return measures
191
+
192
+
193
+ def tensor_to_list(tensor):
194
+ sols = tensor.detach().cpu().numpy().astype(np.float32)
195
+ return sols.reshape(sols.shape[0], -1)
196
+
197
+
198
+ def list_to_tensor(list_):
199
+ sols = np.array(list_).reshape(
200
+ len(list_), 4, SD_IN_HEIGHT, SD_IN_WIDTH
201
+ ) # Hard-coded for now.
202
+ return torch.tensor(sols, dtype=TORCH_DTYPE, device=DEVICE)
203
+
204
+
205
+ def create_scheduler(
206
+ sols,
207
+ objs,
208
+ clip_features,
209
+ diversity_model,
210
+ seed=None,
211
+ ):
212
+ measures = compute_diversity_measures(clip_features, diversity_model)
213
+ archive_bounds = np.array(
214
+ [np.quantile(measures, 0.01, axis=0), np.quantile(measures, 0.99, axis=0)]
215
+ ).T
216
+
217
+ sols = tensor_to_list(sols)
218
+
219
+ # Set up archive.
220
+ archive = GridArchive(
221
+ solution_dim=len(sols[0]), dims=GRID_SIZE, ranges=archive_bounds, seed=SEED
222
+ )
223
+
224
+ # Add initial solutions to the archive.
225
+ archive.add(sols, objs, measures)
226
+
227
+ # Set up the GaussianEmitter.
228
+ emitters = [
229
+ GaussianEmitter(
230
+ archive=archive,
231
+ sigma=0.1,
232
+ initial_solutions=archive.sample_elites(BATCH_SIZE)["solution"],
233
+ batch_size=BATCH_SIZE,
234
+ seed=SEED,
235
+ )
236
+ ]
237
+
238
+ # Return the archive and scheduler.
239
+ return archive, Scheduler(archive, emitters)
240
+
241
+
242
+ def plot_archive(archive):
243
+ plt.figure(figsize=(6, 4.5))
244
+ grid_archive_heatmap(archive, vmin=0, vmax=100)
245
+ plt.xlabel("Diversity Metric 1")
246
+ plt.ylabel("Diversity Metric 2")
247
+ return plt
248
+
249
+
250
+ def run_qdhf(prompt:str, init_pop: int=200, total_itrs: int=200):
251
+ INIT_POP = init_pop
252
+ TOTAL_ITRS = total_itrs
253
+
254
+ # This tutorial uses ViT-B/32, you may use other checkpoints depending on your resources and need.
255
+ CLIP_MODEL, CLIP_PREPROCESS = clip.load("ViT-B/32", device=DEVICE)
256
+ CLIP_MODEL.eval()
257
+ for p in CLIP_MODEL.parameters():
258
+ p.requires_grad_(False)
259
+
260
+ def compute_clip_scores(imgs, text, return_clip_features=False):
261
+ """Computes CLIP scores for a batch of images and a given text prompt."""
262
+ img_tensor = torch.stack([CLIP_PREPROCESS(img) for img in imgs]).to(DEVICE)
263
+ tokenized_text = clip.tokenize([text]).to(DEVICE)
264
+ img_logits, _text_logits = CLIP_MODEL(img_tensor, tokenized_text)
265
+ img_logits = img_logits.detach().cpu().numpy().astype(np.float32)[:, 0]
266
+ img_logits = 1 / img_logits * 100
267
+ # Remap the objective from minimizing [0, 10] to maximizing [0, 100]
268
+ img_logits = (10.0 - img_logits) * 10.0
269
+
270
+ if return_clip_features:
271
+ clip_features = CLIP_MODEL.encode_image(img_tensor).to(TORCH_DTYPE)
272
+ return img_logits, clip_features
273
+ else:
274
+ return img_logits
275
+
276
+ DREAMSIM_MODEL, DREAMSIM_PREPROCESS = dreamsim(
277
+ pretrained=True, dreamsim_type="open_clip_vitb32", device=DEVICE
278
+ )
279
+
280
+ def evaluate_lsi(
281
+ latents,
282
+ prompt,
283
+ return_features=False,
284
+ diversity_model=None,
285
+ ):
286
+ """Evaluates the objective of LSI for a batch of latents and a given text prompt."""
287
+
288
+ images = SDPIPE(
289
+ prompt,
290
+ num_images_per_prompt=latents.shape[0],
291
+ latents=latents,
292
+ # num_inference_steps=1, # For testing.
293
+ ).images
294
+
295
+ objs, clip_features = compute_clip_scores(
296
+ images,
297
+ prompt,
298
+ return_clip_features=True,
299
+ )
300
+
301
+ images = torch.cat([DREAMSIM_PREPROCESS(img) for img in images]).to(DEVICE)
302
+ dreamsim_features = DREAMSIM_MODEL.embed(images)
303
+
304
+ if diversity_model is not None:
305
+ measures = compute_diversity_measures(clip_features, diversity_model)
306
+ else:
307
+ measures = None
308
+
309
+ if return_features:
310
+ return objs, measures, clip_features, dreamsim_features
311
+ else:
312
+ return objs, measures
313
+
314
+
315
+ update_schedule = [1, 21, 51, 101] # Iterations on which to update the archive.
316
+ n_pref_data = 1000 # Number of preferences used in each update.
317
+
318
+ archive = None
319
+
320
+ best = 0.0
321
+ for itr in trange(1, TOTAL_ITRS + 1):
322
+ # Update archive and scheduler if needed.
323
+ if itr in update_schedule:
324
+ if archive is None:
325
+ tqdm.write("Initializing archive and diversity projection.")
326
+
327
+ all_sols = []
328
+ all_clip_features = []
329
+ all_dreamsim_features = []
330
+ all_objs = []
331
+
332
+ # Sample random solutions and get judgment on similarity.
333
+ n_batches = INIT_POP // BATCH_SIZE
334
+ for _ in range(n_batches):
335
+ sols = torch.randn(SD_IN_SHAPE, device=DEVICE, dtype=TORCH_DTYPE)
336
+ objs, _, clip_features, dreamsim_features = evaluate_lsi(
337
+ sols, prompt, return_features=True
338
+ )
339
+ all_sols.append(sols)
340
+ all_clip_features.append(clip_features)
341
+ all_dreamsim_features.append(dreamsim_features)
342
+ all_objs.append(objs)
343
+ all_sols = torch.concat(all_sols, dim=0)
344
+ all_clip_features = torch.concat(all_clip_features, dim=0)
345
+ all_dreamsim_features = torch.concat(all_dreamsim_features, dim=0)
346
+ all_objs = np.concatenate(all_objs, axis=0)
347
+
348
+ # Initialize the diversity projection model.
349
+ div_proj_data = []
350
+ div_proj_labels = []
351
+ for _ in range(n_pref_data):
352
+ idx = np.random.choice(all_sols.shape[0], 3)
353
+ div_proj_data.append(all_clip_features[idx])
354
+ div_proj_labels.append(all_dreamsim_features[idx])
355
+ div_proj_data = torch.concat(div_proj_data, dim=0)
356
+ div_proj_labels = torch.concat(div_proj_labels, dim=0)
357
+ div_proj_data = div_proj_data.reshape(n_pref_data, 3, -1)
358
+ div_proj_label = div_proj_labels.reshape(n_pref_data, 3, -1)
359
+ diversity_model, div_proj_acc = fit_div_proj(
360
+ div_proj_data,
361
+ div_proj_label,
362
+ latent_dim=2,
363
+ )
364
+
365
+ else:
366
+ tqdm.write("Updating archive and diversity projection.")
367
+
368
+ # Get all the current solutions and collect feedback.
369
+ all_sols = list_to_tensor(archive.data("solution"))
370
+ n_batches = np.ceil(len(all_sols) / BATCH_SIZE).astype(int)
371
+ all_clip_features = []
372
+ all_dreamsim_features = []
373
+ all_objs = []
374
+ for i in range(n_batches):
375
+ sols = all_sols[i * BATCH_SIZE : (i + 1) * BATCH_SIZE]
376
+ objs, _, clip_features, dreamsim_features = evaluate_lsi(
377
+ sols, prompt, return_features=True
378
+ )
379
+ all_clip_features.append(clip_features)
380
+ all_dreamsim_features.append(dreamsim_features)
381
+ all_objs.append(objs)
382
+ all_clip_features = torch.concat(
383
+ all_clip_features, dim=0
384
+ ) # n_pref_data * 3, dim
385
+ all_dreamsim_features = torch.concat(all_dreamsim_features, dim=0)
386
+ all_objs = np.concatenate(all_objs, axis=0)
387
+
388
+ # Update the diversity projection model.
389
+ additional_features = []
390
+ additional_labels = []
391
+ for _ in range(n_pref_data):
392
+ idx = np.random.choice(all_sols.shape[0], 3)
393
+ additional_features.append(all_clip_features[idx])
394
+ additional_labels.append(all_dreamsim_features[idx])
395
+ additional_features = torch.concat(additional_features, dim=0)
396
+ additional_labels = torch.concat(additional_labels, dim=0)
397
+ additional_div_proj_data = additional_features.reshape(n_pref_data, 3, -1)
398
+ additional_div_proj_label = additional_labels.reshape(n_pref_data, 3, -1)
399
+ div_proj_data = torch.concat(
400
+ (div_proj_data, additional_div_proj_data), axis=0
401
+ )
402
+ div_proj_label = torch.concat(
403
+ (div_proj_label, additional_div_proj_label), axis=0
404
+ )
405
+ diversity_model, div_proj_acc = fit_div_proj(
406
+ div_proj_data,
407
+ div_proj_label,
408
+ latent_dim=2,
409
+ )
410
+
411
+ archive, scheduler = create_scheduler(
412
+ all_sols,
413
+ all_objs,
414
+ all_clip_features,
415
+ diversity_model,
416
+ seed=SEED,
417
+ )
418
+
419
+ # Primary QD loop.
420
+ sols = scheduler.ask()
421
+ sols = list_to_tensor(sols)
422
+ objs, measures, clip_features, dreamsim_features = evaluate_lsi(
423
+ sols, prompt, return_features=True, diversity_model=diversity_model
424
+ )
425
+ best = max(best, max(objs))
426
+ scheduler.tell(objs, measures)
427
+
428
+ # This can be used as a flag to save on the final iteration, but note that
429
+ # we do not save results in this tutorial.
430
+ final_itr = itr == TOTAL_ITRS
431
+
432
+ # Update the summary statistics for the archive.
433
+ qd_score, coverage = archive.stats.norm_qd_score, archive.stats.coverage
434
+
435
+ tqdm.write(f"QD score: {np.round(qd_score, 2)} Coverage: {coverage * 100}")
436
+
437
+ plt = plot_archive(archive)
438
+ yield archive, plt
439
+
440
+ plt = plot_archive(archive)
441
+ return archive, plt
442
+
443
+
444
+ def many_pictures(archive, prompt:str):
445
+ # Modify this to determine how many images to plot along each dimension.
446
+ img_freq = (
447
+ 4, # Number of columns of images.
448
+ 4, # Number of rows of images.
449
+ )
450
+
451
+ # List of images.
452
+ imgs = []
453
+
454
+ # Convert archive to a df with solutions available.
455
+ df = archive.data(return_type="pandas")
456
+
457
+ # Compute the min and max measures for which solutions were found.
458
+ measure_bounds = np.array(
459
+ [
460
+ (df["measures_0"].min(), df["measures_0"].max()),
461
+ (df["measures_1"].min(), df["measures_1"].max()),
462
+ ]
463
+ )
464
+
465
+ archive_bounds = np.array(
466
+ [archive.boundaries[0][[0, -1]], archive.boundaries[1][[0, -1]]]
467
+ )
468
+
469
+
470
+ delta_measures_0 = (archive_bounds[0][1] - archive_bounds[0][0]) / img_freq[0]
471
+ delta_measures_1 = (archive_bounds[1][1] - archive_bounds[1][0]) / img_freq[1]
472
+
473
+
474
+ for col, row in itertools.product(range(img_freq[1]), range(img_freq[0])):
475
+ # Compute bounds of a box in measure space.
476
+ measures_0_low = archive_bounds[0][0] + delta_measures_0 * row
477
+ measures_0_high = archive_bounds[0][0] + delta_measures_0 * (row + 1)
478
+ measures_1_low = archive_bounds[1][0] + delta_measures_1 * col
479
+ measures_1_high = archive_bounds[1][0] + delta_measures_1 * (col + 1)
480
+
481
+ if row == 0:
482
+ measures_0_low = measure_bounds[0][0]
483
+ if col == 0:
484
+ measures_1_low = measure_bounds[1][0]
485
+ if row == img_freq[0] - 1:
486
+ measures_0_high = measure_bounds[0][1]
487
+ if col == img_freq[1] - 1:
488
+ measures_0_high = measure_bounds[1][1]
489
+
490
+ # Query for a solution with measures within this box.
491
+ query_string = (
492
+ f"{measures_0_low} <= measures_0 & measures_0 <= {measures_0_high} & "
493
+ f"{measures_1_low} <= measures_1 & measures_1 <= {measures_1_high}"
494
+ )
495
+ df_box = df.query(query_string)
496
+
497
+ if not df_box.empty:
498
+ # Randomly sample a solution from the box.
499
+ # Stable Diffusion solutions have SD_IN_CHANNELS * SD_IN_HEIGHT * SD_IN_WIDTH
500
+ # dimensions, so the final solution col is solution_(x-1).
501
+ sol = (
502
+ df_box.loc[
503
+ :,
504
+ "solution_0" : "solution_{}".format(
505
+ SD_IN_CHANNELS * SD_IN_HEIGHT * SD_IN_WIDTH - 1
506
+ ),
507
+ ]
508
+ .sample(n=1)
509
+ .iloc[0]
510
+ )
511
+
512
+ # Convert the latent vector solution to an image.
513
+ latents = torch.tensor(sol.to_numpy()).reshape(
514
+ (1, SD_IN_CHANNELS, SD_IN_HEIGHT, SD_IN_WIDTH)
515
+ )
516
+ latents = latents.to(TORCH_DTYPE).to(DEVICE)
517
+ img = SDPIPE(
518
+ prompt,
519
+ num_images_per_prompt=1,
520
+ latents=latents,
521
+ # num_inference_steps=1, # For testing.
522
+ ).images[0]
523
+
524
+ img = torch.from_numpy(np.array(img)).permute(2, 0, 1) / 255.0
525
+ imgs.append(img)
526
+ else:
527
+ imgs.append(torch.zeros((3, IMG_HEIGHT, IMG_WIDTH)))
528
+ from torchvision.utils import make_grid
529
+
530
+
531
+ def create_archive_tick_labels(measure_range, num_ticks):
532
+ delta = (measure_range[1] - measure_range[0]) / num_ticks
533
+ ticklabels = [round(delta * p + measure_range[0], 3) for p in range(num_ticks + 1)]
534
+ return ticklabels
535
+
536
+
537
+ plt.figure(figsize=(img_freq[0] * 2, img_freq[0] * 2))
538
+ img_grid = make_grid(imgs, nrow=img_freq[0], padding=0)
539
+ img_grid = np.transpose(img_grid.cpu().numpy(), (1, 2, 0))
540
+ plt.imshow(img_grid)
541
+
542
+ plt.xlabel("")
543
+ num_x_ticks = img_freq[0]
544
+ x_ticklabels = create_archive_tick_labels(measure_bounds[0], num_x_ticks)
545
+ x_tick_range = img_grid.shape[1]
546
+ x_ticks = np.arange(0, x_tick_range + 1e-9, step=x_tick_range / num_x_ticks)
547
+ plt.xticks(x_ticks, x_ticklabels)
548
+
549
+ plt.ylabel("")
550
+ num_y_ticks = img_freq[1]
551
+ y_ticklabels = create_archive_tick_labels(measure_bounds[1], num_y_ticks)
552
+ y_ticklabels.reverse()
553
+ y_tick_range = img_grid.shape[0]
554
+ y_ticks = np.arange(0, y_tick_range + 1e-9, step=y_tick_range / num_y_ticks)
555
+ plt.yticks(y_ticks, y_ticklabels)
556
+ plt.tight_layout()
557
+
558
+ return plt
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ streamlit
2
+ matplotlib
3
+ imageio
4
+ pydantic
5
+ torch
6
+ diffusers
7
+ dreamsim
8
+ ribs
9
+ ftfy
10
+ regex
11
+ tqdm
12
+ git+https://github.com/openai/CLIP.git