|
""" ViTamin |
|
|
|
Paper: Designing Scalable Vison Models in the Vision-Language Era |
|
|
|
@misc{chen2023designing, |
|
title={Designing Scalable Vison Models in the Vision-Language Era}, |
|
author={Jieneng Chen and Qihang Yu and Xiaohui Shen and Alan Yuille and Liang-Cheih Chen}, |
|
year={2023}, |
|
archivePrefix={arXiv}, |
|
primaryClass={cs.CV} |
|
} |
|
|
|
Based on Apache 2.0 licensed code at https://github.com/Beckschen/ViTamin |
|
|
|
by Jieneng Chen 2024 |
|
""" |
|
|
|
import copy |
|
import os |
|
from collections import OrderedDict |
|
from typing import TYPE_CHECKING, Any, Mapping, Optional, Union |
|
|
|
|
|
if TYPE_CHECKING: |
|
from transformers.processing_utils import ProcessorMixin |
|
from transformers.utils import TensorType |
|
|
|
from transformers.configuration_utils import PretrainedConfig |
|
from transformers.utils import logging |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
class ViTaminTextConfig(PretrainedConfig): |
|
model_type = "vitamin_text_model" |
|
|
|
def __init__( |
|
self, |
|
context_length = 77, |
|
vocab_size = 49408, |
|
width = 1024, |
|
heads = 16, |
|
layers = 24, |
|
**kwargs, |
|
): |
|
super().__init__(**kwargs) |
|
|
|
self.vocab_size = vocab_size |
|
self.context_length = context_length |
|
self.width = width |
|
self.heads = heads |
|
self.layers = layers |
|
|
|
@classmethod |
|
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": |
|
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) |
|
|
|
if 'text_config' in config_dict: |
|
config_dict = config_dict['text_config'] |
|
|
|
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: |
|
logger.warning( |
|
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " |
|
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." |
|
) |
|
|
|
return cls.from_dict(config_dict, **kwargs) |
|
|
|
|
|
class ViTaminVisionConfig(PretrainedConfig): |
|
|
|
model_type = "vitamin_vision_model" |
|
|
|
def __init__( |
|
self, |
|
timm_model_name = "vitamin_large", |
|
timm_model_pretrained = False, |
|
timm_pool = "", |
|
timm_proj = "linear", |
|
timm_drop = 0.0, |
|
timm_drop_path = 0.1, |
|
image_size = 256, |
|
timm_proj_bias = False, |
|
patch_dropout = 0.0, |
|
drop_path = None, |
|
**kwargs, |
|
): |
|
super().__init__(**kwargs) |
|
|
|
self.timm_model_name = timm_model_name |
|
self.timm_model_pretrained = timm_model_pretrained |
|
self.timm_pool = timm_pool |
|
self.timm_proj = timm_proj |
|
self.timm_drop = timm_drop |
|
self.timm_drop_path = timm_drop_path |
|
self.timm_proj_bias = timm_proj_bias |
|
self.patch_dropout = patch_dropout |
|
self.image_size = image_size |
|
|
|
|
|
@classmethod |
|
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": |
|
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) |
|
|
|
if 'vision_config' in config_dict: |
|
config_dict = config_dict['vision_config'] |
|
|
|
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: |
|
logger.warning( |
|
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " |
|
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." |
|
) |
|
|
|
return cls.from_dict(config_dict, **kwargs) |
|
|
|
|
|
|
|
class ViTaminConfig(PretrainedConfig): |
|
model_type = "vitamin" |
|
is_composition = True |
|
|
|
def __init__( |
|
self, text_config=None, vision_config=None, embed_dim=512, **kwargs |
|
): |
|
super().__init__(**kwargs) |
|
if text_config is None: |
|
text_config = {} |
|
logger.info("`text_config` is `None`. Initializing the `CLIPTextConfig` with default values.") |
|
|
|
if vision_config is None: |
|
vision_config = {} |
|
logger.info("`vision_config` is `None`. initializing the `CLIPVisionConfig` with default values.") |
|
|
|
self.embed_dim = embed_dim |
|
self.text_config = ViTaminTextConfig(**text_config) |
|
self.vision_config = ViTaminVisionConfig(**vision_config) |
|
|
|
@classmethod |
|
def from_text_vision_configs(cls, text_config: ViTaminTextConfig, vision_config: ViTaminVisionConfig, **kwargs): |
|
r""" |
|
Instantiate a [`CLIPConfig`] (or a derived class) from clip text model configuration and clip vision model |
|
configuration. |
|
Returns: |
|
[`CLIPConfig`]: An instance of a configuration object |
|
""" |
|
|
|
return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs) |
|
|
|
def to_dict(self): |
|
""" |
|
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. |
|
Returns: |
|
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance, |
|
""" |
|
output = copy.deepcopy(self.__dict__) |
|
output["text_config"] = self.text_config.to_dict() |
|
output["vision_config"] = self.vision_config.to_dict() |
|
output["model_type"] = self.__class__.model_type |
|
return output |
|
|