tgxs002 commited on
Commit
54199b6
·
1 Parent(s): 9588821
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +17 -0
  2. HPS_v2.pt +3 -0
  3. LICENSE +201 -0
  4. README.md +1 -5
  5. app.py +83 -0
  6. assets/hps_banner.png +3 -0
  7. assets/overview.png +3 -0
  8. configs/HPSv2.sh +32 -0
  9. configs/controller.sh +59 -0
  10. evaluate.py +220 -0
  11. requirements.txt +18 -0
  12. score.py +56 -0
  13. src/__init__.py +0 -0
  14. src/__pycache__/__init__.cpython-38.pyc +0 -0
  15. src/open_clip/__init__.py +14 -0
  16. src/open_clip/__pycache__/__init__.cpython-38.pyc +0 -0
  17. src/open_clip/__pycache__/coca_model.cpython-38.pyc +0 -0
  18. src/open_clip/__pycache__/constants.cpython-38.pyc +0 -0
  19. src/open_clip/__pycache__/factory.cpython-38.pyc +0 -0
  20. src/open_clip/__pycache__/hf_configs.cpython-38.pyc +0 -0
  21. src/open_clip/__pycache__/hf_model.cpython-38.pyc +0 -0
  22. src/open_clip/__pycache__/loss.cpython-38.pyc +0 -0
  23. src/open_clip/__pycache__/model.cpython-38.pyc +0 -0
  24. src/open_clip/__pycache__/modified_resnet.cpython-38.pyc +0 -0
  25. src/open_clip/__pycache__/openai.cpython-38.pyc +0 -0
  26. src/open_clip/__pycache__/pretrained.cpython-38.pyc +0 -0
  27. src/open_clip/__pycache__/push_to_hf_hub.cpython-38.pyc +0 -0
  28. src/open_clip/__pycache__/timm_model.cpython-38.pyc +0 -0
  29. src/open_clip/__pycache__/tokenizer.cpython-38.pyc +0 -0
  30. src/open_clip/__pycache__/transform.cpython-38.pyc +0 -0
  31. src/open_clip/__pycache__/transformer.cpython-38.pyc +0 -0
  32. src/open_clip/__pycache__/utils.cpython-38.pyc +0 -0
  33. src/open_clip/__pycache__/version.cpython-38.pyc +0 -0
  34. src/open_clip/bpe_simple_vocab_16e6.txt.gz +3 -0
  35. src/open_clip/coca_model.py +458 -0
  36. src/open_clip/constants.py +2 -0
  37. src/open_clip/factory.py +433 -0
  38. src/open_clip/generation_utils.py +0 -0
  39. src/open_clip/hf_configs.py +45 -0
  40. src/open_clip/hf_model.py +176 -0
  41. src/open_clip/loss.py +270 -0
  42. src/open_clip/model.py +461 -0
  43. src/open_clip/model_configs/RN101-quickgelu.json +22 -0
  44. src/open_clip/model_configs/RN101.json +21 -0
  45. src/open_clip/model_configs/RN50-quickgelu.json +22 -0
  46. src/open_clip/model_configs/RN50.json +21 -0
  47. src/open_clip/model_configs/RN50x16.json +21 -0
  48. src/open_clip/model_configs/RN50x4.json +21 -0
  49. src/open_clip/model_configs/RN50x64.json +21 -0
  50. src/open_clip/model_configs/ViT-B-16-plus-240.json +16 -0
.gitattributes CHANGED
@@ -33,3 +33,20 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/hps_banner.png filter=lfs diff=lfs merge=lfs -text
37
+ assets/overview.png filter=lfs diff=lfs merge=lfs -text
38
+ src/open_clip/bpe_simple_vocab_16e6.txt.gz filter=lfs diff=lfs merge=lfs -text
39
+ HPS_v2.pt filter=lfs diff=lfs merge=lfs -text
40
+ tests/*.png filter=lfs diff=lfs merge=lfs -text
41
+ tests/docs/clip_loss.png filter=lfs diff=lfs merge=lfs -text
42
+ tests/docs/clip_val_loss.png filter=lfs diff=lfs merge=lfs -text
43
+ tests/docs/clip_zeroshot.png filter=lfs diff=lfs merge=lfs -text
44
+ tests/docs/laion2b_clip_zeroshot_b32.png filter=lfs diff=lfs merge=lfs -text
45
+ tests/docs/laion_clip_zeroshot_l14.png filter=lfs diff=lfs merge=lfs -text
46
+ tests/docs/CLIP.png filter=lfs diff=lfs merge=lfs -text
47
+ tests/docs/clip_recall.png filter=lfs diff=lfs merge=lfs -text
48
+ tests/docs/effective_robustness.png filter=lfs diff=lfs merge=lfs -text
49
+ tests/docs/laion_clip_zeroshot_b16_plus_240.png filter=lfs diff=lfs merge=lfs -text
50
+ tests/docs/laion_clip_zeroshot_b16.png filter=lfs diff=lfs merge=lfs -text
51
+ tests/docs/laion_clip_zeroshot.png filter=lfs diff=lfs merge=lfs -text
52
+ tests/docs/scaling.png filter=lfs diff=lfs merge=lfs -text
HPS_v2.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9defdb2ba952d35ec1cb3334554bd4033e415d1397d742a9946d2ac884ed53a1
3
+ size 8063374362
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md CHANGED
@@ -1,4 +1,3 @@
1
- ---
2
  title: HPSv2
3
  emoji: 🚀
4
  colorFrom: purple
@@ -7,7 +6,4 @@ sdk: gradio
7
  sdk_version: 3.37.0
8
  app_file: app.py
9
  pinned: false
10
- license: apache-2.0
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
1
  title: HPSv2
2
  emoji: 🚀
3
  colorFrom: purple
 
6
  sdk_version: 3.37.0
7
  app_file: app.py
8
  pinned: false
9
+ license: apache-2.0
 
 
 
app.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from PIL import Image
4
+ from src.open_clip import create_model_and_transforms, get_tokenizer
5
+ import warnings
6
+ import argparse
7
+
8
+ warnings.filterwarnings("ignore", category=UserWarning)
9
+
10
+ # Create an argument parser
11
+ parser = argparse.ArgumentParser()
12
+ parser.add_argument('--checkpoint', type=str, default='HPS_v2.pt', help='Path to the model checkpoint')
13
+
14
+ args = parser.parse_args()
15
+
16
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
17
+ model, preprocess_train, preprocess_val = create_model_and_transforms(
18
+ 'ViT-H-14',
19
+ 'laion2B-s32B-b79K',
20
+ precision='amp',
21
+ device=device,
22
+ jit=False,
23
+ force_quick_gelu=False,
24
+ force_custom_text=False,
25
+ force_patch_dropout=False,
26
+ force_image_size=None,
27
+ pretrained_image=False,
28
+ image_mean=None,
29
+ image_std=None,
30
+ light_augmentation=True,
31
+ aug_cfg={},
32
+ output_dict=True,
33
+ with_score_predictor=False,
34
+ with_region_predictor=False
35
+ )
36
+
37
+ checkpoint = torch.load(args.checkpoint)
38
+ model.load_state_dict(checkpoint['state_dict'])
39
+ tokenizer = get_tokenizer('ViT-H-14')
40
+ model.eval()
41
+
42
+ intro = """
43
+ <h1 style="font-weight: 1400; text-align: center; margin-bottom: 7px;">
44
+ HPS v2
45
+ </h1>
46
+ <h3 style="font-weight: 600; text-align: center;">
47
+ evaluating human preference for generated images
48
+ </h3>
49
+ <h4 style="text-align: center; margin-bottom: 7px;">
50
+ <a href="https://github.com/tgxs002/HPSv2" style="text-decoration: underline;" target="_blank">GitHub</a> | <a href="https://arxiv.org/abs/2306.09341" style="text-decoration: underline;" target="_blank">ArXiv</a>
51
+ </h4>
52
+ <p style="font-size: 0.9rem; margin: 0rem; line-height: 1.2em; margin-top:1em">
53
+ <p/>"""
54
+
55
+ def inference(image, prompt):
56
+ # Load your image and prompt
57
+ with torch.no_grad():
58
+
59
+ # Process the image
60
+ image = preprocess_val(image).unsqueeze(0).to(device=device, non_blocking=True)
61
+ # Process the prompt
62
+ text = tokenizer([prompt]).to(device=device, non_blocking=True)
63
+ # Calculate the HPS
64
+ with torch.cuda.amp.autocast():
65
+ outputs = model(image, text)
66
+ image_features, text_features = outputs["image_features"], outputs["text_features"]
67
+ logits_per_image = image_features @ text_features.T
68
+
69
+ hps_score = torch.diagonal(logits_per_image).cpu().numpy()
70
+ output = 'HPSv2 score: ' + str(hps_score[0])
71
+ return output
72
+
73
+ with gr.Blocks(css="style.css") as demo:
74
+ gr.HTML(intro)
75
+ with gr.Column():
76
+ image = gr.Image(label="Image", type="pil")
77
+ prompt = gr.Textbox(lines=1, label="Prompt")
78
+ button = gr.Button("Compute HPS v2")
79
+ score = gr.Textbox(label="output", lines=1, interactive=False, elem_id="output")
80
+ button.click(inference, inputs=[image, prompt], outputs=score)
81
+
82
+ demo.queue(concurrency_count=1)
83
+ demo.launch()
assets/hps_banner.png ADDED

Git LFS Details

  • SHA256: 2dedf40e90f844edd306f17f2d4d515c26c50baafe4bcd291148640581c68ba3
  • Pointer size: 132 Bytes
  • Size of remote file: 1.54 MB
assets/overview.png ADDED

Git LFS Details

  • SHA256: a009e838f28ef1fc6e153fd687c84965dfa4f7ec4c8e1de8f394f9102d12cb54
  • Pointer size: 131 Bytes
  • Size of remote file: 105 kB
configs/HPSv2.sh ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name=$0
2
+ . configs/controller.sh
3
+
4
+ args=" \
5
+ --zeroshot-frequency 1 \
6
+ --report-to tensorboard \
7
+ --train-data $local_ranking_path/train.json $local_benchmark_path/annotations.json \
8
+ --val-data $local_ranking_path/test.json $local_benchmark_path/annotations.json \
9
+ --train-folder $local_ranking_path/train $local_benchmark_path \
10
+ --val-folder $local_ranking_path/test $local_benchmark_path \
11
+ --warmup 500 \
12
+ --lr 0.0000033 \
13
+ --wd 0.35 \
14
+ --workers 4 4 \
15
+ --batch-size 16 16 \
16
+ --pretrained laion2B-s32B-b79K \
17
+ --dataset-type HPD ranking \
18
+ --ignore-in-train 0 1 \
19
+ --ignore-in-val 1 0 \
20
+ --train-data-sample-ratio 1.0 0 \
21
+ --model ViT-H-14 \
22
+ --lock-text \
23
+ --lock-image \
24
+ --lock-text-unlocked-layers 11 \
25
+ --lock-image-unlocked-groups 20 \
26
+ --logs none \
27
+ --light-augmentation \
28
+ --exp-name $name \
29
+ --iterations 100 \
30
+ "
31
+
32
+ eval "$header$args$extra_args 2>&1 | tee -a $work_dir/exp_$now.txt"
configs/controller.sh ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ exp=${1:-'test'}
2
+ gpu=${2:-'1'}
3
+ type=${3:-'local'} # choose slurm if you are running on a cluster with slurm scheduler
4
+
5
+ if [ "$type" == 'local' ]; then
6
+ extra_args=${@:4:99}
7
+ else
8
+ quotatype=${4:-'auto'} # for slurm
9
+ partition=${5:-'1'} # for slurm
10
+ extra_args=${@:6:99}
11
+ quotatype=spot
12
+ partition=YOUR_PARTITION
13
+ extra_args=${@:4:99}
14
+ fi
15
+
16
+ name=${name/#configs/logs}
17
+ name=${name//.sh//$exp}
18
+ work_dir="${name}"
19
+ now=$(date +"%Y%m%d_%H%M%S")
20
+ mkdir -p $work_dir
21
+
22
+ ncpu='4'
23
+
24
+ if [ "$quotatype" == 'reserved_normal' ]; then
25
+ quotatype='reserved --phx-priority=${gpu} normal'
26
+ fi
27
+
28
+ if [ "$type" == 'local' ]; then
29
+
30
+
31
+ ava_path=/mnt/afs/xswu/datasets/AVA/images
32
+ local_data_path=/mnt/afs/xswu/datasets/preference
33
+ local_ava_path=/mnt/afs/xswu/datasets/AVA
34
+ local_simulacra_path=/mnt/afs/xswu/datasets/simulacra
35
+ local_region_path=/mnt/afs/xswu/datasets/regional_dataset
36
+ local_ranking_path=/mnt/afs/xswu/datasets/HPDv2
37
+ local_benchmark_path=/mnt/afs/xswu/datasets/benchmark
38
+ local_ImageReward_path=/mnt/afs/xswu/datasets/ImageReward
39
+ local_pap_path=/mnt/afs/xswu/datasets/PAP
40
+
41
+ header="torchrun --nproc_per_node=${gpu} --nnodes=1 --max_restarts=3 -m src.training.main "
42
+
43
+ else
44
+
45
+ data_path=s3://preference_images/
46
+ ava_path=s3://AVA/
47
+ simulacra_path=s3://simulacra/
48
+ region_path=/mnt/lustre/wuxiaoshi1.vendor/datasets/regional_dataset/
49
+ local_data_path=/mnt/lustre/wuxiaoshi1.vendor/datasets/human_preference
50
+ local_ava_path=/mnt/lustre/wuxiaoshi1.vendor/datasets/AVA
51
+ local_simulacra_path=/mnt/lustre/wuxiaoshi1.vendor/datasets/simulacra
52
+ local_region_path=/mnt/lustre/wuxiaoshi1.vendor/datasets/regional_dataset
53
+ local_ranking_path=/mnt/lustre/wuxiaoshi1.vendor/datasets/ranking_dataset
54
+ local_benchmark_path=/mnt/lustre/wuxiaoshi1.vendor/datasets/benchmark
55
+ local_ImageReward_path=/mnt/lustre/wuxiaoshi1.vendor/datasets/ImageReward
56
+ header="srun --async --partition=$partition -n${gpu} --mpi=pmi2 --gres=gpu:$gpu --ntasks-per-node=${gpu} --quotatype=$quotatype \
57
+ --job-name=$exp --cpus-per-task=$ncpu --kill-on-bad-exit=1 -o local.out python -m src.training.main "
58
+
59
+ fi
evaluate.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from cProfile import label
2
+ import os
3
+ import json
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+ from argparse import ArgumentParser
7
+ from PIL import Image
8
+
9
+ import torch
10
+ from torch.utils.data import Dataset, DataLoader
11
+
12
+ from src.open_clip import create_model_and_transforms, get_tokenizer
13
+ from src.training.train import calc_ImageReward, inversion_score
14
+ from src.training.data import ImageRewardDataset, collate_rank, RankingDataset
15
+
16
+
17
+ parser = ArgumentParser()
18
+ parser.add_argument('--data-type', type=str, choices=['benchmark', 'test', 'ImageReward', 'drawbench'])
19
+ parser.add_argument('--data-path', type=str, help='path to dataset')
20
+ parser.add_argument('--image-path', type=str, help='path to image files')
21
+ parser.add_argument('--checkpoint', type=str, help='path to checkpoint')
22
+ parser.add_argument('--batch-size', type=int, default=20)
23
+ args = parser.parse_args()
24
+
25
+ batch_size = args.batch_size
26
+ args.model = "ViT-H-14"
27
+ args.precision = 'amp'
28
+ print(args.model)
29
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
30
+ model, preprocess_train, preprocess_val = create_model_and_transforms(
31
+ args.model,
32
+ 'laion2B-s32B-b79K',
33
+ precision=args.precision,
34
+ device=device,
35
+ jit=False,
36
+ force_quick_gelu=False,
37
+ force_custom_text=False,
38
+ force_patch_dropout=False,
39
+ force_image_size=None,
40
+ pretrained_image=False,
41
+ image_mean=None,
42
+ image_std=None,
43
+ light_augmentation=True,
44
+ aug_cfg={},
45
+ output_dict=True,
46
+ with_score_predictor=False,
47
+ with_region_predictor=False
48
+ )
49
+
50
+ checkpoint = torch.load(args.checkpoint)
51
+ model.load_state_dict(checkpoint['state_dict'])
52
+ tokenizer = get_tokenizer(args.model)
53
+ model.eval()
54
+
55
+ class BenchmarkDataset(Dataset):
56
+ def __init__(self, meta_file, image_folder,transforms, tokenizer):
57
+ self.transforms = transforms
58
+ self.image_folder = image_folder
59
+ self.tokenizer = tokenizer
60
+ self.open_image = Image.open
61
+ with open(meta_file, 'r') as f:
62
+ self.annotations = json.load(f)
63
+
64
+ def __len__(self):
65
+ return len(self.annotations)
66
+
67
+ def __getitem__(self, idx):
68
+ try:
69
+ img_path = os.path.join(self.image_folder, f'{idx:05d}.jpg')
70
+ images = self.transforms(self.open_image(os.path.join(img_path)))
71
+ caption = self.tokenizer(self.annotations[idx])
72
+ return images, caption
73
+ except:
74
+ print('file not exist')
75
+ return self.__getitem__((idx + 1) % len(self))
76
+
77
+ def evaluate_IR(data_path, image_folder, model):
78
+ meta_file = data_path + '/ImageReward_test.json'
79
+ dataset = ImageRewardDataset(meta_file, image_folder, preprocess_val, tokenizer)
80
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=4, collate_fn=collate_rank)
81
+
82
+ score = 0
83
+ total = len(dataset)
84
+ with torch.no_grad():
85
+ for batch in tqdm(dataloader):
86
+ images, num_images, labels, texts = batch
87
+ images = images.to(device=device, non_blocking=True)
88
+ texts = texts.to(device=device, non_blocking=True)
89
+ num_images = num_images.to(device=device, non_blocking=True)
90
+ labels = labels.to(device=device, non_blocking=True)
91
+
92
+ with torch.cuda.amp.autocast():
93
+ outputs = model(images, texts)
94
+ image_features, text_features, logit_scale = outputs["image_features"], outputs["text_features"], outputs["logit_scale"]
95
+ logits_per_image = logit_scale * image_features @ text_features.T
96
+ paired_logits_list = [logit[:,i] for i, logit in enumerate(logits_per_image.split(num_images.tolist()))]
97
+
98
+ predicted = [torch.argsort(-k) for k in paired_logits_list]
99
+ hps_ranking = [[predicted[i].tolist().index(j) for j in range(n)] for i,n in enumerate(num_images)]
100
+ labels = [label for label in labels.split(num_images.tolist())]
101
+ score +=sum([calc_ImageReward(paired_logits_list[i].tolist(), labels[i]) for i in range(len(hps_ranking))])
102
+ print('ImageReward:', score/total)
103
+
104
+ def evaluate_rank(data_path, image_folder, model):
105
+ meta_file = data_path + '/test.json'
106
+ dataset = RankingDataset(meta_file, image_folder, preprocess_val, tokenizer)
107
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=4, collate_fn=collate_rank)
108
+
109
+ score = 0
110
+ total = len(dataset)
111
+ all_rankings = []
112
+ with torch.no_grad():
113
+ for batch in tqdm(dataloader):
114
+ images, num_images, labels, texts = batch
115
+ images = images.to(device=device, non_blocking=True)
116
+ texts = texts.to(device=device, non_blocking=True)
117
+ num_images = num_images.to(device=device, non_blocking=True)
118
+ labels = labels.to(device=device, non_blocking=True)
119
+
120
+ with torch.cuda.amp.autocast():
121
+ outputs = model(images, texts)
122
+ image_features, text_features, logit_scale = outputs["image_features"], outputs["text_features"], outputs["logit_scale"]
123
+ logits_per_image = logit_scale * image_features @ text_features.T
124
+ paired_logits_list = [logit[:,i] for i, logit in enumerate(logits_per_image.split(num_images.tolist()))]
125
+
126
+ predicted = [torch.argsort(-k) for k in paired_logits_list]
127
+ hps_ranking = [[predicted[i].tolist().index(j) for j in range(n)] for i,n in enumerate(num_images)]
128
+ labels = [label for label in labels.split(num_images.tolist())]
129
+ all_rankings.extend(hps_ranking)
130
+ score += sum([inversion_score(hps_ranking[i], labels[i]) for i in range(len(hps_ranking))])
131
+ print('ranking_acc:', score/total)
132
+ with open('logs/hps_rank.json', 'w') as f:
133
+ json.dump(all_rankings, f)
134
+
135
+ def collate_eval(batch):
136
+ images = torch.stack([sample[0] for sample in batch])
137
+ captions = torch.cat([sample[1] for sample in batch])
138
+ return images, captions
139
+
140
+
141
+ def evaluate_benchmark(data_path, root_dir, model):
142
+ meta_dir = data_path
143
+ model_list = os.listdir(root_dir)
144
+ style_list = os.listdir(os.path.join(root_dir, model_list[0]))
145
+
146
+ score = {}
147
+ for model_id in model_list:
148
+ score[model_id]={}
149
+ for style in style_list:
150
+ # score[model_id][style] = [0] * 10
151
+ score[model_id][style] = []
152
+ image_folder = os.path.join(root_dir, model_id, style)
153
+ meta_file = os.path.join(meta_dir, f'{style}.json')
154
+ dataset = BenchmarkDataset(meta_file, image_folder, preprocess_val, tokenizer)
155
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_eval)
156
+
157
+ with torch.no_grad():
158
+ for i, batch in enumerate(dataloader):
159
+ images, texts = batch
160
+ images = images.to(device=device, non_blocking=True)
161
+ texts = texts.to(device=device, non_blocking=True)
162
+
163
+ with torch.cuda.amp.autocast():
164
+ outputs = model(images, texts)
165
+ image_features, text_features = outputs["image_features"], outputs["text_features"]
166
+ logits_per_image = image_features @ text_features.T
167
+ # score[model_id][style][i] = torch.sum(torch.diagonal(logits_per_image)).cpu().item() / 80
168
+ score[model_id][style].extend(torch.diagonal(logits_per_image).cpu().tolist())
169
+ print('-----------benchmark score ---------------- ')
170
+ for model_id, data in score.items():
171
+ for style , res in data.items():
172
+ avg_score = [np.mean(res[i:i+80]) for i in range(0, 800, 80)]
173
+ print(model_id, '\t', style, '\t', np.mean(avg_score), '\t', np.std(avg_score))
174
+
175
+
176
+ def evaluate_benchmark_DB(data_path, root_dir, model):
177
+ meta_file = data_path + '/drawbench.json'
178
+ model_list = os.listdir(root_dir)
179
+
180
+
181
+ score = {}
182
+ for model_id in model_list:
183
+ image_folder = os.path.join(root_dir, model_id)
184
+ dataset = BenchmarkDataset(meta_file, image_folder, preprocess_val, tokenizer)
185
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=4, collate_fn=collate_eval)
186
+ score[model_id] = 0
187
+ with torch.no_grad():
188
+ for batch in tqdm(dataloader):
189
+ images, texts = batch
190
+ images = images.to(device=device, non_blocking=True)
191
+ texts = texts.to(device=device, non_blocking=True)
192
+
193
+ with torch.cuda.amp.autocast():
194
+ outputs = model(images, texts)
195
+ image_features, text_features = outputs["image_features"], outputs["text_features"]
196
+ logits_per_image = image_features @ text_features.T
197
+ diag = torch.diagonal(logits_per_image)
198
+ score[model_id] += torch.sum(diag).cpu().item()
199
+ score[model_id] = score[model_id] / len(dataset)
200
+ # with open('logs/benchmark_score_DB.json', 'w') as f:
201
+ # json.dump(score, f)
202
+ print('-----------drawbench score ---------------- ')
203
+ for model, data in score.items():
204
+ print(model, '\t', '\t', np.mean(data))
205
+
206
+
207
+ if args.data_type == 'ImageReward':
208
+ evaluate_IR(args.data_path, args.image_path, model)
209
+ elif args.data_type == 'test':
210
+ evaluate_rank(args.data_path, args.image_path, model)
211
+ elif args.data_type == 'benchmark':
212
+ evaluate_benchmark(args.data_path, args.image_path, model)
213
+ elif args.data_type == 'drawbench':
214
+ evaluate_benchmark_DB(args.data_path, args.image_path, model)
215
+ else:
216
+ raise NotImplementedError
217
+
218
+
219
+
220
+
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=1.9.0
2
+ torchvision
3
+ regex
4
+ ftfy
5
+ einops
6
+ pandas
7
+ braceexpand
8
+ fsspec
9
+ tqdm
10
+ huggingface_hub
11
+ sentencepiece
12
+ protobuf<4
13
+ timm
14
+ transformers
15
+ webdataset
16
+ pyarrow
17
+ pytest-split==0.8.0
18
+ pytest==7.2.0
score.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ from src.open_clip import create_model_and_transforms, get_tokenizer
4
+ import warnings
5
+ import argparse
6
+
7
+ warnings.filterwarnings("ignore", category=UserWarning)
8
+
9
+ # Create an argument parser
10
+ parser = argparse.ArgumentParser()
11
+ parser.add_argument('--image_path', type=str, required=True, help='Path to the input image')
12
+ parser.add_argument('--prompt', type=str, required=True, help='Text prompt')
13
+ parser.add_argument('--checkpoint', type=str, default='../HPSv2.pt', help='Path to the model checkpoint')
14
+
15
+ args = parser.parse_args()
16
+
17
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
18
+ model, preprocess_train, preprocess_val = create_model_and_transforms(
19
+ 'ViT-H-14',
20
+ 'laion2B-s32B-b79K',
21
+ precision='amp',
22
+ device=device,
23
+ jit=False,
24
+ force_quick_gelu=False,
25
+ force_custom_text=False,
26
+ force_patch_dropout=False,
27
+ force_image_size=None,
28
+ pretrained_image=False,
29
+ image_mean=None,
30
+ image_std=None,
31
+ light_augmentation=True,
32
+ aug_cfg={},
33
+ output_dict=True,
34
+ with_score_predictor=False,
35
+ with_region_predictor=False
36
+ )
37
+
38
+ checkpoint = torch.load(args.checkpoint)
39
+ model.load_state_dict(checkpoint['state_dict'])
40
+ tokenizer = get_tokenizer('ViT-H-14')
41
+ model.eval()
42
+
43
+ # Load your image and prompt
44
+ with torch.no_grad():
45
+ # Process the image
46
+ image = preprocess_val(Image.open(args.image_path)).unsqueeze(0).to(device=device, non_blocking=True)
47
+ # Process the prompt
48
+ text = tokenizer([args.prompt]).to(device=device, non_blocking=True)
49
+ # Calculate the HPS
50
+ with torch.cuda.amp.autocast():
51
+ outputs = model(image, text)
52
+ image_features, text_features = outputs["image_features"], outputs["text_features"]
53
+ logits_per_image = image_features @ text_features.T
54
+
55
+ hps_score = torch.diagonal(logits_per_image).cpu().numpy()
56
+ print('HPSv2 score:', hps_score[0])
src/__init__.py ADDED
File without changes
src/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (170 Bytes). View file
 
src/open_clip/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .coca_model import CoCa
2
+ from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
3
+ from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer, create_loss
4
+ from .factory import list_models, add_model_config, get_model_config, load_checkpoint
5
+ from .loss import ClipLoss, DistillClipLoss, CoCaLoss
6
+ from .model import CLIP, CustomTextCLIP, CLIPTextCfg, CLIPVisionCfg, \
7
+ convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype
8
+ from .openai import load_openai_model, list_openai_models
9
+ from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model, \
10
+ get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained
11
+ from .push_to_hf_hub import push_pretrained_to_hf_hub, push_to_hf_hub
12
+ from .tokenizer import SimpleTokenizer, tokenize, decode
13
+ from .transform import image_transform, AugmentationCfg
14
+ from .utils import freeze_batch_norm_2d
src/open_clip/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (1.53 kB). View file
 
src/open_clip/__pycache__/coca_model.cpython-38.pyc ADDED
Binary file (9.72 kB). View file
 
src/open_clip/__pycache__/constants.cpython-38.pyc ADDED
Binary file (287 Bytes). View file
 
src/open_clip/__pycache__/factory.cpython-38.pyc ADDED
Binary file (10.2 kB). View file
 
src/open_clip/__pycache__/hf_configs.cpython-38.pyc ADDED
Binary file (638 Bytes). View file
 
src/open_clip/__pycache__/hf_model.cpython-38.pyc ADDED
Binary file (5.91 kB). View file
 
src/open_clip/__pycache__/loss.cpython-38.pyc ADDED
Binary file (7.47 kB). View file
 
src/open_clip/__pycache__/model.cpython-38.pyc ADDED
Binary file (13.7 kB). View file
 
src/open_clip/__pycache__/modified_resnet.cpython-38.pyc ADDED
Binary file (6.32 kB). View file
 
src/open_clip/__pycache__/openai.cpython-38.pyc ADDED
Binary file (4.78 kB). View file
 
src/open_clip/__pycache__/pretrained.cpython-38.pyc ADDED
Binary file (11.2 kB). View file
 
src/open_clip/__pycache__/push_to_hf_hub.cpython-38.pyc ADDED
Binary file (5.26 kB). View file
 
src/open_clip/__pycache__/timm_model.cpython-38.pyc ADDED
Binary file (4.02 kB). View file
 
src/open_clip/__pycache__/tokenizer.cpython-38.pyc ADDED
Binary file (8.79 kB). View file
 
src/open_clip/__pycache__/transform.cpython-38.pyc ADDED
Binary file (4.92 kB). View file
 
src/open_clip/__pycache__/transformer.cpython-38.pyc ADDED
Binary file (20.3 kB). View file
 
src/open_clip/__pycache__/utils.cpython-38.pyc ADDED
Binary file (2.27 kB). View file
 
src/open_clip/__pycache__/version.cpython-38.pyc ADDED
Binary file (201 Bytes). View file
 
src/open_clip/bpe_simple_vocab_16e6.txt.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
3
+ size 1356917
src/open_clip/coca_model.py ADDED
@@ -0,0 +1,458 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+ import numpy as np
7
+ from dataclasses import dataclass
8
+
9
+ from .transformer import (
10
+ LayerNormFp32,
11
+ LayerNorm,
12
+ QuickGELU,
13
+ MultimodalTransformer,
14
+ )
15
+ from .model import CLIPTextCfg, CLIPVisionCfg, _build_vision_tower, _build_text_tower
16
+
17
+ try:
18
+ from transformers import (
19
+ BeamSearchScorer,
20
+ LogitsProcessorList,
21
+ TopPLogitsWarper,
22
+ TopKLogitsWarper,
23
+ RepetitionPenaltyLogitsProcessor,
24
+ MinLengthLogitsProcessor,
25
+ MaxLengthCriteria,
26
+ StoppingCriteriaList
27
+ )
28
+
29
+ GENERATION_TYPES = {
30
+ "top_k": TopKLogitsWarper,
31
+ "top_p": TopPLogitsWarper,
32
+ "beam_search": "beam_search"
33
+ }
34
+ _has_transformers = True
35
+ except ImportError as e:
36
+ GENERATION_TYPES = {
37
+ "top_k": None,
38
+ "top_p": None,
39
+ "beam_search": "beam_search"
40
+ }
41
+ _has_transformers = False
42
+
43
+
44
+ @dataclass
45
+ class MultimodalCfg(CLIPTextCfg):
46
+ mlp_ratio: int = 4
47
+ dim_head: int = 64
48
+ heads: int = 8
49
+ n_queries: int = 256
50
+ attn_pooler_heads: int = 8
51
+
52
+
53
+ def _build_text_decoder_tower(
54
+ embed_dim,
55
+ multimodal_cfg,
56
+ quick_gelu: bool = False,
57
+ cast_dtype: Optional[torch.dtype] = None,
58
+ ):
59
+ multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg
60
+ act_layer = QuickGELU if quick_gelu else nn.GELU
61
+ norm_layer = (
62
+ LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
63
+ )
64
+
65
+ decoder = MultimodalTransformer(
66
+ context_length=multimodal_cfg.context_length,
67
+ width=multimodal_cfg.width,
68
+ heads=multimodal_cfg.heads,
69
+ layers=multimodal_cfg.layers,
70
+ ls_init_value=multimodal_cfg.ls_init_value,
71
+ output_dim=embed_dim,
72
+ act_layer=act_layer,
73
+ norm_layer=norm_layer,
74
+ )
75
+
76
+ return decoder
77
+
78
+
79
+ class CoCa(nn.Module):
80
+ def __init__(
81
+ self,
82
+ embed_dim,
83
+ multimodal_cfg: MultimodalCfg,
84
+ text_cfg: CLIPTextCfg,
85
+ vision_cfg: CLIPVisionCfg,
86
+ quick_gelu: bool = False,
87
+ cast_dtype: Optional[torch.dtype] = None,
88
+ pad_id: int = 0,
89
+ ):
90
+ super().__init__()
91
+ multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg
92
+ text_cfg = CLIPTextCfg(**text_cfg) if isinstance(text_cfg, dict) else text_cfg
93
+ vision_cfg = CLIPVisionCfg(**vision_cfg) if isinstance(vision_cfg, dict) else vision_cfg
94
+
95
+ self.text = _build_text_tower(
96
+ embed_dim=embed_dim,
97
+ text_cfg=text_cfg,
98
+ quick_gelu=quick_gelu,
99
+ cast_dtype=cast_dtype,
100
+ )
101
+
102
+ vocab_size = (
103
+ text_cfg.vocab_size # for hf models
104
+ if hasattr(text_cfg, "hf_model_name") and text_cfg.hf_model_name is not None
105
+ else text_cfg.vocab_size
106
+ )
107
+
108
+ self.visual = _build_vision_tower(
109
+ embed_dim=embed_dim,
110
+ vision_cfg=vision_cfg,
111
+ quick_gelu=quick_gelu,
112
+ cast_dtype=cast_dtype,
113
+ )
114
+
115
+ self.text_decoder = _build_text_decoder_tower(
116
+ vocab_size,
117
+ multimodal_cfg=multimodal_cfg,
118
+ quick_gelu=quick_gelu,
119
+ cast_dtype=cast_dtype,
120
+ )
121
+
122
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
123
+ self.pad_id = pad_id
124
+
125
+ @torch.jit.ignore
126
+ def set_grad_checkpointing(self, enable=True):
127
+ self.visual.set_grad_checkpointing(enable)
128
+ self.text.set_grad_checkpointing(enable)
129
+ self.text_decoder.set_grad_checkpointing(enable)
130
+
131
+ def _encode_image(self, images, normalize=True):
132
+ image_latent, tokens_embs = self.visual(images)
133
+ image_latent = F.normalize(image_latent, dim=-1) if normalize else image_latent
134
+ return image_latent, tokens_embs
135
+
136
+ def _encode_text(self, text, normalize=True, embed_cls=True):
137
+ text = text[:, :-1] if embed_cls else text # make space for CLS token
138
+ text_latent, token_emb = self.text(text)
139
+ text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent
140
+ return text_latent, token_emb
141
+
142
+ def encode_image(self, images, normalize=True):
143
+ image_latent, _ = self._encode_image(images, normalize=normalize)
144
+ return image_latent
145
+
146
+ def encode_text(self, text, normalize=True, embed_cls=True):
147
+ text_latent, _ = self._encode_text(text, normalize=normalize, embed_cls=embed_cls)
148
+ return text_latent
149
+
150
+ def forward(self, image, text, embed_cls=True, image_latent=None, image_embs=None):
151
+ text_latent, token_embs = self._encode_text(text, embed_cls=embed_cls)
152
+ if image_latent is None or image_embs is None:
153
+ image_latent, image_embs = self._encode_image(image)
154
+
155
+ # TODO: add assertion to avoid bugs?
156
+ labels = text[:, -token_embs.shape[1]:]
157
+
158
+ logits = self.text_decoder(image_embs, token_embs)
159
+ return {
160
+ "image_features": image_latent,
161
+ "text_features": text_latent,
162
+ "logits": logits,
163
+ "labels": labels,
164
+ "logit_scale": self.logit_scale.exp()
165
+ }
166
+
167
+ def generate(
168
+ self,
169
+ image,
170
+ text=None,
171
+ seq_len=30,
172
+ max_seq_len=77,
173
+ temperature=1.,
174
+ generation_type="beam_search",
175
+ top_p=0.1, # keep tokens in the 1 - top_p quantile
176
+ top_k=1, # keeps the top_k most probable tokens
177
+ pad_token_id=None,
178
+ eos_token_id=None,
179
+ sot_token_id=None,
180
+ num_beams=6,
181
+ num_beam_groups=3,
182
+ min_seq_len=5,
183
+ stopping_criteria=None,
184
+ repetition_penalty=1.0,
185
+ fixed_output_length=False # if True output.shape == (batch_size, seq_len)
186
+ ):
187
+ # taking many ideas and components from HuggingFace GenerationMixin
188
+ # https://huggingface.co/docs/transformers/main/en/main_classes/text_generation
189
+ assert _has_transformers, "Please install transformers for generate functionality. `pip install transformers`."
190
+ assert seq_len > min_seq_len, "seq_len must be larger than min_seq_len"
191
+
192
+ with torch.no_grad():
193
+ sot_token_id = 49406 if sot_token_id is None else sot_token_id
194
+ eos_token_id = 49407 if eos_token_id is None else eos_token_id
195
+ pad_token_id = self.pad_id if pad_token_id is None else pad_token_id
196
+ logit_processor = LogitsProcessorList(
197
+ [
198
+ MinLengthLogitsProcessor(min_seq_len, eos_token_id),
199
+ RepetitionPenaltyLogitsProcessor(repetition_penalty),
200
+ ]
201
+ )
202
+
203
+ if stopping_criteria is None:
204
+ stopping_criteria = [MaxLengthCriteria(max_length=seq_len)]
205
+
206
+ stopping_criteria = StoppingCriteriaList(
207
+ stopping_criteria
208
+ )
209
+
210
+ device = image.device
211
+
212
+ if generation_type == "beam_search":
213
+ output = self._generate_beamsearch(
214
+ image_inputs = image,
215
+ pad_token_id=pad_token_id,
216
+ eos_token_id=eos_token_id,
217
+ sot_token_id=sot_token_id,
218
+ num_beams=num_beams,
219
+ num_beam_groups=num_beam_groups,
220
+ min_seq_len=min_seq_len,
221
+ stopping_criteria=stopping_criteria,
222
+ logit_processor=logit_processor,
223
+ )
224
+ if fixed_output_length and output.shape[1] < seq_len:
225
+ return torch.cat(
226
+ (output, torch.ones(output.shape[0], seq_len-output.shape[1], device=device, dtype=output.dtype) * self.pad_id),
227
+ dim=1
228
+ )
229
+ return output
230
+
231
+ elif generation_type == "top_p":
232
+ logit_warper = GENERATION_TYPES[generation_type](top_p)
233
+ elif generation_type == "top_k":
234
+ logit_warper = GENERATION_TYPES[generation_type](top_k)
235
+ else:
236
+ raise ValueError(
237
+ f"generation_type has to be one of "
238
+ f"{'| ' + ' | '.join(list(GENERATION_TYPES.keys())) + ' |'}."
239
+ )
240
+
241
+ image_latent, image_embs = self._encode_image(image)
242
+
243
+ if text is None:
244
+ text = torch.ones((image.shape[0], 1), device=device, dtype=torch.long) * sot_token_id
245
+
246
+ was_training = self.training
247
+ num_dims = len(text.shape)
248
+
249
+ if num_dims == 1:
250
+ text = text[None, :]
251
+
252
+ cur_len = text.shape[1]
253
+ self.eval()
254
+ out = text
255
+
256
+ while True:
257
+ x = out[:, -max_seq_len:]
258
+ cur_len = x.shape[1]
259
+ logits = self(image, x, image_latent=image_latent, image_embs=image_embs, embed_cls=False)["logits"][:, -1]
260
+ mask = (out[:, -1] == eos_token_id) | (out[:, -1] == pad_token_id)
261
+ sample = torch.ones((out.shape[0], 1), device=device, dtype=torch.long) * pad_token_id
262
+
263
+ if mask.all():
264
+ if not fixed_output_length:
265
+ break
266
+ else:
267
+ logits = logits[~mask, :]
268
+ filtered_logits = logit_processor(x[~mask, :], logits)
269
+ filtered_logits = logit_warper(x[~mask, :], filtered_logits)
270
+ probs = F.softmax(filtered_logits / temperature, dim=-1)
271
+
272
+ if (cur_len + 1 == seq_len):
273
+ sample[~mask, :] = torch.ones((sum(~mask), 1), device=device, dtype=torch.long) * eos_token_id
274
+ else:
275
+ sample[~mask, :] = torch.multinomial(probs, 1)
276
+
277
+ out = torch.cat((out, sample), dim=-1)
278
+
279
+ cur_len += 1
280
+
281
+ if stopping_criteria(out, None):
282
+ break
283
+
284
+ if num_dims == 1:
285
+ out = out.squeeze(0)
286
+
287
+ self.train(was_training)
288
+ return out
289
+
290
+ def _generate_beamsearch(
291
+ self,
292
+ image_inputs,
293
+ pad_token_id=None,
294
+ eos_token_id=None,
295
+ sot_token_id=None,
296
+ num_beams=6,
297
+ num_beam_groups=3,
298
+ min_seq_len=5,
299
+ stopping_criteria=None,
300
+ logit_processor=None,
301
+ logit_warper=None,
302
+ ):
303
+ device = image_inputs.device
304
+ batch_size = image_inputs.shape[0]
305
+ image_inputs = torch.repeat_interleave(image_inputs, num_beams, dim=0)
306
+ image_latent, image_embs = self._encode_image(image_inputs)
307
+
308
+ input_ids = torch.ones((batch_size * num_beams, 1), device=device, dtype=torch.long)
309
+ input_ids = input_ids * sot_token_id
310
+ beam_scorer = BeamSearchScorer(
311
+ batch_size=batch_size,
312
+ num_beams=num_beams,
313
+ device=device,
314
+ num_beam_groups=num_beam_groups,
315
+ )
316
+ # instantiate logits processors
317
+ logits_processor = (
318
+ LogitsProcessorList([MinLengthLogitsProcessor(min_seq_len, eos_token_id=eos_token_id)])
319
+ if logit_processor is None
320
+ else logit_processor
321
+ )
322
+
323
+ batch_size = len(beam_scorer._beam_hyps)
324
+ num_beams = beam_scorer.num_beams
325
+ num_beam_groups = beam_scorer.num_beam_groups
326
+ num_sub_beams = num_beams // num_beam_groups
327
+ batch_beam_size, cur_len = input_ids.shape
328
+ beam_indices = None
329
+
330
+ if num_beams * batch_size != batch_beam_size:
331
+ raise ValueError(
332
+ f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
333
+ )
334
+
335
+ beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device)
336
+ # initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in
337
+ # the same group don't produce same tokens everytime.
338
+ beam_scores[:, ::num_sub_beams] = 0
339
+ beam_scores = beam_scores.view((batch_size * num_beams,))
340
+
341
+ while True:
342
+
343
+ # predicted tokens in cur_len step
344
+ current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device)
345
+
346
+ # indices which will form the beams in the next time step
347
+ reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device)
348
+
349
+ # do one decoder step on all beams of all sentences in batch
350
+ model_inputs = prepare_inputs_for_generation(input_ids=input_ids, image_inputs=image_inputs)
351
+ outputs = self(
352
+ model_inputs['images'],
353
+ model_inputs['text'],
354
+ embed_cls=False,
355
+ image_latent=image_latent,
356
+ image_embs=image_embs
357
+ )
358
+
359
+ for beam_group_idx in range(num_beam_groups):
360
+ group_start_idx = beam_group_idx * num_sub_beams
361
+ group_end_idx = min(group_start_idx + num_sub_beams, num_beams)
362
+ group_size = group_end_idx - group_start_idx
363
+
364
+ # indices of beams of current group among all sentences in batch
365
+ batch_group_indices = []
366
+
367
+ for batch_idx in range(batch_size):
368
+ batch_group_indices.extend(
369
+ [batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)]
370
+ )
371
+ group_input_ids = input_ids[batch_group_indices]
372
+
373
+ # select outputs of beams of currentg group only
374
+ next_token_logits = outputs['logits'][batch_group_indices, -1, :]
375
+ vocab_size = next_token_logits.shape[-1]
376
+
377
+ next_token_scores_processed = logits_processor(
378
+ group_input_ids, next_token_logits, current_tokens=current_tokens, beam_group_idx=beam_group_idx
379
+ )
380
+ next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1)
381
+ next_token_scores = next_token_scores.expand_as(next_token_scores_processed)
382
+
383
+ # reshape for beam search
384
+ next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size)
385
+
386
+ next_token_scores, next_tokens = torch.topk(
387
+ next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True
388
+ )
389
+
390
+ next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor")
391
+ next_tokens = next_tokens % vocab_size
392
+
393
+ # stateless
394
+ process_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
395
+ beam_outputs = beam_scorer.process(
396
+ group_input_ids,
397
+ next_token_scores,
398
+ next_tokens,
399
+ next_indices,
400
+ pad_token_id=pad_token_id,
401
+ eos_token_id=eos_token_id,
402
+ beam_indices=process_beam_indices,
403
+ )
404
+ beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"]
405
+ beam_next_tokens = beam_outputs["next_beam_tokens"]
406
+ beam_idx = beam_outputs["next_beam_indices"]
407
+
408
+ input_ids[batch_group_indices] = group_input_ids[beam_idx]
409
+ group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
410
+ current_tokens[batch_group_indices] = group_input_ids[:, -1]
411
+
412
+ # (beam_idx // group_size) -> batch_idx
413
+ # (beam_idx % group_size) -> offset of idx inside the group
414
+ reordering_indices[batch_group_indices] = (
415
+ num_beams * torch.div(beam_idx, group_size, rounding_mode="floor") + group_start_idx + (beam_idx % group_size)
416
+ )
417
+
418
+ input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1)
419
+
420
+ # increase cur_len
421
+ cur_len = cur_len + 1
422
+ if beam_scorer.is_done or stopping_criteria(input_ids, None):
423
+ break
424
+
425
+ final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
426
+ sequence_outputs = beam_scorer.finalize(
427
+ input_ids,
428
+ beam_scores,
429
+ next_tokens,
430
+ next_indices,
431
+ pad_token_id=pad_token_id,
432
+ eos_token_id=eos_token_id,
433
+ max_length=stopping_criteria.max_length,
434
+ beam_indices=final_beam_indices,
435
+ )
436
+ return sequence_outputs['sequences']
437
+
438
+
439
+ def prepare_inputs_for_generation(input_ids, image_inputs, past=None, **kwargs):
440
+ if past:
441
+ input_ids = input_ids[:, -1].unsqueeze(-1)
442
+
443
+ attention_mask = kwargs.get("attention_mask", None)
444
+ position_ids = kwargs.get("position_ids", None)
445
+
446
+ if attention_mask is not None and position_ids is None:
447
+ # create position_ids on the fly for batch generation
448
+ position_ids = attention_mask.long().cumsum(-1) - 1
449
+ position_ids.masked_fill_(attention_mask == 0, 1)
450
+ else:
451
+ position_ids = None
452
+ return {
453
+ "text": input_ids,
454
+ "images": image_inputs,
455
+ "past_key_values": past,
456
+ "position_ids": position_ids,
457
+ "attention_mask": attention_mask,
458
+ }
src/open_clip/constants.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
2
+ OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
src/open_clip/factory.py ADDED
@@ -0,0 +1,433 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+ import pathlib
5
+ import re
6
+ from copy import deepcopy
7
+ from pathlib import Path
8
+ from turtle import forward
9
+ from typing import Any, Dict, Optional, Tuple, Union
10
+
11
+ import torch
12
+
13
+ from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
14
+ from .model import CLIP, CustomTextCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\
15
+ resize_pos_embed, get_cast_dtype
16
+ from .coca_model import CoCa
17
+ from .loss import ClipLoss, DistillClipLoss, CoCaLoss
18
+ from .openai import load_openai_model
19
+ from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained, list_pretrained_tags_by_model, download_pretrained_from_hf
20
+ from .transform import image_transform, AugmentationCfg
21
+ from .tokenizer import HFTokenizer, tokenize
22
+
23
+
24
+ HF_HUB_PREFIX = 'hf-hub:'
25
+ _MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
26
+ _MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs
27
+
28
+
29
+ def _natural_key(string_):
30
+ return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
31
+
32
+
33
+ def _rescan_model_configs():
34
+ global _MODEL_CONFIGS
35
+
36
+ config_ext = ('.json',)
37
+ config_files = []
38
+ for config_path in _MODEL_CONFIG_PATHS:
39
+ if config_path.is_file() and config_path.suffix in config_ext:
40
+ config_files.append(config_path)
41
+ elif config_path.is_dir():
42
+ for ext in config_ext:
43
+ config_files.extend(config_path.glob(f'*{ext}'))
44
+
45
+ for cf in config_files:
46
+ with open(cf, 'r') as f:
47
+ model_cfg = json.load(f)
48
+ if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')):
49
+ _MODEL_CONFIGS[cf.stem] = model_cfg
50
+
51
+ _MODEL_CONFIGS = {k: v for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))}
52
+
53
+
54
+ _rescan_model_configs() # initial populate of model config registry
55
+
56
+
57
+ def list_models():
58
+ """ enumerate available model architectures based on config files """
59
+ return list(_MODEL_CONFIGS.keys())
60
+
61
+
62
+ def add_model_config(path):
63
+ """ add model config path or file and update registry """
64
+ if not isinstance(path, Path):
65
+ path = Path(path)
66
+ _MODEL_CONFIG_PATHS.append(path)
67
+ _rescan_model_configs()
68
+
69
+
70
+ def get_model_config(model_name):
71
+ if model_name in _MODEL_CONFIGS:
72
+ return deepcopy(_MODEL_CONFIGS[model_name])
73
+ else:
74
+ return None
75
+
76
+
77
+ def get_tokenizer(model_name):
78
+ if model_name.startswith(HF_HUB_PREFIX):
79
+ tokenizer = HFTokenizer(model_name[len(HF_HUB_PREFIX):])
80
+ else:
81
+ config = get_model_config(model_name)
82
+ tokenizer = HFTokenizer(
83
+ config['text_cfg']['hf_tokenizer_name']) if 'hf_tokenizer_name' in config['text_cfg'] else tokenize
84
+ return tokenizer
85
+
86
+
87
+ def load_state_dict(checkpoint_path: str, map_location='cpu'):
88
+ checkpoint = torch.load(checkpoint_path, map_location=map_location)
89
+ if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
90
+ state_dict = checkpoint['state_dict']
91
+ else:
92
+ state_dict = checkpoint
93
+ if next(iter(state_dict.items()))[0].startswith('module'):
94
+ state_dict = {k[7:]: v for k, v in state_dict.items()}
95
+ return state_dict
96
+
97
+
98
+ def load_checkpoint(model, checkpoint_path, strict=True):
99
+ state_dict = load_state_dict(checkpoint_path)
100
+ # detect old format and make compatible with new format
101
+ if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'):
102
+ state_dict = convert_to_custom_text_state_dict(state_dict)
103
+ resize_pos_embed(state_dict, model)
104
+ incompatible_keys = model.load_state_dict(state_dict, strict=strict)
105
+ return incompatible_keys
106
+
107
+
108
+ def create_model(
109
+ model_name: str,
110
+ pretrained: Optional[str] = None,
111
+ precision: str = 'fp32',
112
+ device: Union[str, torch.device] = 'cpu',
113
+ jit: bool = False,
114
+ force_quick_gelu: bool = False,
115
+ force_custom_text: bool = False,
116
+ force_patch_dropout: Optional[float] = None,
117
+ force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
118
+ pretrained_image: bool = False,
119
+ pretrained_hf: bool = True,
120
+ cache_dir: Optional[str] = None,
121
+ output_dict: Optional[bool] = None,
122
+ require_pretrained: bool = False,
123
+ ):
124
+ has_hf_hub_prefix = model_name.startswith(HF_HUB_PREFIX)
125
+ if has_hf_hub_prefix:
126
+ model_id = model_name[len(HF_HUB_PREFIX):]
127
+ checkpoint_path = download_pretrained_from_hf(model_id, cache_dir=cache_dir)
128
+ config_path = download_pretrained_from_hf(model_id, filename='open_clip_config.json', cache_dir=cache_dir)
129
+
130
+ with open(config_path, 'r', encoding='utf-8') as f:
131
+ config = json.load(f)
132
+ pretrained_cfg = config['preprocess_cfg']
133
+ model_cfg = config['model_cfg']
134
+ else:
135
+ model_name = model_name.replace('/', '-') # for callers using old naming with / in ViT names
136
+ checkpoint_path = None
137
+ pretrained_cfg = {}
138
+ model_cfg = None
139
+
140
+ if isinstance(device, str):
141
+ device = torch.device(device)
142
+
143
+ if pretrained and pretrained.lower() == 'openai':
144
+ logging.info(f'Loading pretrained {model_name} from OpenAI.')
145
+ model = load_openai_model(
146
+ model_name,
147
+ precision=precision,
148
+ device=device,
149
+ jit=jit,
150
+ cache_dir=cache_dir,
151
+ )
152
+
153
+ # to always output dict even if it is clip
154
+ if output_dict and hasattr(model, "output_dict"):
155
+ model.output_dict = True
156
+ else:
157
+ model_cfg = model_cfg or get_model_config(model_name)
158
+ if model_cfg is not None:
159
+ logging.info(f'Loaded {model_name} model config.')
160
+ else:
161
+ logging.error(f'Model config for {model_name} not found; available models {list_models()}.')
162
+ raise RuntimeError(f'Model config for {model_name} not found.')
163
+
164
+ if force_quick_gelu:
165
+ # override for use of QuickGELU on non-OpenAI transformer models
166
+ model_cfg["quick_gelu"] = True
167
+
168
+ if force_patch_dropout is not None:
169
+ # override the default patch dropout value
170
+ model_cfg["vision_cfg"]["patch_dropout"] = force_patch_dropout
171
+
172
+ if force_image_size is not None:
173
+ # override model config's image size
174
+ model_cfg["vision_cfg"]["image_size"] = force_image_size
175
+
176
+ if pretrained_image:
177
+ if 'timm_model_name' in model_cfg.get('vision_cfg', {}):
178
+ # pretrained weight loading for timm models set via vision_cfg
179
+ model_cfg['vision_cfg']['timm_model_pretrained'] = True
180
+ else:
181
+ assert False, 'pretrained image towers currently only supported for timm models'
182
+
183
+ cast_dtype = get_cast_dtype(precision)
184
+ is_hf_model = 'hf_model_name' in model_cfg.get('text_cfg', {})
185
+ custom_text = model_cfg.pop('custom_text', False) or force_custom_text or is_hf_model
186
+
187
+ if custom_text:
188
+ if is_hf_model:
189
+ model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf
190
+ if "coca" in model_name:
191
+ model = CoCa(**model_cfg, cast_dtype=cast_dtype)
192
+ else:
193
+ model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype)
194
+ else:
195
+ model = CLIP(**model_cfg, cast_dtype=cast_dtype)
196
+
197
+ pretrained_loaded = False
198
+ if pretrained:
199
+ checkpoint_path = ''
200
+ pretrained_cfg = get_pretrained_cfg(model_name, pretrained)
201
+ if pretrained_cfg:
202
+ checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir)
203
+ elif os.path.exists(pretrained):
204
+ checkpoint_path = pretrained
205
+
206
+ if checkpoint_path:
207
+ logging.info(f'Loading pretrained {model_name} weights ({pretrained}).')
208
+ load_checkpoint(model, checkpoint_path)
209
+ else:
210
+ error_str = (
211
+ f'Pretrained weights ({pretrained}) not found for model {model_name}.'
212
+ f'Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.')
213
+ logging.warning(error_str)
214
+ raise RuntimeError(error_str)
215
+ pretrained_loaded = True
216
+ elif has_hf_hub_prefix:
217
+ logging.info(f'Loading pretrained {model_name} weights ({pretrained}).')
218
+ load_checkpoint(model, checkpoint_path)
219
+ pretrained_loaded = True
220
+
221
+ if require_pretrained and not pretrained_loaded:
222
+ # callers of create_model_from_pretrained always expect pretrained weights
223
+ raise RuntimeError(
224
+ f'Pretrained weights were required for (model: {model_name}, pretrained: {pretrained}) but not loaded.')
225
+
226
+ model.to(device=device)
227
+ if precision in ("fp16", "bf16"):
228
+ convert_weights_to_lp(model, dtype=torch.bfloat16 if precision == 'bf16' else torch.float16)
229
+
230
+ # set image / mean metadata from pretrained_cfg if available, or use default
231
+ model.visual.image_mean = pretrained_cfg.get('mean', None) or OPENAI_DATASET_MEAN
232
+ model.visual.image_std = pretrained_cfg.get('std', None) or OPENAI_DATASET_STD
233
+
234
+ # to always output dict even if it is clip
235
+ if output_dict and hasattr(model, "output_dict"):
236
+ model.output_dict = True
237
+
238
+ if jit:
239
+ model = torch.jit.script(model)
240
+
241
+ return model
242
+
243
+
244
+ def create_loss(args):
245
+ if args.distill:
246
+ return DistillClipLoss(
247
+ local_loss=args.local_loss,
248
+ gather_with_grad=args.gather_with_grad,
249
+ cache_labels=True,
250
+ rank=args.rank,
251
+ world_size=args.world_size,
252
+ use_horovod=args.horovod,
253
+ )
254
+ elif "coca" in args.model.lower():
255
+ return CoCaLoss(
256
+ caption_loss_weight=args.coca_caption_loss_weight,
257
+ clip_loss_weight=args.coca_contrastive_loss_weight,
258
+ local_loss=args.local_loss,
259
+ gather_with_grad=args.gather_with_grad,
260
+ cache_labels=True,
261
+ rank=args.rank,
262
+ world_size=args.world_size,
263
+ use_horovod=args.horovod,
264
+ )
265
+ return ClipLoss(
266
+ local_loss=args.local_loss,
267
+ gather_with_grad=args.gather_with_grad,
268
+ cache_labels=True,
269
+ rank=args.rank,
270
+ world_size=args.world_size,
271
+ use_horovod=args.horovod,
272
+ )
273
+
274
+ class MLP(torch.nn.Module):
275
+ def __init__(self, input_size):
276
+ super().__init__()
277
+ self.input_size = input_size
278
+ self.layers = torch.nn.Sequential(
279
+ torch.nn.Linear(self.input_size, 1024),
280
+ torch.nn.Dropout(0.2),
281
+ torch.nn.Linear(1024, 128),
282
+ torch.nn.Dropout(0.2),
283
+ torch.nn.Linear(128, 64),
284
+ torch.nn.Dropout(0.1),
285
+ torch.nn.Linear(64, 16),
286
+ torch.nn.Linear(16, 1)
287
+ )
288
+
289
+ def forward(self, x):
290
+ return self.layers(x)
291
+
292
+ # class semantic_head(torch.nn.Module):
293
+ # def __init__(self, input_size):
294
+ # super().__init__()
295
+ # self.input_size = input_size # for ViT-L-14 is 1024
296
+ # self.seg_head = torch.nn.Sequential(
297
+ # torch.nn.Linear(input_size, 128),
298
+ # torch.nn.Dropout(0.2),
299
+ # torch.nn.Linear(128, 64),
300
+ # torch.nn.Dropout(0.1),
301
+ # torch.nn.Linear(64, 16),
302
+ # torch.nn.Linear(16, 1),
303
+ # )
304
+ # self.sigmoid = torch.nn.Sigmoid()
305
+
306
+ # def forward(self, x):
307
+ # return self.sigmoid(self.seg_head(x))
308
+
309
+ def create_model_and_transforms(
310
+ model_name: str,
311
+ pretrained: Optional[str] = None,
312
+ precision: str = 'fp32',
313
+ device: Union[str, torch.device] = 'cpu',
314
+ jit: bool = False,
315
+ force_quick_gelu: bool = False,
316
+ force_custom_text: bool = False,
317
+ force_patch_dropout: Optional[float] = None,
318
+ force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
319
+ pretrained_image: bool = False,
320
+ pretrained_hf: bool = True,
321
+ image_mean: Optional[Tuple[float, ...]] = None,
322
+ image_std: Optional[Tuple[float, ...]] = None,
323
+ aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
324
+ cache_dir: Optional[str] = None,
325
+ light_augmentation = False,
326
+ output_dict: Optional[bool] = None,
327
+ with_score_predictor: bool = False,
328
+ with_region_predictor: bool = False
329
+ ):
330
+ model = create_model(
331
+ model_name,
332
+ pretrained,
333
+ precision=precision,
334
+ device=device,
335
+ jit=jit,
336
+ force_quick_gelu=force_quick_gelu,
337
+ force_custom_text=force_custom_text,
338
+ force_patch_dropout=force_patch_dropout,
339
+ force_image_size=force_image_size,
340
+ pretrained_image=pretrained_image,
341
+ pretrained_hf=pretrained_hf,
342
+ cache_dir=cache_dir,
343
+ output_dict=output_dict,
344
+ )
345
+
346
+ image_mean = image_mean or getattr(model.visual, 'image_mean', None)
347
+ image_std = image_std or getattr(model.visual, 'image_std', None)
348
+
349
+ if with_score_predictor:
350
+ model.score_predictor = MLP(model.visual.proj.size(1)).to(device=device, dtype=model.visual.proj.dtype)
351
+
352
+ if with_region_predictor:
353
+ # model.region_predictor = semantic_head(model.visual.proj.size(1)).to(device=device, dtype=model.visual.proj.dtype)
354
+ model.region_predictor = torch.nn.Linear(model.visual.proj.size(0), 1).to(device=device, dtype=model.visual.proj.dtype)
355
+ # preprocess_train = image_transform_region(
356
+ # model.visual.image_size,
357
+ # is_train=True,
358
+ # mean=image_mean,
359
+ # std=image_std
360
+ # )
361
+ # preprocess_val = image_transform_region(
362
+ # model.visual.image_size,
363
+ # is_train=False,
364
+ # mean=image_mean,
365
+ # std=image_std
366
+ # )
367
+
368
+ if light_augmentation:
369
+ preprocess_val = image_transform(
370
+ model.visual.image_size,
371
+ is_train=False,
372
+ mean=image_mean,
373
+ std=image_std,
374
+ resize_longest_max=True,
375
+ )
376
+ preprocess_train = preprocess_val
377
+ else:
378
+ preprocess_train = image_transform(
379
+ model.visual.image_size,
380
+ is_train=True,
381
+ mean=image_mean,
382
+ std=image_std
383
+ )
384
+ preprocess_val = image_transform(
385
+ model.visual.image_size,
386
+ is_train=False,
387
+ mean=image_mean,
388
+ std=image_std
389
+ )
390
+
391
+ return model, preprocess_train, preprocess_val
392
+
393
+
394
+ def create_model_from_pretrained(
395
+ model_name: str,
396
+ pretrained: Optional[str] = None,
397
+ precision: str = 'fp32',
398
+ device: Union[str, torch.device] = 'cpu',
399
+ jit: bool = False,
400
+ force_quick_gelu: bool = False,
401
+ force_custom_text: bool = False,
402
+ force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
403
+ return_transform: bool = True,
404
+ image_mean: Optional[Tuple[float, ...]] = None,
405
+ image_std: Optional[Tuple[float, ...]] = None,
406
+ cache_dir: Optional[str] = None,
407
+ ):
408
+ model = create_model(
409
+ model_name,
410
+ pretrained,
411
+ precision=precision,
412
+ device=device,
413
+ jit=jit,
414
+ force_quick_gelu=force_quick_gelu,
415
+ force_custom_text=force_custom_text,
416
+ force_image_size=force_image_size,
417
+ cache_dir=cache_dir,
418
+ require_pretrained=True,
419
+ )
420
+
421
+ if not return_transform:
422
+ return model
423
+
424
+ image_mean = image_mean or getattr(model.visual, 'image_mean', None)
425
+ image_std = image_std or getattr(model.visual, 'image_std', None)
426
+ preprocess = image_transform(
427
+ model.visual.image_size,
428
+ is_train=False,
429
+ mean=image_mean,
430
+ std=image_std,
431
+ )
432
+
433
+ return model, preprocess
src/open_clip/generation_utils.py ADDED
File without changes
src/open_clip/hf_configs.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # HF architecture dict:
2
+ arch_dict = {
3
+ # https://huggingface.co/docs/transformers/model_doc/roberta#roberta
4
+ "roberta": {
5
+ "config_names": {
6
+ "context_length": "max_position_embeddings",
7
+ "vocab_size": "vocab_size",
8
+ "width": "hidden_size",
9
+ "heads": "num_attention_heads",
10
+ "layers": "num_hidden_layers",
11
+ "layer_attr": "layer",
12
+ "token_embeddings_attr": "embeddings"
13
+ },
14
+ "pooler": "mean_pooler",
15
+ },
16
+ # https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig
17
+ "xlm-roberta": {
18
+ "config_names": {
19
+ "context_length": "max_position_embeddings",
20
+ "vocab_size": "vocab_size",
21
+ "width": "hidden_size",
22
+ "heads": "num_attention_heads",
23
+ "layers": "num_hidden_layers",
24
+ "layer_attr": "layer",
25
+ "token_embeddings_attr": "embeddings"
26
+ },
27
+ "pooler": "mean_pooler",
28
+ },
29
+ # https://huggingface.co/docs/transformers/model_doc/mt5#mt5
30
+ "mt5": {
31
+ "config_names": {
32
+ # unlimited seqlen
33
+ # https://github.com/google-research/text-to-text-transfer-transformer/issues/273
34
+ # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374
35
+ "context_length": "",
36
+ "vocab_size": "vocab_size",
37
+ "width": "d_model",
38
+ "heads": "num_heads",
39
+ "layers": "num_layers",
40
+ "layer_attr": "block",
41
+ "token_embeddings_attr": "embed_tokens"
42
+ },
43
+ "pooler": "mean_pooler",
44
+ },
45
+ }
src/open_clip/hf_model.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ huggingface model adapter
2
+
3
+ Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model.
4
+ """
5
+
6
+ import re
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from torch import TensorType
11
+
12
+ try:
13
+ import transformers
14
+ from transformers import AutoModel, AutoTokenizer, AutoConfig, PretrainedConfig
15
+ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, \
16
+ BaseModelOutputWithPoolingAndCrossAttentions
17
+ except ImportError as e:
18
+ transformers = None
19
+
20
+
21
+ class BaseModelOutput:
22
+ pass
23
+
24
+
25
+ class PretrainedConfig:
26
+ pass
27
+
28
+ from .hf_configs import arch_dict
29
+
30
+
31
+ # utils
32
+ def _camel2snake(s):
33
+ return re.sub(r'(?<!^)(?=[A-Z])', '_', s).lower()
34
+
35
+
36
+ # TODO: ?last - for gpt-like models
37
+ _POOLERS = {}
38
+
39
+
40
+ def register_pooler(cls):
41
+ """Decorator registering pooler class"""
42
+ _POOLERS[_camel2snake(cls.__name__)] = cls
43
+ return cls
44
+
45
+
46
+ @register_pooler
47
+ class MeanPooler(nn.Module):
48
+ """Mean pooling"""
49
+
50
+ def forward(self, x: BaseModelOutput, attention_mask: TensorType):
51
+ masked_output = x.last_hidden_state * attention_mask.unsqueeze(-1)
52
+ return masked_output.sum(dim=1) / attention_mask.sum(-1, keepdim=True)
53
+
54
+
55
+ @register_pooler
56
+ class MaxPooler(nn.Module):
57
+ """Max pooling"""
58
+
59
+ def forward(self, x: BaseModelOutput, attention_mask: TensorType):
60
+ masked_output = x.last_hidden_state.masked_fill(attention_mask.unsqueeze(-1), -torch.inf)
61
+ return masked_output.max(1).values
62
+
63
+
64
+ @register_pooler
65
+ class ClsPooler(nn.Module):
66
+ """CLS token pooling"""
67
+
68
+ def __init__(self, use_pooler_output=True):
69
+ super().__init__()
70
+ self.cls_token_position = 0
71
+ self.use_pooler_output = use_pooler_output
72
+
73
+ def forward(self, x: BaseModelOutput, attention_mask: TensorType):
74
+ if (self.use_pooler_output and
75
+ isinstance(x, (BaseModelOutputWithPooling, BaseModelOutputWithPoolingAndCrossAttentions)) and
76
+ (x.pooler_output is not None)
77
+ ):
78
+ return x.pooler_output
79
+
80
+ return x.last_hidden_state[:, self.cls_token_position, :]
81
+
82
+
83
+ class HFTextEncoder(nn.Module):
84
+ """HuggingFace model adapter"""
85
+ output_tokens: torch.jit.Final[bool]
86
+
87
+ def __init__(
88
+ self,
89
+ model_name_or_path: str,
90
+ output_dim: int,
91
+ config: PretrainedConfig = None,
92
+ pooler_type: str = None,
93
+ proj: str = None,
94
+ pretrained: bool = True,
95
+ output_tokens: bool = False,
96
+ ):
97
+ super().__init__()
98
+ self.output_tokens = output_tokens
99
+ self.output_dim = output_dim
100
+
101
+ # TODO: find better way to get this information
102
+ uses_transformer_pooler = (pooler_type == "cls_pooler")
103
+
104
+ if transformers is None:
105
+ raise RuntimeError("Please `pip install transformers` to use pre-trained HuggingFace models")
106
+ if config is None:
107
+ self.config = AutoConfig.from_pretrained(model_name_or_path)
108
+ create_func, model_args = (AutoModel.from_pretrained, model_name_or_path) if pretrained else (
109
+ AutoModel.from_config, self.config)
110
+ # TODO: do all model configs have this attribute? PretrainedConfig does so yes??
111
+ if hasattr(self.config, "is_encoder_decoder") and self.config.is_encoder_decoder:
112
+ self.transformer = create_func(model_args)
113
+ self.transformer = self.transformer.encoder
114
+ else:
115
+ self.transformer = create_func(model_args, add_pooling_layer=uses_transformer_pooler)
116
+ else:
117
+ self.config = config
118
+ self.transformer = AutoModel.from_config(config)
119
+ if pooler_type is None: # get default arch pooler
120
+ pooler_type = (arch_dict[self.config.model_type]["pooler"])
121
+
122
+ self.pooler = _POOLERS[pooler_type]()
123
+
124
+ d_model = getattr(self.config, arch_dict[self.config.model_type]["config_names"]["width"])
125
+ if (d_model == output_dim) and (proj is None): # do we always need a proj?
126
+ self.proj = nn.Identity()
127
+ elif proj == 'linear':
128
+ self.proj = nn.Linear(d_model, output_dim, bias=False)
129
+ elif proj == 'mlp':
130
+ hidden_size = (d_model + output_dim) // 2
131
+ self.proj = nn.Sequential(
132
+ nn.Linear(d_model, hidden_size, bias=False),
133
+ nn.GELU(),
134
+ nn.Linear(hidden_size, output_dim, bias=False),
135
+ )
136
+
137
+ def forward(self, x: TensorType):
138
+ attn_mask = (x != self.config.pad_token_id).long()
139
+ out = self.transformer(input_ids=x, attention_mask=attn_mask)
140
+ pooled_out = self.pooler(out, attn_mask)
141
+ projected = self.proj(pooled_out)
142
+
143
+ seq_len = out.last_hidden_state.shape[1]
144
+ tokens = (
145
+ out.last_hidden_state[:, torch.arange(seq_len) != self.pooler.cls_token_position, :]
146
+ if type(self.pooler) == ClsPooler
147
+ else out.last_hidden_state
148
+ )
149
+
150
+ if self.output_tokens:
151
+ return projected, tokens
152
+ return projected
153
+
154
+ def lock(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
155
+ if not unlocked_layers: # full freezing
156
+ for n, p in self.transformer.named_parameters():
157
+ p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
158
+ return
159
+
160
+ encoder = self.transformer.encoder if hasattr(self.transformer, 'encoder') else self.transformer
161
+ layer_list = getattr(encoder, arch_dict[self.config.model_type]["config_names"]["layer_attr"])
162
+ print(f"Unlocking {unlocked_layers}/{len(layer_list) + 1} layers of hf model")
163
+ embeddings = getattr(
164
+ self.transformer, arch_dict[self.config.model_type]["config_names"]["token_embeddings_attr"])
165
+ modules = [embeddings, *layer_list][:-unlocked_layers]
166
+ # freeze layers
167
+ for module in modules:
168
+ for n, p in module.named_parameters():
169
+ p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
170
+
171
+ @torch.jit.ignore
172
+ def set_grad_checkpointing(self, enable=True):
173
+ self.transformer.gradient_checkpointing_enable()
174
+
175
+ def init_parameters(self):
176
+ pass
src/open_clip/loss.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn import functional as F
4
+ from torch.nn.utils.rnn import pad_sequence
5
+
6
+ try:
7
+ import torch.distributed.nn
8
+ from torch import distributed as dist
9
+
10
+ has_distributed = True
11
+ except ImportError:
12
+ has_distributed = False
13
+
14
+ try:
15
+ import horovod.torch as hvd
16
+ except ImportError:
17
+ hvd = None
18
+
19
+
20
+ def gather_features(
21
+ image_features,
22
+ text_features,
23
+ local_loss=False,
24
+ gather_with_grad=False,
25
+ rank=0,
26
+ world_size=1,
27
+ use_horovod=False
28
+ ):
29
+ assert has_distributed, 'torch.distributed did not import correctly, please use a PyTorch version with support.'
30
+ if use_horovod:
31
+ assert hvd is not None, 'Please install horovod'
32
+ if gather_with_grad:
33
+ all_image_features = hvd.allgather(image_features)
34
+ all_text_features = hvd.allgather(text_features)
35
+ else:
36
+ with torch.no_grad():
37
+ all_image_features = hvd.allgather(image_features)
38
+ all_text_features = hvd.allgather(text_features)
39
+ if not local_loss:
40
+ # ensure grads for local rank when all_* features don't have a gradient
41
+ gathered_image_features = list(all_image_features.chunk(world_size, dim=0))
42
+ gathered_text_features = list(all_text_features.chunk(world_size, dim=0))
43
+ gathered_image_features[rank] = image_features
44
+ gathered_text_features[rank] = text_features
45
+ all_image_features = torch.cat(gathered_image_features, dim=0)
46
+ all_text_features = torch.cat(gathered_text_features, dim=0)
47
+ else:
48
+ # We gather tensors from all gpus
49
+ if gather_with_grad:
50
+ all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0)
51
+ all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0)
52
+ else:
53
+ gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)]
54
+ gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)]
55
+ dist.all_gather(gathered_image_features, image_features)
56
+ dist.all_gather(gathered_text_features, text_features)
57
+ if not local_loss:
58
+ # ensure grads for local rank when all_* features don't have a gradient
59
+ gathered_image_features[rank] = image_features
60
+ gathered_text_features[rank] = text_features
61
+ all_image_features = torch.cat(gathered_image_features, dim=0)
62
+ all_text_features = torch.cat(gathered_text_features, dim=0)
63
+
64
+ return all_image_features, all_text_features
65
+
66
+
67
+ class ClipLoss(nn.Module):
68
+
69
+ def __init__(
70
+ self,
71
+ local_loss=False,
72
+ gather_with_grad=False,
73
+ cache_labels=False,
74
+ rank=0,
75
+ world_size=1,
76
+ use_horovod=False,
77
+ ):
78
+ super().__init__()
79
+ self.local_loss = local_loss
80
+ self.gather_with_grad = gather_with_grad
81
+ self.cache_labels = cache_labels
82
+ self.rank = rank
83
+ self.world_size = world_size
84
+ self.use_horovod = use_horovod
85
+
86
+ # cache state
87
+ self.prev_num_logits = 0
88
+ self.labels = {}
89
+
90
+ def get_ground_truth(self, device, num_logits) -> torch.Tensor:
91
+ # calculated ground-truth and cache if enabled
92
+ if self.prev_num_logits != num_logits or device not in self.labels:
93
+ labels = torch.arange(num_logits, device=device, dtype=torch.long)
94
+ if self.world_size > 1 and self.local_loss:
95
+ labels = labels + num_logits * self.rank
96
+ if self.cache_labels:
97
+ self.labels[device] = labels
98
+ self.prev_num_logits = num_logits
99
+ else:
100
+ labels = self.labels[device]
101
+ return labels
102
+
103
+ def get_logits(self, image_features, text_features, logit_scale):
104
+ if self.world_size > 1:
105
+ all_image_features, all_text_features = gather_features(
106
+ image_features, text_features,
107
+ self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod)
108
+
109
+ if self.local_loss:
110
+ logits_per_image = logit_scale * image_features @ all_text_features.T
111
+ logits_per_text = logit_scale * text_features @ all_image_features.T
112
+ else:
113
+ logits_per_image = logit_scale * all_image_features @ all_text_features.T
114
+ logits_per_text = logits_per_image.T
115
+ else:
116
+ logits_per_image = logit_scale * image_features @ text_features.T
117
+ logits_per_text = logit_scale * text_features @ image_features.T
118
+
119
+ return logits_per_image, logits_per_text
120
+
121
+ def forward(self, image_features, text_features, logit_scale, output_dict=False):
122
+ device = image_features.device
123
+ logits_per_image, logits_per_text = self.get_logits(image_features, text_features, logit_scale)
124
+
125
+ labels = self.get_ground_truth(device, logits_per_image.shape[0])
126
+
127
+ total_loss = (
128
+ F.cross_entropy(logits_per_image, labels) +
129
+ F.cross_entropy(logits_per_text, labels)
130
+ ) / 2
131
+ return total_loss
132
+
133
+ class PreferenceLoss(nn.Module):
134
+
135
+ def forward(self, logits_per_image, num_images, labels):
136
+
137
+ paired_logits_list = [logit[:,i] for i, logit in enumerate(logits_per_image.split(num_images.tolist()))]
138
+ paired_logits = pad_sequence(paired_logits_list, batch_first=True, padding_value=-999)
139
+
140
+ ce_loss = F.cross_entropy(paired_logits, labels)
141
+ return ce_loss
142
+
143
+ class HPSLoss(nn.Module):
144
+
145
+ def forward(self, text_logits, labels):
146
+
147
+ device = text_logits.device
148
+ text_0_logits, text_1_logits = text_logits.chunk(2, dim=-1)
149
+ label_0, label_1 = labels.chunk(2, dim=-1)
150
+
151
+ index = torch.arange(text_0_logits.shape[0], device=device, dtype=torch.long)
152
+ text_0_logits = text_0_logits[index, index]
153
+ text_1_logits = text_1_logits[index, index]
154
+ text_logits = torch.stack([text_0_logits, text_1_logits], dim=-1)
155
+ text_0_labels = torch.zeros(text_logits.shape[0], device=device, dtype=torch.long)
156
+ text_1_labels = text_0_labels + 1
157
+
158
+ text_0_loss = torch.nn.functional.cross_entropy(text_logits, text_0_labels, reduction="none")
159
+ text_1_loss = torch.nn.functional.cross_entropy(text_logits, text_1_labels, reduction="none")
160
+
161
+ text_loss = label_0 * text_0_loss + label_1 * text_1_loss
162
+
163
+ # absolute_example_weight = 1 / num_per_prompt
164
+ # denominator = absolute_example_weight.sum()
165
+ # weight_per_example = absolute_example_weight / denominator
166
+ # text_loss *= weight_per_example
167
+
168
+ text_loss = text_loss.sum()
169
+ return text_loss
170
+
171
+ class RankingLoss(nn.Module):
172
+
173
+ def forward(self, logits_per_image, num_images, labels, margin = 1.0):
174
+ paired_logits_list = [logit[:,i] for i, logit in enumerate(logits_per_image.split(num_images.tolist()))]
175
+ label_list = [label for label in labels.split(num_images.tolist())]
176
+ # ranked_logits = [torch.index_select(paired_logits_list[i], 0, rank) for i, rank in enumerate(label_list)]
177
+
178
+ paired_logits = pad_sequence(paired_logits_list, batch_first=True, padding_value=-1)
179
+ padded_labels = pad_sequence(label_list, batch_first=True, padding_value=10)
180
+
181
+ # regulized_logits = torch.log(torch.sigmoid(paired_logits))
182
+
183
+ diff = paired_logits.unsqueeze(1) - paired_logits.unsqueeze(2)
184
+ # diff = paired_logits.unsqueeze(1) - paired_logits.unsqueeze(2)
185
+ # diff_label = torch.clamp(padded_labels.unsqueeze(1) - padded_labels.unsqueeze(2), min=-1, max=1)
186
+ diff_label = - (padded_labels.unsqueeze(1) - padded_labels.unsqueeze(2))
187
+ mask = torch.triu(torch.ones(diff.shape[1], diff.shape[1]), diagonal=1).bool().detach()
188
+
189
+ loss = torch.clamp(margin - torch.mul(diff[:, ~mask],diff_label[:,~mask]), min=0).mean()
190
+ return loss
191
+
192
+ class CoCaLoss(ClipLoss):
193
+ def __init__(
194
+ self,
195
+ caption_loss_weight,
196
+ clip_loss_weight,
197
+ pad_id=0, # pad_token for open_clip custom tokenizer
198
+ local_loss=False,
199
+ gather_with_grad=False,
200
+ cache_labels=False,
201
+ rank=0,
202
+ world_size=1,
203
+ use_horovod=False,
204
+ ):
205
+ super().__init__(
206
+ local_loss=local_loss,
207
+ gather_with_grad=gather_with_grad,
208
+ cache_labels=cache_labels,
209
+ rank=rank,
210
+ world_size=world_size,
211
+ use_horovod=use_horovod
212
+ )
213
+
214
+ self.clip_loss_weight = clip_loss_weight
215
+ self.caption_loss_weight = caption_loss_weight
216
+ self.caption_loss = nn.CrossEntropyLoss(ignore_index=pad_id)
217
+
218
+ def forward(self, image_features, text_features, logits, labels, logit_scale, output_dict=False):
219
+ clip_loss = super().forward(image_features, text_features, logit_scale)
220
+ clip_loss = self.clip_loss_weight * clip_loss
221
+
222
+ caption_loss = self.caption_loss(
223
+ logits.permute(0, 2, 1),
224
+ labels,
225
+ )
226
+ caption_loss = caption_loss * self.caption_loss_weight
227
+
228
+ if output_dict:
229
+ return {"contrastive_loss": clip_loss, "caption_loss": caption_loss}
230
+
231
+ return clip_loss, caption_loss
232
+
233
+
234
+ class DistillClipLoss(ClipLoss):
235
+
236
+ def dist_loss(self, teacher_logits, student_logits):
237
+ return -(teacher_logits.softmax(dim=1) * student_logits.log_softmax(dim=1)).sum(dim=1).mean(dim=0)
238
+
239
+ def forward(
240
+ self,
241
+ image_features,
242
+ text_features,
243
+ logit_scale,
244
+ dist_image_features,
245
+ dist_text_features,
246
+ dist_logit_scale,
247
+ output_dict=False,
248
+ ):
249
+ logits_per_image, logits_per_text = \
250
+ self.get_logits(image_features, text_features, logit_scale)
251
+
252
+ dist_logits_per_image, dist_logits_per_text = \
253
+ self.get_logits(dist_image_features, dist_text_features, dist_logit_scale)
254
+
255
+ labels = self.get_ground_truth(image_features.device, logits_per_image.shape[0])
256
+
257
+ contrastive_loss = (
258
+ F.cross_entropy(logits_per_image, labels) +
259
+ F.cross_entropy(logits_per_text, labels)
260
+ ) / 2
261
+
262
+ distill_loss = (
263
+ self.dist_loss(dist_logits_per_image, logits_per_image) +
264
+ self.dist_loss(dist_logits_per_text, logits_per_text)
265
+ ) / 2
266
+
267
+ if output_dict:
268
+ return {"contrastive_loss": contrastive_loss, "distill_loss": distill_loss}
269
+
270
+ return contrastive_loss, distill_loss
src/open_clip/model.py ADDED
@@ -0,0 +1,461 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ CLIP Model
2
+
3
+ Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4
+ """
5
+ from dataclasses import dataclass
6
+ import logging
7
+ import math
8
+ from typing import Optional, Tuple, Union
9
+
10
+ import numpy as np
11
+ import torch
12
+ import torch.nn.functional as F
13
+ from torch import nn
14
+ from torch.utils.checkpoint import checkpoint
15
+
16
+ from .hf_model import HFTextEncoder
17
+ from .modified_resnet import ModifiedResNet
18
+ from .timm_model import TimmModel
19
+ from .transformer import LayerNormFp32, LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer
20
+ from .utils import to_2tuple
21
+
22
+
23
+ @dataclass
24
+ class CLIPVisionCfg:
25
+ layers: Union[Tuple[int, int, int, int], int] = 12
26
+ width: int = 768
27
+ head_width: int = 64
28
+ mlp_ratio: float = 4.0
29
+ patch_size: int = 16
30
+ image_size: Union[Tuple[int, int], int] = 224
31
+ ls_init_value: Optional[float] = None # layer scale initial value
32
+ patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results
33
+ input_patchnorm: bool = False # whether to use dual patchnorm - would only apply the input layernorm on each patch, as post-layernorm already exist in original clip vit design
34
+ global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580)
35
+ attentional_pool: bool = False # whether to use attentional pooler in the last embedding layer
36
+ n_queries: int = 256 # n_queries for attentional pooler
37
+ attn_pooler_heads: int = 8 # n heads for attentional_pooling
38
+ timm_model_name: str = None # a valid model name overrides layers, width, patch_size
39
+ timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model
40
+ timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
41
+ timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '')
42
+ timm_proj_bias: bool = False # enable bias final projection
43
+ timm_drop: float = 0. # head dropout
44
+ timm_drop_path: Optional[float] = None # backbone stochastic depth
45
+ output_tokens: bool = False
46
+
47
+
48
+ @dataclass
49
+ class CLIPTextCfg:
50
+ context_length: int = 77
51
+ vocab_size: int = 49408
52
+ width: int = 512
53
+ heads: int = 8
54
+ layers: int = 12
55
+ ls_init_value: Optional[float] = None # layer scale initial value
56
+ hf_model_name: str = None
57
+ hf_tokenizer_name: str = None
58
+ hf_model_pretrained: bool = True
59
+ proj: str = 'mlp'
60
+ pooler_type: str = 'mean_pooler'
61
+ embed_cls: bool = False
62
+ pad_id: int = 0
63
+ output_tokens: bool = False
64
+
65
+
66
+ def get_cast_dtype(precision: str):
67
+ cast_dtype = None
68
+ if precision == 'bf16':
69
+ cast_dtype = torch.bfloat16
70
+ elif precision == 'fp16':
71
+ cast_dtype = torch.float16
72
+ return cast_dtype
73
+
74
+
75
+ def _build_vision_tower(
76
+ embed_dim: int,
77
+ vision_cfg: CLIPVisionCfg,
78
+ quick_gelu: bool = False,
79
+ cast_dtype: Optional[torch.dtype] = None
80
+ ):
81
+ if isinstance(vision_cfg, dict):
82
+ vision_cfg = CLIPVisionCfg(**vision_cfg)
83
+
84
+ # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more
85
+ # memory efficient in recent PyTorch releases (>= 1.10).
86
+ # NOTE: timm models always use native GELU regardless of quick_gelu flag.
87
+ act_layer = QuickGELU if quick_gelu else nn.GELU
88
+
89
+ if vision_cfg.timm_model_name:
90
+ visual = TimmModel(
91
+ vision_cfg.timm_model_name,
92
+ pretrained=vision_cfg.timm_model_pretrained,
93
+ pool=vision_cfg.timm_pool,
94
+ proj=vision_cfg.timm_proj,
95
+ proj_bias=vision_cfg.timm_proj_bias,
96
+ drop=vision_cfg.timm_drop,
97
+ drop_path=vision_cfg.timm_drop_path,
98
+ embed_dim=embed_dim,
99
+ image_size=vision_cfg.image_size,
100
+ )
101
+ act_layer = nn.GELU # so that text transformer doesn't use QuickGELU w/ timm models
102
+ elif isinstance(vision_cfg.layers, (tuple, list)):
103
+ vision_heads = vision_cfg.width * 32 // vision_cfg.head_width
104
+ visual = ModifiedResNet(
105
+ layers=vision_cfg.layers,
106
+ output_dim=embed_dim,
107
+ heads=vision_heads,
108
+ image_size=vision_cfg.image_size,
109
+ width=vision_cfg.width,
110
+ )
111
+ else:
112
+ vision_heads = vision_cfg.width // vision_cfg.head_width
113
+ norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
114
+ visual = VisionTransformer(
115
+ image_size=vision_cfg.image_size,
116
+ patch_size=vision_cfg.patch_size,
117
+ width=vision_cfg.width,
118
+ layers=vision_cfg.layers,
119
+ heads=vision_heads,
120
+ mlp_ratio=vision_cfg.mlp_ratio,
121
+ ls_init_value=vision_cfg.ls_init_value,
122
+ patch_dropout=vision_cfg.patch_dropout,
123
+ input_patchnorm=vision_cfg.input_patchnorm,
124
+ global_average_pool=vision_cfg.global_average_pool,
125
+ attentional_pool=vision_cfg.attentional_pool,
126
+ n_queries=vision_cfg.n_queries,
127
+ attn_pooler_heads=vision_cfg.attn_pooler_heads,
128
+ output_tokens=vision_cfg.output_tokens,
129
+ output_dim=embed_dim,
130
+ act_layer=act_layer,
131
+ norm_layer=norm_layer,
132
+ )
133
+
134
+ return visual
135
+
136
+
137
+ def _build_text_tower(
138
+ embed_dim: int,
139
+ text_cfg: CLIPTextCfg,
140
+ quick_gelu: bool = False,
141
+ cast_dtype: Optional[torch.dtype] = None,
142
+ ):
143
+ if isinstance(text_cfg, dict):
144
+ text_cfg = CLIPTextCfg(**text_cfg)
145
+
146
+ if text_cfg.hf_model_name:
147
+ text = HFTextEncoder(
148
+ text_cfg.hf_model_name,
149
+ output_dim=embed_dim,
150
+ proj=text_cfg.proj,
151
+ pooler_type=text_cfg.pooler_type,
152
+ pretrained=text_cfg.hf_model_pretrained,
153
+ output_tokens=text_cfg.output_tokens,
154
+ )
155
+ else:
156
+ act_layer = QuickGELU if quick_gelu else nn.GELU
157
+ norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
158
+
159
+ text = TextTransformer(
160
+ context_length=text_cfg.context_length,
161
+ vocab_size=text_cfg.vocab_size,
162
+ width=text_cfg.width,
163
+ heads=text_cfg.heads,
164
+ layers=text_cfg.layers,
165
+ ls_init_value=text_cfg.ls_init_value,
166
+ output_dim=embed_dim,
167
+ embed_cls=text_cfg.embed_cls,
168
+ output_tokens=text_cfg.output_tokens,
169
+ pad_id=text_cfg.pad_id,
170
+ act_layer=act_layer,
171
+ norm_layer=norm_layer,
172
+ )
173
+ return text
174
+
175
+
176
+ class CLIP(nn.Module):
177
+ output_dict: torch.jit.Final[bool]
178
+
179
+ def __init__(
180
+ self,
181
+ embed_dim: int,
182
+ vision_cfg: CLIPVisionCfg,
183
+ text_cfg: CLIPTextCfg,
184
+ quick_gelu: bool = False,
185
+ cast_dtype: Optional[torch.dtype] = None,
186
+ output_dict: bool = False,
187
+ ):
188
+ super().__init__()
189
+ self.output_dict = output_dict
190
+ self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
191
+
192
+ text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
193
+ self.transformer = text.transformer
194
+ self.vocab_size = text.vocab_size
195
+ self.token_embedding = text.token_embedding
196
+ self.positional_embedding = text.positional_embedding
197
+ self.ln_final = text.ln_final
198
+ self.text_projection = text.text_projection
199
+ self.register_buffer('attn_mask', text.attn_mask, persistent=False)
200
+
201
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
202
+
203
+ def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
204
+ # lock image tower as per LiT - https://arxiv.org/abs/2111.07991
205
+ self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
206
+
207
+ def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
208
+ locked_layers = []
209
+ locked_layers.append(self.token_embedding)
210
+ self.positional_embedding.requires_grad = False
211
+ if unlocked_layers > 0:
212
+ locked_layers.append(self.transformer.resblocks[:-unlocked_layers])
213
+ else:
214
+ locked_layers.append(self.transformer)
215
+ locked_layers.append(self.ln_final)
216
+ self.text_projection.requires_grad = False
217
+
218
+ # freeze layers
219
+ for module in locked_layers:
220
+ for n, p in module.named_parameters():
221
+ p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
222
+
223
+ @torch.jit.ignore
224
+ def set_grad_checkpointing(self, enable=True):
225
+ self.visual.set_grad_checkpointing(enable)
226
+ self.transformer.grad_checkpointing = enable
227
+
228
+ def encode_image(self, image, normalize: bool = False):
229
+ features = self.visual(image)
230
+ return F.normalize(features, dim=-1) if normalize else features
231
+
232
+ def encode_text(self, text, normalize: bool = False):
233
+ cast_dtype = self.transformer.get_cast_dtype()
234
+
235
+ x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
236
+
237
+ x = x + self.positional_embedding.to(cast_dtype)
238
+ x = x.permute(1, 0, 2) # NLD -> LND
239
+ x = self.transformer(x, attn_mask=self.attn_mask)
240
+ x = x.permute(1, 0, 2) # LND -> NLD
241
+ x = self.ln_final(x) # [batch_size, n_ctx, transformer.width]
242
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
243
+ x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
244
+ return F.normalize(x, dim=-1) if normalize else x
245
+
246
+ def forward(self, image, text):
247
+ image_features = self.encode_image(image, normalize=True)
248
+ text_features = self.encode_text(text, normalize=True)
249
+ if self.output_dict:
250
+ return {
251
+ "image_features": image_features,
252
+ "text_features": text_features,
253
+ "logit_scale": self.logit_scale.exp()
254
+ }
255
+ return image_features, text_features, self.logit_scale.exp()
256
+
257
+
258
+ class CustomTextCLIP(nn.Module):
259
+ output_dict: torch.jit.Final[bool]
260
+
261
+ def __init__(
262
+ self,
263
+ embed_dim: int,
264
+ vision_cfg: CLIPVisionCfg,
265
+ text_cfg: CLIPTextCfg,
266
+ quick_gelu: bool = False,
267
+ cast_dtype: Optional[torch.dtype] = None,
268
+ output_dict: bool = False,
269
+ ):
270
+ super().__init__()
271
+ self.output_dict = output_dict
272
+ self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
273
+ self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
274
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
275
+
276
+ def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
277
+ # lock image tower as per LiT - https://arxiv.org/abs/2111.07991
278
+ self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
279
+
280
+ def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
281
+ self.text.lock(unlocked_layers, freeze_layer_norm)
282
+
283
+ @torch.jit.ignore
284
+ def set_grad_checkpointing(self, enable=True):
285
+ self.visual.set_grad_checkpointing(enable)
286
+ self.text.set_grad_checkpointing(enable)
287
+
288
+ def encode_image(self, image, normalize: bool = False):
289
+ features = self.visual(image)
290
+ return F.normalize(features, dim=-1) if normalize else features
291
+
292
+ def encode_text(self, text, normalize: bool = False):
293
+ features = self.text(text)
294
+ return F.normalize(features, dim=-1) if normalize else features
295
+
296
+ def forward(self, image, text):
297
+ image_features = self.encode_image(image, normalize=True)
298
+ text_features = self.encode_text(text, normalize=True)
299
+ if self.output_dict:
300
+ return {
301
+ "image_features": image_features,
302
+ "text_features": text_features,
303
+ "logit_scale": self.logit_scale.exp()
304
+ }
305
+ return image_features, text_features, self.logit_scale.exp()
306
+
307
+
308
+ def convert_weights_to_lp(model: nn.Module, dtype=torch.float16):
309
+ """Convert applicable model parameters to low-precision (bf16 or fp16)"""
310
+
311
+ def _convert_weights(l):
312
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
313
+ l.weight.data = l.weight.data.to(dtype)
314
+ if l.bias is not None:
315
+ l.bias.data = l.bias.data.to(dtype)
316
+
317
+ if isinstance(l, (nn.MultiheadAttention, Attention)):
318
+ for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
319
+ tensor = getattr(l, attr)
320
+ if tensor is not None:
321
+ tensor.data = tensor.data.to(dtype)
322
+
323
+ for name in ["text_projection", "proj"]:
324
+ if hasattr(l, name):
325
+ attr = getattr(l, name)
326
+ if attr is not None:
327
+ attr.data = attr.data.to(dtype)
328
+
329
+ model.apply(_convert_weights)
330
+
331
+
332
+ convert_weights_to_fp16 = convert_weights_to_lp # backwards compat
333
+
334
+
335
+ # used to maintain checkpoint compatibility
336
+ def convert_to_custom_text_state_dict(state_dict: dict):
337
+ if 'text_projection' in state_dict:
338
+ # old format state_dict, move text tower -> .text
339
+ new_state_dict = {}
340
+ for k, v in state_dict.items():
341
+ if any(k.startswith(p) for p in (
342
+ 'text_projection',
343
+ 'positional_embedding',
344
+ 'token_embedding',
345
+ 'transformer',
346
+ 'ln_final',
347
+ )):
348
+ k = 'text.' + k
349
+ new_state_dict[k] = v
350
+ return new_state_dict
351
+ return state_dict
352
+
353
+
354
+ def build_model_from_openai_state_dict(
355
+ state_dict: dict,
356
+ quick_gelu=True,
357
+ cast_dtype=torch.float16,
358
+ ):
359
+ vit = "visual.proj" in state_dict
360
+
361
+ if vit:
362
+ vision_width = state_dict["visual.conv1.weight"].shape[0]
363
+ vision_layers = len(
364
+ [k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
365
+ vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
366
+ grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
367
+ image_size = vision_patch_size * grid_size
368
+ else:
369
+ counts: list = [
370
+ len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
371
+ vision_layers = tuple(counts)
372
+ vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
373
+ output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
374
+ vision_patch_size = None
375
+ assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
376
+ image_size = output_width * 32
377
+
378
+ embed_dim = state_dict["text_projection"].shape[1]
379
+ context_length = state_dict["positional_embedding"].shape[0]
380
+ vocab_size = state_dict["token_embedding.weight"].shape[0]
381
+ transformer_width = state_dict["ln_final.weight"].shape[0]
382
+ transformer_heads = transformer_width // 64
383
+ transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
384
+
385
+ vision_cfg = CLIPVisionCfg(
386
+ layers=vision_layers,
387
+ width=vision_width,
388
+ patch_size=vision_patch_size,
389
+ image_size=image_size,
390
+ )
391
+ text_cfg = CLIPTextCfg(
392
+ context_length=context_length,
393
+ vocab_size=vocab_size,
394
+ width=transformer_width,
395
+ heads=transformer_heads,
396
+ layers=transformer_layers,
397
+ )
398
+ model = CLIP(
399
+ embed_dim,
400
+ vision_cfg=vision_cfg,
401
+ text_cfg=text_cfg,
402
+ quick_gelu=quick_gelu, # OpenAI models were trained with QuickGELU
403
+ cast_dtype=cast_dtype,
404
+ )
405
+
406
+ for key in ["input_resolution", "context_length", "vocab_size"]:
407
+ state_dict.pop(key, None)
408
+
409
+ convert_weights_to_fp16(model) # OpenAI state dicts are partially converted to float16
410
+ model.load_state_dict(state_dict)
411
+ return model.eval()
412
+
413
+
414
+ def trace_model(model, batch_size=256, device=torch.device('cpu')):
415
+ model.eval()
416
+ image_size = model.visual.image_size
417
+ example_images = torch.ones((batch_size, 3, image_size, image_size), device=device)
418
+ example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device)
419
+ model = torch.jit.trace_module(
420
+ model,
421
+ inputs=dict(
422
+ forward=(example_images, example_text),
423
+ encode_text=(example_text,),
424
+ encode_image=(example_images,)
425
+ ))
426
+ model.visual.image_size = image_size
427
+ return model
428
+
429
+
430
+ def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', antialias: bool = True):
431
+ # Rescale the grid of position embeddings when loading from state_dict
432
+ old_pos_embed = state_dict.get('visual.positional_embedding', None)
433
+ if old_pos_embed is None or not hasattr(model.visual, 'grid_size'):
434
+ return
435
+ grid_size = to_2tuple(model.visual.grid_size)
436
+ extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
437
+ new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
438
+ if new_seq_len == old_pos_embed.shape[0]:
439
+ return
440
+
441
+ if extra_tokens:
442
+ pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
443
+ else:
444
+ pos_emb_tok, pos_emb_img = None, old_pos_embed
445
+ old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))
446
+
447
+ logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size)
448
+ pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
449
+ pos_emb_img = F.interpolate(
450
+ pos_emb_img,
451
+ size=grid_size,
452
+ mode=interpolation,
453
+ antialias=antialias,
454
+ align_corners=False,
455
+ )
456
+ pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
457
+ if pos_emb_tok is not None:
458
+ new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
459
+ else:
460
+ new_pos_embed = pos_emb_img
461
+ state_dict['visual.positional_embedding'] = new_pos_embed
src/open_clip/model_configs/RN101-quickgelu.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 512,
3
+ "quick_gelu": true,
4
+ "vision_cfg": {
5
+ "image_size": 224,
6
+ "layers": [
7
+ 3,
8
+ 4,
9
+ 23,
10
+ 3
11
+ ],
12
+ "width": 64,
13
+ "patch_size": null
14
+ },
15
+ "text_cfg": {
16
+ "context_length": 77,
17
+ "vocab_size": 49408,
18
+ "width": 512,
19
+ "heads": 8,
20
+ "layers": 12
21
+ }
22
+ }
src/open_clip/model_configs/RN101.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 512,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": [
6
+ 3,
7
+ 4,
8
+ 23,
9
+ 3
10
+ ],
11
+ "width": 64,
12
+ "patch_size": null
13
+ },
14
+ "text_cfg": {
15
+ "context_length": 77,
16
+ "vocab_size": 49408,
17
+ "width": 512,
18
+ "heads": 8,
19
+ "layers": 12
20
+ }
21
+ }
src/open_clip/model_configs/RN50-quickgelu.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "quick_gelu": true,
4
+ "vision_cfg": {
5
+ "image_size": 224,
6
+ "layers": [
7
+ 3,
8
+ 4,
9
+ 6,
10
+ 3
11
+ ],
12
+ "width": 64,
13
+ "patch_size": null
14
+ },
15
+ "text_cfg": {
16
+ "context_length": 77,
17
+ "vocab_size": 49408,
18
+ "width": 512,
19
+ "heads": 8,
20
+ "layers": 12
21
+ }
22
+ }
src/open_clip/model_configs/RN50.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": [
6
+ 3,
7
+ 4,
8
+ 6,
9
+ 3
10
+ ],
11
+ "width": 64,
12
+ "patch_size": null
13
+ },
14
+ "text_cfg": {
15
+ "context_length": 77,
16
+ "vocab_size": 49408,
17
+ "width": 512,
18
+ "heads": 8,
19
+ "layers": 12
20
+ }
21
+ }
src/open_clip/model_configs/RN50x16.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 768,
3
+ "vision_cfg": {
4
+ "image_size": 384,
5
+ "layers": [
6
+ 6,
7
+ 8,
8
+ 18,
9
+ 8
10
+ ],
11
+ "width": 96,
12
+ "patch_size": null
13
+ },
14
+ "text_cfg": {
15
+ "context_length": 77,
16
+ "vocab_size": 49408,
17
+ "width": 768,
18
+ "heads": 12,
19
+ "layers": 12
20
+ }
21
+ }
src/open_clip/model_configs/RN50x4.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 640,
3
+ "vision_cfg": {
4
+ "image_size": 288,
5
+ "layers": [
6
+ 4,
7
+ 6,
8
+ 10,
9
+ 6
10
+ ],
11
+ "width": 80,
12
+ "patch_size": null
13
+ },
14
+ "text_cfg": {
15
+ "context_length": 77,
16
+ "vocab_size": 49408,
17
+ "width": 640,
18
+ "heads": 10,
19
+ "layers": 12
20
+ }
21
+ }
src/open_clip/model_configs/RN50x64.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "vision_cfg": {
4
+ "image_size": 448,
5
+ "layers": [
6
+ 3,
7
+ 15,
8
+ 36,
9
+ 10
10
+ ],
11
+ "width": 128,
12
+ "patch_size": null
13
+ },
14
+ "text_cfg": {
15
+ "context_length": 77,
16
+ "vocab_size": 49408,
17
+ "width": 1024,
18
+ "heads": 16,
19
+ "layers": 12
20
+ }
21
+ }
src/open_clip/model_configs/ViT-B-16-plus-240.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 640,
3
+ "vision_cfg": {
4
+ "image_size": 240,
5
+ "layers": 12,
6
+ "width": 896,
7
+ "patch_size": 16
8
+ },
9
+ "text_cfg": {
10
+ "context_length": 77,
11
+ "vocab_size": 49408,
12
+ "width": 640,
13
+ "heads": 10,
14
+ "layers": 12
15
+ }
16
+ }