hysts HF staff commited on
Commit
b083a19
0 Parent(s):

Duplicate from hysts/LoRA-SD-training

Browse files
Files changed (14) hide show
  1. .gitattributes +34 -0
  2. .gitignore +164 -0
  3. .gitmodules +3 -0
  4. .pre-commit-config.yaml +35 -0
  5. .style.yapf +5 -0
  6. LICENSE +21 -0
  7. README.md +14 -0
  8. app.py +280 -0
  9. inference.py +93 -0
  10. lora +1 -0
  11. requirements.txt +10 -0
  12. style.css +3 -0
  13. trainer.py +121 -0
  14. uploader.py +20 -0
.gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ training_data/
2
+ results/
3
+
4
+
5
+ # Byte-compiled / optimized / DLL files
6
+ __pycache__/
7
+ *.py[cod]
8
+ *$py.class
9
+
10
+ # C extensions
11
+ *.so
12
+
13
+ # Distribution / packaging
14
+ .Python
15
+ build/
16
+ develop-eggs/
17
+ dist/
18
+ downloads/
19
+ eggs/
20
+ .eggs/
21
+ lib/
22
+ lib64/
23
+ parts/
24
+ sdist/
25
+ var/
26
+ wheels/
27
+ share/python-wheels/
28
+ *.egg-info/
29
+ .installed.cfg
30
+ *.egg
31
+ MANIFEST
32
+
33
+ # PyInstaller
34
+ # Usually these files are written by a python script from a template
35
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
36
+ *.manifest
37
+ *.spec
38
+
39
+ # Installer logs
40
+ pip-log.txt
41
+ pip-delete-this-directory.txt
42
+
43
+ # Unit test / coverage reports
44
+ htmlcov/
45
+ .tox/
46
+ .nox/
47
+ .coverage
48
+ .coverage.*
49
+ .cache
50
+ nosetests.xml
51
+ coverage.xml
52
+ *.cover
53
+ *.py,cover
54
+ .hypothesis/
55
+ .pytest_cache/
56
+ cover/
57
+
58
+ # Translations
59
+ *.mo
60
+ *.pot
61
+
62
+ # Django stuff:
63
+ *.log
64
+ local_settings.py
65
+ db.sqlite3
66
+ db.sqlite3-journal
67
+
68
+ # Flask stuff:
69
+ instance/
70
+ .webassets-cache
71
+
72
+ # Scrapy stuff:
73
+ .scrapy
74
+
75
+ # Sphinx documentation
76
+ docs/_build/
77
+
78
+ # PyBuilder
79
+ .pybuilder/
80
+ target/
81
+
82
+ # Jupyter Notebook
83
+ .ipynb_checkpoints
84
+
85
+ # IPython
86
+ profile_default/
87
+ ipython_config.py
88
+
89
+ # pyenv
90
+ # For a library or package, you might want to ignore these files since the code is
91
+ # intended to run in multiple environments; otherwise, check them in:
92
+ # .python-version
93
+
94
+ # pipenv
95
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
96
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
97
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
98
+ # install all needed dependencies.
99
+ #Pipfile.lock
100
+
101
+ # poetry
102
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
103
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
104
+ # commonly ignored for libraries.
105
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
106
+ #poetry.lock
107
+
108
+ # pdm
109
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
110
+ #pdm.lock
111
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
112
+ # in version control.
113
+ # https://pdm.fming.dev/#use-with-ide
114
+ .pdm.toml
115
+
116
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
117
+ __pypackages__/
118
+
119
+ # Celery stuff
120
+ celerybeat-schedule
121
+ celerybeat.pid
122
+
123
+ # SageMath parsed files
124
+ *.sage.py
125
+
126
+ # Environments
127
+ .env
128
+ .venv
129
+ env/
130
+ venv/
131
+ ENV/
132
+ env.bak/
133
+ venv.bak/
134
+
135
+ # Spyder project settings
136
+ .spyderproject
137
+ .spyproject
138
+
139
+ # Rope project settings
140
+ .ropeproject
141
+
142
+ # mkdocs documentation
143
+ /site
144
+
145
+ # mypy
146
+ .mypy_cache/
147
+ .dmypy.json
148
+ dmypy.json
149
+
150
+ # Pyre type checker
151
+ .pyre/
152
+
153
+ # pytype static type analyzer
154
+ .pytype/
155
+
156
+ # Cython debug symbols
157
+ cython_debug/
158
+
159
+ # PyCharm
160
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
161
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
162
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
163
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
164
+ #.idea/
.gitmodules ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [submodule "lora"]
2
+ path = lora
3
+ url = https://github.com/cloneofsimo/lora
.pre-commit-config.yaml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ repos:
2
+ - repo: https://github.com/pre-commit/pre-commit-hooks
3
+ rev: v4.2.0
4
+ hooks:
5
+ - id: check-executables-have-shebangs
6
+ - id: check-json
7
+ - id: check-merge-conflict
8
+ - id: check-shebang-scripts-are-executable
9
+ - id: check-toml
10
+ - id: check-yaml
11
+ - id: double-quote-string-fixer
12
+ - id: end-of-file-fixer
13
+ - id: mixed-line-ending
14
+ args: ['--fix=lf']
15
+ - id: requirements-txt-fixer
16
+ - id: trailing-whitespace
17
+ - repo: https://github.com/myint/docformatter
18
+ rev: v1.4
19
+ hooks:
20
+ - id: docformatter
21
+ args: ['--in-place']
22
+ - repo: https://github.com/pycqa/isort
23
+ rev: 5.10.1
24
+ hooks:
25
+ - id: isort
26
+ - repo: https://github.com/pre-commit/mirrors-mypy
27
+ rev: v0.991
28
+ hooks:
29
+ - id: mypy
30
+ args: ['--ignore-missing-imports']
31
+ - repo: https://github.com/google/yapf
32
+ rev: v0.32.0
33
+ hooks:
34
+ - id: yapf
35
+ args: ['--parallel', '--in-place']
.style.yapf ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ [style]
2
+ based_on_style = pep8
3
+ blank_line_before_nested_class_or_def = false
4
+ spaces_before_comment = 2
5
+ split_before_logical_operator = true
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2022 hysts
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: LoRA + SD Training
3
+ emoji: 🏢
4
+ colorFrom: red
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 3.12.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ duplicated_from: hysts/LoRA-SD-training
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ """Unofficial demo app for https://github.com/cloneofsimo/lora.
3
+
4
+ The code in this repo is partly adapted from the following repository:
5
+ https://huggingface.co/spaces/multimodalart/dreambooth-training/tree/a00184917aa273c6d8adab08d5deb9b39b997938
6
+ The license of the original code is MIT, which is specified in the README.md.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import os
12
+ import pathlib
13
+
14
+ import gradio as gr
15
+ import torch
16
+
17
+ from inference import InferencePipeline
18
+ from trainer import Trainer
19
+ from uploader import upload
20
+
21
+ TITLE = '# LoRA + StableDiffusion Training UI'
22
+ DESCRIPTION = 'This is an unofficial demo for [https://github.com/cloneofsimo/lora](https://github.com/cloneofsimo/lora).'
23
+
24
+ ORIGINAL_SPACE_ID = 'hysts/LoRA-SD-training'
25
+ SPACE_ID = os.getenv('SPACE_ID', ORIGINAL_SPACE_ID)
26
+ SHARED_UI_WARNING = f'''# Attention - This Space doesn't work in this shared UI. You can duplicate and use it with a paid private T4 GPU.
27
+
28
+ <center><a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14" alt="Duplicate Space"></a></center>
29
+ '''
30
+ if os.getenv('SYSTEM') == 'spaces' and SPACE_ID != ORIGINAL_SPACE_ID:
31
+ SETTINGS = f'<a href="https://huggingface.co/spaces/{SPACE_ID}/settings">Settings</a>'
32
+
33
+ else:
34
+ SETTINGS = 'Settings'
35
+ CUDA_NOT_AVAILABLE_WARNING = f'''# Attention - Running on CPU.
36
+ <center>
37
+ You can assign a GPU in the {SETTINGS} tab if you are running this on HF Spaces.
38
+ "T4 small" is sufficient to run this demo.
39
+ </center>
40
+ '''
41
+
42
+
43
+ def show_warning(warning_text: str) -> gr.Blocks:
44
+ with gr.Blocks() as demo:
45
+ with gr.Box():
46
+ gr.Markdown(warning_text)
47
+ return demo
48
+
49
+
50
+ def update_output_files() -> dict:
51
+ paths = sorted(pathlib.Path('results').glob('*.pt'))
52
+ paths = [path.as_posix() for path in paths] # type: ignore
53
+ return gr.update(value=paths or None)
54
+
55
+
56
+ def create_training_demo(trainer: Trainer,
57
+ pipe: InferencePipeline) -> gr.Blocks:
58
+ with gr.Blocks() as demo:
59
+ base_model = gr.Dropdown(
60
+ choices=['stabilityai/stable-diffusion-2-1-base'],
61
+ value='stabilityai/stable-diffusion-2-1-base',
62
+ label='Base Model',
63
+ visible=False)
64
+ resolution = gr.Dropdown(choices=['512'],
65
+ value='512',
66
+ label='Resolution',
67
+ visible=False)
68
+
69
+ with gr.Row():
70
+ with gr.Box():
71
+ gr.Markdown('Training Data')
72
+ concept_images = gr.Files(label='Images for your concept')
73
+ concept_prompt = gr.Textbox(label='Concept Prompt',
74
+ max_lines=1)
75
+ gr.Markdown('''
76
+ - Upload images of the style you are planning on training on.
77
+ - For a concept prompt, use a unique, made up word to avoid collisions.
78
+ ''')
79
+ with gr.Box():
80
+ gr.Markdown('Training Parameters')
81
+ num_training_steps = gr.Number(
82
+ label='Number of Training Steps', value=1000, precision=0)
83
+ learning_rate = gr.Number(label='Learning Rate', value=0.0001)
84
+ train_text_encoder = gr.Checkbox(label='Train Text Encoder',
85
+ value=True)
86
+ learning_rate_text = gr.Number(
87
+ label='Learning Rate for Text Encoder', value=0.00005)
88
+ gradient_accumulation = gr.Number(
89
+ label='Number of Gradient Accumulation',
90
+ value=1,
91
+ precision=0)
92
+ fp16 = gr.Checkbox(label='FP16', value=True)
93
+ use_8bit_adam = gr.Checkbox(label='Use 8bit Adam', value=True)
94
+ gr.Markdown('''
95
+ - It will take about 8 minutes to train for 1000 steps with a T4 GPU.
96
+ - You may want to try a small number of steps first, like 1, to see if everything works fine in your environment.
97
+ - Note that your trained models will be deleted when the second training is started. You can upload your trained model in the "Upload" tab.
98
+ ''')
99
+
100
+ run_button = gr.Button('Start Training')
101
+ with gr.Box():
102
+ with gr.Row():
103
+ check_status_button = gr.Button('Check Training Status')
104
+ with gr.Column():
105
+ with gr.Box():
106
+ gr.Markdown('Message')
107
+ training_status = gr.Markdown()
108
+ output_files = gr.Files(label='Trained Weight Files')
109
+
110
+ run_button.click(fn=pipe.clear)
111
+ run_button.click(fn=trainer.run,
112
+ inputs=[
113
+ base_model,
114
+ resolution,
115
+ concept_images,
116
+ concept_prompt,
117
+ num_training_steps,
118
+ learning_rate,
119
+ train_text_encoder,
120
+ learning_rate_text,
121
+ gradient_accumulation,
122
+ fp16,
123
+ use_8bit_adam,
124
+ ],
125
+ outputs=[
126
+ training_status,
127
+ output_files,
128
+ ],
129
+ queue=False)
130
+ check_status_button.click(fn=trainer.check_if_running,
131
+ inputs=None,
132
+ outputs=training_status,
133
+ queue=False)
134
+ check_status_button.click(fn=update_output_files,
135
+ inputs=None,
136
+ outputs=output_files,
137
+ queue=False)
138
+ return demo
139
+
140
+
141
+ def find_weight_files() -> list[str]:
142
+ curr_dir = pathlib.Path(__file__).parent
143
+ paths = sorted(curr_dir.rglob('*.pt'))
144
+ paths = [path for path in paths if not path.stem.endswith('.text_encoder')]
145
+ return [path.relative_to(curr_dir).as_posix() for path in paths]
146
+
147
+
148
+ def reload_lora_weight_list() -> dict:
149
+ return gr.update(choices=find_weight_files())
150
+
151
+
152
+ def create_inference_demo(pipe: InferencePipeline) -> gr.Blocks:
153
+ with gr.Blocks() as demo:
154
+ with gr.Row():
155
+ with gr.Column():
156
+ base_model = gr.Dropdown(
157
+ choices=['stabilityai/stable-diffusion-2-1-base'],
158
+ value='stabilityai/stable-diffusion-2-1-base',
159
+ label='Base Model',
160
+ visible=False)
161
+ reload_button = gr.Button('Reload Weight List')
162
+ lora_weight_name = gr.Dropdown(choices=find_weight_files(),
163
+ value='lora/lora_disney.pt',
164
+ label='LoRA Weight File')
165
+ prompt = gr.Textbox(
166
+ label='Prompt',
167
+ max_lines=1,
168
+ placeholder='Example: "style of sks, baby lion"')
169
+ alpha = gr.Slider(label='Alpha',
170
+ minimum=0,
171
+ maximum=2,
172
+ step=0.05,
173
+ value=1)
174
+ alpha_for_text = gr.Slider(label='Alpha for Text Encoder',
175
+ minimum=0,
176
+ maximum=2,
177
+ step=0.05,
178
+ value=1)
179
+ seed = gr.Slider(label='Seed',
180
+ minimum=0,
181
+ maximum=100000,
182
+ step=1,
183
+ value=1)
184
+ with gr.Accordion('Other Parameters', open=False):
185
+ num_steps = gr.Slider(label='Number of Steps',
186
+ minimum=0,
187
+ maximum=100,
188
+ step=1,
189
+ value=50)
190
+ guidance_scale = gr.Slider(label='CFG Scale',
191
+ minimum=0,
192
+ maximum=50,
193
+ step=0.1,
194
+ value=7)
195
+
196
+ run_button = gr.Button('Generate')
197
+
198
+ gr.Markdown('''
199
+ - Models with names starting with "lora/" are the pretrained models provided in the [original repo](https://github.com/cloneofsimo/lora), and the ones with names starting with "results/" are your trained models.
200
+ - After training, you can press "Reload Weight List" button to load your trained model names.
201
+ - The pretrained models for "disney", "illust" and "pop" are trained with the concept prompt "style of sks".
202
+ - The pretrained model for "kiriko" is trained with the concept prompt "game character bnha". For this model, the text encoder is also trained.
203
+ ''')
204
+ with gr.Column():
205
+ result = gr.Image(label='Result')
206
+
207
+ reload_button.click(fn=reload_lora_weight_list,
208
+ inputs=None,
209
+ outputs=lora_weight_name)
210
+ prompt.submit(fn=pipe.run,
211
+ inputs=[
212
+ base_model,
213
+ lora_weight_name,
214
+ prompt,
215
+ alpha,
216
+ alpha_for_text,
217
+ seed,
218
+ num_steps,
219
+ guidance_scale,
220
+ ],
221
+ outputs=result,
222
+ queue=False)
223
+ run_button.click(fn=pipe.run,
224
+ inputs=[
225
+ base_model,
226
+ lora_weight_name,
227
+ prompt,
228
+ alpha,
229
+ alpha_for_text,
230
+ seed,
231
+ num_steps,
232
+ guidance_scale,
233
+ ],
234
+ outputs=result,
235
+ queue=False)
236
+ return demo
237
+
238
+
239
+ def create_upload_demo() -> gr.Blocks:
240
+ with gr.Blocks() as demo:
241
+ model_name = gr.Textbox(label='Model Name')
242
+ hf_token = gr.Textbox(
243
+ label='Hugging Face Token (with write permission)')
244
+ upload_button = gr.Button('Upload')
245
+ with gr.Box():
246
+ gr.Markdown('Message')
247
+ result = gr.Markdown()
248
+ gr.Markdown('''
249
+ - You can upload your trained model to your private Model repo (i.e. https://huggingface.co/{your_username}/{model_name}).
250
+ - You can find your Hugging Face token [here](https://huggingface.co/settings/tokens).
251
+ ''')
252
+
253
+ upload_button.click(fn=upload,
254
+ inputs=[model_name, hf_token],
255
+ outputs=result)
256
+
257
+ return demo
258
+
259
+
260
+ pipe = InferencePipeline()
261
+ trainer = Trainer()
262
+
263
+ with gr.Blocks(css='style.css') as demo:
264
+ if os.getenv('IS_SHARED_UI'):
265
+ show_warning(SHARED_UI_WARNING)
266
+ if not torch.cuda.is_available():
267
+ show_warning(CUDA_NOT_AVAILABLE_WARNING)
268
+
269
+ gr.Markdown(TITLE)
270
+ gr.Markdown(DESCRIPTION)
271
+
272
+ with gr.Tabs():
273
+ with gr.TabItem('Train'):
274
+ create_training_demo(trainer, pipe)
275
+ with gr.TabItem('Test'):
276
+ create_inference_demo(pipe)
277
+ with gr.TabItem('Upload'):
278
+ create_upload_demo()
279
+
280
+ demo.queue(default_enabled=False).launch(share=False)
inference.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import gc
4
+ import pathlib
5
+ import sys
6
+
7
+ import gradio as gr
8
+ import PIL.Image
9
+ import torch
10
+ from diffusers import StableDiffusionPipeline
11
+
12
+ sys.path.insert(0, 'lora')
13
+ from lora_diffusion import monkeypatch_lora, tune_lora_scale
14
+
15
+
16
+ class InferencePipeline:
17
+ def __init__(self):
18
+ self.pipe = None
19
+ self.device = torch.device(
20
+ 'cuda:0' if torch.cuda.is_available() else 'cpu')
21
+ self.weight_path = None
22
+
23
+ def clear(self) -> None:
24
+ self.weight_path = None
25
+ del self.pipe
26
+ self.pipe = None
27
+ torch.cuda.empty_cache()
28
+ gc.collect()
29
+
30
+ @staticmethod
31
+ def get_lora_weight_path(name: str) -> pathlib.Path:
32
+ curr_dir = pathlib.Path(__file__).parent
33
+ return curr_dir / name
34
+
35
+ @staticmethod
36
+ def get_lora_text_encoder_weight_path(path: pathlib.Path) -> str:
37
+ parent_dir = path.parent
38
+ stem = path.stem
39
+ text_encoder_filename = f'{stem}.text_encoder.pt'
40
+ path = parent_dir / text_encoder_filename
41
+ return path.as_posix() if path.exists() else ''
42
+
43
+ def load_pipe(self, model_id: str, lora_filename: str) -> None:
44
+ weight_path = self.get_lora_weight_path(lora_filename)
45
+ if weight_path == self.weight_path:
46
+ return
47
+ self.weight_path = weight_path
48
+ lora_weight = torch.load(self.weight_path, map_location=self.device)
49
+
50
+ if self.device.type == 'cpu':
51
+ pipe = StableDiffusionPipeline.from_pretrained(model_id)
52
+ else:
53
+ pipe = StableDiffusionPipeline.from_pretrained(
54
+ model_id, torch_dtype=torch.float16)
55
+ pipe = pipe.to(self.device)
56
+
57
+ monkeypatch_lora(pipe.unet, lora_weight)
58
+
59
+ lora_text_encoder_weight_path = self.get_lora_text_encoder_weight_path(
60
+ weight_path)
61
+ if lora_text_encoder_weight_path:
62
+ lora_text_encoder_weight = torch.load(
63
+ lora_text_encoder_weight_path, map_location=self.device)
64
+ monkeypatch_lora(pipe.text_encoder,
65
+ lora_text_encoder_weight,
66
+ target_replace_module=['CLIPAttention'])
67
+
68
+ self.pipe = pipe
69
+
70
+ def run(
71
+ self,
72
+ base_model: str,
73
+ lora_weight_name: str,
74
+ prompt: str,
75
+ alpha: float,
76
+ alpha_for_text: float,
77
+ seed: int,
78
+ n_steps: int,
79
+ guidance_scale: float,
80
+ ) -> PIL.Image.Image:
81
+ if not torch.cuda.is_available():
82
+ raise gr.Error('CUDA is not available.')
83
+
84
+ self.load_pipe(base_model, lora_weight_name)
85
+
86
+ generator = torch.Generator(device=self.device).manual_seed(seed)
87
+ tune_lora_scale(self.pipe.unet, alpha) # type: ignore
88
+ tune_lora_scale(self.pipe.text_encoder, alpha_for_text) # type: ignore
89
+ out = self.pipe(prompt,
90
+ num_inference_steps=n_steps,
91
+ guidance_scale=guidance_scale,
92
+ generator=generator) # type: ignore
93
+ return out.images[0]
lora ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit 26787a09bff4ebcb08f0ad4e848b67bce4389a7a
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.15.0
2
+ bitsandbytes==0.35.4
3
+ diffusers==0.10.2
4
+ ftfy==6.1.1
5
+ Pillow==9.3.0
6
+ torch==1.13.0
7
+ torchvision==0.14.0
8
+ transformers==4.25.1
9
+ triton==2.0.0.dev20220701
10
+ xformers==0.0.13
style.css ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ }
trainer.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import pathlib
5
+ import shlex
6
+ import shutil
7
+ import subprocess
8
+
9
+ import gradio as gr
10
+ import PIL.Image
11
+ import torch
12
+
13
+ os.environ['PYTHONPATH'] = f'lora:{os.getenv("PYTHONPATH", "")}'
14
+
15
+
16
+ def pad_image(image: PIL.Image.Image) -> PIL.Image.Image:
17
+ w, h = image.size
18
+ if w == h:
19
+ return image
20
+ elif w > h:
21
+ new_image = PIL.Image.new(image.mode, (w, w), (0, 0, 0))
22
+ new_image.paste(image, (0, (w - h) // 2))
23
+ return new_image
24
+ else:
25
+ new_image = PIL.Image.new(image.mode, (h, h), (0, 0, 0))
26
+ new_image.paste(image, ((h - w) // 2, 0))
27
+ return new_image
28
+
29
+
30
+ class Trainer:
31
+ def __init__(self):
32
+ self.is_running = False
33
+ self.is_running_message = 'Another training is in progress.'
34
+
35
+ self.output_dir = pathlib.Path('results')
36
+ self.instance_data_dir = self.output_dir / 'training_data'
37
+
38
+ def check_if_running(self) -> dict:
39
+ if self.is_running:
40
+ return gr.update(value=self.is_running_message)
41
+ else:
42
+ return gr.update(value='No training is running.')
43
+
44
+ def cleanup_dirs(self) -> None:
45
+ shutil.rmtree(self.output_dir, ignore_errors=True)
46
+
47
+ def prepare_dataset(self, concept_images: list, resolution: int) -> None:
48
+ self.instance_data_dir.mkdir(parents=True)
49
+ for i, temp_path in enumerate(concept_images):
50
+ image = PIL.Image.open(temp_path.name)
51
+ image = pad_image(image)
52
+ image = image.resize((resolution, resolution))
53
+ image = image.convert('RGB')
54
+ out_path = self.instance_data_dir / f'{i:03d}.jpg'
55
+ image.save(out_path, format='JPEG', quality=100)
56
+
57
+ def run(
58
+ self,
59
+ base_model: str,
60
+ resolution_s: str,
61
+ concept_images: list | None,
62
+ concept_prompt: str,
63
+ n_steps: int,
64
+ learning_rate: float,
65
+ train_text_encoder: bool,
66
+ learning_rate_text: float,
67
+ gradient_accumulation: int,
68
+ fp16: bool,
69
+ use_8bit_adam: bool,
70
+ ) -> tuple[dict, list[pathlib.Path]]:
71
+ if not torch.cuda.is_available():
72
+ raise gr.Error('CUDA is not available.')
73
+
74
+ if self.is_running:
75
+ return gr.update(value=self.is_running_message), []
76
+
77
+ if concept_images is None:
78
+ raise gr.Error('You need to upload images.')
79
+ if not concept_prompt:
80
+ raise gr.Error('The concept prompt is missing.')
81
+
82
+ resolution = int(resolution_s)
83
+
84
+ self.cleanup_dirs()
85
+ self.prepare_dataset(concept_images, resolution)
86
+
87
+ command = f'''
88
+ accelerate launch lora/train_lora_dreambooth.py \
89
+ --pretrained_model_name_or_path={base_model} \
90
+ --instance_data_dir={self.instance_data_dir} \
91
+ --output_dir={self.output_dir} \
92
+ --instance_prompt="{concept_prompt}" \
93
+ --resolution={resolution} \
94
+ --train_batch_size=1 \
95
+ --gradient_accumulation_steps={gradient_accumulation} \
96
+ --learning_rate={learning_rate} \
97
+ --lr_scheduler=constant \
98
+ --lr_warmup_steps=0 \
99
+ --max_train_steps={n_steps}
100
+ '''
101
+ if fp16:
102
+ command += ' --mixed_precision fp16'
103
+ if use_8bit_adam:
104
+ command += ' --use_8bit_adam'
105
+ if train_text_encoder:
106
+ command += f' --train_text_encoder --learning_rate_text={learning_rate_text} --color_jitter'
107
+
108
+ with open(self.output_dir / 'train.sh', 'w') as f:
109
+ command_s = ' '.join(command.split())
110
+ f.write(command_s)
111
+
112
+ self.is_running = True
113
+ res = subprocess.run(shlex.split(command))
114
+ self.is_running = False
115
+
116
+ if res.returncode == 0:
117
+ result_message = 'Training Completed!'
118
+ else:
119
+ result_message = 'Training Failed!'
120
+ weight_paths = sorted(self.output_dir.glob('*.pt'))
121
+ return gr.update(value=result_message), weight_paths
uploader.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from huggingface_hub import HfApi
3
+
4
+
5
+ def upload(model_name: str, hf_token: str) -> None:
6
+ api = HfApi(token=hf_token)
7
+ user_name = api.whoami()['name']
8
+ model_id = f'{user_name}/{model_name}'
9
+ try:
10
+ api.create_repo(model_id, repo_type='model', private=True)
11
+ api.upload_folder(repo_id=model_id,
12
+ folder_path='results',
13
+ path_in_repo='results',
14
+ repo_type='model')
15
+ url = f'https://huggingface.co/{model_id}'
16
+ message = f'Your model was successfully uploaded to [{url}]({url}).'
17
+ except Exception as e:
18
+ message = str(e)
19
+
20
+ return gr.update(value=message, visible=True)