File size: 2,730 Bytes
ac9a398 f68c08c ac9a398 f68c08c ac9a398 fd3a88e ac9a398 fd3a88e ac9a398 e1a2c6a ac9a398 f68c08c ac9a398 6f5ff9a ac9a398 711b47e ac9a398 6f5ff9a 711b47e ac9a398 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
from transformers import PretrainedConfig, RobertaConfig
class JapaneseCLIPVisionConfig(PretrainedConfig):
model_type = "vit"
is_composition = True
def __init__(self,
image_size: int,
patch_size: int,
width: int,
layers: int,
head_width: int,
mlp_ratio: float,
ls_init_value: float = None,
attentional_pool: bool = False,
attn_pooler_queries: int = 256,
attn_pooler_heads: int = 8,
output_dim: int = 512,
patch_dropout: float = 0.0,
no_ln_pre: bool = False,
pool_type: str = "tok",
final_ln_after_pool: bool = False,
output_tokens: bool = False,
**kwargs
):
self.image_size = image_size
self.patch_size = patch_size
self.width = width
self.layers = layers
self.head_width = head_width
self.heads = width // head_width
self.mlp_ratio = mlp_ratio
self.ls_init_value = ls_init_value
self.attentional_pool = attentional_pool
self.attn_pooler_queries = attn_pooler_queries
self.attn_pooler_heads = attn_pooler_heads
self.output_dim = output_dim
self.patch_dropout = patch_dropout
self.no_ln_pre = no_ln_pre
self.pool_type = pool_type
self.final_ln_after_pool = final_ln_after_pool
self.output_tokens = output_tokens
super().__init__(**kwargs)
class JapaneseCLIPConfig(PretrainedConfig):
model_type = "japanese_clip"
is_composition = True
def __init__(
self,
max_length: int = 77,
**kwargs
):
super().__init__(**kwargs)
self.max_length = max_length
if "vision_config" not in kwargs:
raise ValueError("vision_config must be provided")
if "text_config" not in kwargs:
raise ValueError("text_config must be provided")
vision_config = kwargs.pop("vision_config")
text_config = kwargs.pop("text_config")
self.vision_config = JapaneseCLIPVisionConfig(**vision_config)
self.text_config = RobertaConfig(**text_config)
@classmethod
def from_vision_text_configs(
cls,
vision_config: PretrainedConfig,
text_config: PretrainedConfig,
**kwargs
):
r"""
Instantiate a [`VisionTextDualEncoderConfig`] (or a derived class) from text model configuration and vision
model configuration.
Returns:
[`VisionTextDualEncoderConfig`]: An instance of a configuration object
"""
return cls(
vision_config=vision_config.to_dict(),
text_config=text_config.to_dict(),
**kwargs,
) |