File size: 3,157 Bytes
2b867e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
from transformers import PretrainedConfig


class ProteinGLMConfig(PretrainedConfig):
    model_type = "ProteinGLM"
    def __init__(
        self,
        num_layers=24,
        padded_vocab_size=128,
        hidden_size=2048,
        ffn_hidden_size=5461,
        kv_channels=64,
        num_attention_heads=32,
        seq_length=1024,
        hidden_dropout=0.0,
        attention_dropout=0.0,
        layernorm_epsilon=1e-5,
        glu_activation='geglu',
        rmsnorm=False,
        deepnorm=True,
        apply_residual_connection_post_layernorm=True,
        post_layer_norm=True,
        add_bias_linear=True,
        add_qkv_bias=True,
        bias_dropout_fusion=True,
        multi_query_attention=False,
        multi_query_group_num=1,
        apply_query_key_layer_scaling=True,
        attention_softmax_in_fp32=True,
        fp32_residual_connection=False,
        quantization_bit=0,
        rotary_embedding_2d=False,
        use_pytorch_sdpa=True,
        is_causal=False,
        use_cache=True,
        initializer_range=0.02,
        moe=False,
        num_experts=0, 
        experts_per_token=0,
        untie_head=False,
        head_num=1,
        **kwargs
    ):

        if not deepnorm and apply_residual_connection_post_layernorm:
            print(f"Warning: deepnorm is False and apply_residual_connection_post_layernorm is True")

        if deepnorm:
            apply_residual_connection_post_layernorm = True

        self.num_layers = num_layers
        self.vocab_size = padded_vocab_size
        self.padded_vocab_size = padded_vocab_size
        self.hidden_size = hidden_size
        self.ffn_hidden_size = ffn_hidden_size
        self.kv_channels = kv_channels
        self.num_attention_heads = num_attention_heads
        self.seq_length = seq_length
        self.hidden_dropout = hidden_dropout
        self.attention_dropout = attention_dropout
        self.layernorm_epsilon = layernorm_epsilon
        self.glu_activation = glu_activation
        self.rmsnorm = rmsnorm
        self.deepnorm = deepnorm
        self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
        self.post_layer_norm = post_layer_norm
        self.add_bias_linear = add_bias_linear
        self.add_qkv_bias = add_qkv_bias
        self.bias_dropout_fusion = bias_dropout_fusion
        self.multi_query_attention = multi_query_attention
        self.multi_query_group_num = multi_query_group_num
        self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
        self.attention_softmax_in_fp32 = attention_softmax_in_fp32
        self.fp32_residual_connection = fp32_residual_connection
        self.quantization_bit = quantization_bit
        self.rotary_embedding_2d = rotary_embedding_2d
        self.is_causal = is_causal
        self.use_cache = use_cache
        self.initializer_range = initializer_range
        self.use_pytorch_sdpa = use_pytorch_sdpa
        self.moe = moe
        self.num_experts = num_experts
        self.experts_per_token = experts_per_token
        self.untie_head = untie_head
        self.head_num=head_num
        super().__init__(**kwargs)