from transformers import PretrainedConfig, RobertaConfig class JapaneseCLIPVisionConfig(PretrainedConfig): model_type = "vit" def __init__(self, image_size: int, patch_size: int, width: int, layers: int, head_width: int, mlp_ratio: float, ls_init_value: float = None, attentional_pool: bool = False, attn_pooler_queries: int = 256, attn_pooler_heads: int = 8, output_dim: int = 512, patch_dropout: float = 0.0, no_ln_pre: bool = False, pool_type: str = "tok", final_ln_after_pool: bool = False, output_tokens: bool = False, **kwargs ): super().__init__(**kwargs) self.image_size = image_size self.patch_size = patch_size self.width = width self.layers = layers self.head_width = head_width self.heads = width // head_width self.mlp_ratio = mlp_ratio self.ls_init_value = ls_init_value self.attentional_pool = attentional_pool self.attn_pooler_queries = attn_pooler_queries self.attn_pooler_heads = attn_pooler_heads self.output_dim = output_dim self.patch_dropout = patch_dropout self.no_ln_pre = no_ln_pre self.pool_type = pool_type self.final_ln_after_pool = final_ln_after_pool self.output_tokens = output_tokens class JapaneseCLIPConfig(PretrainedConfig): model_type = "japanese_clip" def __init__( self, max_length: int = 77, **kwargs ): super().__init__(**kwargs) self.max_length = max_length if "vision_config" not in kwargs: raise ValueError("vision_config must be provided") if "text_config" not in kwargs: raise ValueError("text_config must be provided") vision_config = kwargs.pop("vision_config") text_config = kwargs.pop("text_config") self.vision_config = JapaneseCLIPVisionConfig(**vision_config) self.text_config = RobertaConfig(**text_config) @classmethod def from_vision_text_configs( cls, vision_config: PretrainedConfig, text_config: PretrainedConfig, **kwargs ): r""" Instantiate a [`VisionTextDualEncoderConfig`] (or a derived class) from text model configuration and vision model configuration. Returns: [`VisionTextDualEncoderConfig`]: An instance of a configuration object """ return cls( vision_config=vision_config.to_dict(), text_config=text_config.to_dict(), **kwargs, )