""" OLMo configuration """ from transformers import AutoConfig, PretrainedConfig from transformers.utils import logging from .config import ModelConfig from .aliases import PathOrStr from .beam_search import Sampler from .exceptions import OLMoError from .initialization import ModuleType from .util import StrEnum from .torch_util import seed_all logger = logging.get_logger(__name__) class OLMoConfig(PretrainedConfig): model_type = "olmo" keys_to_ignore_at_inference = ["past_key_values"] # TODO: confirm def __init__(self, use_cache: bool = False, **kwargs): model_config = ModelConfig() all_kwargs = model_config.asdict() all_kwargs.update(kwargs) all_kwargs.update({"use_cache": use_cache}) all_kwargs.update( { "architectures": all_kwargs.get("architectures", ["OLMoModelForCausalLM"]) or ["OLMoModelForCausalLM"] } ) super().__init__(**all_kwargs) @property def num_attention_heads(self): return self.n_heads @property def num_hidden_layers(self): return self.n_layers @property def hidden_size(self): return self.d_model # Register the config class so that it is available for transformer pipelines, auto-loading etc. AutoConfig.register("olmo", OLMoConfig)