kittchy commited on
Commit
30099ac
1 Parent(s): d544090

[ADD] image_vector_search

Browse files
.env.default ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ LOG_LEVEL="info"
2
+ DB_USERNAME=
3
+ DB_PASSWORD=
4
+ DB_HOST=
5
+ BUCKET_NAME=
6
+ AWS_ACCESS_KEY_ID=
7
+ AWS_SECRET_ACCESS_KEY=
.gitignore ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # python generated files
2
+ __pycache__/
3
+ *.py[oc]
4
+ build/
5
+ dist/
6
+ wheels/
7
+ *.egg-info
8
+
9
+ # venv
10
+ .venv
11
+
12
+ .env
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.12.3
app.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ from model import MLModel
4
+ from dotenv import load_dotenv
5
+ from os.path import join, dirname
6
+
7
+ dotenv_path = join(dirname(__file__), ".env")
8
+ load_dotenv(dotenv_path)
9
+
10
+
11
+ URL = "localhost:8000"
12
+ model = MLModel()
13
+
14
+
15
+ def save(image_path: str):
16
+ _ = model.save(image_path)
17
+ return
18
+
19
+
20
+ def search(prompt: str):
21
+ urls = model.search(prompt)
22
+ return urls[0]
23
+
24
+
25
+ with gr.Blocks() as app:
26
+ # Rowでレイアウトを定義
27
+ input = gr.Textbox(placeholder="可愛いワンコ", label="検索")
28
+ output = gr.Image(type="filepath")
29
+ btn = gr.Button("検索")
30
+ btn.click(fn=search, inputs=input, outputs=output)
31
+
32
+ save_image = gr.Image()
33
+ save_btn = gr.Button("保存")
34
+ save_btn.click(fn=save, inputs=save_image, outputs=None)
35
+ app.launch()
db_session.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pymongo import MongoClient
3
+ from pymongo.database import Database
4
+
5
+ from pymongo.mongo_client import MongoClient
6
+ from pymongo.server_api import ServerApi
7
+
8
+ from dotenv import load_dotenv
9
+
10
+ load_dotenv(verbose=True)
11
+
12
+ USER_NAME = os.environ.get("DB_USERNAME")
13
+ PASSWORD = os.environ.get("DB_PASSWORD")
14
+ HOST = os.environ.get("DB_HOST")
15
+
16
+
17
+ MONGO_DATABASE_URL = f"mongodb+srv://{USER_NAME}:{PASSWORD}@{HOST}"
18
+ print(MONGO_DATABASE_URL)
19
+ client = MongoClient(MONGO_DATABASE_URL, server_api=ServerApi("1"))
20
+
21
+
22
+ def get_db() -> Database:
23
+ """DB取得
24
+
25
+ Returns:
26
+ Database: データべース
27
+ """
28
+ db: Database = client.db
29
+ return db
japanese_clip/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 rinna Co., Ltd.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from .clip import CLIPModel, CLIPConfig
17
+ from .cloob import CLOOBModel, CLOOBConfig
18
+ from .auto_model import load, available_models
19
+ from .tokenizer import load_tokenizer, tokenize
japanese_clip/auto_model.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 rinna Co., Ltd.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Union
17
+ import json
18
+ import torch
19
+ from torchvision import transforms as T
20
+ from huggingface_hub import hf_hub_url, cached_download
21
+ import os
22
+
23
+ from .clip import CLIPModel
24
+ from .cloob import CLOOBModel
25
+
26
+ # TODO: Fill in repo_ids
27
+ MODELS = {
28
+ 'rinna/japanese-clip-vit-b-16': {
29
+ 'repo_id': 'rinna/japanese-clip-vit-b-16',
30
+ 'model_class': CLIPModel,
31
+ },
32
+ 'rinna/japanese-cloob-vit-b-16': {
33
+ 'repo_id': 'rinna/japanese-cloob-vit-b-16',
34
+ 'model_class': CLOOBModel,
35
+ }
36
+ }
37
+ MODEL_CLASSES = {
38
+ "cloob": CLOOBModel,
39
+ "clip": CLIPModel,
40
+ }
41
+ MODEL_FILE = "pytorch_model.bin"
42
+ CONFIG_FILE = "config.json"
43
+
44
+
45
+ def available_models():
46
+ return list(MODELS.keys())
47
+
48
+
49
+ def _convert_to_rgb(image):
50
+ return image.convert('RGB')
51
+
52
+
53
+ def _transform(image_size):
54
+ return T.Compose([
55
+ T.Resize(image_size, interpolation=T.InterpolationMode.BILINEAR),
56
+ T.CenterCrop(image_size),
57
+ _convert_to_rgb,
58
+ T.ToTensor(),
59
+ T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711),)
60
+ ])
61
+
62
+
63
+ def _download(repo_id: str, cache_dir: str):
64
+ config_file_url = hf_hub_url(repo_id=repo_id, filename=CONFIG_FILE)
65
+ cached_download(config_file_url, cache_dir=cache_dir)
66
+ model_file_url = hf_hub_url(repo_id=repo_id, filename=MODEL_FILE)
67
+ cached_download(model_file_url, cache_dir=cache_dir)
68
+
69
+
70
+ def load(
71
+ model_name: str,
72
+ device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu",
73
+ **kwargs
74
+ ):
75
+ """
76
+ Args:
77
+ model_name: model unique name or path to pre-downloaded model
78
+ device: device to put the loaded model
79
+ kwargs: kwargs for huggingface pretrained model class
80
+ Return:
81
+ (torch.nn.Module, A torchvision transform)
82
+ """
83
+ if model_name in MODELS.keys():
84
+ ModelClass = CLIPModel if 'clip' in model_name else CLOOBModel
85
+ elif os.path.exists(model_name):
86
+ assert os.path.exists(os.path.join(model_name, CONFIG_FILE))
87
+ with open(os.path.join(model_name, CONFIG_FILE), "r", encoding="utf-8") as f:
88
+ j = json.load(f)
89
+ ModelClass = MODEL_CLASSES[j["model_type"]]
90
+ else:
91
+ RuntimeError(f"Model {model_name} not found; available models = {available_models()}")
92
+
93
+ model = ModelClass.from_pretrained(model_name, **kwargs)
94
+ model = model.eval().requires_grad_(False).to(device)
95
+ return model, _transform(model.config.vision_config.image_size)
japanese_clip/clip/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 rinna Co., Ltd.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ from .modeling_clip import *
16
+ from .configuration_clip import *
japanese_clip/clip/configuration_clip.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 rinna Co., Ltd.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ CLIP model configuration"""
16
+ import logging
17
+ import copy
18
+ import os
19
+ from typing import Union
20
+
21
+ import numpy as np
22
+ from transformers import AutoConfig, PretrainedConfig
23
+
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ class CLIPTextConfig(PretrainedConfig):
29
+ model_type = "clip_text_model"
30
+
31
+ def __init__(
32
+ self,
33
+ vocab_size=49408,
34
+ hidden_size=512,
35
+ intermediate_size=2048,
36
+ num_hidden_layers=12,
37
+ num_attention_heads=8,
38
+ max_position_embeddings=77,
39
+ hidden_act="quick_gelu",
40
+ layer_norm_eps=0.00001,
41
+ dropout=0.0,
42
+ attention_dropout=0.0,
43
+ initializer_range=0.02,
44
+ initializer_factor=1.0,
45
+ pad_token_id=1,
46
+ bos_token_id=0,
47
+ eos_token_id=2,
48
+ **kwargs
49
+ ):
50
+ super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
51
+
52
+ self.vocab_size = vocab_size
53
+ self.hidden_size = hidden_size
54
+ self.intermediate_size = intermediate_size
55
+ self.dropout = dropout
56
+ self.num_hidden_layers = num_hidden_layers
57
+ self.num_attention_heads = num_attention_heads
58
+ self.max_position_embeddings = max_position_embeddings
59
+ self.layer_norm_eps = layer_norm_eps
60
+ self.hidden_act = hidden_act
61
+ self.initializer_range = initializer_range
62
+ self.initializer_factor = initializer_factor
63
+ self.attention_dropout = attention_dropout
64
+
65
+ @classmethod
66
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
67
+
68
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
69
+
70
+ # get the text config dict if we are loading from CLIPConfig
71
+ if config_dict.get("model_type") == "clip":
72
+ config_dict = config_dict["text_config"]
73
+
74
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
75
+ logger.warning(
76
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
77
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
78
+ )
79
+
80
+ return cls.from_dict(config_dict, **kwargs)
81
+
82
+
83
+ class CLIPVisionConfig(PretrainedConfig):
84
+ model_type = "clip_vision_model"
85
+
86
+ def __init__(
87
+ self,
88
+ hidden_size=768,
89
+ intermediate_size=3072,
90
+ num_hidden_layers=12,
91
+ num_attention_heads=12,
92
+ image_size=224,
93
+ patch_size=32,
94
+ hidden_act="quick_gelu",
95
+ layer_norm_eps=0.00001,
96
+ dropout=0.0,
97
+ attention_dropout=0.0,
98
+ initializer_range=0.02,
99
+ initializer_factor=1.0,
100
+ **kwargs
101
+ ):
102
+ super().__init__(**kwargs)
103
+
104
+ self.hidden_size = hidden_size
105
+ self.intermediate_size = intermediate_size
106
+ self.dropout = dropout
107
+ self.num_hidden_layers = num_hidden_layers
108
+ self.num_attention_heads = num_attention_heads
109
+ self.patch_size = patch_size
110
+ self.image_size = image_size
111
+ self.initializer_range = initializer_range
112
+ self.initializer_factor = initializer_factor
113
+ self.attention_dropout = attention_dropout
114
+ self.layer_norm_eps = layer_norm_eps
115
+ self.hidden_act = hidden_act
116
+
117
+ @classmethod
118
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
119
+
120
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
121
+
122
+ # get the vision config dict if we are loading from CLIPConfig
123
+ if config_dict.get("model_type") == "clip":
124
+ config_dict = config_dict["vision_config"]
125
+
126
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
127
+ logger.warning(
128
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
129
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
130
+ )
131
+
132
+ return cls.from_dict(config_dict, **kwargs)
133
+
134
+
135
+ class CLIPConfig(PretrainedConfig):
136
+ r"""
137
+ [`CLIPConfig`] is the configuration class to store the configuration of a [`CLIPModel`]. It is used to instantiate
138
+ CLIP model according to the specified arguments, defining the text model and vision model configs.
139
+
140
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
141
+ documentation from [`PretrainedConfig`] for more information.
142
+
143
+ Args:
144
+ text_config_dict (`dict`, *optional*):
145
+ Dictionary of configuration options used to initialize [`CLIPTextConfig`].
146
+ vision_config_dict (`dict`, *optional*):
147
+ Dictionary of configuration options used to initialize [`CLIPVisionConfig`].
148
+ projection_dim (`int`, *optional*, defaults to 512):
149
+ Dimentionality of text and vision projection layers.
150
+ logit_scale_init_value (`float`, *optional*, defaults to 2.6592):
151
+ The inital value of the *logit_scale* paramter. Default is used as per the original CLIP implementation.
152
+ kwargs (*optional*):
153
+ Dictionary of keyword arguments.
154
+ """
155
+
156
+ model_type = "clip"
157
+ is_composition = True
158
+
159
+ def __init__(
160
+ self,
161
+ text_config=None,
162
+ vision_config=None,
163
+ projection_dim=512,
164
+ logit_scale_init_value=None,
165
+ **kwargs
166
+ ):
167
+ super().__init__(text_config=text_config, vision_config=vision_config, **kwargs)
168
+
169
+ if vision_config is None:
170
+ raise ValueError("`vision_config` can not be `None`.")
171
+
172
+ if text_config is None:
173
+ raise ValueError("`text_config` can not be `None`.")
174
+
175
+ vision_model_type = vision_config.pop("model_type")
176
+ text_model_type = text_config.pop("model_type")
177
+
178
+ if vision_model_type == "clip_vision_model":
179
+ self.vision_config = CLIPVisionConfig(**vision_config)
180
+ else:
181
+ self.vision_config = AutoConfig.for_model(
182
+ vision_model_type, **vision_config
183
+ )
184
+
185
+ if text_model_type == "clip_text_model":
186
+ self.text_config = CLIPTextConfig(**text_config)
187
+ else:
188
+ self.text_config = AutoConfig.for_model(
189
+ text_model_type, **text_config
190
+ )
191
+
192
+ self.projection_dim = projection_dim
193
+ self.logit_scale_init_value = logit_scale_init_value if logit_scale_init_value is not None else np.log(1 / 0.07)
194
+ self.initializer_factor = 1.0
195
+
196
+ @classmethod
197
+ def from_text_vision_configs(cls, text_config: CLIPTextConfig, vision_config: CLIPVisionConfig, **kwargs):
198
+ r"""
199
+ Instantiate a [`CLIPConfig`] (or a derived class) from clip text model configuration and clip vision model
200
+ configuration.
201
+
202
+ Returns:
203
+ [`CLIPConfig`]: An instance of a configuration object
204
+ """
205
+
206
+ return cls(text_config_dict=text_config.to_dict(), vision_config_dict=vision_config.to_dict(), **kwargs)
207
+
208
+ def to_dict(self):
209
+ """
210
+ Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
211
+
212
+ Returns:
213
+ `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
214
+ """
215
+ output = copy.deepcopy(self.__dict__)
216
+ output["text_config"] = self.text_config.to_dict()
217
+ output["vision_config"] = self.vision_config.to_dict()
218
+ output["model_type"] = self.__class__.model_type
219
+ return output
japanese_clip/clip/modeling_clip.py ADDED
@@ -0,0 +1,815 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 rinna Co., Ltd.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import logging
16
+ from dataclasses import dataclass
17
+ from typing import Any, Optional, Tuple, Union
18
+
19
+ import torch
20
+ import torch.utils.checkpoint
21
+ from torch import nn
22
+
23
+ from transformers import AutoModel
24
+ from transformers.activations import ACT2FN
25
+ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
26
+ from transformers.modeling_utils import PreTrainedModel, ModelOutput
27
+ from .configuration_clip import CLIPConfig, CLIPTextConfig, CLIPVisionConfig
28
+
29
+
30
+ logger = logging.getLogger(__name__)
31
+
32
+
33
+ # Copied from transformers.models.bart.modeling_bart._expand_mask
34
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
35
+ """
36
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
37
+ """
38
+ bsz, src_len = mask.size()
39
+ tgt_len = tgt_len if tgt_len is not None else src_len
40
+
41
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
42
+
43
+ inverted_mask = 1.0 - expanded_mask
44
+
45
+ return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
46
+
47
+
48
+ # contrastive loss function, adapted from
49
+ # https://sachinruk.github.io/blog/pytorch/pytorch%20lightning/loss%20function/gpu/2021/03/07/CLIP.html
50
+ def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
51
+ return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device))
52
+
53
+
54
+ def clip_loss(similarity: torch.Tensor) -> torch.Tensor:
55
+ caption_loss = contrastive_loss(similarity)
56
+ image_loss = contrastive_loss(similarity.T)
57
+ return (caption_loss + image_loss) / 2.0
58
+
59
+
60
+ @dataclass
61
+ class CLIPOutput(ModelOutput):
62
+ loss: Optional[torch.FloatTensor] = None
63
+ logits_per_image: torch.FloatTensor = None
64
+ logits_per_text: torch.FloatTensor = None
65
+ text_embeds: torch.FloatTensor = None
66
+ image_embeds: torch.FloatTensor = None
67
+ text_model_output: BaseModelOutputWithPooling = None
68
+ vision_model_output: BaseModelOutputWithPooling = None
69
+
70
+ def to_tuple(self) -> Tuple[Any]:
71
+ return tuple(
72
+ self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
73
+ for k in self.keys()
74
+ )
75
+
76
+
77
+ class CLIPVisionEmbeddings(nn.Module):
78
+ def __init__(self, config: CLIPVisionConfig):
79
+ super().__init__()
80
+ self.config = config
81
+ self.embed_dim = config.hidden_size
82
+ self.image_size = config.image_size
83
+ self.patch_size = config.patch_size
84
+
85
+ self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
86
+
87
+ self.patch_embedding = nn.Conv2d(
88
+ in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size, bias=False
89
+ )
90
+
91
+ self.num_patches = (self.image_size // self.patch_size) ** 2
92
+ self.num_positions = self.num_patches + 1
93
+ self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
94
+ self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)))
95
+
96
+ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
97
+ batch_size = pixel_values.shape[0]
98
+ patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
99
+ patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
100
+
101
+ class_embeds = self.class_embedding.expand(batch_size, 1, -1)
102
+ embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
103
+ embeddings = embeddings + self.position_embedding(self.position_ids)
104
+ return embeddings
105
+
106
+
107
+ class CLIPTextEmbeddings(nn.Module):
108
+ def __init__(self, config: CLIPTextConfig):
109
+ super().__init__()
110
+ embed_dim = config.hidden_size
111
+
112
+ self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
113
+ self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
114
+
115
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
116
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
117
+
118
+ def forward(
119
+ self,
120
+ input_ids: Optional[torch.LongTensor] = None,
121
+ position_ids: Optional[torch.LongTensor] = None,
122
+ inputs_embeds: Optional[torch.FloatTensor] = None,
123
+ ) -> torch.Tensor:
124
+ seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
125
+
126
+ if position_ids is None:
127
+ position_ids = self.position_ids[:, :seq_length]
128
+
129
+ if inputs_embeds is None:
130
+ inputs_embeds = self.token_embedding(input_ids)
131
+
132
+ position_embeddings = self.position_embedding(position_ids)
133
+ embeddings = inputs_embeds + position_embeddings
134
+
135
+ return embeddings
136
+
137
+
138
+ class CLIPAttention(nn.Module):
139
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
140
+
141
+ def __init__(self, config):
142
+ super().__init__()
143
+ self.config = config
144
+ self.embed_dim = config.hidden_size
145
+ self.num_heads = config.num_attention_heads
146
+ self.head_dim = self.embed_dim // self.num_heads
147
+ if self.head_dim * self.num_heads != self.embed_dim:
148
+ raise ValueError(
149
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
150
+ f" {self.num_heads})."
151
+ )
152
+ self.scale = self.head_dim**-0.5
153
+ self.dropout = config.attention_dropout
154
+
155
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
156
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
157
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
158
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
159
+
160
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
161
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
162
+
163
+ def forward(
164
+ self,
165
+ hidden_states: torch.Tensor,
166
+ attention_mask: Optional[torch.Tensor] = None,
167
+ causal_attention_mask: Optional[torch.Tensor] = None,
168
+ output_attentions: Optional[bool] = False,
169
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
170
+ """Input shape: Batch x Time x Channel"""
171
+
172
+ bsz, tgt_len, embed_dim = hidden_states.size()
173
+
174
+ # get query proj
175
+ query_states = self.q_proj(hidden_states) * self.scale
176
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
177
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
178
+
179
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
180
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
181
+ key_states = key_states.view(*proj_shape)
182
+ value_states = value_states.view(*proj_shape)
183
+
184
+ src_len = key_states.size(1)
185
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
186
+
187
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
188
+ raise ValueError(
189
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
190
+ f" {attn_weights.size()}"
191
+ )
192
+
193
+ # apply the causal_attention_mask first
194
+ if causal_attention_mask is not None:
195
+ if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):
196
+ raise ValueError(
197
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
198
+ f" {causal_attention_mask.size()}"
199
+ )
200
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask
201
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
202
+
203
+ if attention_mask is not None:
204
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len):
205
+ raise ValueError(
206
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
207
+ )
208
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
209
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
210
+
211
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
212
+
213
+ if output_attentions:
214
+ # this operation is a bit akward, but it's required to
215
+ # make sure that attn_weights keeps its gradient.
216
+ # In order to do so, attn_weights have to reshaped
217
+ # twice and have to be reused in the following
218
+ attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
219
+ attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
220
+ else:
221
+ attn_weights_reshaped = None
222
+
223
+ attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
224
+
225
+ attn_output = torch.bmm(attn_probs, value_states)
226
+
227
+ if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
228
+ raise ValueError(
229
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
230
+ f" {attn_output.size()}"
231
+ )
232
+
233
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
234
+ attn_output = attn_output.transpose(1, 2)
235
+ attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
236
+
237
+ attn_output = self.out_proj(attn_output)
238
+
239
+ return attn_output, attn_weights_reshaped
240
+
241
+
242
+ class CLIPMLP(nn.Module):
243
+ def __init__(self, config):
244
+ super().__init__()
245
+ self.config = config
246
+ self.activation_fn = ACT2FN[config.hidden_act]
247
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
248
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
249
+
250
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
251
+ hidden_states = self.fc1(hidden_states)
252
+ hidden_states = self.activation_fn(hidden_states)
253
+ hidden_states = self.fc2(hidden_states)
254
+ return hidden_states
255
+
256
+
257
+ class CLIPEncoderLayer(nn.Module):
258
+ def __init__(self, config: CLIPConfig):
259
+ super().__init__()
260
+ self.embed_dim = config.hidden_size
261
+ self.self_attn = CLIPAttention(config)
262
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim)
263
+ self.mlp = CLIPMLP(config)
264
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim)
265
+
266
+ def forward(
267
+ self,
268
+ hidden_states: torch.Tensor,
269
+ attention_mask: torch.Tensor,
270
+ causal_attention_mask: torch.Tensor,
271
+ output_attentions: Optional[bool] = False,
272
+ ) -> Tuple[torch.FloatTensor]:
273
+ """
274
+ Args:
275
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
276
+ attention_mask (`torch.FloatTensor`): attention mask of size
277
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
278
+ `(config.encoder_attention_heads,)`.
279
+ output_attentions (`bool`, *optional*):
280
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
281
+ returned tensors for more detail.
282
+ """
283
+ residual = hidden_states
284
+
285
+ hidden_states = self.layer_norm1(hidden_states)
286
+ hidden_states, attn_weights = self.self_attn(
287
+ hidden_states=hidden_states,
288
+ attention_mask=attention_mask,
289
+ causal_attention_mask=causal_attention_mask,
290
+ output_attentions=output_attentions,
291
+ )
292
+ hidden_states = residual + hidden_states
293
+
294
+ residual = hidden_states
295
+ hidden_states = self.layer_norm2(hidden_states)
296
+ hidden_states = self.mlp(hidden_states)
297
+ hidden_states = residual + hidden_states
298
+
299
+ outputs = (hidden_states,)
300
+
301
+ if output_attentions:
302
+ outputs += (attn_weights,)
303
+
304
+ return outputs
305
+
306
+
307
+ class CLIPPreTrainedModel(PreTrainedModel):
308
+ """
309
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
310
+ models.
311
+ """
312
+
313
+ config_class = CLIPConfig
314
+ base_model_prefix = "clip"
315
+ supports_gradient_checkpointing = True
316
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
317
+
318
+ def _init_weights(self, module):
319
+ """Initialize the weights"""
320
+ factor = self.config.initializer_factor
321
+ if isinstance(module, CLIPTextEmbeddings):
322
+ module.token_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
323
+ module.position_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
324
+ elif isinstance(module, CLIPVisionEmbeddings):
325
+ factor = self.config.initializer_factor
326
+ nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor)
327
+ nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor)
328
+ nn.init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor)
329
+ elif isinstance(module, CLIPAttention):
330
+ factor = self.config.initializer_factor
331
+ in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
332
+ out_proj_std = (module.embed_dim**-0.5) * factor
333
+ nn.init.normal_(module.q_proj.weight, std=in_proj_std)
334
+ nn.init.normal_(module.k_proj.weight, std=in_proj_std)
335
+ nn.init.normal_(module.v_proj.weight, std=in_proj_std)
336
+ nn.init.normal_(module.out_proj.weight, std=out_proj_std)
337
+ elif isinstance(module, CLIPMLP):
338
+ factor = self.config.initializer_factor
339
+ in_proj_std = (
340
+ (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
341
+ )
342
+ fc_std = (2 * module.config.hidden_size) ** -0.5 * factor
343
+ nn.init.normal_(module.fc1.weight, std=fc_std)
344
+ nn.init.normal_(module.fc2.weight, std=in_proj_std)
345
+ elif isinstance(module, CLIPModel):
346
+ nn.init.normal_(
347
+ module.text_projection.weight,
348
+ std=module.text_embed_dim**-0.5 * self.config.initializer_factor,
349
+ )
350
+ nn.init.normal_(
351
+ module.visual_projection.weight,
352
+ std=module.vision_embed_dim**-0.5 * self.config.initializer_factor,
353
+ )
354
+
355
+ if isinstance(module, nn.LayerNorm):
356
+ module.bias.data.zero_()
357
+ module.weight.data.fill_(1.0)
358
+ if isinstance(module, nn.Linear) and module.bias is not None:
359
+ module.bias.data.zero_()
360
+
361
+ def _set_gradient_checkpointing(self, module, value=False):
362
+ if isinstance(module, CLIPEncoder):
363
+ module.gradient_checkpointing = value
364
+
365
+
366
+ class CLIPEncoder(nn.Module):
367
+ """
368
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
369
+ [`CLIPEncoderLayer`].
370
+ Args:
371
+ config: CLIPConfig
372
+ """
373
+
374
+ def __init__(self, config: CLIPConfig):
375
+ super().__init__()
376
+ self.config = config
377
+ self.layers = nn.ModuleList([CLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)])
378
+ self.gradient_checkpointing = False
379
+
380
+ def forward(
381
+ self,
382
+ inputs_embeds,
383
+ attention_mask: Optional[torch.Tensor] = None,
384
+ causal_attention_mask: Optional[torch.Tensor] = None,
385
+ output_attentions: Optional[bool] = None,
386
+ output_hidden_states: Optional[bool] = None,
387
+ return_dict: Optional[bool] = None,
388
+ ) -> Union[Tuple, BaseModelOutput]:
389
+ r"""
390
+ Args:
391
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
392
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
393
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
394
+ than the model's internal embedding lookup matrix.
395
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
396
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
397
+ - 1 for tokens that are **not masked**,
398
+ - 0 for tokens that are **masked**.
399
+ [What are attention masks?](../glossary#attention-mask)
400
+ causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
401
+ Causal mask for the text model. Mask values selected in `[0, 1]`:
402
+ - 1 for tokens that are **not masked**,
403
+ - 0 for tokens that are **masked**.
404
+ [What are attention masks?](../glossary#attention-mask)
405
+ output_attentions (`bool`, *optional*):
406
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
407
+ returned tensors for more detail.
408
+ output_hidden_states (`bool`, *optional*):
409
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
410
+ for more detail.
411
+ return_dict (`bool`, *optional*):
412
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
413
+ """
414
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
415
+ output_hidden_states = (
416
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
417
+ )
418
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
419
+
420
+ encoder_states = () if output_hidden_states else None
421
+ all_attentions = () if output_attentions else None
422
+
423
+ hidden_states = inputs_embeds
424
+ for idx, encoder_layer in enumerate(self.layers):
425
+ if output_hidden_states:
426
+ encoder_states = encoder_states + (hidden_states,)
427
+ if self.gradient_checkpointing and self.training:
428
+
429
+ def create_custom_forward(module):
430
+ def custom_forward(*inputs):
431
+ return module(*inputs, output_attentions)
432
+
433
+ return custom_forward
434
+
435
+ layer_outputs = torch.utils.checkpoint.checkpoint(
436
+ create_custom_forward(encoder_layer),
437
+ hidden_states,
438
+ attention_mask,
439
+ causal_attention_mask,
440
+ )
441
+ else:
442
+ layer_outputs = encoder_layer(
443
+ hidden_states,
444
+ attention_mask,
445
+ causal_attention_mask,
446
+ output_attentions=output_attentions,
447
+ )
448
+
449
+ hidden_states = layer_outputs[0]
450
+
451
+ if output_attentions:
452
+ all_attentions = all_attentions + (layer_outputs[1],)
453
+
454
+ if output_hidden_states:
455
+ encoder_states = encoder_states + (hidden_states,)
456
+
457
+ if not return_dict:
458
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
459
+ return BaseModelOutput(
460
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
461
+ )
462
+
463
+
464
+ class CLIPTextTransformer(nn.Module):
465
+ def __init__(self, config: CLIPTextConfig):
466
+ super().__init__()
467
+ self.config = config
468
+ embed_dim = config.hidden_size
469
+ self.embeddings = CLIPTextEmbeddings(config)
470
+ self.encoder = CLIPEncoder(config)
471
+ self.final_layer_norm = nn.LayerNorm(embed_dim)
472
+
473
+ def forward(
474
+ self,
475
+ input_ids: Optional[torch.Tensor] = None,
476
+ attention_mask: Optional[torch.Tensor] = None,
477
+ position_ids: Optional[torch.Tensor] = None,
478
+ output_attentions: Optional[bool] = None,
479
+ output_hidden_states: Optional[bool] = None,
480
+ return_dict: Optional[bool] = None,
481
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
482
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
483
+ output_hidden_states = (
484
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
485
+ )
486
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
487
+
488
+ if input_ids is None:
489
+ raise ValueError("You have to specify either input_ids")
490
+
491
+ input_shape = input_ids.size()
492
+ input_ids = input_ids.view(-1, input_shape[-1])
493
+
494
+ hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
495
+
496
+ bsz, seq_len = input_shape
497
+ # CLIP's text model uses causal mask, prepare it here.
498
+ # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
499
+ causal_attention_mask = self._build_causal_attention_mask(bsz, seq_len).to(hidden_states.device)
500
+ # expand attention_mask
501
+ if attention_mask is not None:
502
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
503
+ attention_mask = _expand_mask(attention_mask, hidden_states.dtype)
504
+
505
+ encoder_outputs = self.encoder(
506
+ inputs_embeds=hidden_states,
507
+ attention_mask=attention_mask,
508
+ causal_attention_mask=causal_attention_mask,
509
+ output_attentions=output_attentions,
510
+ output_hidden_states=output_hidden_states,
511
+ return_dict=return_dict,
512
+ )
513
+
514
+ last_hidden_state = encoder_outputs[0]
515
+ last_hidden_state = self.final_layer_norm(last_hidden_state)
516
+
517
+ # text_embeds.shape = [batch_size, sequence_length, transformer.width]
518
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
519
+ pooled_output = last_hidden_state[torch.arange(last_hidden_state.shape[0]), input_ids.argmax(dim=-1)]
520
+
521
+ if not return_dict:
522
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
523
+
524
+ return BaseModelOutputWithPooling(
525
+ last_hidden_state=last_hidden_state,
526
+ pooler_output=pooled_output,
527
+ hidden_states=encoder_outputs.hidden_states,
528
+ attentions=encoder_outputs.attentions,
529
+ )
530
+
531
+ def _build_causal_attention_mask(self, bsz, seq_len):
532
+ # lazily create causal attention mask, with full attention between the vision tokens
533
+ # pytorch uses additive attention mask; fill with -inf
534
+ mask = torch.empty(bsz, seq_len, seq_len)
535
+ mask.fill_(float("-inf"))
536
+ mask.triu_(1) # zero out the lower diagonal
537
+ mask = mask.unsqueeze(1) # expand mask
538
+ return mask
539
+
540
+
541
+ class CLIPTextModel(CLIPPreTrainedModel):
542
+ config_class = CLIPTextConfig
543
+
544
+ def __init__(self, config: CLIPTextConfig):
545
+ super().__init__(config)
546
+ self.text_model = CLIPTextTransformer(config)
547
+ # Initialize weights and apply final processing
548
+ self.post_init()
549
+
550
+ def get_input_embeddings(self) -> nn.Module:
551
+ return self.text_model.embeddings.token_embedding
552
+
553
+ def set_input_embeddings(self, value):
554
+ self.text_model.embeddings.token_embedding = value
555
+
556
+ def forward(
557
+ self,
558
+ input_ids: Optional[torch.Tensor] = None,
559
+ attention_mask: Optional[torch.Tensor] = None,
560
+ position_ids: Optional[torch.Tensor] = None,
561
+ output_attentions: Optional[bool] = None,
562
+ output_hidden_states: Optional[bool] = None,
563
+ return_dict: Optional[bool] = None,
564
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
565
+ return self.text_model(
566
+ input_ids=input_ids,
567
+ attention_mask=attention_mask,
568
+ position_ids=position_ids,
569
+ output_attentions=output_attentions,
570
+ output_hidden_states=output_hidden_states,
571
+ return_dict=return_dict,
572
+ )
573
+
574
+
575
+ class CLIPVisionTransformer(nn.Module):
576
+ def __init__(self, config: CLIPVisionConfig):
577
+ super().__init__()
578
+ self.config = config
579
+ embed_dim = config.hidden_size
580
+
581
+ self.embeddings = CLIPVisionEmbeddings(config)
582
+ self.pre_layrnorm = nn.LayerNorm(embed_dim)
583
+ self.encoder = CLIPEncoder(config)
584
+ self.post_layernorm = nn.LayerNorm(embed_dim)
585
+
586
+ def forward(
587
+ self,
588
+ pixel_values: Optional[torch.FloatTensor] = None,
589
+ output_attentions: Optional[bool] = None,
590
+ output_hidden_states: Optional[bool] = None,
591
+ return_dict: Optional[bool] = None,
592
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
593
+ r"""
594
+ Returns:
595
+ """
596
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
597
+ output_hidden_states = (
598
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
599
+ )
600
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
601
+
602
+ if pixel_values is None:
603
+ raise ValueError("You have to specify pixel_values")
604
+
605
+ hidden_states = self.embeddings(pixel_values)
606
+ hidden_states = self.pre_layrnorm(hidden_states)
607
+
608
+ encoder_outputs = self.encoder(
609
+ inputs_embeds=hidden_states,
610
+ output_attentions=output_attentions,
611
+ output_hidden_states=output_hidden_states,
612
+ return_dict=return_dict,
613
+ )
614
+
615
+ last_hidden_state = encoder_outputs[0]
616
+ pooled_output = last_hidden_state[:, 0, :]
617
+ pooled_output = self.post_layernorm(pooled_output)
618
+
619
+ if not return_dict:
620
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
621
+
622
+ return BaseModelOutputWithPooling(
623
+ last_hidden_state=last_hidden_state,
624
+ pooler_output=pooled_output,
625
+ hidden_states=encoder_outputs.hidden_states,
626
+ attentions=encoder_outputs.attentions,
627
+ )
628
+
629
+
630
+ class CLIPVisionModel(CLIPPreTrainedModel):
631
+ config_class = CLIPVisionConfig
632
+ main_input_name = "pixel_values"
633
+
634
+ def __init__(self, config: CLIPVisionConfig):
635
+ super().__init__(config)
636
+ self.vision_model = CLIPVisionTransformer(config)
637
+ # Initialize weights and apply final processing
638
+ self.post_init()
639
+
640
+ def get_input_embeddings(self) -> nn.Module:
641
+ return self.vision_model.embeddings.patch_embedding
642
+
643
+ def forward(
644
+ self,
645
+ pixel_values: Optional[torch.FloatTensor] = None,
646
+ output_attentions: Optional[bool] = None,
647
+ output_hidden_states: Optional[bool] = None,
648
+ return_dict: Optional[bool] = None,
649
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
650
+ return self.vision_model(
651
+ pixel_values=pixel_values,
652
+ output_attentions=output_attentions,
653
+ output_hidden_states=output_hidden_states,
654
+ return_dict=return_dict,
655
+ )
656
+
657
+
658
+ class CLIPModel(CLIPPreTrainedModel):
659
+ config_class = CLIPConfig
660
+
661
+ def __init__(self, config: CLIPConfig):
662
+ super().__init__(config)
663
+ text_config = config.text_config
664
+ vision_config = config.vision_config
665
+
666
+ self.projection_dim = config.projection_dim
667
+ self.text_embed_dim = text_config.hidden_size
668
+ self.vision_embed_dim = vision_config.hidden_size
669
+
670
+ if isinstance(text_config, CLIPTextConfig):
671
+ text_model = CLIPTextTransformer(text_config)
672
+ else:
673
+ text_model = AutoModel.from_config(config.text_config, add_pooling_layer=False)
674
+
675
+ if isinstance(config.vision_config, CLIPVisionConfig):
676
+ vision_model = CLIPVisionModel(config.vision_config)
677
+ else:
678
+ vision_model = AutoModel.from_config(config.vision_config, add_pooling_layer=False)
679
+
680
+ self.text_model = text_model
681
+ self.vision_model = vision_model
682
+
683
+ self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False)
684
+ self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False)
685
+ self.logit_scale = nn.Parameter(torch.ones([]) * self.config.logit_scale_init_value)
686
+
687
+ # Initialize weights and apply final processing
688
+ self.post_init()
689
+
690
+ def encode_text(self, input_ids, **kwargs):
691
+ return self.get_text_features(input_ids=input_ids, **kwargs)
692
+
693
+ def get_text_features(
694
+ self,
695
+ input_ids: Optional[torch.Tensor] = None,
696
+ attention_mask: Optional[torch.Tensor] = None,
697
+ position_ids: Optional[torch.Tensor] = None,
698
+ output_attentions: Optional[bool] = None,
699
+ output_hidden_states: Optional[bool] = None,
700
+ return_dict: Optional[bool] = None,
701
+ ) -> torch.FloatTensor:
702
+ # Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
703
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
704
+ output_hidden_states = (
705
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
706
+ )
707
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
708
+
709
+ text_outputs = self.text_model(
710
+ input_ids=input_ids,
711
+ attention_mask=attention_mask,
712
+ position_ids=position_ids,
713
+ output_attentions=output_attentions,
714
+ output_hidden_states=output_hidden_states,
715
+ return_dict=return_dict,
716
+ )
717
+ pooled_output = text_outputs.last_hidden_state[:, 0, :]
718
+ text_features = self.text_projection(pooled_output)
719
+
720
+ return text_features
721
+
722
+ def encode_image(self, pixel_values, **kwargs):
723
+ return self.get_image_features(pixel_values=pixel_values, **kwargs)
724
+
725
+ def get_image_features(
726
+ self,
727
+ pixel_values: Optional[torch.FloatTensor] = None,
728
+ output_attentions: Optional[bool] = None,
729
+ output_hidden_states: Optional[bool] = None,
730
+ return_dict: Optional[bool] = None,
731
+ ) -> torch.FloatTensor:
732
+ # Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
733
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
734
+ output_hidden_states = (
735
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
736
+ )
737
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
738
+
739
+ vision_outputs = self.vision_model(
740
+ pixel_values=pixel_values,
741
+ output_attentions=output_attentions,
742
+ output_hidden_states=output_hidden_states,
743
+ return_dict=return_dict,
744
+ )
745
+ pooled_output = vision_outputs.last_hidden_state[:, 0, :]
746
+ image_features = self.visual_projection(pooled_output)
747
+
748
+ return image_features
749
+
750
+ def forward(
751
+ self,
752
+ input_ids: Optional[torch.LongTensor] = None,
753
+ pixel_values: Optional[torch.FloatTensor] = None,
754
+ attention_mask: Optional[torch.Tensor] = None,
755
+ position_ids: Optional[torch.LongTensor] = None,
756
+ return_loss: Optional[bool] = None,
757
+ output_attentions: Optional[bool] = None,
758
+ output_hidden_states: Optional[bool] = None,
759
+ return_dict: Optional[bool] = None,
760
+ ) -> Union[Tuple, CLIPOutput]:
761
+ # Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
762
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
763
+ output_hidden_states = (
764
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
765
+ )
766
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
767
+
768
+ vision_outputs = self.vision_model(
769
+ pixel_values=pixel_values,
770
+ output_attentions=output_attentions,
771
+ output_hidden_states=output_hidden_states,
772
+ return_dict=return_dict,
773
+ )
774
+
775
+ text_outputs = self.text_model(
776
+ input_ids=input_ids,
777
+ attention_mask=attention_mask,
778
+ position_ids=position_ids,
779
+ output_attentions=output_attentions,
780
+ output_hidden_states=output_hidden_states,
781
+ return_dict=return_dict,
782
+ )
783
+ image_embeds = vision_outputs.last_hidden_state[:, 0, :]
784
+ image_embeds = self.visual_projection(image_embeds)
785
+
786
+ text_embeds = text_outputs.last_hidden_state[:, 0, :]
787
+ text_embeds = self.text_projection(text_embeds)
788
+
789
+ # normalized features
790
+ image_embeds = image_embeds / image_embeds.norm(dim=-1, keepdim=True)
791
+ text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True)
792
+
793
+ # cosine similarity as logits
794
+ logit_scale = self.logit_scale.exp()
795
+ logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
796
+ logits_per_image = logits_per_text.T
797
+
798
+ loss = None
799
+ if return_loss:
800
+ loss = clip_loss(logits_per_text)
801
+
802
+ if not return_dict:
803
+ output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
804
+ return ((loss,) + output) if loss is not None else output
805
+
806
+ return CLIPOutput(
807
+ loss=loss,
808
+ logits_per_image=logits_per_image,
809
+ logits_per_text=logits_per_text,
810
+ text_embeds=text_embeds,
811
+ image_embeds=image_embeds,
812
+ text_model_output=text_outputs,
813
+ vision_model_output=vision_outputs,
814
+ )
815
+
japanese_clip/cloob/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 rinna Co., Ltd.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ from .configuration_cloob import *
16
+ from .modeling_cloob import *
japanese_clip/cloob/configuration_cloob.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 rinna Co., Ltd.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ CLOOB model configuration"""
16
+ import logging
17
+ import copy
18
+ import os
19
+ from typing import Union
20
+
21
+ from transformers import AutoConfig, PretrainedConfig
22
+
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ class CLOOBTextConfig(PretrainedConfig):
28
+ model_type = "cloob_text_model"
29
+
30
+ def __init__(
31
+ self,
32
+ vocab_size=49408,
33
+ hidden_size=512,
34
+ intermediate_size=2048,
35
+ num_hidden_layers=12,
36
+ num_attention_heads=8,
37
+ max_position_embeddings=77,
38
+ hidden_act="quick_gelu",
39
+ layer_norm_eps=0.00001,
40
+ dropout=0.0,
41
+ attention_dropout=0.0,
42
+ initializer_range=0.02,
43
+ initializer_factor=1.0,
44
+ pad_token_id=1,
45
+ bos_token_id=0,
46
+ eos_token_id=2,
47
+ **kwargs
48
+ ):
49
+ super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
50
+
51
+ self.vocab_size = vocab_size
52
+ self.hidden_size = hidden_size
53
+ self.intermediate_size = intermediate_size
54
+ self.dropout = dropout
55
+ self.num_hidden_layers = num_hidden_layers
56
+ self.num_attention_heads = num_attention_heads
57
+ self.max_position_embeddings = max_position_embeddings
58
+ self.layer_norm_eps = layer_norm_eps
59
+ self.hidden_act = hidden_act
60
+ self.initializer_range = initializer_range
61
+ self.initializer_factor = initializer_factor
62
+ self.attention_dropout = attention_dropout
63
+
64
+ @classmethod
65
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
66
+
67
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
68
+
69
+ # get the text config dict if we are loading from CLIPConfig
70
+ if config_dict.get("model_type") == "clip":
71
+ config_dict = config_dict["text_config"]
72
+
73
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
74
+ logger.warning(
75
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
76
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
77
+ )
78
+
79
+ return cls.from_dict(config_dict, **kwargs)
80
+
81
+
82
+ class CLOOBVisionConfig(PretrainedConfig):
83
+ model_type = "cloob_vision_model"
84
+
85
+ def __init__(
86
+ self,
87
+ hidden_size=768,
88
+ intermediate_size=3072,
89
+ num_hidden_layers=12,
90
+ num_attention_heads=12,
91
+ image_size=224,
92
+ patch_size=32,
93
+ hidden_act="quick_gelu",
94
+ layer_norm_eps=0.00001,
95
+ dropout=0.0,
96
+ attention_dropout=0.0,
97
+ initializer_range=0.02,
98
+ initializer_factor=1.0,
99
+ **kwargs
100
+ ):
101
+ super().__init__(**kwargs)
102
+
103
+ self.hidden_size = hidden_size
104
+ self.intermediate_size = intermediate_size
105
+ self.dropout = dropout
106
+ self.num_hidden_layers = num_hidden_layers
107
+ self.num_attention_heads = num_attention_heads
108
+ self.patch_size = patch_size
109
+ self.image_size = image_size
110
+ self.initializer_range = initializer_range
111
+ self.initializer_factor = initializer_factor
112
+ self.attention_dropout = attention_dropout
113
+ self.layer_norm_eps = layer_norm_eps
114
+ self.hidden_act = hidden_act
115
+
116
+ @classmethod
117
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
118
+
119
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
120
+
121
+ # get the vision config dict if we are loading from CLIPConfig
122
+ if config_dict.get("model_type") == "clip":
123
+ config_dict = config_dict["vision_config"]
124
+
125
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
126
+ logger.warning(
127
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
128
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
129
+ )
130
+
131
+ return cls.from_dict(config_dict, **kwargs)
132
+
133
+
134
+ class CLOOBConfig(PretrainedConfig):
135
+ model_type = "cloob"
136
+ is_composition = True
137
+
138
+ def __init__(
139
+ self,
140
+ text_config=None,
141
+ vision_config=None,
142
+ projection_dim=512,
143
+ init_inv_tau=30.0,
144
+ scale_hopfield=15.0,
145
+ **kwargs
146
+ ):
147
+ super().__init__(text_config=text_config, vision_config=vision_config, **kwargs)
148
+
149
+ if vision_config is None:
150
+ raise ValueError("`vision_config` can not be `None`.")
151
+
152
+ if text_config is None:
153
+ raise ValueError("`text_config` can not be `None`.")
154
+
155
+ vision_model_type = vision_config.pop("model_type")
156
+ text_model_type = text_config.pop("model_type")
157
+
158
+ if vision_model_type == "cloob_vision_model":
159
+ self.vision_config = CLOOBVisionConfig(**vision_config)
160
+ else:
161
+ self.vision_config = AutoConfig.for_model(
162
+ vision_model_type, **vision_config
163
+ )
164
+
165
+ if text_model_type == "cloob_text_model":
166
+ self.text_config = CLOOBTextConfig(**text_config)
167
+ else:
168
+ self.text_config = AutoConfig.for_model(
169
+ text_model_type, **text_config
170
+ )
171
+
172
+ self.projection_dim = projection_dim
173
+ self.initializer_factor = 1.0
174
+ self.init_inv_tau = init_inv_tau
175
+ self.scale_hopfield = scale_hopfield
176
+
177
+
178
+ @classmethod
179
+ def from_text_vision_configs(cls, text_config: CLOOBTextConfig, vision_config: CLOOBVisionConfig, **kwargs):
180
+ r"""
181
+ Instantiate a [`CLIPConfig`] (or a derived class) from clip text model configuration and clip vision model
182
+ configuration.
183
+
184
+ Returns:
185
+ [`CLIPConfig`]: An instance of a configuration object
186
+ """
187
+
188
+ return cls(text_config_dict=text_config.to_dict(), vision_config_dict=vision_config.to_dict(), **kwargs)
189
+
190
+ def to_dict(self):
191
+ """
192
+ Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
193
+
194
+ Returns:
195
+ `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
196
+ """
197
+ output = copy.deepcopy(self.__dict__)
198
+ output["text_config"] = self.text_config.to_dict()
199
+ output["vision_config"] = self.vision_config.to_dict()
200
+ output["model_type"] = self.__class__.model_type
201
+ return output
202
+
203
+
japanese_clip/cloob/loss.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 rinna Co., Ltd.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import torch
17
+ import torch.nn.functional as F
18
+
19
+
20
+ def cloob_loss(image_features, text_features, inv_tau, scale_hopfield):
21
+ """
22
+ Note: this loss has been rescaled from the original CLOOB loss for interpretability,
23
+ to convert to the original, divide it by inv_tau / 2.
24
+ """
25
+ p_xx, p_yy, p_xy, p_yx = hopfield_retrieval(image_features, text_features, scale_hopfield)
26
+ identity = torch.eye(p_xx.shape[1]) > 0.5
27
+ i = identity.to(p_xx.device)
28
+ loss_img = infoLOOB_loss(p_xx.T, p_xy.T, i, inv_tau=inv_tau)
29
+ loss_txt = infoLOOB_loss(p_yy.T, p_yx.T, i, inv_tau=inv_tau)
30
+ return (loss_img + loss_txt) / 2
31
+
32
+
33
+ def infoLOOB_loss(x, y, i, inv_tau):
34
+ tau = 1 / inv_tau
35
+ k = x @ y.T / tau
36
+ positives = -torch.mean(torch.sum(k * i, dim=1))
37
+
38
+ # For logsumexp the zero entries must be equal to a very large negative number
39
+ large_neg = -10000.0
40
+ arg_lse = k * torch.logical_not(i) + i * large_neg
41
+ negatives = torch.mean(torch.logsumexp(arg_lse, dim=1))
42
+ return positives + negatives
43
+
44
+
45
+ def hopfield_retrieval(image_features, text_features, scale_hopfield):
46
+ patterns_xx = hopfield(state_patterns=image_features, stored_patterns=image_features, scale_hopfield=scale_hopfield)
47
+ patterns_yy = hopfield(state_patterns=text_features, stored_patterns=text_features, scale_hopfield=scale_hopfield)
48
+ patterns_xy = hopfield(state_patterns=text_features, stored_patterns=image_features, scale_hopfield=scale_hopfield)
49
+ patterns_yx = hopfield(state_patterns=image_features, stored_patterns=text_features, scale_hopfield=scale_hopfield)
50
+
51
+ return patterns_xx, patterns_yy, patterns_xy, patterns_yx
52
+
53
+
54
+ def hopfield(state_patterns, stored_patterns, scale_hopfield):
55
+ retrieved_patterns = stored_patterns.T @ F.softmax(scale_hopfield * stored_patterns @ state_patterns.T, dim=0)
56
+ # Row vectors -> dim=1 to normalize the row vectors
57
+ retrieved_patterns = retrieved_patterns / retrieved_patterns.norm(dim=0, keepdim=True)
58
+ return retrieved_patterns
japanese_clip/cloob/modeling_cloob.py ADDED
@@ -0,0 +1,783 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 rinna Co., Ltd.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import logging
16
+ from dataclasses import dataclass
17
+ from typing import Any, Optional, Tuple, Union
18
+
19
+ import torch
20
+ import torch.utils.checkpoint
21
+ from torch import nn
22
+
23
+ from transformers import AutoModel
24
+ from transformers.activations import ACT2FN
25
+ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
26
+ from transformers.modeling_utils import PreTrainedModel, ModelOutput
27
+ from .configuration_cloob import CLOOBConfig, CLOOBTextConfig, CLOOBVisionConfig
28
+ from .loss import cloob_loss
29
+ from ..clip.modeling_clip import _expand_mask
30
+
31
+ logger = logging.getLogger(__name__)
32
+
33
+
34
+ @dataclass
35
+ class CLOOBOutput(ModelOutput):
36
+ loss: Optional[torch.FloatTensor] = None
37
+ inv_tau: Union[torch.FloatTensor, float] = None
38
+ text_embeds: torch.FloatTensor = None
39
+ image_embeds: torch.FloatTensor = None
40
+ text_model_output: BaseModelOutputWithPooling = None
41
+ vision_model_output: BaseModelOutputWithPooling = None
42
+
43
+ def to_tuple(self) -> Tuple[Any]:
44
+ return tuple(
45
+ self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
46
+ for k in self.keys()
47
+ )
48
+
49
+
50
+ class CLOOBVisionEmbeddings(nn.Module):
51
+ def __init__(self, config: CLOOBVisionConfig):
52
+ super().__init__()
53
+ self.config = config
54
+ self.embed_dim = config.hidden_size
55
+ self.image_size = config.image_size
56
+ self.patch_size = config.patch_size
57
+
58
+ self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
59
+
60
+ self.patch_embedding = nn.Conv2d(
61
+ in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size, bias=False
62
+ )
63
+
64
+ self.num_patches = (self.image_size // self.patch_size) ** 2
65
+ self.num_positions = self.num_patches + 1
66
+ self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
67
+ self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)))
68
+
69
+ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
70
+ batch_size = pixel_values.shape[0]
71
+ patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
72
+ patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
73
+
74
+ class_embeds = self.class_embedding.expand(batch_size, 1, -1)
75
+ embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
76
+ embeddings = embeddings + self.position_embedding(self.position_ids)
77
+ return embeddings
78
+
79
+
80
+ class CLOOBTextEmbeddings(nn.Module):
81
+ def __init__(self, config: CLOOBTextConfig):
82
+ super().__init__()
83
+ embed_dim = config.hidden_size
84
+
85
+ self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
86
+ self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
87
+
88
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
89
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
90
+
91
+ def forward(
92
+ self,
93
+ input_ids: Optional[torch.LongTensor] = None,
94
+ position_ids: Optional[torch.LongTensor] = None,
95
+ inputs_embeds: Optional[torch.FloatTensor] = None,
96
+ ) -> torch.Tensor:
97
+ seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
98
+
99
+ if position_ids is None:
100
+ position_ids = self.position_ids[:, :seq_length]
101
+
102
+ if inputs_embeds is None:
103
+ inputs_embeds = self.token_embedding(input_ids)
104
+
105
+ position_embeddings = self.position_embedding(position_ids)
106
+ embeddings = inputs_embeds + position_embeddings
107
+
108
+ return embeddings
109
+
110
+
111
+ class CLOOBAttention(nn.Module):
112
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
113
+
114
+ def __init__(self, config):
115
+ super().__init__()
116
+ self.config = config
117
+ self.embed_dim = config.hidden_size
118
+ self.num_heads = config.num_attention_heads
119
+ self.head_dim = self.embed_dim // self.num_heads
120
+ if self.head_dim * self.num_heads != self.embed_dim:
121
+ raise ValueError(
122
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
123
+ f" {self.num_heads})."
124
+ )
125
+ self.scale = self.head_dim**-0.5
126
+ self.dropout = config.attention_dropout
127
+
128
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
129
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
130
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
131
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
132
+
133
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
134
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
135
+
136
+ def forward(
137
+ self,
138
+ hidden_states: torch.Tensor,
139
+ attention_mask: Optional[torch.Tensor] = None,
140
+ causal_attention_mask: Optional[torch.Tensor] = None,
141
+ output_attentions: Optional[bool] = False,
142
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
143
+ """Input shape: Batch x Time x Channel"""
144
+
145
+ bsz, tgt_len, embed_dim = hidden_states.size()
146
+
147
+ # get query proj
148
+ query_states = self.q_proj(hidden_states) * self.scale
149
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
150
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
151
+
152
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
153
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
154
+ key_states = key_states.view(*proj_shape)
155
+ value_states = value_states.view(*proj_shape)
156
+
157
+ src_len = key_states.size(1)
158
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
159
+
160
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
161
+ raise ValueError(
162
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
163
+ f" {attn_weights.size()}"
164
+ )
165
+
166
+ # apply the causal_attention_mask first
167
+ if causal_attention_mask is not None:
168
+ if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):
169
+ raise ValueError(
170
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
171
+ f" {causal_attention_mask.size()}"
172
+ )
173
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask
174
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
175
+
176
+ if attention_mask is not None:
177
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len):
178
+ raise ValueError(
179
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
180
+ )
181
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
182
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
183
+
184
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
185
+
186
+ if output_attentions:
187
+ # this operation is a bit akward, but it's required to
188
+ # make sure that attn_weights keeps its gradient.
189
+ # In order to do so, attn_weights have to reshaped
190
+ # twice and have to be reused in the following
191
+ attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
192
+ attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
193
+ else:
194
+ attn_weights_reshaped = None
195
+
196
+ attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
197
+
198
+ attn_output = torch.bmm(attn_probs, value_states)
199
+
200
+ if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
201
+ raise ValueError(
202
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
203
+ f" {attn_output.size()}"
204
+ )
205
+
206
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
207
+ attn_output = attn_output.transpose(1, 2)
208
+ attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
209
+
210
+ attn_output = self.out_proj(attn_output)
211
+
212
+ return attn_output, attn_weights_reshaped
213
+
214
+
215
+ class CLOOBMLP(nn.Module):
216
+ def __init__(self, config):
217
+ super().__init__()
218
+ self.config = config
219
+ self.activation_fn = ACT2FN[config.hidden_act]
220
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
221
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
222
+
223
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
224
+ hidden_states = self.fc1(hidden_states)
225
+ hidden_states = self.activation_fn(hidden_states)
226
+ hidden_states = self.fc2(hidden_states)
227
+ return hidden_states
228
+
229
+
230
+ class CLOOBEncoderLayer(nn.Module):
231
+ def __init__(self, config: CLOOBConfig):
232
+ super().__init__()
233
+ self.embed_dim = config.hidden_size
234
+ self.self_attn = CLOOBAttention(config)
235
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim)
236
+ self.mlp = CLOOBMLP(config)
237
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim)
238
+
239
+ def forward(
240
+ self,
241
+ hidden_states: torch.Tensor,
242
+ attention_mask: torch.Tensor,
243
+ causal_attention_mask: torch.Tensor,
244
+ output_attentions: Optional[bool] = False,
245
+ ) -> Tuple[torch.FloatTensor]:
246
+ """
247
+ Args:
248
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
249
+ attention_mask (`torch.FloatTensor`): attention mask of size
250
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
251
+ `(config.encoder_attention_heads,)`.
252
+ output_attentions (`bool`, *optional*):
253
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
254
+ returned tensors for more detail.
255
+ """
256
+ residual = hidden_states
257
+
258
+ hidden_states = self.layer_norm1(hidden_states)
259
+ hidden_states, attn_weights = self.self_attn(
260
+ hidden_states=hidden_states,
261
+ attention_mask=attention_mask,
262
+ causal_attention_mask=causal_attention_mask,
263
+ output_attentions=output_attentions,
264
+ )
265
+ hidden_states = residual + hidden_states
266
+
267
+ residual = hidden_states
268
+ hidden_states = self.layer_norm2(hidden_states)
269
+ hidden_states = self.mlp(hidden_states)
270
+ hidden_states = residual + hidden_states
271
+
272
+ outputs = (hidden_states,)
273
+
274
+ if output_attentions:
275
+ outputs += (attn_weights,)
276
+
277
+ return outputs
278
+
279
+
280
+ class CLOOBPreTrainedModel(PreTrainedModel):
281
+ """
282
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
283
+ models.
284
+ """
285
+
286
+ config_class = CLOOBConfig
287
+ base_model_prefix = "cloob"
288
+ supports_gradient_checkpointing = True
289
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
290
+
291
+ def _init_weights(self, module):
292
+ """Initialize the weights"""
293
+ factor = self.config.initializer_factor
294
+ if isinstance(module, CLOOBTextEmbeddings):
295
+ module.token_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
296
+ module.position_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
297
+ elif isinstance(module, CLOOBVisionEmbeddings):
298
+ factor = self.config.initializer_factor
299
+ nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor)
300
+ nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor)
301
+ nn.init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor)
302
+ elif isinstance(module, CLOOBAttention):
303
+ factor = self.config.initializer_factor
304
+ in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
305
+ out_proj_std = (module.embed_dim**-0.5) * factor
306
+ nn.init.normal_(module.q_proj.weight, std=in_proj_std)
307
+ nn.init.normal_(module.k_proj.weight, std=in_proj_std)
308
+ nn.init.normal_(module.v_proj.weight, std=in_proj_std)
309
+ nn.init.normal_(module.out_proj.weight, std=out_proj_std)
310
+ elif isinstance(module, CLOOBMLP):
311
+ factor = self.config.initializer_factor
312
+ in_proj_std = (
313
+ (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
314
+ )
315
+ fc_std = (2 * module.config.hidden_size) ** -0.5 * factor
316
+ nn.init.normal_(module.fc1.weight, std=fc_std)
317
+ nn.init.normal_(module.fc2.weight, std=in_proj_std)
318
+ elif isinstance(module, CLOOBModel):
319
+ nn.init.normal_(
320
+ module.text_projection.weight,
321
+ std=module.text_embed_dim**-0.5 * self.config.initializer_factor,
322
+ )
323
+ nn.init.normal_(
324
+ module.visual_projection.weight,
325
+ std=module.vision_embed_dim**-0.5 * self.config.initializer_factor,
326
+ )
327
+
328
+ if isinstance(module, nn.LayerNorm):
329
+ module.bias.data.zero_()
330
+ module.weight.data.fill_(1.0)
331
+ if isinstance(module, nn.Linear) and module.bias is not None:
332
+ module.bias.data.zero_()
333
+
334
+ def _set_gradient_checkpointing(self, module, value=False):
335
+ if isinstance(module, CLOOBEncoder):
336
+ module.gradient_checkpointing = value
337
+
338
+
339
+ class CLOOBEncoder(nn.Module):
340
+ """
341
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
342
+ [`CLOOBEncoderLayer`].
343
+ Args:
344
+ config: CLOOBConfig
345
+ """
346
+
347
+ def __init__(self, config: CLOOBConfig):
348
+ super().__init__()
349
+ self.config = config
350
+ self.layers = nn.ModuleList([CLOOBEncoderLayer(config) for _ in range(config.num_hidden_layers)])
351
+ self.gradient_checkpointing = False
352
+
353
+ def forward(
354
+ self,
355
+ inputs_embeds,
356
+ attention_mask: Optional[torch.Tensor] = None,
357
+ causal_attention_mask: Optional[torch.Tensor] = None,
358
+ output_attentions: Optional[bool] = None,
359
+ output_hidden_states: Optional[bool] = None,
360
+ return_dict: Optional[bool] = None,
361
+ ) -> Union[Tuple, BaseModelOutput]:
362
+ r"""
363
+ Args:
364
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
365
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
366
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
367
+ than the model's internal embedding lookup matrix.
368
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
369
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
370
+ - 1 for tokens that are **not masked**,
371
+ - 0 for tokens that are **masked**.
372
+ [What are attention masks?](../glossary#attention-mask)
373
+ causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
374
+ Causal mask for the text model. Mask values selected in `[0, 1]`:
375
+ - 1 for tokens that are **not masked**,
376
+ - 0 for tokens that are **masked**.
377
+ [What are attention masks?](../glossary#attention-mask)
378
+ output_attentions (`bool`, *optional*):
379
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
380
+ returned tensors for more detail.
381
+ output_hidden_states (`bool`, *optional*):
382
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
383
+ for more detail.
384
+ return_dict (`bool`, *optional*):
385
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
386
+ """
387
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
388
+ output_hidden_states = (
389
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
390
+ )
391
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
392
+
393
+ encoder_states = () if output_hidden_states else None
394
+ all_attentions = () if output_attentions else None
395
+
396
+ hidden_states = inputs_embeds
397
+ for idx, encoder_layer in enumerate(self.layers):
398
+ if output_hidden_states:
399
+ encoder_states = encoder_states + (hidden_states,)
400
+ if self.gradient_checkpointing and self.training:
401
+
402
+ def create_custom_forward(module):
403
+ def custom_forward(*inputs):
404
+ return module(*inputs, output_attentions)
405
+
406
+ return custom_forward
407
+
408
+ layer_outputs = torch.utils.checkpoint.checkpoint(
409
+ create_custom_forward(encoder_layer),
410
+ hidden_states,
411
+ attention_mask,
412
+ causal_attention_mask,
413
+ )
414
+ else:
415
+ layer_outputs = encoder_layer(
416
+ hidden_states,
417
+ attention_mask,
418
+ causal_attention_mask,
419
+ output_attentions=output_attentions,
420
+ )
421
+
422
+ hidden_states = layer_outputs[0]
423
+
424
+ if output_attentions:
425
+ all_attentions = all_attentions + (layer_outputs[1],)
426
+
427
+ if output_hidden_states:
428
+ encoder_states = encoder_states + (hidden_states,)
429
+
430
+ if not return_dict:
431
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
432
+ return BaseModelOutput(
433
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
434
+ )
435
+
436
+
437
+ class CLOOBTextTransformer(nn.Module):
438
+ def __init__(self, config: CLOOBTextConfig):
439
+ super().__init__()
440
+ self.config = config
441
+ embed_dim = config.hidden_size
442
+ self.embeddings = CLOOBTextEmbeddings(config)
443
+ self.encoder = CLOOBEncoder(config)
444
+ self.final_layer_norm = nn.LayerNorm(embed_dim)
445
+
446
+ def forward(
447
+ self,
448
+ input_ids: Optional[torch.Tensor] = None,
449
+ attention_mask: Optional[torch.Tensor] = None,
450
+ position_ids: Optional[torch.Tensor] = None,
451
+ output_attentions: Optional[bool] = None,
452
+ output_hidden_states: Optional[bool] = None,
453
+ return_dict: Optional[bool] = None,
454
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
455
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
456
+ output_hidden_states = (
457
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
458
+ )
459
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
460
+
461
+ if input_ids is None:
462
+ raise ValueError("You have to specify either input_ids")
463
+
464
+ input_shape = input_ids.size()
465
+ input_ids = input_ids.view(-1, input_shape[-1])
466
+
467
+ hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
468
+
469
+ bsz, seq_len = input_shape
470
+ # CLOOB's text model uses causal mask, prepare it here.
471
+ # https://github.com/openai/CLOOB/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/CLOOB/model.py#L324
472
+ causal_attention_mask = self._build_causal_attention_mask(bsz, seq_len).to(hidden_states.device)
473
+ # expand attention_mask
474
+ if attention_mask is not None:
475
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
476
+ attention_mask = _expand_mask(attention_mask, hidden_states.dtype)
477
+
478
+ encoder_outputs = self.encoder(
479
+ inputs_embeds=hidden_states,
480
+ attention_mask=attention_mask,
481
+ causal_attention_mask=causal_attention_mask,
482
+ output_attentions=output_attentions,
483
+ output_hidden_states=output_hidden_states,
484
+ return_dict=return_dict,
485
+ )
486
+
487
+ last_hidden_state = encoder_outputs[0]
488
+ last_hidden_state = self.final_layer_norm(last_hidden_state)
489
+
490
+ # text_embeds.shape = [batch_size, sequence_length, transformer.width]
491
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
492
+ pooled_output = last_hidden_state[torch.arange(last_hidden_state.shape[0]), input_ids.argmax(dim=-1)]
493
+
494
+ if not return_dict:
495
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
496
+
497
+ return BaseModelOutputWithPooling(
498
+ last_hidden_state=last_hidden_state,
499
+ pooler_output=pooled_output,
500
+ hidden_states=encoder_outputs.hidden_states,
501
+ attentions=encoder_outputs.attentions,
502
+ )
503
+
504
+ def _build_causal_attention_mask(self, bsz, seq_len):
505
+ # lazily create causal attention mask, with full attention between the vision tokens
506
+ # pytorch uses additive attention mask; fill with -inf
507
+ mask = torch.empty(bsz, seq_len, seq_len)
508
+ mask.fill_(float("-inf"))
509
+ mask.triu_(1) # zero out the lower diagonal
510
+ mask = mask.unsqueeze(1) # expand mask
511
+ return mask
512
+
513
+
514
+ class CLOOBTextModel(CLOOBPreTrainedModel):
515
+ config_class = CLOOBTextConfig
516
+
517
+ def __init__(self, config: CLOOBTextConfig):
518
+ super().__init__(config)
519
+ self.text_model = CLOOBTextTransformer(config)
520
+ # Initialize weights and apply final processing
521
+ self.post_init()
522
+
523
+ def get_input_embeddings(self) -> nn.Module:
524
+ return self.text_model.embeddings.token_embedding
525
+
526
+ def set_input_embeddings(self, value):
527
+ self.text_model.embeddings.token_embedding = value
528
+
529
+ def forward(
530
+ self,
531
+ input_ids: Optional[torch.Tensor] = None,
532
+ attention_mask: Optional[torch.Tensor] = None,
533
+ position_ids: Optional[torch.Tensor] = None,
534
+ output_attentions: Optional[bool] = None,
535
+ output_hidden_states: Optional[bool] = None,
536
+ return_dict: Optional[bool] = None,
537
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
538
+ return self.text_model(
539
+ input_ids=input_ids,
540
+ attention_mask=attention_mask,
541
+ position_ids=position_ids,
542
+ output_attentions=output_attentions,
543
+ output_hidden_states=output_hidden_states,
544
+ return_dict=return_dict,
545
+ )
546
+
547
+
548
+ class CLOOBVisionTransformer(nn.Module):
549
+ def __init__(self, config: CLOOBVisionConfig):
550
+ super().__init__()
551
+ self.config = config
552
+ embed_dim = config.hidden_size
553
+
554
+ self.embeddings = CLOOBVisionEmbeddings(config)
555
+ self.pre_layrnorm = nn.LayerNorm(embed_dim)
556
+ self.encoder = CLOOBEncoder(config)
557
+ self.post_layernorm = nn.LayerNorm(embed_dim)
558
+
559
+ def forward(
560
+ self,
561
+ pixel_values: Optional[torch.FloatTensor] = None,
562
+ output_attentions: Optional[bool] = None,
563
+ output_hidden_states: Optional[bool] = None,
564
+ return_dict: Optional[bool] = None,
565
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
566
+ r"""
567
+ Returns:
568
+ """
569
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
570
+ output_hidden_states = (
571
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
572
+ )
573
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
574
+
575
+ if pixel_values is None:
576
+ raise ValueError("You have to specify pixel_values")
577
+
578
+ hidden_states = self.embeddings(pixel_values)
579
+ hidden_states = self.pre_layrnorm(hidden_states)
580
+
581
+ encoder_outputs = self.encoder(
582
+ inputs_embeds=hidden_states,
583
+ output_attentions=output_attentions,
584
+ output_hidden_states=output_hidden_states,
585
+ return_dict=return_dict,
586
+ )
587
+
588
+ last_hidden_state = encoder_outputs[0]
589
+ pooled_output = last_hidden_state[:, 0, :]
590
+ pooled_output = self.post_layernorm(pooled_output)
591
+
592
+ if not return_dict:
593
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
594
+
595
+ return BaseModelOutputWithPooling(
596
+ last_hidden_state=last_hidden_state,
597
+ pooler_output=pooled_output,
598
+ hidden_states=encoder_outputs.hidden_states,
599
+ attentions=encoder_outputs.attentions,
600
+ )
601
+
602
+
603
+ class CLOOBVisionModel(CLOOBPreTrainedModel):
604
+ config_class = CLOOBVisionConfig
605
+ main_input_name = "pixel_values"
606
+
607
+ def __init__(self, config: CLOOBVisionConfig):
608
+ super().__init__(config)
609
+ self.vision_model = CLOOBVisionTransformer(config)
610
+ # Initialize weights and apply final processing
611
+ self.post_init()
612
+
613
+ def get_input_embeddings(self) -> nn.Module:
614
+ return self.vision_model.embeddings.patch_embedding
615
+
616
+ def forward(
617
+ self,
618
+ pixel_values: Optional[torch.FloatTensor] = None,
619
+ output_attentions: Optional[bool] = None,
620
+ output_hidden_states: Optional[bool] = None,
621
+ return_dict: Optional[bool] = None,
622
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
623
+ return self.vision_model(
624
+ pixel_values=pixel_values,
625
+ output_attentions=output_attentions,
626
+ output_hidden_states=output_hidden_states,
627
+ return_dict=return_dict,
628
+ )
629
+
630
+
631
+ class CLOOBModel(CLOOBPreTrainedModel):
632
+ config_class = CLOOBConfig
633
+
634
+ def __init__(self, config: CLOOBConfig):
635
+ super().__init__(config)
636
+ text_config = config.text_config
637
+ vision_config = config.vision_config
638
+
639
+ self.projection_dim = config.projection_dim
640
+ self.text_embed_dim = text_config.hidden_size
641
+ self.vision_embed_dim = vision_config.hidden_size
642
+
643
+ if isinstance(text_config, CLOOBTextConfig):
644
+ text_model = CLOOBTextTransformer(text_config)
645
+ else:
646
+ text_model = AutoModel.from_config(config.text_config, add_pooling_layer=False)
647
+
648
+ if isinstance(config.vision_config, CLOOBVisionConfig):
649
+ vision_model = CLOOBVisionModel(config.vision_config)
650
+ else:
651
+ vision_model = AutoModel.from_config(config.vision_config, add_pooling_layer=False)
652
+
653
+ self.text_model = text_model
654
+ self.vision_model = vision_model
655
+
656
+ self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False)
657
+ self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False)
658
+
659
+ self.inv_tau = config.init_inv_tau
660
+ self.scale_hopfield = config.scale_hopfield
661
+
662
+ # Initialize weights and apply final processing
663
+ self.post_init()
664
+
665
+ def encode_text(self, input_ids, **kwargs):
666
+ return self.get_text_features(input_ids=input_ids, **kwargs)
667
+
668
+ def get_text_features(
669
+ self,
670
+ input_ids: Optional[torch.Tensor] = None,
671
+ attention_mask: Optional[torch.Tensor] = None,
672
+ position_ids: Optional[torch.Tensor] = None,
673
+ output_attentions: Optional[bool] = None,
674
+ output_hidden_states: Optional[bool] = None,
675
+ return_dict: Optional[bool] = None,
676
+ ) -> torch.FloatTensor:
677
+ # Use CLOOB model's config for some fields (if specified) instead of those of vision & text components.
678
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
679
+ output_hidden_states = (
680
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
681
+ )
682
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
683
+
684
+ text_outputs = self.text_model(
685
+ input_ids=input_ids,
686
+ attention_mask=attention_mask,
687
+ position_ids=position_ids,
688
+ output_attentions=output_attentions,
689
+ output_hidden_states=output_hidden_states,
690
+ return_dict=return_dict,
691
+ )
692
+ pooled_output = text_outputs.last_hidden_state[:, 0, :]
693
+ text_features = self.text_projection(pooled_output)
694
+
695
+ return text_features
696
+
697
+ def encode_image(self, pixel_values, **kwargs):
698
+ return self.get_image_features(pixel_values=pixel_values, **kwargs)
699
+
700
+ def get_image_features(
701
+ self,
702
+ pixel_values: Optional[torch.FloatTensor] = None,
703
+ output_attentions: Optional[bool] = None,
704
+ output_hidden_states: Optional[bool] = None,
705
+ return_dict: Optional[bool] = None,
706
+ ) -> torch.FloatTensor:
707
+ # Use CLOOB model's config for some fields (if specified) instead of those of vision & text components.
708
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
709
+ output_hidden_states = (
710
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
711
+ )
712
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
713
+
714
+ vision_outputs = self.vision_model(
715
+ pixel_values=pixel_values,
716
+ output_attentions=output_attentions,
717
+ output_hidden_states=output_hidden_states,
718
+ return_dict=return_dict,
719
+ )
720
+ pooled_output = vision_outputs.last_hidden_state[:, 0, :]
721
+ image_features = self.visual_projection(pooled_output)
722
+
723
+ return image_features
724
+
725
+ def forward(
726
+ self,
727
+ input_ids: Optional[torch.LongTensor] = None,
728
+ pixel_values: Optional[torch.FloatTensor] = None,
729
+ attention_mask: Optional[torch.Tensor] = None,
730
+ position_ids: Optional[torch.LongTensor] = None,
731
+ return_loss: Optional[bool] = None,
732
+ output_attentions: Optional[bool] = None,
733
+ output_hidden_states: Optional[bool] = None,
734
+ return_dict: Optional[bool] = None,
735
+ ) -> Union[Tuple, CLOOBOutput]:
736
+ # Use CLOOB model's config for some fields (if specified) instead of those of vision & text components.
737
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
738
+ output_hidden_states = (
739
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
740
+ )
741
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
742
+
743
+ vision_outputs = self.vision_model(
744
+ pixel_values=pixel_values,
745
+ output_attentions=output_attentions,
746
+ output_hidden_states=output_hidden_states,
747
+ return_dict=return_dict,
748
+ )
749
+
750
+ text_outputs = self.text_model(
751
+ input_ids=input_ids,
752
+ attention_mask=attention_mask,
753
+ position_ids=position_ids,
754
+ output_attentions=output_attentions,
755
+ output_hidden_states=output_hidden_states,
756
+ return_dict=return_dict,
757
+ )
758
+ image_embeds = vision_outputs.last_hidden_state[:, 0, :]
759
+ image_embeds = self.visual_projection(image_embeds)
760
+
761
+ text_embeds = text_outputs.last_hidden_state[:, 0, :]
762
+ text_embeds = self.text_projection(text_embeds)
763
+
764
+ # normalized features
765
+ image_embeds = image_embeds / image_embeds.norm(dim=-1, keepdim=True)
766
+ text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True)
767
+
768
+ loss = None
769
+ if return_loss:
770
+ loss = cloob_loss(image_embeds, text_embeds, self.inv_tau, self.scale_hopfield)
771
+
772
+ if not return_dict:
773
+ output = (text_embeds, image_embeds, self.inv_tau, text_outputs, vision_outputs)
774
+ return ((loss,) + output) if loss is not None else output
775
+
776
+ return CLOOBOutput(
777
+ loss=loss,
778
+ text_embeds=text_embeds,
779
+ image_embeds=image_embeds,
780
+ inv_tau=self.inv_tau,
781
+ text_model_output=text_outputs,
782
+ vision_model_output=vision_outputs,
783
+ )
japanese_clip/tokenizer.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 rinna Co., Ltd.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Union, List
17
+ import torch
18
+ from transformers import T5Tokenizer
19
+
20
+
21
+ def load_tokenizer():
22
+ """
23
+ https://huggingface.co/rinna/japanese-roberta-base
24
+ """
25
+ tokenizer = T5Tokenizer.from_pretrained("rinna/japanese-roberta-base")
26
+ tokenizer.do_lower_case = True # due to some bug of tokenizer config loading
27
+ return tokenizer
28
+
29
+
30
+ def tokenize(
31
+ texts: Union[str, List[str]],
32
+ tokenizer: T5Tokenizer = None,
33
+ max_seq_len: int = 77,
34
+ device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu",
35
+ ):
36
+ """
37
+ This is a function that have the original clip's code has.
38
+ https://github.com/openai/CLIP/blob/main/clip/clip.py#L195
39
+ """
40
+ if isinstance(texts, str):
41
+ texts = [texts]
42
+ if tokenizer is None:
43
+ tokenizer = load_tokenizer()
44
+ inputs = tokenizer(
45
+ texts,
46
+ max_length=max_seq_len-1,
47
+ padding="max_length",
48
+ truncation=True,
49
+ add_special_tokens=False,
50
+ )
51
+ # add cls token at first place
52
+ input_ids = [[tokenizer.cls_token_id] + ids for ids in inputs['input_ids']]
53
+ attention_mask = [[1] + am for am in inputs['attention_mask']]
54
+ position_ids = [list(range(0, len(input_ids[0])))] * len(texts)
55
+
56
+ input_ids = torch.tensor(input_ids, dtype=torch.long)
57
+ attention_mask = torch.tensor(attention_mask, dtype=torch.long)
58
+ position_ids = torch.tensor(position_ids, dtype=torch.long)
59
+ return {
60
+ "input_ids": input_ids.to(device),
61
+ "attention_mask": attention_mask.to(device),
62
+ "position_ids": position_ids.to(device),
63
+ }
japanese_clip/utils/__init__.py ADDED
File without changes
japanese_clip/utils/callbacks.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 rinna Co., Ltd.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from tqdm.auto import tqdm
17
+ import numpy as np
18
+ import torch
19
+
20
+
21
+ def accuracy(output, target, topk=(1,)):
22
+ output = torch.from_numpy(np.asarray(output))
23
+ target = torch.from_numpy(np.asarray(target))
24
+ pred = output.topk(max(topk), dim=1, largest=True, sorted=True)[1].t()
25
+ correct = pred.eq(target.view(1, -1).expand_as(pred))
26
+ return [
27
+ float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy())
28
+ for k in topk
29
+ ]
30
+
31
+
32
+ class ImagenetClassificationCallback:
33
+ def __init__(
34
+ self,
35
+ imagenet_classes,
36
+ imagenet_templates,
37
+ imagenet_dataloader,
38
+ ):
39
+ self.imagenet_classes = imagenet_classes
40
+ self.imagenet_templates = imagenet_templates
41
+ self.imagenet_dataloader = imagenet_dataloader
42
+
43
+ def tokenize(self, tokenizer, examples, device):
44
+ encoding_inputs = tokenizer(examples, max_length=76, padding="max_length", truncation=True, add_special_tokens=False)
45
+ # add cls token at first place
46
+ input_ids = [[tokenizer.cls_token_id] + ids for ids in encoding_inputs['input_ids']]
47
+ attention_mask = [[1] + am for am in encoding_inputs['attention_mask']]
48
+ position_ids = [list(range(0, len(input_ids[0])))] * len(examples)
49
+
50
+ input_ids = torch.tensor(input_ids, dtype=torch.long, device=device)
51
+ attention_mask = torch.tensor(attention_mask, dtype=torch.long, device=device)
52
+ position_ids = torch.tensor(position_ids, dtype=torch.long, device=device)
53
+ return {
54
+ "input_ids": input_ids,
55
+ "attention_mask": attention_mask,
56
+ "position_ids": position_ids,
57
+ }
58
+
59
+ def zeroshot_classifier(self, model, tokenizer, classnames, templates):
60
+ zeroshot_weights = []
61
+ for classname in tqdm(classnames):
62
+ texts = [template.format(classname) for template in templates]
63
+ class_embeddings = model.get_text_features(**self.tokenize(tokenizer, texts, model.device)).detach().cpu().numpy()
64
+ class_embeddings = class_embeddings / np.linalg.norm(
65
+ class_embeddings, axis=-1, keepdims=True
66
+ )
67
+ class_embedding = np.mean(class_embeddings, axis=0)
68
+ class_embedding /= np.linalg.norm(class_embedding, axis=-1)
69
+ zeroshot_weights.append(class_embedding)
70
+ zeroshot_weights = np.stack(zeroshot_weights, axis=1)
71
+ return zeroshot_weights
72
+
73
+ def zeroshot(self, model, tokenizer) -> dict:
74
+ print("Imagenet Zeroshot Classification...")
75
+ zeroshot_weights = self.zeroshot_classifier(model, tokenizer, self.imagenet_classes, self.imagenet_templates)
76
+ top_ns = [1, 5, 10, 100]
77
+ acc_counters = [0.0 for _ in top_ns]
78
+ n = 0.0
79
+
80
+ for i, (images, target) in enumerate(tqdm(self.imagenet_dataloader)):
81
+ target = target.numpy()
82
+ # predict
83
+ image_features = model.get_image_features(images.to(model.device)).detach().cpu().numpy()
84
+ image_features = image_features / np.linalg.norm(image_features, axis=-1, keepdims=True)
85
+ logits = 100.0 * image_features @ zeroshot_weights
86
+
87
+ # measure accuracy
88
+ accs = accuracy(logits, target, topk=top_ns)
89
+ for j in range(len(top_ns)):
90
+ acc_counters[j] += accs[j]
91
+ n += images.shape[0]
92
+
93
+ tops = {f"imagenet/top{top_ns[i]}": acc_counters[i] / n * 100 for i in range(len(top_ns))}
94
+
95
+ return tops
96
+
japanese_clip/utils/imagenet_zeroshot_data.py ADDED
@@ -0,0 +1,1043 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ imagenet_classnames = [{'en': 'tench', 'ja': 'テンチ'},
2
+ {'en': 'goldfish', 'ja': '金魚'},
3
+ {'en': 'great white shark', 'ja': 'ホホジロザメ'},
4
+ {'en': 'tiger shark', 'ja': 'イタチザメ'},
5
+ {'en': 'hammerhead shark', 'ja': 'ハンマーヘッド'},
6
+ {'en': 'electric ray', 'ja': 'シビレエイ'},
7
+ {'en': 'stingray', 'ja': 'アカエイ'},
8
+ {'en': 'rooster', 'ja': 'コック'},
9
+ {'en': 'hen', 'ja': 'めんどり'},
10
+ {'en': 'ostrich', 'ja': 'ダチョウ'},
11
+ {'en': 'brambling', 'ja': 'アトリ'},
12
+ {'en': 'goldfinch', 'ja': 'ゴシキヒワ'},
13
+ {'en': 'house finch', 'ja': 'ハウスフィンチ'},
14
+ {'en': 'junco', 'ja': 'ユキヒメドリ'},
15
+ {'en': 'indigo bunting', 'ja': 'インディゴホオジロ'},
16
+ {'en': 'American robin', 'ja': 'ロビン'},
17
+ {'en': 'bulbul', 'ja': 'ブルブル'},
18
+ {'en': 'jay', 'ja': 'カケス'},
19
+ {'en': 'magpie', 'ja': 'カササギ'},
20
+ {'en': 'chickadee', 'ja': '四十雀'},
21
+ {'en': 'American dipper', 'ja': '水クロウタドリ'},
22
+ {'en': 'kite (bird of prey)', 'ja': '凧'},
23
+ {'en': 'bald eagle', 'ja': '白頭ワシ'},
24
+ {'en': 'vulture', 'ja': 'ハゲワシ'},
25
+ {'en': 'great grey owl', 'ja': 'カラフトフクロウ'},
26
+ {'en': 'fire salamander', 'ja': '欧州ファイアサラマンダー'},
27
+ {'en': 'smooth newt', 'ja': '共通イモリ'},
28
+ {'en': 'newt', 'ja': 'イモリ'},
29
+ {'en': 'spotted salamander', 'ja': 'サンショウウオを発見'},
30
+ {'en': 'axolotl', 'ja': 'アホロートル'},
31
+ {'en': 'American bullfrog', 'ja': 'ウシガエル'},
32
+ {'en': 'tree frog', 'ja': 'アマガエル'},
33
+ {'en': 'tailed frog', 'ja': 'つかれたカエル'},
34
+ {'en': 'loggerhead sea turtle', 'ja': 'とんちき'},
35
+ {'en': 'leatherback sea turtle', 'ja': 'オサガメ'},
36
+ {'en': 'mud turtle', 'ja': '鼈'},
37
+ {'en': 'terrapin', 'ja': 'テラピン'},
38
+ {'en': 'box turtle', 'ja': 'ハコガメ'},
39
+ {'en': 'banded gecko', 'ja': '縞模様のヤモリ'},
40
+ {'en': 'green iguana', 'ja': '共通イグアナ'},
41
+ {'en': 'Carolina anole', 'ja': 'アメリカンカメレオン'},
42
+ {'en': 'desert grassland whiptail lizard', 'ja': 'ウィッペイル'},
43
+ {'en': 'agama', 'ja': 'アガマトカゲ'},
44
+ {'en': 'frilled-necked lizard', 'ja': 'フリルトカゲ'},
45
+ {'en': 'alligator lizard', 'ja': 'アリゲータートカゲ'},
46
+ {'en': 'Gila monster', 'ja': 'アメリカドクトカゲ'},
47
+ {'en': 'European green lizard', 'ja': '緑のトカゲ'},
48
+ {'en': 'chameleon', 'ja': 'アフリカのカメレオン'},
49
+ {'en': 'Komodo dragon', 'ja': 'コモドドラゴン'},
50
+ {'en': 'Nile crocodile', 'ja': 'アフリカのワニ'},
51
+ {'en': 'American alligator', 'ja': 'アメリカワニ'},
52
+ {'en': 'triceratops', 'ja': 'トリケラトプス'},
53
+ {'en': 'worm snake', 'ja': '雷のヘビ'},
54
+ {'en': 'ring-necked snake', 'ja': 'リングネックスネーク'},
55
+ {'en': 'eastern hog-nosed snake', 'ja': 'ホーノースヘビ'},
56
+ {'en': 'smooth green snake', 'ja': '緑のヘビ'},
57
+ {'en': 'kingsnake', 'ja': 'キングスネーク'},
58
+ {'en': 'garter snake', 'ja': 'ガータースネーク'},
59
+ {'en': 'water snake', 'ja': '水蛇'},
60
+ {'en': 'vine snake', 'ja': 'つるヘビ'},
61
+ {'en': 'night snake', 'ja': '夜のヘビ'},
62
+ {'en': 'boa constrictor', 'ja': 'ボア・コンストリクター'},
63
+ {'en': 'African rock python', 'ja': 'ロックパイソン'},
64
+ {'en': 'Indian cobra', 'ja': 'インドコブラ'},
65
+ {'en': 'green mamba', 'ja': 'グリーンマンバ'},
66
+ {'en': 'sea snake', 'ja': 'ウミヘビ'},
67
+ {'en': 'Saharan horned viper', 'ja': 'ツノクサリヘビ'},
68
+ {'en': 'eastern diamondback rattlesnake', 'ja': 'ダイヤ'},
69
+ {'en': 'sidewinder rattlesnake', 'ja': 'サイドワインダー'},
70
+ {'en': 'trilobite', 'ja': '三葉虫'},
71
+ {'en': 'harvestman', 'ja': '刈り入れ作業者'},
72
+ {'en': 'scorpion', 'ja': 'サソリ'},
73
+ {'en': 'yellow garden spider', 'ja': '黒と金の庭クモ'},
74
+ {'en': 'barn spider', 'ja': '納屋クモ'},
75
+ {'en': 'European garden spider', 'ja': '庭クモ'},
76
+ {'en': 'southern black widow', 'ja': 'クロゴケグモ'},
77
+ {'en': 'tarantula', 'ja': 'タランチュラ'},
78
+ {'en': 'wolf spider', 'ja': 'オオカミのクモ'},
79
+ {'en': 'tick', 'ja': 'ダニ'},
80
+ {'en': 'centipede', 'ja': '百足'},
81
+ {'en': 'black grouse', 'ja': 'クロライチョウ'},
82
+ {'en': 'ptarmigan', 'ja': '雷鳥'},
83
+ {'en': 'ruffed grouse', 'ja': 'ひだえりの付いたライチョウ'},
84
+ {'en': 'prairie grouse', 'ja': '草原チキン'},
85
+ {'en': 'peafowl', 'ja': '孔雀'},
86
+ {'en': 'quail', 'ja': 'ウズラ'},
87
+ {'en': 'partridge', 'ja': 'ヤマウズラ'},
88
+ {'en': 'african grey parrot', 'ja': 'アフリカの灰色'},
89
+ {'en': 'macaw', 'ja': 'コンゴウインコ'},
90
+ {'en': 'sulphur-crested cockatoo', 'ja': '硫黄トキオウム'},
91
+ {'en': 'lorikeet', 'ja': 'インコ'},
92
+ {'en': 'coucal', 'ja': 'バンケン'},
93
+ {'en': 'bee eater', 'ja': '蜂食べる人'},
94
+ {'en': 'hornbill', 'ja': 'サイチョウ'},
95
+ {'en': 'hummingbird', 'ja': 'ハチドリ'},
96
+ {'en': 'jacamar', 'ja': '錐嘴'},
97
+ {'en': 'toucan', 'ja': 'オオハシ'},
98
+ {'en': 'duck', 'ja': 'ドレイク'},
99
+ {'en': 'red-breasted merganser', 'ja': '赤ブレストアイサ属のガモ'},
100
+ {'en': 'goose', 'ja': 'ガチョウ'},
101
+ {'en': 'black swan', 'ja': '黒い白鳥'},
102
+ {'en': 'tusker', 'ja': 'タスカービール'},
103
+ {'en': 'echidna', 'ja': 'ハリモグラ'},
104
+ {'en': 'platypus', 'ja': 'カモノハシ'},
105
+ {'en': 'wallaby', 'ja': 'ワラビー'},
106
+ {'en': 'koala', 'ja': 'コアラ'},
107
+ {'en': 'wombat', 'ja': 'ウォンバット'},
108
+ {'en': 'jellyfish', 'ja': 'クラゲ'},
109
+ {'en': 'sea anemone', 'ja': 'イソギンチャク'},
110
+ {'en': 'brain coral', 'ja': '脳サンゴ'},
111
+ {'en': 'flatworm', 'ja': '扁形動物'},
112
+ {'en': 'nematode', 'ja': '線虫'},
113
+ {'en': 'conch', 'ja': '巻き貝'},
114
+ {'en': 'snail', 'ja': 'カタツムリ'},
115
+ {'en': 'slug', 'ja': 'ナメクジ'},
116
+ {'en': 'sea slug', 'ja': 'ウミウシ'},
117
+ {'en': 'chiton', 'ja': 'キトン'},
118
+ {'en': 'chambered nautilus', 'ja': 'オウムガイ'},
119
+ {'en': 'Dungeness crab', 'ja': 'アメリカイチョウガニ'},
120
+ {'en': 'rock crab', 'ja': '岩カニ'},
121
+ {'en': 'fiddler crab', 'ja': 'シオマネキ'},
122
+ {'en': 'red king crab', 'ja': 'タラバガニ'},
123
+ {'en': 'American lobster', 'ja': 'アメリカンロブスター'},
124
+ {'en': 'spiny lobster', 'ja': '伊勢エビ'},
125
+ {'en': 'crayfish', 'ja': 'ザリガニ'},
126
+ {'en': 'hermit crab', 'ja': 'ヤドカリ'},
127
+ {'en': 'isopod', 'ja': '等脚類'},
128
+ {'en': 'white stork', 'ja': 'コウノトリ'},
129
+ {'en': 'black stork', 'ja': 'ナベコウ'},
130
+ {'en': 'spoonbill', 'ja': 'ヘラサギ'},
131
+ {'en': 'flamingo', 'ja': 'フラミンゴ'},
132
+ {'en': 'little blue heron', 'ja': '小さな青いサギ'},
133
+ {'en': 'great egret', 'ja': 'アメリカン白鷺'},
134
+ {'en': 'bittern bird', 'ja': 'にがり'},
135
+ {'en': 'crane bird', 'ja': 'クレーン'},
136
+ {'en': 'limpkin', 'ja': 'ツルモドキ科の鳥'},
137
+ {'en': 'common gallinule', 'ja': 'ヨーロピアン水鳥'},
138
+ {'en': 'American coot', 'ja': 'アメリカオオバン'},
139
+ {'en': 'bustard', 'ja': 'ノガン'},
140
+ {'en': 'ruddy turnstone', 'ja': 'キョウジョシギ'},
141
+ {'en': 'dunlin', 'ja': '赤担保シギ'},
142
+ {'en': 'common redshank', 'ja': 'アカアシシギ'},
143
+ {'en': 'dowitcher', 'ja': 'オオハシシギ'},
144
+ {'en': 'oystercatcher', 'ja': 'ミヤコドリ'},
145
+ {'en': 'pelican', 'ja': 'ペリカン'},
146
+ {'en': 'king penguin', 'ja': 'キングペンギン'},
147
+ {'en': 'albatross', 'ja': 'アルバトロス'},
148
+ {'en': 'grey whale', 'ja': 'コククジラ'},
149
+ {'en': 'killer whale', 'ja': 'シャチ'},
150
+ {'en': 'dugong', 'ja': 'ジュゴン'},
151
+ {'en': 'sea lion', 'ja': 'アシカ'},
152
+ {'en': 'Chihuahua', 'ja': 'チワワ'},
153
+ {'en': 'Japanese Chin', 'ja': '狆'},
154
+ {'en': 'Maltese', 'ja': 'マルチーズ犬'},
155
+ {'en': 'Pekingese', 'ja': '狆'},
156
+ {'en': 'Shih Tzu', 'ja': 'シーズー、シーズー'},
157
+ {'en': 'King Charles Spaniel', 'ja': 'ブレナムスパニエル'},
158
+ {'en': 'Papillon', 'ja': 'パピヨン'},
159
+ {'en': 'toy terrier', 'ja': 'トイテリア'},
160
+ {'en': 'Rhodesian Ridgeback', 'ja': 'ローデシアン・リッジバック'},
161
+ {'en': 'Afghan Hound', 'ja': 'アフガンハウンド'},
162
+ {'en': 'Basset Hound', 'ja': 'バセット犬'},
163
+ {'en': 'Beagle', 'ja': 'ビーグル'},
164
+ {'en': 'Bloodhound', 'ja': 'ブラッドハウンド'},
165
+ {'en': 'Bluetick Coonhound', 'ja': 'ブルーティック'},
166
+ {'en': 'Black and Tan Coonhound', 'ja': '黒と黄褐色の猟犬'},
167
+ {'en': 'Treeing Walker Coonhound', 'ja': 'ウォーカーハウンド'},
168
+ {'en': 'English foxhound', 'ja': 'イングリッシュフォックスハウンド'},
169
+ {'en': 'Redbone Coonhound', 'ja': 'レッドボーン'},
170
+ {'en': 'borzoi', 'ja': 'ボルゾイ'},
171
+ {'en': 'Irish Wolfhound', 'ja': 'アイリッシュ・ウルフハウンド'},
172
+ {'en': 'Italian Greyhound', 'ja': 'イタリアングレーハウンド'},
173
+ {'en': 'Whippet', 'ja': 'ウィペット'},
174
+ {'en': 'Ibizan Hound', 'ja': 'イビサハウンド'},
175
+ {'en': 'Norwegian Elkhound', 'ja': 'ノルウェーエルクハウンド'},
176
+ {'en': 'Otterhound', 'ja': 'オッターハウンド'},
177
+ {'en': 'Saluki', 'ja': 'サルーキ'},
178
+ {'en': 'Scottish Deerhound', 'ja': 'スコティッシュ・ディアハウンド'},
179
+ {'en': 'Weimaraner', 'ja': 'ワイマラナー'},
180
+ {'en': 'Staffordshire Bull Terrier', 'ja': 'スタフォードシャーブルテリア'},
181
+ {'en': 'American Staffordshire Terrier', 'ja': 'アメリカン・スタッフォードシャー・テリア'},
182
+ {'en': 'Bedlington Terrier', 'ja': 'ベドリントンテリア'},
183
+ {'en': 'Border Terrier', 'ja': 'ボーダーテリア'},
184
+ {'en': 'Kerry Blue Terrier', 'ja': 'ケリーブルーテリア'},
185
+ {'en': 'Irish Terrier', 'ja': 'アイリッシュテリア'},
186
+ {'en': 'Norfolk Terrier', 'ja': 'ノーフォークテリア'},
187
+ {'en': 'Norwich Terrier', 'ja': 'ノーリッチ・テリア'},
188
+ {'en': 'Yorkshire Terrier', 'ja': 'ヨークシャーテリア'},
189
+ {'en': 'Wire Fox Terrier', 'ja': 'ワイヤーヘアー・フォックステリア'},
190
+ {'en': 'Lakeland Terrier', 'ja': 'レークランドテリア'},
191
+ {'en': 'Sealyham Terrier', 'ja': 'シーリーハムテリア'},
192
+ {'en': 'Airedale Terrier', 'ja': 'エアデール'},
193
+ {'en': 'Cairn Terrier', 'ja': 'ケルン'},
194
+ {'en': 'Australian Terrier', 'ja': 'オーストラリアテリア'},
195
+ {'en': 'Dandie Dinmont Terrier', 'ja': 'ダンディディンモントテリア'},
196
+ {'en': 'Boston Terrier', 'ja': 'ボストンブル'},
197
+ {'en': 'Miniature Schnauzer', 'ja': 'ミニチュアシュナウザー'},
198
+ {'en': 'Giant Schnauzer', 'ja': 'ジャイアントシュナウザー'},
199
+ {'en': 'Standard Schnauzer', 'ja': 'スタンダードシュナウザー'},
200
+ {'en': 'Scottish Terrier', 'ja': 'スコッチテリア'},
201
+ {'en': 'Tibetan Terrier', 'ja': 'チベタンテリア'},
202
+ {'en': 'Australian Silky Terrier', 'ja': 'シルキーテリア'},
203
+ {'en': 'Soft-coated Wheaten Terrier', 'ja': 'ソフトコーテッド・ウィートン・テリア'},
204
+ {'en': 'West Highland White Terrier', 'ja': 'ウェストハイランドホワイトテリア'},
205
+ {'en': 'Lhasa Apso', 'ja': 'ラサ'},
206
+ {'en': 'Flat-Coated Retriever', 'ja': 'フラットコーテッド・レトリーバー'},
207
+ {'en': 'Curly-coated Retriever', 'ja': 'カーリーコーティングされたレトリーバー'},
208
+ {'en': 'Golden Retriever', 'ja': 'ゴールデンレトリバー'},
209
+ {'en': 'Labrador Retriever', 'ja': 'ラブラドル・レトリーバー犬'},
210
+ {'en': 'Chesapeake Bay Retriever', 'ja': 'チェサピーク湾レトリーバー'},
211
+ {'en': 'German Shorthaired Pointer', 'ja': 'ジャーマン・ショートヘア・ポインタ'},
212
+ {'en': 'Vizsla', 'ja': 'ビズラ'},
213
+ {'en': 'English Setter', 'ja': 'イングリッシュセッター'},
214
+ {'en': 'Irish Setter', 'ja': 'アイリッシュセッター'},
215
+ {'en': 'Gordon Setter', 'ja': 'ゴードンセッター'},
216
+ {'en': 'Brittany dog', 'ja': 'ブリタニースパニエル'},
217
+ {'en': 'Clumber Spaniel', 'ja': 'クランバー'},
218
+ {'en': 'English Springer Spaniel', 'ja': 'イングリッシュスプリンガー'},
219
+ {'en': 'Welsh Springer Spaniel', 'ja': 'ウェルシュスプリンガースパニエル'},
220
+ {'en': 'Cocker Spaniel', 'ja': 'コッカースパニエル'},
221
+ {'en': 'Sussex Spaniel', 'ja': 'サセックススパニエル'},
222
+ {'en': 'Irish Water Spaniel', 'ja': 'アイルランドのウォータースパニエル'},
223
+ {'en': 'Kuvasz', 'ja': 'クバース犬'},
224
+ {'en': 'Schipperke', 'ja': 'スキッパーキー'},
225
+ {'en': 'Groenendael dog', 'ja': 'ベルジアン・シェパード・ドッグ・グローネンダール'},
226
+ {'en': 'Malinois', 'ja': 'マリノア'},
227
+ {'en': 'Briard', 'ja': 'ブリアール'},
228
+ {'en': 'Australian Kelpie', 'ja': 'ケルピー'},
229
+ {'en': 'Komondor', 'ja': 'コモンドール'},
230
+ {'en': 'Old English Sheepdog', 'ja': 'オールドイングリッシュシープドッグ'},
231
+ {'en': 'Shetland Sheepdog', 'ja': 'シェトランドシープドッグ'},
232
+ {'en': 'collie', 'ja': 'コリー'},
233
+ {'en': 'Border Collie', 'ja': 'ボーダーコリー'},
234
+ {'en': 'Bouvier des Flandres dog', 'ja': 'ブーヴィエ・デ・フランドル'},
235
+ {'en': 'Rottweiler', 'ja': 'ロットワイラー'},
236
+ {'en': 'German Shepherd Dog', 'ja': 'ジャーマンシェパード'},
237
+ {'en': 'Dobermann', 'ja': 'ドーベルマン犬'},
238
+ {'en': 'Miniature Pinscher', 'ja': 'ミニチュアピンシャー'},
239
+ {'en': 'Greater Swiss Mountain Dog', 'ja': 'グレータースイスマウンテンドッグ'},
240
+ {'en': 'Bernese Mountain Dog', 'ja': 'バーネーズマウンテンドッグ'},
241
+ {'en': 'Appenzeller Sennenhund', 'ja': 'アッペンツェル'},
242
+ {'en': 'Entlebucher Sennenhund', 'ja': 'エントレブッシャー'},
243
+ {'en': 'Boxer', 'ja': 'ボクサー'},
244
+ {'en': 'Bullmastiff', 'ja': 'ブルマスチフ'},
245
+ {'en': 'Tibetan Mastiff', 'ja': 'チベットマスチフ'},
246
+ {'en': 'French Bulldog', 'ja': 'フレンチブルドッグ'},
247
+ {'en': 'Great Dane', 'ja': 'グレートデーン'},
248
+ {'en': 'St. Bernard', 'ja': 'セントバーナード'},
249
+ {'en': 'husky', 'ja': 'エスキモー犬'},
250
+ {'en': 'Alaskan Malamute', 'ja': 'マラミュート'},
251
+ {'en': 'Siberian Husky', 'ja': 'シベリアンハスキー'},
252
+ {'en': 'Dalmatian', 'ja': 'ダルメシアン'},
253
+ {'en': 'Affenpinscher', 'ja': 'アーフェンピンシャー'},
254
+ {'en': 'Basenji', 'ja': 'バセンジー'},
255
+ {'en': 'pug', 'ja': 'パグ'},
256
+ {'en': 'Leonberger', 'ja': 'レオンバーグ'},
257
+ {'en': 'Newfoundland dog', 'ja': 'ニューファンドランド島'},
258
+ {'en': 'Great Pyrenees dog', 'ja': 'グレートピレニーズ'},
259
+ {'en': 'Samoyed', 'ja': 'サモエド'},
260
+ {'en': 'Pomeranian', 'ja': 'ポメラニアン'},
261
+ {'en': 'Chow Chow', 'ja': 'チャウ'},
262
+ {'en': 'Keeshond', 'ja': 'キースホンド'},
263
+ {'en': 'brussels griffon', 'ja': 'ブラバンソングリフォン'},
264
+ {'en': 'Pembroke Welsh Corgi', 'ja': 'ペンブローク'},
265
+ {'en': 'Cardigan Welsh Corgi', 'ja': 'カーディガン'},
266
+ {'en': 'Toy Poodle', 'ja': 'トイプードル'},
267
+ {'en': 'Miniature Poodle', 'ja': 'ミニチュアプードル'},
268
+ {'en': 'Standard Poodle', 'ja': 'スタンダードプードル'},
269
+ {'en': 'Mexican hairless dog (xoloitzcuintli)', 'ja': 'メキシカン・ヘアーレス'},
270
+ {'en': 'grey wolf', 'ja': 'シンリンオオカミ'},
271
+ {'en': 'Alaskan tundra wolf', 'ja': '白いオオカミ'},
272
+ {'en': 'red wolf or maned wolf', 'ja': 'レッドウルフ'},
273
+ {'en': 'coyote', 'ja': 'コヨーテ'},
274
+ {'en': 'dingo', 'ja': 'ディンゴ'},
275
+ {'en': 'dhole', 'ja': 'ドール'},
276
+ {'en': 'African wild dog', 'ja': 'リカオン'},
277
+ {'en': 'hyena', 'ja': 'ハイエナ'},
278
+ {'en': 'red fox', 'ja': 'アカギツネ'},
279
+ {'en': 'kit fox', 'ja': 'キットキツネ'},
280
+ {'en': 'Arctic fox', 'ja': 'ホッキョクギツネ'},
281
+ {'en': 'grey fox', 'ja': '灰色のキツネ'},
282
+ {'en': 'tabby cat', 'ja': 'タビー'},
283
+ {'en': 'tiger cat', 'ja': '虎猫'},
284
+ {'en': 'Persian cat', 'ja': 'ペルシャ猫'},
285
+ {'en': 'Siamese cat', 'ja': 'シャム猫'},
286
+ {'en': 'Egyptian Mau', 'ja': 'エジプトの猫'},
287
+ {'en': 'cougar', 'ja': 'クーガー'},
288
+ {'en': 'lynx', 'ja': 'オオヤマネコ'},
289
+ {'en': 'leopard', 'ja': 'ヒョウ'},
290
+ {'en': 'snow leopard', 'ja': 'ユキヒョウ'},
291
+ {'en': 'jaguar', 'ja': 'ジャガー'},
292
+ {'en': 'lion', 'ja': 'ライオン'},
293
+ {'en': 'tiger', 'ja': '虎'},
294
+ {'en': 'cheetah', 'ja': 'チーター'},
295
+ {'en': 'brown bear', 'ja': 'ヒグマ'},
296
+ {'en': 'American black bear', 'ja': 'アメリカクロクマ'},
297
+ {'en': 'polar bear', 'ja': '氷のクマ'},
298
+ {'en': 'sloth bear', 'ja': 'ナマケグマ'},
299
+ {'en': 'mongoose', 'ja': 'マングース'},
300
+ {'en': 'meerkat', 'ja': 'ミーアキャット'},
301
+ {'en': 'tiger beetle', 'ja': 'ハンミョウ'},
302
+ {'en': 'ladybug', 'ja': 'てんとう虫'},
303
+ {'en': 'ground beetle', 'ja': 'グランドビートル'},
304
+ {'en': 'longhorn beetle', 'ja': 'カミキリムシ'},
305
+ {'en': 'leaf beetle', 'ja': 'ハムシ'},
306
+ {'en': 'dung beetle', 'ja': 'フンコロガシ'},
307
+ {'en': 'rhinoceros beetle', 'ja': 'サイハムシ'},
308
+ {'en': 'weevil', 'ja': 'ゾウムシ'},
309
+ {'en': 'fly', 'ja': 'ハエ'},
310
+ {'en': 'bee', 'ja': '蜂'},
311
+ {'en': 'ant', 'ja': '蟻'},
312
+ {'en': 'grasshopper', 'ja': 'バッタ'},
313
+ {'en': 'cricket insect', 'ja': 'クリケット'},
314
+ {'en': 'stick insect', 'ja': '杖'},
315
+ {'en': 'cockroach', 'ja': 'ゴキブリ'},
316
+ {'en': 'praying mantis', 'ja': 'カマキリ'},
317
+ {'en': 'cicada', 'ja': '蝉'},
318
+ {'en': 'leafhopper', 'ja': 'ヨコバイ'},
319
+ {'en': 'lacewing', 'ja': 'クサカゲロウ'},
320
+ {'en': 'dragonfly', 'ja': 'トンボ'},
321
+ {'en': 'damselfly', 'ja': 'イトトンボ'},
322
+ {'en': 'red admiral butterfly', 'ja': '提督'},
323
+ {'en': 'ringlet butterfly', 'ja': 'リングレット'},
324
+ {'en': 'monarch butterfly', 'ja': '君主'},
325
+ {'en': 'small white butterfly', 'ja': 'モンシロチョウ'},
326
+ {'en': 'sulphur butterfly', 'ja': '硫黄蝶'},
327
+ {'en': 'gossamer-winged butterfly', 'ja': 'シジミチョウ'},
328
+ {'en': 'starfish', 'ja': 'ヒトデ'},
329
+ {'en': 'sea urchin', 'ja': 'うに'},
330
+ {'en': 'sea cucumber', 'ja': 'ナマコ'},
331
+ {'en': 'cottontail rabbit', 'ja': '木のウサギ'},
332
+ {'en': 'hare', 'ja': '野ウサギ'},
333
+ {'en': 'Angora rabbit', 'ja': 'アンゴラ'},
334
+ {'en': 'hamster', 'ja': 'ハムスター'},
335
+ {'en': 'porcupine', 'ja': 'ヤマアラシ'},
336
+ {'en': 'fox squirrel', 'ja': 'キツネリス'},
337
+ {'en': 'marmot', 'ja': 'マーモット'},
338
+ {'en': 'beaver', 'ja': 'ビーバー'},
339
+ {'en': 'guinea pig', 'ja': 'モルモット'},
340
+ {'en': 'common sorrel horse', 'ja': '栗色'},
341
+ {'en': 'zebra', 'ja': 'シマウマ'},
342
+ {'en': 'pig', 'ja': '豚'},
343
+ {'en': 'wild boar', 'ja': 'イノシシ'},
344
+ {'en': 'warthog', 'ja': 'イボイノシシ'},
345
+ {'en': 'hippopotamus', 'ja': 'カバ'},
346
+ {'en': 'ox', 'ja': '雄牛'},
347
+ {'en': 'water buffalo', 'ja': '水牛'},
348
+ {'en': 'bison', 'ja': 'バイソン'},
349
+ {'en': 'ram (adult male sheep)', 'ja': 'ラム'},
350
+ {'en': 'bighorn sheep', 'ja': 'ビッグホーン'},
351
+ {'en': 'Alpine ibex', 'ja': 'アイベックス'},
352
+ {'en': 'hartebeest', 'ja': 'ハーテビースト'},
353
+ {'en': 'impala (antelope)', 'ja': 'インパラ'},
354
+ {'en': 'gazelle', 'ja': 'ガゼル'},
355
+ {'en': 'arabian camel', 'ja': 'アラビアラクダ'},
356
+ {'en': 'llama', 'ja': 'ラマ'},
357
+ {'en': 'weasel', 'ja': 'イタチ'},
358
+ {'en': 'mink', 'ja': 'ミンク'},
359
+ {'en': 'European polecat', 'ja': 'ケナガイタチ'},
360
+ {'en': 'black-footed ferret', 'ja': 'クロアシイタチ'},
361
+ {'en': 'otter', 'ja': 'カワウソ'},
362
+ {'en': 'skunk', 'ja': 'スカンク'},
363
+ {'en': 'badger', 'ja': '狸'},
364
+ {'en': 'armadillo', 'ja': 'アルマジロ'},
365
+ {'en': 'three-toed sloth', 'ja': 'ミユビナマケモノ'},
366
+ {'en': 'orangutan', 'ja': 'オランウータン'},
367
+ {'en': 'gorilla', 'ja': 'ゴリラ'},
368
+ {'en': 'chimpanzee', 'ja': 'チンパンジー'},
369
+ {'en': 'gibbon', 'ja': 'テナガザル'},
370
+ {'en': 'siamang', 'ja': 'フクロテナガザル'},
371
+ {'en': 'guenon', 'ja': 'オナガザル'},
372
+ {'en': 'patas monkey', 'ja': 'パタス'},
373
+ {'en': 'baboon', 'ja': 'ヒヒ'},
374
+ {'en': 'macaque', 'ja': 'マカク'},
375
+ {'en': 'langur', 'ja': 'ヤセザル'},
376
+ {'en': 'black-and-white colobus', 'ja': 'コロブス属'},
377
+ {'en': 'proboscis monkey', 'ja': 'テングザル'},
378
+ {'en': 'marmoset', 'ja': 'マーモセット'},
379
+ {'en': 'white-headed capuchin', 'ja': 'オマキザル'},
380
+ {'en': 'howler monkey', 'ja': 'ホエザル'},
381
+ {'en': 'titi monkey', 'ja': 'ティティ'},
382
+ {'en': "Geoffroy's spider monkey", 'ja': 'クモザル'},
383
+ {'en': 'common squirrel monkey', 'ja': 'リスザル'},
384
+ {'en': 'ring-tailed lemur', 'ja': 'マダガスカル猫'},
385
+ {'en': 'indri', 'ja': 'インドリ'},
386
+ {'en': 'Asian elephant', 'ja': 'インドゾウ'},
387
+ {'en': 'African bush elephant', 'ja': 'アフリカゾウ'},
388
+ {'en': 'red panda', 'ja': 'レッサーパンダ'},
389
+ {'en': 'giant panda', 'ja': 'ジャイアントパンダ'},
390
+ {'en': 'snoek fish', 'ja': 'バラクータ'},
391
+ {'en': 'eel', 'ja': 'ウナギ'},
392
+ {'en': 'silver salmon', 'ja': 'ギンザケ'},
393
+ {'en': 'rock beauty fish', 'ja': '岩の美しさ'},
394
+ {'en': 'clownfish', 'ja': 'クマノミ'},
395
+ {'en': 'sturgeon', 'ja': 'チョウザメ'},
396
+ {'en': 'gar fish', 'ja': 'ガー'},
397
+ {'en': 'lionfish', 'ja': 'ミノカサゴ'},
398
+ {'en': 'pufferfish', 'ja': 'フグ'},
399
+ {'en': 'abacus', 'ja': 'そろばん'},
400
+ {'en': 'abaya', 'ja': 'アバヤ'},
401
+ {'en': 'academic gown', 'ja': 'アカデミックガウン'},
402
+ {'en': 'accordion', 'ja': 'アコーディオン'},
403
+ {'en': 'acoustic guitar', 'ja': 'アコースティックギター'},
404
+ {'en': 'aircraft carrier', 'ja': '空母'},
405
+ {'en': 'airliner', 'ja': '旅客機'},
406
+ {'en': 'airship', 'ja': '飛行船'},
407
+ {'en': 'altar', 'ja': '祭壇'},
408
+ {'en': 'ambulance', 'ja': '救急車'},
409
+ {'en': 'amphibious vehicle', 'ja': '両生類'},
410
+ {'en': 'analog clock', 'ja': 'アナログ時計'},
411
+ {'en': 'apiary', 'ja': '養蜂場'},
412
+ {'en': 'apron', 'ja': 'エプロン'},
413
+ {'en': 'trash can', 'ja': 'ごみ入れ'},
414
+ {'en': 'assault rifle', 'ja': 'アサルトライフル'},
415
+ {'en': 'backpack', 'ja': 'バックパック'},
416
+ {'en': 'bakery', 'ja': 'ベーカリー'},
417
+ {'en': 'balance beam', 'ja': '平均台'},
418
+ {'en': 'balloon', 'ja': 'バルーン'},
419
+ {'en': 'ballpoint pen', 'ja': 'ボールペン'},
420
+ {'en': 'Band-Aid', 'ja': 'バンドエイド'},
421
+ {'en': 'banjo', 'ja': 'バンジョー'},
422
+ {'en': 'baluster / handrail', 'ja': 'バニスター'},
423
+ {'en': 'barbell', 'ja': 'バーベル'},
424
+ {'en': 'barber chair', 'ja': '理髪店の椅子'},
425
+ {'en': 'barbershop', 'ja': '理髪店'},
426
+ {'en': 'barn', 'ja': '納屋'},
427
+ {'en': 'barometer', 'ja': 'バロメーター'},
428
+ {'en': 'barrel', 'ja': 'バレル'},
429
+ {'en': 'wheelbarrow', 'ja': 'バロー'},
430
+ {'en': 'baseball', 'ja': '野球'},
431
+ {'en': 'basketball', 'ja': 'バスケットボール'},
432
+ {'en': 'bassinet', 'ja': 'バシネット'},
433
+ {'en': 'bassoon', 'ja': 'ファゴット'},
434
+ {'en': 'swimming cap', 'ja': '水泳帽'},
435
+ {'en': 'bath towel', 'ja': 'バスタオル'},
436
+ {'en': 'bathtub', 'ja': 'バスタブ'},
437
+ {'en': 'station wagon', 'ja': 'ビーチワゴン'},
438
+ {'en': 'lighthouse', 'ja': 'ビーコン'},
439
+ {'en': 'beaker', 'ja': 'ビーカー'},
440
+ {'en': 'military hat (bearskin or shako)', 'ja': 'ベアスキン'},
441
+ {'en': 'beer bottle', 'ja': 'ビール瓶'},
442
+ {'en': 'beer glass', 'ja': 'ビールグラス'},
443
+ {'en': 'bell tower', 'ja': 'ベルコート'},
444
+ {'en': 'baby bib', 'ja': 'ビブ'},
445
+ {'en': 'tandem bicycle', 'ja': '自転車'},
446
+ {'en': 'bikini', 'ja': 'ビキニ'},
447
+ {'en': 'ring binder', 'ja': 'バインダー'},
448
+ {'en': 'binoculars', 'ja': '双眼鏡'},
449
+ {'en': 'birdhouse', 'ja': '巣箱'},
450
+ {'en': 'boathouse', 'ja': 'ボートハウス'},
451
+ {'en': 'bobsleigh', 'ja': 'ボブスレー'},
452
+ {'en': 'bolo tie', 'ja': 'ループタイ'},
453
+ {'en': 'poke bonnet', 'ja': 'ボンネット'},
454
+ {'en': 'bookcase', 'ja': '本棚'},
455
+ {'en': 'bookstore', 'ja': '書店'},
456
+ {'en': 'bottle cap', 'ja': '瓶のキャップ'},
457
+ {'en': 'hunting bow', 'ja': '弓'},
458
+ {'en': 'bow tie', 'ja': 'ちょうネクタイ'},
459
+ {'en': 'brass memorial plaque', 'ja': '真鍮'},
460
+ {'en': 'bra', 'ja': 'ブラジャー'},
461
+ {'en': 'breakwater', 'ja': '防波堤'},
462
+ {'en': 'breastplate', 'ja': '胸当て'},
463
+ {'en': 'broom', 'ja': 'ほうき'},
464
+ {'en': 'bucket', 'ja': 'バケツ'},
465
+ {'en': 'buckle', 'ja': 'バックル'},
466
+ {'en': 'bulletproof vest', 'ja': '防弾チョッキ'},
467
+ {'en': 'high-speed train', 'ja': '新幹線'},
468
+ {'en': 'butcher shop', 'ja': '精肉店'},
469
+ {'en': 'taxicab', 'ja': 'タクシー'},
470
+ {'en': 'cauldron', 'ja': '大釜'},
471
+ {'en': 'candle', 'ja': 'キャンドル'},
472
+ {'en': 'cannon', 'ja': '大砲'},
473
+ {'en': 'canoe', 'ja': 'カヌー'},
474
+ {'en': 'can opener', 'ja': '缶切り'},
475
+ {'en': 'cardigan', 'ja': 'カーディガン'},
476
+ {'en': 'car mirror', 'ja': '車のミラー'},
477
+ {'en': 'carousel', 'ja': '回転木馬'},
478
+ {'en': 'tool kit', 'ja': '大工のキット'},
479
+ {'en': 'cardboard box / carton', 'ja': 'カートン'},
480
+ {'en': 'car wheel', 'ja': '車のホイール'},
481
+ {'en': 'automated teller machine', 'ja': '現金自動預け払い機'},
482
+ {'en': 'cassette', 'ja': 'カセット'},
483
+ {'en': 'cassette player', 'ja': 'カセット・プレーヤー'},
484
+ {'en': 'castle', 'ja': '城'},
485
+ {'en': 'catamaran', 'ja': 'カタマラン'},
486
+ {'en': 'CD player', 'ja': 'CDプレーヤー'},
487
+ {'en': 'cello', 'ja': 'チェロ'},
488
+ {'en': 'mobile phone', 'ja': 'スマートフォン'},
489
+ {'en': 'chain', 'ja': '鎖'},
490
+ {'en': 'chain-link fence', 'ja': 'チェーンリンクフェンス'},
491
+ {'en': 'chain mail', 'ja': 'チェーンメール'},
492
+ {'en': 'chainsaw', 'ja': 'チェーンソー'},
493
+ {'en': 'storage chest', 'ja': '胸'},
494
+ {'en': 'chiffonier', 'ja': 'シフォニア'},
495
+ {'en': 'bell or wind chime', 'ja': 'チャイム'},
496
+ {'en': 'china cabinet', 'ja': '中国キャビネット'},
497
+ {'en': 'Christmas stocking', 'ja': 'クリスマスの靴下'},
498
+ {'en': 'church', 'ja': '教会'},
499
+ {'en': 'movie theater', 'ja': '映画'},
500
+ {'en': 'cleaver', 'ja': 'クリーバー'},
501
+ {'en': 'cliff dwelling', 'ja': '崖の住居'},
502
+ {'en': 'cloak', 'ja': 'マント'},
503
+ {'en': 'clogs', 'ja': 'クロッグ'},
504
+ {'en': 'cocktail shaker', 'ja': 'カクテルシェーカー'},
505
+ {'en': 'coffee mug', 'ja': 'コーヒーマグ'},
506
+ {'en': 'coffeemaker', 'ja': 'コーヒーポット'},
507
+ {'en': 'spiral or coil', 'ja': 'コイル'},
508
+ {'en': 'combination lock', 'ja': 'ダイヤル錠'},
509
+ {'en': 'computer keyboard', 'ja': 'コンピュータのキーボード'},
510
+ {'en': 'candy store', 'ja': '製菓'},
511
+ {'en': 'container ship', 'ja': 'コンテナ船'},
512
+ {'en': 'convertible', 'ja': 'コンバーチブル'},
513
+ {'en': 'corkscrew', 'ja': 'コークスクリュー'},
514
+ {'en': 'cornet', 'ja': 'コルネット'},
515
+ {'en': 'cowboy boot', 'ja': 'カウボーイブーツ'},
516
+ {'en': 'cowboy hat', 'ja': 'カウボーイハット'},
517
+ {'en': 'cradle', 'ja': 'クレードル'},
518
+ {'en': 'construction crane', 'ja': 'クレーン'},
519
+ {'en': 'crash helmet', 'ja': 'クラッシュヘルメット'},
520
+ {'en': 'crate', 'ja': '木箱'},
521
+ {'en': 'infant bed', 'ja': 'ベビーベッド'},
522
+ {'en': 'Crock Pot', 'ja': 'クロークポット'},
523
+ {'en': 'croquet ball', 'ja': 'クロケットボール'},
524
+ {'en': 'crutch', 'ja': '松葉杖'},
525
+ {'en': 'cuirass', 'ja': '胸当て'},
526
+ {'en': 'dam', 'ja': 'ダム'},
527
+ {'en': 'desk', 'ja': '机'},
528
+ {'en': 'desktop computer', 'ja': 'デスクトップコンピューター'},
529
+ {'en': 'rotary dial telephone', 'ja': 'ダイヤル電話'},
530
+ {'en': 'diaper', 'ja': 'おむつ'},
531
+ {'en': 'digital clock', 'ja': 'デジタル時計'},
532
+ {'en': 'digital watch', 'ja': 'デジタル腕時計'},
533
+ {'en': 'dining table', 'ja': 'ダイニングテーブル'},
534
+ {'en': 'dishcloth', 'ja': '意気地なし'},
535
+ {'en': 'dishwasher', 'ja': '食器洗い機'},
536
+ {'en': 'disc brake', 'ja': 'ディスクブレーキ'},
537
+ {'en': 'dock', 'ja': 'ドック'},
538
+ {'en': 'dog sled', 'ja': '犬ぞり'},
539
+ {'en': 'dome', 'ja': 'ドーム'},
540
+ {'en': 'doormat', 'ja': '玄関マット'},
541
+ {'en': 'drilling rig', 'ja': '掘削基地'},
542
+ {'en': 'drum', 'ja': 'ドラム'},
543
+ {'en': 'drumstick', 'ja': 'ドラムスティック'},
544
+ {'en': 'dumbbell', 'ja': 'ダンベル'},
545
+ {'en': 'Dutch oven', 'ja': 'ダッチオーブン'},
546
+ {'en': 'electric fan', 'ja': '扇風機'},
547
+ {'en': 'electric guitar', 'ja': 'エレキギター'},
548
+ {'en': 'electric locomotive', 'ja': '電気機関車'},
549
+ {'en': 'entertainment center', 'ja': '娯楽施設'},
550
+ {'en': 'envelope', 'ja': '封筒'},
551
+ {'en': 'espresso machine', 'ja': 'エスプレッソマシーン'},
552
+ {'en': 'face powder', 'ja': 'フェースパウダー'},
553
+ {'en': 'feather boa', 'ja': 'フェザーボア'},
554
+ {'en': 'filing cabinet', 'ja': 'ファイル'},
555
+ {'en': 'fireboat', 'ja': '消防艇'},
556
+ {'en': 'fire truck', 'ja': '消防車'},
557
+ {'en': 'fire screen', 'ja': 'ファイアースクリーン'},
558
+ {'en': 'flagpole', 'ja': '旗竿'},
559
+ {'en': 'flute', 'ja': 'フルート'},
560
+ {'en': 'folding chair', 'ja': '折り畳み式椅子'},
561
+ {'en': 'football helmet', 'ja': 'フットボールヘルメット'},
562
+ {'en': 'forklift', 'ja': 'フォークリフト'},
563
+ {'en': 'fountain', 'ja': '噴水'},
564
+ {'en': 'fountain pen', 'ja': '万年筆'},
565
+ {'en': 'four-poster bed', 'ja': '四柱'},
566
+ {'en': 'freight car', 'ja': '貨車'},
567
+ {'en': 'French horn', 'ja': 'フレンチホルン'},
568
+ {'en': 'frying pan', 'ja': 'フライパン'},
569
+ {'en': 'fur coat', 'ja': '毛皮のコート'},
570
+ {'en': 'garbage truck', 'ja': 'ごみ収集車'},
571
+ {'en': 'gas mask or respirator', 'ja': 'ガスマスク'},
572
+ {'en': 'gas pump', 'ja': 'ガソリンポンプ'},
573
+ {'en': 'goblet', 'ja': 'ゴブレット'},
574
+ {'en': 'go-kart', 'ja': 'ゴーカート'},
575
+ {'en': 'golf ball', 'ja': 'ゴルフボール'},
576
+ {'en': 'golf cart', 'ja': 'ゴルフカート'},
577
+ {'en': 'gondola', 'ja': 'ゴンドラ'},
578
+ {'en': 'gong', 'ja': 'ゴング'},
579
+ {'en': 'gown', 'ja': 'ガウン'},
580
+ {'en': 'grand piano', 'ja': 'グランドピアノ'},
581
+ {'en': 'greenhouse', 'ja': '温室'},
582
+ {'en': 'radiator grille', 'ja': 'グリル'},
583
+ {'en': 'grocery store', 'ja': '食料品店'},
584
+ {'en': 'guillotine', 'ja': 'ギロチン'},
585
+ {'en': 'hair clip', 'ja': 'ヘアスライド'},
586
+ {'en': 'hair spray', 'ja': 'ヘアスプレー'},
587
+ {'en': 'half-track', 'ja': '半トラック'},
588
+ {'en': 'hammer', 'ja': 'ハンマー'},
589
+ {'en': 'hamper', 'ja': '妨げます'},
590
+ {'en': 'hair dryer', 'ja': 'ハンドブロワー'},
591
+ {'en': 'hand-held computer', 'ja': 'タブレット'},
592
+ {'en': 'handkerchief', 'ja': 'ハンカチ'},
593
+ {'en': 'hard disk drive', 'ja': 'ハードディスク'},
594
+ {'en': 'harmonica', 'ja': 'ハーモニカ'},
595
+ {'en': 'harp', 'ja': 'ハープ'},
596
+ {'en': 'combine harvester', 'ja': 'ハーベスタ'},
597
+ {'en': 'hatchet', 'ja': '斧'},
598
+ {'en': 'holster', 'ja': 'ホルスター'},
599
+ {'en': 'home theater', 'ja': 'ホームシアター'},
600
+ {'en': 'honeycomb', 'ja': 'ハニカム'},
601
+ {'en': 'hook', 'ja': 'フック'},
602
+ {'en': 'hoop skirt', 'ja': 'フープスカート'},
603
+ {'en': 'gymnastic horizontal bar', 'ja': '水平バー'},
604
+ {'en': 'horse-drawn vehicle', 'ja': '馬車'},
605
+ {'en': 'hourglass', 'ja': '砂時計'},
606
+ {'en': 'iPod', 'ja': 'アイフォーン'},
607
+ {'en': 'clothes iron', 'ja': '鉄'},
608
+ {'en': 'carved pumpkin', 'ja': 'ジャックオーランタン'},
609
+ {'en': 'jeans', 'ja': 'ジーンズ'},
610
+ {'en': 'jeep', 'ja': 'ジープ'},
611
+ {'en': 'T-shirt', 'ja': 'ジャージー'},
612
+ {'en': 'jigsaw puzzle', 'ja': 'ジグソーパズル'},
613
+ {'en': 'rickshaw', 'ja': '人力車'},
614
+ {'en': 'joystick', 'ja': 'ジョイスティック'},
615
+ {'en': 'kimono', 'ja': '着物'},
616
+ {'en': 'knee pad', 'ja': '膝パッド'},
617
+ {'en': 'knot', 'ja': '結び目'},
618
+ {'en': 'lab coat', 'ja': '白衣'},
619
+ {'en': 'ladle', 'ja': 'ひしゃく'},
620
+ {'en': 'lampshade', 'ja': 'ランプのかさ'},
621
+ {'en': 'laptop computer', 'ja': 'ノートパソコン'},
622
+ {'en': 'lawn mower', 'ja': '芝刈り機'},
623
+ {'en': 'lens cap', 'ja': 'レンズキャップ'},
624
+ {'en': 'letter opener', 'ja': 'レターオープナー'},
625
+ {'en': 'library', 'ja': 'ライブラリ'},
626
+ {'en': 'lifeboat', 'ja': '救命ボート'},
627
+ {'en': 'lighter', 'ja': 'ライター'},
628
+ {'en': 'limousine', 'ja': 'リムジン'},
629
+ {'en': 'ocean liner', 'ja': 'ライナー'},
630
+ {'en': 'lipstick', 'ja': '口紅'},
631
+ {'en': 'slip-on shoe', 'ja': 'ローファー'},
632
+ {'en': 'lotion', 'ja': 'ローション'},
633
+ {'en': 'music speaker', 'ja': 'スピーカー'},
634
+ {'en': 'loupe magnifying glass', 'ja': 'ルーペ'},
635
+ {'en': 'sawmill', 'ja': '製材所'},
636
+ {'en': 'magnetic compass', 'ja': '磁気コンパス'},
637
+ {'en': 'messenger bag', 'ja': '郵袋'},
638
+ {'en': 'mailbox', 'ja': 'メールボックス'},
639
+ {'en': 'tights', 'ja': 'マイヨ'},
640
+ {'en': 'one-piece bathing suit', 'ja': 'マイヨ'},
641
+ {'en': 'manhole cover', 'ja': 'マンホールの蓋'},
642
+ {'en': 'maraca', 'ja': 'マラカス'},
643
+ {'en': 'marimba', 'ja': 'マリンバ'},
644
+ {'en': 'mask', 'ja': 'マスク'},
645
+ {'en': 'matchstick', 'ja': 'マッチ棒'},
646
+ {'en': 'maypole', 'ja': 'メイポール'},
647
+ {'en': 'maze', 'ja': '迷路'},
648
+ {'en': 'measuring cup', 'ja': '計量カップ'},
649
+ {'en': 'medicine cabinet', 'ja': '薬箱'},
650
+ {'en': 'megalith', 'ja': '巨石'},
651
+ {'en': 'microphone', 'ja': 'マイク'},
652
+ {'en': 'microwave oven', 'ja': 'マイクロ波'},
653
+ {'en': 'military uniform', 'ja': '軍服'},
654
+ {'en': 'milk can', 'ja': 'ミルク缶'},
655
+ {'en': 'minibus', 'ja': 'ミニバス'},
656
+ {'en': 'miniskirt', 'ja': 'ミニスカート'},
657
+ {'en': 'minivan', 'ja': 'ミニバン'},
658
+ {'en': 'missile', 'ja': 'ミサイル'},
659
+ {'en': 'mitten', 'ja': 'ミトン'},
660
+ {'en': 'mixing bowl', 'ja': 'ミキシングボウル'},
661
+ {'en': 'mobile home', 'ja': '移動住宅'},
662
+ {'en': 'ford model t', 'ja': 'モデルT'},
663
+ {'en': 'modem', 'ja': 'モデム'},
664
+ {'en': 'monastery', 'ja': '修道院'},
665
+ {'en': 'monitor', 'ja': 'モニター'},
666
+ {'en': 'moped', 'ja': 'モペット'},
667
+ {'en': 'mortar and pestle', 'ja': 'モルタル'},
668
+ {'en': 'graduation cap', 'ja': 'モルタルボード'},
669
+ {'en': 'mosque', 'ja': 'モスク'},
670
+ {'en': 'mosquito net', 'ja': '蚊帳'},
671
+ {'en': 'vespa', 'ja': 'スクーター'},
672
+ {'en': 'mountain bike', 'ja': 'マウンテンバイク'},
673
+ {'en': 'tent', 'ja': '山のテント'},
674
+ {'en': 'computer mouse', 'ja': 'マウス'},
675
+ {'en': 'mousetrap', 'ja': 'ネズミ捕り'},
676
+ {'en': 'moving van', 'ja': '引っ越しトラック'},
677
+ {'en': 'muzzle', 'ja': '銃口'},
678
+ {'en': 'metal nail', 'ja': 'ネイル'},
679
+ {'en': 'neck brace', 'ja': 'ネックブレース'},
680
+ {'en': 'necklace', 'ja': 'ネックレス'},
681
+ {'en': 'baby pacifier', 'ja': '乳首'},
682
+ {'en': 'notebook computer', 'ja': 'ノート'},
683
+ {'en': 'obelisk', 'ja': 'オベリスク'},
684
+ {'en': 'oboe', 'ja': 'オーボエ'},
685
+ {'en': 'ocarina', 'ja': 'オカリナ'},
686
+ {'en': 'odometer', 'ja': 'オドメーター'},
687
+ {'en': 'oil filter', 'ja': 'オイルフィルター'},
688
+ {'en': 'pipe organ', 'ja': '器官'},
689
+ {'en': 'oscilloscope', 'ja': 'オシロスコープ'},
690
+ {'en': 'overskirt', 'ja': 'オーバースカート'},
691
+ {'en': 'bullock cart', 'ja': '牛車'},
692
+ {'en': 'oxygen mask', 'ja': '酸素マスク'},
693
+ {'en': 'product packet / packaging', 'ja': 'パケット'},
694
+ {'en': 'paddle', 'ja': 'パドル'},
695
+ {'en': 'paddle wheel', 'ja': 'パドルホイール'},
696
+ {'en': 'padlock', 'ja': '南京錠'},
697
+ {'en': 'paintbrush', 'ja': '絵筆'},
698
+ {'en': 'pajamas', 'ja': 'パジャマ'},
699
+ {'en': 'palace', 'ja': '宮殿'},
700
+ {'en': 'pan flute', 'ja': 'パンパイプ'},
701
+ {'en': 'paper towel', 'ja': 'ペーパータオル'},
702
+ {'en': 'parachute', 'ja': 'パラシュート'},
703
+ {'en': 'parallel bars', 'ja': '平行棒'},
704
+ {'en': 'park bench', 'ja': '公園のベンチ'},
705
+ {'en': 'parking meter', 'ja': 'パーキングメーター'},
706
+ {'en': 'railroad car', 'ja': '乗用車'},
707
+ {'en': 'patio', 'ja': 'パティオ'},
708
+ {'en': 'payphone', 'ja': '有料電話'},
709
+ {'en': 'pedestal', 'ja': '台座'},
710
+ {'en': 'pencil case', 'ja': '筆箱'},
711
+ {'en': 'pencil sharpener', 'ja': '鉛筆削り'},
712
+ {'en': 'perfume', 'ja': '香水'},
713
+ {'en': 'Petri dish', 'ja': 'ペトリ皿'},
714
+ {'en': 'photocopier', 'ja': 'コピー機'},
715
+ {'en': 'plectrum', 'ja': '選ぶ'},
716
+ {'en': 'Pickelhaube', 'ja': 'スパイク付き鉄かぶと'},
717
+ {'en': 'picket fence', 'ja': '杭柵'},
718
+ {'en': 'pickup truck', 'ja': '拾う'},
719
+ {'en': 'pier', 'ja': '桟橋'},
720
+ {'en': 'piggy bank', 'ja': '貯金箱'},
721
+ {'en': 'pill bottle', 'ja': '錠剤瓶'},
722
+ {'en': 'pillow', 'ja': '枕'},
723
+ {'en': 'ping-pong ball', 'ja': 'ピンポン球'},
724
+ {'en': 'pinwheel', 'ja': '風車'},
725
+ {'en': 'pirate ship', 'ja': '海賊'},
726
+ {'en': 'drink pitcher', 'ja': 'ピッチャー'},
727
+ {'en': 'block plane', 'ja': '飛行機'},
728
+ {'en': 'planetarium', 'ja': 'プラネタリウム'},
729
+ {'en': 'plastic bag', 'ja': 'ビニール袋'},
730
+ {'en': 'plate rack', 'ja': '皿立て'},
731
+ {'en': 'farm plow', 'ja': 'プラウ'},
732
+ {'en': 'plunger', 'ja': 'プランジャー'},
733
+ {'en': 'Polaroid camera', 'ja': 'ポラロイドカメラ'},
734
+ {'en': 'pole', 'ja': 'ポール'},
735
+ {'en': 'police van', 'ja': '警察車'},
736
+ {'en': 'poncho', 'ja': 'ポンチョ'},
737
+ {'en': 'pool table', 'ja': 'ビリヤード台'},
738
+ {'en': 'soda bottle', 'ja': 'ポップ・ボトル'},
739
+ {'en': 'plant pot', 'ja': 'ポット'},
740
+ {'en': "potter's wheel", 'ja': 'ろくろ'},
741
+ {'en': 'power drill', 'ja': 'パワードリル'},
742
+ {'en': 'prayer rug', 'ja': '礼拝用敷物'},
743
+ {'en': 'printer', 'ja': 'プリンタ'},
744
+ {'en': 'prison', 'ja': '刑務所'},
745
+ {'en': 'missile', 'ja': '発射体'},
746
+ {'en': 'projector', 'ja': 'プロジェクター'},
747
+ {'en': 'hockey puck', 'ja': 'パック'},
748
+ {'en': 'punching bag', 'ja': 'サンドバッグ'},
749
+ {'en': 'purse', 'ja': '財布'},
750
+ {'en': 'quill', 'ja': 'クイル'},
751
+ {'en': 'quilt', 'ja': 'キルト'},
752
+ {'en': 'race car', 'ja': 'レーサー'},
753
+ {'en': 'racket', 'ja': 'ラケット'},
754
+ {'en': 'radiator', 'ja': 'ラジエーター'},
755
+ {'en': 'radio', 'ja': '無線'},
756
+ {'en': 'radio telescope', 'ja': '電波望遠鏡'},
757
+ {'en': 'rain barrel', 'ja': '天水桶'},
758
+ {'en': 'recreational vehicle', 'ja': 'RV車'},
759
+ {'en': 'fishing casting reel', 'ja': 'リール'},
760
+ {'en': 'reflex camera', 'ja': 'レフレックスカメラ'},
761
+ {'en': 'refrigerator', 'ja': '冷蔵庫'},
762
+ {'en': 'remote control', 'ja': 'リモコン'},
763
+ {'en': 'restaurant', 'ja': 'レストラン'},
764
+ {'en': 'revolver', 'ja': 'リボルバー'},
765
+ {'en': 'rifle', 'ja': 'ライフル'},
766
+ {'en': 'rocking chair', 'ja': 'ロッキングチェア'},
767
+ {'en': 'rotisserie', 'ja': '焼肉料理店'},
768
+ {'en': 'eraser', 'ja': '消しゴム'},
769
+ {'en': 'rugby ball', 'ja': 'ラグビーボール'},
770
+ {'en': 'ruler measuring stick', 'ja': 'ルール'},
771
+ {'en': 'sneaker', 'ja': 'ランニングシューズ'},
772
+ {'en': 'safe', 'ja': '安全'},
773
+ {'en': 'safety pin', 'ja': '安全ピン'},
774
+ {'en': 'salt shaker', 'ja': '塩の入れ物'},
775
+ {'en': 'sandal', 'ja': 'サンダル'},
776
+ {'en': 'sarong', 'ja': 'サロン'},
777
+ {'en': 'saxophone', 'ja': 'サックス'},
778
+ {'en': 'scabbard', 'ja': '鞘'},
779
+ {'en': 'weighing scale', 'ja': '規模'},
780
+ {'en': 'school bus', 'ja': 'スクールバス'},
781
+ {'en': 'schooner', 'ja': 'スクーナー'},
782
+ {'en': 'scoreboard', 'ja': 'スコアボード'},
783
+ {'en': 'CRT monitor', 'ja': '画面'},
784
+ {'en': 'screw', 'ja': 'スクリュー'},
785
+ {'en': 'screwdriver', 'ja': 'ドライバー'},
786
+ {'en': 'seat belt', 'ja': 'シートベルト'},
787
+ {'en': 'sewing machine', 'ja': 'ミシン'},
788
+ {'en': 'shield', 'ja': 'シールド'},
789
+ {'en': 'shoe store', 'ja': '靴屋'},
790
+ {'en': 'shoji screen / room divider', 'ja': '障子'},
791
+ {'en': 'shopping basket', 'ja': '買い物かご'},
792
+ {'en': 'shopping cart', 'ja': 'ショッピングカート'},
793
+ {'en': 'shovel', 'ja': 'シャベル'},
794
+ {'en': 'shower cap', 'ja': 'シャワーキャップ'},
795
+ {'en': 'shower curtain', 'ja': 'シャワーカーテン'},
796
+ {'en': 'ski', 'ja': 'スキー'},
797
+ {'en': 'balaclava ski mask', 'ja': 'スキーマスク'},
798
+ {'en': 'sleeping bag', 'ja': '寝袋'},
799
+ {'en': 'slide rule', 'ja': '計算尺'},
800
+ {'en': 'sliding door', 'ja': '引き戸'},
801
+ {'en': 'slot machine', 'ja': 'スロット'},
802
+ {'en': 'snorkel', 'ja': 'スノーケル'},
803
+ {'en': 'snowmobile', 'ja': 'スノー���ービル'},
804
+ {'en': 'snowplow', 'ja': '除雪機'},
805
+ {'en': 'soap dispenser', 'ja': 'ソープディスペンサー'},
806
+ {'en': 'soccer ball', 'ja': 'サッカーボール'},
807
+ {'en': 'sock', 'ja': '靴下'},
808
+ {'en': 'solar thermal collector', 'ja': '太陽の皿'},
809
+ {'en': 'sombrero', 'ja': 'ソンブレロ'},
810
+ {'en': 'soup bowl', 'ja': 'スープ皿'},
811
+ {'en': 'keyboard space bar', 'ja': 'スペースキー'},
812
+ {'en': 'space heater', 'ja': 'スペースヒーター'},
813
+ {'en': 'space shuttle', 'ja': 'スペースシャトル'},
814
+ {'en': 'spatula', 'ja': 'へら'},
815
+ {'en': 'motorboat', 'ja': 'スピードボート'},
816
+ {'en': 'spider web', 'ja': 'クモの巣'},
817
+ {'en': 'spindle', 'ja': 'スピンドル'},
818
+ {'en': 'sports car', 'ja': 'スポーツカー'},
819
+ {'en': 'spotlight', 'ja': 'スポットライト'},
820
+ {'en': 'stage', 'ja': 'ステージ'},
821
+ {'en': 'steam locomotive', 'ja': '蒸気機関車'},
822
+ {'en': 'through arch bridge', 'ja': '鋼アーチ橋'},
823
+ {'en': 'steel drum', 'ja': 'スチールドラム'},
824
+ {'en': 'stethoscope', 'ja': '聴診器'},
825
+ {'en': 'scarf', 'ja': 'ストール'},
826
+ {'en': 'stone wall', 'ja': '石垣'},
827
+ {'en': 'stopwatch', 'ja': 'ストップウォッチ'},
828
+ {'en': 'stove', 'ja': 'レンジ'},
829
+ {'en': 'strainer', 'ja': 'ストレーナー'},
830
+ {'en': 'tram', 'ja': '路面電車'},
831
+ {'en': 'stretcher', 'ja': 'ストレッチャー'},
832
+ {'en': 'couch', 'ja': 'スタジオソファ'},
833
+ {'en': 'stupa', 'ja': '仏舎利塔'},
834
+ {'en': 'submarine', 'ja': '潜水艦'},
835
+ {'en': 'suit', 'ja': 'スーツ'},
836
+ {'en': 'sundial', 'ja': '日時計'},
837
+ {'en': 'sunglasses', 'ja': 'サングラス'},
838
+ {'en': 'sunglasses', 'ja': 'サングラス'},
839
+ {'en': 'sunscreen', 'ja': '日焼け止め剤'},
840
+ {'en': 'suspension bridge', 'ja': 'つり橋'},
841
+ {'en': 'mop', 'ja': '綿棒'},
842
+ {'en': 'sweatshirt', 'ja': 'トレーナー'},
843
+ {'en': 'swim trunks / shorts', 'ja': '海パン'},
844
+ {'en': 'swing', 'ja': 'スイング'},
845
+ {'en': 'electrical switch', 'ja': 'スイッチ'},
846
+ {'en': 'syringe', 'ja': '注射器'},
847
+ {'en': 'table lamp', 'ja': '電気スタンド'},
848
+ {'en': 'tank', 'ja': 'タンク'},
849
+ {'en': 'tape player', 'ja': 'テーププレーヤー'},
850
+ {'en': 'teapot', 'ja': 'ティーポット'},
851
+ {'en': 'teddy bear', 'ja': 'テディ'},
852
+ {'en': 'television', 'ja': 'テレビ'},
853
+ {'en': 'tennis ball', 'ja': 'テニスボール'},
854
+ {'en': 'thatched roof', 'ja': 'サッチ'},
855
+ {'en': 'front curtain', 'ja': '劇場のカーテン'},
856
+ {'en': 'thimble', 'ja': '指ぬき'},
857
+ {'en': 'threshing machine', 'ja': '脱穀機'},
858
+ {'en': 'throne', 'ja': '王位'},
859
+ {'en': 'tile roof', 'ja': '瓦屋根'},
860
+ {'en': 'toaster', 'ja': 'トースター'},
861
+ {'en': 'tobacco shop', 'ja': 'タバコ屋'},
862
+ {'en': 'toilet seat', 'ja': '便座'},
863
+ {'en': 'torch', 'ja': 'トーチ'},
864
+ {'en': 'totem pole', 'ja': 'トーテムポール'},
865
+ {'en': 'tow truck', 'ja': 'レッカー車'},
866
+ {'en': 'toy store', 'ja': '玩具屋'},
867
+ {'en': 'tractor', 'ja': 'トラクター'},
868
+ {'en': 'semi-trailer truck', 'ja': 'トレーラートラック'},
869
+ {'en': 'tray', 'ja': 'トレイ'},
870
+ {'en': 'trench coat', 'ja': 'トレンチコート'},
871
+ {'en': 'tricycle', 'ja': '三輪車'},
872
+ {'en': 'trimaran', 'ja': '三胴船'},
873
+ {'en': 'tripod', 'ja': '三脚'},
874
+ {'en': 'triumphal arch', 'ja': '凱旋門'},
875
+ {'en': 'trolleybus', 'ja': 'トロリーバス'},
876
+ {'en': 'trombone', 'ja': 'トロンボーン'},
877
+ {'en': 'hot tub', 'ja': 'バスタブ'},
878
+ {'en': 'turnstile', 'ja': '回転ドア'},
879
+ {'en': 'typewriter keyboard', 'ja': 'タイプライターのキーボード'},
880
+ {'en': 'umbrella', 'ja': '傘'},
881
+ {'en': 'unicycle', 'ja': '一輪車'},
882
+ {'en': 'upright piano', 'ja': '直立'},
883
+ {'en': 'vacuum cleaner', 'ja': '真空'},
884
+ {'en': 'vase', 'ja': '花瓶'},
885
+ {'en': 'vaulted or arched ceiling', 'ja': 'ボールト'},
886
+ {'en': 'velvet fabric', 'ja': 'ベルベット'},
887
+ {'en': 'vending machine', 'ja': '自動販売機'},
888
+ {'en': 'vestment', 'ja': '祭服'},
889
+ {'en': 'viaduct', 'ja': '高架橋'},
890
+ {'en': 'violin', 'ja': 'バイオリン'},
891
+ {'en': 'volleyball', 'ja': 'バレーボール'},
892
+ {'en': 'waffle iron', 'ja': 'ワッフル焼き型'},
893
+ {'en': 'wall clock', 'ja': '壁時計'},
894
+ {'en': 'wallet', 'ja': '財布'},
895
+ {'en': 'wardrobe', 'ja': 'ワードローブ'},
896
+ {'en': 'military aircraft', 'ja': '戦闘機'},
897
+ {'en': 'sink', 'ja': '洗面器'},
898
+ {'en': 'washing machine', 'ja': 'ワッシャー'},
899
+ {'en': 'water bottle', 'ja': '水筒'},
900
+ {'en': 'water jug', 'ja': '水差し'},
901
+ {'en': 'water tower', 'ja': '給水塔'},
902
+ {'en': 'whiskey jug', 'ja': 'ウイスキージャグ'},
903
+ {'en': 'whistle', 'ja': 'ホイッスル'},
904
+ {'en': 'hair wig', 'ja': 'かつら'},
905
+ {'en': 'window screen', 'ja': '窓網戸'},
906
+ {'en': 'window shade', 'ja': 'ブラインド'},
907
+ {'en': 'Windsor tie', 'ja': 'ウィンザーネクタイ'},
908
+ {'en': 'wine bottle', 'ja': 'ワインボトル'},
909
+ {'en': 'airplane wing', 'ja': '翼'},
910
+ {'en': 'wok', 'ja': '中華鍋'},
911
+ {'en': 'wooden spoon', 'ja': '木��スプーン'},
912
+ {'en': 'wool', 'ja': 'ウール'},
913
+ {'en': 'split-rail fence', 'ja': 'ワームフェンス'},
914
+ {'en': 'shipwreck', 'ja': '難破船'},
915
+ {'en': 'sailboat', 'ja': 'ヨール'},
916
+ {'en': 'yurt', 'ja': 'パオ'},
917
+ {'en': 'website', 'ja': 'サイト'},
918
+ {'en': 'comic book', 'ja': 'コミックブック'},
919
+ {'en': 'crossword', 'ja': 'クロスワードパズル'},
920
+ {'en': 'traffic or street sign', 'ja': '道路標識'},
921
+ {'en': 'traffic light', 'ja': '交通信号灯'},
922
+ {'en': 'dust jacket', 'ja': 'ブックカバー'},
923
+ {'en': 'menu', 'ja': 'メニュー'},
924
+ {'en': 'plate', 'ja': 'プレート'},
925
+ {'en': 'guacamole', 'ja': 'グアカモーレ'},
926
+ {'en': 'consomme', 'ja': 'コンソメ'},
927
+ {'en': 'hot pot', 'ja': 'ホットポット'},
928
+ {'en': 'trifle', 'ja': 'パフェ'},
929
+ {'en': 'ice cream', 'ja': 'アイスクリーム'},
930
+ {'en': 'popsicle', 'ja': 'アイスキャンディー'},
931
+ {'en': 'baguette', 'ja': 'フランスパン'},
932
+ {'en': 'bagel', 'ja': 'ベーグル'},
933
+ {'en': 'pretzel', 'ja': 'プレッツェル'},
934
+ {'en': 'cheeseburger', 'ja': 'チーズバーガー'},
935
+ {'en': 'hot dog', 'ja': 'ホットドッグ'},
936
+ {'en': 'mashed potatoes', 'ja': 'マッシュポテト'},
937
+ {'en': 'cabbage', 'ja': 'キャベツ'},
938
+ {'en': 'broccoli', 'ja': 'ブロッコリー'},
939
+ {'en': 'cauliflower', 'ja': 'カリフラワー'},
940
+ {'en': 'zucchini', 'ja': 'ズッキーニ'},
941
+ {'en': 'spaghetti squash', 'ja': 'そうめんかぼちゃ'},
942
+ {'en': 'acorn squash', 'ja': 'ドングリかぼちゃ'},
943
+ {'en': 'butternut squash', 'ja': 'カボチャ'},
944
+ {'en': 'cucumber', 'ja': 'キュウリ'},
945
+ {'en': 'artichoke', 'ja': 'アーティチョーク'},
946
+ {'en': 'bell pepper', 'ja': 'ピーマン'},
947
+ {'en': 'cardoon', 'ja': 'カルドン'},
948
+ {'en': 'mushroom', 'ja': 'キノコ'},
949
+ {'en': 'Granny Smith apple', 'ja': 'リンゴ'},
950
+ {'en': 'strawberry', 'ja': 'イチゴ'},
951
+ {'en': 'orange', 'ja': 'オレンジ'},
952
+ {'en': 'lemon', 'ja': 'レモン'},
953
+ {'en': 'fig', 'ja': 'イチジク'},
954
+ {'en': 'pineapple', 'ja': 'パイナップル'},
955
+ {'en': 'banana', 'ja': 'バナナ'},
956
+ {'en': 'jackfruit', 'ja': 'パラミツ'},
957
+ {'en': 'cherimoya (custard apple)', 'ja': 'カスタードアップル'},
958
+ {'en': 'pomegranate', 'ja': 'ザクロ'},
959
+ {'en': 'hay', 'ja': '干し草'},
960
+ {'en': 'carbonara', 'ja': 'カルボナーラ'},
961
+ {'en': 'chocolate syrup', 'ja': 'チョコレートソース'},
962
+ {'en': 'dough', 'ja': 'パン生地'},
963
+ {'en': 'meatloaf', 'ja': 'ミートローフ'},
964
+ {'en': 'pizza', 'ja': 'ピザ'},
965
+ {'en': 'pot pie', 'ja': 'ポットパイ'},
966
+ {'en': 'burrito', 'ja': 'ブリトー'},
967
+ {'en': 'red wine', 'ja': '赤ワイン'},
968
+ {'en': 'espresso', 'ja': 'エスプレッソ'},
969
+ {'en': 'tea cup', 'ja': 'カップ'},
970
+ {'en': 'eggnog', 'ja': 'エッグノッグ'},
971
+ {'en': 'mountain', 'ja': 'アルプス'},
972
+ {'en': 'bubble', 'ja': 'バブル'},
973
+ {'en': 'cliff', 'ja': '崖'},
974
+ {'en': 'coral reef', 'ja': 'サンゴ礁'},
975
+ {'en': 'geyser', 'ja': '間欠泉'},
976
+ {'en': 'lakeshore', 'ja': '湖畔'},
977
+ {'en': 'promontory', 'ja': '岬'},
978
+ {'en': 'sandbar', 'ja': '砂州'},
979
+ {'en': 'beach', 'ja': '海岸'},
980
+ {'en': 'valley', 'ja': '谷'},
981
+ {'en': 'volcano', 'ja': '火山'},
982
+ {'en': 'baseball player', 'ja': '野球選手'},
983
+ {'en': 'bridegroom', 'ja': '新郎'},
984
+ {'en': 'scuba diver', 'ja': 'スキューバダイバー'},
985
+ {'en': 'rapeseed', 'ja': '菜種'},
986
+ {'en': 'daisy', 'ja': 'デイジー'},
987
+ {'en': "yellow lady's slipper", 'ja': '蘭'},
988
+ {'en': 'corn', 'ja': 'トウモロコシ'},
989
+ {'en': 'acorn', 'ja': 'ドングリ'},
990
+ {'en': 'rose hip', 'ja': 'ヒップ'},
991
+ {'en': 'horse chestnut seed', 'ja': 'トチノキ'},
992
+ {'en': 'coral fungus', 'ja': 'サンゴ菌'},
993
+ {'en': 'agaric', 'ja': 'ハラタケ'},
994
+ {'en': 'gyromitra', 'ja': 'シャグマアミガサタケ'},
995
+ {'en': 'stinkhorn mushroom', 'ja': 'スッポンタケ'},
996
+ {'en': 'earth star fungus', 'ja': 'ハラタケ'},
997
+ {'en': 'hen of the woods mushroom', 'ja': '舞茸'},
998
+ {'en': 'bolete', 'ja': 'きのこ'},
999
+ {'en': 'corn cob', 'ja': '耳'},
1000
+ {'en': 'toilet paper', 'ja': 'トイレットペーパー'}]
1001
+
1002
+
1003
+ imagenet_templates = [{'en': 'a bad photo of a {}.', 'ja': '{}の悪い写真'},
1004
+ {'en': 'a photo of many {}.', 'ja': '多くの{}の写真'},
1005
+ {'en': 'a sculpture of a {}.', 'ja': '{}の彫刻'},
1006
+ {'en': 'a photo of the hard to see {}.', 'ja': '見づらい{}の写真'},
1007
+ {'en': 'a low resolution photo of the {}.', 'ja': '{}の低解像度写真'},
1008
+ {'en': 'a rendering of a {}.', 'ja': '{}のレンダリング'},
1009
+ {'en': 'graffiti of a {}.', 'ja': '{}の落書き'},
1010
+ {'en': 'a cropped photo of the {}.', 'ja': '{}のトリミング写真'},
1011
+ {'en': 'a tattoo of a {}.', 'ja': '{}のタトゥー'},
1012
+ {'en': 'the embroidered {}.', 'ja': '刺繍された{}'},
1013
+ {'en': 'a bright photo of a {}.', 'ja': '{}の明るい写真'},
1014
+ {'en': 'a photo of a clean {}.', 'ja': 'きれいな{}の写真'},
1015
+ {'en': 'a photo of a dirty {}.', 'ja': '汚れた{}の写真'},
1016
+ {'en': 'a dark photo of the {}.', 'ja': '{}の��い写真'},
1017
+ {'en': 'a drawing of a {}.', 'ja': '{}の絵'},
1018
+ {'en': 'a photo of my {}.', 'ja': '私の{}の写真'},
1019
+ {'en': 'the plastic {}.', 'ja': 'プラスチック製の{}'},
1020
+ {'en': 'a photo of the cool {}.', 'ja': 'かっこいい{}の写真'},
1021
+ {'en': 'a close-up photo of a {}.', 'ja': '{}のクローズアップ写真'},
1022
+ {'en': 'a black and white photo of the {}.', 'ja': '{}の白黒写真'},
1023
+ {'en': 'a pixelated photo of the {}.', 'ja': '{}のピクセル写真'},
1024
+ {'en': 'a jpeg corrupted photo of a {}.', 'ja': 'jpegで加工した{}の写真'},
1025
+ {'en': 'a blurry photo of the {}.', 'ja': '{}のぼやけた写真'},
1026
+ {'en': 'a photo of the {}.', 'ja': '{}の写真'},
1027
+ {'en': 'a good photo of the {}.', 'ja': '{}の良い写真'},
1028
+ {'en': 'a {} in a video game.', 'ja': 'ゲームに登場する{}'},
1029
+ {'en': 'the origami {}.', 'ja': '折り紙で作った{}'},
1030
+ {'en': 'a sketch of a {}.', 'ja': '{}のスケッチ'},
1031
+ {'en': 'the toy {}.', 'ja': 'おもちゃの{}'},
1032
+ {'en': 'a rendition of the {}.', 'ja': '{}の演出'},
1033
+ {'en': 'a photo of a large {}.', 'ja': '大きな{}の写真'},
1034
+ {'en': 'a photo of a nice {}.', 'ja': '素敵な{}の写真'},
1035
+ {'en': 'a photo of a weird {}.', 'ja': '奇妙な{}の写真'},
1036
+ {'en': 'a cartoon {}.', 'ja': '漫画の{}'},
1037
+ {'en': 'art of a {}.', 'ja': '{}の芸術'},
1038
+ {'en': 'a plushie {}.', 'ja': '{}のぬいぐるみ'},
1039
+ {'en': 'a photo of the small {}.', 'ja': '小さな{}の写真'},]
1040
+
1041
+
1042
+
1043
+
japanese_clip/utils/imagenet_zeroshot_data_en.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ imagenet_classnames = ["tench", "goldfish", "great white shark", "tiger shark", "hammerhead shark", "electric ray",
2
+ "stingray", "rooster", "hen", "ostrich", "brambling", "goldfinch", "house finch", "junco",
3
+ "indigo bunting", "American robin", "bulbul", "jay", "magpie", "chickadee", "American dipper",
4
+ "kite (bird of prey)", "bald eagle", "vulture", "great grey owl", "fire salamander",
5
+ "smooth newt", "newt", "spotted salamander", "axolotl", "American bullfrog", "tree frog",
6
+ "tailed frog", "loggerhead sea turtle", "leatherback sea turtle", "mud turtle", "terrapin",
7
+ "box turtle", "banded gecko", "green iguana", "Carolina anole",
8
+ "desert grassland whiptail lizard", "agama", "frilled-necked lizard", "alligator lizard",
9
+ "Gila monster", "European green lizard", "chameleon", "Komodo dragon", "Nile crocodile",
10
+ "American alligator", "triceratops", "worm snake", "ring-necked snake",
11
+ "eastern hog-nosed snake", "smooth green snake", "kingsnake", "garter snake", "water snake",
12
+ "vine snake", "night snake", "boa constrictor", "African rock python", "Indian cobra",
13
+ "green mamba", "sea snake", "Saharan horned viper", "eastern diamondback rattlesnake",
14
+ "sidewinder rattlesnake", "trilobite", "harvestman", "scorpion", "yellow garden spider",
15
+ "barn spider", "European garden spider", "southern black widow", "tarantula", "wolf spider",
16
+ "tick", "centipede", "black grouse", "ptarmigan", "ruffed grouse", "prairie grouse", "peafowl",
17
+ "quail", "partridge", "african grey parrot", "macaw", "sulphur-crested cockatoo", "lorikeet",
18
+ "coucal", "bee eater", "hornbill", "hummingbird", "jacamar", "toucan", "duck",
19
+ "red-breasted merganser", "goose", "black swan", "tusker", "echidna", "platypus", "wallaby",
20
+ "koala", "wombat", "jellyfish", "sea anemone", "brain coral", "flatworm", "nematode", "conch",
21
+ "snail", "slug", "sea slug", "chiton", "chambered nautilus", "Dungeness crab", "rock crab",
22
+ "fiddler crab", "red king crab", "American lobster", "spiny lobster", "crayfish", "hermit crab",
23
+ "isopod", "white stork", "black stork", "spoonbill", "flamingo", "little blue heron",
24
+ "great egret", "bittern bird", "crane bird", "limpkin", "common gallinule", "American coot",
25
+ "bustard", "ruddy turnstone", "dunlin", "common redshank", "dowitcher", "oystercatcher",
26
+ "pelican", "king penguin", "albatross", "grey whale", "killer whale", "dugong", "sea lion",
27
+ "Chihuahua", "Japanese Chin", "Maltese", "Pekingese", "Shih Tzu", "King Charles Spaniel",
28
+ "Papillon", "toy terrier", "Rhodesian Ridgeback", "Afghan Hound", "Basset Hound", "Beagle",
29
+ "Bloodhound", "Bluetick Coonhound", "Black and Tan Coonhound", "Treeing Walker Coonhound",
30
+ "English foxhound", "Redbone Coonhound", "borzoi", "Irish Wolfhound", "Italian Greyhound",
31
+ "Whippet", "Ibizan Hound", "Norwegian Elkhound", "Otterhound", "Saluki", "Scottish Deerhound",
32
+ "Weimaraner", "Staffordshire Bull Terrier", "American Staffordshire Terrier",
33
+ "Bedlington Terrier", "Border Terrier", "Kerry Blue Terrier", "Irish Terrier",
34
+ "Norfolk Terrier", "Norwich Terrier", "Yorkshire Terrier", "Wire Fox Terrier",
35
+ "Lakeland Terrier", "Sealyham Terrier", "Airedale Terrier", "Cairn Terrier",
36
+ "Australian Terrier", "Dandie Dinmont Terrier", "Boston Terrier", "Miniature Schnauzer",
37
+ "Giant Schnauzer", "Standard Schnauzer", "Scottish Terrier", "Tibetan Terrier",
38
+ "Australian Silky Terrier", "Soft-coated Wheaten Terrier", "West Highland White Terrier",
39
+ "Lhasa Apso", "Flat-Coated Retriever", "Curly-coated Retriever", "Golden Retriever",
40
+ "Labrador Retriever", "Chesapeake Bay Retriever", "German Shorthaired Pointer", "Vizsla",
41
+ "English Setter", "Irish Setter", "Gordon Setter", "Brittany dog", "Clumber Spaniel",
42
+ "English Springer Spaniel", "Welsh Springer Spaniel", "Cocker Spaniel", "Sussex Spaniel",
43
+ "Irish Water Spaniel", "Kuvasz", "Schipperke", "Groenendael dog", "Malinois", "Briard",
44
+ "Australian Kelpie", "Komondor", "Old English Sheepdog", "Shetland Sheepdog", "collie",
45
+ "Border Collie", "Bouvier des Flandres dog", "Rottweiler", "German Shepherd Dog", "Dobermann",
46
+ "Miniature Pinscher", "Greater Swiss Mountain Dog", "Bernese Mountain Dog",
47
+ "Appenzeller Sennenhund", "Entlebucher Sennenhund", "Boxer", "Bullmastiff", "Tibetan Mastiff",
48
+ "French Bulldog", "Great Dane", "St. Bernard", "husky", "Alaskan Malamute", "Siberian Husky",
49
+ "Dalmatian", "Affenpinscher", "Basenji", "pug", "Leonberger", "Newfoundland dog",
50
+ "Great Pyrenees dog", "Samoyed", "Pomeranian", "Chow Chow", "Keeshond", "brussels griffon",
51
+ "Pembroke Welsh Corgi", "Cardigan Welsh Corgi", "Toy Poodle", "Miniature Poodle",
52
+ "Standard Poodle", "Mexican hairless dog (xoloitzcuintli)", "grey wolf", "Alaskan tundra wolf",
53
+ "red wolf or maned wolf", "coyote", "dingo", "dhole", "African wild dog", "hyena", "red fox",
54
+ "kit fox", "Arctic fox", "grey fox", "tabby cat", "tiger cat", "Persian cat", "Siamese cat",
55
+ "Egyptian Mau", "cougar", "lynx", "leopard", "snow leopard", "jaguar", "lion", "tiger",
56
+ "cheetah", "brown bear", "American black bear", "polar bear", "sloth bear", "mongoose",
57
+ "meerkat", "tiger beetle", "ladybug", "ground beetle", "longhorn beetle", "leaf beetle",
58
+ "dung beetle", "rhinoceros beetle", "weevil", "fly", "bee", "ant", "grasshopper",
59
+ "cricket insect", "stick insect", "cockroach", "praying mantis", "cicada", "leafhopper",
60
+ "lacewing", "dragonfly", "damselfly", "red admiral butterfly", "ringlet butterfly",
61
+ "monarch butterfly", "small white butterfly", "sulphur butterfly", "gossamer-winged butterfly",
62
+ "starfish", "sea urchin", "sea cucumber", "cottontail rabbit", "hare", "Angora rabbit",
63
+ "hamster", "porcupine", "fox squirrel", "marmot", "beaver", "guinea pig", "common sorrel horse",
64
+ "zebra", "pig", "wild boar", "warthog", "hippopotamus", "ox", "water buffalo", "bison",
65
+ "ram (adult male sheep)", "bighorn sheep", "Alpine ibex", "hartebeest", "impala (antelope)",
66
+ "gazelle", "arabian camel", "llama", "weasel", "mink", "European polecat",
67
+ "black-footed ferret", "otter", "skunk", "badger", "armadillo", "three-toed sloth", "orangutan",
68
+ "gorilla", "chimpanzee", "gibbon", "siamang", "guenon", "patas monkey", "baboon", "macaque",
69
+ "langur", "black-and-white colobus", "proboscis monkey", "marmoset", "white-headed capuchin",
70
+ "howler monkey", "titi monkey", "Geoffroy's spider monkey", "common squirrel monkey",
71
+ "ring-tailed lemur", "indri", "Asian elephant", "African bush elephant", "red panda",
72
+ "giant panda", "snoek fish", "eel", "silver salmon", "rock beauty fish", "clownfish",
73
+ "sturgeon", "gar fish", "lionfish", "pufferfish", "abacus", "abaya", "academic gown",
74
+ "accordion", "acoustic guitar", "aircraft carrier", "airliner", "airship", "altar", "ambulance",
75
+ "amphibious vehicle", "analog clock", "apiary", "apron", "trash can", "assault rifle",
76
+ "backpack", "bakery", "balance beam", "balloon", "ballpoint pen", "Band-Aid", "banjo",
77
+ "baluster / handrail", "barbell", "barber chair", "barbershop", "barn", "barometer", "barrel",
78
+ "wheelbarrow", "baseball", "basketball", "bassinet", "bassoon", "swimming cap", "bath towel",
79
+ "bathtub", "station wagon", "lighthouse", "beaker", "military hat (bearskin or shako)",
80
+ "beer bottle", "beer glass", "bell tower", "baby bib", "tandem bicycle", "bikini",
81
+ "ring binder", "binoculars", "birdhouse", "boathouse", "bobsleigh", "bolo tie", "poke bonnet",
82
+ "bookcase", "bookstore", "bottle cap", "hunting bow", "bow tie", "brass memorial plaque", "bra",
83
+ "breakwater", "breastplate", "broom", "bucket", "buckle", "bulletproof vest",
84
+ "high-speed train", "butcher shop", "taxicab", "cauldron", "candle", "cannon", "canoe",
85
+ "can opener", "cardigan", "car mirror", "carousel", "tool kit", "cardboard box / carton",
86
+ "car wheel", "automated teller machine", "cassette", "cassette player", "castle", "catamaran",
87
+ "CD player", "cello", "mobile phone", "chain", "chain-link fence", "chain mail", "chainsaw",
88
+ "storage chest", "chiffonier", "bell or wind chime", "china cabinet", "Christmas stocking",
89
+ "church", "movie theater", "cleaver", "cliff dwelling", "cloak", "clogs", "cocktail shaker",
90
+ "coffee mug", "coffeemaker", "spiral or coil", "combination lock", "computer keyboard",
91
+ "candy store", "container ship", "convertible", "corkscrew", "cornet", "cowboy boot",
92
+ "cowboy hat", "cradle", "construction crane", "crash helmet", "crate", "infant bed",
93
+ "Crock Pot", "croquet ball", "crutch", "cuirass", "dam", "desk", "desktop computer",
94
+ "rotary dial telephone", "diaper", "digital clock", "digital watch", "dining table",
95
+ "dishcloth", "dishwasher", "disc brake", "dock", "dog sled", "dome", "doormat", "drilling rig",
96
+ "drum", "drumstick", "dumbbell", "Dutch oven", "electric fan", "electric guitar",
97
+ "electric locomotive", "entertainment center", "envelope", "espresso machine", "face powder",
98
+ "feather boa", "filing cabinet", "fireboat", "fire truck", "fire screen", "flagpole", "flute",
99
+ "folding chair", "football helmet", "forklift", "fountain", "fountain pen", "four-poster bed",
100
+ "freight car", "French horn", "frying pan", "fur coat", "garbage truck",
101
+ "gas mask or respirator", "gas pump", "goblet", "go-kart", "golf ball", "golf cart", "gondola",
102
+ "gong", "gown", "grand piano", "greenhouse", "radiator grille", "grocery store", "guillotine",
103
+ "hair clip", "hair spray", "half-track", "hammer", "hamper", "hair dryer", "hand-held computer",
104
+ "handkerchief", "hard disk drive", "harmonica", "harp", "combine harvester", "hatchet",
105
+ "holster", "home theater", "honeycomb", "hook", "hoop skirt", "gymnastic horizontal bar",
106
+ "horse-drawn vehicle", "hourglass", "iPod", "clothes iron", "carved pumpkin", "jeans", "jeep",
107
+ "T-shirt", "jigsaw puzzle", "rickshaw", "joystick", "kimono", "knee pad", "knot", "lab coat",
108
+ "ladle", "lampshade", "laptop computer", "lawn mower", "lens cap", "letter opener", "library",
109
+ "lifeboat", "lighter", "limousine", "ocean liner", "lipstick", "slip-on shoe", "lotion",
110
+ "music speaker", "loupe magnifying glass", "sawmill", "magnetic compass", "messenger bag",
111
+ "mailbox", "tights", "one-piece bathing suit", "manhole cover", "maraca", "marimba", "mask",
112
+ "matchstick", "maypole", "maze", "measuring cup", "medicine cabinet", "megalith", "microphone",
113
+ "microwave oven", "military uniform", "milk can", "minibus", "miniskirt", "minivan", "missile",
114
+ "mitten", "mixing bowl", "mobile home", "ford model t", "modem", "monastery", "monitor",
115
+ "moped", "mortar and pestle", "graduation cap", "mosque", "mosquito net", "vespa",
116
+ "mountain bike", "tent", "computer mouse", "mousetrap", "moving van", "muzzle", "metal nail",
117
+ "neck brace", "necklace", "baby pacifier", "notebook computer", "obelisk", "oboe", "ocarina",
118
+ "odometer", "oil filter", "pipe organ", "oscilloscope", "overskirt", "bullock cart",
119
+ "oxygen mask", "product packet / packaging", "paddle", "paddle wheel", "padlock", "paintbrush",
120
+ "pajamas", "palace", "pan flute", "paper towel", "parachute", "parallel bars", "park bench",
121
+ "parking meter", "railroad car", "patio", "payphone", "pedestal", "pencil case",
122
+ "pencil sharpener", "perfume", "Petri dish", "photocopier", "plectrum", "Pickelhaube",
123
+ "picket fence", "pickup truck", "pier", "piggy bank", "pill bottle", "pillow", "ping-pong ball",
124
+ "pinwheel", "pirate ship", "drink pitcher", "block plane", "planetarium", "plastic bag",
125
+ "plate rack", "farm plow", "plunger", "Polaroid camera", "pole", "police van", "poncho",
126
+ "pool table", "soda bottle", "plant pot", "potter's wheel", "power drill", "prayer rug",
127
+ "printer", "prison", "missile", "projector", "hockey puck", "punching bag", "purse", "quill",
128
+ "quilt", "race car", "racket", "radiator", "radio", "radio telescope", "rain barrel",
129
+ "recreational vehicle", "fishing casting reel", "reflex camera", "refrigerator",
130
+ "remote control", "restaurant", "revolver", "rifle", "rocking chair", "rotisserie", "eraser",
131
+ "rugby ball", "ruler measuring stick", "sneaker", "safe", "safety pin", "salt shaker", "sandal",
132
+ "sarong", "saxophone", "scabbard", "weighing scale", "school bus", "schooner", "scoreboard",
133
+ "CRT monitor", "screw", "screwdriver", "seat belt", "sewing machine", "shield", "shoe store",
134
+ "shoji screen / room divider", "shopping basket", "shopping cart", "shovel", "shower cap",
135
+ "shower curtain", "ski", "balaclava ski mask", "sleeping bag", "slide rule", "sliding door",
136
+ "slot machine", "snorkel", "snowmobile", "snowplow", "soap dispenser", "soccer ball", "sock",
137
+ "solar thermal collector", "sombrero", "soup bowl", "keyboard space bar", "space heater",
138
+ "space shuttle", "spatula", "motorboat", "spider web", "spindle", "sports car", "spotlight",
139
+ "stage", "steam locomotive", "through arch bridge", "steel drum", "stethoscope", "scarf",
140
+ "stone wall", "stopwatch", "stove", "strainer", "tram", "stretcher", "couch", "stupa",
141
+ "submarine", "suit", "sundial", "sunglasses", "sunglasses", "sunscreen", "suspension bridge",
142
+ "mop", "sweatshirt", "swim trunks / shorts", "swing", "electrical switch", "syringe",
143
+ "table lamp", "tank", "tape player", "teapot", "teddy bear", "television", "tennis ball",
144
+ "thatched roof", "front curtain", "thimble", "threshing machine", "throne", "tile roof",
145
+ "toaster", "tobacco shop", "toilet seat", "torch", "totem pole", "tow truck", "toy store",
146
+ "tractor", "semi-trailer truck", "tray", "trench coat", "tricycle", "trimaran", "tripod",
147
+ "triumphal arch", "trolleybus", "trombone", "hot tub", "turnstile", "typewriter keyboard",
148
+ "umbrella", "unicycle", "upright piano", "vacuum cleaner", "vase", "vaulted or arched ceiling",
149
+ "velvet fabric", "vending machine", "vestment", "viaduct", "violin", "volleyball",
150
+ "waffle iron", "wall clock", "wallet", "wardrobe", "military aircraft", "sink",
151
+ "washing machine", "water bottle", "water jug", "water tower", "whiskey jug", "whistle",
152
+ "hair wig", "window screen", "window shade", "Windsor tie", "wine bottle", "airplane wing",
153
+ "wok", "wooden spoon", "wool", "split-rail fence", "shipwreck", "sailboat", "yurt", "website",
154
+ "comic book", "crossword", "traffic or street sign", "traffic light", "dust jacket", "menu",
155
+ "plate", "guacamole", "consomme", "hot pot", "trifle", "ice cream", "popsicle", "baguette",
156
+ "bagel", "pretzel", "cheeseburger", "hot dog", "mashed potatoes", "cabbage", "broccoli",
157
+ "cauliflower", "zucchini", "spaghetti squash", "acorn squash", "butternut squash", "cucumber",
158
+ "artichoke", "bell pepper", "cardoon", "mushroom", "Granny Smith apple", "strawberry", "orange",
159
+ "lemon", "fig", "pineapple", "banana", "jackfruit", "cherimoya (custard apple)", "pomegranate",
160
+ "hay", "carbonara", "chocolate syrup", "dough", "meatloaf", "pizza", "pot pie", "burrito",
161
+ "red wine", "espresso", "tea cup", "eggnog", "mountain", "bubble", "cliff", "coral reef",
162
+ "geyser", "lakeshore", "promontory", "sandbar", "beach", "valley", "volcano", "baseball player",
163
+ "bridegroom", "scuba diver", "rapeseed", "daisy", "yellow lady's slipper", "corn", "acorn",
164
+ "rose hip", "horse chestnut seed", "coral fungus", "agaric", "gyromitra", "stinkhorn mushroom",
165
+ "earth star fungus", "hen of the woods mushroom", "bolete", "corn cob", "toilet paper"]
166
+
167
+ imagenet_templates = [
168
+ 'a bad photo of a {}.',
169
+ 'a photo of many {}.',
170
+ 'a sculpture of a {}.',
171
+ 'a photo of the hard to see {}.',
172
+ 'a low resolution photo of the {}.',
173
+ 'a rendering of a {}.',
174
+ 'graffiti of a {}.',
175
+ 'a bad photo of the {}.',
176
+ 'a cropped photo of the {}.',
177
+ 'a tattoo of a {}.',
178
+ 'the embroidered {}.',
179
+ 'a photo of a hard to see {}.',
180
+ 'a bright photo of a {}.',
181
+ 'a photo of a clean {}.',
182
+ 'a photo of a dirty {}.',
183
+ 'a dark photo of the {}.',
184
+ 'a drawing of a {}.',
185
+ 'a photo of my {}.',
186
+ 'the plastic {}.',
187
+ 'a photo of the cool {}.',
188
+ 'a close-up photo of a {}.',
189
+ 'a black and white photo of the {}.',
190
+ 'a painting of the {}.',
191
+ 'a painting of a {}.',
192
+ 'a pixelated photo of the {}.',
193
+ 'a sculpture of the {}.',
194
+ 'a bright photo of the {}.',
195
+ 'a cropped photo of a {}.',
196
+ 'a plastic {}.',
197
+ 'a photo of the dirty {}.',
198
+ 'a jpeg corrupted photo of a {}.',
199
+ 'a blurry photo of the {}.',
200
+ 'a photo of the {}.',
201
+ 'a good photo of the {}.',
202
+ 'a rendering of the {}.',
203
+ 'a {} in a video game.',
204
+ 'a photo of one {}.',
205
+ 'a doodle of a {}.',
206
+ 'a close-up photo of the {}.',
207
+ 'a photo of a {}.',
208
+ 'the origami {}.',
209
+ 'the {} in a video game.',
210
+ 'a sketch of a {}.',
211
+ 'a doodle of the {}.',
212
+ 'a origami {}.',
213
+ 'a low resolution photo of a {}.',
214
+ 'the toy {}.',
215
+ 'a rendition of the {}.',
216
+ 'a photo of the clean {}.',
217
+ 'a photo of a large {}.',
218
+ 'a rendition of a {}.',
219
+ 'a photo of a nice {}.',
220
+ 'a photo of a weird {}.',
221
+ 'a blurry photo of a {}.',
222
+ 'a cartoon {}.',
223
+ 'art of a {}.',
224
+ 'a sketch of the {}.',
225
+ 'a embroidered {}.',
226
+ 'a pixelated photo of a {}.',
227
+ 'itap of the {}.',
228
+ 'a jpeg corrupted photo of the {}.',
229
+ 'a good photo of a {}.',
230
+ 'a plushie {}.',
231
+ 'a photo of the nice {}.',
232
+ 'a photo of the small {}.',
233
+ 'a photo of the weird {}.',
234
+ 'the cartoon {}.',
235
+ 'art of the {}.',
236
+ 'a drawing of the {}.',
237
+ 'a photo of the large {}.',
238
+ 'a black and white photo of a {}.',
239
+ 'the plushie {}.',
240
+ 'a dark photo of a {}.',
241
+ 'itap of a {}.',
242
+ 'graffiti of the {}.',
243
+ 'a toy {}.',
244
+ 'itap of my {}.',
245
+ 'a photo of a cool {}.',
246
+ 'a photo of a small {}.',
247
+ 'a tattoo of the {}.',
248
+ ]
japanese_clip/version.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 rinna Co., Ltd.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ __version__ = '0.2.0'
model.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Any
3
+ import japanese_clip as ja_clip
4
+ from s3_session import Bucket
5
+ from PIL import Image
6
+ import uuid
7
+ from db_session import get_db
8
+
9
+
10
+ @dataclass
11
+ class MLModel:
12
+ tokenizer: Any = None
13
+ model: Any = None
14
+ preprocess: Any = None
15
+ bucket: Any = None
16
+
17
+ def __post_init__(self):
18
+ tokenizer = ja_clip.load_tokenizer()
19
+ model, preprocess = ja_clip.load(
20
+ "rinna/japanese-clip-vit-b-16", cache_dir="/tmp/japanese_clip", device="cpu"
21
+ )
22
+ self.tokenizer = tokenizer
23
+ self.model = model
24
+ self.preprocess = preprocess
25
+ self.bucket = Bucket()
26
+
27
+ def save(self, image_path: str):
28
+ pillow_iamge = Image.open(image_path)
29
+ image = self.preprocess(pillow_iamge).unsqueeze(0).to("cpu")
30
+ image_features = self.model.get_image_features(image)
31
+ image_uuid = str(uuid.uuid4())
32
+
33
+ # media upload
34
+ self.bucket.upload_file(pillow_iamge, image_uuid)
35
+
36
+ # db insert
37
+ db = get_db()
38
+ result = db["embedding"].insert_one(
39
+ {"uuid": image_uuid, "vectorField": image_features[0].tolist()}
40
+ )
41
+ return result.inserted_id
42
+
43
+ def search(self, prompt: str):
44
+ db = get_db()
45
+ encodings = ja_clip.tokenize(
46
+ texts=[prompt], max_seq_len=77, device="cpu", tokenizer=self.tokenizer
47
+ )
48
+ text_features = self.model.get_text_features(**encodings)
49
+ pipeline = [
50
+ {
51
+ "$vectorSearch": {
52
+ "index": "vector_index",
53
+ "path": "vectorField",
54
+ "queryVector": text_features[0].tolist(),
55
+ "numCandidates": 150,
56
+ "limit": 10,
57
+ }
58
+ },
59
+ {
60
+ "$project": {
61
+ "_id": {"$toString": "$_id"},
62
+ "uuid": 1,
63
+ "score": {"$meta": "vectorSearchScore"},
64
+ }
65
+ },
66
+ ]
67
+ result = db["embedding"].aggregate(pipeline)
68
+ urls = [self.bucket.get_presigned_url(x["uuid"]) for x in result]
69
+ return urls
pyproject.toml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "image-vector-search"
3
+ version = "0.1.0"
4
+ description = "Add your description here"
5
+ authors = [{ name = "kittchy", email = "[email protected]" }]
6
+ dependencies = [
7
+ "fastapi>=0.114.2",
8
+ "boto3>=1.35.19",
9
+ "pymongo[srv]>=4.8.0",
10
+ "pydantic>=2.9.1",
11
+ "pydantic-settings>=2.5.2",
12
+ "torch>=2.4.1",
13
+ "torchvision>=0.19.1",
14
+ "sentencepiece>=0.2.0",
15
+ "pandas>=2.2.2",
16
+ "scipy>=1.14.1",
17
+ "transformers>=4.44.2",
18
+ "python-multipart>=0.0.9",
19
+ "python-dotenv>=1.0.1",
20
+ "gradio>=4.44.0",
21
+ ]
22
+ readme = "README.md"
23
+ requires-python = ">= 3.12"
24
+
25
+ [build-system]
26
+ requires = ["hatchling"]
27
+ build-backend = "hatchling.build"
28
+
29
+ [tool.rye]
30
+ managed = true
31
+ dev-dependencies = []
32
+
33
+ [tool.hatch.metadata]
34
+ allow-direct-references = true
35
+
36
+ [tool.hatch.build.targets.wheel]
37
+ packages = ["src/image_vector_search"]
requirements-dev.lock ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # generated by rye
2
+ # use `rye lock` or `rye sync` to update this lockfile
3
+ #
4
+ # last locked with the following flags:
5
+ # pre: false
6
+ # features: []
7
+ # all-features: false
8
+ # with-sources: false
9
+ # generate-hashes: false
10
+
11
+ -e file:.
12
+ aiofiles==23.2.1
13
+ # via gradio
14
+ annotated-types==0.7.0
15
+ # via pydantic
16
+ anyio==4.4.0
17
+ # via gradio
18
+ # via httpx
19
+ # via starlette
20
+ boto3==1.35.19
21
+ # via image-vector-search
22
+ botocore==1.35.19
23
+ # via boto3
24
+ # via s3transfer
25
+ certifi==2024.8.30
26
+ # via httpcore
27
+ # via httpx
28
+ # via requests
29
+ charset-normalizer==3.3.2
30
+ # via requests
31
+ click==8.1.7
32
+ # via typer
33
+ # via uvicorn
34
+ contourpy==1.3.0
35
+ # via matplotlib
36
+ cycler==0.12.1
37
+ # via matplotlib
38
+ dnspython==2.6.1
39
+ # via pymongo
40
+ fastapi==0.114.2
41
+ # via gradio
42
+ # via image-vector-search
43
+ ffmpy==0.4.0
44
+ # via gradio
45
+ filelock==3.16.0
46
+ # via huggingface-hub
47
+ # via torch
48
+ # via transformers
49
+ fonttools==4.53.1
50
+ # via matplotlib
51
+ fsspec==2024.9.0
52
+ # via gradio-client
53
+ # via huggingface-hub
54
+ # via torch
55
+ gradio==4.44.0
56
+ # via image-vector-search
57
+ gradio-client==1.3.0
58
+ # via gradio
59
+ h11==0.14.0
60
+ # via httpcore
61
+ # via uvicorn
62
+ httpcore==1.0.5
63
+ # via httpx
64
+ httpx==0.27.2
65
+ # via gradio
66
+ # via gradio-client
67
+ huggingface-hub==0.24.7
68
+ # via gradio
69
+ # via gradio-client
70
+ # via tokenizers
71
+ # via transformers
72
+ idna==3.10
73
+ # via anyio
74
+ # via httpx
75
+ # via requests
76
+ importlib-resources==6.4.5
77
+ # via gradio
78
+ jinja2==3.1.4
79
+ # via gradio
80
+ # via torch
81
+ jmespath==1.0.1
82
+ # via boto3
83
+ # via botocore
84
+ kiwisolver==1.4.7
85
+ # via matplotlib
86
+ markdown-it-py==3.0.0
87
+ # via rich
88
+ markupsafe==2.1.5
89
+ # via gradio
90
+ # via jinja2
91
+ matplotlib==3.9.2
92
+ # via gradio
93
+ mdurl==0.1.2
94
+ # via markdown-it-py
95
+ mpmath==1.3.0
96
+ # via sympy
97
+ networkx==3.3
98
+ # via torch
99
+ numpy==2.1.1
100
+ # via contourpy
101
+ # via gradio
102
+ # via matplotlib
103
+ # via pandas
104
+ # via scipy
105
+ # via torchvision
106
+ # via transformers
107
+ orjson==3.10.7
108
+ # via gradio
109
+ packaging==24.1
110
+ # via gradio
111
+ # via gradio-client
112
+ # via huggingface-hub
113
+ # via matplotlib
114
+ # via transformers
115
+ pandas==2.2.2
116
+ # via gradio
117
+ # via image-vector-search
118
+ pillow==10.4.0
119
+ # via gradio
120
+ # via matplotlib
121
+ # via torchvision
122
+ pydantic==2.9.1
123
+ # via fastapi
124
+ # via gradio
125
+ # via image-vector-search
126
+ # via pydantic-settings
127
+ pydantic-core==2.23.3
128
+ # via pydantic
129
+ pydantic-settings==2.5.2
130
+ # via image-vector-search
131
+ pydub==0.25.1
132
+ # via gradio
133
+ pygments==2.18.0
134
+ # via rich
135
+ pymongo==4.8.0
136
+ # via image-vector-search
137
+ pyparsing==3.1.4
138
+ # via matplotlib
139
+ python-dateutil==2.9.0.post0
140
+ # via botocore
141
+ # via matplotlib
142
+ # via pandas
143
+ python-dotenv==1.0.1
144
+ # via image-vector-search
145
+ # via pydantic-settings
146
+ python-multipart==0.0.9
147
+ # via gradio
148
+ # via image-vector-search
149
+ pytz==2024.2
150
+ # via pandas
151
+ pyyaml==6.0.2
152
+ # via gradio
153
+ # via huggingface-hub
154
+ # via transformers
155
+ regex==2024.9.11
156
+ # via transformers
157
+ requests==2.32.3
158
+ # via huggingface-hub
159
+ # via transformers
160
+ rich==13.8.1
161
+ # via typer
162
+ ruff==0.6.5
163
+ # via gradio
164
+ s3transfer==0.10.2
165
+ # via boto3
166
+ safetensors==0.4.5
167
+ # via transformers
168
+ scipy==1.14.1
169
+ # via image-vector-search
170
+ semantic-version==2.10.0
171
+ # via gradio
172
+ sentencepiece==0.2.0
173
+ # via image-vector-search
174
+ setuptools==75.1.0
175
+ # via torch
176
+ shellingham==1.5.4
177
+ # via typer
178
+ six==1.16.0
179
+ # via python-dateutil
180
+ sniffio==1.3.1
181
+ # via anyio
182
+ # via httpx
183
+ starlette==0.38.5
184
+ # via fastapi
185
+ sympy==1.13.2
186
+ # via torch
187
+ tokenizers==0.19.1
188
+ # via transformers
189
+ tomlkit==0.12.0
190
+ # via gradio
191
+ torch==2.4.1
192
+ # via image-vector-search
193
+ # via torchvision
194
+ torchvision==0.19.1
195
+ # via image-vector-search
196
+ tqdm==4.66.5
197
+ # via huggingface-hub
198
+ # via transformers
199
+ transformers==4.44.2
200
+ # via image-vector-search
201
+ typer==0.12.5
202
+ # via gradio
203
+ typing-extensions==4.12.2
204
+ # via fastapi
205
+ # via gradio
206
+ # via gradio-client
207
+ # via huggingface-hub
208
+ # via pydantic
209
+ # via pydantic-core
210
+ # via torch
211
+ # via typer
212
+ tzdata==2024.1
213
+ # via pandas
214
+ urllib3==2.2.3
215
+ # via botocore
216
+ # via gradio
217
+ # via requests
218
+ uvicorn==0.30.6
219
+ # via gradio
220
+ websockets==12.0
221
+ # via gradio-client
requirements.lock ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # generated by rye
2
+ # use `rye lock` or `rye sync` to update this lockfile
3
+ #
4
+ # last locked with the following flags:
5
+ # pre: false
6
+ # features: []
7
+ # all-features: false
8
+ # with-sources: false
9
+ # generate-hashes: false
10
+
11
+ -e file:.
12
+ aiofiles==23.2.1
13
+ # via gradio
14
+ annotated-types==0.7.0
15
+ # via pydantic
16
+ anyio==4.4.0
17
+ # via gradio
18
+ # via httpx
19
+ # via starlette
20
+ boto3==1.35.19
21
+ # via image-vector-search
22
+ botocore==1.35.19
23
+ # via boto3
24
+ # via s3transfer
25
+ certifi==2024.8.30
26
+ # via httpcore
27
+ # via httpx
28
+ # via requests
29
+ charset-normalizer==3.3.2
30
+ # via requests
31
+ click==8.1.7
32
+ # via typer
33
+ # via uvicorn
34
+ contourpy==1.3.0
35
+ # via matplotlib
36
+ cycler==0.12.1
37
+ # via matplotlib
38
+ dnspython==2.6.1
39
+ # via pymongo
40
+ fastapi==0.114.2
41
+ # via gradio
42
+ # via image-vector-search
43
+ ffmpy==0.4.0
44
+ # via gradio
45
+ filelock==3.16.0
46
+ # via huggingface-hub
47
+ # via torch
48
+ # via transformers
49
+ fonttools==4.53.1
50
+ # via matplotlib
51
+ fsspec==2024.9.0
52
+ # via gradio-client
53
+ # via huggingface-hub
54
+ # via torch
55
+ gradio==4.44.0
56
+ # via image-vector-search
57
+ gradio-client==1.3.0
58
+ # via gradio
59
+ h11==0.14.0
60
+ # via httpcore
61
+ # via uvicorn
62
+ httpcore==1.0.5
63
+ # via httpx
64
+ httpx==0.27.2
65
+ # via gradio
66
+ # via gradio-client
67
+ huggingface-hub==0.24.7
68
+ # via gradio
69
+ # via gradio-client
70
+ # via tokenizers
71
+ # via transformers
72
+ idna==3.10
73
+ # via anyio
74
+ # via httpx
75
+ # via requests
76
+ importlib-resources==6.4.5
77
+ # via gradio
78
+ jinja2==3.1.4
79
+ # via gradio
80
+ # via torch
81
+ jmespath==1.0.1
82
+ # via boto3
83
+ # via botocore
84
+ kiwisolver==1.4.7
85
+ # via matplotlib
86
+ markdown-it-py==3.0.0
87
+ # via rich
88
+ markupsafe==2.1.5
89
+ # via gradio
90
+ # via jinja2
91
+ matplotlib==3.9.2
92
+ # via gradio
93
+ mdurl==0.1.2
94
+ # via markdown-it-py
95
+ mpmath==1.3.0
96
+ # via sympy
97
+ networkx==3.3
98
+ # via torch
99
+ numpy==2.1.1
100
+ # via contourpy
101
+ # via gradio
102
+ # via matplotlib
103
+ # via pandas
104
+ # via scipy
105
+ # via torchvision
106
+ # via transformers
107
+ orjson==3.10.7
108
+ # via gradio
109
+ packaging==24.1
110
+ # via gradio
111
+ # via gradio-client
112
+ # via huggingface-hub
113
+ # via matplotlib
114
+ # via transformers
115
+ pandas==2.2.2
116
+ # via gradio
117
+ # via image-vector-search
118
+ pillow==10.4.0
119
+ # via gradio
120
+ # via matplotlib
121
+ # via torchvision
122
+ pydantic==2.9.1
123
+ # via fastapi
124
+ # via gradio
125
+ # via image-vector-search
126
+ # via pydantic-settings
127
+ pydantic-core==2.23.3
128
+ # via pydantic
129
+ pydantic-settings==2.5.2
130
+ # via image-vector-search
131
+ pydub==0.25.1
132
+ # via gradio
133
+ pygments==2.18.0
134
+ # via rich
135
+ pymongo==4.8.0
136
+ # via image-vector-search
137
+ pyparsing==3.1.4
138
+ # via matplotlib
139
+ python-dateutil==2.9.0.post0
140
+ # via botocore
141
+ # via matplotlib
142
+ # via pandas
143
+ python-dotenv==1.0.1
144
+ # via image-vector-search
145
+ # via pydantic-settings
146
+ python-multipart==0.0.9
147
+ # via gradio
148
+ # via image-vector-search
149
+ pytz==2024.2
150
+ # via pandas
151
+ pyyaml==6.0.2
152
+ # via gradio
153
+ # via huggingface-hub
154
+ # via transformers
155
+ regex==2024.9.11
156
+ # via transformers
157
+ requests==2.32.3
158
+ # via huggingface-hub
159
+ # via transformers
160
+ rich==13.8.1
161
+ # via typer
162
+ ruff==0.6.5
163
+ # via gradio
164
+ s3transfer==0.10.2
165
+ # via boto3
166
+ safetensors==0.4.5
167
+ # via transformers
168
+ scipy==1.14.1
169
+ # via image-vector-search
170
+ semantic-version==2.10.0
171
+ # via gradio
172
+ sentencepiece==0.2.0
173
+ # via image-vector-search
174
+ setuptools==75.1.0
175
+ # via torch
176
+ shellingham==1.5.4
177
+ # via typer
178
+ six==1.16.0
179
+ # via python-dateutil
180
+ sniffio==1.3.1
181
+ # via anyio
182
+ # via httpx
183
+ starlette==0.38.5
184
+ # via fastapi
185
+ sympy==1.13.2
186
+ # via torch
187
+ tokenizers==0.19.1
188
+ # via transformers
189
+ tomlkit==0.12.0
190
+ # via gradio
191
+ torch==2.4.1
192
+ # via image-vector-search
193
+ # via torchvision
194
+ torchvision==0.19.1
195
+ # via image-vector-search
196
+ tqdm==4.66.5
197
+ # via huggingface-hub
198
+ # via transformers
199
+ transformers==4.44.2
200
+ # via image-vector-search
201
+ typer==0.12.5
202
+ # via gradio
203
+ typing-extensions==4.12.2
204
+ # via fastapi
205
+ # via gradio
206
+ # via gradio-client
207
+ # via huggingface-hub
208
+ # via pydantic
209
+ # via pydantic-core
210
+ # via torch
211
+ # via typer
212
+ tzdata==2024.1
213
+ # via pandas
214
+ urllib3==2.2.3
215
+ # via botocore
216
+ # via gradio
217
+ # via requests
218
+ uvicorn==0.30.6
219
+ # via gradio
220
+ websockets==12.0
221
+ # via gradio-client
requirements.txt ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.2.1
2
+ # via gradio
3
+ annotated-types==0.7.0
4
+ # via pydantic
5
+ anyio==4.4.0
6
+ # via gradio
7
+ # via httpx
8
+ # via starlette
9
+ boto3==1.35.19
10
+ # via image-vector-search
11
+ botocore==1.35.19
12
+ # via boto3
13
+ # via s3transfer
14
+ certifi==2024.8.30
15
+ # via httpcore
16
+ # via httpx
17
+ # via requests
18
+ charset-normalizer==3.3.2
19
+ # via requests
20
+ click==8.1.7
21
+ # via typer
22
+ # via uvicorn
23
+ contourpy==1.3.0
24
+ # via matplotlib
25
+ cycler==0.12.1
26
+ # via matplotlib
27
+ dnspython==2.6.1
28
+ # via pymongo
29
+ fastapi==0.114.2
30
+ # via gradio
31
+ # via image-vector-search
32
+ ffmpy==0.4.0
33
+ # via gradio
34
+ filelock==3.16.0
35
+ # via huggingface-hub
36
+ # via torch
37
+ # via transformers
38
+ fonttools==4.53.1
39
+ # via matplotlib
40
+ fsspec==2024.9.0
41
+ # via gradio-client
42
+ # via huggingface-hub
43
+ # via torch
44
+ gradio==4.44.0
45
+ # via image-vector-search
46
+ gradio-client==1.3.0
47
+ # via gradio
48
+ h11==0.14.0
49
+ # via httpcore
50
+ # via uvicorn
51
+ httpcore==1.0.5
52
+ # via httpx
53
+ httpx==0.27.2
54
+ # via gradio
55
+ # via gradio-client
56
+ huggingface-hub==0.24.7
57
+ # via gradio
58
+ # via gradio-client
59
+ # via tokenizers
60
+ # via transformers
61
+ idna==3.10
62
+ # via anyio
63
+ # via httpx
64
+ # via requests
65
+ importlib-resources==6.4.5
66
+ # via gradio
67
+ jinja2==3.1.4
68
+ # via gradio
69
+ # via torch
70
+ jmespath==1.0.1
71
+ # via boto3
72
+ # via botocore
73
+ kiwisolver==1.4.7
74
+ # via matplotlib
75
+ markdown-it-py==3.0.0
76
+ # via rich
77
+ markupsafe==2.1.5
78
+ # via gradio
79
+ # via jinja2
80
+ matplotlib==3.9.2
81
+ # via gradio
82
+ mdurl==0.1.2
83
+ # via markdown-it-py
84
+ mpmath==1.3.0
85
+ # via sympy
86
+ networkx==3.3
87
+ # via torch
88
+ numpy==2.1.1
89
+ # via contourpy
90
+ # via gradio
91
+ # via matplotlib
92
+ # via pandas
93
+ # via scipy
94
+ # via torchvision
95
+ # via transformers
96
+ orjson==3.10.7
97
+ # via gradio
98
+ packaging==24.1
99
+ # via gradio
100
+ # via gradio-client
101
+ # via huggingface-hub
102
+ # via matplotlib
103
+ # via transformers
104
+ pandas==2.2.2
105
+ # via gradio
106
+ # via image-vector-search
107
+ pillow==10.4.0
108
+ # via gradio
109
+ # via matplotlib
110
+ # via torchvision
111
+ pydantic==2.9.1
112
+ # via fastapi
113
+ # via gradio
114
+ # via image-vector-search
115
+ # via pydantic-settings
116
+ pydantic-core==2.23.3
117
+ # via pydantic
118
+ pydantic-settings==2.5.2
119
+ # via image-vector-search
120
+ pydub==0.25.1
121
+ # via gradio
122
+ pygments==2.18.0
123
+ # via rich
124
+ pymongo==4.8.0
125
+ # via image-vector-search
126
+ pyparsing==3.1.4
127
+ # via matplotlib
128
+ python-dateutil==2.9.0.post0
129
+ # via botocore
130
+ # via matplotlib
131
+ # via pandas
132
+ python-dotenv==1.0.1
133
+ # via image-vector-search
134
+ # via pydantic-settings
135
+ python-multipart==0.0.9
136
+ # via gradio
137
+ # via image-vector-search
138
+ pytz==2024.2
139
+ # via pandas
140
+ pyyaml==6.0.2
141
+ # via gradio
142
+ # via huggingface-hub
143
+ # via transformers
144
+ regex==2024.9.11
145
+ # via transformers
146
+ requests==2.32.3
147
+ # via huggingface-hub
148
+ # via transformers
149
+ rich==13.8.1
150
+ # via typer
151
+ ruff==0.6.5
152
+ # via gradio
153
+ s3transfer==0.10.2
154
+ # via boto3
155
+ safetensors==0.4.5
156
+ # via transformers
157
+ scipy==1.14.1
158
+ # via image-vector-search
159
+ semantic-version==2.10.0
160
+ # via gradio
161
+ sentencepiece==0.2.0
162
+ # via image-vector-search
163
+ setuptools==75.1.0
164
+ # via torch
165
+ shellingham==1.5.4
166
+ # via typer
167
+ six==1.16.0
168
+ # via python-dateutil
169
+ sniffio==1.3.1
170
+ # via anyio
171
+ # via httpx
172
+ starlette==0.38.5
173
+ # via fastapi
174
+ sympy==1.13.2
175
+ # via torch
176
+ tokenizers==0.19.1
177
+ # via transformers
178
+ tomlkit==0.12.0
179
+ # via gradio
180
+ torch==2.4.1
181
+ # via image-vector-search
182
+ # via torchvision
183
+ torchvision==0.19.1
184
+ # via image-vector-search
185
+ tqdm==4.66.5
186
+ # via huggingface-hub
187
+ # via transformers
188
+ transformers==4.44.2
189
+ # via image-vector-search
190
+ typer==0.12.5
191
+ # via gradio
192
+ typing-extensions==4.12.2
193
+ # via fastapi
194
+ # via gradio
195
+ # via gradio-client
196
+ # via huggingface-hub
197
+ # via pydantic
198
+ # via pydantic-core
199
+ # via torch
200
+ # via typer
201
+ tzdata==2024.1
202
+ # via pandas
203
+ urllib3==2.2.3
204
+ # via botocore
205
+ # via gradio
206
+ # via requests
207
+ uvicorn==0.30.6
208
+ # via gradio
209
+ websockets==12.0
210
+ # via gradio-client
s3_session.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL.ImageFile import ImageFile
2
+ import boto3
3
+
4
+ import tempfile
5
+ import os
6
+ import logging
7
+ from dotenv import load_dotenv
8
+
9
+ load_dotenv(verbose=True)
10
+
11
+ AWS_ACCESS_KEY = os.getenv("AWS_ACCESS_KEY_ID")
12
+ AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
13
+ AWS_BUCKET_NAME = os.getenv("AWS_S3_BUCKET_NAME")
14
+ AWS_REGION = os.getenv("AWS_REGION")
15
+
16
+
17
+ class Bucket:
18
+ def __init__(self):
19
+ self.bucket_name = AWS_BUCKET_NAME
20
+ self.s3 = boto3.client(
21
+ "s3",
22
+ region_name=AWS_REGION,
23
+ aws_access_key_id=AWS_ACCESS_KEY,
24
+ aws_secret_access_key=AWS_SECRET_ACCESS_KEY,
25
+ )
26
+
27
+ def upload_file(self, image: ImageFile, uuid: str):
28
+ key = f"{uuid}.png"
29
+ try:
30
+ logging.info(
31
+ f"Uploading image to S3 with key: s3://{self.bucket_name}/{key}"
32
+ )
33
+ with tempfile.TemporaryFile() as fp:
34
+ image.save(fp, "PNG")
35
+ fp.seek(0)
36
+ self.s3.upload_fileobj(fp, self.bucket_name, key)
37
+ except Exception as e:
38
+ logging.error(e)
39
+ raise e
40
+
41
+ def get_presigned_url(self, uuid: str):
42
+ key = f"{uuid}.png"
43
+ try:
44
+ url = self.s3.generate_presigned_url(
45
+ "get_object",
46
+ Params={"Bucket": self.bucket_name, "Key": key},
47
+ ExpiresIn=3600,
48
+ )
49
+ return url
50
+ except Exception as e:
51
+ logging.error(e)
52
+ raise e