1lint commited on
Commit
919fef8
Β·
1 Parent(s): a4b1962

refactor code and fix cpu support

Browse files
app.py CHANGED
@@ -1,283 +1,19 @@
1
-
2
- # inpaint pipeline with fix to avoid noise added to latents during final iteration of denoising loop
3
- from inpaint_pipeline import SDInpaintPipeline as StableDiffusionInpaintPipelineLegacy
4
-
5
- from diffusers import (
6
- StableDiffusionPipeline,
7
- StableDiffusionImg2ImgPipeline,
8
- )
9
-
10
- import diffusers.schedulers
11
  import gradio as gr
12
- import torch
13
- import random
14
  from multiprocessing import cpu_count
15
- import json
16
- from PIL import Image
17
- import os
18
- import argparse
19
- import shutil
20
- import gc
21
-
22
- import importlib
23
-
24
- from textual_inversion import main as run_textual_inversion
25
-
26
- def pad_image(image):
27
- w, h = image.size
28
- if w == h:
29
- return image
30
- elif w > h:
31
- new_image = Image.new(image.mode, (w, w), (0, 0, 0))
32
- new_image.paste(image, (0, (w - h) // 2))
33
- return new_image
34
- else:
35
- new_image = Image.new(image.mode, (h, h), (0, 0, 0))
36
- new_image.paste(image, ((h - w) // 2, 0))
37
- return new_image
38
-
39
- _xformers_available = importlib.util.find_spec("xformers") is not None
40
- device = "cuda" if torch.cuda.is_available() else "cpu"
41
- low_vram_mode = False
42
-
43
- # scheduler dict includes superclass SchedulerMixin (it still generates reasonable images)
44
- scheduler_dict = {
45
- k: v
46
- for k, v in diffusers.schedulers.__dict__.items()
47
- if "Scheduler" in k and "Flax" not in k
48
- }
49
- scheduler_dict.pop(
50
- "VQDiffusionScheduler"
51
- ) # requires unique parameter, unlike other schedulers
52
- scheduler_names = list(scheduler_dict.keys())
53
- default_scheduler = scheduler_names[3] # expected to be DPM Multistep
54
-
55
- model_ids = [
56
- "andite/anything-v4.0",
57
- "hakurei/waifu-diffusion",
58
- "prompthero/openjourney-v2",
59
- "runwayml/stable-diffusion-v1-5",
60
- "johnslegers/epic-diffusion",
61
- "stabilityai/stable-diffusion-2-1",
62
- ]
63
-
64
- loaded_model_id = ""
65
-
66
-
67
- def load_pipe(
68
- model_id, scheduler_name, pipe_class=StableDiffusionPipeline, pipe_kwargs="{}"
69
- ):
70
- global pipe, loaded_model_id
71
-
72
- scheduler = scheduler_dict[scheduler_name]
73
-
74
- # load new weights from disk only when changing model_id
75
- if model_id != loaded_model_id:
76
- pipe = pipe_class.from_pretrained(
77
- model_id,
78
- torch_dtype=torch.float16,
79
- safety_checker=None,
80
- requires_safety_checker=False,
81
- scheduler=scheduler.from_pretrained(model_id, subfolder="scheduler"),
82
- **json.loads(pipe_kwargs),
83
- )
84
- loaded_model_id = model_id
85
-
86
- # if same model_id, instantiate new pipeline with same underlying pytorch objects to avoid reloading weights from disk
87
- elif pipe_class != pipe.__class__ or not isinstance(pipe.scheduler, scheduler):
88
- pipe.components["scheduler"] = scheduler.from_pretrained(
89
- model_id, subfolder="scheduler"
90
- )
91
- pipe = pipe_class(**pipe.components)
92
-
93
- if device == 'cuda':
94
- pipe = pipe.to(device)
95
- if _xformers_available:
96
- pipe.enable_xformers_memory_efficient_attention()
97
- print("using xformers")
98
- if low_vram_mode:
99
- pipe.enable_attention_slicing()
100
- print("using attention slicing to lower VRAM")
101
-
102
- return pipe
103
-
104
-
105
- pipe = None
106
- pipe = load_pipe(model_ids[0], default_scheduler)
107
-
108
- @torch.autocast(device)
109
- @torch.no_grad()
110
- def generate(
111
- model_name,
112
- scheduler_name,
113
- prompt,
114
- guidance,
115
- steps,
116
- n_images=1,
117
- width=512,
118
- height=512,
119
- seed=0,
120
- image=None,
121
- strength=0.5,
122
- inpaint_image=None,
123
- inpaint_strength=0.5,
124
- inpaint_radio='',
125
- neg_prompt="",
126
- pipe_class=StableDiffusionPipeline,
127
- pipe_kwargs="{}",
128
- ):
129
-
130
- if seed == -1:
131
- seed = random.randint(0, 2147483647)
132
-
133
- generator = torch.Generator("cuda").manual_seed(seed)
134
-
135
- pipe = load_pipe(
136
- model_id=model_name,
137
- scheduler_name=scheduler_name,
138
- pipe_class=pipe_class,
139
- pipe_kwargs=pipe_kwargs,
140
- )
141
-
142
- status_message = (
143
- f"Prompt: '{prompt}' | Seed: {seed} | Guidance: {guidance} | Scheduler: {scheduler_name} | Steps: {steps}"
144
- )
145
-
146
- if pipe_class == StableDiffusionPipeline:
147
- status_message = "Text to Image " + status_message
148
-
149
- result = pipe(
150
- prompt,
151
- negative_prompt=neg_prompt,
152
- num_images_per_prompt=n_images,
153
- num_inference_steps=int(steps),
154
- guidance_scale=guidance,
155
- width=width,
156
- height=height,
157
- generator=generator,
158
- )
159
-
160
- elif pipe_class == StableDiffusionImg2ImgPipeline:
161
-
162
- status_message = "Image to Image " + status_message
163
- print(image.size)
164
- image = image.resize((width, height))
165
- print(image.size)
166
-
167
- result = pipe(
168
- prompt,
169
- negative_prompt=neg_prompt,
170
- num_images_per_prompt=n_images,
171
- image=image,
172
- num_inference_steps=int(steps),
173
- strength=strength,
174
- guidance_scale=guidance,
175
- generator=generator,
176
- )
177
-
178
- elif pipe_class == StableDiffusionInpaintPipelineLegacy:
179
- status_message = "Inpainting " + status_message
180
-
181
- init_image = inpaint_image["image"].resize((width, height))
182
- mask = inpaint_image["mask"].resize((width, height))
183
-
184
-
185
- result = pipe(
186
- prompt,
187
- negative_prompt=neg_prompt,
188
- num_images_per_prompt=n_images,
189
- image=init_image,
190
- mask_image=mask,
191
- num_inference_steps=int(steps),
192
- strength=inpaint_strength,
193
- preserve_unmasked_image=(inpaint_radio == inpaint_options[0]),
194
- guidance_scale=guidance,
195
- generator=generator,
196
- )
197
-
198
- else:
199
- return None, f"Unhandled pipeline class: {pipe_class}", -1
200
-
201
- return result.images, status_message, seed
202
-
203
-
204
- # based on lvkaokao/textual-inversion-training
205
- def train_textual_inversion(model_name, scheduler_name, type_of_thing, files, concept_word, init_word, text_train_steps, text_train_bsz, text_learning_rate, progress=gr.Progress(track_tqdm=True)):
206
-
207
- pipe = load_pipe(
208
- model_id=model_name,
209
- scheduler_name=scheduler_name,
210
- pipe_class=StableDiffusionPipeline,
211
- )
212
-
213
- pipe.disable_xformers_memory_efficient_attention() # xformers handled by textual inversion script
214
-
215
- concept_dir = 'concept_images'
216
- output_dir = 'output_model'
217
- training_resolution = 512
218
-
219
- if os.path.exists(output_dir): shutil.rmtree('output_model')
220
- if os.path.exists(concept_dir): shutil.rmtree('concept_images')
221
-
222
- os.makedirs(concept_dir, exist_ok=True)
223
- os.makedirs(output_dir, exist_ok=True)
224
-
225
- gc.collect()
226
- torch.cuda.empty_cache()
227
-
228
- if(prompt == "" or prompt == None):
229
- raise gr.Error("You forgot to define your concept prompt")
230
-
231
- for j, file_temp in enumerate(files):
232
- file = Image.open(file_temp.name)
233
- image = pad_image(file)
234
- image = image.resize((training_resolution, training_resolution))
235
- extension = file_temp.name.split(".")[1]
236
- image = image.convert('RGB')
237
- image.save(f'{concept_dir}/{j+1}.{extension}', quality=100)
238
-
239
-
240
- args_general = argparse.Namespace(
241
- train_data_dir=concept_dir,
242
- learnable_property=type_of_thing,
243
- placeholder_token=concept_word,
244
- initializer_token=init_word,
245
- resolution=training_resolution,
246
- train_batch_size=text_train_bsz,
247
- gradient_accumulation_steps=1,
248
- gradient_checkpointing=True,
249
- mixed_precision='fp16',
250
- use_bf16=False,
251
- max_train_steps=int(text_train_steps),
252
- learning_rate=text_learning_rate,
253
- scale_lr=True,
254
- lr_scheduler="constant",
255
- lr_warmup_steps=0,
256
- output_dir=output_dir,
257
- )
258
-
259
- try:
260
- final_result = run_textual_inversion(pipe, args_general)
261
- except Exception as e:
262
- raise gr.Error(e)
263
-
264
- gc.collect()
265
- torch.cuda.empty_cache()
266
-
267
- return f'Finished training! Check the {output_dir} directory for saved model weights'
268
-
269
 
270
  default_img_size = 512
271
 
272
- with open("header.html") as fp:
273
  header = fp.read()
274
 
275
- with open("footer.html") as fp:
276
  footer = fp.read()
277
 
278
- with gr.Blocks(css="style.css") as demo:
279
 
280
- pipe_state = gr.State(lambda: StableDiffusionPipeline)
281
 
282
  gr.HTML(header)
283
 
@@ -293,7 +29,7 @@ with gr.Blocks(css="style.css") as demo:
293
 
294
  with gr.Column(scale=30):
295
  model_name = gr.Dropdown(
296
- label="Model", choices=model_ids, value=loaded_model_id
297
  )
298
  scheduler_name = gr.Dropdown(
299
  label="Scheduler", choices=scheduler_names, value=default_scheduler
@@ -305,10 +41,10 @@ with gr.Blocks(css="style.css") as demo:
305
  with gr.Column():
306
 
307
  with gr.Tab("Text to Image") as tab:
308
- tab.select(lambda: StableDiffusionPipeline, [], pipe_state)
309
 
310
  with gr.Tab("Image to image") as tab:
311
- tab.select(lambda: StableDiffusionImg2ImgPipeline, [], pipe_state)
312
 
313
  image = gr.Image(
314
  label="Image to Image",
@@ -326,7 +62,7 @@ with gr.Blocks(css="style.css") as demo:
326
  )
327
 
328
  with gr.Tab("Inpainting") as tab:
329
- tab.select(lambda: StableDiffusionInpaintPipelineLegacy, [], pipe_state)
330
 
331
  inpaint_image = gr.Image(
332
  label="Inpainting",
@@ -342,13 +78,26 @@ with gr.Blocks(css="style.css") as demo:
342
  step=0.02,
343
  value=0.8,
344
  )
345
- inpaint_options = ["preserve non-masked portions of image", "output entire inpainted image"]
346
- inpaint_radio = gr.Radio(inpaint_options, value=inpaint_options[0], show_label=False, interactive=True)
 
 
 
 
 
 
 
 
347
 
348
  with gr.Tab("Textual Inversion") as tab:
349
- tab.select(lambda: StableDiffusionPipeline, [], pipe_state)
350
 
351
- type_of_thing = gr.Dropdown(label="What would you like to train?", choices=["object", "person", "style"], value="object", interactive=True)
 
 
 
 
 
352
 
353
  text_train_bsz = gr.Slider(
354
  label="Training Batch Size",
@@ -358,14 +107,23 @@ with gr.Blocks(css="style.css") as demo:
358
  value=1,
359
  )
360
 
361
- files = gr.File(label=f'''Upload the images for your concept''', file_count="multiple", interactive=True, visible=True)
 
 
 
 
 
362
 
363
  text_train_steps = gr.Number(label="How many steps", value=1000)
364
 
365
- text_learning_rate = gr.Number(label="Learning Rate", value=5.e-4)
366
 
367
- concept_word = gr.Textbox(label=f'''concept word - use a unique, made up word to avoid collisions''')
368
- init_word = gr.Textbox(label=f'''initial word - to init the concept embedding''')
 
 
 
 
369
 
370
  textual_inversion_button = gr.Button(value="Train Textual Inversion")
371
 
@@ -436,17 +194,31 @@ with gr.Blocks(css="style.css") as demo:
436
  pipe_state,
437
  pipe_kwargs,
438
  ]
439
- outputs = [gallery, generation_details, seed]
440
 
441
  prompt.submit(generate, inputs=inputs, outputs=outputs)
442
  generate_button.click(generate, inputs=inputs, outputs=outputs)
443
 
444
- textual_inversion_inputs = [model_name, scheduler_name, type_of_thing, files, concept_word, init_word, text_train_steps, text_train_bsz, text_learning_rate]
 
 
 
 
 
 
 
 
 
 
445
 
446
- textual_inversion_button.click(train_textual_inversion, inputs=textual_inversion_inputs, outputs=[training_status])
 
 
 
 
447
 
448
 
449
- #demo = gr.TabbedInterface([demo, dreambooth_tab], ["Main", "Dreambooth"])
450
 
451
  demo.queue(concurrency_count=cpu_count())
452
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
 
 
2
  from multiprocessing import cpu_count
3
+ from utils.functions import generate, train_textual_inversion
4
+ from utils.shared import model_ids, scheduler_names, default_scheduler
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  default_img_size = 512
7
 
8
+ with open("html/header.html") as fp:
9
  header = fp.read()
10
 
11
+ with open("html/footer.html") as fp:
12
  footer = fp.read()
13
 
14
+ with gr.Blocks(css="html/style.css") as demo:
15
 
16
+ pipe_state = gr.State(lambda: 1)
17
 
18
  gr.HTML(header)
19
 
 
29
 
30
  with gr.Column(scale=30):
31
  model_name = gr.Dropdown(
32
+ label="Model", choices=model_ids, value=model_ids[0]
33
  )
34
  scheduler_name = gr.Dropdown(
35
  label="Scheduler", choices=scheduler_names, value=default_scheduler
 
41
  with gr.Column():
42
 
43
  with gr.Tab("Text to Image") as tab:
44
+ tab.select(lambda: 1, [], pipe_state)
45
 
46
  with gr.Tab("Image to image") as tab:
47
+ tab.select(lambda: 2, [], pipe_state)
48
 
49
  image = gr.Image(
50
  label="Image to Image",
 
62
  )
63
 
64
  with gr.Tab("Inpainting") as tab:
65
+ tab.select(lambda: 3, [], pipe_state)
66
 
67
  inpaint_image = gr.Image(
68
  label="Inpainting",
 
78
  step=0.02,
79
  value=0.8,
80
  )
81
+ inpaint_options = [
82
+ "preserve non-masked portions of image",
83
+ "output entire inpainted image",
84
+ ]
85
+ inpaint_radio = gr.Radio(
86
+ inpaint_options,
87
+ value=inpaint_options[0],
88
+ show_label=False,
89
+ interactive=True,
90
+ )
91
 
92
  with gr.Tab("Textual Inversion") as tab:
93
+ tab.select(lambda: 4, [], pipe_state)
94
 
95
+ type_of_thing = gr.Dropdown(
96
+ label="What would you like to train?",
97
+ choices=["object", "person", "style"],
98
+ value="object",
99
+ interactive=True,
100
+ )
101
 
102
  text_train_bsz = gr.Slider(
103
  label="Training Batch Size",
 
107
  value=1,
108
  )
109
 
110
+ files = gr.File(
111
+ label=f"""Upload the images for your concept""",
112
+ file_count="multiple",
113
+ interactive=True,
114
+ visible=True,
115
+ )
116
 
117
  text_train_steps = gr.Number(label="How many steps", value=1000)
118
 
119
+ text_learning_rate = gr.Number(label="Learning Rate", value=5.0e-4)
120
 
121
+ concept_word = gr.Textbox(
122
+ label=f"""concept word - use a unique, made up word to avoid collisions"""
123
+ )
124
+ init_word = gr.Textbox(
125
+ label=f"""initial word - to init the concept embedding"""
126
+ )
127
 
128
  textual_inversion_button = gr.Button(value="Train Textual Inversion")
129
 
 
194
  pipe_state,
195
  pipe_kwargs,
196
  ]
197
+ outputs = [gallery, generation_details]
198
 
199
  prompt.submit(generate, inputs=inputs, outputs=outputs)
200
  generate_button.click(generate, inputs=inputs, outputs=outputs)
201
 
202
+ textual_inversion_inputs = [
203
+ model_name,
204
+ scheduler_name,
205
+ type_of_thing,
206
+ files,
207
+ concept_word,
208
+ init_word,
209
+ text_train_steps,
210
+ text_train_bsz,
211
+ text_learning_rate,
212
+ ]
213
 
214
+ textual_inversion_button.click(
215
+ train_textual_inversion,
216
+ inputs=textual_inversion_inputs,
217
+ outputs=[training_status],
218
+ )
219
 
220
 
221
+ # demo = gr.TabbedInterface([demo, dreambooth_tab], ["Main", "Dreambooth"])
222
 
223
  demo.queue(concurrency_count=cpu_count())
224
 
footer.html β†’ html/footer.html RENAMED
File without changes
header.html β†’ html/header.html RENAMED
File without changes
style.css β†’ html/style.css RENAMED
File without changes
model_ids.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ andite/anything-v4.0
2
+ hakurei/waifu-diffusion
3
+ prompthero/openjourney-v2
4
+ runwayml/stable-diffusion-v1-5
5
+ johnslegers/epic-diffusion
6
+ stabilityai/stable-diffusion-2-1
test.ipynb ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 3,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "with open('model_ids.txt', 'r') as fp:\n",
10
+ " model_ids = fp.read().splitlines() "
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "code",
15
+ "execution_count": 4,
16
+ "metadata": {},
17
+ "outputs": [
18
+ {
19
+ "data": {
20
+ "text/plain": [
21
+ "['andite/anything-v4.0',\n",
22
+ " 'hakurei/waifu-diffusion',\n",
23
+ " 'prompthero/openjourney-v2',\n",
24
+ " 'runwayml/stable-diffusion-v1-5',\n",
25
+ " 'johnslegers/epic-diffusion',\n",
26
+ " 'stabilityai/stable-diffusion-2-1']"
27
+ ]
28
+ },
29
+ "execution_count": 4,
30
+ "metadata": {},
31
+ "output_type": "execute_result"
32
+ }
33
+ ],
34
+ "source": [
35
+ "model_ids"
36
+ ]
37
+ },
38
+ {
39
+ "cell_type": "code",
40
+ "execution_count": null,
41
+ "metadata": {},
42
+ "outputs": [],
43
+ "source": []
44
+ }
45
+ ],
46
+ "metadata": {
47
+ "kernelspec": {
48
+ "display_name": "ml",
49
+ "language": "python",
50
+ "name": "python3"
51
+ },
52
+ "language_info": {
53
+ "codemirror_mode": {
54
+ "name": "ipython",
55
+ "version": 3
56
+ },
57
+ "file_extension": ".py",
58
+ "mimetype": "text/x-python",
59
+ "name": "python",
60
+ "nbconvert_exporter": "python",
61
+ "pygments_lexer": "ipython3",
62
+ "version": "3.10.8"
63
+ },
64
+ "orig_nbformat": 4,
65
+ "vscode": {
66
+ "interpreter": {
67
+ "hash": "cbbcdde725e9a65f1cb734ac4223fed46e03daf1eb62d8ccb3c48face3871521"
68
+ }
69
+ }
70
+ },
71
+ "nbformat": 4,
72
+ "nbformat_minor": 2
73
+ }
utils/__init__.py ADDED
File without changes
utils/functions.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import random
4
+ from PIL import Image
5
+ import os
6
+ import argparse
7
+ import shutil
8
+ import gc
9
+ import importlib
10
+ import json
11
+
12
+ from diffusers import (
13
+ StableDiffusionPipeline,
14
+ StableDiffusionImg2ImgPipeline,
15
+ )
16
+
17
+
18
+ from .inpaint_pipeline import SDInpaintPipeline as StableDiffusionInpaintPipelineLegacy
19
+
20
+ from .textual_inversion import main as run_textual_inversion
21
+ from .shared import default_scheduler, scheduler_dict, model_ids
22
+
23
+
24
+ _xformers_available = importlib.util.find_spec("xformers") is not None
25
+ device = "cuda" if torch.cuda.is_available() else "cpu"
26
+ # device = 'cpu'
27
+ dtype = torch.float16 if device == "cuda" else torch.float32
28
+ low_vram_mode = False
29
+
30
+
31
+ tab_to_pipeline = {
32
+ 1: StableDiffusionPipeline,
33
+ 2: StableDiffusionImg2ImgPipeline,
34
+ 3: StableDiffusionInpaintPipelineLegacy,
35
+ }
36
+
37
+
38
+ def load_pipe(model_id, scheduler_name, tab_index=1, pipe_kwargs="{}"):
39
+ global pipe, loaded_model_id
40
+
41
+ scheduler = scheduler_dict[scheduler_name]
42
+
43
+ pipe_class = tab_to_pipeline[tab_index]
44
+
45
+ # load new weights from disk only when changing model_id
46
+ if model_id != loaded_model_id:
47
+ pipe = pipe_class.from_pretrained(
48
+ model_id,
49
+ torch_dtype=dtype,
50
+ safety_checker=None,
51
+ requires_safety_checker=False,
52
+ scheduler=scheduler.from_pretrained(model_id, subfolder="scheduler"),
53
+ **json.loads(pipe_kwargs),
54
+ )
55
+ loaded_model_id = model_id
56
+
57
+ # if same model_id, instantiate new pipeline with same underlying pytorch objects to avoid reloading weights from disk
58
+ elif pipe_class != pipe.__class__ or not isinstance(pipe.scheduler, scheduler):
59
+ pipe.components["scheduler"] = scheduler.from_pretrained(
60
+ model_id, subfolder="scheduler"
61
+ )
62
+ pipe = pipe_class(**pipe.components)
63
+
64
+ if device == "cuda":
65
+ pipe = pipe.to(device)
66
+ if _xformers_available:
67
+ pipe.enable_xformers_memory_efficient_attention()
68
+ print("using xformers")
69
+ if low_vram_mode:
70
+ pipe.enable_attention_slicing()
71
+ print("using attention slicing to lower VRAM")
72
+
73
+ return pipe
74
+
75
+
76
+ pipe = None
77
+ loaded_model_id = ""
78
+ pipe = load_pipe(model_ids[0], default_scheduler)
79
+
80
+
81
+ def pad_image(image):
82
+ w, h = image.size
83
+ if w == h:
84
+ return image
85
+ elif w > h:
86
+ new_image = Image.new(image.mode, (w, w), (0, 0, 0))
87
+ new_image.paste(image, (0, (w - h) // 2))
88
+ return new_image
89
+ else:
90
+ new_image = Image.new(image.mode, (h, h), (0, 0, 0))
91
+ new_image.paste(image, ((h - w) // 2, 0))
92
+ return new_image
93
+
94
+
95
+ @torch.no_grad()
96
+ def generate(
97
+ model_name,
98
+ scheduler_name,
99
+ prompt,
100
+ guidance,
101
+ steps,
102
+ n_images=1,
103
+ width=512,
104
+ height=512,
105
+ seed=0,
106
+ image=None,
107
+ strength=0.5,
108
+ inpaint_image=None,
109
+ inpaint_strength=0.5,
110
+ inpaint_radio="",
111
+ neg_prompt="",
112
+ tab_index=1,
113
+ pipe_kwargs="{}",
114
+ progress=gr.Progress(track_tqdm=True),
115
+ ):
116
+
117
+ if seed == -1:
118
+ seed = random.randint(0, 2147483647)
119
+
120
+ generator = torch.Generator(device).manual_seed(seed)
121
+
122
+ pipe = load_pipe(
123
+ model_id=model_name,
124
+ scheduler_name=scheduler_name,
125
+ tab_index=tab_index,
126
+ pipe_kwargs=pipe_kwargs,
127
+ )
128
+
129
+ status_message = f"Prompt: '{prompt}' | Seed: {seed} | Guidance: {guidance} | Scheduler: {scheduler_name} | Steps: {steps}"
130
+
131
+ if tab_index == 1:
132
+ status_message = "Text to Image " + status_message
133
+
134
+ result = pipe(
135
+ prompt,
136
+ negative_prompt=neg_prompt,
137
+ num_images_per_prompt=n_images,
138
+ num_inference_steps=int(steps),
139
+ guidance_scale=guidance,
140
+ width=width,
141
+ height=height,
142
+ generator=generator,
143
+ )
144
+
145
+ elif tab_index == 2:
146
+
147
+ status_message = "Image to Image " + status_message
148
+ print(image.size)
149
+ image = image.resize((width, height))
150
+ print(image.size)
151
+
152
+ result = pipe(
153
+ prompt,
154
+ negative_prompt=neg_prompt,
155
+ num_images_per_prompt=n_images,
156
+ image=image,
157
+ num_inference_steps=int(steps),
158
+ strength=strength,
159
+ guidance_scale=guidance,
160
+ generator=generator,
161
+ )
162
+
163
+ elif tab_index == 3:
164
+ status_message = "Inpainting " + status_message
165
+
166
+ init_image = inpaint_image["image"].resize((width, height))
167
+ mask = inpaint_image["mask"].resize((width, height))
168
+
169
+ result = pipe(
170
+ prompt,
171
+ negative_prompt=neg_prompt,
172
+ num_images_per_prompt=n_images,
173
+ image=init_image,
174
+ mask_image=mask,
175
+ num_inference_steps=int(steps),
176
+ strength=inpaint_strength,
177
+ preserve_unmasked_image=(
178
+ inpaint_radio == "preserve non-masked portions of image"
179
+ ),
180
+ guidance_scale=guidance,
181
+ generator=generator,
182
+ )
183
+
184
+ else:
185
+ return None, f"Unhandled tab index: {tab_index}"
186
+
187
+ return result.images, status_message
188
+
189
+
190
+ # based on lvkaokao/textual-inversion-training
191
+ def train_textual_inversion(
192
+ model_name,
193
+ scheduler_name,
194
+ type_of_thing,
195
+ files,
196
+ concept_word,
197
+ init_word,
198
+ text_train_steps,
199
+ text_train_bsz,
200
+ text_learning_rate,
201
+ progress=gr.Progress(track_tqdm=True),
202
+ ):
203
+
204
+ if device == "cpu":
205
+ raise gr.Error("Textual inversion training not supported on CPU")
206
+
207
+ pipe = load_pipe(
208
+ model_id=model_name,
209
+ scheduler_name=scheduler_name,
210
+ tab_index=1,
211
+ )
212
+
213
+ pipe.disable_xformers_memory_efficient_attention() # xformers handled by textual inversion script
214
+
215
+ concept_dir = "concept_images"
216
+ output_dir = "output_model"
217
+ training_resolution = 512
218
+
219
+ if os.path.exists(output_dir):
220
+ shutil.rmtree("output_model")
221
+ if os.path.exists(concept_dir):
222
+ shutil.rmtree("concept_images")
223
+
224
+ os.makedirs(concept_dir, exist_ok=True)
225
+ os.makedirs(output_dir, exist_ok=True)
226
+
227
+ gc.collect()
228
+ torch.cuda.empty_cache()
229
+
230
+ if concept_word == "" or concept_word == None:
231
+ raise gr.Error("You forgot to define your concept prompt")
232
+
233
+ for j, file_temp in enumerate(files):
234
+ file = Image.open(file_temp.name)
235
+ image = pad_image(file)
236
+ image = image.resize((training_resolution, training_resolution))
237
+ extension = file_temp.name.split(".")[1]
238
+ image = image.convert("RGB")
239
+ image.save(f"{concept_dir}/{j+1}.{extension}", quality=100)
240
+
241
+ args_general = argparse.Namespace(
242
+ train_data_dir=concept_dir,
243
+ learnable_property=type_of_thing,
244
+ placeholder_token=concept_word,
245
+ initializer_token=init_word,
246
+ resolution=training_resolution,
247
+ train_batch_size=text_train_bsz,
248
+ gradient_accumulation_steps=1,
249
+ gradient_checkpointing=True,
250
+ mixed_precision="fp16",
251
+ use_bf16=False,
252
+ max_train_steps=int(text_train_steps),
253
+ learning_rate=text_learning_rate,
254
+ scale_lr=True,
255
+ lr_scheduler="constant",
256
+ lr_warmup_steps=0,
257
+ output_dir=output_dir,
258
+ )
259
+
260
+ try:
261
+ final_result = run_textual_inversion(pipe, args_general)
262
+ except Exception as e:
263
+ raise gr.Error(e)
264
+
265
+ pipe.text_encoder = pipe.text_encoder.eval().to(device, dtype=dtype)
266
+ pipe.unet = pipe.unet.eval().to(device, dtype=dtype)
267
+
268
+ gc.collect()
269
+ torch.cuda.empty_cache()
270
+
271
+ return (
272
+ f"Finished training! Check the {output_dir} directory for saved model weights"
273
+ )
inpaint_pipeline.py β†’ utils/inpaint_pipeline.py RENAMED
@@ -1,4 +1,3 @@
1
-
2
  # Licensed under the Apache License, Version 2.0 (the "License");
3
  # you may not use this file except in compliance with the License.
4
  # You may obtain a copy of the License at
@@ -16,21 +15,30 @@ from typing import Optional, Union, List, Callable
16
  import PIL
17
  import numpy as np
18
 
19
- from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint_legacy import preprocess_image, deprecate, StableDiffusionInpaintPipelineLegacy, StableDiffusionPipelineOutput, PIL_INTERPOLATION
 
 
 
 
 
 
 
20
 
21
  def preprocess_mask(mask, scale_factor=8):
22
  mask = mask.convert("L")
23
  w, h = mask.size
24
  w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
25
 
26
- #input_mask = mask.resize((w, h), resample=PIL_INTERPOLATION["nearest"])
27
  input_mask = np.array(mask).astype(np.float32) / 255.0
28
  input_mask = np.tile(input_mask, (3, 1, 1))
29
  input_mask = input_mask[None].transpose(0, 1, 2, 3) # add batch dimension
30
  input_mask = 1 - input_mask # repaint white, keep black
31
  input_mask = torch.round(torch.from_numpy(input_mask))
32
 
33
- mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"])
 
 
34
  mask = np.array(mask).astype(np.float32) / 255.0
35
  mask = np.tile(mask, (4, 1, 1))
36
  mask = mask[None].transpose(0, 1, 2, 3) # add batch dimension
@@ -40,7 +48,6 @@ def preprocess_mask(mask, scale_factor=8):
40
  return mask, input_mask
41
 
42
 
43
-
44
  class SDInpaintPipeline(StableDiffusionInpaintPipelineLegacy):
45
 
46
  # forward call is same as StableDiffusionInpaintPipelineLegacy, but with line added to avoid noise added to final latents right before decoding step
@@ -120,8 +127,8 @@ class SDInpaintPipeline(StableDiffusionInpaintPipelineLegacy):
120
  The frequency at which the `callback` function will be called. If not specified, the callback will be
121
  called at every step.
122
  preserve_unmasked_image (`bool`, *optional*, defaults to `True`):
123
- Whether or not to preserve the unmasked portions of the original image in the inpainted output. If False,
124
- inpainting of the masked latents may produce noticeable distortion of unmasked portions of the decoded
125
  image.
126
 
127
  Returns:
@@ -148,7 +155,11 @@ class SDInpaintPipeline(StableDiffusionInpaintPipelineLegacy):
148
 
149
  # 3. Encode input prompt
150
  text_embeddings = self._encode_prompt(
151
- prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
 
 
 
 
152
  )
153
 
154
  # 4. Preprocess image and mask
@@ -157,17 +168,27 @@ class SDInpaintPipeline(StableDiffusionInpaintPipelineLegacy):
157
 
158
  # get mask corresponding to input latents as well as image
159
  if not isinstance(mask_image, torch.FloatTensor):
160
- mask_image, input_mask_image = preprocess_mask(mask_image, self.vae_scale_factor)
 
 
161
 
162
  # 5. set timesteps
163
  self.scheduler.set_timesteps(num_inference_steps, device=device)
164
- timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
 
 
165
  latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
166
 
167
  # 6. Prepare latent variables
168
  # encode the init image into latents and scale the latents
169
  latents, init_latents_orig, noise = self.prepare_latents(
170
- image, latent_timestep, batch_size, num_images_per_prompt, text_embeddings.dtype, device, generator
 
 
 
 
 
 
171
  )
172
 
173
  # 7. Prepare mask latent
@@ -181,33 +202,47 @@ class SDInpaintPipeline(StableDiffusionInpaintPipelineLegacy):
181
  num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
182
  with self.progress_bar(total=num_inference_steps) as progress_bar:
183
  for i, t in enumerate(timesteps):
184
-
185
  # expand the latents if we are doing classifier free guidance
186
- latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
187
- latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
 
 
 
 
188
 
189
  # predict the noise residual
190
- noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
 
 
191
 
192
  # perform guidance
193
  if do_classifier_free_guidance:
194
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
195
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
 
 
196
 
197
  # compute the previous noisy sample x_t -> x_t-1
198
- latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
 
 
199
  # masking
200
  if add_predicted_noise:
201
  init_latents_proper = self.scheduler.add_noise(
202
  init_latents_orig, noise_pred_uncond, torch.tensor([t])
203
  )
204
  else:
205
- init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
 
 
206
 
207
  latents = (init_latents_proper * mask) + (latents * (1 - mask))
208
 
209
  # call the callback, if provided
210
- if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
 
 
211
  progress_bar.update()
212
  if callback is not None and i % callback_steps == 0:
213
  callback(i, t, latents)
@@ -225,7 +260,9 @@ class SDInpaintPipeline(StableDiffusionInpaintPipelineLegacy):
225
  # restore unmasked parts of image with original image
226
  input_mask_image = input_mask_image.to(inpaint_image)
227
  image = image.to(inpaint_image)
228
- image = (image * input_mask_image) + (inpaint_image * (1 - input_mask_image)) # use original unmasked portions of image to avoid degradation
 
 
229
 
230
  # post-processing of image
231
  image = (image / 2 + 0.5).clamp(0, 1)
@@ -235,7 +272,9 @@ class SDInpaintPipeline(StableDiffusionInpaintPipelineLegacy):
235
  image = self.decode_latents(latents)
236
 
237
  # 11. Run safety checker
238
- image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)
 
 
239
 
240
  # 12. Convert to PIL
241
  if output_type == "pil":
@@ -244,4 +283,6 @@ class SDInpaintPipeline(StableDiffusionInpaintPipelineLegacy):
244
  if not return_dict:
245
  return (image, has_nsfw_concept)
246
 
247
- return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
 
 
 
 
1
  # Licensed under the Apache License, Version 2.0 (the "License");
2
  # you may not use this file except in compliance with the License.
3
  # You may obtain a copy of the License at
 
15
  import PIL
16
  import numpy as np
17
 
18
+ from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint_legacy import (
19
+ preprocess_image,
20
+ deprecate,
21
+ StableDiffusionInpaintPipelineLegacy,
22
+ StableDiffusionPipelineOutput,
23
+ PIL_INTERPOLATION,
24
+ )
25
+
26
 
27
  def preprocess_mask(mask, scale_factor=8):
28
  mask = mask.convert("L")
29
  w, h = mask.size
30
  w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
31
 
32
+ # input_mask = mask.resize((w, h), resample=PIL_INTERPOLATION["nearest"])
33
  input_mask = np.array(mask).astype(np.float32) / 255.0
34
  input_mask = np.tile(input_mask, (3, 1, 1))
35
  input_mask = input_mask[None].transpose(0, 1, 2, 3) # add batch dimension
36
  input_mask = 1 - input_mask # repaint white, keep black
37
  input_mask = torch.round(torch.from_numpy(input_mask))
38
 
39
+ mask = mask.resize(
40
+ (w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"]
41
+ )
42
  mask = np.array(mask).astype(np.float32) / 255.0
43
  mask = np.tile(mask, (4, 1, 1))
44
  mask = mask[None].transpose(0, 1, 2, 3) # add batch dimension
 
48
  return mask, input_mask
49
 
50
 
 
51
  class SDInpaintPipeline(StableDiffusionInpaintPipelineLegacy):
52
 
53
  # forward call is same as StableDiffusionInpaintPipelineLegacy, but with line added to avoid noise added to final latents right before decoding step
 
127
  The frequency at which the `callback` function will be called. If not specified, the callback will be
128
  called at every step.
129
  preserve_unmasked_image (`bool`, *optional*, defaults to `True`):
130
+ Whether or not to preserve the unmasked portions of the original image in the inpainted output. If False,
131
+ inpainting of the masked latents may produce noticeable distortion of unmasked portions of the decoded
132
  image.
133
 
134
  Returns:
 
155
 
156
  # 3. Encode input prompt
157
  text_embeddings = self._encode_prompt(
158
+ prompt,
159
+ device,
160
+ num_images_per_prompt,
161
+ do_classifier_free_guidance,
162
+ negative_prompt,
163
  )
164
 
165
  # 4. Preprocess image and mask
 
168
 
169
  # get mask corresponding to input latents as well as image
170
  if not isinstance(mask_image, torch.FloatTensor):
171
+ mask_image, input_mask_image = preprocess_mask(
172
+ mask_image, self.vae_scale_factor
173
+ )
174
 
175
  # 5. set timesteps
176
  self.scheduler.set_timesteps(num_inference_steps, device=device)
177
+ timesteps, num_inference_steps = self.get_timesteps(
178
+ num_inference_steps, strength, device
179
+ )
180
  latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
181
 
182
  # 6. Prepare latent variables
183
  # encode the init image into latents and scale the latents
184
  latents, init_latents_orig, noise = self.prepare_latents(
185
+ image,
186
+ latent_timestep,
187
+ batch_size,
188
+ num_images_per_prompt,
189
+ text_embeddings.dtype,
190
+ device,
191
+ generator,
192
  )
193
 
194
  # 7. Prepare mask latent
 
202
  num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
203
  with self.progress_bar(total=num_inference_steps) as progress_bar:
204
  for i, t in enumerate(timesteps):
205
+
206
  # expand the latents if we are doing classifier free guidance
207
+ latent_model_input = (
208
+ torch.cat([latents] * 2) if do_classifier_free_guidance else latents
209
+ )
210
+ latent_model_input = self.scheduler.scale_model_input(
211
+ latent_model_input, t
212
+ )
213
 
214
  # predict the noise residual
215
+ noise_pred = self.unet(
216
+ latent_model_input, t, encoder_hidden_states=text_embeddings
217
+ ).sample
218
 
219
  # perform guidance
220
  if do_classifier_free_guidance:
221
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
222
+ noise_pred = noise_pred_uncond + guidance_scale * (
223
+ noise_pred_text - noise_pred_uncond
224
+ )
225
 
226
  # compute the previous noisy sample x_t -> x_t-1
227
+ latents = self.scheduler.step(
228
+ noise_pred, t, latents, **extra_step_kwargs
229
+ ).prev_sample
230
  # masking
231
  if add_predicted_noise:
232
  init_latents_proper = self.scheduler.add_noise(
233
  init_latents_orig, noise_pred_uncond, torch.tensor([t])
234
  )
235
  else:
236
+ init_latents_proper = self.scheduler.add_noise(
237
+ init_latents_orig, noise, torch.tensor([t])
238
+ )
239
 
240
  latents = (init_latents_proper * mask) + (latents * (1 - mask))
241
 
242
  # call the callback, if provided
243
+ if i == len(timesteps) - 1 or (
244
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
245
+ ):
246
  progress_bar.update()
247
  if callback is not None and i % callback_steps == 0:
248
  callback(i, t, latents)
 
260
  # restore unmasked parts of image with original image
261
  input_mask_image = input_mask_image.to(inpaint_image)
262
  image = image.to(inpaint_image)
263
+ image = (image * input_mask_image) + (
264
+ inpaint_image * (1 - input_mask_image)
265
+ ) # use original unmasked portions of image to avoid degradation
266
 
267
  # post-processing of image
268
  image = (image / 2 + 0.5).clamp(0, 1)
 
272
  image = self.decode_latents(latents)
273
 
274
  # 11. Run safety checker
275
+ image, has_nsfw_concept = self.run_safety_checker(
276
+ image, device, text_embeddings.dtype
277
+ )
278
 
279
  # 12. Convert to PIL
280
  if output_type == "pil":
 
283
  if not return_dict:
284
  return (image, has_nsfw_concept)
285
 
286
+ return StableDiffusionPipelineOutput(
287
+ images=image, nsfw_content_detected=has_nsfw_concept
288
+ )
utils/shared.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import diffusers.schedulers
2
+
3
+ # scheduler dict includes superclass SchedulerMixin (it still generates reasonable images)
4
+ scheduler_dict = {
5
+ k: v
6
+ for k, v in diffusers.schedulers.__dict__.items()
7
+ if "Scheduler" in k and "Flax" not in k
8
+ }
9
+ scheduler_dict.pop(
10
+ "VQDiffusionScheduler"
11
+ ) # requires unique parameter, unlike other schedulers
12
+ scheduler_names = list(scheduler_dict.keys())
13
+ default_scheduler = scheduler_names[3] # expected to be DPM Multistep
14
+
15
+ with open("model_ids.txt", "r") as fp:
16
+ model_ids = fp.read().splitlines()
textual_inversion.py β†’ utils/textual_inversion.py RENAMED
@@ -34,7 +34,12 @@ import transformers
34
  from accelerate import Accelerator
35
  from accelerate.logging import get_logger
36
  from accelerate.utils import set_seed
37
- from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
 
 
 
 
 
38
  from diffusers.optimization import get_scheduler
39
  from diffusers.utils import check_min_version
40
  from diffusers.utils.import_utils import is_xformers_available
@@ -76,7 +81,11 @@ logger = get_logger(__name__)
76
 
77
  def save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path):
78
  logger.info("Saving embeddings")
79
- learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_id]
 
 
 
 
80
  learned_embeds_dict = {args.placeholder_token: learned_embeds.detach().cpu()}
81
  torch.save(learned_embeds_dict, save_path)
82
 
@@ -114,7 +123,10 @@ def parse_args():
114
  help="Pretrained tokenizer name or path if not the same as model_name",
115
  )
116
  parser.add_argument(
117
- "--train_data_dir", type=str, default=None, help="A folder containing the training data."
 
 
 
118
  )
119
  parser.add_argument(
120
  "--placeholder_token",
@@ -123,18 +135,33 @@ def parse_args():
123
  help="A token to use as a placeholder for the concept.",
124
  )
125
  parser.add_argument(
126
- "--initializer_token", type=str, default=None, help="A token to use as initializer word."
 
 
 
127
  )
128
 
129
- parser.add_argument("--learnable_property", type=str, default="object", help="Choose between 'object' and 'style'")
130
- parser.add_argument("--repeats", type=int, default=100, help="How many times to repeat the training data.")
 
 
 
 
 
 
 
 
 
 
131
  parser.add_argument(
132
  "--output_dir",
133
  type=str,
134
  default="text-inversion-model",
135
  help="The output directory where the model predictions and checkpoints will be written.",
136
  )
137
- parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
 
 
138
  parser.add_argument(
139
  "--resolution",
140
  type=int,
@@ -145,10 +172,15 @@ def parse_args():
145
  ),
146
  )
147
  parser.add_argument(
148
- "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution"
 
 
149
  )
150
  parser.add_argument(
151
- "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
 
 
 
152
  )
153
  parser.add_argument("--num_train_epochs", type=int, default=100)
154
  parser.add_argument(
@@ -190,14 +222,43 @@ def parse_args():
190
  ),
191
  )
192
  parser.add_argument(
193
- "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
  )
195
- parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
196
- parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
197
- parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
198
- parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
199
- parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
200
- parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
201
  parser.add_argument(
202
  "--hub_model_id",
203
  type=str,
@@ -241,7 +302,12 @@ def parse_args():
241
  ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
242
  ),
243
  )
244
- parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
 
 
 
 
 
245
  parser.add_argument(
246
  "--checkpointing_steps",
247
  type=int,
@@ -261,7 +327,9 @@ def parse_args():
261
  ),
262
  )
263
  parser.add_argument(
264
- "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
 
 
265
  )
266
 
267
  args = parser.parse_args()
@@ -269,7 +337,7 @@ def parse_args():
269
  if env_local_rank != -1 and env_local_rank != args.local_rank:
270
  args.local_rank = env_local_rank
271
 
272
- #if args.train_data_dir is None:
273
  # raise ValueError("You must specify a train data directory.")
274
 
275
  return args
@@ -350,7 +418,10 @@ class TextualInversionDataset(Dataset):
350
  self.center_crop = center_crop
351
  self.flip_p = flip_p
352
 
353
- self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]
 
 
 
354
 
355
  self.num_images = len(self.image_paths)
356
  self._length = self.num_images
@@ -365,7 +436,11 @@ class TextualInversionDataset(Dataset):
365
  "lanczos": PIL_INTERPOLATION["lanczos"],
366
  }[interpolation]
367
 
368
- self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small
 
 
 
 
369
  self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)
370
 
371
  def __len__(self):
@@ -394,14 +469,13 @@ class TextualInversionDataset(Dataset):
394
 
395
  if self.center_crop:
396
  crop = min(img.shape[0], img.shape[1])
397
- (
398
- h,
399
- w,
400
- ) = (
401
  img.shape[0],
402
  img.shape[1],
403
  )
404
- img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2]
 
 
405
 
406
  image = Image.fromarray(img)
407
  image = image.resize((self.size, self.size), resample=self.interpolation)
@@ -414,7 +488,9 @@ class TextualInversionDataset(Dataset):
414
  return example
415
 
416
 
417
- def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
 
 
418
  if token is None:
419
  token = HfFolder.get_token()
420
  if organization is None:
@@ -424,7 +500,6 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token:
424
  return f"{organization}/{model_id}"
425
 
426
 
427
-
428
  def main(pipe, args_imported):
429
 
430
  args = parse_args()
@@ -464,11 +539,15 @@ def main(pipe, args_imported):
464
  if accelerator.is_main_process:
465
  if args.push_to_hub:
466
  if args.hub_model_id is None:
467
- repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
 
 
468
  else:
469
  repo_name = args.hub_model_id
470
  create_repo(repo_name, exist_ok=True, token=args.hub_token)
471
- repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)
 
 
472
 
473
  with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
474
  if "step_*" not in gitignore:
@@ -530,7 +609,9 @@ def main(pipe, args_imported):
530
  if is_xformers_available():
531
  unet.enable_xformers_memory_efficient_attention()
532
  else:
533
- raise ValueError("xformers is not available. Make sure it is installed correctly")
 
 
534
 
535
  # Enable TF32 for faster training on Ampere GPUs,
536
  # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
@@ -539,7 +620,10 @@ def main(pipe, args_imported):
539
 
540
  if args.scale_lr:
541
  args.learning_rate = (
542
- args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
 
 
 
543
  )
544
 
545
  # Initialize the optimizer
@@ -562,11 +646,15 @@ def main(pipe, args_imported):
562
  center_crop=args.center_crop,
563
  set="train",
564
  )
565
- train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=True)
 
 
566
 
567
  # Scheduler and math around the number of training steps.
568
  overrode_max_train_steps = False
569
- num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
 
 
570
  if args.max_train_steps is None:
571
  args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
572
  overrode_max_train_steps = True
@@ -597,7 +685,9 @@ def main(pipe, args_imported):
597
  text_encoder.to(accelerator.device, dtype=torch.float32)
598
 
599
  # We need to recalculate our total training steps as the size of the training dataloader may have changed.
600
- num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
 
 
601
  if overrode_max_train_steps:
602
  args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
603
  # Afterwards we recalculate our number of training epochs
@@ -609,13 +699,19 @@ def main(pipe, args_imported):
609
  accelerator.init_trackers("textual_inversion", config=vars(args))
610
 
611
  # Train!
612
- total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
 
 
 
 
613
 
614
  logger.info("***** Running training *****")
615
  logger.info(f" Num examples = {len(train_dataset)}")
616
  logger.info(f" Num Epochs = {args.num_train_epochs}")
617
  logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
618
- logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
 
 
619
  logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
620
  logger.info(f" Total optimization steps = {args.max_train_steps}")
621
  global_step = 0
@@ -640,31 +736,51 @@ def main(pipe, args_imported):
640
  resume_step = resume_global_step % num_update_steps_per_epoch
641
 
642
  # Only show the progress bar once on each machine.
643
- progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
 
 
 
644
  progress_bar.set_description("Steps")
645
 
646
  # keep original embeddings as reference
647
- orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.clone()
 
 
 
 
648
 
649
- for epoch in (range(first_epoch, args.num_train_epochs)):
650
  text_encoder.train()
651
  for step, batch in enumerate(train_dataloader):
652
  # Skip steps until we reach the resumed step
653
- if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
 
 
 
 
654
  if step % args.gradient_accumulation_steps == 0:
655
  progress_bar.update(1)
656
  continue
657
 
658
  with accelerator.accumulate(text_encoder):
659
  # Convert images to latent space
660
- latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample().detach()
 
 
 
 
661
  latents = latents * 0.18215
662
 
663
  # Sample noise that we'll add to the latents
664
  noise = torch.randn_like(latents)
665
  bsz = latents.shape[0]
666
  # Sample a random timestep for each image
667
- timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
 
 
 
 
 
668
  timesteps = timesteps.long()
669
 
670
  # Add noise to the latents according to the noise magnitude at each timestep
@@ -672,10 +788,14 @@ def main(pipe, args_imported):
672
  noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
673
 
674
  # Get the text embedding for conditioning
675
- encoder_hidden_states = text_encoder(batch["input_ids"])[0].to(dtype=weight_dtype)
 
 
676
 
677
  # Predict the noise residual
678
- model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
 
 
679
 
680
  # Get the target for loss depending on the prediction type
681
  if noise_scheduler.config.prediction_type == "epsilon":
@@ -683,7 +803,9 @@ def main(pipe, args_imported):
683
  elif noise_scheduler.config.prediction_type == "v_prediction":
684
  target = noise_scheduler.get_velocity(latents, noise, timesteps)
685
  else:
686
- raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
 
 
687
 
688
  loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
689
 
@@ -694,8 +816,12 @@ def main(pipe, args_imported):
694
  else:
695
  grads = text_encoder.get_input_embeddings().weight.grad
696
  # Get the index for tokens that we want to zero the grads for
697
- index_grads_to_zero = torch.arange(len(tokenizer)) != placeholder_token_id
698
- grads.data[index_grads_to_zero, :] = grads.data[index_grads_to_zero, :].fill_(0)
 
 
 
 
699
 
700
  optimizer.step()
701
  lr_scheduler.step()
@@ -704,21 +830,31 @@ def main(pipe, args_imported):
704
  # Let's make sure we don't update any embedding weights besides the newly added token
705
  index_no_updates = torch.arange(len(tokenizer)) != placeholder_token_id
706
  with torch.no_grad():
707
- accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[
 
 
 
 
708
  index_no_updates
709
- ] = orig_embeds_params[index_no_updates]
710
 
711
  # Checks if the accelerator has performed an optimization step behind the scenes
712
  if accelerator.sync_gradients:
713
  progress_bar.update(1)
714
  global_step += 1
715
  if global_step % args.save_steps == 0:
716
- save_path = os.path.join(args.output_dir, f"{args.placeholder_token}-{global_step}.bin")
717
- save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path)
 
 
 
 
718
 
719
  if global_step % args.checkpointing_steps == 0:
720
  if accelerator.is_main_process:
721
- save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
 
 
722
  accelerator.save_state(save_path)
723
  logger.info(f"Saved state to {save_path}")
724
 
@@ -733,7 +869,9 @@ def main(pipe, args_imported):
733
  accelerator.wait_for_everyone()
734
  if accelerator.is_main_process:
735
  if args.push_to_hub and args.only_save_embeds:
736
- logger.warn("Enabling full model saving because --push_to_hub=True was specified.")
 
 
737
  save_full_model = True
738
  else:
739
  save_full_model = not args.only_save_embeds
@@ -744,35 +882,35 @@ def main(pipe, args_imported):
744
  save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path)
745
 
746
  if args.push_to_hub:
747
- repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
 
 
748
 
749
  accelerator.end_training()
750
- text_encoder.eval()
751
- unet.eval()
752
- vae.eval()
753
 
754
 
755
  if __name__ == "__main__":
756
- pipeline = StableDiffusionPipeline.from_pretrained('andite/anything-v4.0', torch_dtype=torch.float16)
757
-
758
- imported_args = argparse.Namespace(
759
- train_data_dir="concept_images",
760
- learnable_property='object',
761
- placeholder_token='redeyegirl',
762
- initializer_token='girl',
763
- resolution=512,
764
- train_batch_size=1,
765
- gradient_accumulation_steps=1,
766
- gradient_checkpointing=True,
767
- mixed_precision='fp16',
768
- use_bf16=False,
769
- max_train_steps=1000,
770
- learning_rate=5.0e-4,
771
- scale_lr=False,
772
- lr_scheduler="constant",
773
- lr_warmup_steps=0,
774
- output_dir="output_model",
775
- )
776
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
777
 
778
  main(pipeline, imported_args)
 
34
  from accelerate import Accelerator
35
  from accelerate.logging import get_logger
36
  from accelerate.utils import set_seed
37
+ from diffusers import (
38
+ AutoencoderKL,
39
+ DDPMScheduler,
40
+ StableDiffusionPipeline,
41
+ UNet2DConditionModel,
42
+ )
43
  from diffusers.optimization import get_scheduler
44
  from diffusers.utils import check_min_version
45
  from diffusers.utils.import_utils import is_xformers_available
 
81
 
82
  def save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path):
83
  logger.info("Saving embeddings")
84
+ learned_embeds = (
85
+ accelerator.unwrap_model(text_encoder)
86
+ .get_input_embeddings()
87
+ .weight[placeholder_token_id]
88
+ )
89
  learned_embeds_dict = {args.placeholder_token: learned_embeds.detach().cpu()}
90
  torch.save(learned_embeds_dict, save_path)
91
 
 
123
  help="Pretrained tokenizer name or path if not the same as model_name",
124
  )
125
  parser.add_argument(
126
+ "--train_data_dir",
127
+ type=str,
128
+ default=None,
129
+ help="A folder containing the training data.",
130
  )
131
  parser.add_argument(
132
  "--placeholder_token",
 
135
  help="A token to use as a placeholder for the concept.",
136
  )
137
  parser.add_argument(
138
+ "--initializer_token",
139
+ type=str,
140
+ default=None,
141
+ help="A token to use as initializer word.",
142
  )
143
 
144
+ parser.add_argument(
145
+ "--learnable_property",
146
+ type=str,
147
+ default="object",
148
+ help="Choose between 'object' and 'style'",
149
+ )
150
+ parser.add_argument(
151
+ "--repeats",
152
+ type=int,
153
+ default=100,
154
+ help="How many times to repeat the training data.",
155
+ )
156
  parser.add_argument(
157
  "--output_dir",
158
  type=str,
159
  default="text-inversion-model",
160
  help="The output directory where the model predictions and checkpoints will be written.",
161
  )
162
+ parser.add_argument(
163
+ "--seed", type=int, default=None, help="A seed for reproducible training."
164
+ )
165
  parser.add_argument(
166
  "--resolution",
167
  type=int,
 
172
  ),
173
  )
174
  parser.add_argument(
175
+ "--center_crop",
176
+ action="store_true",
177
+ help="Whether to center crop images before resizing to resolution",
178
  )
179
  parser.add_argument(
180
+ "--train_batch_size",
181
+ type=int,
182
+ default=16,
183
+ help="Batch size (per device) for the training dataloader.",
184
  )
185
  parser.add_argument("--num_train_epochs", type=int, default=100)
186
  parser.add_argument(
 
222
  ),
223
  )
224
  parser.add_argument(
225
+ "--lr_warmup_steps",
226
+ type=int,
227
+ default=500,
228
+ help="Number of steps for the warmup in the lr scheduler.",
229
+ )
230
+ parser.add_argument(
231
+ "--adam_beta1",
232
+ type=float,
233
+ default=0.9,
234
+ help="The beta1 parameter for the Adam optimizer.",
235
+ )
236
+ parser.add_argument(
237
+ "--adam_beta2",
238
+ type=float,
239
+ default=0.999,
240
+ help="The beta2 parameter for the Adam optimizer.",
241
+ )
242
+ parser.add_argument(
243
+ "--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use."
244
+ )
245
+ parser.add_argument(
246
+ "--adam_epsilon",
247
+ type=float,
248
+ default=1e-08,
249
+ help="Epsilon value for the Adam optimizer",
250
+ )
251
+ parser.add_argument(
252
+ "--push_to_hub",
253
+ action="store_true",
254
+ help="Whether or not to push the model to the Hub.",
255
+ )
256
+ parser.add_argument(
257
+ "--hub_token",
258
+ type=str,
259
+ default=None,
260
+ help="The token to use to push to the Model Hub.",
261
  )
 
 
 
 
 
 
262
  parser.add_argument(
263
  "--hub_model_id",
264
  type=str,
 
302
  ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
303
  ),
304
  )
305
+ parser.add_argument(
306
+ "--local_rank",
307
+ type=int,
308
+ default=-1,
309
+ help="For distributed training: local_rank",
310
+ )
311
  parser.add_argument(
312
  "--checkpointing_steps",
313
  type=int,
 
327
  ),
328
  )
329
  parser.add_argument(
330
+ "--enable_xformers_memory_efficient_attention",
331
+ action="store_true",
332
+ help="Whether or not to use xformers.",
333
  )
334
 
335
  args = parser.parse_args()
 
337
  if env_local_rank != -1 and env_local_rank != args.local_rank:
338
  args.local_rank = env_local_rank
339
 
340
+ # if args.train_data_dir is None:
341
  # raise ValueError("You must specify a train data directory.")
342
 
343
  return args
 
418
  self.center_crop = center_crop
419
  self.flip_p = flip_p
420
 
421
+ self.image_paths = [
422
+ os.path.join(self.data_root, file_path)
423
+ for file_path in os.listdir(self.data_root)
424
+ ]
425
 
426
  self.num_images = len(self.image_paths)
427
  self._length = self.num_images
 
436
  "lanczos": PIL_INTERPOLATION["lanczos"],
437
  }[interpolation]
438
 
439
+ self.templates = (
440
+ imagenet_style_templates_small
441
+ if learnable_property == "style"
442
+ else imagenet_templates_small
443
+ )
444
  self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)
445
 
446
  def __len__(self):
 
469
 
470
  if self.center_crop:
471
  crop = min(img.shape[0], img.shape[1])
472
+ (h, w,) = (
 
 
 
473
  img.shape[0],
474
  img.shape[1],
475
  )
476
+ img = img[
477
+ (h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2
478
+ ]
479
 
480
  image = Image.fromarray(img)
481
  image = image.resize((self.size, self.size), resample=self.interpolation)
 
488
  return example
489
 
490
 
491
+ def get_full_repo_name(
492
+ model_id: str, organization: Optional[str] = None, token: Optional[str] = None
493
+ ):
494
  if token is None:
495
  token = HfFolder.get_token()
496
  if organization is None:
 
500
  return f"{organization}/{model_id}"
501
 
502
 
 
503
  def main(pipe, args_imported):
504
 
505
  args = parse_args()
 
539
  if accelerator.is_main_process:
540
  if args.push_to_hub:
541
  if args.hub_model_id is None:
542
+ repo_name = get_full_repo_name(
543
+ Path(args.output_dir).name, token=args.hub_token
544
+ )
545
  else:
546
  repo_name = args.hub_model_id
547
  create_repo(repo_name, exist_ok=True, token=args.hub_token)
548
+ repo = Repository(
549
+ args.output_dir, clone_from=repo_name, token=args.hub_token
550
+ )
551
 
552
  with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
553
  if "step_*" not in gitignore:
 
609
  if is_xformers_available():
610
  unet.enable_xformers_memory_efficient_attention()
611
  else:
612
+ raise ValueError(
613
+ "xformers is not available. Make sure it is installed correctly"
614
+ )
615
 
616
  # Enable TF32 for faster training on Ampere GPUs,
617
  # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
 
620
 
621
  if args.scale_lr:
622
  args.learning_rate = (
623
+ args.learning_rate
624
+ * args.gradient_accumulation_steps
625
+ * args.train_batch_size
626
+ * accelerator.num_processes
627
  )
628
 
629
  # Initialize the optimizer
 
646
  center_crop=args.center_crop,
647
  set="train",
648
  )
649
+ train_dataloader = torch.utils.data.DataLoader(
650
+ train_dataset, batch_size=args.train_batch_size, shuffle=True
651
+ )
652
 
653
  # Scheduler and math around the number of training steps.
654
  overrode_max_train_steps = False
655
+ num_update_steps_per_epoch = math.ceil(
656
+ len(train_dataloader) / args.gradient_accumulation_steps
657
+ )
658
  if args.max_train_steps is None:
659
  args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
660
  overrode_max_train_steps = True
 
685
  text_encoder.to(accelerator.device, dtype=torch.float32)
686
 
687
  # We need to recalculate our total training steps as the size of the training dataloader may have changed.
688
+ num_update_steps_per_epoch = math.ceil(
689
+ len(train_dataloader) / args.gradient_accumulation_steps
690
+ )
691
  if overrode_max_train_steps:
692
  args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
693
  # Afterwards we recalculate our number of training epochs
 
699
  accelerator.init_trackers("textual_inversion", config=vars(args))
700
 
701
  # Train!
702
+ total_batch_size = (
703
+ args.train_batch_size
704
+ * accelerator.num_processes
705
+ * args.gradient_accumulation_steps
706
+ )
707
 
708
  logger.info("***** Running training *****")
709
  logger.info(f" Num examples = {len(train_dataset)}")
710
  logger.info(f" Num Epochs = {args.num_train_epochs}")
711
  logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
712
+ logger.info(
713
+ f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
714
+ )
715
  logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
716
  logger.info(f" Total optimization steps = {args.max_train_steps}")
717
  global_step = 0
 
736
  resume_step = resume_global_step % num_update_steps_per_epoch
737
 
738
  # Only show the progress bar once on each machine.
739
+ progress_bar = tqdm(
740
+ range(global_step, args.max_train_steps),
741
+ disable=not accelerator.is_local_main_process,
742
+ )
743
  progress_bar.set_description("Steps")
744
 
745
  # keep original embeddings as reference
746
+ orig_embeds_params = (
747
+ accelerator.unwrap_model(text_encoder)
748
+ .get_input_embeddings()
749
+ .weight.data.clone()
750
+ )
751
 
752
+ for epoch in range(first_epoch, args.num_train_epochs):
753
  text_encoder.train()
754
  for step, batch in enumerate(train_dataloader):
755
  # Skip steps until we reach the resumed step
756
+ if (
757
+ args.resume_from_checkpoint
758
+ and epoch == first_epoch
759
+ and step < resume_step
760
+ ):
761
  if step % args.gradient_accumulation_steps == 0:
762
  progress_bar.update(1)
763
  continue
764
 
765
  with accelerator.accumulate(text_encoder):
766
  # Convert images to latent space
767
+ latents = (
768
+ vae.encode(batch["pixel_values"].to(dtype=weight_dtype))
769
+ .latent_dist.sample()
770
+ .detach()
771
+ )
772
  latents = latents * 0.18215
773
 
774
  # Sample noise that we'll add to the latents
775
  noise = torch.randn_like(latents)
776
  bsz = latents.shape[0]
777
  # Sample a random timestep for each image
778
+ timesteps = torch.randint(
779
+ 0,
780
+ noise_scheduler.config.num_train_timesteps,
781
+ (bsz,),
782
+ device=latents.device,
783
+ )
784
  timesteps = timesteps.long()
785
 
786
  # Add noise to the latents according to the noise magnitude at each timestep
 
788
  noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
789
 
790
  # Get the text embedding for conditioning
791
+ encoder_hidden_states = text_encoder(batch["input_ids"])[0].to(
792
+ dtype=weight_dtype
793
+ )
794
 
795
  # Predict the noise residual
796
+ model_pred = unet(
797
+ noisy_latents, timesteps, encoder_hidden_states
798
+ ).sample
799
 
800
  # Get the target for loss depending on the prediction type
801
  if noise_scheduler.config.prediction_type == "epsilon":
 
803
  elif noise_scheduler.config.prediction_type == "v_prediction":
804
  target = noise_scheduler.get_velocity(latents, noise, timesteps)
805
  else:
806
+ raise ValueError(
807
+ f"Unknown prediction type {noise_scheduler.config.prediction_type}"
808
+ )
809
 
810
  loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
811
 
 
816
  else:
817
  grads = text_encoder.get_input_embeddings().weight.grad
818
  # Get the index for tokens that we want to zero the grads for
819
+ index_grads_to_zero = (
820
+ torch.arange(len(tokenizer)) != placeholder_token_id
821
+ )
822
+ grads.data[index_grads_to_zero, :] = grads.data[
823
+ index_grads_to_zero, :
824
+ ].fill_(0)
825
 
826
  optimizer.step()
827
  lr_scheduler.step()
 
830
  # Let's make sure we don't update any embedding weights besides the newly added token
831
  index_no_updates = torch.arange(len(tokenizer)) != placeholder_token_id
832
  with torch.no_grad():
833
+ accelerator.unwrap_model(
834
+ text_encoder
835
+ ).get_input_embeddings().weight[
836
+ index_no_updates
837
+ ] = orig_embeds_params[
838
  index_no_updates
839
+ ]
840
 
841
  # Checks if the accelerator has performed an optimization step behind the scenes
842
  if accelerator.sync_gradients:
843
  progress_bar.update(1)
844
  global_step += 1
845
  if global_step % args.save_steps == 0:
846
+ save_path = os.path.join(
847
+ args.output_dir, f"{args.placeholder_token}-{global_step}.bin"
848
+ )
849
+ save_progress(
850
+ text_encoder, placeholder_token_id, accelerator, args, save_path
851
+ )
852
 
853
  if global_step % args.checkpointing_steps == 0:
854
  if accelerator.is_main_process:
855
+ save_path = os.path.join(
856
+ args.output_dir, f"checkpoint-{global_step}"
857
+ )
858
  accelerator.save_state(save_path)
859
  logger.info(f"Saved state to {save_path}")
860
 
 
869
  accelerator.wait_for_everyone()
870
  if accelerator.is_main_process:
871
  if args.push_to_hub and args.only_save_embeds:
872
+ logger.warn(
873
+ "Enabling full model saving because --push_to_hub=True was specified."
874
+ )
875
  save_full_model = True
876
  else:
877
  save_full_model = not args.only_save_embeds
 
882
  save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path)
883
 
884
  if args.push_to_hub:
885
+ repo.push_to_hub(
886
+ commit_message="End of training", blocking=False, auto_lfs_prune=True
887
+ )
888
 
889
  accelerator.end_training()
 
 
 
890
 
891
 
892
  if __name__ == "__main__":
893
+ pipeline = StableDiffusionPipeline.from_pretrained(
894
+ "andite/anything-v4.0", torch_dtype=torch.float16
895
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
896
 
897
+ imported_args = argparse.Namespace(
898
+ train_data_dir="concept_images",
899
+ learnable_property="object",
900
+ placeholder_token="redeyegirl",
901
+ initializer_token="girl",
902
+ resolution=512,
903
+ train_batch_size=1,
904
+ gradient_accumulation_steps=1,
905
+ gradient_checkpointing=True,
906
+ mixed_precision="fp16",
907
+ use_bf16=False,
908
+ max_train_steps=1000,
909
+ learning_rate=5.0e-4,
910
+ scale_lr=False,
911
+ lr_scheduler="constant",
912
+ lr_warmup_steps=0,
913
+ output_dir="output_model",
914
+ )
915
 
916
  main(pipeline, imported_args)