AlexHung29629
commited on
Update mllama_audio_model.py
Browse files- mllama_audio_model.py +5 -0
mllama_audio_model.py
CHANGED
@@ -4,6 +4,7 @@ import torch
|
|
4 |
from torch import nn
|
5 |
from transformers.modeling_outputs import BaseModelOutput
|
6 |
from transformers import Wav2Vec2BertModel, Wav2Vec2BertConfig, MllamaPreTrainedModel
|
|
|
7 |
from transformers.models.wav2vec2_bert.modeling_wav2vec2_bert import _compute_new_attention_mask, _prepare_4d_attention_mask, Wav2Vec2BertFeedForward, Wav2Vec2BertSelfAttention, Wav2Vec2BertFeatureProjection
|
8 |
from .configuration_llama3 import Llama3Config
|
9 |
|
@@ -50,6 +51,8 @@ class AudioAdapter(nn.Module):
|
|
50 |
else:
|
51 |
self.proj = None
|
52 |
self.layers = nn.ModuleList(Wav2Vec2BertAdapterLayer(config) for _ in range(config.num_adapter_layers))
|
|
|
|
|
53 |
|
54 |
def forward(self, hidden_states, attention_mask=None):
|
55 |
# down project hidden_states if necessary
|
@@ -61,6 +64,8 @@ class AudioAdapter(nn.Module):
|
|
61 |
hidden_states
|
62 |
)
|
63 |
|
|
|
|
|
64 |
return hidden_states
|
65 |
|
66 |
|
|
|
4 |
from torch import nn
|
5 |
from transformers.modeling_outputs import BaseModelOutput
|
6 |
from transformers import Wav2Vec2BertModel, Wav2Vec2BertConfig, MllamaPreTrainedModel
|
7 |
+
from transformers.models.llama.modeling_llama import LlamaRMSNorm
|
8 |
from transformers.models.wav2vec2_bert.modeling_wav2vec2_bert import _compute_new_attention_mask, _prepare_4d_attention_mask, Wav2Vec2BertFeedForward, Wav2Vec2BertSelfAttention, Wav2Vec2BertFeatureProjection
|
9 |
from .configuration_llama3 import Llama3Config
|
10 |
|
|
|
51 |
else:
|
52 |
self.proj = None
|
53 |
self.layers = nn.ModuleList(Wav2Vec2BertAdapterLayer(config) for _ in range(config.num_adapter_layers))
|
54 |
+
self.final_adapter_norm = LlamaRMSNorm(config.output_hidden_size)
|
55 |
+
self.final_adapter_norm.weight.data.fill_(0.4)
|
56 |
|
57 |
def forward(self, hidden_states, attention_mask=None):
|
58 |
# down project hidden_states if necessary
|
|
|
64 |
hidden_states
|
65 |
)
|
66 |
|
67 |
+
hidden_states = self.final_adapter_norm(hidden_states)
|
68 |
+
|
69 |
return hidden_states
|
70 |
|
71 |
|