Commit
•
ac586a8
1
Parent(s):
cb024a4
Update app.py
Browse files
app.py
CHANGED
@@ -16,21 +16,6 @@ css = '''
|
|
16 |
shutil.unpack_archive("mix.zip", "mix")
|
17 |
model_to_load = "multimodalart/sd-fine-tunable"
|
18 |
maximum_concepts = 3
|
19 |
-
def swap_values_files(*total_files):
|
20 |
-
file_counter = 0
|
21 |
-
for files in total_files:
|
22 |
-
if(files):
|
23 |
-
for file in files:
|
24 |
-
filename = Path(file.orig_name).stem
|
25 |
-
pt=''.join([i for i in filename if not i.isdigit()])
|
26 |
-
pt=pt.replace("_"," ")
|
27 |
-
pt=pt.replace("(","")
|
28 |
-
pt=pt.replace(")","")
|
29 |
-
instance_prompt = pt
|
30 |
-
print(instance_prompt)
|
31 |
-
file_counter += 1
|
32 |
-
training_steps = (file_counter*200)
|
33 |
-
return training_steps
|
34 |
|
35 |
def swap_text(option):
|
36 |
mandatory_liability = "You must have the right to do so and you are liable for the images you use"
|
@@ -47,6 +32,24 @@ def swap_text(option):
|
|
47 |
freeze_for = 10
|
48 |
return [f"You are going to train a `style`, upload 10-20 images of the style you are planning on training on. Name the files with the words you would like {mandatory_liability}:", '''<img src="file/trsl_style.png" />''', f"You should name your files with a unique word that represent your concept (e.g.: `{instance_prompt_example}` here). Images will be automatically cropped to 512x512.", freeze_for]
|
49 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
def train(*inputs):
|
51 |
if os.path.exists("diffusers_model.zip"): os.remove("diffusers_model.zip")
|
52 |
if os.path.exists("model.ckpt"): os.remove("model.ckpt")
|
@@ -164,7 +167,7 @@ def train(*inputs):
|
|
164 |
shutil.rmtree('instance_images')
|
165 |
shutil.make_archive("diffusers_model", 'zip', "output_model")
|
166 |
torch.cuda.empty_cache()
|
167 |
-
return [gr.update(visible=True, value=["diffusers_model.zip"]), gr.update(visible=True), gr.update(visible=True)]
|
168 |
|
169 |
def generate(prompt):
|
170 |
from diffusers import StableDiffusionPipeline
|
@@ -177,7 +180,7 @@ def generate(prompt):
|
|
177 |
def push(path):
|
178 |
pass
|
179 |
|
180 |
-
def
|
181 |
convert("output_model", "model.ckpt")
|
182 |
return gr.update(visible=True, value=["diffusers_model.zip", "model.ckpt"])
|
183 |
|
@@ -192,6 +195,13 @@ with gr.Blocks(css=css) as demo:
|
|
192 |
<img class="arrow" src="file/arrow.png" />
|
193 |
</div>
|
194 |
''')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
195 |
gr.Markdown("# Dreambooth training")
|
196 |
gr.Markdown("Customize Stable Diffusion by giving it with few-shot examples")
|
197 |
with gr.Row():
|
@@ -253,10 +263,11 @@ with gr.Blocks(css=css) as demo:
|
|
253 |
steps = gr.Number(label="How many steps", value=800)
|
254 |
perc_txt_encoder = gr.Number(label="Percentage of the training steps the text-encoder should be trained as well", value=30)
|
255 |
|
256 |
-
|
257 |
-
|
258 |
|
259 |
type_of_thing.change(fn=swap_text, inputs=[type_of_thing], outputs=[thing_description, thing_image_example, things_naming, perc_txt_encoder], queue=False)
|
|
|
260 |
train_btn = gr.Button("Start Training")
|
261 |
with gr.Box(visible=False) as try_your_model:
|
262 |
gr.Markdown("Try your model")
|
@@ -268,11 +279,11 @@ with gr.Blocks(css=css) as demo:
|
|
268 |
gr.Markdown("Push to Hugging Face Hub")
|
269 |
model_repo_tag = gr.Textbox(label="Model name or URL", placeholder="username/model_name")
|
270 |
push_button = gr.Button("Push to the Hub")
|
271 |
-
result = gr.File(label="Download the uploaded models in the diffusers format
|
272 |
-
convert_button = gr.Button("Convert to CKPT")
|
273 |
|
274 |
-
train_btn.click(fn=train, inputs=is_visible+concept_collection+file_collection+[type_of_thing]+[steps]+[perc_txt_encoder]+[swap_auto_calculated], outputs=[result, try_your_model, push_to_hub])
|
275 |
generate_button.click(fn=generate, inputs=prompt, outputs=result)
|
276 |
push_button.click(fn=push, inputs=model_repo_tag, outputs=[])
|
277 |
-
convert_button.click(fn=
|
278 |
demo.launch()
|
|
|
16 |
shutil.unpack_archive("mix.zip", "mix")
|
17 |
model_to_load = "multimodalart/sd-fine-tunable"
|
18 |
maximum_concepts = 3
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
def swap_text(option):
|
21 |
mandatory_liability = "You must have the right to do so and you are liable for the images you use"
|
|
|
32 |
freeze_for = 10
|
33 |
return [f"You are going to train a `style`, upload 10-20 images of the style you are planning on training on. Name the files with the words you would like {mandatory_liability}:", '''<img src="file/trsl_style.png" />''', f"You should name your files with a unique word that represent your concept (e.g.: `{instance_prompt_example}` here). Images will be automatically cropped to 512x512.", freeze_for]
|
34 |
|
35 |
+
def count_files(*inputs):
|
36 |
+
file_counter = 0
|
37 |
+
for i, input in enumerate(inputs):
|
38 |
+
if(i < maximum_concepts-1):
|
39 |
+
if(input):
|
40 |
+
files = inputs[i+(maximum_concepts*2)]
|
41 |
+
for j, tile_temp in enumerate(files):
|
42 |
+
file_counter+= 1
|
43 |
+
uses_custom = inputs[-1]
|
44 |
+
type_of_thing = inputs[-4]
|
45 |
+
if(uses_custom):
|
46 |
+
Training_Steps = int(inputs[-3])
|
47 |
+
else:
|
48 |
+
if(type_of_thing == "person"):
|
49 |
+
Training_Steps = file_counter*200*2
|
50 |
+
else:
|
51 |
+
Training_Steps = file_counter*200
|
52 |
+
return(gr.update(visible=True, value=f"You are going to train {file_counter} files for {Training_Steps} steps. This should take around {round(Training_Steps/1.5, 2)} seconds, or {round((Training_Steps/1.5)/3600, 2)}. The T4 GPU costs US$0.60 for 1h, so the estimated costs for this training run should be {round(((Training_Steps/1.5)/3600)*0.6, 2)}"))
|
53 |
def train(*inputs):
|
54 |
if os.path.exists("diffusers_model.zip"): os.remove("diffusers_model.zip")
|
55 |
if os.path.exists("model.ckpt"): os.remove("model.ckpt")
|
|
|
167 |
shutil.rmtree('instance_images')
|
168 |
shutil.make_archive("diffusers_model", 'zip', "output_model")
|
169 |
torch.cuda.empty_cache()
|
170 |
+
return [gr.update(visible=True, value=["diffusers_model.zip"]), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)]
|
171 |
|
172 |
def generate(prompt):
|
173 |
from diffusers import StableDiffusionPipeline
|
|
|
180 |
def push(path):
|
181 |
pass
|
182 |
|
183 |
+
def convert_to_ckpt():
|
184 |
convert("output_model", "model.ckpt")
|
185 |
return gr.update(visible=True, value=["diffusers_model.zip", "model.ckpt"])
|
186 |
|
|
|
195 |
<img class="arrow" src="file/arrow.png" />
|
196 |
</div>
|
197 |
''')
|
198 |
+
else:
|
199 |
+
gr.HTML('''
|
200 |
+
<div class="gr-prose" style="max-width: 80%">
|
201 |
+
<h2>You have successfully cloned the Dreambooth Training Space</h2>
|
202 |
+
<p><a href="#">Now you can attribute a T4 GPU to it</a> (by going to the Settings tab) and run the training below. The GPU will be automatically unassigned after training is over. So you will be billed by the minute between when you activate the GPU and when it finishes training.</p>
|
203 |
+
</div>
|
204 |
+
''')
|
205 |
gr.Markdown("# Dreambooth training")
|
206 |
gr.Markdown("Customize Stable Diffusion by giving it with few-shot examples")
|
207 |
with gr.Row():
|
|
|
263 |
steps = gr.Number(label="How many steps", value=800)
|
264 |
perc_txt_encoder = gr.Number(label="Percentage of the training steps the text-encoder should be trained as well", value=30)
|
265 |
|
266 |
+
for file in file_collection:
|
267 |
+
file.change(fn=count_files, inputs=file_collection+[type_of_thing]+[steps]+[perc_txt_encoder]+[swap_auto_calculated], outputs=[training_summary, training_summary])
|
268 |
|
269 |
type_of_thing.change(fn=swap_text, inputs=[type_of_thing], outputs=[thing_description, thing_image_example, things_naming, perc_txt_encoder], queue=False)
|
270 |
+
training_summary = gr.Textbox("", visible=False, label="Training Summary")
|
271 |
train_btn = gr.Button("Start Training")
|
272 |
with gr.Box(visible=False) as try_your_model:
|
273 |
gr.Markdown("Try your model")
|
|
|
279 |
gr.Markdown("Push to Hugging Face Hub")
|
280 |
model_repo_tag = gr.Textbox(label="Model name or URL", placeholder="username/model_name")
|
281 |
push_button = gr.Button("Push to the Hub")
|
282 |
+
result = gr.File(label="Download the uploaded models in the diffusers format", visible=True)
|
283 |
+
convert_button = gr.Button("Convert to CKPT", visible=False)
|
284 |
|
285 |
+
train_btn.click(fn=train, inputs=is_visible+concept_collection+file_collection+[type_of_thing]+[steps]+[perc_txt_encoder]+[swap_auto_calculated], outputs=[result, try_your_model, push_to_hub, convert_button])
|
286 |
generate_button.click(fn=generate, inputs=prompt, outputs=result)
|
287 |
push_button.click(fn=push, inputs=model_repo_tag, outputs=[])
|
288 |
+
convert_button.click(fn=convert_to_ckpt, inputs=[], outputs=result)
|
289 |
demo.launch()
|