File size: 7,358 Bytes
9d3cd60 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
import math
from typing import Optional, Tuple, Union
import torch
from torch import nn
from transformers.modeling_outputs import BaseModelOutput
from transformers import Wav2Vec2BertModel, Wav2Vec2BertConfig, MllamaPreTrainedModel
from transformers.models.wav2vec2_bert.modeling_wav2vec2_bert import _compute_new_attention_mask, _prepare_4d_attention_mask, Wav2Vec2BertFeedForward, Wav2Vec2BertSelfAttention, Wav2Vec2BertFeatureProjection
from .configuration_llama3 import Llama3Config
class Wav2Vec2BertAdapterLayer(nn.Module):
def __init__(self, config):
super().__init__()
embed_dim = config.output_hidden_size
dropout = config.conformer_conv_dropout
self.kernel_size = config.adapter_kernel_size
self.stride = config.adapter_stride
self.activation = nn.GLU(dim=1)
# Self-Attention
self.self_attn_conv = nn.Conv1d(
embed_dim,
2 * embed_dim,
self.kernel_size,
stride=self.stride,
padding=self.stride // 2,
)
self.self_attn = Wav2Vec2BertSelfAttention(config, is_adapter_attention=True)
self.self_attn_dropout = nn.Dropout(dropout)
# Feed-forward
self.ffn = Wav2Vec2BertFeedForward(config, act_fn=config.adapter_act, hidden_size=embed_dim)
def forward(
self,
hidden_states,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
sub_sampled_lengths: Optional[torch.Tensor] = None,
):
# Apply pooling before feeding to the multihead-attention layer.
# (batch, seq_len, feature_dim) -> (batch, feature_dim, seq_len)
hidden_states = hidden_states.transpose(1, 2)
hidden_states = self.self_attn_conv(hidden_states)
hidden_states = self.activation(hidden_states)
# (batch, feature_dim, seq_len) -> (batch, seq_len, feature_dim)
hidden_states = hidden_states.transpose(1, 2)
if attention_mask is not None:
attention_mask = _compute_new_attention_mask(hidden_states=hidden_states, seq_lens=sub_sampled_lengths)
attention_mask = _prepare_4d_attention_mask(
attention_mask,
hidden_states.dtype,
)
# The rest of the computation is identical to a vanilla Transformer
# encoder layer.
hidden_states, attn_weigths = self.self_attn(
hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
)
hidden_states = self.self_attn_dropout(hidden_states)
hidden_states = self.ffn(hidden_states)
return hidden_states
class AudioAdapter(nn.Module):
def __init__(self, config: Wav2Vec2BertConfig):
super().__init__()
# feature dim might need to be down-projected
if config.output_hidden_size != config.hidden_size:
self.proj = nn.Linear(config.hidden_size, config.output_hidden_size)
else:
self.proj = None
self.layers = nn.ModuleList(Wav2Vec2BertAdapterLayer(config) for _ in range(config.num_adapter_layers))
self.kernel_size = config.adapter_kernel_size
self.stride = config.adapter_stride
def _compute_sub_sample_lengths_from_attention_mask(self, seq_lens):
if seq_lens is None:
return seq_lens
pad = self.stride // 2
seq_lens = ((seq_lens + 2 * pad - self.kernel_size) / self.stride) + 1
return seq_lens.floor()
def forward(self, hidden_states, attention_mask=None):
# down project hidden_states if necessary
if self.proj is not None:
hidden_states = self.proj(hidden_states)
sub_sampled_lengths = None
if attention_mask is not None:
sub_sampled_lengths = (attention_mask.size(1) - (1 - attention_mask.int()).sum(1)).to(hidden_states.device)
for layer in self.layers:
sub_sampled_lengths = self._compute_sub_sample_lengths_from_attention_mask(sub_sampled_lengths)
hidden_states = layer(
hidden_states, attention_mask=attention_mask, sub_sampled_lengths=sub_sampled_lengths
)
return hidden_states
class Llama3Embedding(MllamaPreTrainedModel):
config_class = Llama3Config
base_model_prefix = "audio_model"
def __init__(self, config: Llama3Config):
super().__init__(config)
assert config.audio_config.output_hidden_size == config.text_config.hidden_size
config.audio_config.add_adapter = False
self.audio_encoder = Wav2Vec2BertModel(config.audio_config)
self.audio_adapter = AudioAdapter(config.audio_config)
self.start_of_audio = nn.Parameter(data=torch.empty((1, config.audio_config.output_hidden_size)), requires_grad=True)
self.end_of_audio = nn.Parameter(data=torch.empty((1, config.audio_config.output_hidden_size)), requires_grad=True)
def forward(
self,
input_ids: torch.LongTensor = None,
input_embeddings: torch.Tensor = None,
audio_features: Optional[torch.Tensor] = None,
) -> Union[BaseModelOutput, Tuple[torch.Tensor, ...]]:
if audio_features is None:
return input_embeddings
bs, max_num_img, l, d = audio_features.shape
audio_embeddings = self.audio_encoder(input_features=audio_features.view((bs*max_num_img, l, d)))['last_hidden_state']
audio_embeddings = self.audio_adapter(audio_embeddings)
audio_embeddings = audio_embeddings.view((bs, max_num_img, -1, self.start_of_audio.shape[-1]))
for i in range(bs):
for j in range(max_num_img):
audio_id = -1 - j
if torch.any(input_ids[i] == audio_id):
positions = torch.nonzero(input_ids[i] == audio_id, as_tuple=True)
input_embeddings[i] = input_embeddings[i].index_put(positions, torch.concat([self.start_of_audio, audio_embeddings[i, j, :, :], self.end_of_audio]), accumulate=False)
return input_embeddings
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, Wav2Vec2BertSelfAttention):
if hasattr(module, "pos_bias_u"):
nn.init.xavier_uniform_(module.pos_bias_u)
if hasattr(module, "pos_bias_v"):
nn.init.xavier_uniform_(module.pos_bias_v)
elif isinstance(module, Wav2Vec2BertFeatureProjection):
k = math.sqrt(1 / module.projection.in_features)
nn.init.uniform_(module.projection.weight, a=-k, b=k)
nn.init.uniform_(module.projection.bias, a=-k, b=k)
elif isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.audio_config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, nn.Conv1d):
nn.init.kaiming_normal_(module.weight)
if module.bias is not None:
k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
nn.init.uniform_(module.bias, a=-k, b=k) |