fffiloni commited on
Commit
009b227
·
1 Parent(s): f10c916

Create xdecoder_model.py

Browse files
xdecoder/architectures/xdecoder_model.py ADDED
@@ -0,0 +1,622 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # X-Decoder -- Generalized Decoding for Pixel, Image, and Language
3
+ # Copyright (c) 2022 Microsoft
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Xueyan Zou ([email protected])
6
+ # --------------------------------------------------------
7
+
8
+ import random
9
+ from typing import Tuple
10
+ from unicodedata import name
11
+
12
+ import torch
13
+ from torch import nn
14
+ from torch.nn import functional as F
15
+ import numpy as np
16
+
17
+ from .registry import register_model
18
+ from ..utils import configurable
19
+ from ..backbone import build_backbone, Backbone
20
+ from ..body import build_xdecoder_head
21
+ from ..modules import sem_seg_postprocess, bbox_postprocess
22
+ from ..language import build_language_encoder
23
+ from ..language.loss import vl_similarity
24
+
25
+ from timm.models.layers import trunc_normal_
26
+ from nltk.stem.lancaster import LancasterStemmer
27
+ from detectron2.structures import Boxes, ImageList, Instances, BitMasks, BoxMode
28
+ from detectron2.utils.memory import retry_if_cuda_oom
29
+ from detectron2.data import MetadataCatalog
30
+ from utils.misc import prompt_engineering
31
+
32
+ st = LancasterStemmer()
33
+
34
+
35
+ class X_Decoder_Model(nn.Module):
36
+ @configurable
37
+ def __init__(
38
+ self,
39
+ *,
40
+ backbone: Backbone,
41
+ sem_seg_head: nn.Module,
42
+ criterion: nn.Module,
43
+ losses: dict,
44
+ num_queries: int,
45
+ object_mask_threshold: float,
46
+ overlap_threshold: float,
47
+ metadata,
48
+ task_switch: dict,
49
+ phrase_prob: float,
50
+ size_divisibility: int,
51
+ sem_seg_postprocess_before_inference: bool,
52
+ pixel_mean: Tuple[float],
53
+ pixel_std: Tuple[float],
54
+ # inference
55
+ semantic_on: bool,
56
+ panoptic_on: bool,
57
+ instance_on: bool,
58
+ test_topk_per_image: int,
59
+ train_dataset_name: str,
60
+ retrieval_emsemble: bool,
61
+ backbone_dim: int,
62
+ dim_proj: int,
63
+ ):
64
+ super().__init__()
65
+ self.backbone = backbone
66
+ self.sem_seg_head = sem_seg_head
67
+ self.criterion = criterion
68
+ self.losses = losses
69
+ self.num_queries = num_queries
70
+ self.overlap_threshold = overlap_threshold
71
+ self.object_mask_threshold = object_mask_threshold
72
+ self.metadata = metadata
73
+ if size_divisibility < 0:
74
+ # use backbone size_divisibility if not set
75
+ size_divisibility = self.backbone.size_divisibility
76
+ self.size_divisibility = size_divisibility
77
+ self.sem_seg_postprocess_before_inference = sem_seg_postprocess_before_inference
78
+ self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
79
+ self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
80
+
81
+ # additional args
82
+ self.semantic_on = semantic_on
83
+ self.instance_on = instance_on
84
+ self.panoptic_on = panoptic_on
85
+
86
+ # caption argument
87
+ self.task_switch = task_switch
88
+ self.phrase_prob = phrase_prob
89
+
90
+ self.test_topk_per_image = test_topk_per_image
91
+ self.train_class_names = None
92
+
93
+ self.retrieval_emsemble = retrieval_emsemble
94
+ # backbone itc loss
95
+ if task_switch['retrieval'] and retrieval_emsemble:
96
+ self.backbone_proj = nn.Parameter(torch.empty(backbone_dim, dim_proj))
97
+ trunc_normal_(self.backbone_proj, std=.02)
98
+
99
+ if not self.semantic_on:
100
+ assert self.sem_seg_postprocess_before_inference
101
+
102
+ @classmethod
103
+ def from_config(cls, cfg):
104
+ enc_cfg = cfg['MODEL']['ENCODER']
105
+ dec_cfg = cfg['MODEL']['DECODER']
106
+
107
+ task_switch = {'bbox': dec_cfg.get('DETECTION', False),
108
+ 'mask': dec_cfg.get('MASK', True),
109
+ 'caption': dec_cfg['CAPTION'].get('ENABLED', False),
110
+ 'captioning': dec_cfg['CAPTIONING'].get('ENABLED', False),
111
+ 'retrieval': dec_cfg['RETRIEVAL'].get('ENABLED', False),
112
+ 'grounding': dec_cfg['GROUNDING'].get('ENABLED', False)}
113
+
114
+ # build model
115
+ extra = {'task_switch': task_switch}
116
+ backbone = build_backbone(cfg)
117
+ lang_encoder = build_language_encoder(cfg)
118
+ sem_seg_head = build_xdecoder_head(cfg, backbone.output_shape(), lang_encoder, extra)
119
+
120
+ # Training Settings.
121
+ loss_weights = {}
122
+ matcher = None
123
+ losses = {}
124
+ weight_dict = {}
125
+ grd_weight = {}
126
+ top_x_layers = {}
127
+ criterion = None
128
+ train_dataset_name = None
129
+ phrase_prob = None
130
+ # Loss parameters:
131
+ deep_supervision = None
132
+ no_object_weight = None
133
+
134
+ return {
135
+ "backbone": backbone,
136
+ "sem_seg_head": sem_seg_head,
137
+ "criterion": criterion,
138
+ "losses": losses,
139
+ "num_queries": dec_cfg['NUM_OBJECT_QUERIES'],
140
+ "object_mask_threshold": dec_cfg['TEST']['OBJECT_MASK_THRESHOLD'],
141
+ "overlap_threshold": dec_cfg['TEST']['OVERLAP_THRESHOLD'],
142
+ "metadata": None,
143
+ "size_divisibility": dec_cfg['SIZE_DIVISIBILITY'],
144
+ "sem_seg_postprocess_before_inference": (
145
+ dec_cfg['TEST']['SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE']
146
+ or dec_cfg['TEST']['PANOPTIC_ON']
147
+ or dec_cfg['TEST']['INSTANCE_ON']
148
+ ),
149
+ "pixel_mean": cfg['INPUT']['PIXEL_MEAN'],
150
+ "pixel_std": cfg['INPUT']['PIXEL_STD'],
151
+ "task_switch": task_switch,
152
+ "phrase_prob": phrase_prob,
153
+ # inference
154
+ "semantic_on": dec_cfg['TEST']['SEMANTIC_ON'],
155
+ "instance_on": dec_cfg['TEST']['INSTANCE_ON'],
156
+ "panoptic_on": dec_cfg['TEST']['PANOPTIC_ON'],
157
+ "test_topk_per_image": cfg['MODEL']['DECODER']['TEST']['DETECTIONS_PER_IMAGE'],
158
+ "train_dataset_name": train_dataset_name,
159
+ "retrieval_emsemble": dec_cfg['RETRIEVAL']['ENSEMBLE'],
160
+ "backbone_dim": cfg['MODEL']['BACKBONE_DIM'],
161
+ "dim_proj": cfg['MODEL']['DIM_PROJ'],
162
+ }
163
+
164
+ @property
165
+ def device(self):
166
+ return self.pixel_mean.device
167
+
168
+ def forward(self, batched_inputs, mode=None):
169
+ if self.training:
170
+ assert False, "Not support trianing mode."
171
+ else:
172
+ if mode == 'retrieval':
173
+ return self.evaluate_retrieval(batched_inputs)
174
+ elif mode == 'captioning':
175
+ return self.evaluate_captioning(batched_inputs)
176
+ elif mode == 'classification':
177
+ return self.evaluate_classification(batched_inputs)
178
+ elif mode in ['grounding_phrasecut', 'grounding_refcoco']:
179
+ return self.evaluate_grounding(batched_inputs, mode)
180
+ else:
181
+ return self.evaluate(batched_inputs)
182
+
183
+ def evaluate(self, batched_inputs):
184
+ images = [x["image"].to(self.device) for x in batched_inputs]
185
+ images = [(x - self.pixel_mean) / self.pixel_std for x in images]
186
+
187
+ images = ImageList.from_tensors(images, self.size_divisibility)
188
+ img_bs = images.tensor.shape[0]
189
+
190
+ targets = targets_grounding = queries_grounding = None
191
+ features = self.backbone(images.tensor)
192
+ outputs = self.sem_seg_head(features, target_queries=queries_grounding)
193
+
194
+ mask_cls_results = outputs["pred_logits"]
195
+ mask_pred_results = outputs["pred_masks"]
196
+ box_pred_results = outputs["pred_boxes"] if self.task_switch['bbox'] else [None for i in range(len(mask_pred_results))]
197
+ caption_pred_results = outputs["pred_captions"] if self.task_switch['caption'] else [None for i in range(len(mask_pred_results))]
198
+
199
+ # upsample masks
200
+ mask_pred_results = F.interpolate(
201
+ mask_pred_results,
202
+ size=(images.tensor.shape[-2], images.tensor.shape[-1]),
203
+ mode="bilinear",
204
+ align_corners=False,
205
+ )
206
+
207
+ input_size = mask_pred_results.shape[-2:]
208
+ keep_sem_bgd = self.metadata.keep_sem_bgd if hasattr(self.metadata, 'keep_sem_bgd') else False
209
+ del outputs
210
+
211
+ processed_results = []
212
+ for mask_cls_result, mask_pred_result, box_pred_result, caption_pred_result, input_per_image, image_size in zip(
213
+ mask_cls_results, mask_pred_results, box_pred_results, caption_pred_results, batched_inputs, images.image_sizes
214
+ ):
215
+ height = input_per_image.get("height", image_size[0])
216
+ width = input_per_image.get("width", image_size[1])
217
+ processed_results.append({})
218
+
219
+ if self.sem_seg_postprocess_before_inference:
220
+ mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)(
221
+ mask_pred_result, image_size, height, width
222
+ )
223
+ mask_cls_result = mask_cls_result.to(mask_pred_result)
224
+
225
+ # semantic segmentation inference
226
+ if self.semantic_on:
227
+ r = retry_if_cuda_oom(self.semantic_inference)(mask_cls_result, mask_pred_result, keep_sem_bgd)
228
+ if not self.sem_seg_postprocess_before_inference:
229
+ r = retry_if_cuda_oom(sem_seg_postprocess)(r, image_size, height, width)
230
+ processed_results[-1]["sem_seg"] = r
231
+
232
+ # panoptic segmentation inference
233
+ if self.panoptic_on:
234
+ panoptic_r = retry_if_cuda_oom(self.panoptic_inference)(mask_cls_result, mask_pred_result)
235
+ processed_results[-1]["panoptic_seg"] = panoptic_r
236
+
237
+ # instance segmentation inference
238
+ if self.instance_on:
239
+ if self.task_switch['bbox']:
240
+ box_pred_result = bbox_postprocess(box_pred_result, input_size, image_size, height, width)
241
+ instance_r = retry_if_cuda_oom(self.instance_inference)(mask_cls_result, mask_pred_result, box_pred_result)
242
+ processed_results[-1]["instances"] = instance_r
243
+ if self.task_switch['caption']:
244
+ processed_results[-1]["captions"] = caption_pred_result
245
+ processed_results[-1]["masks"] = mask_pred_result
246
+
247
+ return processed_results
248
+
249
+
250
+ def evaluate_retrieval(self, batched_inputs):
251
+ images = [x["image"].to(self.device) for x in batched_inputs]
252
+ images = [(x - self.pixel_mean) / self.pixel_std for x in images]
253
+ images = ImageList.from_tensors(images, self.size_divisibility)
254
+ img_bs = images.tensor.shape[0]
255
+
256
+ targets = targets_grounding = queries_grounding = None
257
+ features = self.backbone(images.tensor)
258
+ outputs = self.sem_seg_head(features, target_queries=queries_grounding)
259
+ v_emb_it = outputs['pred_captions'][:,-1]
260
+
261
+ # compute backbone score
262
+ if self.task_switch['retrieval'] and self.retrieval_emsemble:
263
+ _v_emb_it = features['res5']
264
+ bs,nc,_,_ = _v_emb_it.shape
265
+ _v_emb_it = _v_emb_it.reshape(bs,nc,-1)
266
+ _v_emb_it = F.adaptive_avg_pool1d(_v_emb_it, 1).reshape(bs,nc) @ self.backbone_proj
267
+
268
+ processed_results = []
269
+ for idx, batch_data in enumerate(batched_inputs):
270
+ caption_ids = []
271
+ t_emb_its = []
272
+ processed_results.append({})
273
+ for caption in batch_data['captions']:
274
+ lang_results = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings(caption)
275
+ t_emb_it = lang_results['class_emb']
276
+ caption_ids.append(batch_data['image_id'])
277
+ t_emb_its.append(t_emb_it)
278
+
279
+ t_emb_it = torch.cat(t_emb_its, dim=0)
280
+
281
+ image_embeds = [v_emb_it[idx].unsqueeze(0)]
282
+ if self.task_switch['retrieval'] and self.retrieval_emsemble:
283
+ image_embeds += [_v_emb_it[idx].unsqueeze(0)]
284
+ caption_results = {
285
+ 'image_embeds': image_embeds,
286
+ 'text_embeds': t_emb_it,
287
+ 'caption_ids': caption_ids,
288
+ 'image_ids': batch_data['image_id'],
289
+ }
290
+ processed_results[-1]["caption"] = caption_results
291
+ return processed_results
292
+
293
+ def evaluate_captioning(self, batched_inputs, extra={}):
294
+ images = [x["image"].to(self.device) for x in batched_inputs]
295
+ images = [(x - self.pixel_mean) / self.pixel_std for x in images]
296
+ images = ImageList.from_tensors(images, self.size_divisibility)
297
+ img_bs = images.tensor.shape[0]
298
+
299
+ if not hasattr(self, 'start_token'):
300
+ self.start_token = torch.tensor([[49406]*77], device=self.device)
301
+
302
+ targets = targets_grounding = queries_grounding = None
303
+ features = self.backbone(images.tensor)
304
+
305
+ captioning_mask = None
306
+ if 'captioning_mask' in batched_inputs[-1]:
307
+ captioning_mask = torch.cat([x['captioning_mask'] for x in batched_inputs])
308
+
309
+ extra.update({'start_token': self.start_token, 'captioning_mask': captioning_mask})
310
+ outputs = self.sem_seg_head(features, target_queries=queries_grounding, task='captioning_infer', extra=extra)
311
+
312
+ processed_results = []
313
+ for idx, batch_data in enumerate(batched_inputs):
314
+ processed_results.append({})
315
+ processed_results[-1]["captioning_token"] = outputs['pred_captionings'][idx]
316
+ processed_results[-1]["captioning_text"] = outputs['pred_texts'][idx].split('.')[0]
317
+ processed_results[-1]["image_id"] = batched_inputs[idx]['image_id']
318
+
319
+ return processed_results
320
+
321
+ def evaluate_classification(self, batched_inputs):
322
+ images = [x["image"].to(self.device) for x in batched_inputs]
323
+ images = [(x - self.pixel_mean) / self.pixel_std for x in images]
324
+ images = ImageList.from_tensors(images, self.size_divisibility)
325
+ img_bs = images.tensor.shape[0]
326
+
327
+ targets = targets_grounding = queries_grounding = None
328
+ features = self.backbone(images.tensor)
329
+ outputs = self.sem_seg_head(features, target_queries=queries_grounding)
330
+
331
+ processed_results = []
332
+ for idx, batch_data in enumerate(batched_inputs):
333
+ processed_results.append({})
334
+ processed_results[-1]["pred_class"] = outputs['pred_logits'][idx,-1]
335
+ return processed_results
336
+
337
+ def evaluate_grounding_baseline(self, batched_inputs, mode):
338
+ images = [x["image"].to(self.device) for x in batched_inputs]
339
+ images = [(x - self.pixel_mean) / self.pixel_std for x in images]
340
+ images = ImageList.from_tensors(images, self.size_divisibility)
341
+ img_bs = images.tensor.shape[0]
342
+
343
+ targets = targets_grounding = queries_grounding = None
344
+ features = self.backbone(images.tensor)
345
+ outputs = self.sem_seg_head(features, target_queries=queries_grounding)
346
+
347
+ mask_pred_results = outputs["pred_masks"]
348
+ caption_pred_results = outputs["pred_captions"] if self.task_switch['caption'] else [None for i in range(len(mask_pred_results))]
349
+
350
+ # upsample masks
351
+ mask_pred_results = F.interpolate(
352
+ mask_pred_results,
353
+ size=(images.tensor.shape[-2], images.tensor.shape[-1]),
354
+ mode="bilinear",
355
+ align_corners=False,
356
+ )
357
+
358
+ processed_results = []
359
+ for mask_pred_result, caption_pred_result, input_per_image, image_size in zip(
360
+ mask_pred_results, caption_pred_results, batched_inputs, images.image_sizes
361
+ ):
362
+ height = input_per_image.get("height", image_size[0])
363
+ width = input_per_image.get("width", image_size[1])
364
+ processed_results.append({})
365
+
366
+ mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)(
367
+ mask_pred_result, image_size, height, width
368
+ )[:-1]
369
+
370
+ texts_all = input_per_image['groundings']['texts']
371
+ grd_masks = []
372
+ for texts in texts_all:
373
+ if mode == 'grounding_refcoco':
374
+ self.sem_seg_head.predictor.lang_encoder.get_text_embeddings(texts, name='grounding', prompt=False, is_eval=True)
375
+ elif mode == 'grounding_phrasecut':
376
+ self.sem_seg_head.predictor.lang_encoder.get_text_embeddings(texts, name='grounding', prompt=True, is_eval=False)
377
+ t_emb = getattr(self.sem_seg_head.predictor.lang_encoder, "{}_text_embeddings".format('grounding')).t()
378
+ v_emb = caption_pred_result[:-1]
379
+ v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)
380
+ vt_sim = v_emb @ t_emb
381
+ max_id = vt_sim.max(0)[1][0]
382
+ grd_masks += [mask_pred_result[max_id]]
383
+ processed_results[-1]['grounding_mask'] = torch.stack(grd_masks)
384
+
385
+ return processed_results
386
+
387
+ def evaluate_grounding(self, batched_inputs, mode):
388
+ images = [x["image"].to(self.device) for x in batched_inputs]
389
+ images = [(x - self.pixel_mean) / self.pixel_std for x in images]
390
+ images = ImageList.from_tensors(images, self.size_divisibility)
391
+
392
+ extra = {}
393
+ # mask_pred_results = []
394
+ # for idx, batch_per_image in enumerate(batched_inputs):
395
+ # grd_texts = batch_per_image['groundings']['texts']
396
+ # grd_masks = []
397
+ # for anno_text in grd_texts:
398
+ # gtext = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings([anno_text[0]], name='grounding', token=False, norm=False)
399
+ # token_emb = gtext['token_emb']
400
+ # tokens = gtext['tokens']
401
+
402
+ # grd_emb = token_emb[0][tokens['attention_mask'].bool()[0]]
403
+ # extra['grounding_tokens'] = grd_emb[:,None]
404
+
405
+ # assert len(images.tensor) == 1, "grounding evaluation only support single batch size now"
406
+ # features = self.backbone(images.tensor)
407
+ # outputs = self.sem_seg_head(features, extra=extra, task='grounding_eval')
408
+
409
+ # pred_gmasks = outputs['pred_masks'][idx,self.num_queries:2*self.num_queries-1]
410
+ # v_emb = outputs['pred_captions'][idx,self.num_queries:2*self.num_queries-1]
411
+ # t_emb = grd_emb[-1:]
412
+
413
+ # t_emb = t_emb / (t_emb.norm(dim=-1, keepdim=True) + 1e-7)
414
+ # v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)
415
+
416
+ # temperature = self.sem_seg_head.predictor.lang_encoder.logit_scale
417
+ # out_prob = vl_similarity(v_emb, t_emb, temperature=temperature)
418
+
419
+ # matched_id = out_prob.max(0)[1]
420
+ # grd_masks += [pred_gmasks[matched_id,:,:]]
421
+ # mask_pred_results += [torch.cat(grd_masks)]
422
+
423
+ # comment for multi object inference.
424
+ mask_pred_results = []
425
+ for idx, batch_per_image in enumerate(batched_inputs):
426
+ grd_texts = batch_per_image['groundings']['texts']
427
+ grd_texts = [x[0] for x in grd_texts]
428
+
429
+ gtext = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings(grd_texts, name='grounding', token=False, norm=False)
430
+ token_emb = gtext['token_emb']
431
+ tokens = gtext['tokens']
432
+ query_emb = token_emb[tokens['attention_mask'].bool()]
433
+ extra['grounding_tokens'] = query_emb[:,None]
434
+
435
+ features = self.backbone(images.tensor)
436
+ outputs = self.sem_seg_head(features, extra=extra, task='grounding_eval')
437
+
438
+ pred_gmasks = outputs['pred_masks'][idx,self.num_queries:2*self.num_queries-1]
439
+ v_emb = outputs['pred_captions'][idx,self.num_queries:2*self.num_queries-1]
440
+ t_emb = gtext['class_emb']
441
+
442
+ t_emb = t_emb / (t_emb.norm(dim=-1, keepdim=True) + 1e-7)
443
+ v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)
444
+
445
+ temperature = self.sem_seg_head.predictor.lang_encoder.logit_scale
446
+ out_prob = vl_similarity(v_emb, t_emb, temperature=temperature)
447
+
448
+ matched_id = out_prob.max(0)[1]
449
+ mask_pred_results += [pred_gmasks[matched_id,:,:]]
450
+
451
+ for i in range(len(mask_pred_results)):
452
+ # upsample masks
453
+ mask_pred_results[i] = F.interpolate(
454
+ mask_pred_results[i][None,],
455
+ size=(images.tensor.shape[-2], images.tensor.shape[-1]),
456
+ mode="bilinear",
457
+ align_corners=False,
458
+ )[0]
459
+
460
+ processed_results = []
461
+ for mask_pred_result, input_per_image, image_size in zip(
462
+ mask_pred_results, batched_inputs, images.image_sizes
463
+ ):
464
+ height = input_per_image.get("height", image_size[0])
465
+ width = input_per_image.get("width", image_size[1])
466
+ processed_results.append({})
467
+
468
+ mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)(
469
+ mask_pred_result, image_size, height, width
470
+ )
471
+ processed_results[-1]['grounding_mask'] = mask_pred_result
472
+
473
+ # compute bbox
474
+ # bbox = BitMasks(mask_pred_result > 0).get_bounding_boxes()
475
+ # bbox = BoxMode.convert(bbox.tensor, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS)
476
+ # processed_results[-1]['grounding_box'] = bbox
477
+
478
+ return processed_results
479
+
480
+ def prepare_vlp_targets(self, batched_inputs, device):
481
+ input_ids = []
482
+ attention_mask = []
483
+ for cnt, x in enumerate(batched_inputs):
484
+ captions = x['captions']
485
+ randid = random.randint(0, len(captions)-1)
486
+ input_ids += x['tokens']['input_ids'][randid:randid+1]
487
+ attention_mask += x['tokens']['attention_mask'][randid:randid+1]
488
+
489
+ input_ids = torch.stack(input_ids)
490
+ attention_mask = torch.stack(attention_mask)
491
+ tokens = {"input_ids": input_ids, "attention_mask": attention_mask}
492
+ lang_results = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings(tokens, token=True)
493
+
494
+ target_vlp = []
495
+ for cnt, x in enumerate(batched_inputs):
496
+ target_dict = {}
497
+ target_dict["caption_tokens"] = lang_results['token_emb'][cnt:cnt+1]
498
+ target_dict["caption_proj"] = lang_results['class_emb'][cnt:cnt+1]
499
+ target_dict["caption_tokenids"] = lang_results['tokens']['input_ids'][cnt:cnt+1]
500
+ target_dict["caption_mask"] = lang_results['tokens']['attention_mask'][cnt:cnt+1]
501
+ target_vlp.append(target_dict)
502
+ return target_vlp
503
+
504
+ def semantic_inference(self, mask_cls, mask_pred, keep_sem_bgd=False):
505
+ if keep_sem_bgd:
506
+ mask_cls = F.softmax(mask_cls, dim=-1)
507
+ else:
508
+ mask_cls = F.softmax(mask_cls, dim=-1)[..., :-1]
509
+ mask_pred = mask_pred.sigmoid()
510
+ semseg = torch.einsum("qc,qhw->chw", mask_cls, mask_pred)
511
+ return semseg
512
+
513
+ def panoptic_inference(self, mask_cls, mask_pred):
514
+ scores, labels = F.softmax(mask_cls, dim=-1).max(-1)
515
+ mask_pred = mask_pred.sigmoid()
516
+
517
+ keep = labels.ne(self.sem_seg_head.num_classes) & (scores > self.object_mask_threshold)
518
+ cur_scores = scores[keep]
519
+ cur_classes = labels[keep]
520
+ cur_masks = mask_pred[keep]
521
+ cur_mask_cls = mask_cls[keep]
522
+ cur_mask_cls = cur_mask_cls[:, :-1]
523
+ cur_prob_masks = cur_scores.view(-1, 1, 1) * cur_masks
524
+
525
+ h, w = cur_masks.shape[-2:]
526
+ panoptic_seg = torch.zeros((h, w), dtype=torch.int32, device=cur_masks.device)
527
+ segments_info = []
528
+
529
+ current_segment_id = 0
530
+
531
+ if cur_masks.shape[0] == 0:
532
+ # We didn't detect any mask :(
533
+ return panoptic_seg, segments_info
534
+ else:
535
+ # take argmax
536
+ cur_mask_ids = cur_prob_masks.argmax(0)
537
+ stuff_memory_list = {}
538
+ thing_dataset_id_to_contiguous_id = self.metadata.thing_dataset_id_to_contiguous_id if hasattr(self.metadata, 'thing_dataset_id_to_contiguous_id') else {}
539
+ for k in range(cur_classes.shape[0]):
540
+ pred_class = cur_classes[k].item()
541
+ isthing = pred_class in thing_dataset_id_to_contiguous_id.values()
542
+ mask_area = (cur_mask_ids == k).sum().item()
543
+ original_area = (cur_masks[k] >= 0.5).sum().item()
544
+ mask = (cur_mask_ids == k) & (cur_masks[k] >= 0.5)
545
+
546
+ if mask_area > 0 and original_area > 0 and mask.sum().item() > 0:
547
+ if mask_area / original_area < self.overlap_threshold:
548
+ continue
549
+
550
+ # merge stuff regions
551
+ if not isthing:
552
+ if int(pred_class) in stuff_memory_list.keys():
553
+ panoptic_seg[mask] = stuff_memory_list[int(pred_class)]
554
+ continue
555
+ else:
556
+ stuff_memory_list[int(pred_class)] = current_segment_id + 1
557
+
558
+ current_segment_id += 1
559
+ panoptic_seg[mask] = current_segment_id
560
+
561
+ segments_info.append(
562
+ {
563
+ "id": current_segment_id,
564
+ "isthing": bool(isthing),
565
+ "category_id": int(pred_class),
566
+ }
567
+ )
568
+ return panoptic_seg, segments_info
569
+
570
+ def instance_inference(self, mask_cls, mask_pred, box_pred):
571
+ # mask_pred is already processed to have the same shape as original input
572
+ image_size = mask_pred.shape[-2:]
573
+
574
+ # [Q, K]
575
+ scores = F.softmax(mask_cls, dim=-1)[:, :-1]
576
+ labels = torch.arange(self.sem_seg_head.num_classes, device=self.device).unsqueeze(0).repeat(self.num_queries, 1).flatten(0, 1)
577
+ # scores_per_image, topk_indices = scores.flatten(0, 1).topk(self.num_queries, sorted=False)
578
+ scores_per_image, topk_indices = scores.flatten(0, 1).topk(self.test_topk_per_image, sorted=False)
579
+
580
+ labels_per_image = labels[topk_indices]
581
+ topk_indices = (topk_indices // self.sem_seg_head.num_classes)
582
+ # mask_pred = mask_pred.unsqueeze(1).repeat(1, self.sem_seg_head.num_classes, 1).flatten(0, 1)
583
+ mask_pred = mask_pred[topk_indices]
584
+ if box_pred is not None:
585
+ box_pred = box_pred[topk_indices]
586
+
587
+ # if this is panoptic segmentation, we only keep the "thing" classes
588
+ if self.panoptic_on:
589
+ thing_dataset_id_to_contiguous_id = self.metadata.thing_dataset_id_to_contiguous_id if hasattr(self.metadata, 'thing_dataset_id_to_contiguous_id') else {}
590
+ keep = torch.zeros_like(scores_per_image).bool()
591
+ for i, lab in enumerate(labels_per_image):
592
+ keep[i] = lab in thing_dataset_id_to_contiguous_id.values()
593
+
594
+ scores_per_image = scores_per_image[keep]
595
+ labels_per_image = labels_per_image[keep]
596
+ mask_pred = mask_pred[keep]
597
+
598
+ if box_pred is not None:
599
+ box_pred = box_pred[keep]
600
+
601
+ result = Instances(image_size)
602
+ # mask (before sigmoid)
603
+ result.pred_masks = (mask_pred > 0).float()
604
+ # result.pred_boxes = Boxes(torch.zeros(mask_pred.size(0), 4))
605
+ # Uncomment the following to get boxes from masks (this is slow)
606
+
607
+ if box_pred is not None:
608
+ result.pred_boxes = BitMasks(mask_pred > 0).get_bounding_boxes()
609
+ else:
610
+ result.pred_boxes = Boxes(torch.zeros(mask_pred.size(0), 4))
611
+
612
+ # calculate average mask prob
613
+ mask_scores_per_image = (mask_pred.sigmoid().flatten(1) * result.pred_masks.flatten(1)).sum(1) / (result.pred_masks.flatten(1).sum(1) + 1e-6)
614
+ result.scores = scores_per_image * mask_scores_per_image
615
+ result.pred_classes = labels_per_image
616
+
617
+ return result
618
+
619
+
620
+ @register_model
621
+ def get_segmentation_model(cfg, **kwargs):
622
+ return X_Decoder_Model(cfg)