File size: 6,465 Bytes
e187c98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
from transformers.modeling_utils import PretrainedConfig

class RasphiConfig(PretrainedConfig):
    model_type = "rasphi"
    keys_to_ignore_at_inference = ["past_key_values"]

    def __init__(

        self,

        vocab_size=32064,

        hidden_size=4096,

        intermediate_size=6400,

        num_hidden_layers=32,

        num_attention_heads=32,

        num_key_value_heads=8,

        hidden_act="silu",

        max_position_embeddings=4096 * 32,

        initializer_range=0.02,

        rms_norm_eps=1e-5,

        use_cache=True,

        pad_token_id=None,

        bos_token_id=1,

        eos_token_id=2,

        tie_word_embeddings=False,

        rope_theta=1e6,

        rope_scaling=None,

        sliding_window=None,

        attention_dropout=0.0,

        num_experts_per_tok=2,

        num_local_experts=16,

        output_router_logits=False,

        router_aux_loss_coef=0.001,

        router_jitter_noise=0.01,

        input_jitter_noise=0.0,

        attention_bias=False,

        lm_head_bias=False,

        # Rasphi specific configurations

        num_reasoning_experts=8,  # Number of experts dedicated to reasoning stream

        num_content_experts=8,    # Number of experts dedicated to content stream

        reasoning_hidden_size=2048,  # Hidden size for reasoning stream

        content_hidden_size=2048,    # Hidden size for content stream

        stream_interaction="attention",  # How the two streams interact: "attention", "mlp", or "both"

        **kwargs,

    ):
        self.vocab_size = vocab_size
        self.max_position_embeddings = max_position_embeddings
        self.hidden_size = hidden_size
        self.intermediate_size = intermediate_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.sliding_window = sliding_window
        self.attention_bias = attention_bias
        self.lm_head_bias = lm_head_bias
        self.num_key_value_heads = num_key_value_heads
        self.hidden_act = hidden_act
        self.initializer_range = initializer_range
        self.rms_norm_eps = rms_norm_eps
        self.use_cache = use_cache
        self.rope_theta = rope_theta
        self.attention_dropout = attention_dropout
        self.num_experts_per_tok = num_experts_per_tok
        self.num_local_experts = num_local_experts
        self.output_router_logits = output_router_logits
        self.router_aux_loss_coef = router_aux_loss_coef
        self.router_jitter_noise = router_jitter_noise
        self.input_jitter_noise = input_jitter_noise
        self.rope_scaling = rope_scaling
        self._rope_scaling_validation()

        # Rasphi specific configurations
        self.num_reasoning_experts = num_reasoning_experts
        self.num_content_experts = num_content_experts
        self.reasoning_hidden_size = reasoning_hidden_size
        self.content_hidden_size = content_hidden_size
        self.stream_interaction = stream_interaction

        super().__init__(
            pad_token_id=pad_token_id,
            bos_token_id=bos_token_id,
            eos_token_id=eos_token_id,
            tie_word_embeddings=tie_word_embeddings,
            **kwargs,
        )

    def _rope_scaling_validation(self):
        """

        Validate the `rope_scaling` configuration.

        """
        if self.rope_scaling is None:
            return

        if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 6:
            raise ValueError(
                "`rope_scaling` must be a dictionary with three fields, `type`, `short_factor`, `long_factor`, "
                f"`short_mscale`, `long_mscale` and `original_max_position_embeddings`, got {self.rope_scaling}"
            )
        rope_scaling_type = self.rope_scaling.get("type", None)
        rope_scaling_short_factor = self.rope_scaling.get("short_factor", None)
        rope_scaling_long_factor = self.rope_scaling.get("long_factor", None)
        rope_scaling_short_mscale = self.rope_scaling.get("short_mscale", None)
        rope_scaling_long_mscale = self.rope_scaling.get("long_mscale", None)
        original_max_position_embeddings = self.rope_scaling.get("original_max_position_embeddings", None)
        if rope_scaling_type is None or rope_scaling_type not in ["longrope"]:
            raise ValueError(f"`rope_scaling`'s type field must be one of ['longrope'], got {rope_scaling_type}")
        if not (
            isinstance(rope_scaling_short_factor, list)
            and all(isinstance(x, (int, float)) for x in rope_scaling_short_factor)
        ):
            raise ValueError(
                f"`rope_scaling`'s short_factor field must be a list of numbers, got {rope_scaling_short_factor}"
            )
        if not len(rope_scaling_short_factor) == self.hidden_size // self.num_attention_heads // 2:
            raise ValueError(
                f"`rope_scaling`'s short_factor field must have length {self.hidden_size // self.num_attention_heads // 2}, got {len(rope_scaling_short_factor)}"
            )
        if not (
            isinstance(rope_scaling_long_factor, list)
            and all(isinstance(x, (int, float)) for x in rope_scaling_long_factor)
        ):
            raise ValueError(
                f"`rope_scaling`'s long_factor field must be a list of numbers, got {rope_scaling_long_factor}"
            )
        if not len(rope_scaling_long_factor) == self.hidden_size // self.num_attention_heads // 2:
            raise ValueError(
                f"`rope_scaling`'s long_factor field must have length {self.hidden_size // self.num_attention_heads // 2}, got {len(rope_scaling_long_factor)}"
            )
        if not isinstance(rope_scaling_short_mscale, (int, float)):
            raise ValueError(
                f"`rope_scaling`'s short_mscale field must be a number, got {rope_scaling_short_mscale}"
            )
        if not isinstance(rope_scaling_long_mscale, (int, float)):
            raise ValueError(
                f"`rope_scaling`'s long_mscale field must be a number, got {rope_scaling_long_mscale}"
            )
        if not isinstance(original_max_position_embeddings, int):
            raise ValueError(
                f"`rope_scaling`'s original_max_position_embeddings field must be an integer, got {original_max_position_embeddings}"
            )