ywen commited on
Commit
2f43921
·
1 Parent(s): 3ba432d

initial commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +86 -0
  2. open_clip/__init__.py +11 -0
  3. open_clip/bpe_simple_vocab_16e6.txt.gz +3 -0
  4. open_clip/constants.py +2 -0
  5. open_clip/factory.py +280 -0
  6. open_clip/hf_configs.py +45 -0
  7. open_clip/hf_model.py +164 -0
  8. open_clip/loss.py +121 -0
  9. open_clip/model.py +440 -0
  10. open_clip/model_configs/RN101-quickgelu.json +22 -0
  11. open_clip/model_configs/RN101.json +21 -0
  12. open_clip/model_configs/RN50-quickgelu.json +22 -0
  13. open_clip/model_configs/RN50.json +21 -0
  14. open_clip/model_configs/RN50x16.json +21 -0
  15. open_clip/model_configs/RN50x4.json +21 -0
  16. open_clip/model_configs/RN50x64.json +21 -0
  17. open_clip/model_configs/ViT-B-16-plus-240.json +16 -0
  18. open_clip/model_configs/ViT-B-16-plus.json +16 -0
  19. open_clip/model_configs/ViT-B-16.json +16 -0
  20. open_clip/model_configs/ViT-B-32-plus-256.json +16 -0
  21. open_clip/model_configs/ViT-B-32-quickgelu.json +17 -0
  22. open_clip/model_configs/ViT-B-32.json +16 -0
  23. open_clip/model_configs/ViT-H-14.json +17 -0
  24. open_clip/model_configs/ViT-H-16.json +17 -0
  25. open_clip/model_configs/ViT-L-14-280.json +16 -0
  26. open_clip/model_configs/ViT-L-14-336.json +16 -0
  27. open_clip/model_configs/ViT-L-14.json +16 -0
  28. open_clip/model_configs/ViT-L-16-320.json +16 -0
  29. open_clip/model_configs/ViT-L-16.json +16 -0
  30. open_clip/model_configs/ViT-M-16-alt.json +17 -0
  31. open_clip/model_configs/ViT-M-16.json +16 -0
  32. open_clip/model_configs/ViT-M-32-alt.json +16 -0
  33. open_clip/model_configs/ViT-M-32.json +16 -0
  34. open_clip/model_configs/ViT-S-16-alt.json +16 -0
  35. open_clip/model_configs/ViT-S-16.json +16 -0
  36. open_clip/model_configs/ViT-S-32-alt.json +16 -0
  37. open_clip/model_configs/ViT-S-32.json +16 -0
  38. open_clip/model_configs/ViT-bigG-14.json +18 -0
  39. open_clip/model_configs/ViT-e-14.json +18 -0
  40. open_clip/model_configs/ViT-g-14.json +18 -0
  41. open_clip/model_configs/mt5-base-ViT-B-32.json +15 -0
  42. open_clip/model_configs/mt5-xl-ViT-H-14.json +16 -0
  43. open_clip/model_configs/roberta-ViT-B-32.json +16 -0
  44. open_clip/model_configs/timm-convnext_base.json +17 -0
  45. open_clip/model_configs/timm-convnext_base_w.json +17 -0
  46. open_clip/model_configs/timm-convnext_large.json +17 -0
  47. open_clip/model_configs/timm-convnext_small.json +17 -0
  48. open_clip/model_configs/timm-convnext_tiny.json +17 -0
  49. open_clip/model_configs/timm-convnext_xlarge.json +17 -0
  50. open_clip/model_configs/timm-convnext_xxlarge.json +17 -0
app.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ import torch
4
+ import open_clip
5
+ import mediapy as media
6
+ from optim_utils import *
7
+
8
+ import argparse
9
+
10
+ # load args
11
+ args = argparse.Namespace()
12
+ args.__dict__.update(read_json("sample_config.json"))
13
+ args.print_step = None
14
+
15
+ # load model
16
+ device = "cuda" if torch.cuda.is_available() else "cpu"
17
+ model, _, preprocess = open_clip.create_model_and_transforms(args.clip_model, pretrained=args.clip_pretrain, device=device)
18
+
19
+
20
+ def inference(target_image, prompt_len, iter):
21
+ if prompt_len is not None:
22
+ args.prompt_len = int(prompt_len)
23
+ else:
24
+ args.prompt_len = 8
25
+
26
+ if iter is not None:
27
+ args.iter = int(iter)
28
+ else:
29
+ args.iter = 1000
30
+
31
+ learned_prompt = optimize_prompt(model, preprocess, args, device, target_images=[target_image])
32
+
33
+ return learned_prompt
34
+
35
+ def inference_text(target_prompt, prompt_len, iter):
36
+ if prompt_len is not None:
37
+ args.prompt_len = min(int(prompt_len), 75)
38
+ else:
39
+ args.prompt_len = 8
40
+
41
+ if iter is not None:
42
+ args.iter = min(int(iter), 3000)
43
+ else:
44
+ args.iter = 1000
45
+
46
+ learned_prompt = optimize_prompt(model, preprocess, args, device, target_prompts=[target_prompt])
47
+
48
+ return learned_prompt
49
+
50
+
51
+ gr.Progress(track_tqdm=True)
52
+
53
+ demo = gr.Blocks()
54
+
55
+ with demo:
56
+ gr.Markdown("# PEZ Dispenser")
57
+ gr.Markdown("## Hard Prompts Made Easy (PEZ)")
58
+ gr.Markdown("*Want to generate a text prompt for your image that is useful for Stable Diffusion?*")
59
+ gr.Markdown("This space can either generate a text fragment that describes your image, or it can shorten an existing text prompt. This space is using OpenCLIP-ViT/H, the same text encoder used by Stable Diffusion V2. After you generate a prompt, try it out on Stable Diffusion [here](https://huggingface.co/stabilityai/stable-diffusion-2-base). For a quick PEZ demo, try clicking on one of the examples at the bottom of this page.")
60
+ gr.Markdown("For additional details, you can check out the [paper]() and the code on [Github](https://github.com/YuxinWenRick/hard-prompts-made-easy).")
61
+ gr.Markdown("Note: Generation with 1000 steps takes ~60 seconds with a T4. Don't want to wait? You can also run on [Google Colab](https://colab.research.google.com/drive/1VSFps4siwASXDwhK_o29dKA9COvTnG8A?usp=sharing) Or, you can reduce the number of steps.")
62
+
63
+ with gr.Row():
64
+ with gr.Column():
65
+ gr.Markdown("### Image to Prompt")
66
+ input_image = gr.inputs.Image(type="pil", label="Target Image")
67
+ image_button = gr.Button("Generate Prompt")
68
+
69
+ gr.Markdown("### Long Prompt to Short Prompt")
70
+ input_prompt = gr.Textbox(label="Target Prompt")
71
+ prompt_button = gr.Button("Distill Prompt")
72
+
73
+ prompt_len_field = gr.Number(label="Prompt Length (max 75, recommend 8-16)", default=8)
74
+ num_step_field = gr.Number(label="Optimization Steps (max 3000 because of limited resources)", default=1000)
75
+
76
+ with gr.Column():
77
+ gr.Markdown("### Learned Prompt")
78
+ output_prompt = gr.outputs.Textbox(label="Learned Prompt")
79
+
80
+ image_button.click(inference, inputs=[input_image, prompt_len_field, num_step_field], outputs=output_prompt)
81
+ prompt_button.click(inference_text, inputs=[input_prompt, prompt_len_field, num_step_field], outputs=output_prompt)
82
+
83
+ gr.Examples([["sample.jpeg", 8, 1000]], inputs=[input_image, prompt_len_field, num_step_field])
84
+ gr.Examples([["digital concept art of old wooden cabin in florida swamp, trending on artstation", 3, 1000]], inputs=[input_prompt, prompt_len_field, num_step_field])
85
+
86
+ demo.launch(enable_queue=True)
open_clip/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
2
+ from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer
3
+ from .factory import list_models, add_model_config, get_model_config, load_checkpoint
4
+ from .loss import ClipLoss
5
+ from .model import CLIP, CustomTextCLIP, CLIPTextCfg, CLIPVisionCfg,\
6
+ convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype
7
+ from .openai import load_openai_model, list_openai_models
8
+ from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model,\
9
+ get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained
10
+ from .tokenizer import SimpleTokenizer, tokenize
11
+ from .transform import image_transform
open_clip/bpe_simple_vocab_16e6.txt.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
3
+ size 1356917
open_clip/constants.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
2
+ OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
open_clip/factory.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+ import pathlib
5
+ import re
6
+ from copy import deepcopy
7
+ from pathlib import Path
8
+ from typing import Optional, Tuple, Union
9
+
10
+ import torch
11
+
12
+ from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
13
+ from .model import CLIP, CustomTextCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\
14
+ resize_pos_embed, get_cast_dtype
15
+ from .openai import load_openai_model
16
+ from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained, list_pretrained_tags_by_model
17
+ from .transform import image_transform
18
+ from .tokenizer import HFTokenizer, tokenize
19
+
20
+
21
+ _MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
22
+ _MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs
23
+
24
+
25
+ def _natural_key(string_):
26
+ return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
27
+
28
+
29
+ def _rescan_model_configs():
30
+ global _MODEL_CONFIGS
31
+
32
+ config_ext = ('.json',)
33
+ config_files = []
34
+ for config_path in _MODEL_CONFIG_PATHS:
35
+ if config_path.is_file() and config_path.suffix in config_ext:
36
+ config_files.append(config_path)
37
+ elif config_path.is_dir():
38
+ for ext in config_ext:
39
+ config_files.extend(config_path.glob(f'*{ext}'))
40
+
41
+ for cf in config_files:
42
+ with open(cf, 'r') as f:
43
+ model_cfg = json.load(f)
44
+ if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')):
45
+ _MODEL_CONFIGS[cf.stem] = model_cfg
46
+
47
+ _MODEL_CONFIGS = {k: v for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))}
48
+
49
+
50
+ _rescan_model_configs() # initial populate of model config registry
51
+
52
+
53
+ def list_models():
54
+ """ enumerate available model architectures based on config files """
55
+ return list(_MODEL_CONFIGS.keys())
56
+
57
+
58
+ def add_model_config(path):
59
+ """ add model config path or file and update registry """
60
+ if not isinstance(path, Path):
61
+ path = Path(path)
62
+ _MODEL_CONFIG_PATHS.append(path)
63
+ _rescan_model_configs()
64
+
65
+
66
+ def get_model_config(model_name):
67
+ if model_name in _MODEL_CONFIGS:
68
+ return deepcopy(_MODEL_CONFIGS[model_name])
69
+ else:
70
+ return None
71
+
72
+
73
+ def get_tokenizer(model_name):
74
+ config = get_model_config(model_name)
75
+ tokenizer = HFTokenizer(config['text_cfg']['hf_tokenizer_name']) if 'hf_tokenizer_name' in config['text_cfg'] else tokenize
76
+ return tokenizer
77
+
78
+
79
+ def load_state_dict(checkpoint_path: str, map_location='cpu'):
80
+ checkpoint = torch.load(checkpoint_path, map_location=map_location)
81
+ if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
82
+ state_dict = checkpoint['state_dict']
83
+ else:
84
+ state_dict = checkpoint
85
+ if next(iter(state_dict.items()))[0].startswith('module'):
86
+ state_dict = {k[7:]: v for k, v in state_dict.items()}
87
+ return state_dict
88
+
89
+
90
+ def load_checkpoint(model, checkpoint_path, strict=True):
91
+ state_dict = load_state_dict(checkpoint_path)
92
+ # detect old format and make compatible with new format
93
+ if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'):
94
+ state_dict = convert_to_custom_text_state_dict(state_dict)
95
+ resize_pos_embed(state_dict, model)
96
+ incompatible_keys = model.load_state_dict(state_dict, strict=strict)
97
+ return incompatible_keys
98
+
99
+
100
+ def create_model(
101
+ model_name: str,
102
+ pretrained: Optional[str] = None,
103
+ precision: str = 'fp32',
104
+ device: Union[str, torch.device] = 'cpu',
105
+ jit: bool = False,
106
+ force_quick_gelu: bool = False,
107
+ force_custom_text: bool = False,
108
+ force_patch_dropout: Optional[float] = None,
109
+ pretrained_image: bool = False,
110
+ pretrained_hf: bool = True,
111
+ cache_dir: Optional[str] = None,
112
+ ):
113
+ model_name = model_name.replace('/', '-') # for callers using old naming with / in ViT names
114
+ if isinstance(device, str):
115
+ device = torch.device(device)
116
+
117
+ if pretrained and pretrained.lower() == 'openai':
118
+ logging.info(f'Loading pretrained {model_name} from OpenAI.')
119
+ model = load_openai_model(
120
+ model_name,
121
+ precision=precision,
122
+ device=device,
123
+ jit=jit,
124
+ cache_dir=cache_dir,
125
+ )
126
+ else:
127
+ model_cfg = get_model_config(model_name)
128
+ if model_cfg is not None:
129
+ logging.info(f'Loaded {model_name} model config.')
130
+ else:
131
+ logging.error(f'Model config for {model_name} not found; available models {list_models()}.')
132
+ raise RuntimeError(f'Model config for {model_name} not found.')
133
+
134
+ if force_quick_gelu:
135
+ # override for use of QuickGELU on non-OpenAI transformer models
136
+ model_cfg["quick_gelu"] = True
137
+
138
+ if force_patch_dropout is not None:
139
+ # override the default patch dropout value
140
+ model_cfg["vision_cfg"]["patch_dropout"] = force_patch_dropout
141
+
142
+ if pretrained_image:
143
+ if 'timm_model_name' in model_cfg.get('vision_cfg', {}):
144
+ # pretrained weight loading for timm models set via vision_cfg
145
+ model_cfg['vision_cfg']['timm_model_pretrained'] = True
146
+ else:
147
+ assert False, 'pretrained image towers currently only supported for timm models'
148
+
149
+ cast_dtype = get_cast_dtype(precision)
150
+ custom_text = model_cfg.pop('custom_text', False) or force_custom_text or ('hf_model_name' in model_cfg.get('text_cfg', {}))
151
+
152
+ if custom_text:
153
+ if 'hf_model_name' in model_cfg.get('text_cfg', {}):
154
+ model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf
155
+ model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype)
156
+ else:
157
+ model = CLIP(**model_cfg, cast_dtype=cast_dtype)
158
+
159
+ pretrained_cfg = {}
160
+ if pretrained:
161
+ checkpoint_path = ''
162
+ pretrained_cfg = get_pretrained_cfg(model_name, pretrained)
163
+ if pretrained_cfg:
164
+ checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir)
165
+ elif os.path.exists(pretrained):
166
+ checkpoint_path = pretrained
167
+
168
+ if checkpoint_path:
169
+ logging.info(f'Loading pretrained {model_name} weights ({pretrained}).')
170
+ load_checkpoint(model, checkpoint_path)
171
+ else:
172
+ error_str = (
173
+ f'Pretrained weights ({pretrained}) not found for model {model_name}.'
174
+ f'Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.')
175
+ logging.warning(error_str)
176
+ raise RuntimeError(error_str)
177
+
178
+ model.to(device=device)
179
+ if precision in ("fp16", "bf16"):
180
+ convert_weights_to_lp(model, dtype=torch.bfloat16 if precision == 'bf16' else torch.float16)
181
+
182
+ # set image / mean metadata from pretrained_cfg if available, or use default
183
+ model.visual.image_mean = pretrained_cfg.get('mean', None) or OPENAI_DATASET_MEAN
184
+ model.visual.image_std = pretrained_cfg.get('std', None) or OPENAI_DATASET_STD
185
+
186
+ if jit:
187
+ model = torch.jit.script(model)
188
+
189
+ return model
190
+
191
+
192
+ def create_model_and_transforms(
193
+ model_name: str,
194
+ pretrained: Optional[str] = None,
195
+ precision: str = 'fp32',
196
+ device: Union[str, torch.device] = 'cpu',
197
+ jit: bool = False,
198
+ force_quick_gelu: bool = False,
199
+ force_custom_text: bool = False,
200
+ force_patch_dropout: Optional[float] = None,
201
+ pretrained_image: bool = False,
202
+ pretrained_hf: bool = True,
203
+ image_mean: Optional[Tuple[float, ...]] = None,
204
+ image_std: Optional[Tuple[float, ...]] = None,
205
+ cache_dir: Optional[str] = None,
206
+ ):
207
+ model = create_model(
208
+ model_name,
209
+ pretrained,
210
+ precision=precision,
211
+ device=device,
212
+ jit=jit,
213
+ force_quick_gelu=force_quick_gelu,
214
+ force_custom_text=force_custom_text,
215
+ force_patch_dropout=force_patch_dropout,
216
+ pretrained_image=pretrained_image,
217
+ pretrained_hf=pretrained_hf,
218
+ cache_dir=cache_dir,
219
+ )
220
+
221
+ image_mean = image_mean or getattr(model.visual, 'image_mean', None)
222
+ image_std = image_std or getattr(model.visual, 'image_std', None)
223
+ preprocess_train = image_transform(
224
+ model.visual.image_size,
225
+ is_train=True,
226
+ mean=image_mean,
227
+ std=image_std
228
+ )
229
+ preprocess_val = image_transform(
230
+ model.visual.image_size,
231
+ is_train=False,
232
+ mean=image_mean,
233
+ std=image_std
234
+ )
235
+
236
+ return model, preprocess_train, preprocess_val
237
+
238
+
239
+ def create_model_from_pretrained(
240
+ model_name: str,
241
+ pretrained: str,
242
+ precision: str = 'fp32',
243
+ device: Union[str, torch.device] = 'cpu',
244
+ jit: bool = False,
245
+ force_quick_gelu: bool = False,
246
+ force_custom_text: bool = False,
247
+ return_transform: bool = True,
248
+ image_mean: Optional[Tuple[float, ...]] = None,
249
+ image_std: Optional[Tuple[float, ...]] = None,
250
+ cache_dir: Optional[str] = None,
251
+ ):
252
+ if not is_pretrained_cfg(model_name, pretrained) and not os.path.exists(pretrained):
253
+ raise RuntimeError(
254
+ f'{pretrained} is not a valid pretrained cfg or checkpoint for {model_name}.'
255
+ f' Use open_clip.list_pretrained() to find one.')
256
+
257
+ model = create_model(
258
+ model_name,
259
+ pretrained,
260
+ precision=precision,
261
+ device=device,
262
+ jit=jit,
263
+ force_quick_gelu=force_quick_gelu,
264
+ force_custom_text=force_custom_text,
265
+ cache_dir=cache_dir,
266
+ )
267
+
268
+ if not return_transform:
269
+ return model
270
+
271
+ image_mean = image_mean or getattr(model.visual, 'image_mean', None)
272
+ image_std = image_std or getattr(model.visual, 'image_std', None)
273
+ preprocess = image_transform(
274
+ model.visual.image_size,
275
+ is_train=False,
276
+ mean=image_mean,
277
+ std=image_std
278
+ )
279
+
280
+ return model, preprocess
open_clip/hf_configs.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # HF architecture dict:
2
+ arch_dict = {
3
+ # https://huggingface.co/docs/transformers/model_doc/roberta#roberta
4
+ "roberta": {
5
+ "config_names": {
6
+ "context_length": "max_position_embeddings",
7
+ "vocab_size": "vocab_size",
8
+ "width": "hidden_size",
9
+ "heads": "num_attention_heads",
10
+ "layers": "num_hidden_layers",
11
+ "layer_attr": "layer",
12
+ "token_embeddings_attr": "embeddings"
13
+ },
14
+ "pooler": "mean_pooler",
15
+ },
16
+ # https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig
17
+ "xlm-roberta": {
18
+ "config_names": {
19
+ "context_length": "max_position_embeddings",
20
+ "vocab_size": "vocab_size",
21
+ "width": "hidden_size",
22
+ "heads": "num_attention_heads",
23
+ "layers": "num_hidden_layers",
24
+ "layer_attr": "layer",
25
+ "token_embeddings_attr": "embeddings"
26
+ },
27
+ "pooler": "mean_pooler",
28
+ },
29
+ # https://huggingface.co/docs/transformers/model_doc/mt5#mt5
30
+ "mt5": {
31
+ "config_names": {
32
+ # unlimited seqlen
33
+ # https://github.com/google-research/text-to-text-transfer-transformer/issues/273
34
+ # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374
35
+ "context_length": "",
36
+ "vocab_size": "vocab_size",
37
+ "width": "d_model",
38
+ "heads": "num_heads",
39
+ "layers": "num_layers",
40
+ "layer_attr": "block",
41
+ "token_embeddings_attr": "embed_tokens"
42
+ },
43
+ "pooler": "mean_pooler",
44
+ },
45
+ }
open_clip/hf_model.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ huggingface model adapter
2
+
3
+ Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model.
4
+ """
5
+
6
+ import re
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from torch import TensorType
11
+
12
+ try:
13
+ import transformers
14
+ from transformers import AutoModel, AutoTokenizer, AutoConfig, PretrainedConfig
15
+ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, \
16
+ BaseModelOutputWithPoolingAndCrossAttentions
17
+ except ImportError as e:
18
+ transformers = None
19
+
20
+
21
+ class BaseModelOutput:
22
+ pass
23
+
24
+
25
+ class PretrainedConfig:
26
+ pass
27
+
28
+ from .hf_configs import arch_dict
29
+
30
+
31
+ # utils
32
+ def _camel2snake(s):
33
+ return re.sub(r'(?<!^)(?=[A-Z])', '_', s).lower()
34
+
35
+
36
+ # TODO: ?last - for gpt-like models
37
+ _POOLERS = {}
38
+
39
+
40
+ def register_pooler(cls):
41
+ """Decorator registering pooler class"""
42
+ _POOLERS[_camel2snake(cls.__name__)] = cls
43
+ return cls
44
+
45
+
46
+ @register_pooler
47
+ class MeanPooler(nn.Module):
48
+ """Mean pooling"""
49
+
50
+ def forward(self, x: BaseModelOutput, attention_mask: TensorType):
51
+ masked_output = x.last_hidden_state * attention_mask.unsqueeze(-1)
52
+ return masked_output.sum(dim=1) / attention_mask.sum(-1, keepdim=True)
53
+
54
+
55
+ @register_pooler
56
+ class MaxPooler(nn.Module):
57
+ """Max pooling"""
58
+
59
+ def forward(self, x: BaseModelOutput, attention_mask: TensorType):
60
+ masked_output = x.last_hidden_state.masked_fill(attention_mask.unsqueeze(-1), -torch.inf)
61
+ return masked_output.max(1).values
62
+
63
+
64
+ @register_pooler
65
+ class ClsPooler(nn.Module):
66
+ """CLS token pooling"""
67
+
68
+ def __init__(self, use_pooler_output=True):
69
+ super().__init__()
70
+ self.cls_token_position = 0
71
+ self.use_pooler_output = use_pooler_output
72
+
73
+ def forward(self, x: BaseModelOutput, attention_mask: TensorType):
74
+ if (self.use_pooler_output and
75
+ isinstance(x, (BaseModelOutputWithPooling, BaseModelOutputWithPoolingAndCrossAttentions)) and
76
+ (x.pooler_output is not None)
77
+ ):
78
+ return x.pooler_output
79
+
80
+ return x.last_hidden_state[:, self.cls_token_position, :]
81
+
82
+
83
+ class HFTextEncoder(nn.Module):
84
+ """HuggingFace model adapter"""
85
+
86
+ def __init__(
87
+ self,
88
+ model_name_or_path: str,
89
+ output_dim: int,
90
+ config: PretrainedConfig = None,
91
+ pooler_type: str = None,
92
+ proj: str = None,
93
+ pretrained: bool = True):
94
+ super().__init__()
95
+
96
+ self.output_dim = output_dim
97
+
98
+ # TODO: find better way to get this information
99
+ uses_transformer_pooler = (pooler_type == "cls_pooler")
100
+
101
+ if transformers is None:
102
+ raise RuntimeError("Please `pip install transformers` to use pre-trained HuggingFace models")
103
+ if config is None:
104
+ self.config = AutoConfig.from_pretrained(model_name_or_path)
105
+ create_func, model_args = (AutoModel.from_pretrained, model_name_or_path) if pretrained else (
106
+ AutoModel.from_config, self.config)
107
+ # TODO: do all model configs have this attribute? PretrainedConfig does so yes??
108
+ if hasattr(self.config, "is_encoder_decoder") and self.config.is_encoder_decoder:
109
+ self.transformer = create_func(model_args)
110
+ self.transformer = self.transformer.encoder
111
+ else:
112
+ self.transformer = create_func(model_args, add_pooling_layer=uses_transformer_pooler)
113
+ else:
114
+ self.config = config
115
+ self.transformer = AutoModel.from_config(config)
116
+
117
+ if pooler_type is None: # get default arch pooler
118
+ self.pooler = _POOLERS[(arch_dict[self.config.model_type]["pooler"])]()
119
+ else:
120
+ self.pooler = _POOLERS[pooler_type]()
121
+
122
+ d_model = getattr(self.config, arch_dict[self.config.model_type]["config_names"]["width"])
123
+ if (d_model == output_dim) and (proj is None): # do we always need a proj?
124
+ self.proj = nn.Identity()
125
+ elif proj == 'linear':
126
+ self.proj = nn.Linear(d_model, output_dim, bias=False)
127
+ elif proj == 'mlp':
128
+ hidden_size = (d_model + output_dim) // 2
129
+ self.proj = nn.Sequential(
130
+ nn.Linear(d_model, hidden_size, bias=False),
131
+ nn.GELU(),
132
+ nn.Linear(hidden_size, output_dim, bias=False),
133
+ )
134
+
135
+ def forward(self, x: TensorType) -> TensorType:
136
+ attn_mask = (x != self.config.pad_token_id).long()
137
+ out = self.transformer(input_ids=x, attention_mask=attn_mask)
138
+ pooled_out = self.pooler(out, attn_mask)
139
+
140
+ return self.proj(pooled_out)
141
+
142
+ def lock(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
143
+ if not unlocked_layers: # full freezing
144
+ for n, p in self.transformer.named_parameters():
145
+ p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
146
+ return
147
+
148
+ encoder = self.transformer.encoder if hasattr(self.transformer, 'encoder') else self.transformer
149
+ layer_list = getattr(encoder, arch_dict[self.config.model_type]["config_names"]["layer_attr"])
150
+ print(f"Unlocking {unlocked_layers}/{len(layer_list) + 1} layers of hf model")
151
+ embeddings = getattr(
152
+ self.transformer, arch_dict[self.config.model_type]["config_names"]["token_embeddings_attr"])
153
+ modules = [embeddings, *layer_list][:-unlocked_layers]
154
+ # freeze layers
155
+ for module in modules:
156
+ for n, p in module.named_parameters():
157
+ p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
158
+
159
+ @torch.jit.ignore
160
+ def set_grad_checkpointing(self, enable=True):
161
+ self.transformer.gradient_checkpointing_enable()
162
+
163
+ def init_parameters(self):
164
+ pass
open_clip/loss.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn import functional as F
4
+
5
+ try:
6
+ import torch.distributed.nn
7
+ from torch import distributed as dist
8
+ has_distributed = True
9
+ except ImportError:
10
+ has_distributed = False
11
+
12
+ try:
13
+ import horovod.torch as hvd
14
+ except ImportError:
15
+ hvd = None
16
+
17
+
18
+ def gather_features(
19
+ image_features,
20
+ text_features,
21
+ local_loss=False,
22
+ gather_with_grad=False,
23
+ rank=0,
24
+ world_size=1,
25
+ use_horovod=False
26
+ ):
27
+ assert has_distributed, 'torch.distributed did not import correctly, please use a PyTorch version with support.'
28
+ if use_horovod:
29
+ assert hvd is not None, 'Please install horovod'
30
+ if gather_with_grad:
31
+ all_image_features = hvd.allgather(image_features)
32
+ all_text_features = hvd.allgather(text_features)
33
+ else:
34
+ with torch.no_grad():
35
+ all_image_features = hvd.allgather(image_features)
36
+ all_text_features = hvd.allgather(text_features)
37
+ if not local_loss:
38
+ # ensure grads for local rank when all_* features don't have a gradient
39
+ gathered_image_features = list(all_image_features.chunk(world_size, dim=0))
40
+ gathered_text_features = list(all_text_features.chunk(world_size, dim=0))
41
+ gathered_image_features[rank] = image_features
42
+ gathered_text_features[rank] = text_features
43
+ all_image_features = torch.cat(gathered_image_features, dim=0)
44
+ all_text_features = torch.cat(gathered_text_features, dim=0)
45
+ else:
46
+ # We gather tensors from all gpus
47
+ if gather_with_grad:
48
+ all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0)
49
+ all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0)
50
+ else:
51
+ gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)]
52
+ gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)]
53
+ dist.all_gather(gathered_image_features, image_features)
54
+ dist.all_gather(gathered_text_features, text_features)
55
+ if not local_loss:
56
+ # ensure grads for local rank when all_* features don't have a gradient
57
+ gathered_image_features[rank] = image_features
58
+ gathered_text_features[rank] = text_features
59
+ all_image_features = torch.cat(gathered_image_features, dim=0)
60
+ all_text_features = torch.cat(gathered_text_features, dim=0)
61
+
62
+ return all_image_features, all_text_features
63
+
64
+
65
+ class ClipLoss(nn.Module):
66
+
67
+ def __init__(
68
+ self,
69
+ local_loss=False,
70
+ gather_with_grad=False,
71
+ cache_labels=False,
72
+ rank=0,
73
+ world_size=1,
74
+ use_horovod=False,
75
+ ):
76
+ super().__init__()
77
+ self.local_loss = local_loss
78
+ self.gather_with_grad = gather_with_grad
79
+ self.cache_labels = cache_labels
80
+ self.rank = rank
81
+ self.world_size = world_size
82
+ self.use_horovod = use_horovod
83
+
84
+ # cache state
85
+ self.prev_num_logits = 0
86
+ self.labels = {}
87
+
88
+ def forward(self, image_features, text_features, logit_scale):
89
+ device = image_features.device
90
+ if self.world_size > 1:
91
+ all_image_features, all_text_features = gather_features(
92
+ image_features, text_features,
93
+ self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod)
94
+
95
+ if self.local_loss:
96
+ logits_per_image = logit_scale * image_features @ all_text_features.T
97
+ logits_per_text = logit_scale * text_features @ all_image_features.T
98
+ else:
99
+ logits_per_image = logit_scale * all_image_features @ all_text_features.T
100
+ logits_per_text = logits_per_image.T
101
+ else:
102
+ logits_per_image = logit_scale * image_features @ text_features.T
103
+ logits_per_text = logit_scale * text_features @ image_features.T
104
+
105
+ # calculated ground-truth and cache if enabled
106
+ num_logits = logits_per_image.shape[0]
107
+ if self.prev_num_logits != num_logits or device not in self.labels:
108
+ labels = torch.arange(num_logits, device=device, dtype=torch.long)
109
+ if self.world_size > 1 and self.local_loss:
110
+ labels = labels + num_logits * self.rank
111
+ if self.cache_labels:
112
+ self.labels[device] = labels
113
+ self.prev_num_logits = num_logits
114
+ else:
115
+ labels = self.labels[device]
116
+
117
+ total_loss = (
118
+ F.cross_entropy(logits_per_image, labels) +
119
+ F.cross_entropy(logits_per_text, labels)
120
+ ) / 2
121
+ return total_loss
open_clip/model.py ADDED
@@ -0,0 +1,440 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ CLIP Model
2
+
3
+ Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4
+ """
5
+ from dataclasses import dataclass
6
+ import logging
7
+ import math
8
+ from typing import Optional, Tuple, Union
9
+
10
+ import numpy as np
11
+ import torch
12
+ import torch.nn.functional as F
13
+ from torch import nn
14
+ from torch.utils.checkpoint import checkpoint
15
+
16
+ from .hf_model import HFTextEncoder
17
+ from .modified_resnet import ModifiedResNet
18
+ from .timm_model import TimmModel
19
+ from .transformer import LayerNormFp32, LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer
20
+ from .utils import to_2tuple
21
+
22
+
23
+ @dataclass
24
+ class CLIPVisionCfg:
25
+ layers: Union[Tuple[int, int, int, int], int] = 12
26
+ width: int = 768
27
+ head_width: int = 64
28
+ mlp_ratio: float = 4.0
29
+ patch_size: int = 16
30
+ image_size: Union[Tuple[int, int], int] = 224
31
+ ls_init_value: Optional[float] = None # layer scale initial value
32
+ patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results
33
+ global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580)
34
+ timm_model_name: str = None # a valid model name overrides layers, width, patch_size
35
+ timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model
36
+ timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
37
+ timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '')
38
+ timm_proj_bias: bool = False # enable bias final projection
39
+
40
+
41
+ @dataclass
42
+ class CLIPTextCfg:
43
+ context_length: int = 77
44
+ vocab_size: int = 49408
45
+ width: int = 512
46
+ heads: int = 8
47
+ layers: int = 12
48
+ ls_init_value: Optional[float] = None # layer scale initial value
49
+ hf_model_name: str = None
50
+ hf_tokenizer_name: str = None
51
+ hf_model_pretrained: bool = True
52
+ proj: str = 'mlp'
53
+ pooler_type: str = 'mean_pooler'
54
+
55
+
56
+ def get_cast_dtype(precision: str):
57
+ cast_dtype = None
58
+ if precision == 'bf16':
59
+ cast_dtype = torch.bfloat16
60
+ elif precision == 'fp16':
61
+ cast_dtype = torch.float16
62
+ return cast_dtype
63
+
64
+
65
+ def _build_vision_tower(
66
+ embed_dim: int,
67
+ vision_cfg: CLIPVisionCfg,
68
+ quick_gelu: bool = False,
69
+ cast_dtype: Optional[torch.dtype] = None
70
+ ):
71
+ if isinstance(vision_cfg, dict):
72
+ vision_cfg = CLIPVisionCfg(**vision_cfg)
73
+
74
+ # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more
75
+ # memory efficient in recent PyTorch releases (>= 1.10).
76
+ # NOTE: timm models always use native GELU regardless of quick_gelu flag.
77
+ act_layer = QuickGELU if quick_gelu else nn.GELU
78
+
79
+ if vision_cfg.timm_model_name:
80
+ visual = TimmModel(
81
+ vision_cfg.timm_model_name,
82
+ pretrained=vision_cfg.timm_model_pretrained,
83
+ pool=vision_cfg.timm_pool,
84
+ proj=vision_cfg.timm_proj,
85
+ proj_bias=vision_cfg.timm_proj_bias,
86
+ embed_dim=embed_dim,
87
+ image_size=vision_cfg.image_size
88
+ )
89
+ act_layer = nn.GELU # so that text transformer doesn't use QuickGELU w/ timm models
90
+ elif isinstance(vision_cfg.layers, (tuple, list)):
91
+ vision_heads = vision_cfg.width * 32 // vision_cfg.head_width
92
+ visual = ModifiedResNet(
93
+ layers=vision_cfg.layers,
94
+ output_dim=embed_dim,
95
+ heads=vision_heads,
96
+ image_size=vision_cfg.image_size,
97
+ width=vision_cfg.width
98
+ )
99
+ else:
100
+ vision_heads = vision_cfg.width // vision_cfg.head_width
101
+ norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
102
+ visual = VisionTransformer(
103
+ image_size=vision_cfg.image_size,
104
+ patch_size=vision_cfg.patch_size,
105
+ width=vision_cfg.width,
106
+ layers=vision_cfg.layers,
107
+ heads=vision_heads,
108
+ mlp_ratio=vision_cfg.mlp_ratio,
109
+ ls_init_value=vision_cfg.ls_init_value,
110
+ patch_dropout=vision_cfg.patch_dropout,
111
+ global_average_pool=vision_cfg.global_average_pool,
112
+ output_dim=embed_dim,
113
+ act_layer=act_layer,
114
+ norm_layer=norm_layer,
115
+ )
116
+
117
+ return visual
118
+
119
+
120
+ def _build_text_tower(
121
+ embed_dim: int,
122
+ text_cfg: CLIPTextCfg,
123
+ quick_gelu: bool = False,
124
+ cast_dtype: Optional[torch.dtype] = None,
125
+ ):
126
+ if isinstance(text_cfg, dict):
127
+ text_cfg = CLIPTextCfg(**text_cfg)
128
+
129
+ if text_cfg.hf_model_name:
130
+ text = HFTextEncoder(
131
+ text_cfg.hf_model_name,
132
+ output_dim=embed_dim,
133
+ proj=text_cfg.proj,
134
+ pooler_type=text_cfg.pooler_type,
135
+ pretrained=text_cfg.hf_model_pretrained
136
+ )
137
+ else:
138
+ act_layer = QuickGELU if quick_gelu else nn.GELU
139
+ norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
140
+
141
+ text = TextTransformer(
142
+ context_length=text_cfg.context_length,
143
+ vocab_size=text_cfg.vocab_size,
144
+ width=text_cfg.width,
145
+ heads=text_cfg.heads,
146
+ layers=text_cfg.layers,
147
+ ls_init_value=text_cfg.ls_init_value,
148
+ output_dim=embed_dim,
149
+ act_layer=act_layer,
150
+ norm_layer=norm_layer,
151
+ )
152
+ return text
153
+
154
+
155
+ class CLIP(nn.Module):
156
+ def __init__(
157
+ self,
158
+ embed_dim: int,
159
+ vision_cfg: CLIPVisionCfg,
160
+ text_cfg: CLIPTextCfg,
161
+ quick_gelu: bool = False,
162
+ cast_dtype: Optional[torch.dtype] = None,
163
+ ):
164
+ super().__init__()
165
+ self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
166
+
167
+ text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
168
+ self.transformer = text.transformer
169
+ self.vocab_size = text.vocab_size
170
+ self.token_embedding = text.token_embedding
171
+ self.positional_embedding = text.positional_embedding
172
+ self.ln_final = text.ln_final
173
+ self.text_projection = text.text_projection
174
+ self.register_buffer('attn_mask', text.attn_mask, persistent=False)
175
+
176
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
177
+
178
+ def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
179
+ # lock image tower as per LiT - https://arxiv.org/abs/2111.07991
180
+ self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
181
+
182
+ @torch.jit.ignore
183
+ def set_grad_checkpointing(self, enable=True):
184
+ self.visual.set_grad_checkpointing(enable)
185
+ self.transformer.grad_checkpointing = enable
186
+
187
+ def encode_image(self, image, normalize: bool = False):
188
+ features = self.visual(image)
189
+ return F.normalize(features, dim=-1) if normalize else features
190
+
191
+ def encode_text(self, text, normalize: bool = False):
192
+ cast_dtype = self.transformer.get_cast_dtype()
193
+
194
+ x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
195
+
196
+ x = x + self.positional_embedding.to(cast_dtype)
197
+ x = x.permute(1, 0, 2) # NLD -> LND
198
+ x = self.transformer(x, attn_mask=self.attn_mask)
199
+ x = x.permute(1, 0, 2) # LND -> NLD
200
+ x = self.ln_final(x) # [batch_size, n_ctx, transformer.width]
201
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
202
+ x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
203
+ return F.normalize(x, dim=-1) if normalize else x
204
+
205
+ def forward(self, image, text):
206
+ image_features = self.encode_image(image, normalize=True)
207
+ text_features = self.encode_text(text, normalize=True)
208
+ return image_features, text_features, self.logit_scale.exp()
209
+
210
+ def encode_text_embedding(self, text_embedding, ids, avg_text=False):
211
+ cast_dtype = self.transformer.get_cast_dtype()
212
+
213
+ x = text_embedding + self.positional_embedding.to(cast_dtype)
214
+ x = x.permute(1, 0, 2) # NLD -> LND
215
+ x = self.transformer(x, attn_mask=self.attn_mask)
216
+ x = x.permute(1, 0, 2) # LND -> NLD
217
+ x = self.ln_final(x)
218
+
219
+ # x.shape = [batch_size, n_ctx, transformer.width]
220
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
221
+ if avg_text:
222
+ x = x[torch.arange(x.shape[0]), :ids.argmax(dim=-1)]
223
+ x[:, 1:-1]
224
+ x = x.mean(dim=1) @ self.text_projection
225
+ else:
226
+ x = x[torch.arange(x.shape[0]), ids.argmax(dim=-1)] @ self.text_projection
227
+
228
+ return x
229
+
230
+ def forward_text_embedding(self, embeddings, ids, image_features, avg_text=False, return_feature=False):
231
+ text_features = self.encode_text_embedding(embeddings, ids, avg_text=avg_text)
232
+
233
+ if return_feature:
234
+ return text_features
235
+
236
+ # normalized features
237
+ image_features = image_features / image_features.norm(dim=1, keepdim=True)
238
+ text_features = text_features / text_features.norm(dim=1, keepdim=True)
239
+
240
+ # cosine similarity as logits
241
+ # logit_scale = self.logit_scale.exp()
242
+ logits_per_image = image_features @ text_features.t()
243
+ logits_per_text = logits_per_image.t()
244
+
245
+ # shape = [global_batch_size, global_batch_size]
246
+ return logits_per_image, logits_per_text
247
+
248
+ class CustomTextCLIP(nn.Module):
249
+ def __init__(
250
+ self,
251
+ embed_dim: int,
252
+ vision_cfg: CLIPVisionCfg,
253
+ text_cfg: CLIPTextCfg,
254
+ quick_gelu: bool = False,
255
+ cast_dtype: Optional[torch.dtype] = None,
256
+ ):
257
+ super().__init__()
258
+ self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
259
+ self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
260
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
261
+
262
+ def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
263
+ # lock image tower as per LiT - https://arxiv.org/abs/2111.07991
264
+ self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
265
+
266
+ def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
267
+ self.text.lock(unlocked_layers, freeze_layer_norm)
268
+
269
+ @torch.jit.ignore
270
+ def set_grad_checkpointing(self, enable=True):
271
+ self.visual.set_grad_checkpointing(enable)
272
+ self.text.set_grad_checkpointing(enable)
273
+
274
+ def encode_image(self, image, normalize: bool = False):
275
+ features = self.visual(image)
276
+ return F.normalize(features, dim=-1) if normalize else features
277
+
278
+ def encode_text(self, text, normalize: bool = False):
279
+ features = self.text(text)
280
+ return F.normalize(features, dim=-1) if normalize else features
281
+
282
+ def forward(self, image, text):
283
+ image_features = self.encode_image(image, normalize=True)
284
+ text_features = self.encode_text(text, normalize=True)
285
+ return image_features, text_features, self.logit_scale.exp()
286
+
287
+
288
+ def convert_weights_to_lp(model: nn.Module, dtype=torch.float16):
289
+ """Convert applicable model parameters to low-precision (bf16 or fp16)"""
290
+
291
+ def _convert_weights(l):
292
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
293
+ l.weight.data = l.weight.data.to(dtype)
294
+ if l.bias is not None:
295
+ l.bias.data = l.bias.data.to(dtype)
296
+
297
+ if isinstance(l, (nn.MultiheadAttention, Attention)):
298
+ for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
299
+ tensor = getattr(l, attr)
300
+ if tensor is not None:
301
+ tensor.data = tensor.data.to(dtype)
302
+
303
+ for name in ["text_projection", "proj"]:
304
+ if hasattr(l, name):
305
+ attr = getattr(l, name)
306
+ if attr is not None:
307
+ attr.data = attr.data.to(dtype)
308
+
309
+ model.apply(_convert_weights)
310
+
311
+
312
+ convert_weights_to_fp16 = convert_weights_to_lp # backwards compat
313
+
314
+
315
+ # used to maintain checkpoint compatibility
316
+ def convert_to_custom_text_state_dict(state_dict: dict):
317
+ if 'text_projection' in state_dict:
318
+ # old format state_dict, move text tower -> .text
319
+ new_state_dict = {}
320
+ for k, v in state_dict.items():
321
+ if any(k.startswith(p) for p in (
322
+ 'text_projection',
323
+ 'positional_embedding',
324
+ 'token_embedding',
325
+ 'transformer',
326
+ 'ln_final',
327
+ )):
328
+ k = 'text.' + k
329
+ new_state_dict[k] = v
330
+ return new_state_dict
331
+ return state_dict
332
+
333
+
334
+ def build_model_from_openai_state_dict(
335
+ state_dict: dict,
336
+ quick_gelu=True,
337
+ cast_dtype=torch.float16,
338
+ ):
339
+ vit = "visual.proj" in state_dict
340
+
341
+ if vit:
342
+ vision_width = state_dict["visual.conv1.weight"].shape[0]
343
+ vision_layers = len(
344
+ [k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
345
+ vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
346
+ grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
347
+ image_size = vision_patch_size * grid_size
348
+ else:
349
+ counts: list = [
350
+ len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
351
+ vision_layers = tuple(counts)
352
+ vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
353
+ output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
354
+ vision_patch_size = None
355
+ assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
356
+ image_size = output_width * 32
357
+
358
+ embed_dim = state_dict["text_projection"].shape[1]
359
+ context_length = state_dict["positional_embedding"].shape[0]
360
+ vocab_size = state_dict["token_embedding.weight"].shape[0]
361
+ transformer_width = state_dict["ln_final.weight"].shape[0]
362
+ transformer_heads = transformer_width // 64
363
+ transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
364
+
365
+ vision_cfg = CLIPVisionCfg(
366
+ layers=vision_layers,
367
+ width=vision_width,
368
+ patch_size=vision_patch_size,
369
+ image_size=image_size,
370
+ )
371
+ text_cfg = CLIPTextCfg(
372
+ context_length=context_length,
373
+ vocab_size=vocab_size,
374
+ width=transformer_width,
375
+ heads=transformer_heads,
376
+ layers=transformer_layers
377
+ )
378
+ model = CLIP(
379
+ embed_dim,
380
+ vision_cfg=vision_cfg,
381
+ text_cfg=text_cfg,
382
+ quick_gelu=quick_gelu, # OpenAI models were trained with QuickGELU
383
+ cast_dtype=cast_dtype,
384
+ )
385
+
386
+ for key in ["input_resolution", "context_length", "vocab_size"]:
387
+ state_dict.pop(key, None)
388
+
389
+ convert_weights_to_fp16(model) # OpenAI state dicts are partially converted to float16
390
+ model.load_state_dict(state_dict)
391
+ return model.eval()
392
+
393
+
394
+ def trace_model(model, batch_size=256, device=torch.device('cpu')):
395
+ model.eval()
396
+ image_size = model.visual.image_size
397
+ example_images = torch.ones((batch_size, 3, image_size, image_size), device=device)
398
+ example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device)
399
+ model = torch.jit.trace_module(
400
+ model,
401
+ inputs=dict(
402
+ forward=(example_images, example_text),
403
+ encode_text=(example_text,),
404
+ encode_image=(example_images,)
405
+ ))
406
+ model.visual.image_size = image_size
407
+ return model
408
+
409
+
410
+ def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim=1):
411
+ # Rescale the grid of position embeddings when loading from state_dict
412
+ old_pos_embed = state_dict.get('visual.positional_embedding', None)
413
+ if old_pos_embed is None or not hasattr(model.visual, 'grid_size'):
414
+ return
415
+ grid_size = to_2tuple(model.visual.grid_size)
416
+ extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
417
+ new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
418
+ if new_seq_len == old_pos_embed.shape[0]:
419
+ return
420
+
421
+ if extra_tokens:
422
+ pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
423
+ else:
424
+ pos_emb_tok, pos_emb_img = None, old_pos_embed
425
+ old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))
426
+
427
+ logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size)
428
+ pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
429
+ pos_emb_img = F.interpolate(
430
+ pos_emb_img,
431
+ size=grid_size,
432
+ mode=interpolation,
433
+ align_corners=True,
434
+ )
435
+ pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
436
+ if pos_emb_tok is not None:
437
+ new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
438
+ else:
439
+ new_pos_embed = pos_emb_img
440
+ state_dict['visual.positional_embedding'] = new_pos_embed
open_clip/model_configs/RN101-quickgelu.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 512,
3
+ "quick_gelu": true,
4
+ "vision_cfg": {
5
+ "image_size": 224,
6
+ "layers": [
7
+ 3,
8
+ 4,
9
+ 23,
10
+ 3
11
+ ],
12
+ "width": 64,
13
+ "patch_size": null
14
+ },
15
+ "text_cfg": {
16
+ "context_length": 77,
17
+ "vocab_size": 49408,
18
+ "width": 512,
19
+ "heads": 8,
20
+ "layers": 12
21
+ }
22
+ }
open_clip/model_configs/RN101.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 512,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": [
6
+ 3,
7
+ 4,
8
+ 23,
9
+ 3
10
+ ],
11
+ "width": 64,
12
+ "patch_size": null
13
+ },
14
+ "text_cfg": {
15
+ "context_length": 77,
16
+ "vocab_size": 49408,
17
+ "width": 512,
18
+ "heads": 8,
19
+ "layers": 12
20
+ }
21
+ }
open_clip/model_configs/RN50-quickgelu.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "quick_gelu": true,
4
+ "vision_cfg": {
5
+ "image_size": 224,
6
+ "layers": [
7
+ 3,
8
+ 4,
9
+ 6,
10
+ 3
11
+ ],
12
+ "width": 64,
13
+ "patch_size": null
14
+ },
15
+ "text_cfg": {
16
+ "context_length": 77,
17
+ "vocab_size": 49408,
18
+ "width": 512,
19
+ "heads": 8,
20
+ "layers": 12
21
+ }
22
+ }
open_clip/model_configs/RN50.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": [
6
+ 3,
7
+ 4,
8
+ 6,
9
+ 3
10
+ ],
11
+ "width": 64,
12
+ "patch_size": null
13
+ },
14
+ "text_cfg": {
15
+ "context_length": 77,
16
+ "vocab_size": 49408,
17
+ "width": 512,
18
+ "heads": 8,
19
+ "layers": 12
20
+ }
21
+ }
open_clip/model_configs/RN50x16.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 768,
3
+ "vision_cfg": {
4
+ "image_size": 384,
5
+ "layers": [
6
+ 6,
7
+ 8,
8
+ 18,
9
+ 8
10
+ ],
11
+ "width": 96,
12
+ "patch_size": null
13
+ },
14
+ "text_cfg": {
15
+ "context_length": 77,
16
+ "vocab_size": 49408,
17
+ "width": 768,
18
+ "heads": 12,
19
+ "layers": 12
20
+ }
21
+ }
open_clip/model_configs/RN50x4.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 640,
3
+ "vision_cfg": {
4
+ "image_size": 288,
5
+ "layers": [
6
+ 4,
7
+ 6,
8
+ 10,
9
+ 6
10
+ ],
11
+ "width": 80,
12
+ "patch_size": null
13
+ },
14
+ "text_cfg": {
15
+ "context_length": 77,
16
+ "vocab_size": 49408,
17
+ "width": 640,
18
+ "heads": 10,
19
+ "layers": 12
20
+ }
21
+ }
open_clip/model_configs/RN50x64.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "vision_cfg": {
4
+ "image_size": 448,
5
+ "layers": [
6
+ 3,
7
+ 15,
8
+ 36,
9
+ 10
10
+ ],
11
+ "width": 128,
12
+ "patch_size": null
13
+ },
14
+ "text_cfg": {
15
+ "context_length": 77,
16
+ "vocab_size": 49408,
17
+ "width": 1024,
18
+ "heads": 16,
19
+ "layers": 12
20
+ }
21
+ }
open_clip/model_configs/ViT-B-16-plus-240.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 640,
3
+ "vision_cfg": {
4
+ "image_size": 240,
5
+ "layers": 12,
6
+ "width": 896,
7
+ "patch_size": 16
8
+ },
9
+ "text_cfg": {
10
+ "context_length": 77,
11
+ "vocab_size": 49408,
12
+ "width": 640,
13
+ "heads": 10,
14
+ "layers": 12
15
+ }
16
+ }
open_clip/model_configs/ViT-B-16-plus.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 640,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": 12,
6
+ "width": 896,
7
+ "patch_size": 16
8
+ },
9
+ "text_cfg": {
10
+ "context_length": 77,
11
+ "vocab_size": 49408,
12
+ "width": 640,
13
+ "heads": 10,
14
+ "layers": 12
15
+ }
16
+ }
open_clip/model_configs/ViT-B-16.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 512,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": 12,
6
+ "width": 768,
7
+ "patch_size": 16
8
+ },
9
+ "text_cfg": {
10
+ "context_length": 77,
11
+ "vocab_size": 49408,
12
+ "width": 512,
13
+ "heads": 8,
14
+ "layers": 12
15
+ }
16
+ }
open_clip/model_configs/ViT-B-32-plus-256.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 640,
3
+ "vision_cfg": {
4
+ "image_size": 256,
5
+ "layers": 12,
6
+ "width": 896,
7
+ "patch_size": 32
8
+ },
9
+ "text_cfg": {
10
+ "context_length": 77,
11
+ "vocab_size": 49408,
12
+ "width": 640,
13
+ "heads": 10,
14
+ "layers": 12
15
+ }
16
+ }
open_clip/model_configs/ViT-B-32-quickgelu.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 512,
3
+ "quick_gelu": true,
4
+ "vision_cfg": {
5
+ "image_size": 224,
6
+ "layers": 12,
7
+ "width": 768,
8
+ "patch_size": 32
9
+ },
10
+ "text_cfg": {
11
+ "context_length": 77,
12
+ "vocab_size": 49408,
13
+ "width": 512,
14
+ "heads": 8,
15
+ "layers": 12
16
+ }
17
+ }
open_clip/model_configs/ViT-B-32.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 512,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": 12,
6
+ "width": 768,
7
+ "patch_size": 32
8
+ },
9
+ "text_cfg": {
10
+ "context_length": 77,
11
+ "vocab_size": 49408,
12
+ "width": 512,
13
+ "heads": 8,
14
+ "layers": 12
15
+ }
16
+ }
open_clip/model_configs/ViT-H-14.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": 32,
6
+ "width": 1280,
7
+ "head_width": 80,
8
+ "patch_size": 14
9
+ },
10
+ "text_cfg": {
11
+ "context_length": 77,
12
+ "vocab_size": 49408,
13
+ "width": 1024,
14
+ "heads": 16,
15
+ "layers": 24
16
+ }
17
+ }
open_clip/model_configs/ViT-H-16.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": 32,
6
+ "width": 1280,
7
+ "head_width": 80,
8
+ "patch_size": 16
9
+ },
10
+ "text_cfg": {
11
+ "context_length": 77,
12
+ "vocab_size": 49408,
13
+ "width": 1024,
14
+ "heads": 16,
15
+ "layers": 24
16
+ }
17
+ }
open_clip/model_configs/ViT-L-14-280.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 768,
3
+ "vision_cfg": {
4
+ "image_size": 280,
5
+ "layers": 24,
6
+ "width": 1024,
7
+ "patch_size": 14
8
+ },
9
+ "text_cfg": {
10
+ "context_length": 77,
11
+ "vocab_size": 49408,
12
+ "width": 768,
13
+ "heads": 12,
14
+ "layers": 12
15
+ }
16
+ }
open_clip/model_configs/ViT-L-14-336.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 768,
3
+ "vision_cfg": {
4
+ "image_size": 336,
5
+ "layers": 24,
6
+ "width": 1024,
7
+ "patch_size": 14
8
+ },
9
+ "text_cfg": {
10
+ "context_length": 77,
11
+ "vocab_size": 49408,
12
+ "width": 768,
13
+ "heads": 12,
14
+ "layers": 12
15
+ }
16
+ }
open_clip/model_configs/ViT-L-14.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 768,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": 24,
6
+ "width": 1024,
7
+ "patch_size": 14
8
+ },
9
+ "text_cfg": {
10
+ "context_length": 77,
11
+ "vocab_size": 49408,
12
+ "width": 768,
13
+ "heads": 12,
14
+ "layers": 12
15
+ }
16
+ }
open_clip/model_configs/ViT-L-16-320.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 768,
3
+ "vision_cfg": {
4
+ "image_size": 320,
5
+ "layers": 24,
6
+ "width": 1024,
7
+ "patch_size": 16
8
+ },
9
+ "text_cfg": {
10
+ "context_length": 77,
11
+ "vocab_size": 49408,
12
+ "width": 768,
13
+ "heads": 12,
14
+ "layers": 12
15
+ }
16
+ }
open_clip/model_configs/ViT-L-16.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 768,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": 24,
6
+ "width": 1024,
7
+ "patch_size": 16
8
+ },
9
+ "text_cfg": {
10
+ "context_length": 77,
11
+ "vocab_size": 49408,
12
+ "width": 768,
13
+ "heads": 12,
14
+ "layers": 12
15
+ }
16
+ }
open_clip/model_configs/ViT-M-16-alt.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 384,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": 12,
6
+ "width": 512,
7
+ "patch_size": 16,
8
+ "ls_init_value": 1e-4
9
+ },
10
+ "text_cfg": {
11
+ "context_length": 77,
12
+ "vocab_size": 49408,
13
+ "width": 384,
14
+ "heads": 6,
15
+ "layers": 12
16
+ }
17
+ }
open_clip/model_configs/ViT-M-16.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 512,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": 12,
6
+ "width": 512,
7
+ "patch_size": 16
8
+ },
9
+ "text_cfg": {
10
+ "context_length": 77,
11
+ "vocab_size": 49408,
12
+ "width": 512,
13
+ "heads": 8,
14
+ "layers": 12
15
+ }
16
+ }
open_clip/model_configs/ViT-M-32-alt.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 384,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": 12,
6
+ "width": 512,
7
+ "patch_size": 32
8
+ },
9
+ "text_cfg": {
10
+ "context_length": 77,
11
+ "vocab_size": 49408,
12
+ "width": 384,
13
+ "heads": 6,
14
+ "layers": 12
15
+ }
16
+ }
open_clip/model_configs/ViT-M-32.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 512,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": 12,
6
+ "width": 512,
7
+ "patch_size": 32
8
+ },
9
+ "text_cfg": {
10
+ "context_length": 77,
11
+ "vocab_size": 49408,
12
+ "width": 512,
13
+ "heads": 8,
14
+ "layers": 12
15
+ }
16
+ }
open_clip/model_configs/ViT-S-16-alt.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 256,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": 12,
6
+ "width": 384,
7
+ "patch_size": 16
8
+ },
9
+ "text_cfg": {
10
+ "context_length": 77,
11
+ "vocab_size": 49408,
12
+ "width": 256,
13
+ "heads": 4,
14
+ "layers": 10
15
+ }
16
+ }
open_clip/model_configs/ViT-S-16.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 384,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": 12,
6
+ "width": 384,
7
+ "patch_size": 16
8
+ },
9
+ "text_cfg": {
10
+ "context_length": 77,
11
+ "vocab_size": 49408,
12
+ "width": 384,
13
+ "heads": 6,
14
+ "layers": 12
15
+ }
16
+ }
open_clip/model_configs/ViT-S-32-alt.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 256,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": 12,
6
+ "width": 384,
7
+ "patch_size": 32
8
+ },
9
+ "text_cfg": {
10
+ "context_length": 77,
11
+ "vocab_size": 49408,
12
+ "width": 256,
13
+ "heads": 4,
14
+ "layers": 10
15
+ }
16
+ }
open_clip/model_configs/ViT-S-32.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 384,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": 12,
6
+ "width": 384,
7
+ "patch_size": 32
8
+ },
9
+ "text_cfg": {
10
+ "context_length": 77,
11
+ "vocab_size": 49408,
12
+ "width": 384,
13
+ "heads": 6,
14
+ "layers": 12
15
+ }
16
+ }
open_clip/model_configs/ViT-bigG-14.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1280,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": 48,
6
+ "width": 1664,
7
+ "head_width": 104,
8
+ "mlp_ratio": 4.9231,
9
+ "patch_size": 14
10
+ },
11
+ "text_cfg": {
12
+ "context_length": 77,
13
+ "vocab_size": 49408,
14
+ "width": 1280,
15
+ "heads": 20,
16
+ "layers": 32
17
+ }
18
+ }
open_clip/model_configs/ViT-e-14.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1280,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": 56,
6
+ "width": 1792,
7
+ "head_width": 112,
8
+ "mlp_ratio": 8.5715,
9
+ "patch_size": 14
10
+ },
11
+ "text_cfg": {
12
+ "context_length": 77,
13
+ "vocab_size": 49408,
14
+ "width": 1280,
15
+ "heads": 20,
16
+ "layers": 36
17
+ }
18
+ }
open_clip/model_configs/ViT-g-14.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": 40,
6
+ "width": 1408,
7
+ "head_width": 88,
8
+ "mlp_ratio": 4.3637,
9
+ "patch_size": 14
10
+ },
11
+ "text_cfg": {
12
+ "context_length": 77,
13
+ "vocab_size": 49408,
14
+ "width": 1024,
15
+ "heads": 16,
16
+ "layers": 24
17
+ }
18
+ }
open_clip/model_configs/mt5-base-ViT-B-32.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 512,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": 12,
6
+ "width": 768,
7
+ "patch_size": 32
8
+ },
9
+ "text_cfg": {
10
+ "hf_model_name": "google/mt5-base",
11
+ "hf_tokenizer_name": "google/mt5-base",
12
+ "proj": "mlp",
13
+ "pooler_type": "mean_pooler"
14
+ }
15
+ }
open_clip/model_configs/mt5-xl-ViT-H-14.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": 32,
6
+ "width": 1280,
7
+ "head_width": 80,
8
+ "patch_size": 14
9
+ },
10
+ "text_cfg": {
11
+ "hf_model_name": "google/mt5-xl",
12
+ "hf_tokenizer_name": "google/mt5-xl",
13
+ "proj": "mlp",
14
+ "pooler_type": "mean_pooler"
15
+ }
16
+ }
open_clip/model_configs/roberta-ViT-B-32.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 512,
3
+ "quick_gelu": true,
4
+ "vision_cfg": {
5
+ "image_size": 224,
6
+ "layers": 12,
7
+ "width": 768,
8
+ "patch_size": 32
9
+ },
10
+ "text_cfg": {
11
+ "hf_model_name": "roberta-base",
12
+ "hf_tokenizer_name": "roberta-base",
13
+ "proj": "mlp",
14
+ "pooler_type": "mean_pooler"
15
+ }
16
+ }
open_clip/model_configs/timm-convnext_base.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 512,
3
+ "vision_cfg": {
4
+ "timm_model_name": "convnext_base",
5
+ "timm_model_pretrained": false,
6
+ "timm_pool": "",
7
+ "timm_proj": "linear",
8
+ "image_size": 224
9
+ },
10
+ "text_cfg": {
11
+ "context_length": 77,
12
+ "vocab_size": 49408,
13
+ "width": 512,
14
+ "heads": 8,
15
+ "layers": 12
16
+ }
17
+ }
open_clip/model_configs/timm-convnext_base_w.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 640,
3
+ "vision_cfg": {
4
+ "timm_model_name": "convnext_base",
5
+ "timm_model_pretrained": false,
6
+ "timm_pool": "",
7
+ "timm_proj": "linear",
8
+ "image_size": 256
9
+ },
10
+ "text_cfg": {
11
+ "context_length": 77,
12
+ "vocab_size": 49408,
13
+ "width": 640,
14
+ "heads": 10,
15
+ "layers": 12
16
+ }
17
+ }
open_clip/model_configs/timm-convnext_large.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 768,
3
+ "vision_cfg": {
4
+ "timm_model_name": "convnext_large",
5
+ "timm_model_pretrained": false,
6
+ "timm_pool": "",
7
+ "timm_proj": "linear",
8
+ "image_size": 224
9
+ },
10
+ "text_cfg": {
11
+ "context_length": 77,
12
+ "vocab_size": 49408,
13
+ "width": 768,
14
+ "heads": 12,
15
+ "layers": 12
16
+ }
17
+ }
open_clip/model_configs/timm-convnext_small.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 512,
3
+ "vision_cfg": {
4
+ "timm_model_name": "convnext_small",
5
+ "timm_model_pretrained": false,
6
+ "timm_pool": "",
7
+ "timm_proj": "linear",
8
+ "image_size": 224
9
+ },
10
+ "text_cfg": {
11
+ "context_length": 77,
12
+ "vocab_size": 49408,
13
+ "width": 512,
14
+ "heads": 8,
15
+ "layers": 12
16
+ }
17
+ }
open_clip/model_configs/timm-convnext_tiny.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "vision_cfg": {
4
+ "timm_model_name": "convnext_tiny",
5
+ "timm_model_pretrained": false,
6
+ "timm_pool": "",
7
+ "timm_proj": "linear",
8
+ "image_size": 224
9
+ },
10
+ "text_cfg": {
11
+ "context_length": 77,
12
+ "vocab_size": 49408,
13
+ "width": 512,
14
+ "heads": 8,
15
+ "layers": 12
16
+ }
17
+ }
open_clip/model_configs/timm-convnext_xlarge.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "vision_cfg": {
4
+ "timm_model_name": "convnext_xlarge",
5
+ "timm_model_pretrained": false,
6
+ "timm_pool": "",
7
+ "timm_proj": "linear",
8
+ "image_size": 224
9
+ },
10
+ "text_cfg": {
11
+ "context_length": 77,
12
+ "vocab_size": 49408,
13
+ "width": 1024,
14
+ "heads": 16,
15
+ "layers": 16
16
+ }
17
+ }
open_clip/model_configs/timm-convnext_xxlarge.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "vision_cfg": {
4
+ "timm_model_name": "convnext_xxlarge",
5
+ "timm_model_pretrained": false,
6
+ "timm_pool": "",
7
+ "timm_proj": "linear",
8
+ "image_size": 256
9
+ },
10
+ "text_cfg": {
11
+ "context_length": 77,
12
+ "vocab_size": 49408,
13
+ "width": 1024,
14
+ "heads": 16,
15
+ "layers": 24
16
+ }
17
+ }