vendor cog_sdxl
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +2 -0
- cog_sdxl/.dockerignore +35 -0
- cog_sdxl/.gitignore +23 -0
- cog_sdxl/LICENSE +202 -0
- cog_sdxl/README.md +41 -0
- cog_sdxl/cog.yaml +33 -0
- cog_sdxl/dataset_and_utils.py +421 -0
- cog_sdxl/example_datasets/README.md +3 -0
- cog_sdxl/example_datasets/kiriko.png +3 -0
- cog_sdxl/example_datasets/kiriko/0.src.jpg +0 -0
- cog_sdxl/example_datasets/kiriko/1.src.jpg +0 -0
- cog_sdxl/example_datasets/kiriko/10.src.jpg +0 -0
- cog_sdxl/example_datasets/kiriko/11.src.jpg +0 -0
- cog_sdxl/example_datasets/kiriko/12.src.jpg +0 -0
- cog_sdxl/example_datasets/kiriko/2.src.jpg +0 -0
- cog_sdxl/example_datasets/kiriko/3.src.jpg +0 -0
- cog_sdxl/example_datasets/kiriko/4.src.jpg +0 -0
- cog_sdxl/example_datasets/kiriko/5.src.jpg +0 -0
- cog_sdxl/example_datasets/kiriko/6.src.jpg +0 -0
- cog_sdxl/example_datasets/kiriko/7.src.jpg +0 -0
- cog_sdxl/example_datasets/kiriko/8.src.jpg +0 -0
- cog_sdxl/example_datasets/kiriko/9.src.jpg +0 -0
- cog_sdxl/example_datasets/monster.png +3 -0
- cog_sdxl/example_datasets/monster/caption.csv +6 -0
- cog_sdxl/example_datasets/monster/monstertoy (1).jpg +0 -0
- cog_sdxl/example_datasets/monster/monstertoy (2).jpg +0 -0
- cog_sdxl/example_datasets/monster/monstertoy (3).jpg +0 -0
- cog_sdxl/example_datasets/monster/monstertoy (4).jpg +0 -0
- cog_sdxl/example_datasets/monster/monstertoy (5).jpg +0 -0
- cog_sdxl/example_datasets/monster_uni.png +3 -0
- cog_sdxl/example_datasets/zeke.zip +3 -0
- cog_sdxl/example_datasets/zeke/0.src.jpg +0 -0
- cog_sdxl/example_datasets/zeke/1.src.jpg +0 -0
- cog_sdxl/example_datasets/zeke/2.src.jpg +0 -0
- cog_sdxl/example_datasets/zeke/3.src.jpg +0 -0
- cog_sdxl/example_datasets/zeke/4.src.jpg +0 -0
- cog_sdxl/example_datasets/zeke/5.src.jpg +0 -0
- cog_sdxl/example_datasets/zeke_unicorn.png +3 -0
- cog_sdxl/feature-extractor/preprocessor_config.json +20 -0
- cog_sdxl/no_init.py +121 -0
- cog_sdxl/predict.py +462 -0
- cog_sdxl/preprocess.py +599 -0
- cog_sdxl/requirements_test.txt +5 -0
- cog_sdxl/samples.py +155 -0
- cog_sdxl/script/download_preprocessing_weights.py +54 -0
- cog_sdxl/script/download_weights.py +50 -0
- cog_sdxl/tests/assets/out.png +3 -0
- cog_sdxl/tests/test_predict.py +205 -0
- cog_sdxl/tests/test_remote_train.py +69 -0
- cog_sdxl/tests/test_utils.py +105 -0
.gitattributes
CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
37 |
+
cog_sdxl/tests/assets/ filter=lfs diff=lfs merge=lfs -text
|
cog_sdxl/.dockerignore
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
sdxl-cache/
|
2 |
+
refiner-cache/
|
3 |
+
safety-cache/
|
4 |
+
trained-model/
|
5 |
+
*.png
|
6 |
+
cache/
|
7 |
+
checkpoint/
|
8 |
+
training_out/
|
9 |
+
dreambooth/
|
10 |
+
lora/
|
11 |
+
ttemp/
|
12 |
+
.git/
|
13 |
+
cog_class_data/
|
14 |
+
dataset/
|
15 |
+
training_data/
|
16 |
+
temp/
|
17 |
+
temp_in/
|
18 |
+
cog_instance_data/
|
19 |
+
example_datasets/
|
20 |
+
trained_model.tar
|
21 |
+
zeke_data.tar
|
22 |
+
data.tar
|
23 |
+
zeke.zip
|
24 |
+
sketch-mountains-input.jpeg
|
25 |
+
training_out*
|
26 |
+
weights
|
27 |
+
inference_*
|
28 |
+
trained-model
|
29 |
+
*.zip
|
30 |
+
tmp/
|
31 |
+
blip-cache/
|
32 |
+
clipseg-cache/
|
33 |
+
swin2sr-cache/
|
34 |
+
weights-cache/
|
35 |
+
tests/
|
cog_sdxl/.gitignore
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
refiner-cache
|
3 |
+
sdxl-cache
|
4 |
+
safety-cache
|
5 |
+
trained-model
|
6 |
+
temp
|
7 |
+
temp_in
|
8 |
+
cache
|
9 |
+
.cog
|
10 |
+
__pycache__
|
11 |
+
wandb
|
12 |
+
ft*
|
13 |
+
*.ipynb
|
14 |
+
dataset
|
15 |
+
training_data
|
16 |
+
training_out
|
17 |
+
output*
|
18 |
+
training_out*
|
19 |
+
trained_model.tar
|
20 |
+
checkpoint*
|
21 |
+
weights
|
22 |
+
__*.zip
|
23 |
+
**-cache
|
cog_sdxl/LICENSE
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
Apache License
|
3 |
+
Version 2.0, January 2004
|
4 |
+
http://www.apache.org/licenses/
|
5 |
+
|
6 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
7 |
+
|
8 |
+
1. Definitions.
|
9 |
+
|
10 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
11 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
12 |
+
|
13 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
14 |
+
the copyright owner that is granting the License.
|
15 |
+
|
16 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
17 |
+
other entities that control, are controlled by, or are under common
|
18 |
+
control with that entity. For the purposes of this definition,
|
19 |
+
"control" means (i) the power, direct or indirect, to cause the
|
20 |
+
direction or management of such entity, whether by contract or
|
21 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
22 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
23 |
+
|
24 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
25 |
+
exercising permissions granted by this License.
|
26 |
+
|
27 |
+
"Source" form shall mean the preferred form for making modifications,
|
28 |
+
including but not limited to software source code, documentation
|
29 |
+
source, and configuration files.
|
30 |
+
|
31 |
+
"Object" form shall mean any form resulting from mechanical
|
32 |
+
transformation or translation of a Source form, including but
|
33 |
+
not limited to compiled object code, generated documentation,
|
34 |
+
and conversions to other media types.
|
35 |
+
|
36 |
+
"Work" shall mean the work of authorship, whether in Source or
|
37 |
+
Object form, made available under the License, as indicated by a
|
38 |
+
copyright notice that is included in or attached to the work
|
39 |
+
(an example is provided in the Appendix below).
|
40 |
+
|
41 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
42 |
+
form, that is based on (or derived from) the Work and for which the
|
43 |
+
editorial revisions, annotations, elaborations, or other modifications
|
44 |
+
represent, as a whole, an original work of authorship. For the purposes
|
45 |
+
of this License, Derivative Works shall not include works that remain
|
46 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
47 |
+
the Work and Derivative Works thereof.
|
48 |
+
|
49 |
+
"Contribution" shall mean any work of authorship, including
|
50 |
+
the original version of the Work and any modifications or additions
|
51 |
+
to that Work or Derivative Works thereof, that is intentionally
|
52 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
53 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
54 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
55 |
+
means any form of electronic, verbal, or written communication sent
|
56 |
+
to the Licensor or its representatives, including but not limited to
|
57 |
+
communication on electronic mailing lists, source code control systems,
|
58 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
59 |
+
Licensor for the purpose of discussing and improving the Work, but
|
60 |
+
excluding communication that is conspicuously marked or otherwise
|
61 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
62 |
+
|
63 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
64 |
+
on behalf of whom a Contribution has been received by Licensor and
|
65 |
+
subsequently incorporated within the Work.
|
66 |
+
|
67 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
68 |
+
this License, each Contributor hereby grants to You a perpetual,
|
69 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
70 |
+
copyright license to reproduce, prepare Derivative Works of,
|
71 |
+
publicly display, publicly perform, sublicense, and distribute the
|
72 |
+
Work and such Derivative Works in Source or Object form.
|
73 |
+
|
74 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
75 |
+
this License, each Contributor hereby grants to You a perpetual,
|
76 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
77 |
+
(except as stated in this section) patent license to make, have made,
|
78 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
79 |
+
where such license applies only to those patent claims licensable
|
80 |
+
by such Contributor that are necessarily infringed by their
|
81 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
82 |
+
with the Work to which such Contribution(s) was submitted. If You
|
83 |
+
institute patent litigation against any entity (including a
|
84 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
85 |
+
or a Contribution incorporated within the Work constitutes direct
|
86 |
+
or contributory patent infringement, then any patent licenses
|
87 |
+
granted to You under this License for that Work shall terminate
|
88 |
+
as of the date such litigation is filed.
|
89 |
+
|
90 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
91 |
+
Work or Derivative Works thereof in any medium, with or without
|
92 |
+
modifications, and in Source or Object form, provided that You
|
93 |
+
meet the following conditions:
|
94 |
+
|
95 |
+
(a) You must give any other recipients of the Work or
|
96 |
+
Derivative Works a copy of this License; and
|
97 |
+
|
98 |
+
(b) You must cause any modified files to carry prominent notices
|
99 |
+
stating that You changed the files; and
|
100 |
+
|
101 |
+
(c) You must retain, in the Source form of any Derivative Works
|
102 |
+
that You distribute, all copyright, patent, trademark, and
|
103 |
+
attribution notices from the Source form of the Work,
|
104 |
+
excluding those notices that do not pertain to any part of
|
105 |
+
the Derivative Works; and
|
106 |
+
|
107 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
108 |
+
distribution, then any Derivative Works that You distribute must
|
109 |
+
include a readable copy of the attribution notices contained
|
110 |
+
within such NOTICE file, excluding those notices that do not
|
111 |
+
pertain to any part of the Derivative Works, in at least one
|
112 |
+
of the following places: within a NOTICE text file distributed
|
113 |
+
as part of the Derivative Works; within the Source form or
|
114 |
+
documentation, if provided along with the Derivative Works; or,
|
115 |
+
within a display generated by the Derivative Works, if and
|
116 |
+
wherever such third-party notices normally appear. The contents
|
117 |
+
of the NOTICE file are for informational purposes only and
|
118 |
+
do not modify the License. You may add Your own attribution
|
119 |
+
notices within Derivative Works that You distribute, alongside
|
120 |
+
or as an addendum to the NOTICE text from the Work, provided
|
121 |
+
that such additional attribution notices cannot be construed
|
122 |
+
as modifying the License.
|
123 |
+
|
124 |
+
You may add Your own copyright statement to Your modifications and
|
125 |
+
may provide additional or different license terms and conditions
|
126 |
+
for use, reproduction, or distribution of Your modifications, or
|
127 |
+
for any such Derivative Works as a whole, provided Your use,
|
128 |
+
reproduction, and distribution of the Work otherwise complies with
|
129 |
+
the conditions stated in this License.
|
130 |
+
|
131 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
132 |
+
any Contribution intentionally submitted for inclusion in the Work
|
133 |
+
by You to the Licensor shall be under the terms and conditions of
|
134 |
+
this License, without any additional terms or conditions.
|
135 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
136 |
+
the terms of any separate license agreement you may have executed
|
137 |
+
with Licensor regarding such Contributions.
|
138 |
+
|
139 |
+
6. Trademarks. This License does not grant permission to use the trade
|
140 |
+
names, trademarks, service marks, or product names of the Licensor,
|
141 |
+
except as required for reasonable and customary use in describing the
|
142 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
143 |
+
|
144 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
145 |
+
agreed to in writing, Licensor provides the Work (and each
|
146 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
147 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
148 |
+
implied, including, without limitation, any warranties or conditions
|
149 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
150 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
151 |
+
appropriateness of using or redistributing the Work and assume any
|
152 |
+
risks associated with Your exercise of permissions under this License.
|
153 |
+
|
154 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
155 |
+
whether in tort (including negligence), contract, or otherwise,
|
156 |
+
unless required by applicable law (such as deliberate and grossly
|
157 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
158 |
+
liable to You for damages, including any direct, indirect, special,
|
159 |
+
incidental, or consequential damages of any character arising as a
|
160 |
+
result of this License or out of the use or inability to use the
|
161 |
+
Work (including but not limited to damages for loss of goodwill,
|
162 |
+
work stoppage, computer failure or malfunction, or any and all
|
163 |
+
other commercial damages or losses), even if such Contributor
|
164 |
+
has been advised of the possibility of such damages.
|
165 |
+
|
166 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
167 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
168 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
169 |
+
or other liability obligations and/or rights consistent with this
|
170 |
+
License. However, in accepting such obligations, You may act only
|
171 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
172 |
+
of any other Contributor, and only if You agree to indemnify,
|
173 |
+
defend, and hold each Contributor harmless for any liability
|
174 |
+
incurred by, or claims asserted against, such Contributor by reason
|
175 |
+
of your accepting any such warranty or additional liability.
|
176 |
+
|
177 |
+
END OF TERMS AND CONDITIONS
|
178 |
+
|
179 |
+
APPENDIX: How to apply the Apache License to your work.
|
180 |
+
|
181 |
+
To apply the Apache License to your work, attach the following
|
182 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
183 |
+
replaced with your own identifying information. (Don't include
|
184 |
+
the brackets!) The text should be enclosed in the appropriate
|
185 |
+
comment syntax for the file format. We also recommend that a
|
186 |
+
file or class name and description of purpose be included on the
|
187 |
+
same "printed page" as the copyright notice for easier
|
188 |
+
identification within third-party archives.
|
189 |
+
|
190 |
+
Copyright 2023, Replicate, Inc.
|
191 |
+
|
192 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
193 |
+
you may not use this file except in compliance with the License.
|
194 |
+
You may obtain a copy of the License at
|
195 |
+
|
196 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
197 |
+
|
198 |
+
Unless required by applicable law or agreed to in writing, software
|
199 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
200 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
201 |
+
See the License for the specific language governing permissions and
|
202 |
+
limitations under the License.
|
cog_sdxl/README.md
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Cog-SDXL
|
2 |
+
|
3 |
+
[![Replicate demo and cloud API](https://replicate.com/stability-ai/sdxl/badge)](https://replicate.com/stability-ai/sdxl)
|
4 |
+
|
5 |
+
This is an implementation of Stability AI's [SDXL](https://github.com/Stability-AI/generative-models) as a [Cog](https://github.com/replicate/cog) model.
|
6 |
+
|
7 |
+
## Development
|
8 |
+
|
9 |
+
Follow the [model pushing guide](https://replicate.com/docs/guides/push-a-model) to push your own fork of SDXL to [Replicate](https://replicate.com).
|
10 |
+
|
11 |
+
## Basic Usage
|
12 |
+
|
13 |
+
for prediction,
|
14 |
+
|
15 |
+
```bash
|
16 |
+
cog predict -i prompt="a photo of TOK"
|
17 |
+
```
|
18 |
+
|
19 |
+
```bash
|
20 |
+
cog train -i input_images=@example_datasets/__data.zip -i use_face_detection_instead=True
|
21 |
+
```
|
22 |
+
|
23 |
+
```bash
|
24 |
+
cog run -p 5000 python -m cog.server.http
|
25 |
+
```
|
26 |
+
|
27 |
+
## Update notes
|
28 |
+
|
29 |
+
**2023-08-17**
|
30 |
+
* ROI problem is fixed.
|
31 |
+
* Now BLIP caption_prefix does not interfere with BLIP captioner.
|
32 |
+
|
33 |
+
|
34 |
+
**2023-08-12**
|
35 |
+
* Input types are inferred from input name extensions, or from the `input_images_filetype` argument
|
36 |
+
* Preprocssing are now done with fp16, and if no mask is found, the model will use the whole image
|
37 |
+
|
38 |
+
**2023-08-11**
|
39 |
+
* Default to 768x768 resolution training
|
40 |
+
* Rank as argument now, default to 32
|
41 |
+
* Now uses Swin2SR `caidas/swin2SR-realworld-sr-x4-64-bsrgan-psnr` as default, and will upscale + downscale to 768x768
|
cog_sdxl/cog.yaml
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Configuration for Cog ⚙️
|
2 |
+
# Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md
|
3 |
+
|
4 |
+
build:
|
5 |
+
gpu: true
|
6 |
+
cuda: "11.8"
|
7 |
+
python_version: "3.9"
|
8 |
+
system_packages:
|
9 |
+
- "libgl1-mesa-glx"
|
10 |
+
- "ffmpeg"
|
11 |
+
- "libsm6"
|
12 |
+
- "libxext6"
|
13 |
+
- "wget"
|
14 |
+
python_packages:
|
15 |
+
- "diffusers<=0.25"
|
16 |
+
- "torch==2.0.1"
|
17 |
+
- "transformers==4.31.0"
|
18 |
+
- "invisible-watermark==0.2.0"
|
19 |
+
- "accelerate==0.21.0"
|
20 |
+
- "pandas==2.0.3"
|
21 |
+
- "torchvision==0.15.2"
|
22 |
+
- "numpy==1.25.1"
|
23 |
+
- "pandas==2.0.3"
|
24 |
+
- "fire==0.5.0"
|
25 |
+
- "opencv-python>=4.1.0.25"
|
26 |
+
- "mediapipe==0.10.2"
|
27 |
+
|
28 |
+
run:
|
29 |
+
- curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/latest/download/pget_$(uname -s)_$(uname -m)" && chmod +x /usr/local/bin/pget
|
30 |
+
- wget http://thegiflibrary.tumblr.com/post/11565547760 -O face_landmarker_v2_with_blendshapes.task -q https://storage.googleapis.com/mediapipe-models/face_landmarker/face_landmarker/float16/1/face_landmarker.task
|
31 |
+
|
32 |
+
predict: "predict.py:Predictor"
|
33 |
+
train: "train.py:train"
|
cog_sdxl/dataset_and_utils.py
ADDED
@@ -0,0 +1,421 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import Dict, List, Optional, Tuple
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import pandas as pd
|
6 |
+
import PIL
|
7 |
+
import torch
|
8 |
+
import torch.utils.checkpoint
|
9 |
+
from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
|
10 |
+
from PIL import Image
|
11 |
+
from safetensors import safe_open
|
12 |
+
from safetensors.torch import save_file
|
13 |
+
from torch.utils.data import Dataset
|
14 |
+
from transformers import AutoTokenizer, PretrainedConfig
|
15 |
+
|
16 |
+
|
17 |
+
def prepare_image(
|
18 |
+
pil_image: PIL.Image.Image, w: int = 512, h: int = 512
|
19 |
+
) -> torch.Tensor:
|
20 |
+
pil_image = pil_image.resize((w, h), resample=Image.BICUBIC, reducing_gap=1)
|
21 |
+
arr = np.array(pil_image.convert("RGB"))
|
22 |
+
arr = arr.astype(np.float32) / 127.5 - 1
|
23 |
+
arr = np.transpose(arr, [2, 0, 1])
|
24 |
+
image = torch.from_numpy(arr).unsqueeze(0)
|
25 |
+
return image
|
26 |
+
|
27 |
+
|
28 |
+
def prepare_mask(
|
29 |
+
pil_image: PIL.Image.Image, w: int = 512, h: int = 512
|
30 |
+
) -> torch.Tensor:
|
31 |
+
pil_image = pil_image.resize((w, h), resample=Image.BICUBIC, reducing_gap=1)
|
32 |
+
arr = np.array(pil_image.convert("L"))
|
33 |
+
arr = arr.astype(np.float32) / 255.0
|
34 |
+
arr = np.expand_dims(arr, 0)
|
35 |
+
image = torch.from_numpy(arr).unsqueeze(0)
|
36 |
+
return image
|
37 |
+
|
38 |
+
|
39 |
+
class PreprocessedDataset(Dataset):
|
40 |
+
def __init__(
|
41 |
+
self,
|
42 |
+
csv_path: str,
|
43 |
+
tokenizer_1,
|
44 |
+
tokenizer_2,
|
45 |
+
vae_encoder,
|
46 |
+
text_encoder_1=None,
|
47 |
+
text_encoder_2=None,
|
48 |
+
do_cache: bool = False,
|
49 |
+
size: int = 512,
|
50 |
+
text_dropout: float = 0.0,
|
51 |
+
scale_vae_latents: bool = True,
|
52 |
+
substitute_caption_map: Dict[str, str] = {},
|
53 |
+
):
|
54 |
+
super().__init__()
|
55 |
+
|
56 |
+
self.data = pd.read_csv(csv_path)
|
57 |
+
self.csv_path = csv_path
|
58 |
+
|
59 |
+
self.caption = self.data["caption"]
|
60 |
+
# make it lowercase
|
61 |
+
self.caption = self.caption.str.lower()
|
62 |
+
for key, value in substitute_caption_map.items():
|
63 |
+
self.caption = self.caption.str.replace(key.lower(), value)
|
64 |
+
|
65 |
+
self.image_path = self.data["image_path"]
|
66 |
+
|
67 |
+
if "mask_path" not in self.data.columns:
|
68 |
+
self.mask_path = None
|
69 |
+
else:
|
70 |
+
self.mask_path = self.data["mask_path"]
|
71 |
+
|
72 |
+
if text_encoder_1 is None:
|
73 |
+
self.return_text_embeddings = False
|
74 |
+
else:
|
75 |
+
self.text_encoder_1 = text_encoder_1
|
76 |
+
self.text_encoder_2 = text_encoder_2
|
77 |
+
self.return_text_embeddings = True
|
78 |
+
assert (
|
79 |
+
NotImplementedError
|
80 |
+
), "Preprocessing Text Encoder is not implemented yet"
|
81 |
+
|
82 |
+
self.tokenizer_1 = tokenizer_1
|
83 |
+
self.tokenizer_2 = tokenizer_2
|
84 |
+
|
85 |
+
self.vae_encoder = vae_encoder
|
86 |
+
self.scale_vae_latents = scale_vae_latents
|
87 |
+
self.text_dropout = text_dropout
|
88 |
+
|
89 |
+
self.size = size
|
90 |
+
|
91 |
+
if do_cache:
|
92 |
+
self.vae_latents = []
|
93 |
+
self.tokens_tuple = []
|
94 |
+
self.masks = []
|
95 |
+
|
96 |
+
self.do_cache = True
|
97 |
+
|
98 |
+
print("Captions to train on: ")
|
99 |
+
for idx in range(len(self.data)):
|
100 |
+
token, vae_latent, mask = self._process(idx)
|
101 |
+
self.vae_latents.append(vae_latent)
|
102 |
+
self.tokens_tuple.append(token)
|
103 |
+
self.masks.append(mask)
|
104 |
+
|
105 |
+
del self.vae_encoder
|
106 |
+
|
107 |
+
else:
|
108 |
+
self.do_cache = False
|
109 |
+
|
110 |
+
@torch.no_grad()
|
111 |
+
def _process(
|
112 |
+
self, idx: int
|
113 |
+
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]:
|
114 |
+
image_path = self.image_path[idx]
|
115 |
+
image_path = os.path.join(os.path.dirname(self.csv_path), image_path)
|
116 |
+
|
117 |
+
image = PIL.Image.open(image_path).convert("RGB")
|
118 |
+
image = prepare_image(image, self.size, self.size).to(
|
119 |
+
dtype=self.vae_encoder.dtype, device=self.vae_encoder.device
|
120 |
+
)
|
121 |
+
|
122 |
+
caption = self.caption[idx]
|
123 |
+
|
124 |
+
print(caption)
|
125 |
+
|
126 |
+
# tokenizer_1
|
127 |
+
ti1 = self.tokenizer_1(
|
128 |
+
caption,
|
129 |
+
padding="max_length",
|
130 |
+
max_length=77,
|
131 |
+
truncation=True,
|
132 |
+
add_special_tokens=True,
|
133 |
+
return_tensors="pt",
|
134 |
+
).input_ids
|
135 |
+
|
136 |
+
ti2 = self.tokenizer_2(
|
137 |
+
caption,
|
138 |
+
padding="max_length",
|
139 |
+
max_length=77,
|
140 |
+
truncation=True,
|
141 |
+
add_special_tokens=True,
|
142 |
+
return_tensors="pt",
|
143 |
+
).input_ids
|
144 |
+
|
145 |
+
vae_latent = self.vae_encoder.encode(image).latent_dist.sample()
|
146 |
+
|
147 |
+
if self.scale_vae_latents:
|
148 |
+
vae_latent = vae_latent * self.vae_encoder.config.scaling_factor
|
149 |
+
|
150 |
+
if self.mask_path is None:
|
151 |
+
mask = torch.ones_like(
|
152 |
+
vae_latent, dtype=self.vae_encoder.dtype, device=self.vae_encoder.device
|
153 |
+
)
|
154 |
+
|
155 |
+
else:
|
156 |
+
mask_path = self.mask_path[idx]
|
157 |
+
mask_path = os.path.join(os.path.dirname(self.csv_path), mask_path)
|
158 |
+
|
159 |
+
mask = PIL.Image.open(mask_path)
|
160 |
+
mask = prepare_mask(mask, self.size, self.size).to(
|
161 |
+
dtype=self.vae_encoder.dtype, device=self.vae_encoder.device
|
162 |
+
)
|
163 |
+
|
164 |
+
mask = torch.nn.functional.interpolate(
|
165 |
+
mask, size=(vae_latent.shape[-2], vae_latent.shape[-1]), mode="nearest"
|
166 |
+
)
|
167 |
+
mask = mask.repeat(1, vae_latent.shape[1], 1, 1)
|
168 |
+
|
169 |
+
assert len(mask.shape) == 4 and len(vae_latent.shape) == 4
|
170 |
+
|
171 |
+
return (ti1.squeeze(), ti2.squeeze()), vae_latent.squeeze(), mask.squeeze()
|
172 |
+
|
173 |
+
def __len__(self) -> int:
|
174 |
+
return len(self.data)
|
175 |
+
|
176 |
+
def atidx(
|
177 |
+
self, idx: int
|
178 |
+
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]:
|
179 |
+
if self.do_cache:
|
180 |
+
return self.tokens_tuple[idx], self.vae_latents[idx], self.masks[idx]
|
181 |
+
else:
|
182 |
+
return self._process(idx)
|
183 |
+
|
184 |
+
def __getitem__(
|
185 |
+
self, idx: int
|
186 |
+
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]:
|
187 |
+
token, vae_latent, mask = self.atidx(idx)
|
188 |
+
return token, vae_latent, mask
|
189 |
+
|
190 |
+
|
191 |
+
def import_model_class_from_model_name_or_path(
|
192 |
+
pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
|
193 |
+
):
|
194 |
+
text_encoder_config = PretrainedConfig.from_pretrained(
|
195 |
+
pretrained_model_name_or_path, subfolder=subfolder, revision=revision
|
196 |
+
)
|
197 |
+
model_class = text_encoder_config.architectures[0]
|
198 |
+
|
199 |
+
if model_class == "CLIPTextModel":
|
200 |
+
from transformers import CLIPTextModel
|
201 |
+
|
202 |
+
return CLIPTextModel
|
203 |
+
elif model_class == "CLIPTextModelWithProjection":
|
204 |
+
from transformers import CLIPTextModelWithProjection
|
205 |
+
|
206 |
+
return CLIPTextModelWithProjection
|
207 |
+
else:
|
208 |
+
raise ValueError(f"{model_class} is not supported.")
|
209 |
+
|
210 |
+
|
211 |
+
def load_models(pretrained_model_name_or_path, revision, device, weight_dtype):
|
212 |
+
tokenizer_one = AutoTokenizer.from_pretrained(
|
213 |
+
pretrained_model_name_or_path,
|
214 |
+
subfolder="tokenizer",
|
215 |
+
revision=revision,
|
216 |
+
use_fast=False,
|
217 |
+
)
|
218 |
+
tokenizer_two = AutoTokenizer.from_pretrained(
|
219 |
+
pretrained_model_name_or_path,
|
220 |
+
subfolder="tokenizer_2",
|
221 |
+
revision=revision,
|
222 |
+
use_fast=False,
|
223 |
+
)
|
224 |
+
|
225 |
+
# Load scheduler and models
|
226 |
+
noise_scheduler = DDPMScheduler.from_pretrained(
|
227 |
+
pretrained_model_name_or_path, subfolder="scheduler"
|
228 |
+
)
|
229 |
+
# import correct text encoder classes
|
230 |
+
text_encoder_cls_one = import_model_class_from_model_name_or_path(
|
231 |
+
pretrained_model_name_or_path, revision
|
232 |
+
)
|
233 |
+
text_encoder_cls_two = import_model_class_from_model_name_or_path(
|
234 |
+
pretrained_model_name_or_path, revision, subfolder="text_encoder_2"
|
235 |
+
)
|
236 |
+
text_encoder_one = text_encoder_cls_one.from_pretrained(
|
237 |
+
pretrained_model_name_or_path, subfolder="text_encoder", revision=revision
|
238 |
+
)
|
239 |
+
text_encoder_two = text_encoder_cls_two.from_pretrained(
|
240 |
+
pretrained_model_name_or_path, subfolder="text_encoder_2", revision=revision
|
241 |
+
)
|
242 |
+
|
243 |
+
vae = AutoencoderKL.from_pretrained(
|
244 |
+
pretrained_model_name_or_path, subfolder="vae", revision=revision
|
245 |
+
)
|
246 |
+
unet = UNet2DConditionModel.from_pretrained(
|
247 |
+
pretrained_model_name_or_path, subfolder="unet", revision=revision
|
248 |
+
)
|
249 |
+
|
250 |
+
vae.requires_grad_(False)
|
251 |
+
text_encoder_one.requires_grad_(False)
|
252 |
+
text_encoder_two.requires_grad_(False)
|
253 |
+
|
254 |
+
unet.to(device, dtype=weight_dtype)
|
255 |
+
vae.to(device, dtype=torch.float32)
|
256 |
+
text_encoder_one.to(device, dtype=weight_dtype)
|
257 |
+
text_encoder_two.to(device, dtype=weight_dtype)
|
258 |
+
|
259 |
+
return (
|
260 |
+
tokenizer_one,
|
261 |
+
tokenizer_two,
|
262 |
+
noise_scheduler,
|
263 |
+
text_encoder_one,
|
264 |
+
text_encoder_two,
|
265 |
+
vae,
|
266 |
+
unet,
|
267 |
+
)
|
268 |
+
|
269 |
+
|
270 |
+
def unet_attn_processors_state_dict(unet) -> Dict[str, torch.tensor]:
|
271 |
+
"""
|
272 |
+
Returns:
|
273 |
+
a state dict containing just the attention processor parameters.
|
274 |
+
"""
|
275 |
+
attn_processors = unet.attn_processors
|
276 |
+
|
277 |
+
attn_processors_state_dict = {}
|
278 |
+
|
279 |
+
for attn_processor_key, attn_processor in attn_processors.items():
|
280 |
+
for parameter_key, parameter in attn_processor.state_dict().items():
|
281 |
+
attn_processors_state_dict[
|
282 |
+
f"{attn_processor_key}.{parameter_key}"
|
283 |
+
] = parameter
|
284 |
+
|
285 |
+
return attn_processors_state_dict
|
286 |
+
|
287 |
+
|
288 |
+
class TokenEmbeddingsHandler:
|
289 |
+
def __init__(self, text_encoders, tokenizers):
|
290 |
+
self.text_encoders = text_encoders
|
291 |
+
self.tokenizers = tokenizers
|
292 |
+
|
293 |
+
self.train_ids: Optional[torch.Tensor] = None
|
294 |
+
self.inserting_toks: Optional[List[str]] = None
|
295 |
+
self.embeddings_settings = {}
|
296 |
+
|
297 |
+
def initialize_new_tokens(self, inserting_toks: List[str]):
|
298 |
+
idx = 0
|
299 |
+
for tokenizer, text_encoder in zip(self.tokenizers, self.text_encoders):
|
300 |
+
assert isinstance(
|
301 |
+
inserting_toks, list
|
302 |
+
), "inserting_toks should be a list of strings."
|
303 |
+
assert all(
|
304 |
+
isinstance(tok, str) for tok in inserting_toks
|
305 |
+
), "All elements in inserting_toks should be strings."
|
306 |
+
|
307 |
+
self.inserting_toks = inserting_toks
|
308 |
+
special_tokens_dict = {"additional_special_tokens": self.inserting_toks}
|
309 |
+
tokenizer.add_special_tokens(special_tokens_dict)
|
310 |
+
text_encoder.resize_token_embeddings(len(tokenizer))
|
311 |
+
|
312 |
+
self.train_ids = tokenizer.convert_tokens_to_ids(self.inserting_toks)
|
313 |
+
|
314 |
+
# random initialization of new tokens
|
315 |
+
|
316 |
+
std_token_embedding = (
|
317 |
+
text_encoder.text_model.embeddings.token_embedding.weight.data.std()
|
318 |
+
)
|
319 |
+
|
320 |
+
print(f"{idx} text encodedr's std_token_embedding: {std_token_embedding}")
|
321 |
+
|
322 |
+
text_encoder.text_model.embeddings.token_embedding.weight.data[
|
323 |
+
self.train_ids
|
324 |
+
] = (
|
325 |
+
torch.randn(
|
326 |
+
len(self.train_ids), text_encoder.text_model.config.hidden_size
|
327 |
+
)
|
328 |
+
.to(device=self.device)
|
329 |
+
.to(dtype=self.dtype)
|
330 |
+
* std_token_embedding
|
331 |
+
)
|
332 |
+
self.embeddings_settings[
|
333 |
+
f"original_embeddings_{idx}"
|
334 |
+
] = text_encoder.text_model.embeddings.token_embedding.weight.data.clone()
|
335 |
+
self.embeddings_settings[f"std_token_embedding_{idx}"] = std_token_embedding
|
336 |
+
|
337 |
+
inu = torch.ones((len(tokenizer),), dtype=torch.bool)
|
338 |
+
inu[self.train_ids] = False
|
339 |
+
|
340 |
+
self.embeddings_settings[f"index_no_updates_{idx}"] = inu
|
341 |
+
|
342 |
+
print(self.embeddings_settings[f"index_no_updates_{idx}"].shape)
|
343 |
+
|
344 |
+
idx += 1
|
345 |
+
|
346 |
+
def save_embeddings(self, file_path: str):
|
347 |
+
assert (
|
348 |
+
self.train_ids is not None
|
349 |
+
), "Initialize new tokens before saving embeddings."
|
350 |
+
tensors = {}
|
351 |
+
for idx, text_encoder in enumerate(self.text_encoders):
|
352 |
+
assert text_encoder.text_model.embeddings.token_embedding.weight.data.shape[
|
353 |
+
0
|
354 |
+
] == len(self.tokenizers[0]), "Tokenizers should be the same."
|
355 |
+
new_token_embeddings = (
|
356 |
+
text_encoder.text_model.embeddings.token_embedding.weight.data[
|
357 |
+
self.train_ids
|
358 |
+
]
|
359 |
+
)
|
360 |
+
tensors[f"text_encoders_{idx}"] = new_token_embeddings
|
361 |
+
|
362 |
+
save_file(tensors, file_path)
|
363 |
+
|
364 |
+
@property
|
365 |
+
def dtype(self):
|
366 |
+
return self.text_encoders[0].dtype
|
367 |
+
|
368 |
+
@property
|
369 |
+
def device(self):
|
370 |
+
return self.text_encoders[0].device
|
371 |
+
|
372 |
+
def _load_embeddings(self, loaded_embeddings, tokenizer, text_encoder):
|
373 |
+
# Assuming new tokens are of the format <s_i>
|
374 |
+
self.inserting_toks = [f"<s{i}>" for i in range(loaded_embeddings.shape[0])]
|
375 |
+
special_tokens_dict = {"additional_special_tokens": self.inserting_toks}
|
376 |
+
tokenizer.add_special_tokens(special_tokens_dict)
|
377 |
+
text_encoder.resize_token_embeddings(len(tokenizer))
|
378 |
+
|
379 |
+
self.train_ids = tokenizer.convert_tokens_to_ids(self.inserting_toks)
|
380 |
+
assert self.train_ids is not None, "New tokens could not be converted to IDs."
|
381 |
+
text_encoder.text_model.embeddings.token_embedding.weight.data[
|
382 |
+
self.train_ids
|
383 |
+
] = loaded_embeddings.to(device=self.device).to(dtype=self.dtype)
|
384 |
+
|
385 |
+
@torch.no_grad()
|
386 |
+
def retract_embeddings(self):
|
387 |
+
for idx, text_encoder in enumerate(self.text_encoders):
|
388 |
+
index_no_updates = self.embeddings_settings[f"index_no_updates_{idx}"]
|
389 |
+
text_encoder.text_model.embeddings.token_embedding.weight.data[
|
390 |
+
index_no_updates
|
391 |
+
] = (
|
392 |
+
self.embeddings_settings[f"original_embeddings_{idx}"][index_no_updates]
|
393 |
+
.to(device=text_encoder.device)
|
394 |
+
.to(dtype=text_encoder.dtype)
|
395 |
+
)
|
396 |
+
|
397 |
+
# for the parts that were updated, we need to normalize them
|
398 |
+
# to have the same std as before
|
399 |
+
std_token_embedding = self.embeddings_settings[f"std_token_embedding_{idx}"]
|
400 |
+
|
401 |
+
index_updates = ~index_no_updates
|
402 |
+
new_embeddings = (
|
403 |
+
text_encoder.text_model.embeddings.token_embedding.weight.data[
|
404 |
+
index_updates
|
405 |
+
]
|
406 |
+
)
|
407 |
+
off_ratio = std_token_embedding / new_embeddings.std()
|
408 |
+
|
409 |
+
new_embeddings = new_embeddings * (off_ratio**0.1)
|
410 |
+
text_encoder.text_model.embeddings.token_embedding.weight.data[
|
411 |
+
index_updates
|
412 |
+
] = new_embeddings
|
413 |
+
|
414 |
+
def load_embeddings(self, file_path: str):
|
415 |
+
with safe_open(file_path, framework="pt", device=self.device.type) as f:
|
416 |
+
for idx in range(len(self.text_encoders)):
|
417 |
+
text_encoder = self.text_encoders[idx]
|
418 |
+
tokenizer = self.tokenizers[idx]
|
419 |
+
|
420 |
+
loaded_embeddings = f.get_tensor(f"text_encoders_{idx}")
|
421 |
+
self._load_embeddings(loaded_embeddings, tokenizer, text_encoder)
|
cog_sdxl/example_datasets/README.md
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
## Example Datasets
|
2 |
+
|
3 |
+
This folder contains three example datasets that were used to tune SDXL using the Replicate API, along with (at the top level) example outputs generated from those datasets.
|
cog_sdxl/example_datasets/kiriko.png
ADDED
Git LFS Details
|
cog_sdxl/example_datasets/kiriko/0.src.jpg
ADDED
cog_sdxl/example_datasets/kiriko/1.src.jpg
ADDED
cog_sdxl/example_datasets/kiriko/10.src.jpg
ADDED
cog_sdxl/example_datasets/kiriko/11.src.jpg
ADDED
cog_sdxl/example_datasets/kiriko/12.src.jpg
ADDED
cog_sdxl/example_datasets/kiriko/2.src.jpg
ADDED
cog_sdxl/example_datasets/kiriko/3.src.jpg
ADDED
cog_sdxl/example_datasets/kiriko/4.src.jpg
ADDED
cog_sdxl/example_datasets/kiriko/5.src.jpg
ADDED
cog_sdxl/example_datasets/kiriko/6.src.jpg
ADDED
cog_sdxl/example_datasets/kiriko/7.src.jpg
ADDED
cog_sdxl/example_datasets/kiriko/8.src.jpg
ADDED
cog_sdxl/example_datasets/kiriko/9.src.jpg
ADDED
cog_sdxl/example_datasets/monster.png
ADDED
Git LFS Details
|
cog_sdxl/example_datasets/monster/caption.csv
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
caption,image_file
|
2 |
+
a TOK on a windowsill,monstertoy (1).jpg
|
3 |
+
a photo of smiling TOK in an office,monstertoy (2).jpg
|
4 |
+
a photo of TOK sitting by a window,monstertoy (3).jpg
|
5 |
+
a photo of TOK on a car,monstertoy (4).jpg
|
6 |
+
a photo of TOK smiling on the ground,monstertoy (5).jpg
|
cog_sdxl/example_datasets/monster/monstertoy (1).jpg
ADDED
cog_sdxl/example_datasets/monster/monstertoy (2).jpg
ADDED
cog_sdxl/example_datasets/monster/monstertoy (3).jpg
ADDED
cog_sdxl/example_datasets/monster/monstertoy (4).jpg
ADDED
cog_sdxl/example_datasets/monster/monstertoy (5).jpg
ADDED
cog_sdxl/example_datasets/monster_uni.png
ADDED
Git LFS Details
|
cog_sdxl/example_datasets/zeke.zip
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:64d655ee118eec386272a15c8e3c2522bc40155cd0f39f451596f7800df403e6
|
3 |
+
size 860587
|
cog_sdxl/example_datasets/zeke/0.src.jpg
ADDED
cog_sdxl/example_datasets/zeke/1.src.jpg
ADDED
cog_sdxl/example_datasets/zeke/2.src.jpg
ADDED
cog_sdxl/example_datasets/zeke/3.src.jpg
ADDED
cog_sdxl/example_datasets/zeke/4.src.jpg
ADDED
cog_sdxl/example_datasets/zeke/5.src.jpg
ADDED
cog_sdxl/example_datasets/zeke_unicorn.png
ADDED
Git LFS Details
|
cog_sdxl/feature-extractor/preprocessor_config.json
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"crop_size": 224,
|
3 |
+
"do_center_crop": true,
|
4 |
+
"do_convert_rgb": true,
|
5 |
+
"do_normalize": true,
|
6 |
+
"do_resize": true,
|
7 |
+
"feature_extractor_type": "CLIPFeatureExtractor",
|
8 |
+
"image_mean": [
|
9 |
+
0.48145466,
|
10 |
+
0.4578275,
|
11 |
+
0.40821073
|
12 |
+
],
|
13 |
+
"image_std": [
|
14 |
+
0.26862954,
|
15 |
+
0.26130258,
|
16 |
+
0.27577711
|
17 |
+
],
|
18 |
+
"resample": 3,
|
19 |
+
"size": 224
|
20 |
+
}
|
cog_sdxl/no_init.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import contextlib
|
2 |
+
import contextvars
|
3 |
+
import threading
|
4 |
+
from typing import (
|
5 |
+
Callable,
|
6 |
+
ContextManager,
|
7 |
+
NamedTuple,
|
8 |
+
Optional,
|
9 |
+
TypeVar,
|
10 |
+
Union,
|
11 |
+
)
|
12 |
+
|
13 |
+
import torch
|
14 |
+
|
15 |
+
__all__ = ["no_init_or_tensor"]
|
16 |
+
|
17 |
+
|
18 |
+
Model = TypeVar("Model")
|
19 |
+
|
20 |
+
|
21 |
+
def no_init_or_tensor(
|
22 |
+
loading_code: Optional[Callable[..., Model]] = None
|
23 |
+
) -> Union[Model, ContextManager]:
|
24 |
+
"""
|
25 |
+
Suppress the initialization of weights while loading a model.
|
26 |
+
|
27 |
+
Can either directly be passed a callable containing model-loading code,
|
28 |
+
which will be evaluated with weight initialization suppressed,
|
29 |
+
or used as a context manager around arbitrary model-loading code.
|
30 |
+
|
31 |
+
Args:
|
32 |
+
loading_code: Either a callable to evaluate
|
33 |
+
with model weight initialization suppressed,
|
34 |
+
or None (the default) to use as a context manager.
|
35 |
+
|
36 |
+
Returns:
|
37 |
+
The return value of `loading_code`, if `loading_code` is callable.
|
38 |
+
|
39 |
+
Otherwise, if `loading_code` is None, returns a context manager
|
40 |
+
to be used in a `with`-statement.
|
41 |
+
|
42 |
+
Examples:
|
43 |
+
As a context manager::
|
44 |
+
|
45 |
+
from transformers import AutoConfig, AutoModelForCausalLM
|
46 |
+
config = AutoConfig("EleutherAI/gpt-j-6B")
|
47 |
+
with no_init_or_tensor():
|
48 |
+
model = AutoModelForCausalLM.from_config(config)
|
49 |
+
|
50 |
+
Or, directly passing a callable::
|
51 |
+
|
52 |
+
from transformers import AutoConfig, AutoModelForCausalLM
|
53 |
+
config = AutoConfig("EleutherAI/gpt-j-6B")
|
54 |
+
model = no_init_or_tensor(lambda: AutoModelForCausalLM.from_config(config))
|
55 |
+
"""
|
56 |
+
if loading_code is None:
|
57 |
+
return _NoInitOrTensorImpl.context_manager()
|
58 |
+
elif callable(loading_code):
|
59 |
+
with _NoInitOrTensorImpl.context_manager():
|
60 |
+
return loading_code()
|
61 |
+
else:
|
62 |
+
raise TypeError(
|
63 |
+
"no_init_or_tensor() expected a callable to evaluate,"
|
64 |
+
" or None if being used as a context manager;"
|
65 |
+
f' got an object of type "{type(loading_code).__name__}" instead.'
|
66 |
+
)
|
67 |
+
|
68 |
+
|
69 |
+
class _NoInitOrTensorImpl:
|
70 |
+
# Implementation of the thread-safe, async-safe, re-entrant context manager
|
71 |
+
# version of no_init_or_tensor().
|
72 |
+
# This class essentially acts as a namespace.
|
73 |
+
# It is not instantiable, because modifications to torch functions
|
74 |
+
# inherently affect the global scope, and thus there is no worthwhile data
|
75 |
+
# to store in the class instance scope.
|
76 |
+
_MODULES = (torch.nn.Linear, torch.nn.Embedding, torch.nn.LayerNorm)
|
77 |
+
_MODULE_ORIGINALS = tuple((m, m.reset_parameters) for m in _MODULES)
|
78 |
+
_ORIGINAL_EMPTY = torch.empty
|
79 |
+
|
80 |
+
is_active = contextvars.ContextVar("_NoInitOrTensorImpl.is_active", default=False)
|
81 |
+
_count_active: int = 0
|
82 |
+
_count_active_lock = threading.Lock()
|
83 |
+
|
84 |
+
@classmethod
|
85 |
+
@contextlib.contextmanager
|
86 |
+
def context_manager(cls):
|
87 |
+
if cls.is_active.get():
|
88 |
+
yield
|
89 |
+
return
|
90 |
+
|
91 |
+
with cls._count_active_lock:
|
92 |
+
cls._count_active += 1
|
93 |
+
if cls._count_active == 1:
|
94 |
+
for mod in cls._MODULES:
|
95 |
+
mod.reset_parameters = cls._disable(mod.reset_parameters)
|
96 |
+
# When torch.empty is called, make it map to meta device by replacing
|
97 |
+
# the device in kwargs.
|
98 |
+
torch.empty = cls._ORIGINAL_EMPTY
|
99 |
+
reset_token = cls.is_active.set(True)
|
100 |
+
|
101 |
+
try:
|
102 |
+
yield
|
103 |
+
finally:
|
104 |
+
cls.is_active.reset(reset_token)
|
105 |
+
with cls._count_active_lock:
|
106 |
+
cls._count_active -= 1
|
107 |
+
if cls._count_active == 0:
|
108 |
+
torch.empty = cls._ORIGINAL_EMPTY
|
109 |
+
for mod, original in cls._MODULE_ORIGINALS:
|
110 |
+
mod.reset_parameters = original
|
111 |
+
|
112 |
+
@staticmethod
|
113 |
+
def _disable(func):
|
114 |
+
def wrapper(*args, **kwargs):
|
115 |
+
# Behaves as normal except in an active context
|
116 |
+
if not _NoInitOrTensorImpl.is_active.get():
|
117 |
+
return func(*args, **kwargs)
|
118 |
+
|
119 |
+
return wrapper
|
120 |
+
|
121 |
+
__init__ = None
|
cog_sdxl/predict.py
ADDED
@@ -0,0 +1,462 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import hashlib
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
import shutil
|
5 |
+
import subprocess
|
6 |
+
import time
|
7 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
8 |
+
from weights import WeightsDownloadCache
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
from cog import BasePredictor, Input, Path
|
13 |
+
from diffusers import (
|
14 |
+
DDIMScheduler,
|
15 |
+
DiffusionPipeline,
|
16 |
+
DPMSolverMultistepScheduler,
|
17 |
+
EulerAncestralDiscreteScheduler,
|
18 |
+
EulerDiscreteScheduler,
|
19 |
+
HeunDiscreteScheduler,
|
20 |
+
PNDMScheduler,
|
21 |
+
StableDiffusionXLImg2ImgPipeline,
|
22 |
+
StableDiffusionXLInpaintPipeline,
|
23 |
+
)
|
24 |
+
from diffusers.models.attention_processor import LoRAAttnProcessor2_0
|
25 |
+
from diffusers.pipelines.stable_diffusion.safety_checker import (
|
26 |
+
StableDiffusionSafetyChecker,
|
27 |
+
)
|
28 |
+
from diffusers.utils import load_image
|
29 |
+
from safetensors import safe_open
|
30 |
+
from safetensors.torch import load_file
|
31 |
+
from transformers import CLIPImageProcessor
|
32 |
+
|
33 |
+
from dataset_and_utils import TokenEmbeddingsHandler
|
34 |
+
|
35 |
+
SDXL_MODEL_CACHE = "./sdxl-cache"
|
36 |
+
REFINER_MODEL_CACHE = "./refiner-cache"
|
37 |
+
SAFETY_CACHE = "./safety-cache"
|
38 |
+
FEATURE_EXTRACTOR = "./feature-extractor"
|
39 |
+
SDXL_URL = "https://weights.replicate.delivery/default/sdxl/sdxl-vae-upcast-fix.tar"
|
40 |
+
REFINER_URL = (
|
41 |
+
"https://weights.replicate.delivery/default/sdxl/refiner-no-vae-no-encoder-1.0.tar"
|
42 |
+
)
|
43 |
+
SAFETY_URL = "https://weights.replicate.delivery/default/sdxl/safety-1.0.tar"
|
44 |
+
|
45 |
+
|
46 |
+
class KarrasDPM:
|
47 |
+
def from_config(config):
|
48 |
+
return DPMSolverMultistepScheduler.from_config(config, use_karras_sigmas=True)
|
49 |
+
|
50 |
+
|
51 |
+
SCHEDULERS = {
|
52 |
+
"DDIM": DDIMScheduler,
|
53 |
+
"DPMSolverMultistep": DPMSolverMultistepScheduler,
|
54 |
+
"HeunDiscrete": HeunDiscreteScheduler,
|
55 |
+
"KarrasDPM": KarrasDPM,
|
56 |
+
"K_EULER_ANCESTRAL": EulerAncestralDiscreteScheduler,
|
57 |
+
"K_EULER": EulerDiscreteScheduler,
|
58 |
+
"PNDM": PNDMScheduler,
|
59 |
+
}
|
60 |
+
|
61 |
+
|
62 |
+
def download_weights(url, dest):
|
63 |
+
start = time.time()
|
64 |
+
print("downloading url: ", url)
|
65 |
+
print("downloading to: ", dest)
|
66 |
+
subprocess.check_call(["pget", "-x", url, dest], close_fds=False)
|
67 |
+
print("downloading took: ", time.time() - start)
|
68 |
+
|
69 |
+
|
70 |
+
class Predictor(BasePredictor):
|
71 |
+
def load_trained_weights(self, weights, pipe):
|
72 |
+
from no_init import no_init_or_tensor
|
73 |
+
|
74 |
+
# weights can be a URLPath, which behaves in unexpected ways
|
75 |
+
weights = str(weights)
|
76 |
+
if self.tuned_weights == weights:
|
77 |
+
print("skipping loading .. weights already loaded")
|
78 |
+
return
|
79 |
+
|
80 |
+
# predictions can be cancelled while in this function, which
|
81 |
+
# interrupts this finishing. To protect against odd states we
|
82 |
+
# set tuned_weights to a value that lets the next prediction
|
83 |
+
# know if it should try to load weights or if loading completed
|
84 |
+
self.tuned_weights = 'loading'
|
85 |
+
|
86 |
+
local_weights_cache = self.weights_cache.ensure(weights)
|
87 |
+
|
88 |
+
# load UNET
|
89 |
+
print("Loading fine-tuned model")
|
90 |
+
self.is_lora = False
|
91 |
+
|
92 |
+
maybe_unet_path = os.path.join(local_weights_cache, "unet.safetensors")
|
93 |
+
if not os.path.exists(maybe_unet_path):
|
94 |
+
print("Does not have Unet. assume we are using LoRA")
|
95 |
+
self.is_lora = True
|
96 |
+
|
97 |
+
if not self.is_lora:
|
98 |
+
print("Loading Unet")
|
99 |
+
|
100 |
+
new_unet_params = load_file(
|
101 |
+
os.path.join(local_weights_cache, "unet.safetensors")
|
102 |
+
)
|
103 |
+
# this should return _IncompatibleKeys(missing_keys=[...], unexpected_keys=[])
|
104 |
+
pipe.unet.load_state_dict(new_unet_params, strict=False)
|
105 |
+
|
106 |
+
else:
|
107 |
+
print("Loading Unet LoRA")
|
108 |
+
|
109 |
+
unet = pipe.unet
|
110 |
+
|
111 |
+
tensors = load_file(os.path.join(local_weights_cache, "lora.safetensors"))
|
112 |
+
|
113 |
+
unet_lora_attn_procs = {}
|
114 |
+
name_rank_map = {}
|
115 |
+
for tk, tv in tensors.items():
|
116 |
+
# up is N, d
|
117 |
+
tensors[tk] = tv.half()
|
118 |
+
if tk.endswith("up.weight"):
|
119 |
+
proc_name = ".".join(tk.split(".")[:-3])
|
120 |
+
r = tv.shape[1]
|
121 |
+
name_rank_map[proc_name] = r
|
122 |
+
|
123 |
+
for name, attn_processor in unet.attn_processors.items():
|
124 |
+
cross_attention_dim = (
|
125 |
+
None
|
126 |
+
if name.endswith("attn1.processor")
|
127 |
+
else unet.config.cross_attention_dim
|
128 |
+
)
|
129 |
+
if name.startswith("mid_block"):
|
130 |
+
hidden_size = unet.config.block_out_channels[-1]
|
131 |
+
elif name.startswith("up_blocks"):
|
132 |
+
block_id = int(name[len("up_blocks.")])
|
133 |
+
hidden_size = list(reversed(unet.config.block_out_channels))[
|
134 |
+
block_id
|
135 |
+
]
|
136 |
+
elif name.startswith("down_blocks"):
|
137 |
+
block_id = int(name[len("down_blocks.")])
|
138 |
+
hidden_size = unet.config.block_out_channels[block_id]
|
139 |
+
with no_init_or_tensor():
|
140 |
+
module = LoRAAttnProcessor2_0(
|
141 |
+
hidden_size=hidden_size,
|
142 |
+
cross_attention_dim=cross_attention_dim,
|
143 |
+
rank=name_rank_map[name],
|
144 |
+
).half()
|
145 |
+
unet_lora_attn_procs[name] = module.to("cuda", non_blocking=True)
|
146 |
+
|
147 |
+
unet.set_attn_processor(unet_lora_attn_procs)
|
148 |
+
unet.load_state_dict(tensors, strict=False)
|
149 |
+
|
150 |
+
# load text
|
151 |
+
handler = TokenEmbeddingsHandler(
|
152 |
+
[pipe.text_encoder, pipe.text_encoder_2], [pipe.tokenizer, pipe.tokenizer_2]
|
153 |
+
)
|
154 |
+
handler.load_embeddings(os.path.join(local_weights_cache, "embeddings.pti"))
|
155 |
+
|
156 |
+
# load params
|
157 |
+
with open(os.path.join(local_weights_cache, "special_params.json"), "r") as f:
|
158 |
+
params = json.load(f)
|
159 |
+
|
160 |
+
self.token_map = params
|
161 |
+
self.tuned_weights = weights
|
162 |
+
self.tuned_model = True
|
163 |
+
|
164 |
+
def unload_trained_weights(self, pipe: DiffusionPipeline):
|
165 |
+
print("unloading loras")
|
166 |
+
|
167 |
+
def _recursive_unset_lora(module: torch.nn.Module):
|
168 |
+
if hasattr(module, "lora_layer"):
|
169 |
+
module.lora_layer = None
|
170 |
+
|
171 |
+
for _, child in module.named_children():
|
172 |
+
_recursive_unset_lora(child)
|
173 |
+
|
174 |
+
_recursive_unset_lora(pipe.unet)
|
175 |
+
self.tuned_weights = None
|
176 |
+
self.tuned_model = False
|
177 |
+
|
178 |
+
def setup(self, weights: Optional[Path] = None):
|
179 |
+
"""Load the model into memory to make running multiple predictions efficient"""
|
180 |
+
|
181 |
+
start = time.time()
|
182 |
+
self.tuned_model = False
|
183 |
+
self.tuned_weights = None
|
184 |
+
if str(weights) == "weights":
|
185 |
+
weights = None
|
186 |
+
|
187 |
+
self.weights_cache = WeightsDownloadCache()
|
188 |
+
|
189 |
+
print("Loading safety checker...")
|
190 |
+
if not os.path.exists(SAFETY_CACHE):
|
191 |
+
download_weights(SAFETY_URL, SAFETY_CACHE)
|
192 |
+
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(
|
193 |
+
SAFETY_CACHE, torch_dtype=torch.float16
|
194 |
+
).to("cuda")
|
195 |
+
self.feature_extractor = CLIPImageProcessor.from_pretrained(FEATURE_EXTRACTOR)
|
196 |
+
|
197 |
+
if not os.path.exists(SDXL_MODEL_CACHE):
|
198 |
+
download_weights(SDXL_URL, SDXL_MODEL_CACHE)
|
199 |
+
|
200 |
+
print("Loading sdxl txt2img pipeline...")
|
201 |
+
self.txt2img_pipe = DiffusionPipeline.from_pretrained(
|
202 |
+
SDXL_MODEL_CACHE,
|
203 |
+
torch_dtype=torch.float16,
|
204 |
+
use_safetensors=True,
|
205 |
+
variant="fp16",
|
206 |
+
)
|
207 |
+
self.is_lora = False
|
208 |
+
if weights or os.path.exists("./trained-model"):
|
209 |
+
self.load_trained_weights(weights, self.txt2img_pipe)
|
210 |
+
|
211 |
+
self.txt2img_pipe.to("cuda")
|
212 |
+
|
213 |
+
print("Loading SDXL img2img pipeline...")
|
214 |
+
self.img2img_pipe = StableDiffusionXLImg2ImgPipeline(
|
215 |
+
vae=self.txt2img_pipe.vae,
|
216 |
+
text_encoder=self.txt2img_pipe.text_encoder,
|
217 |
+
text_encoder_2=self.txt2img_pipe.text_encoder_2,
|
218 |
+
tokenizer=self.txt2img_pipe.tokenizer,
|
219 |
+
tokenizer_2=self.txt2img_pipe.tokenizer_2,
|
220 |
+
unet=self.txt2img_pipe.unet,
|
221 |
+
scheduler=self.txt2img_pipe.scheduler,
|
222 |
+
)
|
223 |
+
self.img2img_pipe.to("cuda")
|
224 |
+
|
225 |
+
print("Loading SDXL inpaint pipeline...")
|
226 |
+
self.inpaint_pipe = StableDiffusionXLInpaintPipeline(
|
227 |
+
vae=self.txt2img_pipe.vae,
|
228 |
+
text_encoder=self.txt2img_pipe.text_encoder,
|
229 |
+
text_encoder_2=self.txt2img_pipe.text_encoder_2,
|
230 |
+
tokenizer=self.txt2img_pipe.tokenizer,
|
231 |
+
tokenizer_2=self.txt2img_pipe.tokenizer_2,
|
232 |
+
unet=self.txt2img_pipe.unet,
|
233 |
+
scheduler=self.txt2img_pipe.scheduler,
|
234 |
+
)
|
235 |
+
self.inpaint_pipe.to("cuda")
|
236 |
+
|
237 |
+
print("Loading SDXL refiner pipeline...")
|
238 |
+
# FIXME(ja): should the vae/text_encoder_2 be loaded from SDXL always?
|
239 |
+
# - in the case of fine-tuned SDXL should we still?
|
240 |
+
# FIXME(ja): if the answer to above is use VAE/Text_Encoder_2 from fine-tune
|
241 |
+
# what does this imply about lora + refiner? does the refiner need to know about
|
242 |
+
|
243 |
+
if not os.path.exists(REFINER_MODEL_CACHE):
|
244 |
+
download_weights(REFINER_URL, REFINER_MODEL_CACHE)
|
245 |
+
|
246 |
+
print("Loading refiner pipeline...")
|
247 |
+
self.refiner = DiffusionPipeline.from_pretrained(
|
248 |
+
REFINER_MODEL_CACHE,
|
249 |
+
text_encoder_2=self.txt2img_pipe.text_encoder_2,
|
250 |
+
vae=self.txt2img_pipe.vae,
|
251 |
+
torch_dtype=torch.float16,
|
252 |
+
use_safetensors=True,
|
253 |
+
variant="fp16",
|
254 |
+
)
|
255 |
+
self.refiner.to("cuda")
|
256 |
+
print("setup took: ", time.time() - start)
|
257 |
+
# self.txt2img_pipe.__class__.encode_prompt = new_encode_prompt
|
258 |
+
|
259 |
+
def load_image(self, path):
|
260 |
+
shutil.copyfile(path, "/tmp/image.png")
|
261 |
+
return load_image("/tmp/image.png").convert("RGB")
|
262 |
+
|
263 |
+
def run_safety_checker(self, image):
|
264 |
+
safety_checker_input = self.feature_extractor(image, return_tensors="pt").to(
|
265 |
+
"cuda"
|
266 |
+
)
|
267 |
+
np_image = [np.array(val) for val in image]
|
268 |
+
image, has_nsfw_concept = self.safety_checker(
|
269 |
+
images=np_image,
|
270 |
+
clip_input=safety_checker_input.pixel_values.to(torch.float16),
|
271 |
+
)
|
272 |
+
return image, has_nsfw_concept
|
273 |
+
|
274 |
+
@torch.inference_mode()
|
275 |
+
def predict(
|
276 |
+
self,
|
277 |
+
prompt: str = Input(
|
278 |
+
description="Input prompt",
|
279 |
+
default="An astronaut riding a rainbow unicorn",
|
280 |
+
),
|
281 |
+
negative_prompt: str = Input(
|
282 |
+
description="Input Negative Prompt",
|
283 |
+
default="",
|
284 |
+
),
|
285 |
+
image: Path = Input(
|
286 |
+
description="Input image for img2img or inpaint mode",
|
287 |
+
default=None,
|
288 |
+
),
|
289 |
+
mask: Path = Input(
|
290 |
+
description="Input mask for inpaint mode. Black areas will be preserved, white areas will be inpainted.",
|
291 |
+
default=None,
|
292 |
+
),
|
293 |
+
width: int = Input(
|
294 |
+
description="Width of output image",
|
295 |
+
default=1024,
|
296 |
+
),
|
297 |
+
height: int = Input(
|
298 |
+
description="Height of output image",
|
299 |
+
default=1024,
|
300 |
+
),
|
301 |
+
num_outputs: int = Input(
|
302 |
+
description="Number of images to output.",
|
303 |
+
ge=1,
|
304 |
+
le=4,
|
305 |
+
default=1,
|
306 |
+
),
|
307 |
+
scheduler: str = Input(
|
308 |
+
description="scheduler",
|
309 |
+
choices=SCHEDULERS.keys(),
|
310 |
+
default="K_EULER",
|
311 |
+
),
|
312 |
+
num_inference_steps: int = Input(
|
313 |
+
description="Number of denoising steps", ge=1, le=500, default=50
|
314 |
+
),
|
315 |
+
guidance_scale: float = Input(
|
316 |
+
description="Scale for classifier-free guidance", ge=1, le=50, default=7.5
|
317 |
+
),
|
318 |
+
prompt_strength: float = Input(
|
319 |
+
description="Prompt strength when using img2img / inpaint. 1.0 corresponds to full destruction of information in image",
|
320 |
+
ge=0.0,
|
321 |
+
le=1.0,
|
322 |
+
default=0.8,
|
323 |
+
),
|
324 |
+
seed: int = Input(
|
325 |
+
description="Random seed. Leave blank to randomize the seed", default=None
|
326 |
+
),
|
327 |
+
refine: str = Input(
|
328 |
+
description="Which refine style to use",
|
329 |
+
choices=["no_refiner", "expert_ensemble_refiner", "base_image_refiner"],
|
330 |
+
default="no_refiner",
|
331 |
+
),
|
332 |
+
high_noise_frac: float = Input(
|
333 |
+
description="For expert_ensemble_refiner, the fraction of noise to use",
|
334 |
+
default=0.8,
|
335 |
+
le=1.0,
|
336 |
+
ge=0.0,
|
337 |
+
),
|
338 |
+
refine_steps: int = Input(
|
339 |
+
description="For base_image_refiner, the number of steps to refine, defaults to num_inference_steps",
|
340 |
+
default=None,
|
341 |
+
),
|
342 |
+
apply_watermark: bool = Input(
|
343 |
+
description="Applies a watermark to enable determining if an image is generated in downstream applications. If you have other provisions for generating or deploying images safely, you can use this to disable watermarking.",
|
344 |
+
default=True,
|
345 |
+
),
|
346 |
+
lora_scale: float = Input(
|
347 |
+
description="LoRA additive scale. Only applicable on trained models.",
|
348 |
+
ge=0.0,
|
349 |
+
le=1.0,
|
350 |
+
default=0.6,
|
351 |
+
),
|
352 |
+
replicate_weights: str = Input(
|
353 |
+
description="Replicate LoRA weights to use. Leave blank to use the default weights.",
|
354 |
+
default=None,
|
355 |
+
),
|
356 |
+
disable_safety_checker: bool = Input(
|
357 |
+
description="Disable safety checker for generated images. This feature is only available through the API. See [https://replicate.com/docs/how-does-replicate-work#safety](https://replicate.com/docs/how-does-replicate-work#safety)",
|
358 |
+
default=False,
|
359 |
+
),
|
360 |
+
) -> List[Path]:
|
361 |
+
"""Run a single prediction on the model."""
|
362 |
+
if seed is None:
|
363 |
+
seed = int.from_bytes(os.urandom(2), "big")
|
364 |
+
print(f"Using seed: {seed}")
|
365 |
+
|
366 |
+
if replicate_weights:
|
367 |
+
self.load_trained_weights(replicate_weights, self.txt2img_pipe)
|
368 |
+
elif self.tuned_model:
|
369 |
+
self.unload_trained_weights(self.txt2img_pipe)
|
370 |
+
|
371 |
+
# OOMs can leave vae in bad state
|
372 |
+
if self.txt2img_pipe.vae.dtype == torch.float32:
|
373 |
+
self.txt2img_pipe.vae.to(dtype=torch.float16)
|
374 |
+
|
375 |
+
sdxl_kwargs = {}
|
376 |
+
if self.tuned_model:
|
377 |
+
# consistency with fine-tuning API
|
378 |
+
for k, v in self.token_map.items():
|
379 |
+
prompt = prompt.replace(k, v)
|
380 |
+
print(f"Prompt: {prompt}")
|
381 |
+
if image and mask:
|
382 |
+
print("inpainting mode")
|
383 |
+
sdxl_kwargs["image"] = self.load_image(image)
|
384 |
+
sdxl_kwargs["mask_image"] = self.load_image(mask)
|
385 |
+
sdxl_kwargs["strength"] = prompt_strength
|
386 |
+
sdxl_kwargs["width"] = width
|
387 |
+
sdxl_kwargs["height"] = height
|
388 |
+
pipe = self.inpaint_pipe
|
389 |
+
elif image:
|
390 |
+
print("img2img mode")
|
391 |
+
sdxl_kwargs["image"] = self.load_image(image)
|
392 |
+
sdxl_kwargs["strength"] = prompt_strength
|
393 |
+
pipe = self.img2img_pipe
|
394 |
+
else:
|
395 |
+
print("txt2img mode")
|
396 |
+
sdxl_kwargs["width"] = width
|
397 |
+
sdxl_kwargs["height"] = height
|
398 |
+
pipe = self.txt2img_pipe
|
399 |
+
|
400 |
+
if refine == "expert_ensemble_refiner":
|
401 |
+
sdxl_kwargs["output_type"] = "latent"
|
402 |
+
sdxl_kwargs["denoising_end"] = high_noise_frac
|
403 |
+
elif refine == "base_image_refiner":
|
404 |
+
sdxl_kwargs["output_type"] = "latent"
|
405 |
+
|
406 |
+
if not apply_watermark:
|
407 |
+
# toggles watermark for this prediction
|
408 |
+
watermark_cache = pipe.watermark
|
409 |
+
pipe.watermark = None
|
410 |
+
self.refiner.watermark = None
|
411 |
+
|
412 |
+
pipe.scheduler = SCHEDULERS[scheduler].from_config(pipe.scheduler.config)
|
413 |
+
generator = torch.Generator("cuda").manual_seed(seed)
|
414 |
+
|
415 |
+
common_args = {
|
416 |
+
"prompt": [prompt] * num_outputs,
|
417 |
+
"negative_prompt": [negative_prompt] * num_outputs,
|
418 |
+
"guidance_scale": guidance_scale,
|
419 |
+
"generator": generator,
|
420 |
+
"num_inference_steps": num_inference_steps,
|
421 |
+
}
|
422 |
+
|
423 |
+
if self.is_lora:
|
424 |
+
sdxl_kwargs["cross_attention_kwargs"] = {"scale": lora_scale}
|
425 |
+
|
426 |
+
output = pipe(**common_args, **sdxl_kwargs)
|
427 |
+
|
428 |
+
if refine in ["expert_ensemble_refiner", "base_image_refiner"]:
|
429 |
+
refiner_kwargs = {
|
430 |
+
"image": output.images,
|
431 |
+
}
|
432 |
+
|
433 |
+
if refine == "expert_ensemble_refiner":
|
434 |
+
refiner_kwargs["denoising_start"] = high_noise_frac
|
435 |
+
if refine == "base_image_refiner" and refine_steps:
|
436 |
+
common_args["num_inference_steps"] = refine_steps
|
437 |
+
|
438 |
+
output = self.refiner(**common_args, **refiner_kwargs)
|
439 |
+
|
440 |
+
if not apply_watermark:
|
441 |
+
pipe.watermark = watermark_cache
|
442 |
+
self.refiner.watermark = watermark_cache
|
443 |
+
|
444 |
+
if not disable_safety_checker:
|
445 |
+
_, has_nsfw_content = self.run_safety_checker(output.images)
|
446 |
+
|
447 |
+
output_paths = []
|
448 |
+
for i, image in enumerate(output.images):
|
449 |
+
if not disable_safety_checker:
|
450 |
+
if has_nsfw_content[i]:
|
451 |
+
print(f"NSFW content detected in image {i}")
|
452 |
+
continue
|
453 |
+
output_path = f"/tmp/out-{i}.png"
|
454 |
+
image.save(output_path)
|
455 |
+
output_paths.append(Path(output_path))
|
456 |
+
|
457 |
+
if len(output_paths) == 0:
|
458 |
+
raise Exception(
|
459 |
+
f"NSFW content detected. Try running it again, or try a different prompt."
|
460 |
+
)
|
461 |
+
|
462 |
+
return output_paths
|
cog_sdxl/preprocess.py
ADDED
@@ -0,0 +1,599 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Have SwinIR upsample
|
2 |
+
# Have BLIP auto caption
|
3 |
+
# Have CLIPSeg auto mask concept
|
4 |
+
|
5 |
+
import gc
|
6 |
+
import fnmatch
|
7 |
+
import mimetypes
|
8 |
+
import os
|
9 |
+
import re
|
10 |
+
import shutil
|
11 |
+
import tarfile
|
12 |
+
from pathlib import Path
|
13 |
+
from typing import List, Literal, Optional, Tuple, Union
|
14 |
+
from zipfile import ZipFile
|
15 |
+
|
16 |
+
import cv2
|
17 |
+
import mediapipe as mp
|
18 |
+
import numpy as np
|
19 |
+
import pandas as pd
|
20 |
+
import torch
|
21 |
+
from PIL import Image, ImageFilter
|
22 |
+
from tqdm import tqdm
|
23 |
+
from transformers import (
|
24 |
+
BlipForConditionalGeneration,
|
25 |
+
BlipProcessor,
|
26 |
+
CLIPSegForImageSegmentation,
|
27 |
+
CLIPSegProcessor,
|
28 |
+
Swin2SRForImageSuperResolution,
|
29 |
+
Swin2SRImageProcessor,
|
30 |
+
)
|
31 |
+
|
32 |
+
from predict import download_weights
|
33 |
+
|
34 |
+
# model is fixed to Salesforce/blip-image-captioning-large
|
35 |
+
BLIP_URL = "https://weights.replicate.delivery/default/blip_large/blip_large.tar"
|
36 |
+
BLIP_PROCESSOR_URL = (
|
37 |
+
"https://weights.replicate.delivery/default/blip_processor/blip_processor.tar"
|
38 |
+
)
|
39 |
+
BLIP_PATH = "./blip-cache"
|
40 |
+
BLIP_PROCESSOR_PATH = "./blip-proc-cache"
|
41 |
+
|
42 |
+
# model is fixed to CIDAS/clipseg-rd64-refined
|
43 |
+
CLIPSEG_URL = "https://weights.replicate.delivery/default/clip_seg_rd64_refined/clip_seg_rd64_refined.tar"
|
44 |
+
CLIPSEG_PROCESSOR = "https://weights.replicate.delivery/default/clip_seg_processor/clip_seg_processor.tar"
|
45 |
+
CLIPSEG_PATH = "./clipseg-cache"
|
46 |
+
CLIPSEG_PROCESSOR_PATH = "./clipseg-proc-cache"
|
47 |
+
|
48 |
+
# model is fixed to caidas/swin2SR-realworld-sr-x4-64-bsrgan-psnr
|
49 |
+
SWIN2SR_URL = "https://weights.replicate.delivery/default/swin2sr_realworld_sr_x4_64_bsrgan_psnr/swin2sr_realworld_sr_x4_64_bsrgan_psnr.tar"
|
50 |
+
SWIN2SR_PATH = "./swin2sr-cache"
|
51 |
+
|
52 |
+
TEMP_OUT_DIR = "./temp/"
|
53 |
+
TEMP_IN_DIR = "./temp_in/"
|
54 |
+
|
55 |
+
CSV_MATCH = "caption"
|
56 |
+
|
57 |
+
|
58 |
+
def preprocess(
|
59 |
+
input_images_filetype: str,
|
60 |
+
input_zip_path: Path,
|
61 |
+
caption_text: str,
|
62 |
+
mask_target_prompts: str,
|
63 |
+
target_size: int,
|
64 |
+
crop_based_on_salience: bool,
|
65 |
+
use_face_detection_instead: bool,
|
66 |
+
temp: float,
|
67 |
+
substitution_tokens: List[str],
|
68 |
+
) -> Path:
|
69 |
+
# assert str(files).endswith(".zip"), "files must be a zip file"
|
70 |
+
|
71 |
+
# clear TEMP_IN_DIR first.
|
72 |
+
|
73 |
+
for path in [TEMP_OUT_DIR, TEMP_IN_DIR]:
|
74 |
+
if os.path.exists(path):
|
75 |
+
shutil.rmtree(path)
|
76 |
+
os.makedirs(path)
|
77 |
+
|
78 |
+
caption_csv = None
|
79 |
+
|
80 |
+
if input_images_filetype == "zip" or str(input_zip_path).endswith(".zip"):
|
81 |
+
with ZipFile(str(input_zip_path), "r") as zip_ref:
|
82 |
+
for zip_info in zip_ref.infolist():
|
83 |
+
if zip_info.filename[-1] == "/" or zip_info.filename.startswith(
|
84 |
+
"__MACOSX"
|
85 |
+
):
|
86 |
+
continue
|
87 |
+
mt = mimetypes.guess_type(zip_info.filename)
|
88 |
+
if mt and mt[0] and mt[0].startswith("image/"):
|
89 |
+
zip_info.filename = os.path.basename(zip_info.filename)
|
90 |
+
zip_ref.extract(zip_info, TEMP_IN_DIR)
|
91 |
+
if (
|
92 |
+
mt
|
93 |
+
and mt[0]
|
94 |
+
and mt[0] == "text/csv"
|
95 |
+
and CSV_MATCH in zip_info.filename
|
96 |
+
):
|
97 |
+
zip_info.filename = os.path.basename(zip_info.filename)
|
98 |
+
zip_ref.extract(zip_info, TEMP_IN_DIR)
|
99 |
+
caption_csv = os.path.join(TEMP_IN_DIR, zip_info.filename)
|
100 |
+
elif input_images_filetype == "tar" or str(input_zip_path).endswith(".tar"):
|
101 |
+
assert str(input_zip_path).endswith(
|
102 |
+
".tar"
|
103 |
+
), "files must be a tar file if not zip"
|
104 |
+
with tarfile.open(input_zip_path, "r") as tar_ref:
|
105 |
+
for tar_info in tar_ref:
|
106 |
+
if tar_info.name[-1] == "/" or tar_info.name.startswith("__MACOSX"):
|
107 |
+
continue
|
108 |
+
|
109 |
+
mt = mimetypes.guess_type(tar_info.name)
|
110 |
+
if mt and mt[0] and mt[0].startswith("image/"):
|
111 |
+
tar_info.name = os.path.basename(tar_info.name)
|
112 |
+
tar_ref.extract(tar_info, TEMP_IN_DIR)
|
113 |
+
if mt and mt[0] and mt[0] == "text/csv" and CSV_MATCH in tar_info.name:
|
114 |
+
tar_info.name = os.path.basename(tar_info.name)
|
115 |
+
tar_ref.extract(tar_info, TEMP_IN_DIR)
|
116 |
+
caption_csv = os.path.join(TEMP_IN_DIR, tar_info.name)
|
117 |
+
else:
|
118 |
+
assert False, "input_images_filetype must be zip or tar"
|
119 |
+
|
120 |
+
output_dir: str = TEMP_OUT_DIR
|
121 |
+
|
122 |
+
load_and_save_masks_and_captions(
|
123 |
+
files=TEMP_IN_DIR,
|
124 |
+
output_dir=output_dir,
|
125 |
+
caption_text=caption_text,
|
126 |
+
caption_csv=caption_csv,
|
127 |
+
mask_target_prompts=mask_target_prompts,
|
128 |
+
target_size=target_size,
|
129 |
+
crop_based_on_salience=crop_based_on_salience,
|
130 |
+
use_face_detection_instead=use_face_detection_instead,
|
131 |
+
temp=temp,
|
132 |
+
substitution_tokens=substitution_tokens,
|
133 |
+
)
|
134 |
+
|
135 |
+
return Path(TEMP_OUT_DIR)
|
136 |
+
|
137 |
+
|
138 |
+
@torch.no_grad()
|
139 |
+
@torch.cuda.amp.autocast()
|
140 |
+
def swin_ir_sr(
|
141 |
+
images: List[Image.Image],
|
142 |
+
target_size: Optional[Tuple[int, int]] = None,
|
143 |
+
device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
|
144 |
+
**kwargs,
|
145 |
+
) -> List[Image.Image]:
|
146 |
+
"""
|
147 |
+
Upscales images using SwinIR. Returns a list of PIL images.
|
148 |
+
If the image is already larger than the target size, it will not be upscaled
|
149 |
+
and will be returned as is.
|
150 |
+
|
151 |
+
"""
|
152 |
+
if not os.path.exists(SWIN2SR_PATH):
|
153 |
+
download_weights(SWIN2SR_URL, SWIN2SR_PATH)
|
154 |
+
model = Swin2SRForImageSuperResolution.from_pretrained(SWIN2SR_PATH).to(device)
|
155 |
+
processor = Swin2SRImageProcessor()
|
156 |
+
|
157 |
+
out_images = []
|
158 |
+
|
159 |
+
for image in tqdm(images):
|
160 |
+
ori_w, ori_h = image.size
|
161 |
+
if target_size is not None:
|
162 |
+
if ori_w >= target_size[0] and ori_h >= target_size[1]:
|
163 |
+
out_images.append(image)
|
164 |
+
continue
|
165 |
+
|
166 |
+
inputs = processor(image, return_tensors="pt").to(device)
|
167 |
+
with torch.no_grad():
|
168 |
+
outputs = model(**inputs)
|
169 |
+
|
170 |
+
output = (
|
171 |
+
outputs.reconstruction.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
172 |
+
)
|
173 |
+
output = np.moveaxis(output, source=0, destination=-1)
|
174 |
+
output = (output * 255.0).round().astype(np.uint8)
|
175 |
+
output = Image.fromarray(output)
|
176 |
+
|
177 |
+
out_images.append(output)
|
178 |
+
|
179 |
+
return out_images
|
180 |
+
|
181 |
+
|
182 |
+
@torch.no_grad()
|
183 |
+
@torch.cuda.amp.autocast()
|
184 |
+
def clipseg_mask_generator(
|
185 |
+
images: List[Image.Image],
|
186 |
+
target_prompts: Union[List[str], str],
|
187 |
+
device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
|
188 |
+
bias: float = 0.01,
|
189 |
+
temp: float = 1.0,
|
190 |
+
**kwargs,
|
191 |
+
) -> List[Image.Image]:
|
192 |
+
"""
|
193 |
+
Returns a greyscale mask for each image, where the mask is the probability of the target prompt being present in the image
|
194 |
+
"""
|
195 |
+
|
196 |
+
if isinstance(target_prompts, str):
|
197 |
+
print(
|
198 |
+
f'Warning: only one target prompt "{target_prompts}" was given, so it will be used for all images'
|
199 |
+
)
|
200 |
+
|
201 |
+
target_prompts = [target_prompts] * len(images)
|
202 |
+
if not os.path.exists(CLIPSEG_PROCESSOR_PATH):
|
203 |
+
download_weights(CLIPSEG_PROCESSOR, CLIPSEG_PROCESSOR_PATH)
|
204 |
+
if not os.path.exists(CLIPSEG_PATH):
|
205 |
+
download_weights(CLIPSEG_URL, CLIPSEG_PATH)
|
206 |
+
processor = CLIPSegProcessor.from_pretrained(CLIPSEG_PROCESSOR_PATH)
|
207 |
+
model = CLIPSegForImageSegmentation.from_pretrained(CLIPSEG_PATH).to(device)
|
208 |
+
|
209 |
+
masks = []
|
210 |
+
|
211 |
+
for image, prompt in tqdm(zip(images, target_prompts)):
|
212 |
+
original_size = image.size
|
213 |
+
|
214 |
+
inputs = processor(
|
215 |
+
text=[prompt, ""],
|
216 |
+
images=[image] * 2,
|
217 |
+
padding="max_length",
|
218 |
+
truncation=True,
|
219 |
+
return_tensors="pt",
|
220 |
+
).to(device)
|
221 |
+
|
222 |
+
outputs = model(**inputs)
|
223 |
+
|
224 |
+
logits = outputs.logits
|
225 |
+
probs = torch.nn.functional.softmax(logits / temp, dim=0)[0]
|
226 |
+
probs = (probs + bias).clamp_(0, 1)
|
227 |
+
probs = 255 * probs / probs.max()
|
228 |
+
|
229 |
+
# make mask greyscale
|
230 |
+
mask = Image.fromarray(probs.cpu().numpy()).convert("L")
|
231 |
+
|
232 |
+
# resize mask to original size
|
233 |
+
mask = mask.resize(original_size)
|
234 |
+
|
235 |
+
masks.append(mask)
|
236 |
+
|
237 |
+
return masks
|
238 |
+
|
239 |
+
|
240 |
+
@torch.no_grad()
|
241 |
+
def blip_captioning_dataset(
|
242 |
+
images: List[Image.Image],
|
243 |
+
text: Optional[str] = None,
|
244 |
+
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
|
245 |
+
substitution_tokens: Optional[List[str]] = None,
|
246 |
+
**kwargs,
|
247 |
+
) -> List[str]:
|
248 |
+
"""
|
249 |
+
Returns a list of captions for the given images
|
250 |
+
"""
|
251 |
+
if not os.path.exists(BLIP_PROCESSOR_PATH):
|
252 |
+
download_weights(BLIP_PROCESSOR_URL, BLIP_PROCESSOR_PATH)
|
253 |
+
if not os.path.exists(BLIP_PATH):
|
254 |
+
download_weights(BLIP_URL, BLIP_PATH)
|
255 |
+
processor = BlipProcessor.from_pretrained(BLIP_PROCESSOR_PATH)
|
256 |
+
model = BlipForConditionalGeneration.from_pretrained(BLIP_PATH).to(device)
|
257 |
+
captions = []
|
258 |
+
text = text.strip()
|
259 |
+
print(f"Input captioning text: {text}")
|
260 |
+
for image in tqdm(images):
|
261 |
+
inputs = processor(image, return_tensors="pt").to("cuda")
|
262 |
+
out = model.generate(
|
263 |
+
**inputs, max_length=150, do_sample=True, top_k=50, temperature=0.7
|
264 |
+
)
|
265 |
+
caption = processor.decode(out[0], skip_special_tokens=True)
|
266 |
+
|
267 |
+
# BLIP 2 lowercases all caps tokens. This should properly replace them w/o messing up subwords. I'm sure there's a better way to do this.
|
268 |
+
for token in substitution_tokens:
|
269 |
+
print(token)
|
270 |
+
sub_cap = " " + caption + " "
|
271 |
+
print(sub_cap)
|
272 |
+
sub_cap = sub_cap.replace(" " + token.lower() + " ", " " + token + " ")
|
273 |
+
caption = sub_cap.strip()
|
274 |
+
|
275 |
+
captions.append(text + " " + caption)
|
276 |
+
print("Generated captions", captions)
|
277 |
+
return captions
|
278 |
+
|
279 |
+
|
280 |
+
def face_mask_google_mediapipe(
|
281 |
+
images: List[Image.Image], blur_amount: float = 0.0, bias: float = 50.0
|
282 |
+
) -> List[Image.Image]:
|
283 |
+
"""
|
284 |
+
Returns a list of images with masks on the face parts.
|
285 |
+
"""
|
286 |
+
mp_face_detection = mp.solutions.face_detection
|
287 |
+
mp_face_mesh = mp.solutions.face_mesh
|
288 |
+
|
289 |
+
face_detection = mp_face_detection.FaceDetection(
|
290 |
+
model_selection=1, min_detection_confidence=0.1
|
291 |
+
)
|
292 |
+
face_mesh = mp_face_mesh.FaceMesh(
|
293 |
+
static_image_mode=True, max_num_faces=1, min_detection_confidence=0.1
|
294 |
+
)
|
295 |
+
|
296 |
+
masks = []
|
297 |
+
for image in tqdm(images):
|
298 |
+
image_np = np.array(image)
|
299 |
+
|
300 |
+
# Perform face detection
|
301 |
+
results_detection = face_detection.process(image_np)
|
302 |
+
ih, iw, _ = image_np.shape
|
303 |
+
if results_detection.detections:
|
304 |
+
for detection in results_detection.detections:
|
305 |
+
bboxC = detection.location_data.relative_bounding_box
|
306 |
+
|
307 |
+
bbox = (
|
308 |
+
int(bboxC.xmin * iw),
|
309 |
+
int(bboxC.ymin * ih),
|
310 |
+
int(bboxC.width * iw),
|
311 |
+
int(bboxC.height * ih),
|
312 |
+
)
|
313 |
+
|
314 |
+
# make sure bbox is within image
|
315 |
+
bbox = (
|
316 |
+
max(0, bbox[0]),
|
317 |
+
max(0, bbox[1]),
|
318 |
+
min(iw - bbox[0], bbox[2]),
|
319 |
+
min(ih - bbox[1], bbox[3]),
|
320 |
+
)
|
321 |
+
|
322 |
+
print(bbox)
|
323 |
+
|
324 |
+
# Extract face landmarks
|
325 |
+
face_landmarks = face_mesh.process(
|
326 |
+
image_np[bbox[1] : bbox[1] + bbox[3], bbox[0] : bbox[0] + bbox[2]]
|
327 |
+
).multi_face_landmarks
|
328 |
+
|
329 |
+
# https://github.com/google/mediapipe/issues/1615
|
330 |
+
# This was def helpful
|
331 |
+
indexes = [
|
332 |
+
10,
|
333 |
+
338,
|
334 |
+
297,
|
335 |
+
332,
|
336 |
+
284,
|
337 |
+
251,
|
338 |
+
389,
|
339 |
+
356,
|
340 |
+
454,
|
341 |
+
323,
|
342 |
+
361,
|
343 |
+
288,
|
344 |
+
397,
|
345 |
+
365,
|
346 |
+
379,
|
347 |
+
378,
|
348 |
+
400,
|
349 |
+
377,
|
350 |
+
152,
|
351 |
+
148,
|
352 |
+
176,
|
353 |
+
149,
|
354 |
+
150,
|
355 |
+
136,
|
356 |
+
172,
|
357 |
+
58,
|
358 |
+
132,
|
359 |
+
93,
|
360 |
+
234,
|
361 |
+
127,
|
362 |
+
162,
|
363 |
+
21,
|
364 |
+
54,
|
365 |
+
103,
|
366 |
+
67,
|
367 |
+
109,
|
368 |
+
]
|
369 |
+
|
370 |
+
if face_landmarks:
|
371 |
+
mask = Image.new("L", (iw, ih), 0)
|
372 |
+
mask_np = np.array(mask)
|
373 |
+
|
374 |
+
for face_landmark in face_landmarks:
|
375 |
+
face_landmark = [face_landmark.landmark[idx] for idx in indexes]
|
376 |
+
landmark_points = [
|
377 |
+
(int(l.x * bbox[2]) + bbox[0], int(l.y * bbox[3]) + bbox[1])
|
378 |
+
for l in face_landmark
|
379 |
+
]
|
380 |
+
mask_np = cv2.fillPoly(
|
381 |
+
mask_np, [np.array(landmark_points)], 255
|
382 |
+
)
|
383 |
+
|
384 |
+
mask = Image.fromarray(mask_np)
|
385 |
+
|
386 |
+
# Apply blur to the mask
|
387 |
+
if blur_amount > 0:
|
388 |
+
mask = mask.filter(ImageFilter.GaussianBlur(blur_amount))
|
389 |
+
|
390 |
+
# Apply bias to the mask
|
391 |
+
if bias > 0:
|
392 |
+
mask = np.array(mask)
|
393 |
+
mask = mask + bias * np.ones(mask.shape, dtype=mask.dtype)
|
394 |
+
mask = np.clip(mask, 0, 255)
|
395 |
+
mask = Image.fromarray(mask)
|
396 |
+
|
397 |
+
# Convert mask to 'L' mode (grayscale) before saving
|
398 |
+
mask = mask.convert("L")
|
399 |
+
|
400 |
+
masks.append(mask)
|
401 |
+
else:
|
402 |
+
# If face landmarks are not available, add a black mask of the same size as the image
|
403 |
+
masks.append(Image.new("L", (iw, ih), 255))
|
404 |
+
|
405 |
+
else:
|
406 |
+
print("No face detected, adding full mask")
|
407 |
+
# If no face is detected, add a white mask of the same size as the image
|
408 |
+
masks.append(Image.new("L", (iw, ih), 255))
|
409 |
+
|
410 |
+
return masks
|
411 |
+
|
412 |
+
|
413 |
+
def _crop_to_square(
|
414 |
+
image: Image.Image, com: List[Tuple[int, int]], resize_to: Optional[int] = None
|
415 |
+
):
|
416 |
+
cx, cy = com
|
417 |
+
width, height = image.size
|
418 |
+
if width > height:
|
419 |
+
left_possible = max(cx - height / 2, 0)
|
420 |
+
left = min(left_possible, width - height)
|
421 |
+
right = left + height
|
422 |
+
top = 0
|
423 |
+
bottom = height
|
424 |
+
else:
|
425 |
+
left = 0
|
426 |
+
right = width
|
427 |
+
top_possible = max(cy - width / 2, 0)
|
428 |
+
top = min(top_possible, height - width)
|
429 |
+
bottom = top + width
|
430 |
+
|
431 |
+
image = image.crop((left, top, right, bottom))
|
432 |
+
|
433 |
+
if resize_to:
|
434 |
+
image = image.resize((resize_to, resize_to), Image.Resampling.LANCZOS)
|
435 |
+
|
436 |
+
return image
|
437 |
+
|
438 |
+
|
439 |
+
def _center_of_mass(mask: Image.Image):
|
440 |
+
"""
|
441 |
+
Returns the center of mass of the mask
|
442 |
+
"""
|
443 |
+
x, y = np.meshgrid(np.arange(mask.size[0]), np.arange(mask.size[1]))
|
444 |
+
mask_np = np.array(mask) + 0.01
|
445 |
+
x_ = x * mask_np
|
446 |
+
y_ = y * mask_np
|
447 |
+
|
448 |
+
x = np.sum(x_) / np.sum(mask_np)
|
449 |
+
y = np.sum(y_) / np.sum(mask_np)
|
450 |
+
|
451 |
+
return x, y
|
452 |
+
|
453 |
+
|
454 |
+
def load_and_save_masks_and_captions(
|
455 |
+
files: Union[str, List[str]],
|
456 |
+
output_dir: str = TEMP_OUT_DIR,
|
457 |
+
caption_text: Optional[str] = None,
|
458 |
+
caption_csv: Optional[str] = None,
|
459 |
+
mask_target_prompts: Optional[Union[List[str], str]] = None,
|
460 |
+
target_size: int = 1024,
|
461 |
+
crop_based_on_salience: bool = True,
|
462 |
+
use_face_detection_instead: bool = False,
|
463 |
+
temp: float = 1.0,
|
464 |
+
n_length: int = -1,
|
465 |
+
substitution_tokens: Optional[List[str]] = None,
|
466 |
+
):
|
467 |
+
"""
|
468 |
+
Loads images from the given files, generates masks for them, and saves the masks and captions and upscale images
|
469 |
+
to output dir. If mask_target_prompts is given, it will generate kinda-segmentation-masks for the prompts and save them as well.
|
470 |
+
|
471 |
+
Example:
|
472 |
+
>>> x = load_and_save_masks_and_captions(
|
473 |
+
files="./data/images",
|
474 |
+
output_dir="./data/masks_and_captions",
|
475 |
+
caption_text="a photo of",
|
476 |
+
mask_target_prompts="cat",
|
477 |
+
target_size=768,
|
478 |
+
crop_based_on_salience=True,
|
479 |
+
use_face_detection_instead=False,
|
480 |
+
temp=1.0,
|
481 |
+
n_length=-1,
|
482 |
+
)
|
483 |
+
"""
|
484 |
+
os.makedirs(output_dir, exist_ok=True)
|
485 |
+
|
486 |
+
# load images
|
487 |
+
if isinstance(files, str):
|
488 |
+
# check if it is a directory
|
489 |
+
if os.path.isdir(files):
|
490 |
+
# get all the .png .jpg in the directory
|
491 |
+
files = (
|
492 |
+
_find_files("*.png", files)
|
493 |
+
+ _find_files("*.jpg", files)
|
494 |
+
+ _find_files("*.jpeg", files)
|
495 |
+
)
|
496 |
+
|
497 |
+
if len(files) == 0:
|
498 |
+
raise Exception(
|
499 |
+
f"No files found in {files}. Either {files} is not a directory or it does not contain any .png or .jpg/jpeg files."
|
500 |
+
)
|
501 |
+
if n_length == -1:
|
502 |
+
n_length = len(files)
|
503 |
+
files = sorted(files)[:n_length]
|
504 |
+
print("Image files: ", files)
|
505 |
+
images = [Image.open(file).convert("RGB") for file in files]
|
506 |
+
|
507 |
+
# captions
|
508 |
+
if caption_csv:
|
509 |
+
print(f"Using provided captions")
|
510 |
+
caption_df = pd.read_csv(caption_csv)
|
511 |
+
# sort images to be consistent with 'sorted' above
|
512 |
+
caption_df = caption_df.sort_values("image_file")
|
513 |
+
captions = caption_df["caption"].values
|
514 |
+
print("Captions: ", captions)
|
515 |
+
if len(captions) != len(images):
|
516 |
+
print("Not the same number of captions as images!")
|
517 |
+
print(f"Num captions: {len(captions)}, Num images: {len(images)}")
|
518 |
+
print("Captions: ", captions)
|
519 |
+
print("Images: ", files)
|
520 |
+
raise Exception(
|
521 |
+
"Not the same number of captions as images! Check that all files passed in have a caption in your caption csv, and vice versa"
|
522 |
+
)
|
523 |
+
|
524 |
+
else:
|
525 |
+
print(f"Generating {len(images)} captions...")
|
526 |
+
captions = blip_captioning_dataset(
|
527 |
+
images, text=caption_text, substitution_tokens=substitution_tokens
|
528 |
+
)
|
529 |
+
|
530 |
+
if mask_target_prompts is None:
|
531 |
+
mask_target_prompts = ""
|
532 |
+
temp = 999
|
533 |
+
|
534 |
+
print(f"Generating {len(images)} masks...")
|
535 |
+
if not use_face_detection_instead:
|
536 |
+
seg_masks = clipseg_mask_generator(
|
537 |
+
images=images, target_prompts=mask_target_prompts, temp=temp
|
538 |
+
)
|
539 |
+
else:
|
540 |
+
seg_masks = face_mask_google_mediapipe(images=images)
|
541 |
+
|
542 |
+
# find the center of mass of the mask
|
543 |
+
if crop_based_on_salience:
|
544 |
+
coms = [_center_of_mass(mask) for mask in seg_masks]
|
545 |
+
else:
|
546 |
+
coms = [(image.size[0] / 2, image.size[1] / 2) for image in images]
|
547 |
+
# based on the center of mass, crop the image to a square
|
548 |
+
images = [
|
549 |
+
_crop_to_square(image, com, resize_to=None) for image, com in zip(images, coms)
|
550 |
+
]
|
551 |
+
|
552 |
+
print(f"Upscaling {len(images)} images...")
|
553 |
+
# upscale images anyways
|
554 |
+
images = swin_ir_sr(images, target_size=(target_size, target_size))
|
555 |
+
images = [
|
556 |
+
image.resize((target_size, target_size), Image.Resampling.LANCZOS)
|
557 |
+
for image in images
|
558 |
+
]
|
559 |
+
|
560 |
+
seg_masks = [
|
561 |
+
_crop_to_square(mask, com, resize_to=target_size)
|
562 |
+
for mask, com in zip(seg_masks, coms)
|
563 |
+
]
|
564 |
+
|
565 |
+
data = []
|
566 |
+
|
567 |
+
# clean TEMP_OUT_DIR first
|
568 |
+
if os.path.exists(output_dir):
|
569 |
+
for file in os.listdir(output_dir):
|
570 |
+
os.remove(os.path.join(output_dir, file))
|
571 |
+
|
572 |
+
os.makedirs(output_dir, exist_ok=True)
|
573 |
+
|
574 |
+
# iterate through the images, masks, and captions and add a row to the dataframe for each
|
575 |
+
for idx, (image, mask, caption) in enumerate(zip(images, seg_masks, captions)):
|
576 |
+
image_name = f"{idx}.src.png"
|
577 |
+
mask_file = f"{idx}.mask.png"
|
578 |
+
|
579 |
+
# save the image and mask files
|
580 |
+
image.save(output_dir + image_name)
|
581 |
+
mask.save(output_dir + mask_file)
|
582 |
+
|
583 |
+
# add a new row to the dataframe with the file names and caption
|
584 |
+
data.append(
|
585 |
+
{"image_path": image_name, "mask_path": mask_file, "caption": caption},
|
586 |
+
)
|
587 |
+
|
588 |
+
df = pd.DataFrame(columns=["image_path", "mask_path", "caption"], data=data)
|
589 |
+
# save the dataframe to a CSV file
|
590 |
+
df.to_csv(os.path.join(output_dir, "captions.csv"), index=False)
|
591 |
+
|
592 |
+
|
593 |
+
def _find_files(pattern, dir="."):
|
594 |
+
"""Return list of files matching pattern in a given directory, in absolute format.
|
595 |
+
Unlike glob, this is case-insensitive.
|
596 |
+
"""
|
597 |
+
|
598 |
+
rule = re.compile(fnmatch.translate(pattern), re.IGNORECASE)
|
599 |
+
return [os.path.join(dir, f) for f in os.listdir(dir) if rule.match(f)]
|
cog_sdxl/requirements_test.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
numpy
|
2 |
+
pytest
|
3 |
+
replicate
|
4 |
+
requests
|
5 |
+
Pillow
|
cog_sdxl/samples.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
A handy utility for verifying SDXL image generation locally.
|
3 |
+
To set up, first run a local cog server using:
|
4 |
+
cog run -p 5000 python -m cog.server.http
|
5 |
+
Then, in a separate terminal, generate samples
|
6 |
+
python samples.py
|
7 |
+
"""
|
8 |
+
|
9 |
+
|
10 |
+
import base64
|
11 |
+
import os
|
12 |
+
import sys
|
13 |
+
|
14 |
+
import requests
|
15 |
+
|
16 |
+
|
17 |
+
def gen(output_fn, **kwargs):
|
18 |
+
if os.path.exists(output_fn):
|
19 |
+
return
|
20 |
+
|
21 |
+
print("Generating", output_fn)
|
22 |
+
url = "http://localhost:5000/predictions"
|
23 |
+
response = requests.post(url, json={"input": kwargs})
|
24 |
+
data = response.json()
|
25 |
+
|
26 |
+
try:
|
27 |
+
datauri = data["output"][0]
|
28 |
+
base64_encoded_data = datauri.split(",")[1]
|
29 |
+
data = base64.b64decode(base64_encoded_data)
|
30 |
+
except:
|
31 |
+
print("Error!")
|
32 |
+
print("input:", kwargs)
|
33 |
+
print(data["logs"])
|
34 |
+
sys.exit(1)
|
35 |
+
|
36 |
+
with open(output_fn, "wb") as f:
|
37 |
+
f.write(data)
|
38 |
+
|
39 |
+
|
40 |
+
def main():
|
41 |
+
SCHEDULERS = [
|
42 |
+
"DDIM",
|
43 |
+
"DPMSolverMultistep",
|
44 |
+
"HeunDiscrete",
|
45 |
+
"KarrasDPM",
|
46 |
+
"K_EULER_ANCESTRAL",
|
47 |
+
"K_EULER",
|
48 |
+
"PNDM",
|
49 |
+
]
|
50 |
+
|
51 |
+
gen(
|
52 |
+
f"sample.txt2img.png",
|
53 |
+
prompt="A studio portrait photo of a cat",
|
54 |
+
num_inference_steps=25,
|
55 |
+
guidance_scale=7,
|
56 |
+
negative_prompt="ugly, soft, blurry, out of focus, low quality, garish, distorted, disfigured",
|
57 |
+
seed=1000,
|
58 |
+
width=1024,
|
59 |
+
height=1024,
|
60 |
+
)
|
61 |
+
|
62 |
+
for refiner in ["base_image_refiner", "expert_ensemble_refiner", "no_refiner"]:
|
63 |
+
gen(
|
64 |
+
f"sample.img2img.{refiner}.png",
|
65 |
+
prompt="a photo of an astronaut riding a horse on mars",
|
66 |
+
image="https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/aa_xl/000000009.png",
|
67 |
+
prompt_strength=0.8,
|
68 |
+
num_inference_steps=25,
|
69 |
+
refine=refiner,
|
70 |
+
guidance_scale=7,
|
71 |
+
negative_prompt="ugly, soft, blurry, out of focus, low quality, garish, distorted, disfigured",
|
72 |
+
seed=42,
|
73 |
+
)
|
74 |
+
|
75 |
+
gen(
|
76 |
+
f"sample.inpaint.{refiner}.png",
|
77 |
+
prompt="A majestic tiger sitting on a bench",
|
78 |
+
image="https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png",
|
79 |
+
mask="https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png",
|
80 |
+
prompt_strength=0.8,
|
81 |
+
num_inference_steps=25,
|
82 |
+
refine=refiner,
|
83 |
+
guidance_scale=7,
|
84 |
+
negative_prompt="ugly, soft, blurry, out of focus, low quality, garish, distorted, disfigured",
|
85 |
+
seed=42,
|
86 |
+
)
|
87 |
+
|
88 |
+
for split in range(0, 10):
|
89 |
+
split = split / 10.0
|
90 |
+
gen(
|
91 |
+
f"sample.expert_ensemble_refiner.{split}.txt2img.png",
|
92 |
+
prompt="A studio portrait photo of a cat",
|
93 |
+
num_inference_steps=25,
|
94 |
+
guidance_scale=7,
|
95 |
+
refine="expert_ensemble_refiner",
|
96 |
+
high_noise_frac=split,
|
97 |
+
negative_prompt="ugly, soft, blurry, out of focus, low quality, garish, distorted, disfigured",
|
98 |
+
seed=1000,
|
99 |
+
width=1024,
|
100 |
+
height=1024,
|
101 |
+
)
|
102 |
+
|
103 |
+
gen(
|
104 |
+
f"sample.refine.txt2img.png",
|
105 |
+
prompt="A studio portrait photo of a cat",
|
106 |
+
num_inference_steps=25,
|
107 |
+
guidance_scale=7,
|
108 |
+
refine="base_image_refiner",
|
109 |
+
negative_prompt="ugly, soft, blurry, out of focus, low quality, garish, distorted, disfigured",
|
110 |
+
seed=1000,
|
111 |
+
width=1024,
|
112 |
+
height=1024,
|
113 |
+
)
|
114 |
+
gen(
|
115 |
+
f"sample.refine.10.txt2img.png",
|
116 |
+
prompt="A studio portrait photo of a cat",
|
117 |
+
num_inference_steps=25,
|
118 |
+
guidance_scale=7,
|
119 |
+
refine="base_image_refiner",
|
120 |
+
refine_steps=10,
|
121 |
+
negative_prompt="ugly, soft, blurry, out of focus, low quality, garish, distorted, disfigured",
|
122 |
+
seed=1000,
|
123 |
+
width=1024,
|
124 |
+
height=1024,
|
125 |
+
)
|
126 |
+
|
127 |
+
gen(
|
128 |
+
"samples.2.txt2img.png",
|
129 |
+
prompt="A studio portrait photo of a cat",
|
130 |
+
num_inference_steps=25,
|
131 |
+
guidance_scale=7,
|
132 |
+
negative_prompt="ugly, soft, blurry, out of focus, low quality, garish, distorted, disfigured",
|
133 |
+
scheduler="KarrasDPM",
|
134 |
+
num_outputs=2,
|
135 |
+
seed=1000,
|
136 |
+
width=1024,
|
137 |
+
height=1024,
|
138 |
+
)
|
139 |
+
|
140 |
+
for s in SCHEDULERS:
|
141 |
+
gen(
|
142 |
+
f"sample.{s}.txt2img.png",
|
143 |
+
prompt="A studio portrait photo of a cat",
|
144 |
+
num_inference_steps=25,
|
145 |
+
guidance_scale=7,
|
146 |
+
negative_prompt="ugly, soft, blurry, out of focus, low quality, garish, distorted, disfigured",
|
147 |
+
scheduler=s,
|
148 |
+
seed=1000,
|
149 |
+
width=1024,
|
150 |
+
height=1024,
|
151 |
+
)
|
152 |
+
|
153 |
+
|
154 |
+
if __name__ == "__main__":
|
155 |
+
main()
|
cog_sdxl/script/download_preprocessing_weights.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import shutil
|
4 |
+
|
5 |
+
from transformers import (
|
6 |
+
BlipForConditionalGeneration,
|
7 |
+
BlipProcessor,
|
8 |
+
CLIPSegForImageSegmentation,
|
9 |
+
CLIPSegProcessor,
|
10 |
+
Swin2SRForImageSuperResolution,
|
11 |
+
)
|
12 |
+
|
13 |
+
DEFAULT_BLIP = "Salesforce/blip-image-captioning-large"
|
14 |
+
DEFAULT_CLIPSEG = "CIDAS/clipseg-rd64-refined"
|
15 |
+
DEFAULT_SWINIR = "caidas/swin2SR-realworld-sr-x4-64-bsrgan-psnr"
|
16 |
+
|
17 |
+
|
18 |
+
def upload(args):
|
19 |
+
blip_processor = BlipProcessor.from_pretrained(DEFAULT_BLIP)
|
20 |
+
blip_model = BlipForConditionalGeneration.from_pretrained(DEFAULT_BLIP)
|
21 |
+
|
22 |
+
clip_processor = CLIPSegProcessor.from_pretrained(DEFAULT_CLIPSEG)
|
23 |
+
clip_model = CLIPSegForImageSegmentation.from_pretrained(DEFAULT_CLIPSEG)
|
24 |
+
|
25 |
+
swin_model = Swin2SRForImageSuperResolution.from_pretrained(DEFAULT_SWINIR)
|
26 |
+
|
27 |
+
temp_models = "tmp/models"
|
28 |
+
if os.path.exists(temp_models):
|
29 |
+
shutil.rmtree(temp_models)
|
30 |
+
os.makedirs(temp_models)
|
31 |
+
|
32 |
+
blip_processor.save_pretrained(os.path.join(temp_models, "blip_processor"))
|
33 |
+
blip_model.save_pretrained(os.path.join(temp_models, "blip_large"))
|
34 |
+
clip_processor.save_pretrained(os.path.join(temp_models, "clip_seg_processor"))
|
35 |
+
clip_model.save_pretrained(os.path.join(temp_models, "clip_seg_rd64_refined"))
|
36 |
+
swin_model.save_pretrained(
|
37 |
+
os.path.join(temp_models, "swin2sr_realworld_sr_x4_64_bsrgan_psnr")
|
38 |
+
)
|
39 |
+
|
40 |
+
for val in os.listdir(temp_models):
|
41 |
+
if "tar" not in val:
|
42 |
+
os.system(
|
43 |
+
f"sudo tar -cvf {os.path.join(temp_models, val)}.tar -C {os.path.join(temp_models, val)} ."
|
44 |
+
)
|
45 |
+
os.system(
|
46 |
+
f"gcloud storage cp -R {os.path.join(temp_models, val)}.tar gs://{args.bucket}/{val}/"
|
47 |
+
)
|
48 |
+
|
49 |
+
|
50 |
+
if __name__ == "__main__":
|
51 |
+
parser = argparse.ArgumentParser()
|
52 |
+
parser.add_argument("--bucket", "-m", type=str)
|
53 |
+
args = parser.parse_args()
|
54 |
+
upload(args)
|
cog_sdxl/script/download_weights.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Run this before you deploy it on replicate, because if you don't
|
2 |
+
# whenever you run the model, it will download the weights from the
|
3 |
+
# internet, which will take a long time.
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from diffusers import AutoencoderKL, DiffusionPipeline
|
7 |
+
from diffusers.pipelines.stable_diffusion.safety_checker import (
|
8 |
+
StableDiffusionSafetyChecker,
|
9 |
+
)
|
10 |
+
|
11 |
+
# pipe = DiffusionPipeline.from_pretrained(
|
12 |
+
# "stabilityai/stable-diffusion-xl-base-1.0",
|
13 |
+
# torch_dtype=torch.float16,
|
14 |
+
# use_safetensors=True,
|
15 |
+
# variant="fp16",
|
16 |
+
# )
|
17 |
+
|
18 |
+
# pipe.save_pretrained("./cache", safe_serialization=True)
|
19 |
+
|
20 |
+
better_vae = AutoencoderKL.from_pretrained(
|
21 |
+
"madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
|
22 |
+
)
|
23 |
+
|
24 |
+
pipe = DiffusionPipeline.from_pretrained(
|
25 |
+
"stabilityai/stable-diffusion-xl-base-1.0",
|
26 |
+
vae=better_vae,
|
27 |
+
torch_dtype=torch.float16,
|
28 |
+
use_safetensors=True,
|
29 |
+
variant="fp16",
|
30 |
+
)
|
31 |
+
|
32 |
+
pipe.save_pretrained("./sdxl-cache", safe_serialization=True)
|
33 |
+
|
34 |
+
pipe = DiffusionPipeline.from_pretrained(
|
35 |
+
"stabilityai/stable-diffusion-xl-refiner-1.0",
|
36 |
+
torch_dtype=torch.float16,
|
37 |
+
use_safetensors=True,
|
38 |
+
variant="fp16",
|
39 |
+
)
|
40 |
+
|
41 |
+
# TODO - we don't need to save all of this and in fact should save just the unet, tokenizer, and config.
|
42 |
+
pipe.save_pretrained("./refiner-cache", safe_serialization=True)
|
43 |
+
|
44 |
+
|
45 |
+
safety = StableDiffusionSafetyChecker.from_pretrained(
|
46 |
+
"CompVis/stable-diffusion-safety-checker",
|
47 |
+
torch_dtype=torch.float16,
|
48 |
+
)
|
49 |
+
|
50 |
+
safety.save_pretrained("./safety-cache")
|
cog_sdxl/tests/assets/out.png
ADDED
Git LFS Details
|
cog_sdxl/tests/test_predict.py
ADDED
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
import os
|
3 |
+
import pickle
|
4 |
+
import subprocess
|
5 |
+
import sys
|
6 |
+
import time
|
7 |
+
from functools import partial
|
8 |
+
from io import BytesIO
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import pytest
|
12 |
+
import replicate
|
13 |
+
import requests
|
14 |
+
from PIL import Image, ImageChops
|
15 |
+
|
16 |
+
ENV = os.getenv('TEST_ENV', 'local')
|
17 |
+
LOCAL_ENDPOINT = "http://localhost:5000/predictions"
|
18 |
+
MODEL = os.getenv('STAGING_MODEL', 'no model configured')
|
19 |
+
|
20 |
+
def local_run(model_endpoint: str, model_input: dict):
|
21 |
+
response = requests.post(model_endpoint, json={"input": model_input})
|
22 |
+
data = response.json()
|
23 |
+
|
24 |
+
try:
|
25 |
+
# TODO: this will break if we test batching
|
26 |
+
datauri = data["output"][0]
|
27 |
+
base64_encoded_data = datauri.split(",")[1]
|
28 |
+
data = base64.b64decode(base64_encoded_data)
|
29 |
+
return Image.open(BytesIO(data))
|
30 |
+
except Exception as e:
|
31 |
+
print("Error!")
|
32 |
+
print("input:", model_input)
|
33 |
+
print(data["logs"])
|
34 |
+
raise e
|
35 |
+
|
36 |
+
|
37 |
+
def replicate_run(model: str, version: str, model_input: dict):
|
38 |
+
output = replicate.run(
|
39 |
+
f"{model}:{version}",
|
40 |
+
input=model_input)
|
41 |
+
url = output[0]
|
42 |
+
|
43 |
+
response = requests.get(url)
|
44 |
+
return Image.open(BytesIO(response.content))
|
45 |
+
|
46 |
+
|
47 |
+
def wait_for_server_to_be_ready(url, timeout=300):
|
48 |
+
"""
|
49 |
+
Waits for the server to be ready.
|
50 |
+
|
51 |
+
Args:
|
52 |
+
- url: The health check URL to poll.
|
53 |
+
- timeout: Maximum time (in seconds) to wait for the server to be ready.
|
54 |
+
"""
|
55 |
+
start_time = time.time()
|
56 |
+
while True:
|
57 |
+
try:
|
58 |
+
response = requests.get(url)
|
59 |
+
data = response.json()
|
60 |
+
|
61 |
+
if data["status"] == "READY":
|
62 |
+
return
|
63 |
+
elif data["status"] == "SETUP_FAILED":
|
64 |
+
raise RuntimeError(
|
65 |
+
"Server initialization failed with status: SETUP_FAILED"
|
66 |
+
)
|
67 |
+
|
68 |
+
except requests.RequestException:
|
69 |
+
pass
|
70 |
+
|
71 |
+
if time.time() - start_time > timeout:
|
72 |
+
raise TimeoutError("Server did not become ready in the expected time.")
|
73 |
+
|
74 |
+
time.sleep(5) # Poll every 5 seconds
|
75 |
+
|
76 |
+
|
77 |
+
@pytest.fixture(scope="session")
|
78 |
+
def inference_func():
|
79 |
+
"""
|
80 |
+
local inference uses http API to hit local server; staging inference uses python API b/c it's cleaner.
|
81 |
+
"""
|
82 |
+
if ENV == 'local':
|
83 |
+
return partial(local_run, LOCAL_ENDPOINT)
|
84 |
+
elif ENV == 'staging':
|
85 |
+
model = replicate.models.get(MODEL)
|
86 |
+
print(f"model,", model)
|
87 |
+
version = model.versions.list()[0]
|
88 |
+
return partial(replicate_run, MODEL, version.id)
|
89 |
+
else:
|
90 |
+
raise Exception(f"env should be local or staging but was {ENV}")
|
91 |
+
|
92 |
+
|
93 |
+
@pytest.fixture(scope="session", autouse=True)
|
94 |
+
def service():
|
95 |
+
"""
|
96 |
+
Spins up local cog server to hit for tests if running locally, no-op otherwise
|
97 |
+
"""
|
98 |
+
if ENV == 'local':
|
99 |
+
print("building model")
|
100 |
+
# starts local server if we're running things locally
|
101 |
+
build_command = 'cog build -t test-model'.split()
|
102 |
+
subprocess.run(build_command, check=True)
|
103 |
+
container_name = 'cog-test'
|
104 |
+
try:
|
105 |
+
subprocess.check_output(['docker', 'inspect', '--format="{{.State.Running}}"', container_name])
|
106 |
+
print(f"Container '{container_name}' is running. Stopping and removing...")
|
107 |
+
subprocess.check_call(['docker', 'stop', container_name])
|
108 |
+
subprocess.check_call(['docker', 'rm', container_name])
|
109 |
+
print(f"Container '{container_name}' stopped and removed.")
|
110 |
+
except subprocess.CalledProcessError:
|
111 |
+
# Container not found
|
112 |
+
print(f"Container '{container_name}' not found or not running.")
|
113 |
+
|
114 |
+
run_command = f'docker run -d -p 5000:5000 --gpus all --name {container_name} test-model '.split()
|
115 |
+
process = subprocess.Popen(run_command, stdout=sys.stdout, stderr=sys.stderr)
|
116 |
+
|
117 |
+
wait_for_server_to_be_ready("http://localhost:5000/health-check")
|
118 |
+
|
119 |
+
yield
|
120 |
+
process.terminate()
|
121 |
+
process.wait()
|
122 |
+
stop_command = "docker stop cog-test".split()
|
123 |
+
subprocess.run(stop_command)
|
124 |
+
else:
|
125 |
+
yield
|
126 |
+
|
127 |
+
|
128 |
+
def image_equal_fuzzy(img_expected, img_actual, test_name='default', tol=20):
|
129 |
+
"""
|
130 |
+
Assert that average pixel values differ by less than tol across image
|
131 |
+
Tol determined empirically - holding everything else equal but varying seed
|
132 |
+
generates images that vary by at least 50
|
133 |
+
"""
|
134 |
+
img1 = np.array(img_expected, dtype=np.int32)
|
135 |
+
img2 = np.array(img_actual, dtype=np.int32)
|
136 |
+
|
137 |
+
mean_delta = np.mean(np.abs(img1 - img2))
|
138 |
+
imgs_equal = (mean_delta < tol)
|
139 |
+
if not imgs_equal:
|
140 |
+
# save failures for quick inspection
|
141 |
+
save_dir = f"tmp/{test_name}"
|
142 |
+
if not os.path.exists(save_dir):
|
143 |
+
os.makedirs(save_dir)
|
144 |
+
img_expected.save(os.path.join(save_dir, 'expected.png'))
|
145 |
+
img_actual.save(os.path.join(save_dir, 'actual.png'))
|
146 |
+
difference = ImageChops.difference(img_expected, img_actual)
|
147 |
+
difference.save(os.path.join(save_dir, 'delta.png'))
|
148 |
+
|
149 |
+
return imgs_equal
|
150 |
+
|
151 |
+
|
152 |
+
def test_seeded_prediction(inference_func, request):
|
153 |
+
"""
|
154 |
+
SDXL w/seed should be deterministic. may need to adjust tolerance for optimized SDXLs
|
155 |
+
"""
|
156 |
+
data = {
|
157 |
+
"prompt": "An astronaut riding a rainbow unicorn, cinematic, dramatic",
|
158 |
+
"num_inference_steps": 50,
|
159 |
+
"width": 1024,
|
160 |
+
"height": 1024,
|
161 |
+
"scheduler": "DDIM",
|
162 |
+
"refine": "expert_ensemble_refiner",
|
163 |
+
"seed": 12103,
|
164 |
+
}
|
165 |
+
actual_image = inference_func(data)
|
166 |
+
expected_image = Image.open("tests/assets/out.png")
|
167 |
+
assert image_equal_fuzzy(actual_image, expected_image, test_name=request.node.name)
|
168 |
+
|
169 |
+
|
170 |
+
def test_lora_load_unload(inference_func, request):
|
171 |
+
"""
|
172 |
+
Tests generation with & without loras.
|
173 |
+
This is checking for some gnarly state issues (can SDXL load / unload LoRAs), so predictions need to run in series.
|
174 |
+
"""
|
175 |
+
SEED = 1234
|
176 |
+
base_data = {
|
177 |
+
"prompt": "A photo of a dog on the beach",
|
178 |
+
"num_inference_steps": 50,
|
179 |
+
# Add other parameters here
|
180 |
+
"seed": SEED,
|
181 |
+
}
|
182 |
+
base_img_1 = inference_func(base_data)
|
183 |
+
|
184 |
+
lora_a_data = {
|
185 |
+
"prompt": "A photo of a TOK on the beach",
|
186 |
+
"num_inference_steps": 50,
|
187 |
+
# Add other parameters here
|
188 |
+
"replicate_weights": "https://storage.googleapis.com/dan-scratch-public/sdxl/other_model.tar",
|
189 |
+
"seed": SEED
|
190 |
+
}
|
191 |
+
lora_a_img_1 = inference_func(lora_a_data)
|
192 |
+
assert not image_equal_fuzzy(lora_a_img_1, base_img_1, test_name=request.node.name)
|
193 |
+
|
194 |
+
lora_a_img_2 = inference_func(lora_a_data)
|
195 |
+
assert image_equal_fuzzy(lora_a_img_1, lora_a_img_2, test_name=request.node.name)
|
196 |
+
|
197 |
+
lora_b_data = {
|
198 |
+
"prompt": "A photo of a TOK on the beach",
|
199 |
+
"num_inference_steps": 50,
|
200 |
+
"replicate_weights": "https://storage.googleapis.com/dan-scratch-public/sdxl/monstertoy_model.tar",
|
201 |
+
"seed": SEED,
|
202 |
+
}
|
203 |
+
lora_b_img = inference_func(lora_b_data)
|
204 |
+
assert not image_equal_fuzzy(lora_a_img_1, lora_b_img, test_name=request.node.name)
|
205 |
+
assert not image_equal_fuzzy(base_img_1, lora_b_img, test_name=request.node.name)
|
cog_sdxl/tests/test_remote_train.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import pytest
|
3 |
+
import replicate
|
4 |
+
|
5 |
+
|
6 |
+
@pytest.fixture(scope="module")
|
7 |
+
def model_name(request):
|
8 |
+
return "stability-ai/sdxl"
|
9 |
+
|
10 |
+
|
11 |
+
@pytest.fixture(scope="module")
|
12 |
+
def model(model_name):
|
13 |
+
return replicate.models.get(model_name)
|
14 |
+
|
15 |
+
|
16 |
+
@pytest.fixture(scope="module")
|
17 |
+
def version(model):
|
18 |
+
versions = model.versions.list()
|
19 |
+
return versions[0]
|
20 |
+
|
21 |
+
|
22 |
+
@pytest.fixture(scope="module")
|
23 |
+
def training(model_name, version):
|
24 |
+
training_input = {
|
25 |
+
"input_images": "https://storage.googleapis.com/replicate-datasets/sdxl-test/monstertoy-captions.tar"
|
26 |
+
}
|
27 |
+
print(f"Training on {model_name}:{version.id}")
|
28 |
+
return replicate.trainings.create(
|
29 |
+
version=model_name + ":" + version.id,
|
30 |
+
input=training_input,
|
31 |
+
destination="replicate-internal/training-scratch",
|
32 |
+
)
|
33 |
+
|
34 |
+
|
35 |
+
@pytest.fixture(scope="module")
|
36 |
+
def prediction_tests():
|
37 |
+
return [
|
38 |
+
{
|
39 |
+
"prompt": "A photo of TOK at the beach",
|
40 |
+
"refine": "expert_ensemble_refiner",
|
41 |
+
},
|
42 |
+
]
|
43 |
+
|
44 |
+
|
45 |
+
def test_training(training):
|
46 |
+
while training.completed_at is None:
|
47 |
+
time.sleep(60)
|
48 |
+
training.reload()
|
49 |
+
assert training.status == "succeeded"
|
50 |
+
|
51 |
+
|
52 |
+
@pytest.fixture(scope="module")
|
53 |
+
def trained_model_and_version(training):
|
54 |
+
trained_model, trained_version = training.output["version"].split(":")
|
55 |
+
return trained_model, trained_version
|
56 |
+
|
57 |
+
|
58 |
+
def test_post_training_predictions(trained_model_and_version, prediction_tests):
|
59 |
+
trained_model, trained_version = trained_model_and_version
|
60 |
+
model = replicate.models.get(trained_model)
|
61 |
+
version = model.versions.get(trained_version)
|
62 |
+
predictions = [
|
63 |
+
replicate.predictions.create(version=version, input=val)
|
64 |
+
for val in prediction_tests
|
65 |
+
]
|
66 |
+
|
67 |
+
for val in predictions:
|
68 |
+
val.wait()
|
69 |
+
assert val.status == "succeeded"
|
cog_sdxl/tests/test_utils.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import requests
|
4 |
+
import time
|
5 |
+
from threading import Thread, Lock
|
6 |
+
import re
|
7 |
+
import multiprocessing
|
8 |
+
import subprocess
|
9 |
+
|
10 |
+
ERROR_PATTERN = re.compile(r"ERROR:")
|
11 |
+
|
12 |
+
|
13 |
+
def get_image_name():
|
14 |
+
current_dir = os.path.basename(os.getcwd())
|
15 |
+
|
16 |
+
if "cog" in current_dir:
|
17 |
+
return current_dir
|
18 |
+
else:
|
19 |
+
return f"cog-{current_dir}"
|
20 |
+
|
21 |
+
|
22 |
+
def process_log_line(line):
|
23 |
+
line = line.decode("utf-8").strip()
|
24 |
+
try:
|
25 |
+
log_data = json.loads(line)
|
26 |
+
return json.dumps(log_data, indent=2)
|
27 |
+
except json.JSONDecodeError:
|
28 |
+
return line
|
29 |
+
|
30 |
+
|
31 |
+
def capture_output(pipe, print_lock, logs=None, error_detected=None):
|
32 |
+
for line in iter(pipe.readline, b""):
|
33 |
+
formatted_line = process_log_line(line)
|
34 |
+
with print_lock:
|
35 |
+
print(formatted_line)
|
36 |
+
if logs is not None:
|
37 |
+
logs.append(formatted_line)
|
38 |
+
if error_detected is not None:
|
39 |
+
if ERROR_PATTERN.search(formatted_line):
|
40 |
+
error_detected[0] = True
|
41 |
+
|
42 |
+
|
43 |
+
def wait_for_server_to_be_ready(url, timeout=300):
|
44 |
+
"""
|
45 |
+
Waits for the server to be ready.
|
46 |
+
|
47 |
+
Args:
|
48 |
+
- url: The health check URL to poll.
|
49 |
+
- timeout: Maximum time (in seconds) to wait for the server to be ready.
|
50 |
+
"""
|
51 |
+
start_time = time.time()
|
52 |
+
while True:
|
53 |
+
try:
|
54 |
+
response = requests.get(url)
|
55 |
+
data = response.json()
|
56 |
+
|
57 |
+
if data["status"] == "READY":
|
58 |
+
return
|
59 |
+
elif data["status"] == "SETUP_FAILED":
|
60 |
+
raise RuntimeError(
|
61 |
+
"Server initialization failed with status: SETUP_FAILED"
|
62 |
+
)
|
63 |
+
|
64 |
+
except requests.RequestException:
|
65 |
+
pass
|
66 |
+
|
67 |
+
if time.time() - start_time > timeout:
|
68 |
+
raise TimeoutError("Server did not become ready in the expected time.")
|
69 |
+
|
70 |
+
time.sleep(5) # Poll every 5 seconds
|
71 |
+
|
72 |
+
|
73 |
+
def run_training_subprocess(command):
|
74 |
+
# Start the subprocess with pipes for stdout and stderr
|
75 |
+
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
76 |
+
|
77 |
+
# Create a lock for printing and a list to accumulate logs
|
78 |
+
print_lock = multiprocessing.Lock()
|
79 |
+
logs = multiprocessing.Manager().list()
|
80 |
+
error_detected = multiprocessing.Manager().list([False])
|
81 |
+
|
82 |
+
# Start two separate processes to handle stdout and stderr
|
83 |
+
stdout_processor = multiprocessing.Process(
|
84 |
+
target=capture_output, args=(process.stdout, print_lock, logs, error_detected)
|
85 |
+
)
|
86 |
+
stderr_processor = multiprocessing.Process(
|
87 |
+
target=capture_output, args=(process.stderr, print_lock, logs, error_detected)
|
88 |
+
)
|
89 |
+
|
90 |
+
# Start the log processors
|
91 |
+
stdout_processor.start()
|
92 |
+
stderr_processor.start()
|
93 |
+
|
94 |
+
# Wait for the subprocess to finish
|
95 |
+
process.wait()
|
96 |
+
|
97 |
+
# Wait for the log processors to finish
|
98 |
+
stdout_processor.join()
|
99 |
+
stderr_processor.join()
|
100 |
+
|
101 |
+
# Check if an error pattern was detected
|
102 |
+
if error_detected[0]:
|
103 |
+
raise Exception("Error detected in training logs! Check logs for details")
|
104 |
+
|
105 |
+
return list(logs)
|