File size: 4,337 Bytes
f188f75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
# Copyright 2023 Stability AI team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Union

from transformers import PretrainedConfig, CLIPVisionConfig
from transformers.models.auto import CONFIG_MAPPING
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
from transformers.utils import logging


logger = logging.get_logger(__name__)


class LlavaMlpConfig(PretrainedConfig):
    model_type = "llava_mlp"

    def __init__(
        self,
        num_hidden_layers=2,
        **kwargs,
    ):
        super().__init__(**kwargs)

        self.num_hidden_layers = num_hidden_layers

    @classmethod
    def from_pretrained(
        cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
    ) -> "PretrainedConfig":
        cls._set_token_in_kwargs(kwargs)

        config_dict, kwargs = cls.get_config_dict(
            pretrained_model_name_or_path, **kwargs
        )

        # get the qformer config dict if we are loading from InstructBlipConfig
        if config_dict.get("model_type") == "llava":
            config_dict = config_dict["mlp_config"]

        if (
            "model_type" in config_dict
            and hasattr(cls, "model_type")
            and config_dict["model_type"] != cls.model_type
        ):
            logger.warning(
                f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
                f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
            )

        return cls.from_dict(config_dict, **kwargs)


class LlavaConfig(PretrainedConfig):
    model_type = "llava"
    is_composition = True

    def __init__(
        self,
        vision_config=None,
        mlp_config=None,
        text_config=None,
        vision_select_layer=-2,
        vision_select_feature="patch",
        **kwargs,
    ):
        super().__init__(**kwargs)

        if vision_config is None:
            vision_config = {}
            logger.info(
                "vision_config is None. initializing the CLIPVisionConfig with default values."
            )

        if mlp_config is None:
            mlp_config = {}
            logger.info(
                "mlp_config is None. Initializing the LlavaMlpConfig with default values."
            )

        if text_config is None:
            text_config = {}
            logger.info(
                "text_config is None. Initializing the text config with default values (`OPTConfig`)."
            )

        self.vision_config = CLIPVisionConfig(**vision_config)
        self.mlp_config = LlavaMlpConfig(**mlp_config)
        text_model_type = text_config["model_type"]
        self.text_config = CONFIG_MAPPING[text_model_type](**text_config)

        self.tie_word_embeddings = self.text_config.tie_word_embeddings
        self.is_encoder_decoder = self.text_config.is_encoder_decoder

        self.use_decoder_only_language_model = (
            self.text_config.model_type in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
        )
        self.vision_select_layer = vision_select_layer
        assert vision_select_feature in [
            "cls_patch",
            "patch",
        ], f"Unexpected select feature: {vision_select_feature}"
        self.vision_select_feature = vision_select_feature
        self.initializer_factor = 1.0
        self.initializer_range = 0.02

    @classmethod
    def from_vision_mlp_text_configs(
        cls,
        vision_config: CLIPVisionConfig,
        mlp_config: LlavaMlpConfig,
        text_config: PretrainedConfig,
        **kwargs,
    ):
        return cls(
            vision_config=vision_config.to_dict(),
            mlp_config=mlp_config.to_dict(),
            text_config=text_config.to_dict(),
            **kwargs,
        )