Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
# Copyright (c) Meta Platforms, Inc. All Rights Reserved | |
# Modified by Feng Liang from | |
# https://github.com/MendelXu/zsseg.baseline/blob/master/mask_former/modeling/clip_adapter/text_prompt.py | |
# https://github.com/MendelXu/zsseg.baseline/blob/master/mask_former/modeling/clip_adapter/utils.py | |
from typing import List | |
# import clip | |
from .clip import tokenize | |
import torch | |
from torch import nn | |
IMAGENET_PROMPT = [ | |
"a bad photo of a {}.", | |
"a photo of many {}.", | |
"a sculpture of a {}.", | |
"a photo of the hard to see {}.", | |
"a low resolution photo of the {}.", | |
"a rendering of a {}.", | |
"graffiti of a {}.", | |
"a bad photo of the {}.", | |
"a cropped photo of the {}.", | |
"a tattoo of a {}.", | |
"the embroidered {}.", | |
"a photo of a hard to see {}.", | |
"a bright photo of a {}.", | |
"a photo of a clean {}.", | |
"a photo of a dirty {}.", | |
"a dark photo of the {}.", | |
"a drawing of a {}.", | |
"a photo of my {}.", | |
"the plastic {}.", | |
"a photo of the cool {}.", | |
"a close-up photo of a {}.", | |
"a black and white photo of the {}.", | |
"a painting of the {}.", | |
"a painting of a {}.", | |
"a pixelated photo of the {}.", | |
"a sculpture of the {}.", | |
"a bright photo of the {}.", | |
"a cropped photo of a {}.", | |
"a plastic {}.", | |
"a photo of the dirty {}.", | |
"a jpeg corrupted photo of a {}.", | |
"a blurry photo of the {}.", | |
"a photo of the {}.", | |
"a good photo of the {}.", | |
"a rendering of the {}.", | |
"a {} in a video game.", | |
"a photo of one {}.", | |
"a doodle of a {}.", | |
"a close-up photo of the {}.", | |
"a photo of a {}.", | |
"the origami {}.", | |
"the {} in a video game.", | |
"a sketch of a {}.", | |
"a doodle of the {}.", | |
"a origami {}.", | |
"a low resolution photo of a {}.", | |
"the toy {}.", | |
"a rendition of the {}.", | |
"a photo of the clean {}.", | |
"a photo of a large {}.", | |
"a rendition of a {}.", | |
"a photo of a nice {}.", | |
"a photo of a weird {}.", | |
"a blurry photo of a {}.", | |
"a cartoon {}.", | |
"art of a {}.", | |
"a sketch of the {}.", | |
"a embroidered {}.", | |
"a pixelated photo of a {}.", | |
"itap of the {}.", | |
"a jpeg corrupted photo of the {}.", | |
"a good photo of a {}.", | |
"a plushie {}.", | |
"a photo of the nice {}.", | |
"a photo of the small {}.", | |
"a photo of the weird {}.", | |
"the cartoon {}.", | |
"art of the {}.", | |
"a drawing of the {}.", | |
"a photo of the large {}.", | |
"a black and white photo of a {}.", | |
"the plushie {}.", | |
"a dark photo of a {}.", | |
"itap of a {}.", | |
"graffiti of the {}.", | |
"a toy {}.", | |
"itap of my {}.", | |
"a photo of a cool {}.", | |
"a photo of a small {}.", | |
"a tattoo of the {}.", | |
] | |
VILD_PROMPT = [ | |
"a photo of a {}.", | |
"This is a photo of a {}", | |
"There is a {} in the scene", | |
"There is the {} in the scene", | |
"a photo of a {} in the scene", | |
"a photo of a small {}.", | |
"a photo of a medium {}.", | |
"a photo of a large {}.", | |
"This is a photo of a small {}.", | |
"This is a photo of a medium {}.", | |
"This is a photo of a large {}.", | |
"There is a small {} in the scene.", | |
"There is a medium {} in the scene.", | |
"There is a large {} in the scene.", | |
] | |
class PromptExtractor(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self._buffer_init = False | |
def init_buffer(self, clip_model): | |
self._buffer_init = True | |
def forward(self, noun_list: List[str], clip_model: nn.Module): | |
raise NotImplementedError() | |
class PredefinedPromptExtractor(PromptExtractor): | |
def __init__(self, templates: List[str]): | |
super().__init__() | |
self.templates = templates | |
def forward(self, noun_list: List[str], clip_model: nn.Module): | |
text_features_bucket = [] | |
for template in self.templates: | |
noun_tokens = [tokenize(template.format(noun)) for noun in noun_list] | |
text_inputs = torch.cat(noun_tokens).to( | |
clip_model.text_projection.data.device | |
) | |
text_features = clip_model.encode_text(text_inputs) | |
text_features /= text_features.norm(dim=-1, keepdim=True) | |
text_features_bucket.append(text_features) | |
del text_inputs | |
# ensemble by averaging | |
text_features = torch.stack(text_features_bucket).mean(dim=0) | |
text_features = text_features / text_features.norm(dim=-1, keepdim=True) | |
return text_features | |
class ImageNetPromptExtractor(PredefinedPromptExtractor): | |
def __init__(self): | |
super().__init__(IMAGENET_PROMPT) | |
class VILDPromptExtractor(PredefinedPromptExtractor): | |
def __init__(self): | |
super().__init__(VILD_PROMPT) | |