takatosp1 commited on
Commit
0b45dd5
1 Parent(s): 492f418

Update modeling_qwen.py

Browse files
Files changed (1) hide show
  1. modeling_qwen.py +7 -7
modeling_qwen.py CHANGED
@@ -154,7 +154,7 @@ class QWenAttention(nn.Module):
154
  if self.rotary_ndims is not None
155
  else self.hidden_size_per_attention_head
156
  )
157
- self.rotary_emb = RotaryEmbedding(dim, base=config.rotary_emb_base)
158
 
159
  self.use_dynamic_ntk = config.use_dynamic_ntk
160
  self.use_logn_attn = config.use_logn_attn
@@ -386,12 +386,12 @@ class QWenBlock(nn.Module):
386
  hidden_size = config.hidden_size
387
  self.bf16 = config.bf16
388
 
389
- self.ln_1 = RMSNorm(
390
  hidden_size,
391
  eps=config.layer_norm_epsilon,
392
  )
393
  self.attn = QWenAttention(config)
394
- self.ln_2 = RMSNorm(
395
  hidden_size,
396
  eps=config.layer_norm_epsilon,
397
  )
@@ -460,7 +460,7 @@ class QWenPreTrainedModel(PreTrainedModel):
460
  module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
461
  if module.padding_idx is not None:
462
  module.weight.data[module.padding_idx].zero_()
463
- elif isinstance(module, RMSNorm):
464
  module.weight.data.fill_(1.0)
465
 
466
  for name, p in module.named_parameters():
@@ -500,7 +500,7 @@ class QWenModel(QWenPreTrainedModel):
500
  for i in range(config.num_hidden_layers)
501
  ]
502
  )
503
- self.ln_f = RMSNorm(
504
  self.embed_dim,
505
  eps=config.layer_norm_epsilon,
506
  )
@@ -1041,7 +1041,7 @@ class QWenLMHeadModel(QWenPreTrainedModel):
1041
  )
1042
 
1043
 
1044
- class RotaryEmbedding(torch.nn.Module):
1045
  def __init__(self, dim, base=10000):
1046
  super().__init__()
1047
  self.dim = dim
@@ -1104,7 +1104,7 @@ def apply_rotary_pos_emb(t, freqs):
1104
  return torch.cat((t_, t_pass_), dim=-1).type_as(t)
1105
 
1106
 
1107
- class RMSNorm(torch.nn.Module):
1108
  def __init__(self, dim: int, eps: float = 1e-6):
1109
  super().__init__()
1110
  self.eps = eps
 
154
  if self.rotary_ndims is not None
155
  else self.hidden_size_per_attention_head
156
  )
157
+ self.rotary_emb = QWenRotaryEmbedding(dim, base=config.rotary_emb_base)
158
 
159
  self.use_dynamic_ntk = config.use_dynamic_ntk
160
  self.use_logn_attn = config.use_logn_attn
 
386
  hidden_size = config.hidden_size
387
  self.bf16 = config.bf16
388
 
389
+ self.ln_1 = QWenRMSNorm(
390
  hidden_size,
391
  eps=config.layer_norm_epsilon,
392
  )
393
  self.attn = QWenAttention(config)
394
+ self.ln_2 = QWenRMSNorm(
395
  hidden_size,
396
  eps=config.layer_norm_epsilon,
397
  )
 
460
  module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
461
  if module.padding_idx is not None:
462
  module.weight.data[module.padding_idx].zero_()
463
+ elif isinstance(module, QWenRMSNorm):
464
  module.weight.data.fill_(1.0)
465
 
466
  for name, p in module.named_parameters():
 
500
  for i in range(config.num_hidden_layers)
501
  ]
502
  )
503
+ self.ln_f = QWenRMSNorm(
504
  self.embed_dim,
505
  eps=config.layer_norm_epsilon,
506
  )
 
1041
  )
1042
 
1043
 
1044
+ class QWenRotaryEmbedding(torch.nn.Module):
1045
  def __init__(self, dim, base=10000):
1046
  super().__init__()
1047
  self.dim = dim
 
1104
  return torch.cat((t_, t_pass_), dim=-1).type_as(t)
1105
 
1106
 
1107
+ class QWenRMSNorm(torch.nn.Module):
1108
  def __init__(self, dim: int, eps: float = 1e-6):
1109
  super().__init__()
1110
  self.eps = eps