Spaces:
Runtime error
Runtime error
File size: 14,985 Bytes
2aac0e2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 |
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
"""
import torch
import torch.nn.functional as F
from torch import nn
# from ..backbone import build_backbone, Backbone
# from ..body.encoder import build_encoder
# from ..body.decoder import build_decoder
from detectron2.modeling import build_backbone
from .pixel_decoder.maskdino_encoder import build_pixel_decoder
from .transformer_decoder.maskdino_decoder import build_transformer_decoder
import random
from transformers import AutoTokenizer
from collections import OrderedDict
from ..modules.point_features import point_sample
from timm.models.layers import trunc_normal_
from transformers import CLIPTokenizer,CLIPTextModel
from .vos_utils import masks_to_boxes, FeatureFuser
import numpy as np
import math
def rand_sample(x, max_len):
if x.shape[1] <= max_len:
return x
else:
rand_idx = torch.randperm(x.shape[1])[:max_len]
return x[:,rand_idx]
def agg_lang_feat(features, mask, pool_type="average"):
"""average pooling of language features"""
# feat: (bs, seq_len, C)
# mask: (bs, seq_len)
if pool_type == "average":
embedded = features * mask.unsqueeze(-1).float() # use mask to zero out invalid token features
aggregate = embedded.sum(1) / (mask.sum(-1).unsqueeze(-1).float())
elif pool_type == "max":
out = []
for i in range(len(features)):
pool_feat, _ = torch.max(features[i][mask[i]], 0) # (L, C) -> (C, )
out.append(pool_feat)
aggregate = torch.stack(out, dim=0) # (bs, C)
else:
raise ValueError("pool_type should be average or max")
return aggregate
class GLEE_Model(nn.Module):
"""
Main class for mask classification semantic segmentation architectures.
"""
def __init__(self, cfg, matcher, device, video_info, contras_mean):
super().__init__()
self.cfg = cfg
self.matcher = matcher
self.backbone = build_backbone(cfg)
output_channels = [v for k,v in self.backbone._out_feature_channels.items()]
self.sot_fuser = FeatureFuser(output_channels[-3:], 256)
self.tokenizer = CLIPTokenizer.from_pretrained('GLEE/clip_vit_base_patch32')
self.tokenizer.add_special_tokens({'cls_token': self.tokenizer.eos_token})
self.text_encoder = CLIPTextModel.from_pretrained('GLEE/clip_vit_base_patch32')
# self.text_encoder_teacher = CLIPTextModel.from_pretrained('GLEE/clip_vit_base_patch32')
self.lang_encoder = None
# for p in self.text_encoder_teacher.parameters():
# p.requires_grad = False
self.lang_projection = nn.Parameter(torch.rand(cfg.MODEL.LANGUAGE_BACKBONE.LANG_DIM, cfg.MODEL.DIM_PROJ))
self.text_encode_type = 'clip_teacher'
# self.lang_encoder = None
self.pixel_decoder = build_pixel_decoder(cfg, self.backbone.output_shape())
transformer_predictor_in_channels = cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM
self.predictor = build_transformer_decoder(cfg, transformer_predictor_in_channels, lang_encoder = self.lang_encoder, mask_classification=True,)
self.to(device)
self.video_info = video_info
self.contras_mean = contras_mean
self.track_loss_version = cfg.MODEL.TRACK_VERSION
self.no_mask_tasks = ['obj365', 'obj365_clip','openimage', 'openimage_clip', 'vg', 'grit', 'bdd_det', 'bdd_track_box']
# for visual prompt
hidden_dim = 256
self.max_spatial_len = [512,512,512,512]
self.mask_sptial_embed = nn.ParameterList([nn.Parameter(torch.empty(hidden_dim, hidden_dim)) for x in range(4)])
trunc_normal_(self.mask_sptial_embed[0], std=.02)
trunc_normal_(self.mask_sptial_embed[1], std=.02)
trunc_normal_(self.mask_sptial_embed[2], std=.02)
trunc_normal_(self.mask_sptial_embed[3], std=.02)
# learnable positive negative indicator
self.pn_indicator = nn.Embedding(2, hidden_dim)
@property
def device(self):
return self.pixel_mean.device
def forward(self, images, prompts, task, targets=None, batch_name_list=None, is_train = True, visual_prompt_type='scribble'):
extra = {}
# dist_loss = None
early_semantic = None
if self.text_encode_type == "clip_teacher":
if task not in ['grounding','rvos']:
assert batch_name_list
calsses_name_list = batch_name_list
tokenized = self.tokenizer.batch_encode_plus(calsses_name_list,
max_length=self.cfg.MODEL.LANGUAGE_BACKBONE.MAX_QUERY_LEN, # 256
padding='max_length' if self.cfg.MODEL.LANGUAGE_BACKBONE.PAD_MAX else "longest", # max_length
return_special_tokens_mask=True,
return_tensors='pt',
truncation=True).to(images.device)
texts = (tokenized['input_ids'], tokenized['attention_mask'])
token_x = self.text_encoder(*texts)['last_hidden_state']
valid_mask = tokenized['attention_mask'].bool()
# token_x_teacher = self.text_encoder_teacher(*texts)['last_hidden_state']
# if is_train:
# dist_loss = F.mse_loss(token_x[valid_mask], token_x_teacher[valid_mask] )
# F.l2_loss(token_x[valid_mask], token_x_teacher[valid_mask] )
token_x = token_x @ self.lang_projection
lang_feat_pool = agg_lang_feat(token_x, tokenized['attention_mask'], pool_type="average") # (bs, 768)
extra['class_embeddings'] = lang_feat_pool
if True: # early_fusion
gather_all_classtoken = token_x.flatten(0,1)[tokenized['attention_mask'].flatten(0,1)>0]
gather_all_classtoken = gather_all_classtoken.unsqueeze(0).repeat(len(images),1,1) #[bs,L,C]
gather_all_classtoken_mask = torch.ones_like(gather_all_classtoken[:,:,0])>0 #[bs,L]
early_semantic = {"hidden":gather_all_classtoken.float(),"masks":gather_all_classtoken_mask}
if 'grounding' in prompts:
if self.text_encode_type == 'clip_frozen' or self.text_encode_type == 'clip_teacher':
tokens = self.tokenizer(
prompts['grounding'], padding='max_length', truncation=True, max_length=self.cfg.MODEL.LANGUAGE_BACKBONE.MAX_QUERY_LEN, return_tensors='pt'
)
tokens = {key: value.to(images.device) for key, value in tokens.items()}
texts = (tokens['input_ids'], tokens['attention_mask'])
x = self.text_encoder(*texts)
token_x = x['last_hidden_state']
token_x = token_x @ self.lang_projection
extra['grounding_tokens'] = token_x.permute(1,0,2) #[len,bz,C]
non_zero_query_mask = tokens['attention_mask']
lang_feat_pool = agg_lang_feat(token_x, non_zero_query_mask, pool_type="average").unsqueeze(1) # (bs, 1, 768)
dist_loss = (lang_feat_pool*0).sum()
extra['grounding_nonzero_mask'] = ~non_zero_query_mask.bool() # [bz,len]
extra['grounding_class'] = lang_feat_pool.squeeze(1) #[bz,C
# gather_all_classtoken = token_x.flatten(0,1)[tokenized['attention_mask'].flatten(0,1)>0]
# gather_all_classtoken = gather_all_classtoken.unsqueeze(0).repeat(len(images),1,1) #[bs,L,C]
# gather_all_classtoken_mask = torch.ones_like(gather_all_classtoken[:,:,0])>0 #[bs,L]
# early_semantic = {"hidden":gather_all_classtoken.float(),"masks":gather_all_classtoken_mask}
early_semantic = {"hidden":token_x.float(),"masks":tokens['attention_mask']>0}
if isinstance(images,torch.Tensor):
features = self.backbone(images)
else:
features = self.backbone(images.tensor)
if 'spatial' in prompts:
## setp 1,2,3
key_images = [ images ] #bz*[1,3,H,W]
key_promptmasks = [m.unsqueeze(0) for m in prompts['spatial']] #bz*[1,1,H,W]
prompt_mode = visual_prompt_type
ref_feats, ref_masks = self.get_template(key_images, key_promptmasks, prompt_mode)
early_fusion = {"hidden":ref_feats,"masks":ref_masks}
if early_semantic is None:
early_semantic = early_fusion
else:
early_semantic["hidden"] = torch.cat([early_semantic["hidden"],early_fusion["hidden"]],dim=1)
early_semantic["masks"] = torch.cat([early_semantic["masks"],early_fusion["masks"]],dim=1)
# bz = len(images)//2
mask_features, _, multi_scale_features, zero_loss = self.pixel_decoder.forward_features(features, masks=None, early_fusion = early_semantic)
if 'spatial' in prompts:
pos_masks = prompts['spatial']
# neg_masks = [~p for p in prompts['spatial']]
neg_masks = [p&False for p in prompts['spatial']]
extra.update({'spatial_query_pos_mask': pos_masks, 'spatial_query_neg_mask': neg_masks})
_,h,w = extra['spatial_query_pos_mask'][0].shape
divisor = torch.tensor([h,w], device=mask_features.device)[None,]
# Get mean pos spatial query
non_zero_pos_point = [rand_sample((m.nonzero()[:,1:]/divisor).t(), self.max_spatial_len[-1]).t() for m in extra['spatial_query_pos_mask']]
non_zero_pos_point = nn.utils.rnn.pad_sequence(non_zero_pos_point, padding_value=-1).permute(1,0,2)
non_zero_pos_mask = (non_zero_pos_point.sum(dim=-1) < 0)
spatial_query_pos = point_sample(mask_features, non_zero_pos_point.flip(dims=(2,)).type(mask_features.dtype), align_corners=True) #[(N, C, P)
spatial_query_pos = torch.stack([x[m].mean(dim=0, keepdim=True) for x, m in zip(spatial_query_pos.transpose(1,2), ~non_zero_pos_mask)]).transpose(0,1).nan_to_num() # [1,bz,C]
# Get mean neg spatial query
non_zero_neg_point = [rand_sample((m.nonzero()[:,1:]/divisor).t(), self.max_spatial_len[-1]).t() for m in extra['spatial_query_neg_mask']]
non_zero_neg_point = nn.utils.rnn.pad_sequence(non_zero_neg_point, padding_value=-1).permute(1,0,2)
non_zero_neg_mask = (non_zero_neg_point.sum(dim=-1) < 0)
spatial_query_neg = point_sample(mask_features, non_zero_neg_point.flip(dims=(2,)).type(mask_features.dtype), align_corners=True)
spatial_query_neg = torch.stack([x[m].mean(dim=0, keepdim=True) for x, m in zip(spatial_query_neg.transpose(1,2), ~non_zero_neg_mask)]).transpose(0,1).nan_to_num()
# Get layerwise spatial query
src_spatial_queries = []
src_spatial_maskings = []
for i in range(len(multi_scale_features)):
bs,dc,h,w = multi_scale_features[i].shape
# src_mask_features = multi_scale_features[i].view(h,w,bs,dc)
src_mask_features = multi_scale_features[i].permute(2,3,0,1)
src_mask_features = src_mask_features @ self.mask_sptial_embed[i]
non_zero_query_point_pos = [rand_sample((m.nonzero()[:,1:]/divisor).t(), self.max_spatial_len[i]).t() for m in extra['spatial_query_pos_mask']]
non_zero_query_point_neg = [rand_sample((m.nonzero()[:,1:]/divisor).t(), self.max_spatial_len[i]).t() for m in extra['spatial_query_neg_mask']]
non_zero_query_point = [torch.cat([x,y], dim=0) for x,y in zip(non_zero_query_point_pos, non_zero_query_point_neg)]
pos_neg_indicator = [torch.cat([torch.ones(x.shape[0], device=x.device), -torch.ones(y.shape[0], device=y.device)]) for x,y in zip(non_zero_query_point_pos, non_zero_query_point_neg)]
pos_neg_indicator = nn.utils.rnn.pad_sequence(pos_neg_indicator, padding_value=0)
non_zero_query_point = nn.utils.rnn.pad_sequence(non_zero_query_point, padding_value=-1).permute(1,0,2)
non_zero_query_mask = (non_zero_query_point.sum(dim=-1) < 0)
non_zero_query_point[non_zero_query_mask] = 0
spatial_tokens = point_sample(src_mask_features.permute(2,3,0,1), non_zero_query_point.flip(dims=(2,)).type(src_mask_features.dtype), align_corners=True).permute(2,0,1)
spatial_tokens[pos_neg_indicator==1] += self.pn_indicator.weight[0:1]
spatial_tokens[pos_neg_indicator==-1] += self.pn_indicator.weight[1:2]
src_spatial_queries += [spatial_tokens]
src_spatial_maskings += [non_zero_query_mask]
extra['visual_prompt_tokens'] = src_spatial_queries #[len,bz,C]
extra['visual_prompt_nonzero_mask'] = src_spatial_maskings # [bz,len]
outputs = self.predictor(multi_scale_features, mask_features, extra=extra, task=task, masks=None, targets=targets)
return outputs
def get_template(self, imgs, pad_masks, prompt_mode='scribble'):
"""img: (N, 3, H, W), mask: (N, 1, H, W), bbox: (1, 4)"""
"""get 4-channel template"""
croped_img_with_mask = []
for image_i, mask_i in zip( imgs, pad_masks):
if prompt_mode in ['scribble','point']:
image_with_mask = image_i + mask_i.to(image_i)
else:
image_with_mask = image_i
# image_with_mask = torch.cat([image_i,mask_i.to(image_i)],dim=1) #[1,3,H,W]
box_i = masks_to_boxes(mask_i[0]) #[xyxy]
box_i[:, 2:] = box_i[:, 2:] - box_i[:, :2] #xywh
x, y, w, h = box_i[0].long().tolist()
self.search_area_factor=2
crop_sz = math.ceil(math.sqrt(w * h) * self.search_area_factor)
x1 = max(0,round(x + 0.5 * w - crop_sz * 0.5))
x2 = x1 + crop_sz
y1 = max(0,round(y + 0.5 * h - crop_sz * 0.5))
y2 = y1 + crop_sz
im_crop = image_with_mask[:, :, y1:y2, x1:x2]
# resize
if im_crop.shape[-1] ==0 or im_crop.shape[-2] ==0 :
im_crop = image_with_mask
im_crop = F.interpolate(im_crop, (256,256), mode='bilinear', align_corners=False)
croped_img_with_mask.append(im_crop)
croped_img_with_mask = torch.cat(croped_img_with_mask,dim=0) #[bz,3,256,256]
with torch.no_grad():
ref_srcs = self.backbone(croped_img_with_mask.contiguous())
ref_srcs = [v for k,v in ref_srcs.items()]
ref_feats = self.sot_fuser(ref_srcs[1:]).float() #[bz,256,32,32]
ref_feats = ref_feats.flatten(-2).permute(0, 2, 1) # (bs, L, C)
ref_masks = torch.ones_like(ref_feats[:,:,0])>0 #[bs,L]
return ref_feats, ref_masks
|