Nupur Kumari commited on
Commit
dbc579c
·
1 Parent(s): f4d0eb6
Files changed (3) hide show
  1. app.py +86 -25
  2. inference.py +1 -0
  3. trainer.py +54 -24
app.py CHANGED
@@ -25,8 +25,7 @@ It is recommended to upgrade to GPU in Settings after duplicating this space to
25
  DETAILDESCRIPTION='''
26
  Custom Diffusion allows you to fine-tune text-to-image diffusion models, such as Stable Diffusion, given a few images of a new concept (~4-20).
27
  We fine-tune only a subset of model parameters, namely key and value projection matrices, in the cross-attention layers and the modifier token used to represent the object.
28
- This also reduces the extra storage for each additional concept to 75MB.
29
- Our method further allows you to use a combination of concepts. Demo for multiple concepts will be added soon.
30
  <center>
31
  <img src="https://huggingface.co/spaces/nupurkmr9/custom-diffusion/resolve/main/method.jpg" width="600" align="center" >
32
  </center>
@@ -81,27 +80,82 @@ def create_training_demo(trainer: Trainer,
81
 
82
  with gr.Row():
83
  with gr.Box():
84
- gr.Markdown('Training Data')
85
- concept_images = gr.Files(label='Images for your concept')
86
- with gr.Row():
87
- class_prompt = gr.Textbox(label='Class Prompt',
88
- max_lines=1, placeholder='Example: "cat"')
89
- with gr.Column():
90
- modifier_token = gr.Checkbox(label='modifier token',
91
- value=True)
92
- train_text_encoder = gr.Checkbox(label='Train Text Encoder',
93
- value=False)
94
- concept_prompt = gr.Textbox(label='Concept Prompt',
95
- max_lines=1, placeholder='Example: "photo of a \<new1\> cat"')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  gr.Markdown('''
97
- - We use "\<new1\>" modifier token in front of the concept, e.g., "\<new1\> cat". By default modifier_token is enabled.
98
- - If "Train Text Encoder", disable "modifier token" and use any unique text to describe the concept e.g. "ktn cat".
99
- - For a new concept an e.g. concept prompt is "photo of a \<new1\> cat" and "cat" for class prompt.
100
- - For a style concept, use "painting in the style of \<new1\> art" for concept prompt and "art" for class prompt.
101
- - Class prompt should be the object category.
102
- ''')
103
  with gr.Box():
104
  gr.Markdown('Training Parameters')
 
 
 
 
 
105
  num_training_steps = gr.Number(
106
  label='Number of Training Steps', value=1000, precision=0)
107
  learning_rate = gr.Number(label='Learning Rate', value=0.00001)
@@ -115,6 +169,10 @@ def create_training_demo(trainer: Trainer,
115
  label='Number of Gradient Accumulation',
116
  value=1,
117
  precision=0)
 
 
 
 
118
  gen_images = gr.Checkbox(label='Generated images as regularization',
119
  value=False)
120
  gr.Markdown('''
@@ -122,6 +180,7 @@ def create_training_demo(trainer: Trainer,
122
  - Our results in the paper are trained with batch-size 4 (8 including class regularization samples).
123
  - Enable gradient checkpointing for lower memory requirements (~14GB) at the expense of slower backward pass.
124
  - Note that your trained models will be deleted when the second training is started. You can upload your trained model in the "Upload" tab.
 
125
  ''')
126
 
127
  run_button = gr.Button('Start Training')
@@ -141,9 +200,6 @@ def create_training_demo(trainer: Trainer,
141
  inputs=[
142
  base_model,
143
  resolution,
144
- concept_images,
145
- concept_prompt,
146
- class_prompt,
147
  num_training_steps,
148
  learning_rate,
149
  train_text_encoder,
@@ -152,8 +208,13 @@ def create_training_demo(trainer: Trainer,
152
  batch_size,
153
  use_8bit_adam,
154
  gradient_checkpointing,
155
- gen_images
156
- ],
 
 
 
 
 
157
  outputs=[
158
  training_status,
159
  output_files,
 
25
  DETAILDESCRIPTION='''
26
  Custom Diffusion allows you to fine-tune text-to-image diffusion models, such as Stable Diffusion, given a few images of a new concept (~4-20).
27
  We fine-tune only a subset of model parameters, namely key and value projection matrices, in the cross-attention layers and the modifier token used to represent the object.
28
+ This also reduces the extra storage for each additional concept to 75MB. Our method also allows you to use a combination of concepts. There's still limitations on which compositions work. For more analysis please refer to our [website](https://www.cs.cmu.edu/~custom-diffusion/).
 
29
  <center>
30
  <img src="https://huggingface.co/spaces/nupurkmr9/custom-diffusion/resolve/main/method.jpg" width="600" align="center" >
31
  </center>
 
80
 
81
  with gr.Row():
82
  with gr.Box():
83
+ concept_images_collection = []
84
+ concept_prompt_collection = []
85
+ class_prompt_collection = []
86
+ buttons_collection = []
87
+ delete_collection = []
88
+ is_visible = []
89
+ maximum_concepts = 3
90
+ row = [None] * maximum_concepts
91
+ for x in range(maximum_concepts):
92
+ ordinal = lambda n: "%d%s" % (n, "tsnrhtdd"[(n // 10 % 10 != 1) * (n % 10 < 4) * n % 10::4])
93
+ ordinal_concept = ["<new1> cat", "<new2> wooden pot", "<new3> chair"]
94
+ if(x == 0):
95
+ visible = True
96
+ is_visible.append(gr.State(value=True))
97
+ else:
98
+ visible = False
99
+ is_visible.append(gr.State(value=False))
100
+
101
+ concept_images_collection.append(gr.Files(label=f'''Upload the images for your {ordinal(x+1) if (x>0) else ""} concept''', visible=visible))
102
+ with gr.Column(visible=visible) as row[x]:
103
+ concept_prompt_collection.append(
104
+ gr.Textbox(label=f'''{ordinal(x+1) if (x>0) else ""} concept prompt ''', max_lines=1,
105
+ placeholder=f'''Example: "photo of a {ordinal_concept[x]}"''' )
106
+ )
107
+ class_prompt_collection.append(
108
+ gr.Textbox(label=f'''{ordinal(x+1) if (x>0) else ""} class prompt ''',
109
+ max_lines=1, placeholder=f'''Example: "{ordinal_concept[x][7:]}"''')
110
+ )
111
+ with gr.Row():
112
+ if(x < maximum_concepts-1):
113
+ buttons_collection.append(gr.Button(value=f"Add {ordinal(x+2)} concept", visible=visible))
114
+ if(x > 0):
115
+ delete_collection.append(gr.Button(value=f"Delete {ordinal(x+1)} concept"))
116
+
117
+ counter_add = 1
118
+ for button in buttons_collection:
119
+ if(counter_add < len(buttons_collection)):
120
+ button.click(lambda:
121
+ [gr.update(visible=True),gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), True, None],
122
+ None,
123
+ [row[counter_add], concept_images_collection[counter_add], buttons_collection[counter_add-1], buttons_collection[counter_add], is_visible[counter_add], concept_images_collection[counter_add]], queue=False)
124
+ else:
125
+ button.click(lambda:
126
+ [gr.update(visible=True),gr.update(visible=True), gr.update(visible=False), True],
127
+ None,
128
+ [row[counter_add], concept_images_collection[counter_add], buttons_collection[counter_add-1], is_visible[counter_add]], queue=False)
129
+ counter_add += 1
130
+
131
+ counter_delete = 1
132
+ for delete_button in delete_collection:
133
+ if(counter_delete < len(delete_collection)+1):
134
+ if counter_delete == 1:
135
+ delete_button.click(lambda:
136
+ [gr.update(visible=False, value=None),gr.update(visible=False), gr.update(visible=True), gr.update(visible=False),False],
137
+ None,
138
+ [concept_images_collection[counter_delete], row[counter_delete], buttons_collection[counter_delete-1], buttons_collection[counter_delete], is_visible[counter_delete]], queue=False)
139
+ else:
140
+ delete_button.click(lambda:
141
+ [gr.update(visible=False, value=None),gr.update(visible=False), gr.update(visible=True), False],
142
+ None,
143
+ [concept_images_collection[counter_delete], row[counter_delete], buttons_collection[counter_delete-1], is_visible[counter_delete]], queue=False)
144
+ counter_delete += 1
145
  gr.Markdown('''
146
+ - We use "\<new1\>" modifier_token in front of the concept, e.g., "\<new1\> cat". For multiple concepts use "\<new2\>", "\<new3\>" etc. Increase the number of steps with more concepts.
147
+ - For a new concept an e.g. concept prompt is "photo of a \<new1\> cat" and "cat" for class prompt.
148
+ - For a style concept, use "painting in the style of \<new1\> art" for concept prompt and "art" for class prompt.
149
+ - Class prompt should be the object category.
150
+ - If "Train Text Encoder", disable "modifier token" and use any unique text to describe the concept e.g. "ktn cat".
151
+ ''')
152
  with gr.Box():
153
  gr.Markdown('Training Parameters')
154
+ with gr.Row():
155
+ modifier_token = gr.Checkbox(label='modifier token',
156
+ value=True)
157
+ train_text_encoder = gr.Checkbox(label='Train Text Encoder',
158
+ value=False)
159
  num_training_steps = gr.Number(
160
  label='Number of Training Steps', value=1000, precision=0)
161
  learning_rate = gr.Number(label='Learning Rate', value=0.00001)
 
169
  label='Number of Gradient Accumulation',
170
  value=1,
171
  precision=0)
172
+ num_reg_images = gr.Number(
173
+ label='Number of Class Concept images',
174
+ value=200,
175
+ precision=0)
176
  gen_images = gr.Checkbox(label='Generated images as regularization',
177
  value=False)
178
  gr.Markdown('''
 
180
  - Our results in the paper are trained with batch-size 4 (8 including class regularization samples).
181
  - Enable gradient checkpointing for lower memory requirements (~14GB) at the expense of slower backward pass.
182
  - Note that your trained models will be deleted when the second training is started. You can upload your trained model in the "Upload" tab.
183
+ - We retrieve real images for class concept using clip_retireval library which can take some time.
184
  ''')
185
 
186
  run_button = gr.Button('Start Training')
 
200
  inputs=[
201
  base_model,
202
  resolution,
 
 
 
203
  num_training_steps,
204
  learning_rate,
205
  train_text_encoder,
 
208
  batch_size,
209
  use_8bit_adam,
210
  gradient_checkpointing,
211
+ gen_images,
212
+ num_reg_images,
213
+ ] +
214
+ concept_images_collection +
215
+ concept_prompt_collection +
216
+ class_prompt_collection
217
+ ,
218
  outputs=[
219
  training_status,
220
  output_files,
inference.py CHANGED
@@ -75,6 +75,7 @@ class InferencePipeline:
75
  height=resolution, width=resolution,
76
  eta = eta,
77
  generator=generator) # type: ignore
 
78
  out = out.images
79
  out = PIL.Image.fromarray(np.hstack([np.array(x) for x in out]))
80
  return out
 
75
  height=resolution, width=resolution,
76
  eta = eta,
77
  generator=generator) # type: ignore
78
+ torch.cuda.empty_cache()
79
  out = out.images
80
  out = PIL.Image.fromarray(np.hstack([np.array(x) for x in out]))
81
  return out
trainer.py CHANGED
@@ -9,6 +9,7 @@ import subprocess
9
  import gradio as gr
10
  import PIL.Image
11
  import torch
 
12
 
13
  os.environ['PYTHONPATH'] = f'custom-diffusion:{os.getenv("PYTHONPATH", "")}'
14
 
@@ -45,23 +46,41 @@ class Trainer:
45
  def cleanup_dirs(self) -> None:
46
  shutil.rmtree(self.output_dir, ignore_errors=True)
47
 
48
- def prepare_dataset(self, concept_images: list, resolution: int) -> None:
49
  self.instance_data_dir.mkdir(parents=True)
50
- for i, temp_path in enumerate(concept_images):
51
- image = PIL.Image.open(temp_path.name)
52
- image = pad_image(image)
53
- image = image.resize((resolution, resolution))
54
- image = image.convert('RGB')
55
- out_path = self.instance_data_dir / f'{i:03d}.jpg'
56
- image.save(out_path, format='JPEG', quality=100)
57
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  def run(
59
  self,
60
  base_model: str,
61
  resolution_s: str,
62
- concept_images: list | None,
63
- concept_prompt: str,
64
- class_prompt: str,
65
  n_steps: int,
66
  learning_rate: float,
67
  train_text_encoder: bool,
@@ -71,32 +90,40 @@ class Trainer:
71
  use_8bit_adam: bool,
72
  gradient_checkpointing: bool,
73
  gen_images: bool,
 
 
74
  ) -> tuple[dict, list[pathlib.Path]]:
75
  if not torch.cuda.is_available():
76
  raise gr.Error('CUDA is not available.')
77
 
 
 
 
 
 
 
 
 
 
78
  if self.is_running:
79
  return gr.update(value=self.is_running_message), []
80
 
81
- if concept_images is None:
82
  raise gr.Error('You need to upload images.')
83
- if not concept_prompt:
84
  raise gr.Error('The concept prompt is missing.')
85
 
86
  resolution = int(resolution_s)
87
 
88
  self.cleanup_dirs()
89
- self.prepare_dataset(concept_images, resolution)
90
-
91
  command = f'''
92
  accelerate launch custom-diffusion/src/diffuser_training.py \
93
  --pretrained_model_name_or_path={base_model} \
94
- --instance_data_dir={self.instance_data_dir} \
95
  --output_dir={self.output_dir} \
96
- --instance_prompt="{concept_prompt}" \
97
- --class_data_dir={self.class_data_dir} \
98
  --with_prior_preservation --prior_loss_weight=1.0 \
99
- --class_prompt="{class_prompt}" \
100
  --resolution={resolution} \
101
  --train_batch_size={batch_size} \
102
  --gradient_accumulation_steps={gradient_accumulation} \
@@ -104,11 +131,14 @@ class Trainer:
104
  --lr_scheduler="constant" \
105
  --lr_warmup_steps=0 \
106
  --max_train_steps={n_steps} \
107
- --num_class_images=200 \
108
- --scale_lr
 
109
  '''
110
  if modifier_token:
111
- command += ' --modifier_token "<new1>"'
 
 
112
  if not gen_images:
113
  command += ' --real_prior'
114
  if use_8bit_adam:
@@ -117,7 +147,7 @@ class Trainer:
117
  command += f' --train_text_encoder'
118
  if gradient_checkpointing:
119
  command += f' --gradient_checkpointing'
120
-
121
  with open(self.output_dir / 'train.sh', 'w') as f:
122
  command_s = ' '.join(command.split())
123
  f.write(command_s)
 
9
  import gradio as gr
10
  import PIL.Image
11
  import torch
12
+ import json
13
 
14
  os.environ['PYTHONPATH'] = f'custom-diffusion:{os.getenv("PYTHONPATH", "")}'
15
 
 
46
  def cleanup_dirs(self) -> None:
47
  shutil.rmtree(self.output_dir, ignore_errors=True)
48
 
49
+ def prepare_dataset(self, concept_images_collection: list, concept_prompt_collection: list, class_prompt_collection: list, resolution: int) -> None:
50
  self.instance_data_dir.mkdir(parents=True)
51
+ concepts_list = []
52
+
53
+ for i in range(len(concept_images_collection)):
54
+ concept_dir = self.instance_data_dir / f'{i}'
55
+ class_dir = self.class_data_dir / f'{i}'
56
+ concept_dir.mkdir(parents=True)
57
+ concept_images = concept_images_collection[i]
58
+
59
+ concepts_list.append(
60
+ {
61
+ "instance_prompt": concept_prompt_collection[i],
62
+ "class_prompt": class_prompt_collection[i],
63
+ "instance_data_dir": f'{concept_dir}',
64
+ "class_data_dir": f'{class_dir}'
65
+ }
66
+ )
67
+
68
+ for i, temp_path in enumerate(concept_images):
69
+ image = PIL.Image.open(temp_path.name)
70
+ image = pad_image(image)
71
+ # image = image.resize((resolution, resolution))
72
+ image = image.convert('RGB')
73
+ out_path = concept_dir / f'{i:03d}.jpg'
74
+ image.save(out_path, format='JPEG', quality=100)
75
+
76
+ print(concepts_list)
77
+ json.dump(concepts_list, open( f'{self.output_dir}/temp.json' , 'w') )
78
+
79
+
80
  def run(
81
  self,
82
  base_model: str,
83
  resolution_s: str,
 
 
 
84
  n_steps: int,
85
  learning_rate: float,
86
  train_text_encoder: bool,
 
90
  use_8bit_adam: bool,
91
  gradient_checkpointing: bool,
92
  gen_images: bool,
93
+ num_reg_images: int,
94
+ *inputs,
95
  ) -> tuple[dict, list[pathlib.Path]]:
96
  if not torch.cuda.is_available():
97
  raise gr.Error('CUDA is not available.')
98
 
99
+ num_concept = 0
100
+ for i in range(len(inputs) // 3):
101
+ if inputs[i] != None:
102
+ num_concept +=1
103
+
104
+ print(num_concept, inputs)
105
+ concept_images_collection = inputs[: num_concept]
106
+ concept_prompt_collection = inputs[3: 3 + num_concept]
107
+ class_prompt_collection = inputs[6: 6+num_concept]
108
  if self.is_running:
109
  return gr.update(value=self.is_running_message), []
110
 
111
+ if concept_images_collection is None:
112
  raise gr.Error('You need to upload images.')
113
+ if not concept_prompt_collection:
114
  raise gr.Error('The concept prompt is missing.')
115
 
116
  resolution = int(resolution_s)
117
 
118
  self.cleanup_dirs()
119
+ self.prepare_dataset(concept_images_collection, concept_prompt_collection, class_prompt_collection, resolution)
120
+ torch.cuda.empty_cache()
121
  command = f'''
122
  accelerate launch custom-diffusion/src/diffuser_training.py \
123
  --pretrained_model_name_or_path={base_model} \
 
124
  --output_dir={self.output_dir} \
125
+ --concepts_list={f'{self.output_dir}/temp.json'} \
 
126
  --with_prior_preservation --prior_loss_weight=1.0 \
 
127
  --resolution={resolution} \
128
  --train_batch_size={batch_size} \
129
  --gradient_accumulation_steps={gradient_accumulation} \
 
131
  --lr_scheduler="constant" \
132
  --lr_warmup_steps=0 \
133
  --max_train_steps={n_steps} \
134
+ --num_class_images={num_reg_images} \
135
+ --initializer_token="ktn+pll+ucd" \
136
+ --scale_lr --hflip
137
  '''
138
  if modifier_token:
139
+ tokens = '+'.join([f'<new{i+1}>' for i in range(num_concept)])
140
+ command += f' --modifier_token {tokens}'
141
+
142
  if not gen_images:
143
  command += ' --real_prior'
144
  if use_8bit_adam:
 
147
  command += f' --train_text_encoder'
148
  if gradient_checkpointing:
149
  command += f' --gradient_checkpointing'
150
+
151
  with open(self.output_dir / 'train.sh', 'w') as f:
152
  command_s = ' '.join(command.split())
153
  f.write(command_s)