Spaces:
Paused
Paused
Upload 6 files
Browse files- xdecoder/language/build.py +11 -0
- xdecoder/language/fixvlpencoder.py +35 -0
- xdecoder/language/loss.py +225 -0
- xdecoder/language/misc.py +64 -0
- xdecoder/language/registry.py +13 -0
- xdecoder/language/vlpencoder.py +168 -0
xdecoder/language/build.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .registry import model_entrypoints
|
2 |
+
from .registry import is_model
|
3 |
+
|
4 |
+
|
5 |
+
def build_language_encoder(config, **kwargs):
|
6 |
+
model_name = config['MODEL']['TEXT']['ARCH']
|
7 |
+
|
8 |
+
if not is_model(model_name):
|
9 |
+
raise ValueError(f'Unkown model: {model_name}')
|
10 |
+
|
11 |
+
return model_entrypoints(model_name)(config, **kwargs)
|
xdecoder/language/fixvlpencoder.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from importlib.metadata import requires
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
from .registry import register_model
|
6 |
+
from .vlpencoder import LanguageEncoder
|
7 |
+
|
8 |
+
class FixLanguageEncoder(LanguageEncoder):
|
9 |
+
|
10 |
+
def __init__(
|
11 |
+
self,
|
12 |
+
*args, **kwargs):
|
13 |
+
super(FixLanguageEncoder, self).__init__(*args, **kwargs)
|
14 |
+
self.logit_scale = nn.Parameter(torch.ones([]), requires_grad=False)
|
15 |
+
|
16 |
+
@torch.no_grad()
|
17 |
+
def get_text_embeddings(self, *args, **kwargs):
|
18 |
+
return super().get_text_embeddings(*args, **kwargs)
|
19 |
+
|
20 |
+
@torch.no_grad()
|
21 |
+
def get_text_token_embeddings(self, *args, **kwargs):
|
22 |
+
return super().get_text_token_embeddings(*args, **kwargs)
|
23 |
+
|
24 |
+
@torch.no_grad()
|
25 |
+
def forward_language(self, *args, **kwargs):
|
26 |
+
return super().forward_language(*args, **kwargs)
|
27 |
+
|
28 |
+
@torch.no_grad()
|
29 |
+
def forward_language_token(self, *args, **kwargs):
|
30 |
+
return super().forward_language_token(*args, **kwargs)
|
31 |
+
|
32 |
+
|
33 |
+
@register_model
|
34 |
+
def get_language_model(cfg, **kwargs):
|
35 |
+
return FixLanguageEncoder(cfg)
|
xdecoder/language/loss.py
ADDED
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pickle
|
2 |
+
from distutils import log
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
import torch.distributed as dist
|
7 |
+
|
8 |
+
from einops import rearrange, repeat
|
9 |
+
from timm.loss import SoftTargetCrossEntropy
|
10 |
+
|
11 |
+
soft_cross_entropy = SoftTargetCrossEntropy()
|
12 |
+
|
13 |
+
def is_dist_initialized():
|
14 |
+
return torch.distributed.is_initialized()
|
15 |
+
|
16 |
+
def get_world_size():
|
17 |
+
if is_dist_initialized():
|
18 |
+
return torch.distributed.get_world_size()
|
19 |
+
return 1
|
20 |
+
|
21 |
+
def get_rank():
|
22 |
+
if is_dist_initialized():
|
23 |
+
return dist.get_rank()
|
24 |
+
return 0
|
25 |
+
|
26 |
+
def all_gather_grad(x):
|
27 |
+
if get_world_size() > 1:
|
28 |
+
all_x = [torch.zeros_like(x) for _ in range(get_world_size())]
|
29 |
+
torch.distributed.all_gather(all_x, x)
|
30 |
+
all_x[torch.distributed.get_rank()] = x
|
31 |
+
x = torch.cat(all_x, dim=0)
|
32 |
+
return x
|
33 |
+
|
34 |
+
def vl_multilabel_contrastive_loss(image_feat, text_feat, temperature=1):
|
35 |
+
"""
|
36 |
+
Args:
|
37 |
+
image_feat (torch.Tensor): shape [B, L1, C] # B: batch_size, L1: 1, C: 256
|
38 |
+
text_feat (torch.Tensor): shape [B, L2, C] # B:batch_size, L2: number of selected nouns, C: 256
|
39 |
+
|
40 |
+
Returns:
|
41 |
+
"""
|
42 |
+
# [B, L1, C], L1 = 1
|
43 |
+
# image_feat = F.normalize(image_feat, dim=-1)
|
44 |
+
# [B, L2, C]
|
45 |
+
# text_feat = F.normalize(text_feat, dim=-1)
|
46 |
+
# HACK: normalize outside
|
47 |
+
|
48 |
+
# [B, L1, L2]
|
49 |
+
dist_per_img = image_feat @ rearrange(text_feat, 'b l c -> b c l')
|
50 |
+
# [B, L2, L1]
|
51 |
+
dist_per_text = text_feat @ rearrange(image_feat, 'b l c -> b c l')
|
52 |
+
|
53 |
+
batch = image_feat.shape[0]
|
54 |
+
img_len = image_feat.shape[1]
|
55 |
+
text_len = text_feat.shape[1]
|
56 |
+
# [B, L1, L2]
|
57 |
+
pos_labels_batch_img = rearrange(torch.ones_like(dist_per_text) / dist_per_text.size(1), 'b l2 l1 -> b l1 l2')
|
58 |
+
# [B, L2, L1]
|
59 |
+
pos_labels_batch_text = rearrange(torch.ones_like(dist_per_img) / dist_per_img.size(1), 'b l1 l2 -> b l2 l1')
|
60 |
+
|
61 |
+
image_x = rearrange(image_feat, 'b l c -> (b l) c')
|
62 |
+
text_x = rearrange(text_feat, 'b l c -> (b l) c')
|
63 |
+
|
64 |
+
logits_per_img = image_x @ all_gather_grad(text_x).t()
|
65 |
+
logits_per_text = text_x @ all_gather_grad(image_x).t()
|
66 |
+
|
67 |
+
# get label globally
|
68 |
+
# [B, L1, B, L2, W]
|
69 |
+
labels_per_img = F.one_hot(
|
70 |
+
torch.ones(batch, img_len, batch, text_len, dtype=torch.long, device=image_x.device) * get_rank(),
|
71 |
+
num_classes=get_world_size()).to(image_x.dtype)
|
72 |
+
labels_per_img *= rearrange(pos_labels_batch_img, 'b l1 l2 -> b l1 1 l2 1') * repeat(
|
73 |
+
torch.eye(batch, dtype=image_x.dtype, device=image_x.device), 'b1 b2 -> b1 1 b2 1 1')
|
74 |
+
# [BxL1, WxBxL2]
|
75 |
+
labels_per_img = rearrange(labels_per_img, 'b1 l1 b2 l2 w -> (b1 l1) (w b2 l2)')
|
76 |
+
# [B, L2, B, L1, W]
|
77 |
+
labels_per_text = F.one_hot(
|
78 |
+
torch.ones(batch, text_len, batch, img_len, dtype=torch.long, device=text_x.device) * get_rank(),
|
79 |
+
num_classes=get_world_size()).to(text_x.dtype)
|
80 |
+
labels_per_text *= rearrange(pos_labels_batch_text, 'b l2 l1 -> b l2 1 l1 1') * repeat(
|
81 |
+
torch.eye(batch, dtype=text_x.dtype, device=image_x.device), 'b2 b1 -> b2 1 b1 1 1')
|
82 |
+
# [BxL2, WxBxL1]
|
83 |
+
labels_per_text = rearrange(labels_per_text, 'b2 l2 b1 l1 w -> (b2 l2) (w b1 l1)')
|
84 |
+
|
85 |
+
logit_scale = temperature.exp().clamp(max=100)
|
86 |
+
|
87 |
+
loss_img = soft_cross_entropy(logit_scale * logits_per_img, labels_per_img)
|
88 |
+
loss_text = soft_cross_entropy(logit_scale * logits_per_text, labels_per_text)
|
89 |
+
|
90 |
+
loss = 0.5 * (loss_img + loss_text)
|
91 |
+
return loss
|
92 |
+
|
93 |
+
def vl_contrastive_loss(image_feat, text_feat, temperature=1):
|
94 |
+
# if image_id or text_id is None, it should be None across all GPUs
|
95 |
+
# image_feat = F.normalize(image_feat, dim=1)
|
96 |
+
# text_feat = F.normalize(text_feat, dim=1)
|
97 |
+
# handle normalization outside
|
98 |
+
|
99 |
+
# add the following 4 lines
|
100 |
+
image_feat = all_gather_grad(image_feat)
|
101 |
+
text_feat = all_gather_grad(text_feat)
|
102 |
+
|
103 |
+
logits = torch.matmul(image_feat, text_feat.t())
|
104 |
+
logit_scale = temperature.exp().clamp(max=100)
|
105 |
+
|
106 |
+
gt = torch.arange(logits.shape[0], device=logits.device)
|
107 |
+
loss1 = F.cross_entropy(logit_scale * logits, gt)
|
108 |
+
loss2 = F.cross_entropy(logit_scale * logits.t(), gt)
|
109 |
+
return (loss1 + loss2) / 2 # scale it up by the number of GPUs
|
110 |
+
|
111 |
+
|
112 |
+
def all_gather_pickle(data, device):
|
113 |
+
"""
|
114 |
+
Run all_gather on arbitrary picklable data (not necessarily tensors)
|
115 |
+
Args:
|
116 |
+
data: any picklable object
|
117 |
+
Returns:
|
118 |
+
list[data]: list of data gathered from each rank
|
119 |
+
"""
|
120 |
+
world_size = get_world_size()
|
121 |
+
if world_size == 1:
|
122 |
+
return [data]
|
123 |
+
|
124 |
+
# serialized to a Tensor
|
125 |
+
buffer = pickle.dumps(data)
|
126 |
+
storage = torch.ByteStorage.from_buffer(buffer)
|
127 |
+
tensor = torch.ByteTensor(storage).to(device)
|
128 |
+
|
129 |
+
# obtain Tensor size of each rank
|
130 |
+
local_size = torch.LongTensor([tensor.numel()]).cuda()
|
131 |
+
size_list = [torch.LongTensor([0]).cuda() for _ in range(world_size)]
|
132 |
+
dist.all_gather(size_list, local_size)
|
133 |
+
size_list = [int(size.item()) for size in size_list]
|
134 |
+
max_size = max(size_list)
|
135 |
+
|
136 |
+
# receiving Tensor from all ranks
|
137 |
+
# we pad the tensor because torch all_gather does not support
|
138 |
+
# gathering tensors of different shapes
|
139 |
+
tensor_list = []
|
140 |
+
for _ in size_list:
|
141 |
+
tensor_list.append(torch.ByteTensor(size=(max_size,)).cuda())
|
142 |
+
if local_size != max_size:
|
143 |
+
padding = torch.ByteTensor(size=(max_size - local_size,)).cuda()
|
144 |
+
tensor = torch.cat((tensor, padding), dim=0)
|
145 |
+
dist.all_gather(tensor_list, tensor)
|
146 |
+
|
147 |
+
data_list = []
|
148 |
+
for size, tensor in zip(size_list, tensor_list):
|
149 |
+
buffer = tensor.cpu().numpy().tobytes()[:size]
|
150 |
+
data_list.append(pickle.loads(buffer))
|
151 |
+
|
152 |
+
return data_list
|
153 |
+
|
154 |
+
def all_gather_arbitary_tensor(tensor):
|
155 |
+
if get_world_size() > 1:
|
156 |
+
device = tensor.device
|
157 |
+
tensor_batch = all_gather_pickle(tensor.cpu(), device)
|
158 |
+
tensor_batch = [x.to(device) for x in tensor_batch]
|
159 |
+
tensor_batch[torch.distributed.get_rank()] = tensor
|
160 |
+
tensor_batch = torch.cat(tensor_batch, dim=0)
|
161 |
+
else:
|
162 |
+
tensor_batch = tensor
|
163 |
+
return tensor_batch
|
164 |
+
|
165 |
+
def ql_contrastive_loss(image_feat, text_feat, temperature=1):
|
166 |
+
# add the following 4 lines
|
167 |
+
image_feat = all_gather_arbitary_tensor(image_feat)
|
168 |
+
text_feat = all_gather_arbitary_tensor(text_feat)
|
169 |
+
|
170 |
+
logits = torch.matmul(image_feat, text_feat.t())
|
171 |
+
logit_scale = temperature.exp().clamp(max=100)
|
172 |
+
|
173 |
+
gt = torch.arange(logits.shape[0], device=logits.device)
|
174 |
+
loss1 = F.cross_entropy(logit_scale * logits, gt)
|
175 |
+
loss2 = F.cross_entropy(logit_scale * logits.t(), gt)
|
176 |
+
return (loss1 + loss2) / 2 # scale it up by the number of GPUs
|
177 |
+
|
178 |
+
def vl_similarity(image_feat, text_feat, temperature=1):
|
179 |
+
# Only support single GPU for now.
|
180 |
+
logits = torch.matmul(image_feat, text_feat.t())
|
181 |
+
logits = temperature.exp().clamp(max=100) * logits
|
182 |
+
return logits
|
183 |
+
|
184 |
+
def ql_multi_contrastive_loss(image_feat, text_feat, text_hash, temperature=1):
|
185 |
+
# add the following 4 lines
|
186 |
+
image_feat = all_gather_arbitary_tensor(image_feat)
|
187 |
+
text_feat = all_gather_arbitary_tensor(text_feat)
|
188 |
+
|
189 |
+
text_hash_batch = all_gather_pickle(text_hash, text_feat.device)
|
190 |
+
text_hash_all = torch.cat(text_hash_batch)
|
191 |
+
|
192 |
+
text_hash_all_unique = torch.unique(text_hash_all).tolist()
|
193 |
+
gt = torch.zeros((image_feat.shape[0], len(text_hash_all_unique)), device=text_feat.device)
|
194 |
+
text_hash_all = text_hash_all.tolist()
|
195 |
+
text_feat_unique = torch.stack([text_feat[text_hash_all.index(txt)] for txt in text_hash_all_unique])
|
196 |
+
|
197 |
+
for idx, txt in enumerate(text_hash_all):
|
198 |
+
gt[idx][text_hash_all_unique.index(txt)] = 1
|
199 |
+
|
200 |
+
logits = torch.matmul(image_feat, text_feat_unique.t())
|
201 |
+
logits = logits*temperature.exp().clamp(max=100)
|
202 |
+
|
203 |
+
loss_img = soft_cross_entropy(logits, gt)
|
204 |
+
loss_text = soft_cross_entropy(logits.t(), gt.t() / gt.t().sum(-1, keepdim=True))
|
205 |
+
|
206 |
+
loss = 0.7 * loss_img + 0.3 * loss_text
|
207 |
+
return loss
|
208 |
+
|
209 |
+
def image_text_contrastive_loss_queue(image_feat_inp, text_feat_inp, lang_enc, training):
|
210 |
+
# add the following 4 lines
|
211 |
+
image_feat = all_gather_grad(image_feat_inp.contiguous())
|
212 |
+
text_feat = all_gather_grad(text_feat_inp.contiguous())
|
213 |
+
|
214 |
+
image_feat = image_feat / (image_feat.norm(dim=-1, keepdim=True) + 1e-7)
|
215 |
+
text_feat = text_feat / (text_feat.norm(dim=-1, keepdim=True) + 1e-7)
|
216 |
+
|
217 |
+
temperature = lang_enc.logit_scale
|
218 |
+
logits = torch.matmul(image_feat, text_feat.t())
|
219 |
+
logit_scale = temperature.exp().clamp(max=100)
|
220 |
+
|
221 |
+
gt = torch.arange(logits.shape[0], device=logits.device)
|
222 |
+
loss1 = F.cross_entropy(logit_scale * logits, gt)
|
223 |
+
loss2 = F.cross_entropy(logit_scale * logits.t(), gt)
|
224 |
+
|
225 |
+
return (loss1 + loss2) / 2 # scale it up by the number of GPUs
|
xdecoder/language/misc.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
|
3 |
+
import nltk
|
4 |
+
nltk.data.path.append('/mnt/data/nltk_data')
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
from utils.constants import IMAGENET_DEFAULT_TEMPLATES
|
8 |
+
|
9 |
+
|
10 |
+
def get_tag(tokenized, tags):
|
11 |
+
if not isinstance(tags, (list, tuple)):
|
12 |
+
tags = [tags]
|
13 |
+
ret = []
|
14 |
+
for (word, pos) in nltk.pos_tag(tokenized):
|
15 |
+
for tag in tags:
|
16 |
+
if pos == tag:
|
17 |
+
ret.append(word)
|
18 |
+
return ret
|
19 |
+
|
20 |
+
def get_noun_phrase(tokenized):
|
21 |
+
# Taken from Su Nam Kim Paper...
|
22 |
+
grammar = r"""
|
23 |
+
NBAR:
|
24 |
+
{<NN.*|JJ>*<NN.*>} # Nouns and Adjectives, terminated with Nouns
|
25 |
+
|
26 |
+
NP:
|
27 |
+
{<NBAR>}
|
28 |
+
{<NBAR><IN><NBAR>} # Above, connected with in/of/etc...
|
29 |
+
"""
|
30 |
+
chunker = nltk.RegexpParser(grammar)
|
31 |
+
|
32 |
+
chunked = chunker.parse(nltk.pos_tag(tokenized))
|
33 |
+
continuous_chunk = []
|
34 |
+
current_chunk = []
|
35 |
+
|
36 |
+
for subtree in chunked:
|
37 |
+
if isinstance(subtree, nltk.Tree):
|
38 |
+
current_chunk.append(' '.join([token for token, pos in subtree.leaves()]))
|
39 |
+
elif current_chunk:
|
40 |
+
named_entity = ' '.join(current_chunk)
|
41 |
+
if named_entity not in continuous_chunk:
|
42 |
+
continuous_chunk.append(named_entity)
|
43 |
+
current_chunk = []
|
44 |
+
else:
|
45 |
+
continue
|
46 |
+
|
47 |
+
return continuous_chunk
|
48 |
+
|
49 |
+
def text_noun_with_prompt_all(text, phrase_prob=0.0, append_text=True):
|
50 |
+
tokenized = nltk.word_tokenize(text)
|
51 |
+
|
52 |
+
if random.random() >= phrase_prob:
|
53 |
+
nouns = get_tag(tokenized, ['NN', 'NNS', 'NNP'])
|
54 |
+
else:
|
55 |
+
nouns = get_noun_phrase(tokenized)
|
56 |
+
|
57 |
+
|
58 |
+
prompt_texts = [np.random.choice(IMAGENET_DEFAULT_TEMPLATES).format(noun) for noun in nouns]
|
59 |
+
|
60 |
+
if append_text:
|
61 |
+
prompt_texts += [text]
|
62 |
+
nouns += [text]
|
63 |
+
|
64 |
+
return prompt_texts, nouns
|
xdecoder/language/registry.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_model_entrypoints = {}
|
2 |
+
|
3 |
+
def register_model(fn):
|
4 |
+
module_name_split = fn.__module__.split('.')
|
5 |
+
model_name = module_name_split[-1]
|
6 |
+
_model_entrypoints[model_name] = fn
|
7 |
+
return fn
|
8 |
+
|
9 |
+
def model_entrypoints(model_name):
|
10 |
+
return _model_entrypoints[model_name]
|
11 |
+
|
12 |
+
def is_model(model_name):
|
13 |
+
return model_name in _model_entrypoints
|
xdecoder/language/vlpencoder.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
from torch.nn import functional as F
|
5 |
+
|
6 |
+
from timm.models.layers import trunc_normal_
|
7 |
+
|
8 |
+
from .registry import register_model
|
9 |
+
from ..utils import configurable
|
10 |
+
from .LangEncoder import build_tokenizer, build_lang_encoder
|
11 |
+
from utils.misc import prompt_engineering, get_prompt_templates
|
12 |
+
|
13 |
+
|
14 |
+
class LanguageEncoder(nn.Module):
|
15 |
+
|
16 |
+
@configurable
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
tokenizer,
|
20 |
+
tokenizer_type,
|
21 |
+
lang_encoder,
|
22 |
+
lang_projection,
|
23 |
+
max_token_num,
|
24 |
+
):
|
25 |
+
super().__init__()
|
26 |
+
self.tokenizer = tokenizer
|
27 |
+
self.tokenizer_type = tokenizer_type
|
28 |
+
self.lang_encoder = lang_encoder
|
29 |
+
self.lang_proj = lang_projection
|
30 |
+
self.max_token_num = max_token_num
|
31 |
+
self.logit_scale = nn.Parameter(torch.ones([]))
|
32 |
+
|
33 |
+
@classmethod
|
34 |
+
def from_config(cls, cfg):
|
35 |
+
tokenizer = build_tokenizer(cfg['MODEL']['TEXT'])
|
36 |
+
tokenizer_type = cfg['MODEL']['TEXT']['TOKENIZER']
|
37 |
+
lang_encoder = build_lang_encoder(cfg['MODEL']['TEXT'], tokenizer, cfg['VERBOSE'])
|
38 |
+
max_token_num = cfg['MODEL']['TEXT']['CONTEXT_LENGTH']
|
39 |
+
|
40 |
+
dim_lang = cfg['MODEL']['TEXT']['WIDTH']
|
41 |
+
dim_projection = cfg['MODEL']['DIM_PROJ']
|
42 |
+
lang_projection = nn.Parameter(torch.empty(dim_lang, dim_projection))
|
43 |
+
trunc_normal_(lang_projection, std=.02)
|
44 |
+
|
45 |
+
return {
|
46 |
+
"tokenizer": tokenizer,
|
47 |
+
"tokenizer_type": tokenizer_type,
|
48 |
+
"lang_encoder": lang_encoder,
|
49 |
+
"lang_projection": lang_projection,
|
50 |
+
"max_token_num": max_token_num,
|
51 |
+
}
|
52 |
+
|
53 |
+
def get_text_embeddings(self, class_names, name='default', is_eval=False, add_bgd=False, prompt=True, norm=True):
|
54 |
+
if not is_eval:
|
55 |
+
if prompt:
|
56 |
+
# randomly sample one template
|
57 |
+
arbitary_concepts = [
|
58 |
+
prompt_engineering(class_names[label].replace('-other','').replace('-merged','').replace('-stuff',''), topk=10000, suffix='.') \
|
59 |
+
for label in range(len(class_names))
|
60 |
+
]
|
61 |
+
if add_bgd:
|
62 |
+
arbitary_concepts.append("A background in coco.")
|
63 |
+
else:
|
64 |
+
arbitary_concepts = class_names
|
65 |
+
|
66 |
+
input_ids = []
|
67 |
+
attention_masks = []
|
68 |
+
for txt in arbitary_concepts:
|
69 |
+
tokens = self.tokenizer(
|
70 |
+
txt, padding='max_length', truncation=True, max_length=self.max_token_num, return_tensors='pt'
|
71 |
+
)
|
72 |
+
tokens['input_ids'].squeeze_()
|
73 |
+
tokens['attention_mask'].squeeze_()
|
74 |
+
|
75 |
+
input_ids.append(tokens['input_ids'])
|
76 |
+
attention_masks.append(tokens['attention_mask'])
|
77 |
+
|
78 |
+
arbitary_tokens = torch.stack(input_ids)
|
79 |
+
arbitary_attention_masks = torch.stack(attention_masks)
|
80 |
+
|
81 |
+
text_emb = self.forward_language((arbitary_tokens.cuda(), arbitary_attention_masks.cuda()), norm=norm)
|
82 |
+
setattr(self, '{}_text_embeddings'.format(name), text_emb)
|
83 |
+
else:
|
84 |
+
with torch.no_grad():
|
85 |
+
def extract_mean_emb(txts):
|
86 |
+
tokens = self.tokenizer(
|
87 |
+
txts, padding='max_length', truncation=True, max_length=self.max_token_num, return_tensors='pt'
|
88 |
+
)
|
89 |
+
clss_embedding = self.forward_language((tokens['input_ids'].cuda(), tokens['attention_mask'].cuda()), norm=norm)
|
90 |
+
clss_embedding = clss_embedding.mean(dim=0)
|
91 |
+
clss_embedding /= clss_embedding.norm()
|
92 |
+
return clss_embedding
|
93 |
+
|
94 |
+
templates = get_prompt_templates()
|
95 |
+
clss_embeddings = []
|
96 |
+
if prompt:
|
97 |
+
for clss in class_names:
|
98 |
+
txts = [template.format(clss.replace('-other','').replace('-merged','').replace('-stuff','')) for template in templates]
|
99 |
+
clss_embeddings.append(extract_mean_emb(txts))
|
100 |
+
else:
|
101 |
+
clss_embeddings.append(extract_mean_emb(class_names))
|
102 |
+
|
103 |
+
if add_bgd:
|
104 |
+
txts = ["A background in coco."]
|
105 |
+
clss_embeddings.append(extract_mean_emb(txts))
|
106 |
+
|
107 |
+
text_emb = torch.stack(clss_embeddings, dim=0)
|
108 |
+
setattr(self, '{}_text_embeddings'.format(name), text_emb)
|
109 |
+
|
110 |
+
def get_text_token_embeddings(self, txts, name='default', token=False, norm=False):
|
111 |
+
if not token:
|
112 |
+
tokens = self.tokenizer(
|
113 |
+
txts, padding='max_length', truncation=True, max_length=self.max_token_num, return_tensors='pt'
|
114 |
+
)
|
115 |
+
tokens = {key: value.cuda() for key, value in tokens.items()}
|
116 |
+
else:
|
117 |
+
tokens = txts
|
118 |
+
token_emb, class_emb = self.forward_language_token((tokens['input_ids'], tokens['attention_mask']), norm=norm)
|
119 |
+
ret = {"tokens": tokens,
|
120 |
+
"token_emb": token_emb,
|
121 |
+
"class_emb": class_emb,}
|
122 |
+
setattr(self, '{}_token_embeddings'.format(name), ret)
|
123 |
+
return ret
|
124 |
+
|
125 |
+
def forward_language(self, texts, norm=True):
|
126 |
+
x = self.lang_encoder(*texts)
|
127 |
+
x = x['last_hidden_state']
|
128 |
+
|
129 |
+
if self.tokenizer_type == 'clip':
|
130 |
+
x = x[torch.arange(x.size(0)), texts[0].argmax(dim=-1)]
|
131 |
+
else:
|
132 |
+
x = x[:, 0]
|
133 |
+
|
134 |
+
x = x @ self.lang_proj
|
135 |
+
if norm:
|
136 |
+
x = x / (x.norm(dim=-1, keepdim=True) + 1e-7)
|
137 |
+
return x
|
138 |
+
|
139 |
+
def forward_language_token(self, texts, norm=False):
|
140 |
+
x = self.lang_encoder(*texts)
|
141 |
+
token_x = x['last_hidden_state']
|
142 |
+
|
143 |
+
if self.tokenizer_type == 'clip':
|
144 |
+
class_x = token_x[torch.arange(token_x.size(0)), texts[0].argmax(dim=-1)]
|
145 |
+
else:
|
146 |
+
class_x = token_x[:, 0]
|
147 |
+
|
148 |
+
class_x = class_x @ self.lang_proj
|
149 |
+
token_x = token_x @ self.lang_proj
|
150 |
+
|
151 |
+
if norm:
|
152 |
+
class_x = class_x / (class_x.norm(dim=-1, keepdim=True) + 1e-7)
|
153 |
+
token_x = token_x / (token_x.norm(dim=-1, keepdim=True) + 1e-7)
|
154 |
+
|
155 |
+
return token_x, class_x
|
156 |
+
|
157 |
+
def compute_similarity(self, v_emb, name='default', fake=False):
|
158 |
+
if fake:
|
159 |
+
return None
|
160 |
+
v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)
|
161 |
+
t_emb = getattr(self, '{}_text_embeddings'.format(name))
|
162 |
+
output = self.logit_scale.exp() * v_emb @ t_emb.unsqueeze(0).transpose(1, 2)
|
163 |
+
return output
|
164 |
+
|
165 |
+
|
166 |
+
@register_model
|
167 |
+
def get_language_model(cfg, **kwargs):
|
168 |
+
return LanguageEncoder(cfg)
|