Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
@@ -72,6 +72,8 @@ def create_training_demo(trainer: Trainer,
|
|
72 |
concept_images = gr.Files(label='Images for your concept')
|
73 |
concept_prompt = gr.Textbox(label='Concept Prompt',
|
74 |
max_lines=1)
|
|
|
|
|
75 |
gr.Markdown('''
|
76 |
- Upload images of the style you are planning on training on.
|
77 |
- For a concept prompt, use a unique, made up word to avoid collisions.
|
@@ -80,11 +82,13 @@ def create_training_demo(trainer: Trainer,
|
|
80 |
gr.Markdown('Training Parameters')
|
81 |
num_training_steps = gr.Number(
|
82 |
label='Number of Training Steps', value=1000, precision=0)
|
83 |
-
learning_rate = gr.Number(label='Learning Rate', value=0.
|
84 |
train_text_encoder = gr.Checkbox(label='Train Text Encoder',
|
85 |
value=True)
|
|
|
|
|
86 |
learning_rate_text = gr.Number(
|
87 |
-
label='Learning Rate for Text Encoder', value=0.
|
88 |
gradient_accumulation = gr.Number(
|
89 |
label='Number of Gradient Accumulation',
|
90 |
value=1,
|
@@ -145,7 +149,7 @@ def find_weight_files() -> list[str]:
|
|
145 |
return [path.relative_to(curr_dir).as_posix() for path in paths]
|
146 |
|
147 |
|
148 |
-
def
|
149 |
return gr.update(choices=find_weight_files())
|
150 |
|
151 |
|
@@ -159,23 +163,13 @@ def create_inference_demo(pipe: InferencePipeline) -> gr.Blocks:
|
|
159 |
label='Base Model',
|
160 |
visible=False)
|
161 |
reload_button = gr.Button('Reload Weight List')
|
162 |
-
|
163 |
-
value='
|
164 |
-
label='
|
165 |
prompt = gr.Textbox(
|
166 |
label='Prompt',
|
167 |
max_lines=1,
|
168 |
-
placeholder='Example: "
|
169 |
-
alpha = gr.Slider(label='Alpha',
|
170 |
-
minimum=0,
|
171 |
-
maximum=2,
|
172 |
-
step=0.05,
|
173 |
-
value=1)
|
174 |
-
alpha_for_text = gr.Slider(label='Alpha for Text Encoder',
|
175 |
-
minimum=0,
|
176 |
-
maximum=2,
|
177 |
-
step=0.05,
|
178 |
-
value=1)
|
179 |
seed = gr.Slider(label='Seed',
|
180 |
minimum=0,
|
181 |
maximum=100000,
|
@@ -184,52 +178,53 @@ def create_inference_demo(pipe: InferencePipeline) -> gr.Blocks:
|
|
184 |
with gr.Accordion('Other Parameters', open=False):
|
185 |
num_steps = gr.Slider(label='Number of Steps',
|
186 |
minimum=0,
|
187 |
-
maximum=
|
188 |
step=1,
|
189 |
-
value=
|
190 |
guidance_scale = gr.Slider(label='CFG Scale',
|
191 |
minimum=0,
|
192 |
maximum=50,
|
193 |
step=0.1,
|
194 |
-
value=
|
|
|
|
|
|
|
|
|
|
|
195 |
|
196 |
run_button = gr.Button('Generate')
|
197 |
|
198 |
gr.Markdown('''
|
199 |
-
- Models with names starting with "
|
200 |
- After training, you can press "Reload Weight List" button to load your trained model names.
|
201 |
-
- The pretrained models for "disney", "illust" and "pop" are trained with the concept prompt "style of sks".
|
202 |
-
- The pretrained model for "kiriko" is trained with the concept prompt "game character bnha". For this model, the text encoder is also trained.
|
203 |
''')
|
204 |
with gr.Column():
|
205 |
result = gr.Image(label='Result')
|
206 |
|
207 |
-
reload_button.click(fn=
|
208 |
inputs=None,
|
209 |
-
outputs=
|
210 |
prompt.submit(fn=pipe.run,
|
211 |
inputs=[
|
212 |
base_model,
|
213 |
-
|
214 |
prompt,
|
215 |
-
alpha,
|
216 |
-
alpha_for_text,
|
217 |
seed,
|
218 |
num_steps,
|
219 |
guidance_scale,
|
|
|
220 |
],
|
221 |
outputs=result,
|
222 |
queue=False)
|
223 |
run_button.click(fn=pipe.run,
|
224 |
inputs=[
|
225 |
base_model,
|
226 |
-
|
227 |
prompt,
|
228 |
-
alpha,
|
229 |
-
alpha_for_text,
|
230 |
seed,
|
231 |
num_steps,
|
232 |
guidance_scale,
|
|
|
233 |
],
|
234 |
outputs=result,
|
235 |
queue=False)
|
|
|
72 |
concept_images = gr.Files(label='Images for your concept')
|
73 |
concept_prompt = gr.Textbox(label='Concept Prompt',
|
74 |
max_lines=1)
|
75 |
+
class_prompt = gr.Textbox(label='Regularization set Prompt',
|
76 |
+
max_lines=1)
|
77 |
gr.Markdown('''
|
78 |
- Upload images of the style you are planning on training on.
|
79 |
- For a concept prompt, use a unique, made up word to avoid collisions.
|
|
|
82 |
gr.Markdown('Training Parameters')
|
83 |
num_training_steps = gr.Number(
|
84 |
label='Number of Training Steps', value=1000, precision=0)
|
85 |
+
learning_rate = gr.Number(label='Learning Rate', value=0.00001)
|
86 |
train_text_encoder = gr.Checkbox(label='Train Text Encoder',
|
87 |
value=True)
|
88 |
+
modifier_token = gr.Checkbox(label='modifier token',
|
89 |
+
value=True)
|
90 |
learning_rate_text = gr.Number(
|
91 |
+
label='Learning Rate for Text Encoder', value=0.00001)
|
92 |
gradient_accumulation = gr.Number(
|
93 |
label='Number of Gradient Accumulation',
|
94 |
value=1,
|
|
|
149 |
return [path.relative_to(curr_dir).as_posix() for path in paths]
|
150 |
|
151 |
|
152 |
+
def reload_custom_diffusion_weight_list() -> dict:
|
153 |
return gr.update(choices=find_weight_files())
|
154 |
|
155 |
|
|
|
163 |
label='Base Model',
|
164 |
visible=False)
|
165 |
reload_button = gr.Button('Reload Weight List')
|
166 |
+
weight_name = gr.Dropdown(choices=find_weight_files(),
|
167 |
+
value='custom-diffusion/cat.ckpt',
|
168 |
+
label='Custom Diffusion Weight File')
|
169 |
prompt = gr.Textbox(
|
170 |
label='Prompt',
|
171 |
max_lines=1,
|
172 |
+
placeholder='Example: "<new1> cat swimming in a pool"')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
173 |
seed = gr.Slider(label='Seed',
|
174 |
minimum=0,
|
175 |
maximum=100000,
|
|
|
178 |
with gr.Accordion('Other Parameters', open=False):
|
179 |
num_steps = gr.Slider(label='Number of Steps',
|
180 |
minimum=0,
|
181 |
+
maximum=500,
|
182 |
step=1,
|
183 |
+
value=200)
|
184 |
guidance_scale = gr.Slider(label='CFG Scale',
|
185 |
minimum=0,
|
186 |
maximum=50,
|
187 |
step=0.1,
|
188 |
+
value=6
|
189 |
+
eta = gr.Slider(label='CFG Scale',
|
190 |
+
minimum=0,
|
191 |
+
maximum=1.,
|
192 |
+
step=0.1,
|
193 |
+
value=1.)
|
194 |
|
195 |
run_button = gr.Button('Generate')
|
196 |
|
197 |
gr.Markdown('''
|
198 |
+
- Models with names starting with "custom-diffusion/" are the pretrained models provided in the [original repo](https://github.com/adobe-research/custom-diffusion), and the ones with names starting with "results/" are your trained models.
|
199 |
- After training, you can press "Reload Weight List" button to load your trained model names.
|
|
|
|
|
200 |
''')
|
201 |
with gr.Column():
|
202 |
result = gr.Image(label='Result')
|
203 |
|
204 |
+
reload_button.click(fn=reload_custom_diffusion_weight_list,
|
205 |
inputs=None,
|
206 |
+
outputs=weight_name)
|
207 |
prompt.submit(fn=pipe.run,
|
208 |
inputs=[
|
209 |
base_model,
|
210 |
+
weight_name,
|
211 |
prompt,
|
|
|
|
|
212 |
seed,
|
213 |
num_steps,
|
214 |
guidance_scale,
|
215 |
+
eta,
|
216 |
],
|
217 |
outputs=result,
|
218 |
queue=False)
|
219 |
run_button.click(fn=pipe.run,
|
220 |
inputs=[
|
221 |
base_model,
|
222 |
+
weight_name,
|
223 |
prompt,
|
|
|
|
|
224 |
seed,
|
225 |
num_steps,
|
226 |
guidance_scale,
|
227 |
+
eta,
|
228 |
],
|
229 |
outputs=result,
|
230 |
queue=False)
|