Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
# Copyright (c) Meta Platforms, Inc. All Rights Reserved | |
# Modified by Feng Liang from | |
# https://github.com/MendelXu/zsseg.baseline/blob/master/mask_former/modeling/clip_adapter/adapter.py | |
from typing import List | |
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
from detectron2.structures import BitMasks | |
from .utils import build_clip_model, crop_with_mask | |
from .text_template import PromptExtractor | |
PIXEL_MEAN = (0.48145466, 0.4578275, 0.40821073) | |
PIXEL_STD = (0.26862954, 0.26130258, 0.27577711) | |
class ClipAdapter(nn.Module): | |
def __init__(self, clip_model_name: str, mask_prompt_depth: int, text_templates: PromptExtractor): | |
super().__init__() | |
self.clip_model = build_clip_model(clip_model_name, mask_prompt_depth) | |
self.text_templates = text_templates | |
self.text_templates.init_buffer(self.clip_model) | |
self.text_feature_buffer = {} | |
def forward(self, image: torch.Tensor, text: List[str], **kwargs): | |
image = self._preprocess_image(image, **kwargs) | |
text_feature = self.get_text_features(text) # k,feat_dim | |
image_features = self.get_image_features(image) | |
return self.get_sim_logits(text_feature, image_features) | |
def _preprocess_image(self, image: torch.Tensor): | |
return image | |
def _get_text_features(self, noun_list: List[str]): | |
left_noun_list = [ | |
noun for noun in noun_list if noun not in self.text_feature_buffer | |
] | |
if len(left_noun_list) > 0: | |
left_text_features = self.text_templates( | |
left_noun_list, self.clip_model | |
) | |
self.text_feature_buffer.update( | |
{ | |
noun: text_feature | |
for noun, text_feature in zip( | |
left_noun_list, left_text_features | |
) | |
} | |
) | |
return torch.stack([self.text_feature_buffer[noun] for noun in noun_list]) | |
def get_text_features(self, noun_list: List[str]): | |
return self._get_text_features(noun_list) | |
def get_image_features(self, image: torch.Tensor): | |
image_features = self.clip_model.visual(image) | |
image_features = image_features / image_features.norm(dim=-1, keepdim=True) | |
return image_features | |
def get_sim_logits( | |
self, | |
text_features: torch.Tensor, | |
image_features: torch.Tensor, | |
temperature: float = 100, | |
): | |
return temperature * image_features @ text_features.T | |
def normalize_feature(self, feat: torch.Tensor): | |
return feat / feat.norm(dim=-1, keepdim=True) | |
class MaskFormerClipAdapter(ClipAdapter): | |
def __init__( | |
self, | |
clip_model_name: str, | |
text_templates: PromptExtractor, | |
mask_fill: str = "mean", | |
mask_expand_ratio: float = 1.0, | |
mask_thr: float = 0.5, | |
mask_matting: bool = False, | |
region_resized: bool = True, | |
mask_prompt_depth: int = 0, | |
mask_prompt_fwd: bool = False, | |
): | |
super().__init__(clip_model_name, mask_prompt_depth, text_templates) | |
self.non_object_embedding = nn.Parameter( | |
torch.empty(1, self.clip_model.text_projection.shape[-1]) | |
) | |
nn.init.normal_( | |
self.non_object_embedding.data, | |
std=self.clip_model.transformer.width ** -0.5, | |
) | |
# for test | |
self.mask_fill = mask_fill | |
if self.mask_fill == "zero": | |
self.mask_fill = (0.0, 0.0, 0.0) | |
elif self.mask_fill == "mean": | |
self.mask_fill = [255.0 * c for c in PIXEL_MEAN] | |
else: | |
raise NotImplementedError( | |
"Unknown mask_fill method: {}".format(self.mask_fill) | |
) | |
self.mask_expand_ratio = mask_expand_ratio | |
self.mask_thr = mask_thr | |
self.mask_matting = mask_matting | |
self.region_resized = region_resized | |
self.mask_prompt_fwd = mask_prompt_fwd | |
self.register_buffer( | |
"pixel_mean", torch.Tensor(PIXEL_MEAN).reshape(1, 3, 1, 1) * 255.0 | |
) | |
self.register_buffer( | |
"pixel_std", torch.Tensor(PIXEL_STD).reshape(1, 3, 1, 1) * 255.0 | |
) | |
def forward( | |
self, | |
image: torch.Tensor, | |
text: List[str], | |
mask: torch.Tensor, | |
normalize: bool = True, | |
fwd_w_region_mask: bool = False, | |
): | |
(regions, unnorm_regions), region_masks, valid_flag = self._preprocess_image(image, mask, normalize=normalize) | |
if regions is None: | |
return None, valid_flag | |
if isinstance(regions, list): | |
assert NotImplementedError | |
image_features = torch.cat( | |
[self.get_image_features(image_i) for image_i in regions], dim=0 | |
) | |
else: | |
if self.mask_prompt_fwd: | |
image_features = self.get_image_features(regions, region_masks) | |
else: | |
image_features = self.get_image_features(regions) | |
text_feature = self.get_text_features(text) # k,feat_dim | |
return self.get_sim_logits(text_feature, image_features), unnorm_regions, valid_flag | |
def get_image_features(self, image, region_masks=None): | |
image_features = self.clip_model.visual(image, region_masks) | |
image_features = image_features / image_features.norm(dim=-1, keepdim=True) | |
return image_features | |
def _preprocess_image( | |
self, image: torch.Tensor, mask: torch.Tensor, normalize: bool = True | |
): | |
"""crop, mask and normalize the image | |
Args: | |
image ([type]): [C,H,W] | |
mask ([type]): [K,H,W | |
normalize (bool, optional): [description]. Defaults to True. | |
""" | |
dtype = mask.dtype | |
bin_mask = mask > self.mask_thr | |
valid = bin_mask.sum(dim=(-1, -2)) > 0 | |
bin_mask = bin_mask[valid] | |
mask = mask[valid] | |
if not self.mask_matting: | |
mask = bin_mask | |
bin_mask = BitMasks(bin_mask) | |
bboxes = bin_mask.get_bounding_boxes() | |
# crop,mask | |
regions = [] | |
region_masks = [] | |
for bbox, single_mask in zip(bboxes, mask): | |
region, region_mask = crop_with_mask( | |
image.type(dtype), | |
single_mask.type(dtype), | |
bbox, | |
fill=self.mask_fill, | |
expand_ratio=self.mask_expand_ratio, | |
) | |
regions.append(region.unsqueeze(0)) | |
region_masks.append(region_mask.unsqueeze(0)) | |
if len(regions) == 0: | |
return None, valid | |
unnorm_regions = regions | |
if normalize: | |
regions = [(r - self.pixel_mean) / self.pixel_std for r in regions] | |
# resize | |
if self.region_resized: | |
regions = [ | |
F.interpolate(r, size=(224, 224), mode="bicubic") for r in regions | |
] | |
regions = torch.cat(regions) | |
region_masks = [ | |
F.interpolate(r, size=(224, 224), mode="nearest") for r in region_masks | |
] | |
region_masks = torch.cat(region_masks) | |
unnorm_regions = [ | |
F.interpolate(r, size=(224, 224), mode="bicubic") for r in unnorm_regions | |
] | |
unnorm_regions = torch.cat(unnorm_regions) | |
return (regions, unnorm_regions), region_masks, valid | |
def get_text_features(self, noun_list: List[str]): | |
object_text_features = self._get_text_features(noun_list) | |
non_object_text_features = ( | |
self.non_object_embedding | |
/ self.non_object_embedding.norm(dim=-1, keepdim=True) | |
) | |
return torch.cat([object_text_features, non_object_text_features], dim=0) | |