hidehisa-arai commited on
Commit
ac9a398
1 Parent(s): 214aaec

add config and model architecture

Browse files
config.json ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_commit_hash": null,
3
+ "_name_or_path": "recruit-jp/japanese-clip-vit-b-32-roberta-base",
4
+ "architectures": [
5
+ "JapaneseCLIPModel"
6
+ ],
7
+ "auto_map": {
8
+ "AutoModel": "modeling_japanese_clip.JapaneseCLIPModel",
9
+ "AutoConfig": "configuration_japanese_clip.JapaneseCLIPConfig"
10
+ },
11
+ "torch_dtype": "float32",
12
+ "transformers_version": "4.36.2",
13
+ "model_type": "japanese_clip",
14
+ "text_config": {
15
+ "_name_or_path": "",
16
+ "architectures": [
17
+ "RobertaModel"
18
+ ],
19
+ "attention_probs_dropout_prob": 0.1,
20
+ "bos_token_id": 1,
21
+ "eos_token_id": 2,
22
+ "gradient_checkpointing": false,
23
+ "hidden_act": "gelu",
24
+ "hidden_dropout_prob": 0.1,
25
+ "hidden_size": 768,
26
+ "initializer_range": 0.02,
27
+ "intermediate_size": 3072,
28
+ "layer_norm_eps": 1e-05,
29
+ "max_position_embeddings": 514,
30
+ "model_type": "roberta",
31
+ "num_attention_heads": 12,
32
+ "num_hidden_layers": 12,
33
+ "pad_token_id": 3,
34
+ "position_embedding_type": "absolute",
35
+ "transformers_version": "4.6.1",
36
+ "type_vocab_size": 2,
37
+ "use_cache": true,
38
+ "vocab_size": 32000
39
+ },
40
+ "vision_config": {
41
+ "_name_or_path": "",
42
+ "image_size": 224,
43
+ "patch_size": 32,
44
+ "width": 768,
45
+ "layers": 12,
46
+ "mlp_ratio": 4.0,
47
+ "ls_init_value": null,
48
+ "attentional_pool": false,
49
+ "attn_pooler_queries": 256,
50
+ "attn_pooler_heads": 8,
51
+ "output_dim": 512,
52
+ "patch_dropout": 0.0,
53
+ "no_ln_pre": false,
54
+ "pool_type": "tok",
55
+ "final_ln_after_pool": false,
56
+ "output_tokens": false
57
+ }
58
+ }
configuration_japanese_clip.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from transformers import PretrainedConfig, RobertaConfig
3
+
4
+
5
+ class JapaneseCLIPVisionConfig(PretrainedConfig):
6
+ model_type = "vit"
7
+
8
+ def __init__(self,
9
+ image_size: int,
10
+ patch_size: int,
11
+ width: int,
12
+ layers: int,
13
+ heads: int,
14
+ mlp_ratio: float,
15
+ ls_init_value: float = None,
16
+ attentional_pool: bool = False,
17
+ attn_pooler_queries: int = 256,
18
+ attn_pooler_heads: int = 8,
19
+ output_dim: int = 512,
20
+ patch_dropout: float = 0.0,
21
+ no_ln_pre: bool = False,
22
+ pool_type: str = "tok",
23
+ final_ln_after_pool: bool = False,
24
+ output_tokens: bool = False,
25
+ **kwargs
26
+ ):
27
+ super().__init__(**kwargs)
28
+
29
+ self.image_size = image_size
30
+ self.patch_size = patch_size
31
+ self.width = width
32
+ self.layers = layers
33
+ self.heads = heads
34
+ self.mlp_ratio = mlp_ratio
35
+ self.ls_init_value = ls_init_value
36
+ self.attentional_pool = attentional_pool
37
+ self.attn_pooler_queries = attn_pooler_queries
38
+ self.attn_pooler_heads = attn_pooler_heads
39
+ self.output_dim = output_dim
40
+ self.patch_dropout = patch_dropout
41
+ self.no_ln_pre = no_ln_pre
42
+ self.pool_type = pool_type
43
+ self.final_ln_after_pool = final_ln_after_pool
44
+ self.output_tokens = output_tokens
45
+
46
+
47
+ class JapaneseCLIPConfig(PretrainedConfig):
48
+ model_type = "japanese_clip"
49
+
50
+ def __init__(
51
+ self,
52
+ max_length: int = 77,
53
+ **kwargs
54
+ ):
55
+ super().__init__(**kwargs)
56
+
57
+ self.max_length = max_length
58
+
59
+ if "vision_config" not in kwargs:
60
+ raise ValueError("vision_config must be provided")
61
+ if "text_config" not in kwargs:
62
+ raise ValueError("text_config must be provided")
63
+
64
+ vision_config = kwargs.pop("vision_config")
65
+ text_config = kwargs.pop("text_config")
66
+
67
+ self.vision_config = JapaneseCLIPVisionConfig(**vision_config)
68
+ self.text_config = RobertaConfig(**text_config)
69
+
70
+ @classmethod
71
+ def from_vision_text_configs(
72
+ cls,
73
+ vision_config: PretrainedConfig,
74
+ text_config: PretrainedConfig,
75
+ **kwargs
76
+ ):
77
+ r"""
78
+ Instantiate a [`VisionTextDualEncoderConfig`] (or a derived class) from text model configuration and vision
79
+ model configuration.
80
+ Returns:
81
+ [`VisionTextDualEncoderConfig`]: An instance of a configuration object
82
+ """
83
+
84
+ return cls(
85
+ vision_config=vision_config.to_dict(),
86
+ text_config=text_config.to_dict(),
87
+ **kwargs,
88
+ )
modeling_japanese_clip.py ADDED
@@ -0,0 +1,476 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections.abc
2
+ import math
3
+ from collections import OrderedDict
4
+ from itertools import repeat
5
+ from typing import Callable, Optional, Sequence, Tuple
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from torch.nn import functional as F
10
+ from torch.utils.checkpoint import checkpoint
11
+
12
+ from transformers import AutoModel, PreTrainedModel
13
+
14
+ from .configuration_japanese_clip import JapaneseCLIPConfig
15
+
16
+
17
+ class LayerNorm(nn.LayerNorm):
18
+ """Subclass torch's LayerNorm (with cast back to input dtype)."""
19
+
20
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
21
+ orig_dtype = x.dtype
22
+ x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
23
+ return x.to(dtype=orig_dtype)
24
+
25
+
26
+ class LayerScale(nn.Module):
27
+ def __init__(self, dim, init_values=1e-5, inplace=False):
28
+ super().__init__()
29
+ self.inplace = inplace
30
+ self.gamma = nn.Parameter(torch.ones(dim) * init_values)
31
+
32
+ def forward(self, x):
33
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
34
+
35
+
36
+ class PatchDropout(nn.Module):
37
+ """
38
+ https://arxiv.org/abs/2212.00794
39
+ """
40
+
41
+ def __init__(self, prob, exclude_first_token=True):
42
+ super().__init__()
43
+ assert 0 <= prob < 1.0
44
+ self.prob = prob
45
+ self.exclude_first_token = exclude_first_token # exclude CLS token
46
+
47
+ def forward(self, x):
48
+ if not self.training or self.prob == 0.:
49
+ return x
50
+
51
+ if self.exclude_first_token:
52
+ cls_tokens, x = x[:, :1], x[:, 1:]
53
+ else:
54
+ cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1])
55
+
56
+ batch = x.size()[0]
57
+ num_tokens = x.size()[1]
58
+
59
+ batch_indices = torch.arange(batch)
60
+ batch_indices = batch_indices[..., None]
61
+
62
+ keep_prob = 1 - self.prob
63
+ num_patches_keep = max(1, int(num_tokens * keep_prob))
64
+
65
+ rand = torch.randn(batch, num_tokens)
66
+ patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices
67
+
68
+ x = x[batch_indices, patch_indices_keep]
69
+
70
+ if self.exclude_first_token:
71
+ x = torch.cat((cls_tokens, x), dim=1)
72
+
73
+ return x
74
+
75
+
76
+ class AttentionalPooler(nn.Module):
77
+ def __init__(
78
+ self,
79
+ d_model: int,
80
+ context_dim: int,
81
+ n_head: int = 8,
82
+ n_queries: int = 256,
83
+ norm_layer: Callable = LayerNorm
84
+ ):
85
+ super().__init__()
86
+ self.query = nn.Parameter(torch.randn(n_queries, d_model))
87
+ self.attn = nn.MultiheadAttention(d_model, n_head, kdim=context_dim, vdim=context_dim)
88
+ self.ln_q = norm_layer(d_model)
89
+ self.ln_k = norm_layer(context_dim)
90
+
91
+ def forward(self, x: torch.Tensor):
92
+ x = self.ln_k(x).permute(1, 0, 2) # NLD -> LND
93
+ N = x.shape[1]
94
+ q = self.ln_q(self.query)
95
+ out = self.attn(q.unsqueeze(1).expand(-1, N, -1), x, x, need_weights=False)[0]
96
+ return out.permute(1, 0, 2) # LND -> NLD
97
+
98
+
99
+ class ResidualAttentionBlock(nn.Module):
100
+ def __init__(
101
+ self,
102
+ d_model: int,
103
+ n_head: int,
104
+ mlp_ratio: float = 4.0,
105
+ ls_init_value: Optional[float] = None,
106
+ act_layer: Callable = nn.GELU,
107
+ norm_layer: Callable = LayerNorm,
108
+ is_cross_attention: bool = False,
109
+ ):
110
+ super().__init__()
111
+
112
+ self.ln_1 = norm_layer(d_model)
113
+ self.attn = nn.MultiheadAttention(d_model, n_head)
114
+ self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
115
+ if is_cross_attention:
116
+ self.ln_1_kv = norm_layer(d_model)
117
+
118
+ self.ln_2 = norm_layer(d_model)
119
+ mlp_width = int(d_model * mlp_ratio)
120
+ self.mlp = nn.Sequential(OrderedDict([
121
+ ("c_fc", nn.Linear(d_model, mlp_width)),
122
+ ("gelu", act_layer()),
123
+ ("c_proj", nn.Linear(mlp_width, d_model))
124
+ ]))
125
+ self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
126
+
127
+ def attention(
128
+ self,
129
+ q_x: torch.Tensor,
130
+ k_x: Optional[torch.Tensor] = None,
131
+ v_x: Optional[torch.Tensor] = None,
132
+ attn_mask: Optional[torch.Tensor] = None,
133
+ ):
134
+ k_x = k_x if k_x is not None else q_x
135
+ v_x = v_x if v_x is not None else q_x
136
+
137
+ attn_mask = attn_mask.to(q_x.dtype) if attn_mask is not None else None
138
+ return self.attn(
139
+ q_x, k_x, v_x, need_weights=False, attn_mask=attn_mask
140
+ )[0]
141
+
142
+ def forward(
143
+ self,
144
+ q_x: torch.Tensor,
145
+ k_x: Optional[torch.Tensor] = None,
146
+ v_x: Optional[torch.Tensor] = None,
147
+ attn_mask: Optional[torch.Tensor] = None,
148
+ ):
149
+ k_x = self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None
150
+ v_x = self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None
151
+
152
+ x = q_x + self.ls_1(self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask))
153
+ x = x + self.ls_2(self.mlp(self.ln_2(x)))
154
+ return x
155
+
156
+
157
+ # From PyTorch internals
158
+ def _ntuple(n):
159
+ def parse(x):
160
+ if isinstance(x, collections.abc.Iterable):
161
+ return x
162
+ return tuple(repeat(x, n))
163
+ return parse
164
+
165
+ to_2tuple = _ntuple(2)
166
+
167
+
168
+ def _expand_token(token, batch_size: int):
169
+ return token.view(1, 1, -1).expand(batch_size, -1, -1)
170
+
171
+
172
+ class Transformer(nn.Module):
173
+ def __init__(
174
+ self,
175
+ width: int,
176
+ layers: int,
177
+ heads: int,
178
+ mlp_ratio: float = 4.0,
179
+ ls_init_value: float = None,
180
+ act_layer: Callable = nn.GELU,
181
+ norm_layer: Callable = LayerNorm,
182
+ ):
183
+ super().__init__()
184
+ self.width = width
185
+ self.layers = layers
186
+ self.grad_checkpointing = False
187
+
188
+ self.resblocks = nn.ModuleList([
189
+ ResidualAttentionBlock(
190
+ width, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer)
191
+ for _ in range(layers)
192
+ ])
193
+
194
+ def get_cast_dtype(self) -> torch.dtype:
195
+ if hasattr(self.resblocks[0].mlp.c_fc, 'int8_original_dtype'):
196
+ return self.resblocks[0].mlp.c_fc.int8_original_dtype
197
+ return self.resblocks[0].mlp.c_fc.weight.dtype
198
+
199
+ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
200
+ for r in self.resblocks:
201
+ if self.grad_checkpointing and not torch.jit.is_scripting():
202
+ # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372
203
+ x = checkpoint(r, x, None, None, attn_mask)
204
+ else:
205
+ x = r(x, attn_mask=attn_mask)
206
+ return x
207
+
208
+
209
+ class JapaneseCLIPVisionTransformer(PreTrainedModel):
210
+ output_tokens: torch.jit.Final[bool]
211
+
212
+ def __init__(
213
+ self,
214
+ image_size: int,
215
+ patch_size: int,
216
+ width: int,
217
+ layers: int,
218
+ heads: int,
219
+ mlp_ratio: float,
220
+ ls_init_value: float = None,
221
+ attentional_pool: bool = False,
222
+ attn_pooler_queries: int = 256,
223
+ attn_pooler_heads: int = 8,
224
+ output_dim: int = 512,
225
+ patch_dropout: float = 0.,
226
+ no_ln_pre: bool = False,
227
+ pool_type: str = 'tok',
228
+ final_ln_after_pool: bool = False,
229
+ act_layer: Callable = nn.GELU,
230
+ norm_layer: Callable = LayerNorm,
231
+ output_tokens: bool = False,
232
+ ):
233
+ super().__init__()
234
+ assert pool_type in ('tok', 'avg', 'none')
235
+ self.output_tokens = output_tokens
236
+ image_height, image_width = self.image_size = to_2tuple(image_size)
237
+ patch_height, patch_width = self.patch_size = to_2tuple(patch_size)
238
+ self.grid_size = (image_height // patch_height, image_width // patch_width)
239
+ self.final_ln_after_pool = final_ln_after_pool # currently ignored w/ attn pool enabled
240
+ self.output_dim = output_dim
241
+
242
+ self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
243
+
244
+ # class embeddings and positional embeddings
245
+ scale = width ** -0.5
246
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
247
+ self.positional_embedding = nn.Parameter(
248
+ scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width))
249
+
250
+ # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn
251
+ self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity()
252
+
253
+ self.ln_pre = nn.Identity() if no_ln_pre else norm_layer(width)
254
+ self.transformer = Transformer(
255
+ width,
256
+ layers,
257
+ heads,
258
+ mlp_ratio,
259
+ ls_init_value=ls_init_value,
260
+ act_layer=act_layer,
261
+ norm_layer=norm_layer,
262
+ )
263
+
264
+ if attentional_pool:
265
+ if isinstance(attentional_pool, str):
266
+ self.attn_pool_type = attentional_pool
267
+ self.pool_type = 'none'
268
+ if attentional_pool in ('parallel', 'cascade'):
269
+ self.attn_pool = AttentionalPooler(
270
+ output_dim,
271
+ width,
272
+ n_head=attn_pooler_heads,
273
+ n_queries=attn_pooler_queries,
274
+ )
275
+ self.attn_pool_contrastive = AttentionalPooler(
276
+ output_dim,
277
+ width,
278
+ n_head=attn_pooler_heads,
279
+ n_queries=1,
280
+ )
281
+ else:
282
+ assert False
283
+ else:
284
+ self.attn_pool_type = ''
285
+ self.pool_type = pool_type
286
+ self.attn_pool = AttentionalPooler(
287
+ output_dim,
288
+ width,
289
+ n_head=attn_pooler_heads,
290
+ n_queries=attn_pooler_queries,
291
+ )
292
+ self.attn_pool_contrastive = None
293
+ pool_dim = output_dim
294
+ else:
295
+ self.attn_pool = None
296
+ pool_dim = width
297
+ self.pool_type = pool_type
298
+
299
+ self.ln_post = norm_layer(pool_dim)
300
+ self.proj = nn.Parameter(scale * torch.randn(pool_dim, output_dim))
301
+
302
+ self.init_parameters()
303
+
304
+ def lock(self, unlocked_groups=0, freeze_bn_stats=False):
305
+ for param in self.parameters():
306
+ param.requires_grad = False
307
+
308
+ if unlocked_groups != 0:
309
+ groups = [
310
+ [
311
+ self.conv1,
312
+ self.class_embedding,
313
+ self.positional_embedding,
314
+ self.ln_pre,
315
+ ],
316
+ *self.transformer.resblocks[:-1],
317
+ [
318
+ self.transformer.resblocks[-1],
319
+ self.ln_post,
320
+ ],
321
+ self.proj,
322
+ ]
323
+
324
+ def _unlock(x):
325
+ if isinstance(x, Sequence):
326
+ for g in x:
327
+ _unlock(g)
328
+ else:
329
+ if isinstance(x, torch.nn.Parameter):
330
+ x.requires_grad = True
331
+ else:
332
+ for p in x.parameters():
333
+ p.requires_grad = True
334
+
335
+ _unlock(groups[-unlocked_groups:])
336
+
337
+ def init_parameters(self):
338
+ # FIXME OpenAI CLIP did not define an init for the VisualTransformer
339
+ # TODO experiment if default PyTorch init, below, or alternate init is best.
340
+
341
+ # nn.init.normal_(self.class_embedding, std=self.scale)
342
+ # nn.init.normal_(self.positional_embedding, std=self.scale)
343
+ #
344
+ # proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
345
+ # attn_std = self.transformer.width ** -0.5
346
+ # fc_std = (2 * self.transformer.width) ** -0.5
347
+ # for block in self.transformer.resblocks:
348
+ # nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
349
+ # nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
350
+ # nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
351
+ # nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
352
+ #
353
+ # if self.text_projection is not None:
354
+ # nn.init.normal_(self.text_projection, std=self.scale)
355
+ pass
356
+
357
+ @torch.jit.ignore
358
+ def set_grad_checkpointing(self, enable=True):
359
+ self.transformer.grad_checkpointing = enable
360
+
361
+ def _global_pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
362
+ if self.pool_type == 'avg':
363
+ pooled, tokens = x[:, 1:].mean(dim=1), x[:, 1:]
364
+ elif self.pool_type == 'tok':
365
+ pooled, tokens = x[:, 0], x[:, 1:]
366
+ else:
367
+ pooled = tokens = x
368
+
369
+ return pooled, tokens
370
+
371
+ def forward(self, x: torch.Tensor):
372
+ x = self.conv1(x) # shape = [*, width, grid, grid]
373
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
374
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
375
+
376
+ # class embeddings and positional embeddings
377
+ x = torch.cat([_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x], dim=1)
378
+ # shape = [*, grid ** 2 + 1, width]
379
+ x = x + self.positional_embedding.to(x.dtype)
380
+
381
+ x = self.patch_dropout(x)
382
+ x = self.ln_pre(x)
383
+
384
+ x = x.permute(1, 0, 2) # NLD -> LND
385
+ x = self.transformer(x)
386
+ x = x.permute(1, 0, 2) # LND -> NLD
387
+
388
+ if self.attn_pool is not None:
389
+ if self.attn_pool_contrastive is not None:
390
+ # This is untested, WIP pooling that should match paper
391
+ x = self.ln_post(x) # TBD LN first or separate one after each pool?
392
+ tokens = self.attn_pool(x)
393
+ if self.attn_pool_type == 'parallel':
394
+ pooled = self.attn_pool_contrastive(x)
395
+ else:
396
+ assert self.attn_pool_type == 'cascade'
397
+ pooled = self.attn_pool_contrastive(tokens)
398
+ else:
399
+ # this is the original OpenCLIP CoCa setup, does not match paper
400
+ x = self.attn_pool(x)
401
+ x = self.ln_post(x)
402
+ pooled, tokens = self._global_pool(x)
403
+ elif self.final_ln_after_pool:
404
+ pooled, tokens = self._global_pool(x)
405
+ pooled = self.ln_post(pooled)
406
+ else:
407
+ x = self.ln_post(x)
408
+ pooled, tokens = self._global_pool(x)
409
+
410
+ if self.proj is not None:
411
+ pooled = pooled @ self.proj
412
+
413
+ if self.output_tokens:
414
+ return pooled, tokens
415
+
416
+ return pooled
417
+
418
+
419
+ class JapaneseCLIPModel(PreTrainedModel):
420
+ config_class = JapaneseCLIPConfig
421
+
422
+ def __init__(self, config: JapaneseCLIPConfig):
423
+ super().__init__(config)
424
+ text_config = config.text_config
425
+ vision_config = config.vision_config
426
+
427
+ self.image_encoder = JapaneseCLIPVisionTransformer(
428
+ **vision_config.to_dict()
429
+ )
430
+ self.text_encoder = AutoModel.from_config(text_config, add_pooling_layer=False)
431
+ hidden_size = text_config.hidden_size
432
+ self.projection_dim = self.image_encoder.output_dim
433
+ self.text_projection = nn.Linear(hidden_size, self.projection_dim, bias=False)
434
+ self.logit_scale = nn.Parameter(torch.ones([]) * math.log(1 / 0.07))
435
+ self.max_length = config.max_length
436
+ self.position_ids = list(range(0, self.max_length))
437
+
438
+ def _create_position_id_tensor(self, batch_size: int) -> torch.LongTensor:
439
+ # rinna/japanese-roberta-base requires providing custom position ids
440
+ # see: https://huggingface.co/rinna/japanese-roberta-base#note-3-provide-position_ids-as-an-argument-explicitly
441
+ return torch.LongTensor([self.position_ids for _ in range(batch_size)])
442
+
443
+ def get_image_features(self, pixel_values: torch.FloatTensor) -> torch.FloatTensor:
444
+ return self.image_encoder(pixel_values) # (batch_size, hidden_dim)
445
+
446
+ def get_text_features(
447
+ self, input_ids: torch.Tensor, position_ids: torch.Tensor = None
448
+ ) -> torch.FloatTensor:
449
+ if position_ids is None:
450
+ position_ids = self._create_position_id_tensor(input_ids.size(0)).to(
451
+ input_ids.device
452
+ )
453
+ last_hidden_state = self.text_encoder(
454
+ input_ids=input_ids,
455
+ position_ids=position_ids,
456
+ output_hidden_states=True,
457
+ return_dict=True,
458
+ ).hidden_states[
459
+ -1
460
+ ] # (batch_size, tokens, embed_dim)
461
+ pooled_output = last_hidden_state[:, 0, :] # (batch_size, embed_dim)
462
+ return self.text_projection(pooled_output) # (batch_size, hidden_dim)
463
+
464
+ def forward(
465
+ self,
466
+ pixel_values: torch.FloatTensor,
467
+ input_ids: torch.Tensor,
468
+ position_ids: torch.Tensor = None,
469
+ ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
470
+ """
471
+ DDPを使うときはこのメソッドを経由しなければならない
472
+ 他のメソッドで得られた勾配はGPU間で同期されない
473
+ """
474
+ image_features = self.get_image_features(pixel_values)
475
+ text_features = self.get_text_features(input_ids, position_ids)
476
+ return image_features, text_features, self.logit_scale