File size: 1,351 Bytes
5169b80
 
 
 
 
 
 
2010c83
 
 
 
 
 
 
5169b80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
"""
OLMo configuration
"""

from transformers import AutoConfig, PretrainedConfig
from transformers.utils import logging

from .config import ModelConfig
from .aliases import PathOrStr
from .beam_search import Sampler
from .exceptions import OLMoError
from .initialization import ModuleType
from .util import StrEnum
from .torch_util import seed_all

logger = logging.get_logger(__name__)


class OLMoConfig(PretrainedConfig):
    model_type = "olmo"
    keys_to_ignore_at_inference = ["past_key_values"]  # TODO: confirm

    def __init__(self, use_cache: bool = False, **kwargs):
        model_config = ModelConfig()
        all_kwargs = model_config.asdict()
        all_kwargs.update(kwargs)
        all_kwargs.update({"use_cache": use_cache})
        all_kwargs.update(
            {
                "architectures": all_kwargs.get("architectures", ["OLMoModelForCausalLM"])
                or ["OLMoModelForCausalLM"]
            }
        )
        super().__init__(**all_kwargs)

    @property
    def num_attention_heads(self):
        return self.n_heads

    @property
    def num_hidden_layers(self):
        return self.n_layers

    @property
    def hidden_size(self):
        return self.d_model


# Register the config class so that it is available for transformer pipelines, auto-loading etc.
AutoConfig.register("olmo", OLMoConfig)