Spaces:
Build error
Build error
Nupur Kumari
commited on
Commit
•
dbc579c
1
Parent(s):
f4d0eb6
update
Browse files- app.py +86 -25
- inference.py +1 -0
- trainer.py +54 -24
app.py
CHANGED
@@ -25,8 +25,7 @@ It is recommended to upgrade to GPU in Settings after duplicating this space to
|
|
25 |
DETAILDESCRIPTION='''
|
26 |
Custom Diffusion allows you to fine-tune text-to-image diffusion models, such as Stable Diffusion, given a few images of a new concept (~4-20).
|
27 |
We fine-tune only a subset of model parameters, namely key and value projection matrices, in the cross-attention layers and the modifier token used to represent the object.
|
28 |
-
This also reduces the extra storage for each additional concept to 75MB.
|
29 |
-
Our method further allows you to use a combination of concepts. Demo for multiple concepts will be added soon.
|
30 |
<center>
|
31 |
<img src="https://huggingface.co/spaces/nupurkmr9/custom-diffusion/resolve/main/method.jpg" width="600" align="center" >
|
32 |
</center>
|
@@ -81,27 +80,82 @@ def create_training_demo(trainer: Trainer,
|
|
81 |
|
82 |
with gr.Row():
|
83 |
with gr.Box():
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
gr.Markdown('''
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
with gr.Box():
|
104 |
gr.Markdown('Training Parameters')
|
|
|
|
|
|
|
|
|
|
|
105 |
num_training_steps = gr.Number(
|
106 |
label='Number of Training Steps', value=1000, precision=0)
|
107 |
learning_rate = gr.Number(label='Learning Rate', value=0.00001)
|
@@ -115,6 +169,10 @@ def create_training_demo(trainer: Trainer,
|
|
115 |
label='Number of Gradient Accumulation',
|
116 |
value=1,
|
117 |
precision=0)
|
|
|
|
|
|
|
|
|
118 |
gen_images = gr.Checkbox(label='Generated images as regularization',
|
119 |
value=False)
|
120 |
gr.Markdown('''
|
@@ -122,6 +180,7 @@ def create_training_demo(trainer: Trainer,
|
|
122 |
- Our results in the paper are trained with batch-size 4 (8 including class regularization samples).
|
123 |
- Enable gradient checkpointing for lower memory requirements (~14GB) at the expense of slower backward pass.
|
124 |
- Note that your trained models will be deleted when the second training is started. You can upload your trained model in the "Upload" tab.
|
|
|
125 |
''')
|
126 |
|
127 |
run_button = gr.Button('Start Training')
|
@@ -141,9 +200,6 @@ def create_training_demo(trainer: Trainer,
|
|
141 |
inputs=[
|
142 |
base_model,
|
143 |
resolution,
|
144 |
-
concept_images,
|
145 |
-
concept_prompt,
|
146 |
-
class_prompt,
|
147 |
num_training_steps,
|
148 |
learning_rate,
|
149 |
train_text_encoder,
|
@@ -152,8 +208,13 @@ def create_training_demo(trainer: Trainer,
|
|
152 |
batch_size,
|
153 |
use_8bit_adam,
|
154 |
gradient_checkpointing,
|
155 |
-
gen_images
|
156 |
-
|
|
|
|
|
|
|
|
|
|
|
157 |
outputs=[
|
158 |
training_status,
|
159 |
output_files,
|
|
|
25 |
DETAILDESCRIPTION='''
|
26 |
Custom Diffusion allows you to fine-tune text-to-image diffusion models, such as Stable Diffusion, given a few images of a new concept (~4-20).
|
27 |
We fine-tune only a subset of model parameters, namely key and value projection matrices, in the cross-attention layers and the modifier token used to represent the object.
|
28 |
+
This also reduces the extra storage for each additional concept to 75MB. Our method also allows you to use a combination of concepts. There's still limitations on which compositions work. For more analysis please refer to our [website](https://www.cs.cmu.edu/~custom-diffusion/).
|
|
|
29 |
<center>
|
30 |
<img src="https://huggingface.co/spaces/nupurkmr9/custom-diffusion/resolve/main/method.jpg" width="600" align="center" >
|
31 |
</center>
|
|
|
80 |
|
81 |
with gr.Row():
|
82 |
with gr.Box():
|
83 |
+
concept_images_collection = []
|
84 |
+
concept_prompt_collection = []
|
85 |
+
class_prompt_collection = []
|
86 |
+
buttons_collection = []
|
87 |
+
delete_collection = []
|
88 |
+
is_visible = []
|
89 |
+
maximum_concepts = 3
|
90 |
+
row = [None] * maximum_concepts
|
91 |
+
for x in range(maximum_concepts):
|
92 |
+
ordinal = lambda n: "%d%s" % (n, "tsnrhtdd"[(n // 10 % 10 != 1) * (n % 10 < 4) * n % 10::4])
|
93 |
+
ordinal_concept = ["<new1> cat", "<new2> wooden pot", "<new3> chair"]
|
94 |
+
if(x == 0):
|
95 |
+
visible = True
|
96 |
+
is_visible.append(gr.State(value=True))
|
97 |
+
else:
|
98 |
+
visible = False
|
99 |
+
is_visible.append(gr.State(value=False))
|
100 |
+
|
101 |
+
concept_images_collection.append(gr.Files(label=f'''Upload the images for your {ordinal(x+1) if (x>0) else ""} concept''', visible=visible))
|
102 |
+
with gr.Column(visible=visible) as row[x]:
|
103 |
+
concept_prompt_collection.append(
|
104 |
+
gr.Textbox(label=f'''{ordinal(x+1) if (x>0) else ""} concept prompt ''', max_lines=1,
|
105 |
+
placeholder=f'''Example: "photo of a {ordinal_concept[x]}"''' )
|
106 |
+
)
|
107 |
+
class_prompt_collection.append(
|
108 |
+
gr.Textbox(label=f'''{ordinal(x+1) if (x>0) else ""} class prompt ''',
|
109 |
+
max_lines=1, placeholder=f'''Example: "{ordinal_concept[x][7:]}"''')
|
110 |
+
)
|
111 |
+
with gr.Row():
|
112 |
+
if(x < maximum_concepts-1):
|
113 |
+
buttons_collection.append(gr.Button(value=f"Add {ordinal(x+2)} concept", visible=visible))
|
114 |
+
if(x > 0):
|
115 |
+
delete_collection.append(gr.Button(value=f"Delete {ordinal(x+1)} concept"))
|
116 |
+
|
117 |
+
counter_add = 1
|
118 |
+
for button in buttons_collection:
|
119 |
+
if(counter_add < len(buttons_collection)):
|
120 |
+
button.click(lambda:
|
121 |
+
[gr.update(visible=True),gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), True, None],
|
122 |
+
None,
|
123 |
+
[row[counter_add], concept_images_collection[counter_add], buttons_collection[counter_add-1], buttons_collection[counter_add], is_visible[counter_add], concept_images_collection[counter_add]], queue=False)
|
124 |
+
else:
|
125 |
+
button.click(lambda:
|
126 |
+
[gr.update(visible=True),gr.update(visible=True), gr.update(visible=False), True],
|
127 |
+
None,
|
128 |
+
[row[counter_add], concept_images_collection[counter_add], buttons_collection[counter_add-1], is_visible[counter_add]], queue=False)
|
129 |
+
counter_add += 1
|
130 |
+
|
131 |
+
counter_delete = 1
|
132 |
+
for delete_button in delete_collection:
|
133 |
+
if(counter_delete < len(delete_collection)+1):
|
134 |
+
if counter_delete == 1:
|
135 |
+
delete_button.click(lambda:
|
136 |
+
[gr.update(visible=False, value=None),gr.update(visible=False), gr.update(visible=True), gr.update(visible=False),False],
|
137 |
+
None,
|
138 |
+
[concept_images_collection[counter_delete], row[counter_delete], buttons_collection[counter_delete-1], buttons_collection[counter_delete], is_visible[counter_delete]], queue=False)
|
139 |
+
else:
|
140 |
+
delete_button.click(lambda:
|
141 |
+
[gr.update(visible=False, value=None),gr.update(visible=False), gr.update(visible=True), False],
|
142 |
+
None,
|
143 |
+
[concept_images_collection[counter_delete], row[counter_delete], buttons_collection[counter_delete-1], is_visible[counter_delete]], queue=False)
|
144 |
+
counter_delete += 1
|
145 |
gr.Markdown('''
|
146 |
+
- We use "\<new1\>" modifier_token in front of the concept, e.g., "\<new1\> cat". For multiple concepts use "\<new2\>", "\<new3\>" etc. Increase the number of steps with more concepts.
|
147 |
+
- For a new concept an e.g. concept prompt is "photo of a \<new1\> cat" and "cat" for class prompt.
|
148 |
+
- For a style concept, use "painting in the style of \<new1\> art" for concept prompt and "art" for class prompt.
|
149 |
+
- Class prompt should be the object category.
|
150 |
+
- If "Train Text Encoder", disable "modifier token" and use any unique text to describe the concept e.g. "ktn cat".
|
151 |
+
''')
|
152 |
with gr.Box():
|
153 |
gr.Markdown('Training Parameters')
|
154 |
+
with gr.Row():
|
155 |
+
modifier_token = gr.Checkbox(label='modifier token',
|
156 |
+
value=True)
|
157 |
+
train_text_encoder = gr.Checkbox(label='Train Text Encoder',
|
158 |
+
value=False)
|
159 |
num_training_steps = gr.Number(
|
160 |
label='Number of Training Steps', value=1000, precision=0)
|
161 |
learning_rate = gr.Number(label='Learning Rate', value=0.00001)
|
|
|
169 |
label='Number of Gradient Accumulation',
|
170 |
value=1,
|
171 |
precision=0)
|
172 |
+
num_reg_images = gr.Number(
|
173 |
+
label='Number of Class Concept images',
|
174 |
+
value=200,
|
175 |
+
precision=0)
|
176 |
gen_images = gr.Checkbox(label='Generated images as regularization',
|
177 |
value=False)
|
178 |
gr.Markdown('''
|
|
|
180 |
- Our results in the paper are trained with batch-size 4 (8 including class regularization samples).
|
181 |
- Enable gradient checkpointing for lower memory requirements (~14GB) at the expense of slower backward pass.
|
182 |
- Note that your trained models will be deleted when the second training is started. You can upload your trained model in the "Upload" tab.
|
183 |
+
- We retrieve real images for class concept using clip_retireval library which can take some time.
|
184 |
''')
|
185 |
|
186 |
run_button = gr.Button('Start Training')
|
|
|
200 |
inputs=[
|
201 |
base_model,
|
202 |
resolution,
|
|
|
|
|
|
|
203 |
num_training_steps,
|
204 |
learning_rate,
|
205 |
train_text_encoder,
|
|
|
208 |
batch_size,
|
209 |
use_8bit_adam,
|
210 |
gradient_checkpointing,
|
211 |
+
gen_images,
|
212 |
+
num_reg_images,
|
213 |
+
] +
|
214 |
+
concept_images_collection +
|
215 |
+
concept_prompt_collection +
|
216 |
+
class_prompt_collection
|
217 |
+
,
|
218 |
outputs=[
|
219 |
training_status,
|
220 |
output_files,
|
inference.py
CHANGED
@@ -75,6 +75,7 @@ class InferencePipeline:
|
|
75 |
height=resolution, width=resolution,
|
76 |
eta = eta,
|
77 |
generator=generator) # type: ignore
|
|
|
78 |
out = out.images
|
79 |
out = PIL.Image.fromarray(np.hstack([np.array(x) for x in out]))
|
80 |
return out
|
|
|
75 |
height=resolution, width=resolution,
|
76 |
eta = eta,
|
77 |
generator=generator) # type: ignore
|
78 |
+
torch.cuda.empty_cache()
|
79 |
out = out.images
|
80 |
out = PIL.Image.fromarray(np.hstack([np.array(x) for x in out]))
|
81 |
return out
|
trainer.py
CHANGED
@@ -9,6 +9,7 @@ import subprocess
|
|
9 |
import gradio as gr
|
10 |
import PIL.Image
|
11 |
import torch
|
|
|
12 |
|
13 |
os.environ['PYTHONPATH'] = f'custom-diffusion:{os.getenv("PYTHONPATH", "")}'
|
14 |
|
@@ -45,23 +46,41 @@ class Trainer:
|
|
45 |
def cleanup_dirs(self) -> None:
|
46 |
shutil.rmtree(self.output_dir, ignore_errors=True)
|
47 |
|
48 |
-
def prepare_dataset(self,
|
49 |
self.instance_data_dir.mkdir(parents=True)
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
def run(
|
59 |
self,
|
60 |
base_model: str,
|
61 |
resolution_s: str,
|
62 |
-
concept_images: list | None,
|
63 |
-
concept_prompt: str,
|
64 |
-
class_prompt: str,
|
65 |
n_steps: int,
|
66 |
learning_rate: float,
|
67 |
train_text_encoder: bool,
|
@@ -71,32 +90,40 @@ class Trainer:
|
|
71 |
use_8bit_adam: bool,
|
72 |
gradient_checkpointing: bool,
|
73 |
gen_images: bool,
|
|
|
|
|
74 |
) -> tuple[dict, list[pathlib.Path]]:
|
75 |
if not torch.cuda.is_available():
|
76 |
raise gr.Error('CUDA is not available.')
|
77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
if self.is_running:
|
79 |
return gr.update(value=self.is_running_message), []
|
80 |
|
81 |
-
if
|
82 |
raise gr.Error('You need to upload images.')
|
83 |
-
if not
|
84 |
raise gr.Error('The concept prompt is missing.')
|
85 |
|
86 |
resolution = int(resolution_s)
|
87 |
|
88 |
self.cleanup_dirs()
|
89 |
-
self.prepare_dataset(
|
90 |
-
|
91 |
command = f'''
|
92 |
accelerate launch custom-diffusion/src/diffuser_training.py \
|
93 |
--pretrained_model_name_or_path={base_model} \
|
94 |
-
--instance_data_dir={self.instance_data_dir} \
|
95 |
--output_dir={self.output_dir} \
|
96 |
-
--
|
97 |
-
--class_data_dir={self.class_data_dir} \
|
98 |
--with_prior_preservation --prior_loss_weight=1.0 \
|
99 |
-
--class_prompt="{class_prompt}" \
|
100 |
--resolution={resolution} \
|
101 |
--train_batch_size={batch_size} \
|
102 |
--gradient_accumulation_steps={gradient_accumulation} \
|
@@ -104,11 +131,14 @@ class Trainer:
|
|
104 |
--lr_scheduler="constant" \
|
105 |
--lr_warmup_steps=0 \
|
106 |
--max_train_steps={n_steps} \
|
107 |
-
--num_class_images=
|
108 |
-
--
|
|
|
109 |
'''
|
110 |
if modifier_token:
|
111 |
-
|
|
|
|
|
112 |
if not gen_images:
|
113 |
command += ' --real_prior'
|
114 |
if use_8bit_adam:
|
@@ -117,7 +147,7 @@ class Trainer:
|
|
117 |
command += f' --train_text_encoder'
|
118 |
if gradient_checkpointing:
|
119 |
command += f' --gradient_checkpointing'
|
120 |
-
|
121 |
with open(self.output_dir / 'train.sh', 'w') as f:
|
122 |
command_s = ' '.join(command.split())
|
123 |
f.write(command_s)
|
|
|
9 |
import gradio as gr
|
10 |
import PIL.Image
|
11 |
import torch
|
12 |
+
import json
|
13 |
|
14 |
os.environ['PYTHONPATH'] = f'custom-diffusion:{os.getenv("PYTHONPATH", "")}'
|
15 |
|
|
|
46 |
def cleanup_dirs(self) -> None:
|
47 |
shutil.rmtree(self.output_dir, ignore_errors=True)
|
48 |
|
49 |
+
def prepare_dataset(self, concept_images_collection: list, concept_prompt_collection: list, class_prompt_collection: list, resolution: int) -> None:
|
50 |
self.instance_data_dir.mkdir(parents=True)
|
51 |
+
concepts_list = []
|
52 |
+
|
53 |
+
for i in range(len(concept_images_collection)):
|
54 |
+
concept_dir = self.instance_data_dir / f'{i}'
|
55 |
+
class_dir = self.class_data_dir / f'{i}'
|
56 |
+
concept_dir.mkdir(parents=True)
|
57 |
+
concept_images = concept_images_collection[i]
|
58 |
+
|
59 |
+
concepts_list.append(
|
60 |
+
{
|
61 |
+
"instance_prompt": concept_prompt_collection[i],
|
62 |
+
"class_prompt": class_prompt_collection[i],
|
63 |
+
"instance_data_dir": f'{concept_dir}',
|
64 |
+
"class_data_dir": f'{class_dir}'
|
65 |
+
}
|
66 |
+
)
|
67 |
+
|
68 |
+
for i, temp_path in enumerate(concept_images):
|
69 |
+
image = PIL.Image.open(temp_path.name)
|
70 |
+
image = pad_image(image)
|
71 |
+
# image = image.resize((resolution, resolution))
|
72 |
+
image = image.convert('RGB')
|
73 |
+
out_path = concept_dir / f'{i:03d}.jpg'
|
74 |
+
image.save(out_path, format='JPEG', quality=100)
|
75 |
+
|
76 |
+
print(concepts_list)
|
77 |
+
json.dump(concepts_list, open( f'{self.output_dir}/temp.json' , 'w') )
|
78 |
+
|
79 |
+
|
80 |
def run(
|
81 |
self,
|
82 |
base_model: str,
|
83 |
resolution_s: str,
|
|
|
|
|
|
|
84 |
n_steps: int,
|
85 |
learning_rate: float,
|
86 |
train_text_encoder: bool,
|
|
|
90 |
use_8bit_adam: bool,
|
91 |
gradient_checkpointing: bool,
|
92 |
gen_images: bool,
|
93 |
+
num_reg_images: int,
|
94 |
+
*inputs,
|
95 |
) -> tuple[dict, list[pathlib.Path]]:
|
96 |
if not torch.cuda.is_available():
|
97 |
raise gr.Error('CUDA is not available.')
|
98 |
|
99 |
+
num_concept = 0
|
100 |
+
for i in range(len(inputs) // 3):
|
101 |
+
if inputs[i] != None:
|
102 |
+
num_concept +=1
|
103 |
+
|
104 |
+
print(num_concept, inputs)
|
105 |
+
concept_images_collection = inputs[: num_concept]
|
106 |
+
concept_prompt_collection = inputs[3: 3 + num_concept]
|
107 |
+
class_prompt_collection = inputs[6: 6+num_concept]
|
108 |
if self.is_running:
|
109 |
return gr.update(value=self.is_running_message), []
|
110 |
|
111 |
+
if concept_images_collection is None:
|
112 |
raise gr.Error('You need to upload images.')
|
113 |
+
if not concept_prompt_collection:
|
114 |
raise gr.Error('The concept prompt is missing.')
|
115 |
|
116 |
resolution = int(resolution_s)
|
117 |
|
118 |
self.cleanup_dirs()
|
119 |
+
self.prepare_dataset(concept_images_collection, concept_prompt_collection, class_prompt_collection, resolution)
|
120 |
+
torch.cuda.empty_cache()
|
121 |
command = f'''
|
122 |
accelerate launch custom-diffusion/src/diffuser_training.py \
|
123 |
--pretrained_model_name_or_path={base_model} \
|
|
|
124 |
--output_dir={self.output_dir} \
|
125 |
+
--concepts_list={f'{self.output_dir}/temp.json'} \
|
|
|
126 |
--with_prior_preservation --prior_loss_weight=1.0 \
|
|
|
127 |
--resolution={resolution} \
|
128 |
--train_batch_size={batch_size} \
|
129 |
--gradient_accumulation_steps={gradient_accumulation} \
|
|
|
131 |
--lr_scheduler="constant" \
|
132 |
--lr_warmup_steps=0 \
|
133 |
--max_train_steps={n_steps} \
|
134 |
+
--num_class_images={num_reg_images} \
|
135 |
+
--initializer_token="ktn+pll+ucd" \
|
136 |
+
--scale_lr --hflip
|
137 |
'''
|
138 |
if modifier_token:
|
139 |
+
tokens = '+'.join([f'<new{i+1}>' for i in range(num_concept)])
|
140 |
+
command += f' --modifier_token {tokens}'
|
141 |
+
|
142 |
if not gen_images:
|
143 |
command += ' --real_prior'
|
144 |
if use_8bit_adam:
|
|
|
147 |
command += f' --train_text_encoder'
|
148 |
if gradient_checkpointing:
|
149 |
command += f' --gradient_checkpointing'
|
150 |
+
|
151 |
with open(self.output_dir / 'train.sh', 'w') as f:
|
152 |
command_s = ' '.join(command.split())
|
153 |
f.write(command_s)
|