File size: 7,358 Bytes
9d3cd60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import math
from typing import Optional, Tuple, Union
import torch
from torch import nn
from transformers.modeling_outputs import BaseModelOutput
from transformers import Wav2Vec2BertModel, Wav2Vec2BertConfig, MllamaPreTrainedModel
from transformers.models.wav2vec2_bert.modeling_wav2vec2_bert import _compute_new_attention_mask, _prepare_4d_attention_mask, Wav2Vec2BertFeedForward, Wav2Vec2BertSelfAttention, Wav2Vec2BertFeatureProjection
from .configuration_llama3 import Llama3Config

class Wav2Vec2BertAdapterLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        embed_dim = config.output_hidden_size
        dropout = config.conformer_conv_dropout

        self.kernel_size = config.adapter_kernel_size
        self.stride = config.adapter_stride

        self.activation = nn.GLU(dim=1)

        # Self-Attention
        self.self_attn_conv = nn.Conv1d(
            embed_dim,
            2 * embed_dim,
            self.kernel_size,
            stride=self.stride,
            padding=self.stride // 2,
        )
        self.self_attn = Wav2Vec2BertSelfAttention(config, is_adapter_attention=True)
        self.self_attn_dropout = nn.Dropout(dropout)

        # Feed-forward
        self.ffn = Wav2Vec2BertFeedForward(config, act_fn=config.adapter_act, hidden_size=embed_dim)

    def forward(
        self,
        hidden_states,
        attention_mask: Optional[torch.Tensor] = None,
        output_attentions: bool = False,
        sub_sampled_lengths: Optional[torch.Tensor] = None,
    ):
        # Apply pooling before feeding to the multihead-attention layer.
        # (batch, seq_len, feature_dim) -> (batch, feature_dim, seq_len)
        hidden_states = hidden_states.transpose(1, 2)
        hidden_states = self.self_attn_conv(hidden_states)
        hidden_states = self.activation(hidden_states)
        # (batch, feature_dim, seq_len) -> (batch, seq_len, feature_dim)
        hidden_states = hidden_states.transpose(1, 2)

        if attention_mask is not None:
            attention_mask = _compute_new_attention_mask(hidden_states=hidden_states, seq_lens=sub_sampled_lengths)
            attention_mask = _prepare_4d_attention_mask(
                attention_mask,
                hidden_states.dtype,
            )

        # The rest of the computation is identical to a vanilla Transformer
        # encoder layer.
        hidden_states, attn_weigths = self.self_attn(
            hidden_states,
            attention_mask=attention_mask,
            output_attentions=output_attentions,
        )
        hidden_states = self.self_attn_dropout(hidden_states)
        hidden_states = self.ffn(hidden_states)

        return hidden_states

class AudioAdapter(nn.Module):
    def __init__(self, config: Wav2Vec2BertConfig):
        super().__init__()
        # feature dim might need to be down-projected
        if config.output_hidden_size != config.hidden_size:
            self.proj = nn.Linear(config.hidden_size, config.output_hidden_size)
        else:
            self.proj = None
        self.layers = nn.ModuleList(Wav2Vec2BertAdapterLayer(config) for _ in range(config.num_adapter_layers))

        self.kernel_size = config.adapter_kernel_size
        self.stride = config.adapter_stride

    def _compute_sub_sample_lengths_from_attention_mask(self, seq_lens):
        if seq_lens is None:
            return seq_lens
        pad = self.stride // 2
        seq_lens = ((seq_lens + 2 * pad - self.kernel_size) / self.stride) + 1
        return seq_lens.floor()

    def forward(self, hidden_states, attention_mask=None):
        # down project hidden_states if necessary
        if self.proj is not None:
            hidden_states = self.proj(hidden_states)

        sub_sampled_lengths = None
        if attention_mask is not None:
            sub_sampled_lengths = (attention_mask.size(1) - (1 - attention_mask.int()).sum(1)).to(hidden_states.device)

        for layer in self.layers:
            sub_sampled_lengths = self._compute_sub_sample_lengths_from_attention_mask(sub_sampled_lengths)
            hidden_states = layer(
                hidden_states, attention_mask=attention_mask, sub_sampled_lengths=sub_sampled_lengths
            )

        return hidden_states

        
class Llama3Embedding(MllamaPreTrainedModel):
    config_class = Llama3Config
    base_model_prefix = "audio_model"
    def __init__(self, config: Llama3Config):
        super().__init__(config)
        assert config.audio_config.output_hidden_size == config.text_config.hidden_size
        config.audio_config.add_adapter = False
        self.audio_encoder = Wav2Vec2BertModel(config.audio_config)
        self.audio_adapter = AudioAdapter(config.audio_config)
        self.start_of_audio = nn.Parameter(data=torch.empty((1, config.audio_config.output_hidden_size)), requires_grad=True)
        self.end_of_audio = nn.Parameter(data=torch.empty((1, config.audio_config.output_hidden_size)), requires_grad=True)
            
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        input_embeddings: torch.Tensor = None,
        audio_features: Optional[torch.Tensor] = None,
    ) -> Union[BaseModelOutput, Tuple[torch.Tensor, ...]]:
        if audio_features is None:
            return input_embeddings
        bs, max_num_img, l, d = audio_features.shape
        audio_embeddings = self.audio_encoder(input_features=audio_features.view((bs*max_num_img, l, d)))['last_hidden_state']
        audio_embeddings = self.audio_adapter(audio_embeddings)
        audio_embeddings = audio_embeddings.view((bs, max_num_img, -1, self.start_of_audio.shape[-1]))        

        for i in range(bs):
            for j in range(max_num_img):
                audio_id = -1 - j
                if torch.any(input_ids[i] == audio_id):
                    positions = torch.nonzero(input_ids[i] == audio_id, as_tuple=True)
                    input_embeddings[i] = input_embeddings[i].index_put(positions, torch.concat([self.start_of_audio, audio_embeddings[i, j, :, :], self.end_of_audio]), accumulate=False)
        return input_embeddings
    
    def _init_weights(self, module):
        """Initialize the weights"""
        if isinstance(module, Wav2Vec2BertSelfAttention):
            if hasattr(module, "pos_bias_u"):
                nn.init.xavier_uniform_(module.pos_bias_u)
            if hasattr(module, "pos_bias_v"):
                nn.init.xavier_uniform_(module.pos_bias_v)
        elif isinstance(module, Wav2Vec2BertFeatureProjection):
            k = math.sqrt(1 / module.projection.in_features)
            nn.init.uniform_(module.projection.weight, a=-k, b=k)
            nn.init.uniform_(module.projection.bias, a=-k, b=k)
        elif isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=self.config.audio_config.initializer_range)

            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
        elif isinstance(module, nn.Conv1d):
            nn.init.kaiming_normal_(module.weight)

            if module.bias is not None:
                k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
                nn.init.uniform_(module.bias, a=-k, b=k)