Spaces:
Paused
Paused
Create xdecoder_model.py
Browse files
xdecoder/architectures/xdecoder_model.py
ADDED
@@ -0,0 +1,622 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# X-Decoder -- Generalized Decoding for Pixel, Image, and Language
|
3 |
+
# Copyright (c) 2022 Microsoft
|
4 |
+
# Licensed under The MIT License [see LICENSE for details]
|
5 |
+
# Written by Xueyan Zou ([email protected])
|
6 |
+
# --------------------------------------------------------
|
7 |
+
|
8 |
+
import random
|
9 |
+
from typing import Tuple
|
10 |
+
from unicodedata import name
|
11 |
+
|
12 |
+
import torch
|
13 |
+
from torch import nn
|
14 |
+
from torch.nn import functional as F
|
15 |
+
import numpy as np
|
16 |
+
|
17 |
+
from .registry import register_model
|
18 |
+
from ..utils import configurable
|
19 |
+
from ..backbone import build_backbone, Backbone
|
20 |
+
from ..body import build_xdecoder_head
|
21 |
+
from ..modules import sem_seg_postprocess, bbox_postprocess
|
22 |
+
from ..language import build_language_encoder
|
23 |
+
from ..language.loss import vl_similarity
|
24 |
+
|
25 |
+
from timm.models.layers import trunc_normal_
|
26 |
+
from nltk.stem.lancaster import LancasterStemmer
|
27 |
+
from detectron2.structures import Boxes, ImageList, Instances, BitMasks, BoxMode
|
28 |
+
from detectron2.utils.memory import retry_if_cuda_oom
|
29 |
+
from detectron2.data import MetadataCatalog
|
30 |
+
from utils.misc import prompt_engineering
|
31 |
+
|
32 |
+
st = LancasterStemmer()
|
33 |
+
|
34 |
+
|
35 |
+
class X_Decoder_Model(nn.Module):
|
36 |
+
@configurable
|
37 |
+
def __init__(
|
38 |
+
self,
|
39 |
+
*,
|
40 |
+
backbone: Backbone,
|
41 |
+
sem_seg_head: nn.Module,
|
42 |
+
criterion: nn.Module,
|
43 |
+
losses: dict,
|
44 |
+
num_queries: int,
|
45 |
+
object_mask_threshold: float,
|
46 |
+
overlap_threshold: float,
|
47 |
+
metadata,
|
48 |
+
task_switch: dict,
|
49 |
+
phrase_prob: float,
|
50 |
+
size_divisibility: int,
|
51 |
+
sem_seg_postprocess_before_inference: bool,
|
52 |
+
pixel_mean: Tuple[float],
|
53 |
+
pixel_std: Tuple[float],
|
54 |
+
# inference
|
55 |
+
semantic_on: bool,
|
56 |
+
panoptic_on: bool,
|
57 |
+
instance_on: bool,
|
58 |
+
test_topk_per_image: int,
|
59 |
+
train_dataset_name: str,
|
60 |
+
retrieval_emsemble: bool,
|
61 |
+
backbone_dim: int,
|
62 |
+
dim_proj: int,
|
63 |
+
):
|
64 |
+
super().__init__()
|
65 |
+
self.backbone = backbone
|
66 |
+
self.sem_seg_head = sem_seg_head
|
67 |
+
self.criterion = criterion
|
68 |
+
self.losses = losses
|
69 |
+
self.num_queries = num_queries
|
70 |
+
self.overlap_threshold = overlap_threshold
|
71 |
+
self.object_mask_threshold = object_mask_threshold
|
72 |
+
self.metadata = metadata
|
73 |
+
if size_divisibility < 0:
|
74 |
+
# use backbone size_divisibility if not set
|
75 |
+
size_divisibility = self.backbone.size_divisibility
|
76 |
+
self.size_divisibility = size_divisibility
|
77 |
+
self.sem_seg_postprocess_before_inference = sem_seg_postprocess_before_inference
|
78 |
+
self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
|
79 |
+
self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
|
80 |
+
|
81 |
+
# additional args
|
82 |
+
self.semantic_on = semantic_on
|
83 |
+
self.instance_on = instance_on
|
84 |
+
self.panoptic_on = panoptic_on
|
85 |
+
|
86 |
+
# caption argument
|
87 |
+
self.task_switch = task_switch
|
88 |
+
self.phrase_prob = phrase_prob
|
89 |
+
|
90 |
+
self.test_topk_per_image = test_topk_per_image
|
91 |
+
self.train_class_names = None
|
92 |
+
|
93 |
+
self.retrieval_emsemble = retrieval_emsemble
|
94 |
+
# backbone itc loss
|
95 |
+
if task_switch['retrieval'] and retrieval_emsemble:
|
96 |
+
self.backbone_proj = nn.Parameter(torch.empty(backbone_dim, dim_proj))
|
97 |
+
trunc_normal_(self.backbone_proj, std=.02)
|
98 |
+
|
99 |
+
if not self.semantic_on:
|
100 |
+
assert self.sem_seg_postprocess_before_inference
|
101 |
+
|
102 |
+
@classmethod
|
103 |
+
def from_config(cls, cfg):
|
104 |
+
enc_cfg = cfg['MODEL']['ENCODER']
|
105 |
+
dec_cfg = cfg['MODEL']['DECODER']
|
106 |
+
|
107 |
+
task_switch = {'bbox': dec_cfg.get('DETECTION', False),
|
108 |
+
'mask': dec_cfg.get('MASK', True),
|
109 |
+
'caption': dec_cfg['CAPTION'].get('ENABLED', False),
|
110 |
+
'captioning': dec_cfg['CAPTIONING'].get('ENABLED', False),
|
111 |
+
'retrieval': dec_cfg['RETRIEVAL'].get('ENABLED', False),
|
112 |
+
'grounding': dec_cfg['GROUNDING'].get('ENABLED', False)}
|
113 |
+
|
114 |
+
# build model
|
115 |
+
extra = {'task_switch': task_switch}
|
116 |
+
backbone = build_backbone(cfg)
|
117 |
+
lang_encoder = build_language_encoder(cfg)
|
118 |
+
sem_seg_head = build_xdecoder_head(cfg, backbone.output_shape(), lang_encoder, extra)
|
119 |
+
|
120 |
+
# Training Settings.
|
121 |
+
loss_weights = {}
|
122 |
+
matcher = None
|
123 |
+
losses = {}
|
124 |
+
weight_dict = {}
|
125 |
+
grd_weight = {}
|
126 |
+
top_x_layers = {}
|
127 |
+
criterion = None
|
128 |
+
train_dataset_name = None
|
129 |
+
phrase_prob = None
|
130 |
+
# Loss parameters:
|
131 |
+
deep_supervision = None
|
132 |
+
no_object_weight = None
|
133 |
+
|
134 |
+
return {
|
135 |
+
"backbone": backbone,
|
136 |
+
"sem_seg_head": sem_seg_head,
|
137 |
+
"criterion": criterion,
|
138 |
+
"losses": losses,
|
139 |
+
"num_queries": dec_cfg['NUM_OBJECT_QUERIES'],
|
140 |
+
"object_mask_threshold": dec_cfg['TEST']['OBJECT_MASK_THRESHOLD'],
|
141 |
+
"overlap_threshold": dec_cfg['TEST']['OVERLAP_THRESHOLD'],
|
142 |
+
"metadata": None,
|
143 |
+
"size_divisibility": dec_cfg['SIZE_DIVISIBILITY'],
|
144 |
+
"sem_seg_postprocess_before_inference": (
|
145 |
+
dec_cfg['TEST']['SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE']
|
146 |
+
or dec_cfg['TEST']['PANOPTIC_ON']
|
147 |
+
or dec_cfg['TEST']['INSTANCE_ON']
|
148 |
+
),
|
149 |
+
"pixel_mean": cfg['INPUT']['PIXEL_MEAN'],
|
150 |
+
"pixel_std": cfg['INPUT']['PIXEL_STD'],
|
151 |
+
"task_switch": task_switch,
|
152 |
+
"phrase_prob": phrase_prob,
|
153 |
+
# inference
|
154 |
+
"semantic_on": dec_cfg['TEST']['SEMANTIC_ON'],
|
155 |
+
"instance_on": dec_cfg['TEST']['INSTANCE_ON'],
|
156 |
+
"panoptic_on": dec_cfg['TEST']['PANOPTIC_ON'],
|
157 |
+
"test_topk_per_image": cfg['MODEL']['DECODER']['TEST']['DETECTIONS_PER_IMAGE'],
|
158 |
+
"train_dataset_name": train_dataset_name,
|
159 |
+
"retrieval_emsemble": dec_cfg['RETRIEVAL']['ENSEMBLE'],
|
160 |
+
"backbone_dim": cfg['MODEL']['BACKBONE_DIM'],
|
161 |
+
"dim_proj": cfg['MODEL']['DIM_PROJ'],
|
162 |
+
}
|
163 |
+
|
164 |
+
@property
|
165 |
+
def device(self):
|
166 |
+
return self.pixel_mean.device
|
167 |
+
|
168 |
+
def forward(self, batched_inputs, mode=None):
|
169 |
+
if self.training:
|
170 |
+
assert False, "Not support trianing mode."
|
171 |
+
else:
|
172 |
+
if mode == 'retrieval':
|
173 |
+
return self.evaluate_retrieval(batched_inputs)
|
174 |
+
elif mode == 'captioning':
|
175 |
+
return self.evaluate_captioning(batched_inputs)
|
176 |
+
elif mode == 'classification':
|
177 |
+
return self.evaluate_classification(batched_inputs)
|
178 |
+
elif mode in ['grounding_phrasecut', 'grounding_refcoco']:
|
179 |
+
return self.evaluate_grounding(batched_inputs, mode)
|
180 |
+
else:
|
181 |
+
return self.evaluate(batched_inputs)
|
182 |
+
|
183 |
+
def evaluate(self, batched_inputs):
|
184 |
+
images = [x["image"].to(self.device) for x in batched_inputs]
|
185 |
+
images = [(x - self.pixel_mean) / self.pixel_std for x in images]
|
186 |
+
|
187 |
+
images = ImageList.from_tensors(images, self.size_divisibility)
|
188 |
+
img_bs = images.tensor.shape[0]
|
189 |
+
|
190 |
+
targets = targets_grounding = queries_grounding = None
|
191 |
+
features = self.backbone(images.tensor)
|
192 |
+
outputs = self.sem_seg_head(features, target_queries=queries_grounding)
|
193 |
+
|
194 |
+
mask_cls_results = outputs["pred_logits"]
|
195 |
+
mask_pred_results = outputs["pred_masks"]
|
196 |
+
box_pred_results = outputs["pred_boxes"] if self.task_switch['bbox'] else [None for i in range(len(mask_pred_results))]
|
197 |
+
caption_pred_results = outputs["pred_captions"] if self.task_switch['caption'] else [None for i in range(len(mask_pred_results))]
|
198 |
+
|
199 |
+
# upsample masks
|
200 |
+
mask_pred_results = F.interpolate(
|
201 |
+
mask_pred_results,
|
202 |
+
size=(images.tensor.shape[-2], images.tensor.shape[-1]),
|
203 |
+
mode="bilinear",
|
204 |
+
align_corners=False,
|
205 |
+
)
|
206 |
+
|
207 |
+
input_size = mask_pred_results.shape[-2:]
|
208 |
+
keep_sem_bgd = self.metadata.keep_sem_bgd if hasattr(self.metadata, 'keep_sem_bgd') else False
|
209 |
+
del outputs
|
210 |
+
|
211 |
+
processed_results = []
|
212 |
+
for mask_cls_result, mask_pred_result, box_pred_result, caption_pred_result, input_per_image, image_size in zip(
|
213 |
+
mask_cls_results, mask_pred_results, box_pred_results, caption_pred_results, batched_inputs, images.image_sizes
|
214 |
+
):
|
215 |
+
height = input_per_image.get("height", image_size[0])
|
216 |
+
width = input_per_image.get("width", image_size[1])
|
217 |
+
processed_results.append({})
|
218 |
+
|
219 |
+
if self.sem_seg_postprocess_before_inference:
|
220 |
+
mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)(
|
221 |
+
mask_pred_result, image_size, height, width
|
222 |
+
)
|
223 |
+
mask_cls_result = mask_cls_result.to(mask_pred_result)
|
224 |
+
|
225 |
+
# semantic segmentation inference
|
226 |
+
if self.semantic_on:
|
227 |
+
r = retry_if_cuda_oom(self.semantic_inference)(mask_cls_result, mask_pred_result, keep_sem_bgd)
|
228 |
+
if not self.sem_seg_postprocess_before_inference:
|
229 |
+
r = retry_if_cuda_oom(sem_seg_postprocess)(r, image_size, height, width)
|
230 |
+
processed_results[-1]["sem_seg"] = r
|
231 |
+
|
232 |
+
# panoptic segmentation inference
|
233 |
+
if self.panoptic_on:
|
234 |
+
panoptic_r = retry_if_cuda_oom(self.panoptic_inference)(mask_cls_result, mask_pred_result)
|
235 |
+
processed_results[-1]["panoptic_seg"] = panoptic_r
|
236 |
+
|
237 |
+
# instance segmentation inference
|
238 |
+
if self.instance_on:
|
239 |
+
if self.task_switch['bbox']:
|
240 |
+
box_pred_result = bbox_postprocess(box_pred_result, input_size, image_size, height, width)
|
241 |
+
instance_r = retry_if_cuda_oom(self.instance_inference)(mask_cls_result, mask_pred_result, box_pred_result)
|
242 |
+
processed_results[-1]["instances"] = instance_r
|
243 |
+
if self.task_switch['caption']:
|
244 |
+
processed_results[-1]["captions"] = caption_pred_result
|
245 |
+
processed_results[-1]["masks"] = mask_pred_result
|
246 |
+
|
247 |
+
return processed_results
|
248 |
+
|
249 |
+
|
250 |
+
def evaluate_retrieval(self, batched_inputs):
|
251 |
+
images = [x["image"].to(self.device) for x in batched_inputs]
|
252 |
+
images = [(x - self.pixel_mean) / self.pixel_std for x in images]
|
253 |
+
images = ImageList.from_tensors(images, self.size_divisibility)
|
254 |
+
img_bs = images.tensor.shape[0]
|
255 |
+
|
256 |
+
targets = targets_grounding = queries_grounding = None
|
257 |
+
features = self.backbone(images.tensor)
|
258 |
+
outputs = self.sem_seg_head(features, target_queries=queries_grounding)
|
259 |
+
v_emb_it = outputs['pred_captions'][:,-1]
|
260 |
+
|
261 |
+
# compute backbone score
|
262 |
+
if self.task_switch['retrieval'] and self.retrieval_emsemble:
|
263 |
+
_v_emb_it = features['res5']
|
264 |
+
bs,nc,_,_ = _v_emb_it.shape
|
265 |
+
_v_emb_it = _v_emb_it.reshape(bs,nc,-1)
|
266 |
+
_v_emb_it = F.adaptive_avg_pool1d(_v_emb_it, 1).reshape(bs,nc) @ self.backbone_proj
|
267 |
+
|
268 |
+
processed_results = []
|
269 |
+
for idx, batch_data in enumerate(batched_inputs):
|
270 |
+
caption_ids = []
|
271 |
+
t_emb_its = []
|
272 |
+
processed_results.append({})
|
273 |
+
for caption in batch_data['captions']:
|
274 |
+
lang_results = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings(caption)
|
275 |
+
t_emb_it = lang_results['class_emb']
|
276 |
+
caption_ids.append(batch_data['image_id'])
|
277 |
+
t_emb_its.append(t_emb_it)
|
278 |
+
|
279 |
+
t_emb_it = torch.cat(t_emb_its, dim=0)
|
280 |
+
|
281 |
+
image_embeds = [v_emb_it[idx].unsqueeze(0)]
|
282 |
+
if self.task_switch['retrieval'] and self.retrieval_emsemble:
|
283 |
+
image_embeds += [_v_emb_it[idx].unsqueeze(0)]
|
284 |
+
caption_results = {
|
285 |
+
'image_embeds': image_embeds,
|
286 |
+
'text_embeds': t_emb_it,
|
287 |
+
'caption_ids': caption_ids,
|
288 |
+
'image_ids': batch_data['image_id'],
|
289 |
+
}
|
290 |
+
processed_results[-1]["caption"] = caption_results
|
291 |
+
return processed_results
|
292 |
+
|
293 |
+
def evaluate_captioning(self, batched_inputs, extra={}):
|
294 |
+
images = [x["image"].to(self.device) for x in batched_inputs]
|
295 |
+
images = [(x - self.pixel_mean) / self.pixel_std for x in images]
|
296 |
+
images = ImageList.from_tensors(images, self.size_divisibility)
|
297 |
+
img_bs = images.tensor.shape[0]
|
298 |
+
|
299 |
+
if not hasattr(self, 'start_token'):
|
300 |
+
self.start_token = torch.tensor([[49406]*77], device=self.device)
|
301 |
+
|
302 |
+
targets = targets_grounding = queries_grounding = None
|
303 |
+
features = self.backbone(images.tensor)
|
304 |
+
|
305 |
+
captioning_mask = None
|
306 |
+
if 'captioning_mask' in batched_inputs[-1]:
|
307 |
+
captioning_mask = torch.cat([x['captioning_mask'] for x in batched_inputs])
|
308 |
+
|
309 |
+
extra.update({'start_token': self.start_token, 'captioning_mask': captioning_mask})
|
310 |
+
outputs = self.sem_seg_head(features, target_queries=queries_grounding, task='captioning_infer', extra=extra)
|
311 |
+
|
312 |
+
processed_results = []
|
313 |
+
for idx, batch_data in enumerate(batched_inputs):
|
314 |
+
processed_results.append({})
|
315 |
+
processed_results[-1]["captioning_token"] = outputs['pred_captionings'][idx]
|
316 |
+
processed_results[-1]["captioning_text"] = outputs['pred_texts'][idx].split('.')[0]
|
317 |
+
processed_results[-1]["image_id"] = batched_inputs[idx]['image_id']
|
318 |
+
|
319 |
+
return processed_results
|
320 |
+
|
321 |
+
def evaluate_classification(self, batched_inputs):
|
322 |
+
images = [x["image"].to(self.device) for x in batched_inputs]
|
323 |
+
images = [(x - self.pixel_mean) / self.pixel_std for x in images]
|
324 |
+
images = ImageList.from_tensors(images, self.size_divisibility)
|
325 |
+
img_bs = images.tensor.shape[0]
|
326 |
+
|
327 |
+
targets = targets_grounding = queries_grounding = None
|
328 |
+
features = self.backbone(images.tensor)
|
329 |
+
outputs = self.sem_seg_head(features, target_queries=queries_grounding)
|
330 |
+
|
331 |
+
processed_results = []
|
332 |
+
for idx, batch_data in enumerate(batched_inputs):
|
333 |
+
processed_results.append({})
|
334 |
+
processed_results[-1]["pred_class"] = outputs['pred_logits'][idx,-1]
|
335 |
+
return processed_results
|
336 |
+
|
337 |
+
def evaluate_grounding_baseline(self, batched_inputs, mode):
|
338 |
+
images = [x["image"].to(self.device) for x in batched_inputs]
|
339 |
+
images = [(x - self.pixel_mean) / self.pixel_std for x in images]
|
340 |
+
images = ImageList.from_tensors(images, self.size_divisibility)
|
341 |
+
img_bs = images.tensor.shape[0]
|
342 |
+
|
343 |
+
targets = targets_grounding = queries_grounding = None
|
344 |
+
features = self.backbone(images.tensor)
|
345 |
+
outputs = self.sem_seg_head(features, target_queries=queries_grounding)
|
346 |
+
|
347 |
+
mask_pred_results = outputs["pred_masks"]
|
348 |
+
caption_pred_results = outputs["pred_captions"] if self.task_switch['caption'] else [None for i in range(len(mask_pred_results))]
|
349 |
+
|
350 |
+
# upsample masks
|
351 |
+
mask_pred_results = F.interpolate(
|
352 |
+
mask_pred_results,
|
353 |
+
size=(images.tensor.shape[-2], images.tensor.shape[-1]),
|
354 |
+
mode="bilinear",
|
355 |
+
align_corners=False,
|
356 |
+
)
|
357 |
+
|
358 |
+
processed_results = []
|
359 |
+
for mask_pred_result, caption_pred_result, input_per_image, image_size in zip(
|
360 |
+
mask_pred_results, caption_pred_results, batched_inputs, images.image_sizes
|
361 |
+
):
|
362 |
+
height = input_per_image.get("height", image_size[0])
|
363 |
+
width = input_per_image.get("width", image_size[1])
|
364 |
+
processed_results.append({})
|
365 |
+
|
366 |
+
mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)(
|
367 |
+
mask_pred_result, image_size, height, width
|
368 |
+
)[:-1]
|
369 |
+
|
370 |
+
texts_all = input_per_image['groundings']['texts']
|
371 |
+
grd_masks = []
|
372 |
+
for texts in texts_all:
|
373 |
+
if mode == 'grounding_refcoco':
|
374 |
+
self.sem_seg_head.predictor.lang_encoder.get_text_embeddings(texts, name='grounding', prompt=False, is_eval=True)
|
375 |
+
elif mode == 'grounding_phrasecut':
|
376 |
+
self.sem_seg_head.predictor.lang_encoder.get_text_embeddings(texts, name='grounding', prompt=True, is_eval=False)
|
377 |
+
t_emb = getattr(self.sem_seg_head.predictor.lang_encoder, "{}_text_embeddings".format('grounding')).t()
|
378 |
+
v_emb = caption_pred_result[:-1]
|
379 |
+
v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)
|
380 |
+
vt_sim = v_emb @ t_emb
|
381 |
+
max_id = vt_sim.max(0)[1][0]
|
382 |
+
grd_masks += [mask_pred_result[max_id]]
|
383 |
+
processed_results[-1]['grounding_mask'] = torch.stack(grd_masks)
|
384 |
+
|
385 |
+
return processed_results
|
386 |
+
|
387 |
+
def evaluate_grounding(self, batched_inputs, mode):
|
388 |
+
images = [x["image"].to(self.device) for x in batched_inputs]
|
389 |
+
images = [(x - self.pixel_mean) / self.pixel_std for x in images]
|
390 |
+
images = ImageList.from_tensors(images, self.size_divisibility)
|
391 |
+
|
392 |
+
extra = {}
|
393 |
+
# mask_pred_results = []
|
394 |
+
# for idx, batch_per_image in enumerate(batched_inputs):
|
395 |
+
# grd_texts = batch_per_image['groundings']['texts']
|
396 |
+
# grd_masks = []
|
397 |
+
# for anno_text in grd_texts:
|
398 |
+
# gtext = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings([anno_text[0]], name='grounding', token=False, norm=False)
|
399 |
+
# token_emb = gtext['token_emb']
|
400 |
+
# tokens = gtext['tokens']
|
401 |
+
|
402 |
+
# grd_emb = token_emb[0][tokens['attention_mask'].bool()[0]]
|
403 |
+
# extra['grounding_tokens'] = grd_emb[:,None]
|
404 |
+
|
405 |
+
# assert len(images.tensor) == 1, "grounding evaluation only support single batch size now"
|
406 |
+
# features = self.backbone(images.tensor)
|
407 |
+
# outputs = self.sem_seg_head(features, extra=extra, task='grounding_eval')
|
408 |
+
|
409 |
+
# pred_gmasks = outputs['pred_masks'][idx,self.num_queries:2*self.num_queries-1]
|
410 |
+
# v_emb = outputs['pred_captions'][idx,self.num_queries:2*self.num_queries-1]
|
411 |
+
# t_emb = grd_emb[-1:]
|
412 |
+
|
413 |
+
# t_emb = t_emb / (t_emb.norm(dim=-1, keepdim=True) + 1e-7)
|
414 |
+
# v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)
|
415 |
+
|
416 |
+
# temperature = self.sem_seg_head.predictor.lang_encoder.logit_scale
|
417 |
+
# out_prob = vl_similarity(v_emb, t_emb, temperature=temperature)
|
418 |
+
|
419 |
+
# matched_id = out_prob.max(0)[1]
|
420 |
+
# grd_masks += [pred_gmasks[matched_id,:,:]]
|
421 |
+
# mask_pred_results += [torch.cat(grd_masks)]
|
422 |
+
|
423 |
+
# comment for multi object inference.
|
424 |
+
mask_pred_results = []
|
425 |
+
for idx, batch_per_image in enumerate(batched_inputs):
|
426 |
+
grd_texts = batch_per_image['groundings']['texts']
|
427 |
+
grd_texts = [x[0] for x in grd_texts]
|
428 |
+
|
429 |
+
gtext = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings(grd_texts, name='grounding', token=False, norm=False)
|
430 |
+
token_emb = gtext['token_emb']
|
431 |
+
tokens = gtext['tokens']
|
432 |
+
query_emb = token_emb[tokens['attention_mask'].bool()]
|
433 |
+
extra['grounding_tokens'] = query_emb[:,None]
|
434 |
+
|
435 |
+
features = self.backbone(images.tensor)
|
436 |
+
outputs = self.sem_seg_head(features, extra=extra, task='grounding_eval')
|
437 |
+
|
438 |
+
pred_gmasks = outputs['pred_masks'][idx,self.num_queries:2*self.num_queries-1]
|
439 |
+
v_emb = outputs['pred_captions'][idx,self.num_queries:2*self.num_queries-1]
|
440 |
+
t_emb = gtext['class_emb']
|
441 |
+
|
442 |
+
t_emb = t_emb / (t_emb.norm(dim=-1, keepdim=True) + 1e-7)
|
443 |
+
v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)
|
444 |
+
|
445 |
+
temperature = self.sem_seg_head.predictor.lang_encoder.logit_scale
|
446 |
+
out_prob = vl_similarity(v_emb, t_emb, temperature=temperature)
|
447 |
+
|
448 |
+
matched_id = out_prob.max(0)[1]
|
449 |
+
mask_pred_results += [pred_gmasks[matched_id,:,:]]
|
450 |
+
|
451 |
+
for i in range(len(mask_pred_results)):
|
452 |
+
# upsample masks
|
453 |
+
mask_pred_results[i] = F.interpolate(
|
454 |
+
mask_pred_results[i][None,],
|
455 |
+
size=(images.tensor.shape[-2], images.tensor.shape[-1]),
|
456 |
+
mode="bilinear",
|
457 |
+
align_corners=False,
|
458 |
+
)[0]
|
459 |
+
|
460 |
+
processed_results = []
|
461 |
+
for mask_pred_result, input_per_image, image_size in zip(
|
462 |
+
mask_pred_results, batched_inputs, images.image_sizes
|
463 |
+
):
|
464 |
+
height = input_per_image.get("height", image_size[0])
|
465 |
+
width = input_per_image.get("width", image_size[1])
|
466 |
+
processed_results.append({})
|
467 |
+
|
468 |
+
mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)(
|
469 |
+
mask_pred_result, image_size, height, width
|
470 |
+
)
|
471 |
+
processed_results[-1]['grounding_mask'] = mask_pred_result
|
472 |
+
|
473 |
+
# compute bbox
|
474 |
+
# bbox = BitMasks(mask_pred_result > 0).get_bounding_boxes()
|
475 |
+
# bbox = BoxMode.convert(bbox.tensor, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS)
|
476 |
+
# processed_results[-1]['grounding_box'] = bbox
|
477 |
+
|
478 |
+
return processed_results
|
479 |
+
|
480 |
+
def prepare_vlp_targets(self, batched_inputs, device):
|
481 |
+
input_ids = []
|
482 |
+
attention_mask = []
|
483 |
+
for cnt, x in enumerate(batched_inputs):
|
484 |
+
captions = x['captions']
|
485 |
+
randid = random.randint(0, len(captions)-1)
|
486 |
+
input_ids += x['tokens']['input_ids'][randid:randid+1]
|
487 |
+
attention_mask += x['tokens']['attention_mask'][randid:randid+1]
|
488 |
+
|
489 |
+
input_ids = torch.stack(input_ids)
|
490 |
+
attention_mask = torch.stack(attention_mask)
|
491 |
+
tokens = {"input_ids": input_ids, "attention_mask": attention_mask}
|
492 |
+
lang_results = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings(tokens, token=True)
|
493 |
+
|
494 |
+
target_vlp = []
|
495 |
+
for cnt, x in enumerate(batched_inputs):
|
496 |
+
target_dict = {}
|
497 |
+
target_dict["caption_tokens"] = lang_results['token_emb'][cnt:cnt+1]
|
498 |
+
target_dict["caption_proj"] = lang_results['class_emb'][cnt:cnt+1]
|
499 |
+
target_dict["caption_tokenids"] = lang_results['tokens']['input_ids'][cnt:cnt+1]
|
500 |
+
target_dict["caption_mask"] = lang_results['tokens']['attention_mask'][cnt:cnt+1]
|
501 |
+
target_vlp.append(target_dict)
|
502 |
+
return target_vlp
|
503 |
+
|
504 |
+
def semantic_inference(self, mask_cls, mask_pred, keep_sem_bgd=False):
|
505 |
+
if keep_sem_bgd:
|
506 |
+
mask_cls = F.softmax(mask_cls, dim=-1)
|
507 |
+
else:
|
508 |
+
mask_cls = F.softmax(mask_cls, dim=-1)[..., :-1]
|
509 |
+
mask_pred = mask_pred.sigmoid()
|
510 |
+
semseg = torch.einsum("qc,qhw->chw", mask_cls, mask_pred)
|
511 |
+
return semseg
|
512 |
+
|
513 |
+
def panoptic_inference(self, mask_cls, mask_pred):
|
514 |
+
scores, labels = F.softmax(mask_cls, dim=-1).max(-1)
|
515 |
+
mask_pred = mask_pred.sigmoid()
|
516 |
+
|
517 |
+
keep = labels.ne(self.sem_seg_head.num_classes) & (scores > self.object_mask_threshold)
|
518 |
+
cur_scores = scores[keep]
|
519 |
+
cur_classes = labels[keep]
|
520 |
+
cur_masks = mask_pred[keep]
|
521 |
+
cur_mask_cls = mask_cls[keep]
|
522 |
+
cur_mask_cls = cur_mask_cls[:, :-1]
|
523 |
+
cur_prob_masks = cur_scores.view(-1, 1, 1) * cur_masks
|
524 |
+
|
525 |
+
h, w = cur_masks.shape[-2:]
|
526 |
+
panoptic_seg = torch.zeros((h, w), dtype=torch.int32, device=cur_masks.device)
|
527 |
+
segments_info = []
|
528 |
+
|
529 |
+
current_segment_id = 0
|
530 |
+
|
531 |
+
if cur_masks.shape[0] == 0:
|
532 |
+
# We didn't detect any mask :(
|
533 |
+
return panoptic_seg, segments_info
|
534 |
+
else:
|
535 |
+
# take argmax
|
536 |
+
cur_mask_ids = cur_prob_masks.argmax(0)
|
537 |
+
stuff_memory_list = {}
|
538 |
+
thing_dataset_id_to_contiguous_id = self.metadata.thing_dataset_id_to_contiguous_id if hasattr(self.metadata, 'thing_dataset_id_to_contiguous_id') else {}
|
539 |
+
for k in range(cur_classes.shape[0]):
|
540 |
+
pred_class = cur_classes[k].item()
|
541 |
+
isthing = pred_class in thing_dataset_id_to_contiguous_id.values()
|
542 |
+
mask_area = (cur_mask_ids == k).sum().item()
|
543 |
+
original_area = (cur_masks[k] >= 0.5).sum().item()
|
544 |
+
mask = (cur_mask_ids == k) & (cur_masks[k] >= 0.5)
|
545 |
+
|
546 |
+
if mask_area > 0 and original_area > 0 and mask.sum().item() > 0:
|
547 |
+
if mask_area / original_area < self.overlap_threshold:
|
548 |
+
continue
|
549 |
+
|
550 |
+
# merge stuff regions
|
551 |
+
if not isthing:
|
552 |
+
if int(pred_class) in stuff_memory_list.keys():
|
553 |
+
panoptic_seg[mask] = stuff_memory_list[int(pred_class)]
|
554 |
+
continue
|
555 |
+
else:
|
556 |
+
stuff_memory_list[int(pred_class)] = current_segment_id + 1
|
557 |
+
|
558 |
+
current_segment_id += 1
|
559 |
+
panoptic_seg[mask] = current_segment_id
|
560 |
+
|
561 |
+
segments_info.append(
|
562 |
+
{
|
563 |
+
"id": current_segment_id,
|
564 |
+
"isthing": bool(isthing),
|
565 |
+
"category_id": int(pred_class),
|
566 |
+
}
|
567 |
+
)
|
568 |
+
return panoptic_seg, segments_info
|
569 |
+
|
570 |
+
def instance_inference(self, mask_cls, mask_pred, box_pred):
|
571 |
+
# mask_pred is already processed to have the same shape as original input
|
572 |
+
image_size = mask_pred.shape[-2:]
|
573 |
+
|
574 |
+
# [Q, K]
|
575 |
+
scores = F.softmax(mask_cls, dim=-1)[:, :-1]
|
576 |
+
labels = torch.arange(self.sem_seg_head.num_classes, device=self.device).unsqueeze(0).repeat(self.num_queries, 1).flatten(0, 1)
|
577 |
+
# scores_per_image, topk_indices = scores.flatten(0, 1).topk(self.num_queries, sorted=False)
|
578 |
+
scores_per_image, topk_indices = scores.flatten(0, 1).topk(self.test_topk_per_image, sorted=False)
|
579 |
+
|
580 |
+
labels_per_image = labels[topk_indices]
|
581 |
+
topk_indices = (topk_indices // self.sem_seg_head.num_classes)
|
582 |
+
# mask_pred = mask_pred.unsqueeze(1).repeat(1, self.sem_seg_head.num_classes, 1).flatten(0, 1)
|
583 |
+
mask_pred = mask_pred[topk_indices]
|
584 |
+
if box_pred is not None:
|
585 |
+
box_pred = box_pred[topk_indices]
|
586 |
+
|
587 |
+
# if this is panoptic segmentation, we only keep the "thing" classes
|
588 |
+
if self.panoptic_on:
|
589 |
+
thing_dataset_id_to_contiguous_id = self.metadata.thing_dataset_id_to_contiguous_id if hasattr(self.metadata, 'thing_dataset_id_to_contiguous_id') else {}
|
590 |
+
keep = torch.zeros_like(scores_per_image).bool()
|
591 |
+
for i, lab in enumerate(labels_per_image):
|
592 |
+
keep[i] = lab in thing_dataset_id_to_contiguous_id.values()
|
593 |
+
|
594 |
+
scores_per_image = scores_per_image[keep]
|
595 |
+
labels_per_image = labels_per_image[keep]
|
596 |
+
mask_pred = mask_pred[keep]
|
597 |
+
|
598 |
+
if box_pred is not None:
|
599 |
+
box_pred = box_pred[keep]
|
600 |
+
|
601 |
+
result = Instances(image_size)
|
602 |
+
# mask (before sigmoid)
|
603 |
+
result.pred_masks = (mask_pred > 0).float()
|
604 |
+
# result.pred_boxes = Boxes(torch.zeros(mask_pred.size(0), 4))
|
605 |
+
# Uncomment the following to get boxes from masks (this is slow)
|
606 |
+
|
607 |
+
if box_pred is not None:
|
608 |
+
result.pred_boxes = BitMasks(mask_pred > 0).get_bounding_boxes()
|
609 |
+
else:
|
610 |
+
result.pred_boxes = Boxes(torch.zeros(mask_pred.size(0), 4))
|
611 |
+
|
612 |
+
# calculate average mask prob
|
613 |
+
mask_scores_per_image = (mask_pred.sigmoid().flatten(1) * result.pred_masks.flatten(1)).sum(1) / (result.pred_masks.flatten(1).sum(1) + 1e-6)
|
614 |
+
result.scores = scores_per_image * mask_scores_per_image
|
615 |
+
result.pred_classes = labels_per_image
|
616 |
+
|
617 |
+
return result
|
618 |
+
|
619 |
+
|
620 |
+
@register_model
|
621 |
+
def get_segmentation_model(cfg, **kwargs):
|
622 |
+
return X_Decoder_Model(cfg)
|