Spaces:
Paused
Paused
Upload 3 files
Browse files
xdecoder/language/LangEncoder/build.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
from transformers import CLIPTokenizer, CLIPTokenizerFast
|
4 |
+
from transformers import AutoTokenizer
|
5 |
+
|
6 |
+
from .registry import lang_encoders
|
7 |
+
from .registry import is_lang_encoder
|
8 |
+
|
9 |
+
|
10 |
+
def build_lang_encoder(config_encoder, tokenizer, verbose, **kwargs):
|
11 |
+
model_name = config_encoder['NAME']
|
12 |
+
|
13 |
+
if not is_lang_encoder(model_name):
|
14 |
+
raise ValueError(f'Unkown model: {model_name}')
|
15 |
+
|
16 |
+
return lang_encoders(model_name)(config_encoder, tokenizer, verbose, **kwargs)
|
17 |
+
|
18 |
+
|
19 |
+
def build_tokenizer(config_encoder):
|
20 |
+
tokenizer = None
|
21 |
+
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
|
22 |
+
if config_encoder['TOKENIZER'] == 'clip':
|
23 |
+
pretrained_tokenizer = config_encoder.get(
|
24 |
+
'PRETRAINED_TOKENIZER', 'openai/clip-vit-base-patch32'
|
25 |
+
)
|
26 |
+
tokenizer = CLIPTokenizer.from_pretrained(pretrained_tokenizer)
|
27 |
+
tokenizer.add_special_tokens({'cls_token': tokenizer.eos_token})
|
28 |
+
elif config_encoder['TOKENIZER'] == 'clip-fast':
|
29 |
+
pretrained_tokenizer = config_encoder.get(
|
30 |
+
'PRETRAINED_TOKENIZER', 'openai/clip-vit-base-patch32'
|
31 |
+
)
|
32 |
+
tokenizer = CLIPTokenizerFast.from_pretrained(pretrained_tokenizer, from_slow=True)
|
33 |
+
else:
|
34 |
+
tokenizer = AutoTokenizer.from_pretrained(config_encoder['TOKENIZER'])
|
35 |
+
|
36 |
+
return tokenizer
|
xdecoder/language/LangEncoder/registry.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_lang_encoders = {}
|
2 |
+
|
3 |
+
|
4 |
+
def register_lang_encoder(fn):
|
5 |
+
module_name_split = fn.__module__.split('.')
|
6 |
+
model_name = module_name_split[-1]
|
7 |
+
|
8 |
+
_lang_encoders[model_name] = fn
|
9 |
+
|
10 |
+
return fn
|
11 |
+
|
12 |
+
|
13 |
+
def lang_encoders(model_name):
|
14 |
+
return _lang_encoders[model_name]
|
15 |
+
|
16 |
+
|
17 |
+
def is_lang_encoder(model_name):
|
18 |
+
return model_name in _lang_encoders
|
xdecoder/language/LangEncoder/transformer.py
ADDED
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import OrderedDict
|
2 |
+
from typing import Tuple, Union
|
3 |
+
import logging
|
4 |
+
import os
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from torch import nn
|
10 |
+
|
11 |
+
from timm.models.layers import DropPath, trunc_normal_
|
12 |
+
|
13 |
+
from .registry import register_lang_encoder
|
14 |
+
from utils.distributed import is_main_process
|
15 |
+
from utils.model import register_norm_module
|
16 |
+
|
17 |
+
logger = logging.getLogger(__name__)
|
18 |
+
|
19 |
+
|
20 |
+
@register_norm_module
|
21 |
+
class LayerNorm(nn.Module):
|
22 |
+
def __init__(self, hidden_size, eps=1e-12):
|
23 |
+
"""Construct a layernorm module in the TF style (epsilon inside the square root).
|
24 |
+
"""
|
25 |
+
super(LayerNorm, self).__init__()
|
26 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
27 |
+
self.bias = nn.Parameter(torch.zeros(hidden_size))
|
28 |
+
self.variance_epsilon = eps
|
29 |
+
|
30 |
+
def forward(self, x):
|
31 |
+
pdtype = x.dtype
|
32 |
+
x = x.float()
|
33 |
+
u = x.mean(-1, keepdim=True)
|
34 |
+
s = (x - u).pow(2).mean(-1, keepdim=True)
|
35 |
+
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
|
36 |
+
return self.weight * x.to(pdtype) + self.bias
|
37 |
+
|
38 |
+
|
39 |
+
class QuickGELU(nn.Module):
|
40 |
+
def forward(self, x: torch.Tensor):
|
41 |
+
return x * torch.sigmoid(1.702 * x)
|
42 |
+
|
43 |
+
|
44 |
+
class ResidualAttentionBlock(nn.Module):
|
45 |
+
def __init__(self,
|
46 |
+
d_model: int,
|
47 |
+
n_head: int,
|
48 |
+
attn_mask: torch.Tensor = None,
|
49 |
+
drop_path: float = 0.0):
|
50 |
+
super().__init__()
|
51 |
+
|
52 |
+
self.attn = nn.MultiheadAttention(d_model, n_head)
|
53 |
+
self.ln_1 = LayerNorm(d_model)
|
54 |
+
self.mlp = nn.Sequential(OrderedDict([
|
55 |
+
("c_fc", nn.Linear(d_model, d_model * 4)),
|
56 |
+
("gelu", QuickGELU()),
|
57 |
+
("c_proj", nn.Linear(d_model * 4, d_model))
|
58 |
+
]))
|
59 |
+
self.ln_2 = LayerNorm(d_model)
|
60 |
+
self.attn_mask = attn_mask
|
61 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
62 |
+
|
63 |
+
def attention(self, x: torch.Tensor, key_padding_mask: torch.Tensor = None):
|
64 |
+
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) \
|
65 |
+
if self.attn_mask is not None else None
|
66 |
+
|
67 |
+
|
68 |
+
return self.attn(
|
69 |
+
x, x, x,
|
70 |
+
key_padding_mask=key_padding_mask,
|
71 |
+
need_weights=False,
|
72 |
+
attn_mask=self.attn_mask
|
73 |
+
)[0]
|
74 |
+
|
75 |
+
def forward(self, x: torch.Tensor, key_padding_mask: torch.Tensor = None):
|
76 |
+
x = x + self.drop_path(self.attention(self.ln_1(x), key_padding_mask=key_padding_mask))
|
77 |
+
x = x + self.drop_path(self.mlp(self.ln_2(x)))
|
78 |
+
return x
|
79 |
+
|
80 |
+
|
81 |
+
class Transformer(nn.Module):
|
82 |
+
def __init__(self,
|
83 |
+
context_length: int,
|
84 |
+
vocab_size: int,
|
85 |
+
width: int,
|
86 |
+
layers: int,
|
87 |
+
heads: int,
|
88 |
+
drop_path: float = 0.0,
|
89 |
+
autogressive: bool =True):
|
90 |
+
super().__init__()
|
91 |
+
|
92 |
+
self.token_embedding = nn.Embedding(vocab_size, width)
|
93 |
+
|
94 |
+
self.context_length = context_length
|
95 |
+
self.positional_embedding = nn.Parameter(
|
96 |
+
torch.empty(self.context_length, width)
|
97 |
+
)
|
98 |
+
|
99 |
+
self.width = width
|
100 |
+
self.layers = layers
|
101 |
+
self.autogressive = autogressive
|
102 |
+
attn_mask = self.build_attention_mask() if autogressive else None
|
103 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path, layers)] # stochastic depth decay rule
|
104 |
+
self.resblocks = nn.ModuleList(
|
105 |
+
[
|
106 |
+
ResidualAttentionBlock(width, heads, attn_mask, dpr[i])
|
107 |
+
for i in range(layers)
|
108 |
+
]
|
109 |
+
)
|
110 |
+
|
111 |
+
self.ln_final = LayerNorm(width)
|
112 |
+
|
113 |
+
trunc_normal_(self.positional_embedding, std=.02)
|
114 |
+
# nn.init.normal_(self.token_embedding, std=.02)
|
115 |
+
trunc_normal_(self.token_embedding.weight, std=.02)
|
116 |
+
self.apply(self._init_weights)
|
117 |
+
|
118 |
+
@property
|
119 |
+
def dim_out(self):
|
120 |
+
return self.width
|
121 |
+
|
122 |
+
def build_attention_mask(self):
|
123 |
+
# lazily create causal attention mask, with full attention between the vision tokens
|
124 |
+
# pytorch uses additive attention mask; fill with -inf
|
125 |
+
mask = torch.empty(self.context_length, self.context_length)
|
126 |
+
mask.fill_(float("-inf"))
|
127 |
+
mask.triu_(1) # zero out the lower diagonal
|
128 |
+
return mask
|
129 |
+
|
130 |
+
def _init_weights(self, m):
|
131 |
+
if isinstance(m, (nn.Linear, nn.Conv2d)):
|
132 |
+
if is_main_process():
|
133 |
+
logger.info('=> init weight of Linear/Conv2d from trunc norm')
|
134 |
+
trunc_normal_(m.weight, std=0.02)
|
135 |
+
if m.bias is not None:
|
136 |
+
if is_main_process():
|
137 |
+
logger.info('=> init bias of Linear/Conv2d to zeros')
|
138 |
+
nn.init.constant_(m.bias, 0)
|
139 |
+
elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)):
|
140 |
+
nn.init.constant_(m.bias, 0)
|
141 |
+
|
142 |
+
def load_pretrained(self, pretrained='', pretrained_layers=[], verbose=True):
|
143 |
+
if os.path.isfile(pretrained):
|
144 |
+
pretrained_dict = torch.load(pretrained, map_location='cpu')
|
145 |
+
logging.info(f'=> loading pretrained model {pretrained}')
|
146 |
+
model_dict = self.state_dict()
|
147 |
+
stripped_key = lambda x: x[13:] if x.startswith('lang_encoder.') else x
|
148 |
+
pretrained_dict = {
|
149 |
+
stripped_key(k): v for k, v in pretrained_dict.items()
|
150 |
+
if stripped_key(k) in model_dict.keys()
|
151 |
+
}
|
152 |
+
need_init_state_dict = {}
|
153 |
+
for k, v in pretrained_dict.items():
|
154 |
+
need_init = (
|
155 |
+
k.split('.')[0] in pretrained_layers
|
156 |
+
or pretrained_layers[0] == '*'
|
157 |
+
)
|
158 |
+
if need_init:
|
159 |
+
if verbose:
|
160 |
+
logger.info(f'=> init {k} from {pretrained}')
|
161 |
+
|
162 |
+
if 'positional_embedding' in k and v.size() != model_dict[k].size():
|
163 |
+
positional_embedding_pretrained = v
|
164 |
+
positional_embedding_current = model_dict[k]
|
165 |
+
L1, nH1 = positional_embedding_pretrained.size()
|
166 |
+
L2, nH2 = positional_embedding_current.size()
|
167 |
+
if nH1 != nH2:
|
168 |
+
logger.info(f"Error in loading {k}, passing")
|
169 |
+
else:
|
170 |
+
if L1 != L2:
|
171 |
+
logger.info(
|
172 |
+
'=> load_pretrained: resized variant: {} to {}'
|
173 |
+
.format((L1, nH1), (L2, nH2))
|
174 |
+
)
|
175 |
+
|
176 |
+
posemb = positional_embedding_pretrained.float()
|
177 |
+
posemb_grid = posemb.unsqueeze(dim=0).permute(0, 2, 1)
|
178 |
+
posemb_grid = torch.nn.functional.interpolate(posemb_grid, size=L2, mode='linear')
|
179 |
+
posemb_grid = posemb_grid.permute(0, 2, 1).squeeze(dim=0)
|
180 |
+
v = posemb_grid
|
181 |
+
|
182 |
+
need_init_state_dict[k] = v
|
183 |
+
|
184 |
+
self.load_state_dict(need_init_state_dict, strict=False)
|
185 |
+
|
186 |
+
|
187 |
+
@torch.jit.ignore
|
188 |
+
def no_weight_decay(self):
|
189 |
+
return {
|
190 |
+
'positional_embedding',
|
191 |
+
'token_embedding',
|
192 |
+
}
|
193 |
+
|
194 |
+
def forward(self, input_ids, attention_mask=None):
|
195 |
+
key_padding_mask = (attention_mask == 0) if (not self.autogressive and attention_mask is not None) else None
|
196 |
+
# key_padding_mask = (input_ids == 0) if not self.autogressive else None
|
197 |
+
x = self.token_embedding(input_ids) # [batch_size, n_ctx, d_model]
|
198 |
+
x = x + self.positional_embedding
|
199 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
200 |
+
for block in self.resblocks:
|
201 |
+
x = block(x, key_padding_mask)
|
202 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
203 |
+
|
204 |
+
x = self.ln_final(x)
|
205 |
+
|
206 |
+
return {'last_hidden_state': x}
|
207 |
+
|
208 |
+
|
209 |
+
@register_lang_encoder
|
210 |
+
def lang_encoder(config_encoder, tokenizer, verbose, **kwargs):
|
211 |
+
transformer = Transformer(
|
212 |
+
context_length=config_encoder['CONTEXT_LENGTH'],
|
213 |
+
vocab_size=tokenizer.vocab_size,
|
214 |
+
width=config_encoder['WIDTH'],
|
215 |
+
layers=config_encoder['LAYERS'],
|
216 |
+
heads=config_encoder['HEADS'],
|
217 |
+
autogressive=config_encoder.get('AUTOGRESSIVE', True)
|
218 |
+
)
|
219 |
+
|
220 |
+
if config_encoder.get('LOAD_PRETRAINED', False):
|
221 |
+
transformer.load_pretrained(config_encoder['PRETRAINED'], config_encoder.get('PRETRAINED_LAYERS', ['*']))
|
222 |
+
return transformer
|