mkshing commited on
Commit
e2a20af
0 Parent(s):

first commit

Browse files
Files changed (18) hide show
  1. .gitattributes +34 -0
  2. .gitignore +165 -0
  3. .pre-commit-config.yaml +37 -0
  4. .style.yapf +5 -0
  5. LICENSE +21 -0
  6. README.md +15 -0
  7. app.py +78 -0
  8. app_inference.py +170 -0
  9. app_training.py +148 -0
  10. app_upload.py +100 -0
  11. constants.py +6 -0
  12. inference.py +103 -0
  13. requirements.txt +5 -0
  14. style.css +3 -0
  15. train_svdiff.py +1013 -0
  16. trainer.py +175 -0
  17. uploader.py +42 -0
  18. utils.py +58 -0
.gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ training_data/
2
+ experiments/
3
+ wandb/
4
+
5
+
6
+ # Byte-compiled / optimized / DLL files
7
+ __pycache__/
8
+ *.py[cod]
9
+ *$py.class
10
+
11
+ # C extensions
12
+ *.so
13
+
14
+ # Distribution / packaging
15
+ .Python
16
+ build/
17
+ develop-eggs/
18
+ dist/
19
+ downloads/
20
+ eggs/
21
+ .eggs/
22
+ lib/
23
+ lib64/
24
+ parts/
25
+ sdist/
26
+ var/
27
+ wheels/
28
+ share/python-wheels/
29
+ *.egg-info/
30
+ .installed.cfg
31
+ *.egg
32
+ MANIFEST
33
+
34
+ # PyInstaller
35
+ # Usually these files are written by a python script from a template
36
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
37
+ *.manifest
38
+ *.spec
39
+
40
+ # Installer logs
41
+ pip-log.txt
42
+ pip-delete-this-directory.txt
43
+
44
+ # Unit test / coverage reports
45
+ htmlcov/
46
+ .tox/
47
+ .nox/
48
+ .coverage
49
+ .coverage.*
50
+ .cache
51
+ nosetests.xml
52
+ coverage.xml
53
+ *.cover
54
+ *.py,cover
55
+ .hypothesis/
56
+ .pytest_cache/
57
+ cover/
58
+
59
+ # Translations
60
+ *.mo
61
+ *.pot
62
+
63
+ # Django stuff:
64
+ *.log
65
+ local_settings.py
66
+ db.sqlite3
67
+ db.sqlite3-journal
68
+
69
+ # Flask stuff:
70
+ instance/
71
+ .webassets-cache
72
+
73
+ # Scrapy stuff:
74
+ .scrapy
75
+
76
+ # Sphinx documentation
77
+ docs/_build/
78
+
79
+ # PyBuilder
80
+ .pybuilder/
81
+ target/
82
+
83
+ # Jupyter Notebook
84
+ .ipynb_checkpoints
85
+
86
+ # IPython
87
+ profile_default/
88
+ ipython_config.py
89
+
90
+ # pyenv
91
+ # For a library or package, you might want to ignore these files since the code is
92
+ # intended to run in multiple environments; otherwise, check them in:
93
+ # .python-version
94
+
95
+ # pipenv
96
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
97
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
98
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
99
+ # install all needed dependencies.
100
+ #Pipfile.lock
101
+
102
+ # poetry
103
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
104
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
105
+ # commonly ignored for libraries.
106
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
107
+ #poetry.lock
108
+
109
+ # pdm
110
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
111
+ #pdm.lock
112
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
113
+ # in version control.
114
+ # https://pdm.fming.dev/#use-with-ide
115
+ .pdm.toml
116
+
117
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
118
+ __pypackages__/
119
+
120
+ # Celery stuff
121
+ celerybeat-schedule
122
+ celerybeat.pid
123
+
124
+ # SageMath parsed files
125
+ *.sage.py
126
+
127
+ # Environments
128
+ .env
129
+ .venv
130
+ env/
131
+ venv/
132
+ ENV/
133
+ env.bak/
134
+ venv.bak/
135
+
136
+ # Spyder project settings
137
+ .spyderproject
138
+ .spyproject
139
+
140
+ # Rope project settings
141
+ .ropeproject
142
+
143
+ # mkdocs documentation
144
+ /site
145
+
146
+ # mypy
147
+ .mypy_cache/
148
+ .dmypy.json
149
+ dmypy.json
150
+
151
+ # Pyre type checker
152
+ .pyre/
153
+
154
+ # pytype static type analyzer
155
+ .pytype/
156
+
157
+ # Cython debug symbols
158
+ cython_debug/
159
+
160
+ # PyCharm
161
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
162
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
163
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
164
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
165
+ #.idea/
.pre-commit-config.yaml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ exclude: train_dreambooth_lora.py
2
+ repos:
3
+ - repo: https://github.com/pre-commit/pre-commit-hooks
4
+ rev: v4.2.0
5
+ hooks:
6
+ - id: check-executables-have-shebangs
7
+ - id: check-json
8
+ - id: check-merge-conflict
9
+ - id: check-shebang-scripts-are-executable
10
+ - id: check-toml
11
+ - id: check-yaml
12
+ - id: double-quote-string-fixer
13
+ - id: end-of-file-fixer
14
+ - id: mixed-line-ending
15
+ args: ['--fix=lf']
16
+ - id: requirements-txt-fixer
17
+ - id: trailing-whitespace
18
+ - repo: https://github.com/myint/docformatter
19
+ rev: v1.4
20
+ hooks:
21
+ - id: docformatter
22
+ args: ['--in-place']
23
+ - repo: https://github.com/pycqa/isort
24
+ rev: 5.10.1
25
+ hooks:
26
+ - id: isort
27
+ - repo: https://github.com/pre-commit/mirrors-mypy
28
+ rev: v0.991
29
+ hooks:
30
+ - id: mypy
31
+ args: ['--ignore-missing-imports']
32
+ additional_dependencies: ['types-python-slugify']
33
+ - repo: https://github.com/google/yapf
34
+ rev: v0.32.0
35
+ hooks:
36
+ - id: yapf
37
+ args: ['--parallel', '--in-place']
.style.yapf ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ [style]
2
+ based_on_style = pep8
3
+ blank_line_before_nested_class_or_def = false
4
+ spaces_before_comment = 2
5
+ split_before_logical_operator = true
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2022 hysts
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: SVDiff-pytorch Training UI
3
+ emoji: ⚡
4
+ colorFrom: red
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 3.16.2
8
+ python_version: 3.10.9
9
+ app_file: app.py
10
+ pinned: false
11
+ license: mit
12
+ duplicated_from: mshing/SVDiff-pytorch-UI
13
+ ---
14
+
15
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+
7
+ import gradio as gr
8
+ import torch
9
+
10
+ from app_inference import create_inference_demo
11
+ from app_training import create_training_demo
12
+ from app_upload import create_upload_demo
13
+ from inference import InferencePipeline
14
+ from trainer import Trainer
15
+
16
+ TITLE = """# SVDiff-pytorch Training UI
17
+ This demo is based on https://github.com/mkshing/svdiff-pytorch, which is an implementation of "SVDiff: Compact Parameter Space for Diffusion Fine-Tuning" by [mkshing](https://twitter.com/mk1stats)
18
+ """
19
+
20
+ ORIGINAL_SPACE_ID = 'mshing/SVDiff-pytorch-UI'
21
+ SPACE_ID = os.getenv('SPACE_ID', ORIGINAL_SPACE_ID)
22
+ SHARED_UI_WARNING = f'''# Attention - This Space doesn't work in this shared UI. You can duplicate and use it with a paid private T4 GPU.
23
+
24
+ <center><a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14" alt="Duplicate Space"></a></center>
25
+ '''
26
+
27
+ if os.getenv('SYSTEM') == 'spaces' and SPACE_ID != ORIGINAL_SPACE_ID:
28
+ SETTINGS = f'<a href="https://huggingface.co/spaces/{SPACE_ID}/settings">Settings</a>'
29
+ else:
30
+ SETTINGS = 'Settings'
31
+ CUDA_NOT_AVAILABLE_WARNING = f'''# Attention - Running on CPU.
32
+ <center>
33
+ You can assign a GPU in the {SETTINGS} tab if you are running this on HF Spaces.
34
+ "T4 small" is sufficient to run this demo.
35
+ </center>
36
+ '''
37
+
38
+ HF_TOKEN_NOT_SPECIFIED_WARNING = f'''# Attention - The environment variable `HF_TOKEN` is not specified. Please specify your Hugging Face token with write permission as the value of it.
39
+ <center>
40
+ You can check and create your Hugging Face tokens <a href="https://huggingface.co/settings/tokens" target="_blank">here</a>.
41
+ You can specify environment variables in the "Repository secrets" section of the {SETTINGS} tab.
42
+ </center>
43
+ '''
44
+
45
+ HF_TOKEN = os.getenv('HF_TOKEN')
46
+
47
+
48
+ def show_warning(warning_text: str) -> gr.Blocks:
49
+ with gr.Blocks() as demo:
50
+ with gr.Box():
51
+ gr.Markdown(warning_text)
52
+ return demo
53
+
54
+
55
+ pipe = InferencePipeline(HF_TOKEN)
56
+ trainer = Trainer(HF_TOKEN)
57
+
58
+ with gr.Blocks(css='style.css') as demo:
59
+ if os.getenv('IS_SHARED_UI'):
60
+ show_warning(SHARED_UI_WARNING)
61
+ if not torch.cuda.is_available():
62
+ show_warning(CUDA_NOT_AVAILABLE_WARNING)
63
+ if not HF_TOKEN:
64
+ show_warning(HF_TOKEN_NOT_SPECIFIED_WARNING)
65
+
66
+ gr.Markdown(TITLE)
67
+ with gr.Tabs():
68
+ with gr.TabItem('Train'):
69
+ create_training_demo(trainer, pipe)
70
+ with gr.TabItem('Test'):
71
+ create_inference_demo(pipe, HF_TOKEN)
72
+ with gr.TabItem('Upload'):
73
+ gr.Markdown('''
74
+ - You can use this tab to upload models later if you choose not to upload models in training time or if upload in training time failed.
75
+ ''')
76
+ create_upload_demo(HF_TOKEN)
77
+
78
+ demo.queue(max_size=1).launch(share=False)
app_inference.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import enum
6
+
7
+ import gradio as gr
8
+ from huggingface_hub import HfApi
9
+
10
+ from inference import InferencePipeline
11
+ from utils import find_exp_dirs
12
+
13
+ SAMPLE_MODEL_IDS = [
14
+ 'svdiff-library/svdiff_dog_example',
15
+ 'mshing/svdiff_kumamon_example',
16
+ ]
17
+
18
+
19
+ class ModelSource(enum.Enum):
20
+ SAMPLE = 'Sample'
21
+ HUB_LIB = 'Hub (svdiff-library)'
22
+ LOCAL = 'Local'
23
+
24
+
25
+ class InferenceUtil:
26
+ def __init__(self, hf_token: str | None):
27
+ self.hf_token = hf_token
28
+
29
+ @staticmethod
30
+ def load_sample_model_list():
31
+ return gr.update(choices=SAMPLE_MODEL_IDS, value=SAMPLE_MODEL_IDS[0])
32
+
33
+ def load_hub_model_list(self) -> dict:
34
+ api = HfApi(token=self.hf_token)
35
+ choices = [
36
+ info.modelId for info in api.list_models(author='svdiff-library')
37
+ ]
38
+ return gr.update(choices=choices,
39
+ value=choices[0] if choices else None)
40
+
41
+ @staticmethod
42
+ def load_local_model_list() -> dict:
43
+ choices = find_exp_dirs()
44
+ return gr.update(choices=choices,
45
+ value=choices[0] if choices else None)
46
+
47
+ def reload_model_list(self, model_source: str) -> dict:
48
+ if model_source == ModelSource.SAMPLE.value:
49
+ return self.load_sample_model_list()
50
+ elif model_source == ModelSource.HUB_LIB.value:
51
+ return self.load_hub_model_list()
52
+ elif model_source == ModelSource.LOCAL.value:
53
+ return self.load_local_model_list()
54
+ else:
55
+ raise ValueError
56
+
57
+ def load_model_info(self, model_id: str) -> tuple[str, str]:
58
+ try:
59
+ card = InferencePipeline.get_model_card(model_id,
60
+ self.hf_token)
61
+ except Exception:
62
+ return '', ''
63
+ base_model = getattr(card.data, 'base_model', '')
64
+ instance_prompt = getattr(card.data, 'instance_prompt', '')
65
+ return base_model, instance_prompt
66
+
67
+ def reload_model_list_and_update_model_info(
68
+ self, model_source: str) -> tuple[dict, str, str]:
69
+ model_list_update = self.reload_model_list(model_source)
70
+ model_list = model_list_update['choices']
71
+ model_info = self.load_model_info(model_list[0] if model_list else '')
72
+ return model_list_update, *model_info
73
+
74
+
75
+ def create_inference_demo(pipe: InferencePipeline,
76
+ hf_token: str | None = None) -> gr.Blocks:
77
+ app = InferenceUtil(hf_token)
78
+
79
+ with gr.Blocks() as demo:
80
+ with gr.Row():
81
+ with gr.Column():
82
+ with gr.Box():
83
+ model_source = gr.Radio(
84
+ label='Model Source',
85
+ choices=[_.value for _ in ModelSource],
86
+ value=ModelSource.SAMPLE.value)
87
+ reload_button = gr.Button('Reload Model List')
88
+ model_id = gr.Dropdown(label='Model ID',
89
+ choices=SAMPLE_MODEL_IDS,
90
+ value=SAMPLE_MODEL_IDS[0])
91
+ with gr.Accordion(
92
+ label=
93
+ 'Model info (Base model and instance prompt used for training)',
94
+ open=False):
95
+ with gr.Row():
96
+ base_model_used_for_training = gr.Text(
97
+ label='Base model', interactive=False)
98
+ instance_prompt_used_for_training = gr.Text(
99
+ label='Instance prompt', interactive=False)
100
+ prompt = gr.Textbox(
101
+ label='Prompt',
102
+ max_lines=1,
103
+ placeholder='Example: "A picture of a sks dog in a bucket"'
104
+ )
105
+ seed = gr.Slider(label='Seed',
106
+ minimum=0,
107
+ maximum=100000,
108
+ step=1,
109
+ value=0)
110
+ with gr.Accordion('Other Parameters', open=False):
111
+ num_steps = gr.Slider(label='Number of Steps',
112
+ minimum=0,
113
+ maximum=100,
114
+ step=1,
115
+ value=25)
116
+ guidance_scale = gr.Slider(label='CFG Scale',
117
+ minimum=0,
118
+ maximum=50,
119
+ step=0.1,
120
+ value=7.5)
121
+
122
+ run_button = gr.Button('Generate')
123
+
124
+ gr.Markdown('''
125
+ - After training, you can press "Reload Model List" button to load your trained model names.
126
+ ''')
127
+ with gr.Column():
128
+ result = gr.Image(label='Result')
129
+
130
+ model_source.change(
131
+ fn=app.reload_model_list_and_update_model_info,
132
+ inputs=model_source,
133
+ outputs=[
134
+ model_id,
135
+ base_model_used_for_training,
136
+ instance_prompt_used_for_training,
137
+ ])
138
+ reload_button.click(
139
+ fn=app.reload_model_list_and_update_model_info,
140
+ inputs=model_source,
141
+ outputs=[
142
+ model_id,
143
+ base_model_used_for_training,
144
+ instance_prompt_used_for_training,
145
+ ])
146
+ model_id.change(fn=app.load_model_info,
147
+ inputs=model_id,
148
+ outputs=[
149
+ base_model_used_for_training,
150
+ instance_prompt_used_for_training,
151
+ ])
152
+ inputs = [
153
+ model_id,
154
+ prompt,
155
+ seed,
156
+ num_steps,
157
+ guidance_scale,
158
+ ]
159
+ prompt.submit(fn=pipe.run, inputs=inputs, outputs=result)
160
+ run_button.click(fn=pipe.run, inputs=inputs, outputs=result)
161
+ return demo
162
+
163
+
164
+ if __name__ == '__main__':
165
+ import os
166
+
167
+ hf_token = os.getenv('HF_TOKEN')
168
+ pipe = InferencePipeline(hf_token)
169
+ demo = create_inference_demo(pipe, hf_token)
170
+ demo.queue(max_size=10).launch(share=True, debug=True)
app_training.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+
7
+ import gradio as gr
8
+
9
+ from constants import UploadTarget
10
+ from inference import InferencePipeline
11
+ from trainer import Trainer
12
+
13
+
14
+ def create_training_demo(trainer: Trainer,
15
+ pipe: InferencePipeline | None = None) -> gr.Blocks:
16
+ with gr.Blocks() as demo:
17
+ with gr.Row():
18
+ with gr.Column():
19
+ with gr.Box():
20
+ gr.Markdown('Training Data')
21
+ instance_images = gr.Files(label='Instance images')
22
+ instance_prompt = gr.Textbox(label='Instance prompt',
23
+ max_lines=1)
24
+ gr.Markdown('''
25
+ - Upload images of the style you are planning on training on.
26
+ - For an instance prompt, use a unique, made up word to avoid collisions.
27
+ ''')
28
+ with gr.Box():
29
+ gr.Markdown('Output Model')
30
+ output_model_name = gr.Text(label='Name of your model',
31
+ max_lines=1)
32
+ delete_existing_model = gr.Checkbox(
33
+ label='Delete existing model of the same name',
34
+ value=False)
35
+ validation_prompt = gr.Text(label='Validation Prompt')
36
+ with gr.Box():
37
+ gr.Markdown('Upload Settings')
38
+ with gr.Row():
39
+ upload_to_hub = gr.Checkbox(
40
+ label='Upload model to Hub', value=True)
41
+ use_private_repo = gr.Checkbox(label='Private',
42
+ value=True)
43
+ delete_existing_repo = gr.Checkbox(
44
+ label='Delete existing repo of the same name',
45
+ value=False)
46
+ upload_to = gr.Radio(
47
+ label='Upload to',
48
+ choices=[_.value for _ in UploadTarget],
49
+ value=UploadTarget.SVDIFF_LIBRARY.value)
50
+ gr.Markdown('''
51
+ - By default, trained models will be uploaded to [SVDiff-pytorch Library](https://huggingface.co/svdiff-library).
52
+ - You can also choose "Personal Profile", in which case, the model will be uploaded to https://huggingface.co/{your_username}/{model_name}.
53
+ ''')
54
+
55
+ with gr.Box():
56
+ gr.Markdown('Training Parameters')
57
+ with gr.Row():
58
+ base_model = gr.Text(
59
+ label='Base Model',
60
+ value='runwayml/stable-diffusion-v1-5',
61
+ max_lines=1)
62
+ resolution = gr.Dropdown(choices=['512', '768'],
63
+ value='512',
64
+ label='Resolution')
65
+ num_training_steps = gr.Number(
66
+ label='Number of Training Steps', value=1000, precision=0)
67
+ learning_rate = gr.Number(label='Learning Rate', value=0.005)
68
+ gradient_accumulation = gr.Number(
69
+ label='Number of Gradient Accumulation',
70
+ value=1,
71
+ precision=0)
72
+ seed = gr.Slider(label='Seed',
73
+ minimum=0,
74
+ maximum=100000,
75
+ step=1,
76
+ value=0)
77
+ fp16 = gr.Checkbox(label='FP16', value=False)
78
+ use_8bit_adam = gr.Checkbox(label='Use 8bit Adam', value=True)
79
+ gradient_checkpointing = gr.Checkbox(label='Use gradient checkpointing', value=True)
80
+ # enable_xformers_memory_efficient_attention = gr.Checkbox(label='Use xformers', value=True)
81
+ checkpointing_steps = gr.Number(label='Checkpointing Steps',
82
+ value=200,
83
+ precision=0)
84
+ use_wandb = gr.Checkbox(label='Use W&B',
85
+ value=False,
86
+ interactive=bool(
87
+ os.getenv('WANDB_API_KEY')))
88
+ validation_epochs = gr.Number(label='Validation Epochs',
89
+ value=200,
90
+ precision=0)
91
+ gr.Markdown('''
92
+ - The base model must be a model that is compatible with [diffusers](https://github.com/huggingface/diffusers) library.
93
+ - It takes a few minutes to download the base model first.
94
+ - It will take about 8 minutes to train for 1000 steps with a T4 GPU.
95
+ - You may want to try a small number of steps first, like 1, to see if everything works fine in your environment.
96
+ - You can check the training status by pressing the "Open logs" button if you are running this on your Space.
97
+ - You need to set the environment variable `WANDB_API_KEY` if you'd like to use [W&B](https://wandb.ai/site). See [W&B documentation](https://docs.wandb.ai/guides/track/advanced/environment-variables).
98
+ - **Note:** Due to [this issue](https://github.com/huggingface/accelerate/issues/944), currently, training will not terminate properly if you use W&B.
99
+ ''')
100
+
101
+ remove_gpu_after_training = gr.Checkbox(
102
+ label='Remove GPU after training',
103
+ value=False,
104
+ interactive=bool(os.getenv('SPACE_ID')),
105
+ visible=False)
106
+ run_button = gr.Button('Start Training')
107
+
108
+ with gr.Box():
109
+ gr.Markdown('Output message')
110
+ output_message = gr.Markdown()
111
+
112
+ if pipe is not None:
113
+ run_button.click(fn=pipe.clear)
114
+ run_button.click(fn=trainer.run,
115
+ inputs=[
116
+ instance_images,
117
+ instance_prompt,
118
+ output_model_name,
119
+ delete_existing_model,
120
+ validation_prompt,
121
+ base_model,
122
+ resolution,
123
+ num_training_steps,
124
+ learning_rate,
125
+ gradient_accumulation,
126
+ seed,
127
+ fp16,
128
+ use_8bit_adam,
129
+ gradient_checkpointing,
130
+ # enable_xformers_memory_efficient_attention,
131
+ checkpointing_steps,
132
+ use_wandb,
133
+ validation_epochs,
134
+ upload_to_hub,
135
+ use_private_repo,
136
+ delete_existing_repo,
137
+ upload_to,
138
+ remove_gpu_after_training,
139
+ ],
140
+ outputs=output_message)
141
+ return demo
142
+
143
+
144
+ if __name__ == '__main__':
145
+ hf_token = os.getenv('HF_TOKEN')
146
+ trainer = Trainer(hf_token)
147
+ demo = create_training_demo(trainer)
148
+ demo.queue(max_size=1).launch(share=True, debug=True)
app_upload.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import pathlib
6
+
7
+ import gradio as gr
8
+ import slugify
9
+
10
+ from constants import UploadTarget
11
+ from uploader import Uploader
12
+ from utils import find_exp_dirs
13
+
14
+
15
+ class ModelUploader(Uploader):
16
+ def upload_model(
17
+ self,
18
+ folder_path: str,
19
+ repo_name: str,
20
+ upload_to: str,
21
+ private: bool,
22
+ delete_existing_repo: bool,
23
+ ) -> str:
24
+ if not folder_path:
25
+ raise ValueError
26
+ if not repo_name:
27
+ repo_name = pathlib.Path(folder_path).name
28
+ repo_name = slugify.slugify(repo_name)
29
+
30
+ if upload_to == UploadTarget.PERSONAL_PROFILE.value:
31
+ organization = ''
32
+ elif upload_to == UploadTarget.SVDIFF_LIBRARY.value:
33
+ organization = 'svdiff-library'
34
+ else:
35
+ raise ValueError
36
+
37
+ return self.upload(folder_path,
38
+ repo_name,
39
+ organization=organization,
40
+ private=private,
41
+ delete_existing_repo=delete_existing_repo)
42
+
43
+
44
+ def load_local_model_list() -> dict:
45
+ choices = find_exp_dirs(ignore_repo=True)
46
+ return gr.update(choices=choices, value=choices[0] if choices else None)
47
+
48
+
49
+ def create_upload_demo(hf_token: str | None) -> gr.Blocks:
50
+ uploader = ModelUploader(hf_token)
51
+ model_dirs = find_exp_dirs(ignore_repo=True)
52
+
53
+ with gr.Blocks() as demo:
54
+ with gr.Box():
55
+ gr.Markdown('Local Models')
56
+ reload_button = gr.Button('Reload Model List')
57
+ model_dir = gr.Dropdown(
58
+ label='Model names',
59
+ choices=model_dirs,
60
+ value=model_dirs[0] if model_dirs else None)
61
+ with gr.Box():
62
+ gr.Markdown('Upload Settings')
63
+ with gr.Row():
64
+ use_private_repo = gr.Checkbox(label='Private', value=True)
65
+ delete_existing_repo = gr.Checkbox(
66
+ label='Delete existing repo of the same name', value=False)
67
+ upload_to = gr.Radio(label='Upload to',
68
+ choices=[_.value for _ in UploadTarget],
69
+ value=UploadTarget.SVDIFF_LIBRARY.value)
70
+ model_name = gr.Textbox(label='Model Name')
71
+ upload_button = gr.Button('Upload')
72
+ gr.Markdown('''
73
+ - You can upload your trained model to your personal profile (i.e. https://huggingface.co/{your_username}/{model_name}) or to the public [SVDiff-pytorch Concepts Library](https://huggingface.co/svdiff-library) (i.e. https://huggingface.co/svdiff-library/{model_name}).
74
+ ''')
75
+ with gr.Box():
76
+ gr.Markdown('Output message')
77
+ output_message = gr.Markdown()
78
+
79
+ reload_button.click(fn=load_local_model_list,
80
+ inputs=None,
81
+ outputs=model_dir)
82
+ upload_button.click(fn=uploader.upload_model,
83
+ inputs=[
84
+ model_dir,
85
+ model_name,
86
+ upload_to,
87
+ use_private_repo,
88
+ delete_existing_repo,
89
+ ],
90
+ outputs=output_message)
91
+
92
+ return demo
93
+
94
+
95
+ if __name__ == '__main__':
96
+ import os
97
+
98
+ hf_token = os.getenv('HF_TOKEN')
99
+ demo = create_upload_demo(hf_token)
100
+ demo.queue(max_size=1).launch(share=True, debug=True)
constants.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ import enum
2
+
3
+
4
+ class UploadTarget(enum.Enum):
5
+ PERSONAL_PROFILE = 'Personal Profile'
6
+ SVDIFF_LIBRARY = 'SVDiff Library'
inference.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import gc
4
+ import pathlib
5
+
6
+ import gradio as gr
7
+ import PIL.Image
8
+ import torch
9
+ from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
10
+ from huggingface_hub import ModelCard
11
+ from svdiff_pytorch import load_unet_for_svdiff, SCHEDULER_MAPPING, image_grid
12
+
13
+
14
+
15
+ class InferencePipeline:
16
+ def __init__(self, hf_token: str | None = None):
17
+ self.hf_token = hf_token
18
+ self.pipe = None
19
+ self.device = torch.device(
20
+ 'cuda:0' if torch.cuda.is_available() else 'cpu')
21
+ self.model_id = None
22
+ self.base_model_id = None
23
+
24
+ def clear(self) -> None:
25
+ self.model_id = None
26
+ self.base_model_id = None
27
+ del self.pipe
28
+ self.pipe = None
29
+ torch.cuda.empty_cache()
30
+ gc.collect()
31
+
32
+ @staticmethod
33
+ def check_if_model_is_local(model_id: str) -> bool:
34
+ return pathlib.Path(model_id).exists()
35
+
36
+ @staticmethod
37
+ def get_model_card(model_id: str,
38
+ hf_token: str | None = None) -> ModelCard:
39
+ if InferencePipeline.check_if_model_is_local(model_id):
40
+ card_path = (pathlib.Path(model_id) / 'README.md').as_posix()
41
+ else:
42
+ card_path = model_id
43
+ return ModelCard.load(card_path, token=hf_token)
44
+
45
+ @staticmethod
46
+ def get_base_model_info(model_id: str,
47
+ hf_token: str | None = None) -> str:
48
+ card = InferencePipeline.get_model_card(model_id, hf_token)
49
+ return card.data.base_model
50
+
51
+ def load_pipe(self, model_id: str) -> None:
52
+ if model_id == self.model_id:
53
+ return
54
+
55
+ base_model_id = self.get_base_model_info(model_id, self.hf_token)
56
+ unet = load_unet_for_svdiff(base_model_id, spectral_shifts_ckpt=model_id, subfolder="unet").to(self.device)
57
+ # first perform svd and cache
58
+ for module in unet.modules():
59
+ if hasattr(module, "perform_svd"):
60
+ module.perform_svd()
61
+ unet = unet.to(self.device, dtype=torch.float16)
62
+ if base_model_id != self.base_model_id:
63
+ if self.device.type == 'cpu':
64
+ pipe = DiffusionPipeline.from_pretrained(
65
+ base_model_id,
66
+ unet=unet,
67
+ use_auth_token=self.hf_token
68
+ )
69
+ else:
70
+ pipe = DiffusionPipeline.from_pretrained(
71
+ base_model_id,
72
+ unet=unet,
73
+ torch_dtype=torch.float16,
74
+ use_auth_token=self.hf_token
75
+ )
76
+ pipe = pipe.to(self.device)
77
+ pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
78
+ self.pipe = pipe
79
+
80
+ self.model_id = model_id # type: ignore
81
+ self.base_model_id = base_model_id # type: ignore
82
+
83
+ def run(
84
+ self,
85
+ model_id: str,
86
+ prompt: str,
87
+ seed: int,
88
+ n_steps: int,
89
+ guidance_scale: float,
90
+ ) -> PIL.Image.Image:
91
+ # if not torch.cuda.is_available():
92
+ # raise gr.Error('CUDA is not available.')
93
+
94
+ self.load_pipe(model_id)
95
+
96
+ generator = torch.Generator(device=self.device).manual_seed(seed)
97
+ out = self.pipe(
98
+ prompt,
99
+ num_inference_steps=n_steps,
100
+ guidance_scale=guidance_scale,
101
+ generator=generator,
102
+ ) # type: ignore
103
+ return out.images[0]
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ svdiff-pytorch
2
+ bitsandbytes==0.35.0
3
+ python-slugify==7.0.0
4
+ tomesd
5
+ gradio==3.16.2
style.css ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ }
train_svdiff.py ADDED
@@ -0,0 +1,1013 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import hashlib
3
+ import logging
4
+ import math
5
+ import os
6
+ import warnings
7
+ from pathlib import Path
8
+ from typing import Optional
9
+ from packaging import version
10
+
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn.functional as F
14
+ import torch.utils.checkpoint
15
+ import transformers
16
+ from accelerate import Accelerator
17
+ from accelerate.logging import get_logger
18
+ from accelerate.utils import ProjectConfiguration, set_seed
19
+ from huggingface_hub import create_repo, upload_folder
20
+ from packaging import version
21
+ from PIL import Image
22
+ from torch.utils.data import Dataset
23
+ from torchvision import transforms
24
+ from tqdm.auto import tqdm
25
+ from transformers import AutoTokenizer, PretrainedConfig
26
+
27
+ import diffusers
28
+ from diffusers import __version__
29
+ from diffusers import (
30
+ AutoencoderKL,
31
+ DDPMScheduler,
32
+ DiffusionPipeline,
33
+ StableDiffusionPipeline,
34
+ DPMSolverMultistepScheduler,
35
+ )
36
+ from svdiff_pytorch import load_unet_for_svdiff, SCHEDULER_MAPPING
37
+ from diffusers.loaders import AttnProcsLayers
38
+ from diffusers.optimization import get_scheduler
39
+ from diffusers.utils import check_min_version, is_wandb_available
40
+ from diffusers.utils.import_utils import is_xformers_available
41
+ from safetensors import safe_open
42
+ from safetensors.torch import save_file
43
+ if is_wandb_available():
44
+ import wandb
45
+
46
+
47
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
48
+ # check_min_version("0.15.0.dev0")
49
+ diffusers_version = "0.14.0"
50
+ if version.parse(__version__) != version.parse(diffusers_version):
51
+ error_message = f"This example requires a version of {diffusers_version},"
52
+ error_message += f" but the version found is {__version__}.\n"
53
+ raise ImportError(error_message)
54
+
55
+ logger = get_logger(__name__)
56
+
57
+
58
+ def save_model_card(repo_id: str, base_model=str, prompt=str, repo_folder=None):
59
+ yaml = f"""
60
+ ---
61
+ license: creativeml-openrail-m
62
+ base_model: {base_model}
63
+ instance_prompt: {prompt}
64
+ tags:
65
+ - stable-diffusion
66
+ - stable-diffusion-diffusers
67
+ - text-to-image
68
+ - diffusers
69
+ - svdiff
70
+ inference: true
71
+ ---
72
+ """
73
+ model_card = f"""
74
+ # SVDiff-pytorch - {repo_id}
75
+ These are SVDiff weights for {base_model}. The weights were trained on {prompt} using [DreamBooth](https://dreambooth.github.io/).
76
+ """
77
+ with open(os.path.join(repo_folder, "README.md"), "w") as f:
78
+ f.write(yaml + model_card)
79
+
80
+
81
+ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
82
+ text_encoder_config = PretrainedConfig.from_pretrained(
83
+ pretrained_model_name_or_path,
84
+ subfolder="text_encoder",
85
+ revision=revision,
86
+ )
87
+ model_class = text_encoder_config.architectures[0]
88
+
89
+ if model_class == "CLIPTextModel":
90
+ from transformers import CLIPTextModel
91
+
92
+ return CLIPTextModel
93
+ elif model_class == "RobertaSeriesModelWithTransformation":
94
+ from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
95
+
96
+ return RobertaSeriesModelWithTransformation
97
+ else:
98
+ raise ValueError(f"{model_class} is not supported.")
99
+
100
+
101
+ def parse_args(input_args=None):
102
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
103
+ parser.add_argument(
104
+ "--pretrained_model_name_or_path",
105
+ type=str,
106
+ default=None,
107
+ required=True,
108
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
109
+ )
110
+ parser.add_argument(
111
+ "--pretrained_vae_name_or_path",
112
+ type=str,
113
+ default=None,
114
+ help="Path to pretrained vae or vae identifier from huggingface.co/models. This will be used in prior generation",
115
+ )
116
+ parser.add_argument(
117
+ "--revision",
118
+ type=str,
119
+ default=None,
120
+ required=False,
121
+ help="Revision of pretrained model identifier from huggingface.co/models.",
122
+ )
123
+ parser.add_argument(
124
+ "--tokenizer_name",
125
+ type=str,
126
+ default=None,
127
+ help="Pretrained tokenizer name or path if not the same as model_name",
128
+ )
129
+ parser.add_argument(
130
+ "--instance_data_dir",
131
+ type=str,
132
+ default=None,
133
+ required=True,
134
+ help="A folder containing the training data of instance images.",
135
+ )
136
+ parser.add_argument(
137
+ "--class_data_dir",
138
+ type=str,
139
+ default=None,
140
+ required=False,
141
+ help="A folder containing the training data of class images.",
142
+ )
143
+ parser.add_argument(
144
+ "--instance_prompt",
145
+ type=str,
146
+ default=None,
147
+ required=True,
148
+ help="The prompt with identifier specifying the instance",
149
+ )
150
+ parser.add_argument(
151
+ "--class_prompt",
152
+ type=str,
153
+ default=None,
154
+ help="The prompt to specify images in the same class as provided instance images.",
155
+ )
156
+ parser.add_argument(
157
+ "--validation_prompt",
158
+ type=str,
159
+ default=None,
160
+ help="A prompt that is used during validation to verify that the model is learning.",
161
+ )
162
+ parser.add_argument(
163
+ "--num_validation_images",
164
+ type=int,
165
+ default=4,
166
+ help="Number of images that should be generated during validation with `validation_prompt`.",
167
+ )
168
+ parser.add_argument(
169
+ "--validation_epochs",
170
+ type=int,
171
+ default=50,
172
+ help=(
173
+ "Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt"
174
+ " `args.validation_prompt` multiple times: `args.num_validation_images`."
175
+ ),
176
+ )
177
+ parser.add_argument(
178
+ "--with_prior_preservation",
179
+ default=False,
180
+ action="store_true",
181
+ help="Flag to add prior preservation loss.",
182
+ )
183
+ parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
184
+ parser.add_argument(
185
+ "--num_class_images",
186
+ type=int,
187
+ default=100,
188
+ help=(
189
+ "Minimal class images for prior preservation loss. If there are not enough images already present in"
190
+ " class_data_dir, additional images will be sampled with class_prompt."
191
+ ),
192
+ )
193
+ parser.add_argument(
194
+ "--output_dir",
195
+ type=str,
196
+ default="lora-dreambooth-model",
197
+ help="The output directory where the model predictions and checkpoints will be written.",
198
+ )
199
+ parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")
200
+ parser.add_argument(
201
+ "--resolution",
202
+ type=int,
203
+ default=512,
204
+ help=(
205
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
206
+ " resolution"
207
+ ),
208
+ )
209
+ parser.add_argument(
210
+ "--center_crop",
211
+ default=False,
212
+ action="store_true",
213
+ help=(
214
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
215
+ " cropped. The images will be resized to the resolution first before cropping."
216
+ ),
217
+ )
218
+ parser.add_argument(
219
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
220
+ )
221
+ parser.add_argument(
222
+ "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
223
+ )
224
+ parser.add_argument("--num_train_epochs", type=int, default=1)
225
+ parser.add_argument(
226
+ "--max_train_steps",
227
+ type=int,
228
+ default=None,
229
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
230
+ )
231
+ parser.add_argument(
232
+ "--checkpointing_steps",
233
+ type=int,
234
+ default=500,
235
+ help=(
236
+ "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
237
+ " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
238
+ " training using `--resume_from_checkpoint`."
239
+ ),
240
+ )
241
+ parser.add_argument(
242
+ "--checkpoints_total_limit",
243
+ type=int,
244
+ default=None,
245
+ help=(
246
+ "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
247
+ " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
248
+ " for more docs"
249
+ ),
250
+ )
251
+ parser.add_argument(
252
+ "--resume_from_checkpoint",
253
+ type=str,
254
+ default=None,
255
+ help=(
256
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
257
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
258
+ ),
259
+ )
260
+ parser.add_argument(
261
+ "--gradient_accumulation_steps",
262
+ type=int,
263
+ default=1,
264
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
265
+ )
266
+ parser.add_argument(
267
+ "--gradient_checkpointing",
268
+ action="store_true",
269
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
270
+ )
271
+ parser.add_argument(
272
+ "--learning_rate",
273
+ type=float,
274
+ default=5e-4,
275
+ help="Initial learning rate (after the potential warmup period) to use.",
276
+ )
277
+ parser.add_argument(
278
+ "--scale_lr",
279
+ action="store_true",
280
+ default=False,
281
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
282
+ )
283
+ parser.add_argument(
284
+ "--lr_scheduler",
285
+ type=str,
286
+ default="constant",
287
+ help=(
288
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
289
+ ' "constant", "constant_with_warmup"]'
290
+ ),
291
+ )
292
+ parser.add_argument(
293
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
294
+ )
295
+ parser.add_argument(
296
+ "--lr_num_cycles",
297
+ type=int,
298
+ default=1,
299
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
300
+ )
301
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
302
+ parser.add_argument(
303
+ "--dataloader_num_workers",
304
+ type=int,
305
+ default=0,
306
+ help=(
307
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
308
+ ),
309
+ )
310
+ parser.add_argument(
311
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
312
+ )
313
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
314
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
315
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
316
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
317
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
318
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
319
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
320
+ parser.add_argument(
321
+ "--hub_model_id",
322
+ type=str,
323
+ default=None,
324
+ help="The name of the repository to keep in sync with the local `output_dir`.",
325
+ )
326
+ parser.add_argument(
327
+ "--logging_dir",
328
+ type=str,
329
+ default="logs",
330
+ help=(
331
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
332
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
333
+ ),
334
+ )
335
+ parser.add_argument(
336
+ "--allow_tf32",
337
+ action="store_true",
338
+ help=(
339
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
340
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
341
+ ),
342
+ )
343
+ parser.add_argument(
344
+ "--report_to",
345
+ type=str,
346
+ default="tensorboard",
347
+ help=(
348
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
349
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
350
+ ),
351
+ )
352
+ parser.add_argument(
353
+ "--mixed_precision",
354
+ type=str,
355
+ default=None,
356
+ choices=["no", "fp16", "bf16"],
357
+ help=(
358
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
359
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
360
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
361
+ ),
362
+ )
363
+ parser.add_argument(
364
+ "--prior_generation_precision",
365
+ type=str,
366
+ default=None,
367
+ choices=["no", "fp32", "fp16", "bf16"],
368
+ help=(
369
+ "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
370
+ " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32."
371
+ ),
372
+ )
373
+ parser.add_argument("--prior_generation_scheduler_type", type=str, choices=["ddim", "plms", "lms", "euler", "euler_ancestral", "dpm_solver++"], default="ddim", help="diffusion scheduler type")
374
+ parser.add_argument("--prior_generation_num_inference_steps", type=int, default=50, help="number of sampling steps")
375
+
376
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
377
+ parser.add_argument(
378
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
379
+ )
380
+ parser.add_argument(
381
+ "--enable_token_merging", action="store_true", help="Whether or not to use tomesd on prior generation"
382
+ )
383
+ if input_args is not None:
384
+ args = parser.parse_args(input_args)
385
+ else:
386
+ args = parser.parse_args()
387
+
388
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
389
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
390
+ args.local_rank = env_local_rank
391
+
392
+ if args.with_prior_preservation:
393
+ if args.class_data_dir is None:
394
+ raise ValueError("You must specify a data directory for class images.")
395
+ if args.class_prompt is None:
396
+ raise ValueError("You must specify prompt for class images.")
397
+ else:
398
+ # logger is not available yet
399
+ if args.class_data_dir is not None:
400
+ warnings.warn("You need not use --class_data_dir without --with_prior_preservation.")
401
+ if args.class_prompt is not None:
402
+ warnings.warn("You need not use --class_prompt without --with_prior_preservation.")
403
+
404
+ return args
405
+
406
+
407
+ class DreamBoothDataset(Dataset):
408
+ """
409
+ A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
410
+ It pre-processes the images and the tokenizes prompts.
411
+ """
412
+
413
+ def __init__(
414
+ self,
415
+ instance_data_root,
416
+ instance_prompt,
417
+ tokenizer,
418
+ class_data_root=None,
419
+ class_prompt=None,
420
+ class_num=None,
421
+ size=512,
422
+ center_crop=False,
423
+ ):
424
+ self.size = size
425
+ self.center_crop = center_crop
426
+ self.tokenizer = tokenizer
427
+
428
+ self.instance_data_root = Path(instance_data_root)
429
+ if not self.instance_data_root.exists():
430
+ raise ValueError("Instance images root doesn't exists.")
431
+
432
+ self.instance_images_path = list(Path(instance_data_root).iterdir())
433
+ self.num_instance_images = len(self.instance_images_path)
434
+ self.instance_prompt = instance_prompt
435
+ self._length = self.num_instance_images
436
+
437
+ if class_data_root is not None:
438
+ self.class_data_root = Path(class_data_root)
439
+ self.class_data_root.mkdir(parents=True, exist_ok=True)
440
+ self.class_images_path = list(self.class_data_root.iterdir())
441
+ if class_num is not None:
442
+ self.num_class_images = min(len(self.class_images_path), class_num)
443
+ else:
444
+ self.num_class_images = len(self.class_images_path)
445
+ self._length = max(self.num_class_images, self.num_instance_images)
446
+ self.class_prompt = class_prompt
447
+ else:
448
+ self.class_data_root = None
449
+
450
+ self.image_transforms = transforms.Compose(
451
+ [
452
+ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
453
+ transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
454
+ transforms.ToTensor(),
455
+ transforms.Normalize([0.5], [0.5]),
456
+ ]
457
+ )
458
+
459
+ def __len__(self):
460
+ return self._length
461
+
462
+ def __getitem__(self, index):
463
+ example = {}
464
+ instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
465
+ if not instance_image.mode == "RGB":
466
+ instance_image = instance_image.convert("RGB")
467
+ example["instance_images"] = self.image_transforms(instance_image)
468
+ example["instance_prompt_ids"] = self.tokenizer(
469
+ self.instance_prompt,
470
+ truncation=True,
471
+ padding="max_length",
472
+ max_length=self.tokenizer.model_max_length,
473
+ return_tensors="pt",
474
+ ).input_ids
475
+
476
+ if self.class_data_root:
477
+ class_image = Image.open(self.class_images_path[index % self.num_class_images])
478
+ if not class_image.mode == "RGB":
479
+ class_image = class_image.convert("RGB")
480
+ example["class_images"] = self.image_transforms(class_image)
481
+ example["class_prompt_ids"] = self.tokenizer(
482
+ self.class_prompt,
483
+ truncation=True,
484
+ padding="max_length",
485
+ max_length=self.tokenizer.model_max_length,
486
+ return_tensors="pt",
487
+ ).input_ids
488
+
489
+ return example
490
+
491
+
492
+ def collate_fn(examples, with_prior_preservation=False):
493
+ input_ids = [example["instance_prompt_ids"] for example in examples]
494
+ pixel_values = [example["instance_images"] for example in examples]
495
+
496
+ # Concat class and instance examples for prior preservation.
497
+ # We do this to avoid doing two forward passes.
498
+ if with_prior_preservation:
499
+ input_ids += [example["class_prompt_ids"] for example in examples]
500
+ pixel_values += [example["class_images"] for example in examples]
501
+
502
+ pixel_values = torch.stack(pixel_values)
503
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
504
+
505
+ input_ids = torch.cat(input_ids, dim=0)
506
+
507
+ batch = {
508
+ "input_ids": input_ids,
509
+ "pixel_values": pixel_values,
510
+ }
511
+ return batch
512
+
513
+
514
+ class PromptDataset(Dataset):
515
+ "A simple dataset to prepare the prompts to generate class images on multiple GPUs."
516
+
517
+ def __init__(self, prompt, num_samples):
518
+ self.prompt = prompt
519
+ self.num_samples = num_samples
520
+
521
+ def __len__(self):
522
+ return self.num_samples
523
+
524
+ def __getitem__(self, index):
525
+ example = {}
526
+ example["prompt"] = self.prompt
527
+ example["index"] = index
528
+ return example
529
+
530
+
531
+ def log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch):
532
+ logger.info(
533
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
534
+ f" {args.validation_prompt}."
535
+ )
536
+ # create pipeline (note: unet and vae are loaded again in float32)
537
+ pipeline = DiffusionPipeline.from_pretrained(
538
+ args.pretrained_model_name_or_path,
539
+ text_encoder=text_encoder,
540
+ tokenizer=tokenizer,
541
+ unet=accelerator.unwrap_model(unet),
542
+ vae=vae,
543
+ revision=args.revision,
544
+ torch_dtype=weight_dtype,
545
+ )
546
+ pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
547
+ pipeline = pipeline.to(accelerator.device)
548
+ pipeline.set_progress_bar_config(disable=True)
549
+
550
+ # run inference
551
+ generator = None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed)
552
+ images = []
553
+ for _ in range(args.num_validation_images):
554
+ with torch.autocast("cuda"):
555
+ image = pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
556
+ images.append(image)
557
+
558
+ for tracker in accelerator.trackers:
559
+ if tracker.name == "tensorboard":
560
+ np_images = np.stack([np.asarray(img) for img in images])
561
+ tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
562
+ if tracker.name == "wandb":
563
+ tracker.log(
564
+ {
565
+ "validation": [
566
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images)
567
+ ]
568
+ }
569
+ )
570
+
571
+ del pipeline
572
+ torch.cuda.empty_cache()
573
+
574
+
575
+
576
+ def main(args):
577
+ logging_dir = Path(args.output_dir, args.logging_dir)
578
+
579
+ accelerator_project_config = ProjectConfiguration(total_limit=args.checkpoints_total_limit)
580
+
581
+ accelerator = Accelerator(
582
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
583
+ mixed_precision=args.mixed_precision,
584
+ log_with=args.report_to,
585
+ logging_dir=logging_dir,
586
+ project_config=accelerator_project_config,
587
+ )
588
+
589
+ if args.report_to == "wandb":
590
+ if not is_wandb_available():
591
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
592
+ import wandb
593
+
594
+ # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
595
+ # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
596
+ # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
597
+ # Make one log on every process with the configuration for debugging.
598
+ logging.basicConfig(
599
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
600
+ datefmt="%m/%d/%Y %H:%M:%S",
601
+ level=logging.INFO,
602
+ )
603
+ logger.info(accelerator.state, main_process_only=False)
604
+ if accelerator.is_local_main_process:
605
+ transformers.utils.logging.set_verbosity_warning()
606
+ diffusers.utils.logging.set_verbosity_info()
607
+ else:
608
+ transformers.utils.logging.set_verbosity_error()
609
+ diffusers.utils.logging.set_verbosity_error()
610
+
611
+ # If passed along, set the training seed now.
612
+ if args.seed is not None:
613
+ set_seed(args.seed)
614
+
615
+ # Generate class images if prior preservation is enabled.
616
+ if args.with_prior_preservation:
617
+ class_images_dir = Path(args.class_data_dir)
618
+ if not class_images_dir.exists():
619
+ class_images_dir.mkdir(parents=True)
620
+ cur_class_images = len(list(class_images_dir.iterdir()))
621
+
622
+ if cur_class_images < args.num_class_images:
623
+ torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
624
+ if args.prior_generation_precision == "fp32":
625
+ torch_dtype = torch.float32
626
+ elif args.prior_generation_precision == "fp16":
627
+ torch_dtype = torch.float16
628
+ elif args.prior_generation_precision == "bf16":
629
+ torch_dtype = torch.bfloat16
630
+ pipeline = StableDiffusionPipeline.from_pretrained(
631
+ args.pretrained_model_name_or_path,
632
+ vae=AutoencoderKL.from_pretrained(
633
+ args.pretrained_vae_name_or_path or args.pretrained_model_name_or_path,
634
+ subfolder=None if args.pretrained_vae_name_or_path else "vae",
635
+ revision=None if args.pretrained_vae_name_or_path else args.revision,
636
+ torch_dtype=torch_dtype
637
+ ),
638
+ torch_dtype=torch_dtype,
639
+ safety_checker=None,
640
+ revision=args.revision,
641
+ )
642
+ pipeline.scheduler = SCHEDULER_MAPPING[args.prior_generation_scheduler_type].from_config(pipeline.scheduler.config)
643
+ if is_xformers_available():
644
+ pipeline.enable_xformers_memory_efficient_attention()
645
+ if args.enable_token_merging:
646
+ try:
647
+ import tomesd
648
+ except ImportError:
649
+ raise ImportError(
650
+ "To use token merging (ToMe), please install the tomesd library: `pip install tomesd`."
651
+ )
652
+ tomesd.apply_patch(pipeline, ratio=0.5)
653
+
654
+ pipeline.set_progress_bar_config(disable=True)
655
+
656
+ num_new_images = args.num_class_images - cur_class_images
657
+ logger.info(f"Number of class images to sample: {num_new_images}.")
658
+
659
+ sample_dataset = PromptDataset(args.class_prompt, num_new_images)
660
+ sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
661
+
662
+ sample_dataloader = accelerator.prepare(sample_dataloader)
663
+ pipeline.to(accelerator.device)
664
+
665
+ for example in tqdm(
666
+ sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
667
+ ):
668
+ images = pipeline(
669
+ example["prompt"],
670
+ num_inference_steps=args.prior_generation_num_inference_steps,
671
+ ).images
672
+
673
+ for i, image in enumerate(images):
674
+ hash_image = hashlib.sha1(image.tobytes()).hexdigest()
675
+ image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
676
+ image.save(image_filename)
677
+
678
+ del pipeline
679
+ if torch.cuda.is_available():
680
+ torch.cuda.empty_cache()
681
+
682
+ # Handle the repository creation
683
+ if accelerator.is_main_process:
684
+ if args.output_dir is not None:
685
+ os.makedirs(args.output_dir, exist_ok=True)
686
+
687
+ if args.push_to_hub:
688
+ repo_id = create_repo(
689
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
690
+ ).repo_id
691
+
692
+ # Load the tokenizer
693
+ if args.tokenizer_name:
694
+ tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)
695
+ elif args.pretrained_model_name_or_path:
696
+ tokenizer = AutoTokenizer.from_pretrained(
697
+ args.pretrained_model_name_or_path,
698
+ subfolder="tokenizer",
699
+ revision=args.revision,
700
+ use_fast=False,
701
+ )
702
+
703
+ # import correct text encoder class
704
+ text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
705
+
706
+ # Load scheduler and models
707
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
708
+ text_encoder = text_encoder_cls.from_pretrained(
709
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
710
+ )
711
+ vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
712
+ unet = load_unet_for_svdiff(args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, low_cpu_mem_usage=True)
713
+
714
+ # We only train the additional spectral shifts
715
+ vae.requires_grad_(False)
716
+ text_encoder.requires_grad_(False)
717
+ unet.requires_grad_(False)
718
+ optim_params = []
719
+ for n, p in unet.named_parameters():
720
+ if "delta" in n:
721
+ p.requires_grad = True
722
+ optim_params.append(p)
723
+ total_params = sum(p.numel() for p in optim_params)
724
+ print(f"Number of Trainable Parameters: {total_params * 1.e-6:.2f} M")
725
+
726
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
727
+ # as these models are only used for inference, keeping weights in full precision is not required.
728
+ weight_dtype = torch.float32
729
+ if accelerator.mixed_precision == "fp16":
730
+ weight_dtype = torch.float16
731
+ elif accelerator.mixed_precision == "bf16":
732
+ weight_dtype = torch.bfloat16
733
+
734
+ # Move unet, vae and text_encoder to device and cast to weight_dtype
735
+ # unet.to(accelerator.device, dtype=weight_dtype)
736
+ vae.to(accelerator.device, dtype=weight_dtype)
737
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
738
+
739
+ if args.enable_xformers_memory_efficient_attention:
740
+ if is_xformers_available():
741
+ import xformers
742
+
743
+ xformers_version = version.parse(xformers.__version__)
744
+ if xformers_version == version.parse("0.0.16"):
745
+ logger.warn(
746
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
747
+ )
748
+ unet.enable_xformers_memory_efficient_attention()
749
+ else:
750
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
751
+
752
+ if args.gradient_checkpointing:
753
+ unet.enable_gradient_checkpointing()
754
+
755
+ if args.scale_lr:
756
+ args.learning_rate = (
757
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
758
+ )
759
+
760
+ # Enable TF32 for faster training on Ampere GPUs,
761
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
762
+ if args.allow_tf32:
763
+ torch.backends.cuda.matmul.allow_tf32 = True
764
+
765
+ if args.scale_lr:
766
+ args.learning_rate = (
767
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
768
+ )
769
+
770
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
771
+ if args.use_8bit_adam:
772
+ try:
773
+ import bitsandbytes as bnb
774
+ except ImportError:
775
+ raise ImportError(
776
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
777
+ )
778
+
779
+ optimizer_class = bnb.optim.AdamW8bit
780
+ else:
781
+ optimizer_class = torch.optim.AdamW
782
+
783
+ # Optimizer creation
784
+ optimizer = optimizer_class(
785
+ optim_params,
786
+ lr=args.learning_rate,
787
+ betas=(args.adam_beta1, args.adam_beta2),
788
+ weight_decay=args.adam_weight_decay,
789
+ eps=args.adam_epsilon,
790
+ )
791
+
792
+ # Dataset and DataLoaders creation:
793
+ train_dataset = DreamBoothDataset(
794
+ instance_data_root=args.instance_data_dir,
795
+ instance_prompt=args.instance_prompt,
796
+ class_data_root=args.class_data_dir if args.with_prior_preservation else None,
797
+ class_prompt=args.class_prompt,
798
+ class_num=args.num_class_images,
799
+ tokenizer=tokenizer,
800
+ size=args.resolution,
801
+ center_crop=args.center_crop,
802
+ )
803
+
804
+ train_dataloader = torch.utils.data.DataLoader(
805
+ train_dataset,
806
+ batch_size=args.train_batch_size,
807
+ shuffle=True,
808
+ collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),
809
+ num_workers=args.dataloader_num_workers,
810
+ )
811
+
812
+ # Scheduler and math around the number of training steps.
813
+ overrode_max_train_steps = False
814
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
815
+ if args.max_train_steps is None:
816
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
817
+ overrode_max_train_steps = True
818
+
819
+ lr_scheduler = get_scheduler(
820
+ args.lr_scheduler,
821
+ optimizer=optimizer,
822
+ num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
823
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
824
+ num_cycles=args.lr_num_cycles,
825
+ power=args.lr_power,
826
+ )
827
+
828
+ # Prepare everything with our `accelerator`.
829
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
830
+ unet, optimizer, train_dataloader, lr_scheduler
831
+ )
832
+
833
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
834
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
835
+ if overrode_max_train_steps:
836
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
837
+ # Afterwards we recalculate our number of training epochs
838
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
839
+
840
+ # We need to initialize the trackers we use, and also store our configuration.
841
+ # The trackers initializes automatically on the main process.
842
+ if accelerator.is_main_process:
843
+ accelerator.init_trackers("svdiff-pytorch", config=vars(args))
844
+
845
+ def save_weights(step):
846
+ # Create the pipeline using using the trained modules and save it.
847
+ if accelerator.is_main_process:
848
+ save_path = os.path.join(args.output_dir, f"checkpoint-{step}")
849
+ os.makedirs(save_path, exist_ok=True)
850
+ unet_model = accelerator.unwrap_model(unet, keep_fp32_wrapper=True)
851
+ state_dict = {k: v for k, v in unet_model.state_dict().items() if "delta" in k}
852
+ save_file(state_dict, os.path.join(save_path, "spectral_shifts.safetensors"))
853
+ print(f"[*] Weights saved at {save_path}")
854
+
855
+ # Train!
856
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
857
+
858
+ logger.info("***** Running training *****")
859
+ logger.info(f" Num examples = {len(train_dataset)}")
860
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
861
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
862
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
863
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
864
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
865
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
866
+ global_step = 0
867
+ first_epoch = 0
868
+
869
+ # Potentially load in the weights and states from a previous save
870
+ if args.resume_from_checkpoint:
871
+ if args.resume_from_checkpoint != "latest":
872
+ path = os.path.basename(args.resume_from_checkpoint)
873
+ else:
874
+ # Get the mos recent checkpoint
875
+ dirs = os.listdir(args.output_dir)
876
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
877
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
878
+ path = dirs[-1] if len(dirs) > 0 else None
879
+
880
+ if path is None:
881
+ accelerator.print(
882
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
883
+ )
884
+ args.resume_from_checkpoint = None
885
+ else:
886
+ accelerator.print(f"Resuming from checkpoint {path}")
887
+ accelerator.load_state(os.path.join(args.output_dir, path))
888
+ global_step = int(path.split("-")[1])
889
+
890
+ resume_global_step = global_step * args.gradient_accumulation_steps
891
+ first_epoch = global_step // num_update_steps_per_epoch
892
+ resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
893
+
894
+ # Only show the progress bar once on each machine.
895
+ progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
896
+ progress_bar.set_description("Steps")
897
+
898
+ for epoch in range(first_epoch, args.num_train_epochs):
899
+ unet.train()
900
+ for step, batch in enumerate(train_dataloader):
901
+ # Skip steps until we reach the resumed step
902
+ if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
903
+ if step % args.gradient_accumulation_steps == 0:
904
+ progress_bar.update(1)
905
+ continue
906
+
907
+ with accelerator.accumulate(unet):
908
+ # Convert images to latent space
909
+ latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
910
+ latents = latents * vae.config.scaling_factor
911
+
912
+ # Sample noise that we'll add to the latents
913
+ noise = torch.randn_like(latents)
914
+ bsz = latents.shape[0]
915
+ # Sample a random timestep for each image
916
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
917
+ timesteps = timesteps.long()
918
+
919
+ # Add noise to the latents according to the noise magnitude at each timestep
920
+ # (this is the forward diffusion process)
921
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
922
+
923
+ # Get the text embedding for conditioning
924
+ encoder_hidden_states = text_encoder(batch["input_ids"])[0]
925
+
926
+ # Predict the noise residual
927
+ model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
928
+
929
+ # Get the target for loss depending on the prediction type
930
+ if noise_scheduler.config.prediction_type == "epsilon":
931
+ target = noise
932
+ elif noise_scheduler.config.prediction_type == "v_prediction":
933
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
934
+ else:
935
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
936
+
937
+ if args.with_prior_preservation:
938
+ # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
939
+ model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
940
+ target, target_prior = torch.chunk(target, 2, dim=0)
941
+
942
+ # Compute instance loss
943
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
944
+
945
+ # Compute prior loss
946
+ prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
947
+
948
+ # Add the prior loss to the instance loss.
949
+ loss = loss + args.prior_loss_weight * prior_loss
950
+ else:
951
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
952
+
953
+ accelerator.backward(loss)
954
+ if accelerator.sync_gradients:
955
+ params_to_clip = unet.parameters()
956
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
957
+ optimizer.step()
958
+ lr_scheduler.step()
959
+ optimizer.zero_grad()
960
+
961
+ # Checks if the accelerator has performed an optimization step behind the scenes
962
+ if accelerator.sync_gradients:
963
+ progress_bar.update(1)
964
+ global_step += 1
965
+
966
+ if global_step % args.checkpointing_steps == 0:
967
+ if accelerator.is_main_process:
968
+ save_weights(global_step)
969
+ # save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
970
+ # accelerator.save_state(save_path)
971
+ # logger.info(f"Saved state to {save_path}")
972
+
973
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
974
+ progress_bar.set_postfix(**logs)
975
+ accelerator.log(logs, step=global_step)
976
+
977
+ if global_step >= args.max_train_steps:
978
+ break
979
+
980
+ if accelerator.is_main_process:
981
+ if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
982
+ log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch)
983
+
984
+ accelerator.wait_for_everyone()
985
+ save_weights(global_step)
986
+ # put the latest checkpoint to output-dir
987
+ save_path = args.output_dir
988
+ unet_model = accelerator.unwrap_model(unet, keep_fp32_wrapper=True)
989
+ state_dict = {k: v for k, v in unet_model.state_dict().items() if "delta" in k}
990
+ save_file(state_dict, os.path.join(save_path, "spectral_shifts.safetensors"))
991
+ print(f"[*] Weights saved at {save_path}")
992
+
993
+ if accelerator.is_main_process:
994
+ if args.push_to_hub:
995
+ save_model_card(
996
+ repo_id,
997
+ base_model=args.pretrained_model_name_or_path,
998
+ prompt=args.instance_prompt,
999
+ repo_folder=args.output_dir,
1000
+ )
1001
+ upload_folder(
1002
+ repo_id=repo_id,
1003
+ folder_path=args.output_dir,
1004
+ commit_message="End of training",
1005
+ ignore_patterns=["step_*", "epoch_*"],
1006
+ )
1007
+
1008
+ accelerator.end_training()
1009
+
1010
+
1011
+ if __name__ == "__main__":
1012
+ args = parse_args()
1013
+ main(args)
trainer.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import datetime
4
+ import os
5
+ import pathlib
6
+ import shlex
7
+ import shutil
8
+ import subprocess
9
+
10
+ import gradio as gr
11
+ import PIL.Image
12
+ import slugify
13
+ import torch
14
+ from huggingface_hub import HfApi
15
+ from accelerate.utils import write_basic_config
16
+
17
+
18
+ from app_upload import ModelUploader
19
+ from utils import save_model_card
20
+
21
+ URL_TO_JOIN_LIBRARY_ORG = 'https://huggingface.co/organizations/svdiff-library/share/PZBRRkosXikenXUdjMcvcoFmpWjcWnZjKL'
22
+
23
+
24
+ def pad_image(image: PIL.Image.Image) -> PIL.Image.Image:
25
+ w, h = image.size
26
+ if w == h:
27
+ return image
28
+ elif w > h:
29
+ new_image = PIL.Image.new(image.mode, (w, w), (0, 0, 0))
30
+ new_image.paste(image, (0, (w - h) // 2))
31
+ return new_image
32
+ else:
33
+ new_image = PIL.Image.new(image.mode, (h, h), (0, 0, 0))
34
+ new_image.paste(image, ((h - w) // 2, 0))
35
+ return new_image
36
+
37
+
38
+ class Trainer:
39
+ def __init__(self, hf_token: str | None = None):
40
+ self.hf_token = hf_token
41
+ self.api = HfApi(token=hf_token)
42
+ self.model_uploader = ModelUploader(hf_token)
43
+
44
+ def prepare_dataset(self, instance_images: list, resolution: int,
45
+ instance_data_dir: pathlib.Path) -> None:
46
+ shutil.rmtree(instance_data_dir, ignore_errors=True)
47
+ instance_data_dir.mkdir(parents=True)
48
+ for i, temp_path in enumerate(instance_images):
49
+ image = PIL.Image.open(temp_path.name)
50
+ image = pad_image(image)
51
+ image = image.resize((resolution, resolution))
52
+ image = image.convert('RGB')
53
+ out_path = instance_data_dir / f'{i:03d}.jpg'
54
+ image.save(out_path, format='JPEG', quality=100)
55
+
56
+ def join_library_org(self) -> None:
57
+ subprocess.run(
58
+ shlex.split(
59
+ f'curl -X POST -H "Authorization: Bearer {self.hf_token}" -H "Content-Type: application/json" {URL_TO_JOIN_LIBRARY_ORG}'
60
+ ))
61
+
62
+ def run(
63
+ self,
64
+ instance_images: list | None,
65
+ instance_prompt: str,
66
+ output_model_name: str,
67
+ overwrite_existing_model: bool,
68
+ validation_prompt: str,
69
+ base_model: str,
70
+ resolution_s: str,
71
+ n_steps: int,
72
+ learning_rate: float,
73
+ gradient_accumulation: int,
74
+ seed: int,
75
+ fp16: bool,
76
+ use_8bit_adam: bool,
77
+ gradient_checkpointing: bool,
78
+ # enable_xformers_memory_efficient_attention: bool,
79
+ checkpointing_steps: int,
80
+ use_wandb: bool,
81
+ validation_epochs: int,
82
+ upload_to_hub: bool,
83
+ use_private_repo: bool,
84
+ delete_existing_repo: bool,
85
+ upload_to: str,
86
+ remove_gpu_after_training: bool,
87
+ ) -> str:
88
+ if not torch.cuda.is_available():
89
+ raise gr.Error('CUDA is not available.')
90
+ if instance_images is None:
91
+ raise gr.Error('You need to upload images.')
92
+ if not instance_prompt:
93
+ raise gr.Error('The instance prompt is missing.')
94
+ if not validation_prompt:
95
+ raise gr.Error('The validation prompt is missing.')
96
+
97
+ resolution = int(resolution_s)
98
+
99
+ if not output_model_name:
100
+ timestamp = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
101
+ output_model_name = f'svdiff-pytorch-{timestamp}'
102
+ output_model_name = slugify.slugify(output_model_name)
103
+
104
+ repo_dir = pathlib.Path(__file__).parent
105
+ output_dir = repo_dir / 'experiments' / output_model_name
106
+ if overwrite_existing_model or upload_to_hub:
107
+ shutil.rmtree(output_dir, ignore_errors=True)
108
+ output_dir.mkdir(parents=True)
109
+
110
+ instance_data_dir = repo_dir / 'training_data' / output_model_name
111
+ self.prepare_dataset(instance_images, resolution, instance_data_dir)
112
+
113
+ if upload_to_hub:
114
+ self.join_library_org()
115
+ # accelerate config
116
+ write_basic_config()
117
+ command = f'''
118
+ accelerate launch train_svdiff.py \
119
+ --pretrained_model_name_or_path={base_model} \
120
+ --instance_data_dir={instance_data_dir} \
121
+ --output_dir={output_dir} \
122
+ --instance_prompt="{instance_prompt}" \
123
+ --resolution={resolution} \
124
+ --train_batch_size=1 \
125
+ --gradient_accumulation_steps={gradient_accumulation} \
126
+ --learning_rate={learning_rate} \
127
+ --lr_scheduler=constant \
128
+ --lr_warmup_steps=0 \
129
+ --max_train_steps={n_steps} \
130
+ --checkpointing_steps={checkpointing_steps} \
131
+ --validation_prompt="{validation_prompt}" \
132
+ --validation_epochs={validation_epochs} \
133
+ --seed={seed}
134
+ '''
135
+ if fp16:
136
+ command += ' --mixed_precision="fp16"'
137
+ if use_8bit_adam:
138
+ command += ' --use_8bit_adam'
139
+ if gradient_checkpointing:
140
+ command += ' --gradient_checkpointing'
141
+ # if enable_xformers_memory_efficient_attention:
142
+ # command += ' --enable_xformers_memory_efficient_attention'
143
+ if use_wandb:
144
+ command += ' --report_to wandb'
145
+
146
+ with open(output_dir / 'train.sh', 'w') as f:
147
+ command_s = ' '.join(command.split())
148
+ f.write(command_s)
149
+ subprocess.run(shlex.split(command))
150
+ save_model_card(save_dir=output_dir,
151
+ base_model=base_model,
152
+ instance_prompt=instance_prompt,
153
+ test_prompt=validation_prompt,
154
+ test_image_dir='test_images')
155
+
156
+ message = 'Training completed!'
157
+ print(message)
158
+
159
+ if upload_to_hub:
160
+ upload_message = self.model_uploader.upload_model(
161
+ folder_path=output_dir.as_posix(),
162
+ repo_name=output_model_name,
163
+ upload_to=upload_to,
164
+ private=use_private_repo,
165
+ delete_existing_repo=delete_existing_repo)
166
+ print(upload_message)
167
+ message = message + '\n' + upload_message
168
+
169
+ if remove_gpu_after_training:
170
+ space_id = os.getenv('SPACE_ID')
171
+ if space_id:
172
+ self.api.request_space_hardware(repo_id=space_id,
173
+ hardware='cpu-basic')
174
+
175
+ return message
uploader.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from huggingface_hub import HfApi
4
+
5
+
6
+ class Uploader:
7
+ def __init__(self, hf_token: str | None):
8
+ self.api = HfApi(token=hf_token)
9
+
10
+ def get_username(self) -> str:
11
+ return self.api.whoami()['name']
12
+
13
+ def upload(self,
14
+ folder_path: str,
15
+ repo_name: str,
16
+ organization: str = '',
17
+ repo_type: str = 'model',
18
+ private: bool = True,
19
+ delete_existing_repo: bool = False) -> str:
20
+ if not folder_path:
21
+ raise ValueError
22
+ if not repo_name:
23
+ raise ValueError
24
+ if not organization:
25
+ organization = self.get_username()
26
+ repo_id = f'{organization}/{repo_name}'
27
+ if delete_existing_repo:
28
+ try:
29
+ self.api.delete_repo(repo_id, repo_type=repo_type)
30
+ except Exception:
31
+ pass
32
+ try:
33
+ self.api.create_repo(repo_id, repo_type=repo_type, private=private)
34
+ self.api.upload_folder(repo_id=repo_id,
35
+ folder_path=folder_path,
36
+ path_in_repo='.',
37
+ repo_type=repo_type)
38
+ url = f'https://huggingface.co/{repo_id}'
39
+ message = f'Your model was successfully uploaded to <a href="{url}" target="_blank">{url}</a>.'
40
+ except Exception as e:
41
+ message = str(e)
42
+ return message
utils.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import pathlib
4
+
5
+
6
+ def find_exp_dirs(ignore_repo: bool = False) -> list[str]:
7
+ repo_dir = pathlib.Path(__file__).parent
8
+ exp_root_dir = repo_dir / 'experiments'
9
+ if not exp_root_dir.exists():
10
+ return []
11
+ exp_dirs = sorted(exp_root_dir.glob('*'))
12
+ exp_dirs = [
13
+ exp_dir for exp_dir in exp_dirs
14
+ if (exp_dir / 'spectral_shifts.safetensors').exists()
15
+ ]
16
+ if ignore_repo:
17
+ exp_dirs = [
18
+ exp_dir for exp_dir in exp_dirs if not (exp_dir / '.git').exists()
19
+ ]
20
+ return [path.relative_to(repo_dir).as_posix() for path in exp_dirs]
21
+
22
+
23
+ def save_model_card(
24
+ save_dir: pathlib.Path,
25
+ base_model: str,
26
+ instance_prompt: str,
27
+ test_prompt: str = '',
28
+ test_image_dir: str = '',
29
+ ) -> None:
30
+ image_str = ''
31
+ if test_prompt and test_image_dir:
32
+ image_paths = sorted((save_dir / test_image_dir).glob('*'))
33
+ if image_paths:
34
+ image_str = f'Test prompt: {test_prompt}\n'
35
+ for image_path in image_paths:
36
+ rel_path = image_path.relative_to(save_dir)
37
+ image_str += f'![{image_path.stem}]({rel_path})\n'
38
+
39
+ model_card = f'''---
40
+ license: creativeml-openrail-m
41
+ base_model: {base_model}
42
+ instance_prompt: {instance_prompt}
43
+ tags:
44
+ - stable-diffusion
45
+ - stable-diffusion-diffusers
46
+ - text-to-image
47
+ - diffusers
48
+ - lora
49
+ inference: true
50
+ ---
51
+ # SVDiff-pytorch - {save_dir.name}
52
+
53
+ These are SVDiff weights for {base_model}. The weights were trained on "{instance_prompt}" using [DreamBooth](https://dreambooth.github.io/). You can find some example images in the following.
54
+ {image_str}
55
+ '''
56
+
57
+ with open(save_dir / 'README.md', 'w') as f:
58
+ f.write(model_card)