AlexHung29629 commited on
Commit
c13395b
·
verified ·
1 Parent(s): 5ab87b5

Update mllama_audio_model.py

Browse files
Files changed (1) hide show
  1. mllama_audio_model.py +5 -0
mllama_audio_model.py CHANGED
@@ -4,6 +4,7 @@ import torch
4
  from torch import nn
5
  from transformers.modeling_outputs import BaseModelOutput
6
  from transformers import Wav2Vec2BertModel, Wav2Vec2BertConfig, MllamaPreTrainedModel
 
7
  from transformers.models.wav2vec2_bert.modeling_wav2vec2_bert import _compute_new_attention_mask, _prepare_4d_attention_mask, Wav2Vec2BertFeedForward, Wav2Vec2BertSelfAttention, Wav2Vec2BertFeatureProjection
8
  from .configuration_llama3 import Llama3Config
9
 
@@ -50,6 +51,8 @@ class AudioAdapter(nn.Module):
50
  else:
51
  self.proj = None
52
  self.layers = nn.ModuleList(Wav2Vec2BertAdapterLayer(config) for _ in range(config.num_adapter_layers))
 
 
53
 
54
  def forward(self, hidden_states, attention_mask=None):
55
  # down project hidden_states if necessary
@@ -61,6 +64,8 @@ class AudioAdapter(nn.Module):
61
  hidden_states
62
  )
63
 
 
 
64
  return hidden_states
65
 
66
 
 
4
  from torch import nn
5
  from transformers.modeling_outputs import BaseModelOutput
6
  from transformers import Wav2Vec2BertModel, Wav2Vec2BertConfig, MllamaPreTrainedModel
7
+ from transformers.models.llama.modeling_llama import LlamaRMSNorm
8
  from transformers.models.wav2vec2_bert.modeling_wav2vec2_bert import _compute_new_attention_mask, _prepare_4d_attention_mask, Wav2Vec2BertFeedForward, Wav2Vec2BertSelfAttention, Wav2Vec2BertFeatureProjection
9
  from .configuration_llama3 import Llama3Config
10
 
 
51
  else:
52
  self.proj = None
53
  self.layers = nn.ModuleList(Wav2Vec2BertAdapterLayer(config) for _ in range(config.num_adapter_layers))
54
+ self.final_adapter_norm = LlamaRMSNorm(config.output_hidden_size)
55
+ self.final_adapter_norm.weight.data.fill_(0.4)
56
 
57
  def forward(self, hidden_states, attention_mask=None):
58
  # down project hidden_states if necessary
 
64
  hidden_states
65
  )
66
 
67
+ hidden_states = self.final_adapter_norm(hidden_states)
68
+
69
  return hidden_states
70
 
71