|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from __future__ import annotations |
|
|
|
from typing import Any, Literal, Optional |
|
|
|
from transformers import PretrainedConfig |
|
|
|
|
|
class CLYPConfig(PretrainedConfig): |
|
model_type = "clyp" |
|
|
|
def __init__( |
|
self, |
|
vision_encoder_config: Optional[dict] = None, |
|
text_encoder_config: Optional[dict] = None, |
|
itc_loss_config: Optional[dict] = None, |
|
learn_temperature: bool = True, |
|
temperature_init: float = 0.07, |
|
temperature_min: float = 0.01, |
|
temperature_max: float = 1000.0, |
|
**kwargs, |
|
) -> None: |
|
super().__init__(**kwargs) |
|
vision_encoder_config = vision_encoder_config or {} |
|
text_encoder_config = text_encoder_config or {} |
|
self.vision_encoder_config = CLYPVisionEncoderConfig(**vision_encoder_config) |
|
self.text_encoder_config = CLYPTextEncoderConfig(**text_encoder_config) |
|
self.itc_loss_config = ( |
|
CLYPLossConfig(**itc_loss_config) if itc_loss_config else None |
|
) |
|
self.learn_temperature = learn_temperature |
|
self.temperature_init = temperature_init |
|
self.temperature_min = temperature_min |
|
self.temperature_max = temperature_max |
|
|
|
def to_diff_dict(self) -> dict[str, Any]: |
|
serializable_config_dict = super().to_diff_dict() |
|
sub_serializable_config_dict = { |
|
"vision_encoder_config": _to_diff_dict(self.vision_encoder_config), |
|
"text_encoder_config": _to_diff_dict(self.text_encoder_config), |
|
} |
|
self.dict_torch_dtype_to_str(sub_serializable_config_dict) |
|
serializable_config_dict.update(sub_serializable_config_dict) |
|
return serializable_config_dict |
|
|
|
|
|
class CLYPVisionEncoderConfig(PretrainedConfig): |
|
def __init__( |
|
self, |
|
backbone_config: Optional[dict] = None, |
|
pooler_config: Optional[dict] = None, |
|
neck_config: Optional[dict] = None, |
|
**kwargs, |
|
) -> None: |
|
super().__init__(**kwargs) |
|
backbone_config = backbone_config or {} |
|
pooler_config = pooler_config or {"input_type": "timm"} |
|
neck_config = neck_config or {} |
|
self.backbone_config = CLYPVisionBackboneConfig(**backbone_config) |
|
self.pooler_config = CLYPPoolerConfig(**pooler_config) |
|
self.neck_config = CLYPNeckConfig(**neck_config) |
|
|
|
def to_diff_dict(self) -> dict[str, Any]: |
|
serializable_config_dict = { |
|
"backbone_config": _to_diff_dict(self.backbone_config), |
|
"pooler_config": _to_diff_dict(self.pooler_config), |
|
"neck_config": _to_diff_dict(self.neck_config), |
|
} |
|
self.dict_torch_dtype_to_str(serializable_config_dict) |
|
return serializable_config_dict |
|
|
|
|
|
class CLYPTextEncoderConfig(PretrainedConfig): |
|
def __init__( |
|
self, |
|
backbone_config: Optional[dict] = None, |
|
pooler_config: Optional[dict] = None, |
|
neck_config: Optional[dict] = None, |
|
**kwargs, |
|
) -> None: |
|
super().__init__(**kwargs) |
|
backbone_config = backbone_config or {} |
|
pooler_config = pooler_config or {"input_type": "huggingface"} |
|
neck_config = neck_config or {} |
|
self.backbone_config = CLYPTextBackboneConfig(**backbone_config) |
|
self.pooler_config = CLYPPoolerConfig(**pooler_config) |
|
self.neck_config = CLYPNeckConfig(**neck_config) |
|
|
|
def to_diff_dict(self) -> dict[str, Any]: |
|
serializable_config_dict = { |
|
"backbone_config": _to_diff_dict(self.backbone_config), |
|
"pooler_config": _to_diff_dict(self.pooler_config), |
|
"neck_config": _to_diff_dict(self.neck_config), |
|
} |
|
self.dict_torch_dtype_to_str(serializable_config_dict) |
|
return serializable_config_dict |
|
|
|
|
|
class CLYPVisionBackboneConfig(PretrainedConfig): |
|
def __init__( |
|
self, |
|
model_name: str = "eva02_base_patch16_clip_224.merged2b", |
|
pretrained: bool = True, |
|
extra_kwargs: Optional[dict] = None, |
|
**kwargs, |
|
) -> None: |
|
super().__init__(**kwargs) |
|
self.model_name = model_name |
|
self.pretrained = pretrained |
|
self.extra_kwargs = extra_kwargs or {} |
|
|
|
|
|
class CLYPTextBackboneConfig(PretrainedConfig): |
|
def __init__( |
|
self, |
|
model_name: str = "rinna/japanese-clip-vit-b-16", |
|
**kwargs, |
|
) -> None: |
|
super().__init__(**kwargs) |
|
self.model_name = model_name |
|
|
|
|
|
class CLYPPoolerConfig(PretrainedConfig): |
|
def __init__( |
|
self, |
|
input_type: Literal["timm", "huggingface"] | None = None, |
|
return_patch_features: bool = False, |
|
**kwargs, |
|
) -> None: |
|
super().__init__(**kwargs) |
|
self.input_type = input_type |
|
self.return_patch_features = return_patch_features |
|
|
|
|
|
class CLYPNeckConfig(PretrainedConfig): |
|
def __init__( |
|
self, |
|
in_channels: int = 768, |
|
out_channels: int = 512, |
|
bias: bool = False, |
|
**kwargs, |
|
) -> None: |
|
super().__init__(**kwargs) |
|
self.in_channels = in_channels |
|
self.out_channels = out_channels |
|
self.bias = bias |
|
|
|
|
|
class CLYPLossConfig(PretrainedConfig): |
|
def __init__( |
|
self, |
|
learn_temperature: bool = True, |
|
init_temperature: float = 0.07, |
|
max_temperature: Optional[float] = None, |
|
min_temperature: Optional[float] = None, |
|
label_smoothing: float = 0.0, |
|
gather_with_grad: bool = True, |
|
**kwargs, |
|
) -> None: |
|
super().__init__(**kwargs) |
|
self.learn_temperature = learn_temperature |
|
self.init_temperature = init_temperature |
|
self.max_temperature = max_temperature |
|
self.min_temperature = min_temperature |
|
self.label_smoothing = label_smoothing |
|
self.gather_with_grad = gather_with_grad |
|
|
|
|
|
def _to_diff_dict(c: PretrainedConfig) -> dict: |
|
"""Function to override PretrainedConfig.to_diff_dict() |
|
|
|
NOTE |
|
---- |
|
In transformers==4.38.1, |
|
PretrainedConfig.__repr__ may not be able to show configs that has some sub-configs |
|
""" |
|
d = c.to_diff_dict() |
|
if "transformers_version" in d: |
|
d.pop("transformers_version") |
|
return d |
|
|
|
|
|
if __name__ == "__main__": |
|
conf = CLYPConfig.from_pretrained("config.json") |
|
print(conf) |
|
|